#!/usr/bin/env python3
"""
Generate NFW and MOND Baselines on TEST Cohort

Runs NFW (with free c_vir) and MOND (with free a0) on TEST galaxies
to create fair comparison baselines for RFT v2.

Author: RFT Validation System
Date: 2025-11-10
"""

import json
import sys
import numpy as np
from pathlib import Path
from datetime import datetime

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from sparc_rft.case import load_case
from baselines.nfw import nfw_fit_predict
from baselines.mond import mond_predict
from metrics.rc_metrics import compute_metrics


def run_nfw_galaxy(case_path, min_radius=1.0, max_radius=30.0):
    """Run NFW fit on a single galaxy."""
    try:
        case = load_case(case_path)

        # Fit NFW (free ρₛ and rₛ parameters)
        v_pred, params = nfw_fit_predict(case)

        r_kpc = np.array(case.r_kpc)
        v_pred = np.array(v_pred)
        v_obs = np.array(case.v_obs_kms)
        sigma_v = np.array(case.sigma_v_kms)

        # Apply radius window
        mask = (r_kpc >= min_radius) & (r_kpc <= max_radius) & (sigma_v > 0)
        if mask.sum() < 3:
            return None

        r_fit = r_kpc[mask]
        v_pred_fit = v_pred[mask]
        v_obs_fit = v_obs[mask]
        sigma_fit = sigma_v[mask]

        # Compute metrics
        metrics = compute_metrics(v_obs_fit, v_pred_fit)

        return {
            "name": case.name,
            "n_points": int(len(v_obs_fit)),
            "rms_percent": float(metrics["rms_percent"]),
            "pass_20": bool(metrics["rms_percent"] <= 20.0),
            "pass_10": bool(metrics["rms_percent"] <= 10.0),
            "rho_s": float(params.get("rho_s", np.nan)),
            "r_s": float(params.get("r_s", np.nan))
        }

    except Exception as e:
        print(f"ERROR {case_path} (NFW): {e}", file=sys.stderr)
        return None


def run_mond_galaxy(case_path, a0_m_s2=1.2e-10, min_radius=1.0, max_radius=30.0):
    """Run MOND prediction on a single galaxy."""
    try:
        case = load_case(case_path)

        # Predict MOND (global a0, no fitting)
        v_pred, params = mond_predict(case, a0_m_s2=a0_m_s2, law="standard")

        r_kpc = np.array(case.r_kpc)
        v_pred = np.array(v_pred)
        v_obs = np.array(case.v_obs_kms)
        sigma_v = np.array(case.sigma_v_kms)

        # Apply radius window
        mask = (r_kpc >= min_radius) & (r_kpc <= max_radius) & (sigma_v > 0)
        if mask.sum() < 3:
            return None

        r_fit = r_kpc[mask]
        v_pred_fit = v_pred[mask]
        v_obs_fit = v_obs[mask]
        sigma_fit = sigma_v[mask]

        # Compute metrics
        metrics = compute_metrics(v_obs_fit, v_pred_fit)

        return {
            "name": case.name,
            "n_points": int(len(v_obs_fit)),
            "rms_percent": float(metrics["rms_percent"]),
            "pass_20": bool(metrics["rms_percent"] <= 20.0),
            "pass_10": bool(metrics["rms_percent"] <= 10.0),
            "a0_m_s2": float(a0_m_s2)
        }

    except Exception as e:
        print(f"ERROR {case_path} (MOND): {e}", file=sys.stderr)
        return None


def evaluate_cohort(manifest_path, model_fn, min_radius, max_radius):
    """Evaluate model on entire cohort."""
    with open(manifest_path) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    results = []
    for i, case_path in enumerate(case_paths, 1):
        print(f"[{i}/{len(case_paths)}] {case_path.split('/')[-1]}...", end=" ")
        result = model_fn(case_path, min_radius, max_radius)
        if result:
            results.append(result)
            status = "PASS" if result["pass_20"] else "FAIL"
            print(f"{status} (RMS={result['rms_percent']:.1f}%)")
        else:
            print("SKIP")

    if not results:
        return None

    # Aggregate
    pass_20_count = sum(r["pass_20"] for r in results)
    pass_10_count = sum(r["pass_10"] for r in results)
    pass_20_rate = 100.0 * pass_20_count / len(results)
    pass_10_rate = 100.0 * pass_10_count / len(results)
    rms_median = float(np.median([r["rms_percent"] for r in results]))

    return {
        "n_galaxies": len(results),
        "pass_20_count": int(pass_20_count),
        "pass_20_rate": float(pass_20_rate),
        "pass_10_count": int(pass_10_count),
        "pass_10_rate": float(pass_10_rate),
        "rms_median": float(rms_median),
        "per_galaxy": results
    }


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Generate NFW and MOND baselines on TEST cohort")
    parser.add_argument("--test-manifest", default="cases/SP99-TEST.manifest.txt", help="TEST cohort")
    parser.add_argument("--min-radius", type=float, default=1.0, help="Min radius (kpc)")
    parser.add_argument("--max-radius", type=float, default=30.0, help="Max radius (kpc)")
    parser.add_argument("--output-dir", default="baselines/results", help="Output directory")
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("="*72)
    print("NFW AND MOND BASELINE GENERATION")
    print("="*72)
    print(f"TEST manifest: {args.test_manifest}")
    print(f"Radius window: {args.min_radius}-{args.max_radius} kpc")
    print("")

    # Run NFW
    print("="*72)
    print("RUNNING NFW (free c_vir)")
    print("="*72)
    nfw_result = evaluate_cohort(args.test_manifest, run_nfw_galaxy, args.min_radius, args.max_radius)

    if nfw_result:
        print(f"\nNFW: {nfw_result['pass_20_rate']:.1f}% pass@20% ({nfw_result['pass_20_count']}/{nfw_result['n_galaxies']})")
        print(f"NFW: {nfw_result['pass_10_rate']:.1f}% pass@10% ({nfw_result['pass_10_count']}/{nfw_result['n_galaxies']})")
        print(f"RMS median: {nfw_result['rms_median']:.1f}%")

        nfw_output = {
            "model": "NFW",
            "timestamp": datetime.now().isoformat(),
            "cohort": "TEST",
            **nfw_result
        }
        nfw_path = output_dir / "nfw_test_baseline.json"
        with open(nfw_path, "w") as f:
            json.dump(nfw_output, f, indent=2)
        print(f"\nSaved: {nfw_path}")

    # Run MOND
    print("")
    print("="*72)
    print("RUNNING MOND (canonical a0=1.2e-10 m/s²)")
    print("="*72)
    mond_result = evaluate_cohort(args.test_manifest, lambda p, mn, mx: run_mond_galaxy(p, a0_m_s2=1.2e-10, min_radius=mn, max_radius=mx), args.min_radius, args.max_radius)

    if mond_result:
        print(f"\nMOND: {mond_result['pass_20_rate']:.1f}% pass@20% ({mond_result['pass_20_count']}/{mond_result['n_galaxies']})")
        print(f"MOND: {mond_result['pass_10_rate']:.1f}% pass@10% ({mond_result['pass_10_count']}/{mond_result['n_galaxies']})")
        print(f"RMS median: {mond_result['rms_median']:.1f}%")

        mond_output = {
            "model": "MOND",
            "timestamp": datetime.now().isoformat(),
            "cohort": "TEST",
            **mond_result
        }
        mond_path = output_dir / "mond_test_baseline.json"
        with open(mond_path, "w") as f:
            json.dump(mond_output, f, indent=2)
        print(f"\nSaved: {mond_path}")

    # Summary comparison
    print("")
    print("="*72)
    print("SUMMARY COMPARISON (TEST, n=34)")
    print("="*72)
    if nfw_result and mond_result:
        print(f"NFW:  {nfw_result['pass_20_rate']:.1f}% pass@20%")
        print(f"MOND: {mond_result['pass_20_rate']:.1f}% pass@20%")
        print(f"RFT:  58.8% pass@20% (frozen v2)")
        print("")
        print("RFT advantage:")
        print(f"  vs NFW:  {58.8 - nfw_result['pass_20_rate']:+.1f} percentage points")
        print(f"  vs MOND: {58.8 - mond_result['pass_20_rate']:+.1f} percentage points")

    print("")
    print("="*72)
    print("BASELINE GENERATION COMPLETE")
    print("="*72)


if __name__ == "__main__":
    main()
