#!/usr/bin/env python3
"""
Fill SPARC120_RESULTS_TEMPLATE.md from reports/_summary_SP120.json.

Usage:
  python -m scripts.fill_results \
    --summary reports/_summary_SP120.json \
    --template SPARC120_RESULTS_TEMPLATE.md \
    --out SPARC120_RESULTS.md
"""
import argparse, json, math, re
from datetime import datetime, timezone
from pathlib import Path

def pct(x, digits=2):
    try:
        return f"{100*float(x):.{digits}f}%"
    except Exception:
        return "—"

def get(d, *path, default=None):
    cur = d
    for p in path:
        if not isinstance(cur, dict) or p not in cur: return default
        cur = cur[p]
    return cur

def load_summary(p):
    j = json.load(open(p, "r", encoding="utf-8"))
    rows = j.get("rows", [])
    totals = j.get("solver_totals", {})
    wins = j.get("wins", {})
    meta = j.get("meta", {})
    subs = j.get("subgroups", {})  # optional
    return j, rows, totals, wins, meta, subs

def derive_counts(rows, totals):
    # Prefer explicit totals; fall back to rows length
    n = None
    if "rft_geom" in totals:
        n = totals["rft_geom"].get("n_cases")
    if not n:
        n = len(rows) or get(totals, "mond", "n_cases", default=0)
    return int(n or 0)

def default_dt(meta):
    t = meta.get("timestamp_utc")
    if t:
        try:
            return t
        except Exception:
            pass
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")

def build_map(summary, rows, totals, wins, meta, subs):
    n = derive_counts(rows, totals)
    def solver_block(name):
        t = totals.get(name, {})
        return {
            f"{name.upper()}_N":            str(t.get("n_cases", n)),
            f"{name.upper()}_PASS":         str(t.get("pass_count", "0")),
            f"{name.upper()}_PASS_RATE":    pct(t.get("pass_rate", 0)),
            f"{name.upper()}_WILSON_LO":    pct(get(t, "wilson_95", default=[0,0])[0] if isinstance(get(t, "wilson_95"), list) else 0),
            f"{name.upper()}_WILSON_HI":    pct(get(t, "wilson_95", default=[0,0])[1] if isinstance(get(t, "wilson_95"), list) else 0),
            f"{name.upper()}_MEDIAN_RMS":   (f"{t.get('rms_pct_median'):.2f}%" if isinstance(t.get('rms_pct_median'), (int,float)) else "—"),
        }

    m = {
        "DATE_ISO": default_dt(meta),
        "COHORT": meta.get("cohort", "SPARC-120"),
        "GLOBAL_CONFIG_SHA256": meta.get("global_config_sha256", "UNKNOWN"),
        "N_CASES": str(n),
    }
    # Solvers we expect
    for s in ["rft_geom", "mond", "nfw_fit"]:
        m.update(solver_block(s))

    # Pairwise wins
    for k in ["rft_geom>mond", "rft_geom>nfw_fit", "mond>nfw_fit"]:
        w = wins.get(k, {"wins":0,"of":0})
        m[f"WINS_{k.replace('>','_GT_').upper()}"] = str(w.get("wins",0))
        m[f"WINS_{k.replace('>','_GT_').upper()}_OF"] = str(w.get("of",0))
        of = w.get("of",0) or 0
        rate = (w.get("wins",0)/of) if of else None
        m[f"WINS_{k.replace('>','_GT_').upper()}_RATE"] = pct(rate) if rate is not None else "—"

    # Subgroups (overall, not per-solver)
    for tag in ["LSB","HSB","bulge","no_bulge"]:
        sg = subs.get(tag, {})
        m[f"SUB_{tag.upper()}_N"] = str(sg.get("n", 0))
        m[f"SUB_{tag.upper()}_PASS"] = str(sg.get("pass", 0))
        m[f"SUB_{tag.upper()}_PASS_RATE"] = pct(sg.get("pass_rate", 0))
        w95 = sg.get("wilson_95", [0,0])
        m[f"SUB_{tag.upper()}_WILSON_LO"] = pct(w95[0] if isinstance(w95,(list,tuple)) and len(w95)==2 else 0)
        m[f"SUB_{tag.upper()}_WILSON_HI"] = pct(w95[1] if isinstance(w95,(list,tuple)) and len(w95)==2 else 0)

    return m

def fill_template(tpl_text, mapping):
    # Replace {{PLACEHOLDER}} tokens; leave unknowns intact
    def repl(match):
        key = match.group(1)
        return str(mapping.get(key, match.group(0)))
    return re.sub(r"\{\{([A-Z0-9_>]+)\}\}", repl, tpl_text)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--summary", required=True)
    ap.add_argument("--template", required=True)
    ap.add_argument("--out", required=True)
    args = ap.parse_args()

    summary, rows, totals, wins, meta, subs = load_summary(args.summary)
    mapping = build_map(summary, rows, totals, wins, meta, subs)

    tpl = Path(args.template).read_text(encoding="utf-8")
    out = fill_template(tpl, mapping)
    Path(args.out).write_text(out, encoding="utf-8")
    print(f"✅ Wrote {args.out}")

    # Print key metrics
    print("\nKey Metrics:")
    print(f"  RFT Pass Rate: {mapping.get('RFT_GEOM_PASS_RATE', '—')}")
    print(f"  RFT vs MOND:   {mapping.get('WINS_RFT_GEOM_GT_MOND', '—')}/{mapping.get('WINS_RFT_GEOM_GT_MOND_OF', '—')} ({mapping.get('WINS_RFT_GEOM_GT_MOND_RATE', '—')})")
    print(f"  RFT vs NFW:    {mapping.get('WINS_RFT_GEOM_GT_NFW_FIT', '—')}/{mapping.get('WINS_RFT_GEOM_GT_NFW_FIT_OF', '—')} ({mapping.get('WINS_RFT_GEOM_GT_NFW_FIT_RATE', '—')})")
    print(f"  LSB Pass Rate: {mapping.get('SUB_LSB_PASS_RATE', '—')}")
    print(f"  HSB Pass Rate: {mapping.get('SUB_HSB_PASS_RATE', '—')}")

if __name__ == "__main__":
    main()
