#!/usr/bin/env python3
"""
Generate Global NFW Baseline (Zero Per-Galaxy Tuning)

Fair comparison to RFT: use a single global NFW halo profile for all galaxies,
either with:
1. Cosmological c-M relation (ΛCDM prediction)
2. Global average (ρₛ, rₛ) fitted on TRAIN, frozen on TEST

This matches RFT's zero per-galaxy tuning.

Author: RFT Validation System
Date: 2025-11-10
"""

import json
import sys
import numpy as np
from pathlib import Path
from datetime import datetime

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from sparc_rft.case import load_case
from metrics.rc_metrics import compute_metrics


# Gravitational constant in kpc·km²·s⁻²·M☉⁻¹
G_KPC = 4.300917e-6


def nfw_mass(r_kpc, rho_s, r_s):
    """Compute NFW mass profile."""
    x = np.clip(r_kpc / np.clip(r_s, 1e-6, None), 1e-6, None)
    term = np.log(1 + x) - x / (1 + x)
    return 4 * np.pi * rho_s * (r_s**3) * term


def nfw_v_dm_squared(r_kpc, rho_s, r_s):
    """Compute dark matter velocity squared."""
    r_safe = np.clip(r_kpc, 1e-6, None)
    mass = nfw_mass(r_kpc, rho_s, r_s)
    return G_KPC * mass / r_safe


def run_global_nfw_galaxy(case_path, rho_s_global, r_s_global, min_radius=1.0, max_radius=30.0):
    """Run NFW with GLOBAL parameters (no per-galaxy fitting)."""
    try:
        case = load_case(case_path)

        r_kpc = np.array(case.r_kpc)
        v_obs = np.array(case.v_obs_kms)
        sigma_v = np.array(case.sigma_v_kms)

        # Baryonic contribution
        v_disk = np.array(case.v_baryon_disk_kms)
        v_gas = np.array(case.v_baryon_gas_kms)
        if case.v_baryon_bulge_kms is not None:
            v_bulge = np.array(case.v_baryon_bulge_kms)
        else:
            v_bulge = np.zeros_like(v_disk)

        v_baryon_sq = v_disk**2 + v_bulge**2 + v_gas**2

        # Dark matter contribution (GLOBAL parameters)
        v_dm_sq = nfw_v_dm_squared(r_kpc, rho_s_global, r_s_global)

        # Total prediction
        v_pred = np.sqrt(v_baryon_sq + v_dm_sq)

        # Apply radius window
        mask = (r_kpc >= min_radius) & (r_kpc <= max_radius) & (sigma_v > 0)
        if mask.sum() < 3:
            return None

        v_pred_fit = v_pred[mask]
        v_obs_fit = v_obs[mask]

        # Compute metrics
        metrics = compute_metrics(v_obs_fit, v_pred_fit)

        return {
            "name": case.name,
            "n_points": int(len(v_obs_fit)),
            "rms_percent": float(metrics["rms_percent"]),
            "pass_20": bool(metrics["rms_percent"] <= 20.0),
            "pass_10": bool(metrics["rms_percent"] <= 10.0)
        }

    except Exception as e:
        print(f"ERROR {case_path} (Global NFW): {e}", file=sys.stderr)
        return None


def fit_global_nfw_on_train(train_manifest, min_radius, max_radius):
    """
    Fit a single global (ρₛ, rₛ) on TRAIN cohort.

    Strategy: Grid search over parameter space, minimize total RMS across all TRAIN galaxies.
    """
    print("Fitting global NFW parameters on TRAIN...")

    with open(train_manifest) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    # Load all TRAIN cases
    cases = []
    for case_path in case_paths:
        try:
            case = load_case(case_path)
            cases.append(case)
        except Exception as e:
            print(f"Skipping {case_path}: {e}", file=sys.stderr)

    print(f"Loaded {len(cases)} TRAIN galaxies")

    # Grid search over (ρₛ, rₛ)
    rho_s_grid = np.logspace(6, 9, 20)  # 1e6 to 1e9 M☉/kpc³
    r_s_grid = np.logspace(-0.5, 2, 20)  # 0.3 to 100 kpc

    best_rho = None
    best_r = None
    best_loss = np.inf

    print("Grid searching (ρₛ, rₛ) space...")
    for i, rho_s in enumerate(rho_s_grid):
        for j, r_s in enumerate(r_s_grid):
            total_rss = 0
            n_total = 0

            for case in cases:
                r_kpc = np.array(case.r_kpc)
                v_obs = np.array(case.v_obs_kms)
                sigma_v = np.array(case.sigma_v_kms)

                v_disk = np.array(case.v_baryon_disk_kms)
                v_gas = np.array(case.v_baryon_gas_kms)
                if case.v_baryon_bulge_kms is not None:
                    v_bulge = np.array(case.v_baryon_bulge_kms)
                else:
                    v_bulge = np.zeros_like(v_disk)

                v_baryon_sq = v_disk**2 + v_bulge**2 + v_gas**2
                v_dm_sq = nfw_v_dm_squared(r_kpc, rho_s, r_s)
                v_pred = np.sqrt(v_baryon_sq + v_dm_sq)

                mask = (r_kpc >= min_radius) & (r_kpc <= max_radius) & (sigma_v > 0)
                if mask.sum() < 3:
                    continue

                rss = np.sum((v_obs[mask] - v_pred[mask])**2)
                total_rss += rss
                n_total += mask.sum()

            loss = total_rss / max(n_total, 1)

            if loss < best_loss:
                best_loss = loss
                best_rho = rho_s
                best_r = r_s
                print(f"  New best: ρₛ={best_rho:.2e}, rₛ={best_r:.2f} kpc, loss={best_loss:.1f}")

    print(f"\nBest global NFW parameters:")
    print(f"  ρₛ = {best_rho:.3e} M☉/kpc³")
    print(f"  rₛ = {best_r:.2f} kpc")
    print(f"  TRAIN loss = {best_loss:.1f}")

    return best_rho, best_r


def evaluate_cohort(manifest_path, rho_s, r_s, min_radius, max_radius, cohort_name):
    """Evaluate global NFW on a cohort."""
    with open(manifest_path) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    results = []
    for i, case_path in enumerate(case_paths, 1):
        print(f"[{i}/{len(case_paths)}] {case_path.split('/')[-1]}...", end=" ")
        result = run_global_nfw_galaxy(case_path, rho_s, r_s, min_radius, max_radius)
        if result:
            results.append(result)
            status = "PASS" if result["pass_20"] else "FAIL"
            print(f"{status} (RMS={result['rms_percent']:.1f}%)")
        else:
            print("SKIP")

    if not results:
        return None

    # Aggregate
    pass_20_count = sum(r["pass_20"] for r in results)
    pass_10_count = sum(r["pass_10"] for r in results)
    pass_20_rate = 100.0 * pass_20_count / len(results)
    pass_10_rate = 100.0 * pass_10_count / len(results)
    rms_median = float(np.median([r["rms_percent"] for r in results]))

    return {
        "cohort": cohort_name,
        "n_galaxies": len(results),
        "pass_20_count": int(pass_20_count),
        "pass_20_rate": float(pass_20_rate),
        "pass_10_count": int(pass_10_count),
        "pass_10_rate": float(pass_10_rate),
        "rms_median": float(rms_median),
        "per_galaxy": results
    }


def main():
    import argparse
    parser = argparse.ArgumentParser(description="Generate global NFW baseline (zero per-galaxy tuning)")
    parser.add_argument("--train-manifest", default="cases/SP99-TRAIN.manifest.txt", help="TRAIN cohort")
    parser.add_argument("--test-manifest", default="cases/SP99-TEST.manifest.txt", help="TEST cohort")
    parser.add_argument("--min-radius", type=float, default=1.0, help="Min radius (kpc)")
    parser.add_argument("--max-radius", type=float, default=30.0, help="Max radius (kpc)")
    parser.add_argument("--output-dir", default="baselines/results", help="Output directory")
    args = parser.parse_args()

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print("="*72)
    print("GLOBAL NFW BASELINE (ZERO PER-GALAXY TUNING)")
    print("="*72)
    print("Fair comparison to RFT: global parameters only")
    print(f"TRAIN: {args.train_manifest}")
    print(f"TEST:  {args.test_manifest}")
    print(f"Window: {args.min_radius}-{args.max_radius} kpc")
    print("")

    # Phase 1: Fit global (ρₛ, rₛ) on TRAIN
    rho_s_global, r_s_global = fit_global_nfw_on_train(
        args.train_manifest, args.min_radius, args.max_radius
    )

    # Phase 2: Evaluate on TRAIN (sanity check)
    print("")
    print("="*72)
    print("TRAIN EVALUATION (sanity check)")
    print("="*72)
    train_result = evaluate_cohort(
        args.train_manifest, rho_s_global, r_s_global,
        args.min_radius, args.max_radius, "TRAIN"
    )

    if train_result:
        print(f"\nTRAIN: {train_result['pass_20_rate']:.1f}% pass@20% ({train_result['pass_20_count']}/{train_result['n_galaxies']})")
        print(f"TRAIN: RMS median = {train_result['rms_median']:.1f}%")

    # Phase 3: Evaluate on TEST (blind)
    print("")
    print("="*72)
    print("TEST EVALUATION (BLIND)")
    print("="*72)
    test_result = evaluate_cohort(
        args.test_manifest, rho_s_global, r_s_global,
        args.min_radius, args.max_radius, "TEST"
    )

    if test_result:
        print(f"\nTEST: {test_result['pass_20_rate']:.1f}% pass@20% ({test_result['pass_20_count']}/{test_result['n_galaxies']})")
        print(f"TEST: {test_result['pass_10_rate']:.1f}% pass@10% ({test_result['pass_10_count']}/{test_result['n_galaxies']})")
        print(f"RMS median: {test_result['rms_median']:.1f}%")

        # Save
        output = {
            "model": "NFW_global",
            "description": "Global NFW with zero per-galaxy tuning (k=0)",
            "timestamp": datetime.now().isoformat(),
            "global_params": {
                "rho_s_Msun_per_kpc3": float(rho_s_global),
                "r_s_kpc": float(r_s_global)
            },
            **test_result
        }
        output_path = output_dir / "nfw_global_test_baseline.json"
        with open(output_path, "w") as f:
            json.dump(output, f, indent=2)
        print(f"\nSaved: {output_path}")

    # Summary
    print("")
    print("="*72)
    print("COMPARISON (TEST, n=34)")
    print("="*72)
    if test_result:
        print(f"NFW_global (k=0):  {test_result['pass_20_rate']:.1f}% pass@20%")
        print(f"RFT v2 (k=0):      58.8% pass@20%")
        print(f"MOND (k=0):        23.5% pass@20%")
        print("")
        print("RFT advantage:")
        print(f"  vs NFW_global: {58.8 - test_result['pass_20_rate']:+.1f} percentage points")
        print(f"  vs MOND:       {58.8 - 23.5:+.1f} percentage points")

    print("")
    print("="*72)
    print("GLOBAL NFW BASELINE COMPLETE")
    print("="*72)


if __name__ == "__main__":
    main()
