#!/usr/bin/env python3
"""
Compare RFT v2 against standard theories: Newtonian, MOND, NFW halo.

This provides an apples-to-apples comparison on the same TEST cohort
using the same pass@20% threshold.
"""
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))

import json
import numpy as np
from sparc_rft.case import load_case
from solver.rft_v2_gated_tail import apply_v2_gated_tail
from solver.mond import apply_mond
from solver.nfw_halo import apply_nfw_halo
from metrics.rc_metrics import compute_metrics


def run_theory(theory_name, apply_func, config, manifest_path, is_v2=False):
    """Run a single theory on all galaxies."""
    with open(manifest_path) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    results = []
    for case_path in case_paths:
        try:
            case = load_case(case_path)
            if is_v2:
                # v2 expects separate kernel_config and tail_config
                result = apply_func(case, config["kernel"], config["tail"])
            else:
                result = apply_func(case, config)

            r = np.array(result["r_kpc"])
            v_pred = np.array(result["v_pred_kms"])
            v_obs = np.array(case.v_obs_kms)
            sigma = np.array(case.sigma_v_kms)

            # Standard window
            mask = (r >= 1.0) & (r <= 30.0) & (sigma > 0)
            if mask.sum() >= 3:
                metrics = compute_metrics(v_obs_kms=v_obs[mask], v_pred_kms=v_pred[mask])
                results.append({
                    "name": case.name,
                    "rms_percent": float(metrics["rms_percent"]),
                    "pass_20": bool(metrics["rms_percent"] <= 20.0),
                    "pass_10": bool(metrics["rms_percent"] <= 10.0)
                })
        except Exception as e:
            print(f"  ERROR {theory_name} on {case_path}: {e}", file=sys.stderr)

    pass_20_count = sum(r["pass_20"] for r in results)
    pass_10_count = sum(r["pass_10"] for r in results)
    pass_20_rate = 100.0 * pass_20_count / len(results) if results else 0.0
    pass_10_rate = 100.0 * pass_10_count / len(results) if results else 0.0
    rms_median = float(np.median([r["rms_percent"] for r in results])) if results else 0.0

    return {
        "theory": theory_name,
        "n_cases": len(results),
        "pass_20_count": int(pass_20_count),
        "pass_20_rate": float(pass_20_rate),
        "pass_10_count": int(pass_10_count),
        "pass_10_rate": float(pass_10_rate),
        "rms_median": float(rms_median),
        "config": config,
        "results": results
    }


print("="*72)
print("THEORY COMPARISON: RFT v2 vs Standard Models")
print("="*72)
print("Cohort: TEST (34 galaxies)")
print("Window: 1.0-30.0 kpc, ≥3 data points")
print("")

theories = []

# 1. Newtonian (baryons only) - baseline
print("Running Newtonian (baryons only)...")
newtonian_config = {
    "kernel": {"grid": [0.0], "weights": [1.0], "r_scale": "r_geo"},
    "tail": {"A0_kms2_per_kpc": 0.0, "alpha": 0.6, "g_star_kms2_per_kpc": 1000,
             "gamma": 0.5, "r_turn_kpc": 2.0, "p": 2.0}
}
newtonian = run_theory("Newtonian", apply_v2_gated_tail, newtonian_config,
                       "cases/SP99-TEST.manifest.txt", is_v2=True)
theories.append(newtonian)
print(f"  Pass@20%: {newtonian['pass_20_rate']:.1f}%, RMS: {newtonian['rms_median']:.1f}%\n")

# 2. MOND (standard interpolation)
print("Running MOND (standard)...")
mond_config = {
    "a0_kms2_per_kpc": 3.7,  # Standard MOND scale
    "mu_form": "standard"
}
mond = run_theory("MOND", apply_mond, mond_config, "cases/SP99-TEST.manifest.txt")
theories.append(mond)
print(f"  Pass@20%: {mond['pass_20_rate']:.1f}%, RMS: {mond['rms_median']:.1f}%\n")

# 3. NFW halo (ΛCDM)
print("Running NFW halo (ΛCDM)...")
nfw_config = {
    "v200_kms": None,  # Auto-fit from data
    "c": 10.0,
    "core_kpc": 0.5
}
nfw = run_theory("NFW_halo", apply_nfw_halo, nfw_config, "cases/SP99-TEST.manifest.txt")
theories.append(nfw)
print(f"  Pass@20%: {nfw['pass_20_rate']:.1f}%, RMS: {nfw['rms_median']:.1f}%\n")

# 4. RFT v2 (frozen best)
print("Running RFT v2 (frozen)...")
rft_v2_config = json.load(open("config/global_rc_v2_frozen.json"))
rft_v2 = run_theory("RFT_v2", apply_v2_gated_tail, rft_v2_config,
                    "cases/SP99-TEST.manifest.txt", is_v2=True)
theories.append(rft_v2)
print(f"  Pass@20%: {rft_v2['pass_20_rate']:.1f}%, RMS: {rft_v2['rms_median']:.1f}%\n")

# Save results
Path("reports/comparisons").mkdir(parents=True, exist_ok=True)
with open("reports/comparisons/theory_comparison_test.json", "w") as f:
    json.dump({"theories": theories}, f, indent=2)

# Summary table
print("="*72)
print("SUMMARY TABLE")
print("="*72)
print(f"{'Theory':<20} {'Pass@20%':>10} {'Pass@10%':>10} {'RMS%':>8} {'vs RFT v2':>12}")
print("-"*72)

rft_v2_pass = rft_v2['pass_20_rate']
for t in theories:
    delta = t['pass_20_rate'] - rft_v2_pass if t['theory'] != 'RFT_v2' else 0.0
    delta_str = f"{delta:+.1f}pp" if t['theory'] != 'RFT_v2' else "—"
    print(f"{t['theory']:<20} {t['pass_20_rate']:>9.1f}% {t['pass_10_rate']:>9.1f}% "
          f"{t['rms_median']:>7.1f}% {delta_str:>12}")

print("="*72)
print("\nKey Findings:")
print(f"  - RFT v2: {rft_v2['pass_20_rate']:.1f}% pass@20% (GREEN gate)")
print(f"  - MOND: {mond['pass_20_rate']:.1f}% pass@20% ({mond['pass_20_rate'] - rft_v2_pass:+.1f}pp vs RFT v2)")
print(f"  - NFW halo: {nfw['pass_20_rate']:.1f}% pass@20% ({nfw['pass_20_rate'] - rft_v2_pass:+.1f}pp vs RFT v2)")
print(f"  - Newtonian: {newtonian['pass_20_rate']:.1f}% pass@20% (baseline)")
print("")
print("Saved: reports/comparisons/theory_comparison_test.json")
