#!/usr/bin/env python3
"""
Compare Geometric Solver (RFT geom modes) vs RFT v2 (acceleration-gated tail)
on the same 34 TEST galaxies.

This allows direct comparison of the two methodologies:
- Geometric: modes (flat/m1, spiral/m2, core/bar, toroidal)
- RFT v2: acceleration-gated tail with 6 global parameters

Both use k=0 (no per-galaxy tuning).
"""

import json
import sys
from pathlib import Path
from typing import Dict, List

import numpy as np

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from sparc_rft.case import load_case
from solver.rft_geom import rft_geom_predict
from solver.rft_v2_gated_tail import apply_v2_gated_tail


def load_manifest(manifest_path: Path) -> List[Path]:
    """Load galaxy paths from manifest file."""
    with open(manifest_path) as f:
        lines = [line.strip() for line in f if line.strip() and not line.startswith("#")]

    # Resolve paths relative to manifest directory
    manifest_dir = manifest_path.parent
    return [manifest_dir / line for line in lines]


def compute_rms_percent(v_model: np.ndarray, v_obs: np.ndarray) -> float:
    """Compute RMS percentage error."""
    if len(v_model) == 0 or len(v_obs) == 0:
        return np.nan
    residuals = (v_model - v_obs) / v_obs
    return float(100.0 * np.sqrt(np.mean(residuals ** 2)))


def evaluate_galaxy(
    case_path: Path,
    geom_config: Dict,
    v2_config: Dict,
) -> Dict:
    """Evaluate both methods on a single galaxy."""
    try:
        case = load_case(case_path)

        # Run geometric solver
        try:
            geom_result = rft_geom_predict(case, geom_config)
            v_geom = geom_result["v_pred_kms"]
            v_obs_geom = np.asarray(case.v_obs_kms, dtype=float)
            rms_geom = compute_rms_percent(v_geom, v_obs_geom)
            geom_pass = rms_geom <= 20.0
        except Exception as e:
            print(f"  Geometric solver failed: {e}")
            rms_geom = np.nan
            geom_pass = False

        # Run RFT v2
        try:
            v2_result = apply_v2_gated_tail(
                case,
                kernel_config=v2_config["kernel"],
                tail_config=v2_config["tail"],
            )
            v_v2 = v2_result["v_pred_kms"]
            v_obs_v2 = np.asarray(case.v_obs_kms, dtype=float)
            rms_v2 = compute_rms_percent(v_v2, v_obs_v2)
            v2_pass = rms_v2 <= 20.0
        except Exception as e:
            print(f"  RFT v2 failed: {e}")
            rms_v2 = np.nan
            v2_pass = False

        return {
            "galaxy": case.name,
            "geom_rms": rms_geom,
            "geom_pass": geom_pass,
            "v2_rms": rms_v2,
            "v2_pass": v2_pass,
        }

    except Exception as e:
        print(f"  Failed to load case: {e}")
        return {
            "galaxy": case_path.stem,
            "geom_rms": np.nan,
            "geom_pass": False,
            "v2_rms": np.nan,
            "v2_pass": False,
        }


def main():
    # Paths
    repo_root = Path(__file__).parent.parent
    manifest_path = repo_root / "cases" / "SP99-TEST.manifest.txt"

    # Load configs
    # Use best geometric config (C9 with fitted betas)
    geom_config_path = repo_root / "config" / "global_c9_betas_fit.json"
    if not geom_config_path.exists():
        # Fallback to C8
        geom_config_path = repo_root / "config" / "c8_best.json"

    if not geom_config_path.exists():
        print(f"ERROR: Geometric config not found at {geom_config_path}")
        print("Available configs:")
        for cfg in (repo_root / "config").glob("c*.json"):
            print(f"  {cfg}")
        sys.exit(1)

    with open(geom_config_path) as f:
        geom_config = json.load(f)

    # Load RFT v2 frozen config
    v2_config_path = repo_root / "config" / "global_rc_v2_frozen.json"
    if not v2_config_path.exists():
        print(f"ERROR: RFT v2 config not found at {v2_config_path}")
        sys.exit(1)

    with open(v2_config_path) as f:
        v2_config = json.load(f)

    # Load TEST manifest
    if not manifest_path.exists():
        print(f"ERROR: TEST manifest not found at {manifest_path}")
        sys.exit(1)

    test_cases = load_manifest(manifest_path)
    print(f"Loaded {len(test_cases)} TEST galaxies from {manifest_path}")
    print(f"Geometric config: {geom_config_path}")
    print(f"RFT v2 config: {v2_config_path}")
    print()

    # Evaluate all galaxies
    results = []
    for i, case_path in enumerate(test_cases, 1):
        print(f"[{i}/{len(test_cases)}] {case_path.stem}...")
        result = evaluate_galaxy(case_path, geom_config, v2_config)
        results.append(result)

        # Print result
        if not np.isnan(result["geom_rms"]) and not np.isnan(result["v2_rms"]):
            geom_status = "✓" if result["geom_pass"] else "✗"
            v2_status = "✓" if result["v2_pass"] else "✗"
            print(f"  Geom: {result['geom_rms']:.1f}% {geom_status}  |  V2: {result['v2_rms']:.1f}% {v2_status}")
        else:
            print(f"  Failed to compute metrics")
        print()

    # Aggregate statistics
    geom_passes = sum(1 for r in results if r["geom_pass"])
    v2_passes = sum(1 for r in results if r["v2_pass"])
    n_total = len(results)

    geom_pass_rate = 100.0 * geom_passes / n_total
    v2_pass_rate = 100.0 * v2_passes / n_total

    # Head-to-head comparison
    both_pass = sum(1 for r in results if r["geom_pass"] and r["v2_pass"])
    geom_only = sum(1 for r in results if r["geom_pass"] and not r["v2_pass"])
    v2_only = sum(1 for r in results if r["v2_pass"] and not r["geom_pass"])
    both_fail = sum(1 for r in results if not r["geom_pass"] and not r["v2_pass"])

    # Median RMS (ignore NaNs)
    geom_rms_values = [r["geom_rms"] for r in results if not np.isnan(r["geom_rms"])]
    v2_rms_values = [r["v2_rms"] for r in results if not np.isnan(r["v2_rms"])]

    geom_median_rms = np.median(geom_rms_values) if geom_rms_values else np.nan
    v2_median_rms = np.median(v2_rms_values) if v2_rms_values else np.nan

    # Print summary
    print("=" * 70)
    print("COMPARISON SUMMARY: Geometric Solver vs RFT v2")
    print("=" * 70)
    print()
    print(f"Test cohort: {n_total} blind TEST galaxies (SP99-TEST)")
    print()
    print("Pass@20% Performance:")
    print(f"  Geometric Solver:  {geom_passes}/{n_total} ({geom_pass_rate:.1f}%)")
    print(f"  RFT v2 (tail):     {v2_passes}/{n_total} ({v2_pass_rate:.1f}%)")
    print()
    print(f"Difference: {v2_pass_rate - geom_pass_rate:+.1f} percentage points (v2 - geom)")
    print()
    print("Median RMS%:")
    print(f"  Geometric Solver:  {geom_median_rms:.1f}%")
    print(f"  RFT v2:            {v2_median_rms:.1f}%")
    print()
    print("Head-to-Head Breakdown:")
    print(f"  Both pass:         {both_pass}/{n_total} ({100.0 * both_pass / n_total:.1f}%)")
    print(f"  Geometric only:    {geom_only}/{n_total} ({100.0 * geom_only / n_total:.1f}%)")
    print(f"  V2 only:           {v2_only}/{n_total} ({100.0 * v2_only / n_total:.1f}%)")
    print(f"  Both fail:         {both_fail}/{n_total} ({100.0 * both_fail / n_total:.1f}%)")
    print()

    # McNemar test (if applicable)
    if geom_only + v2_only > 0:
        # Exact binomial test under H0: p(geom wins) = p(v2 wins) = 0.5
        try:
            from scipy.stats import binomtest
            result = binomtest(v2_only, n=geom_only + v2_only, p=0.5, alternative='two-sided')
            mcnemar_p = result.pvalue
        except ImportError:
            # Fallback for older scipy
            from scipy.stats import binom
            mcnemar_p = 2 * min(
                binom.cdf(min(geom_only, v2_only), geom_only + v2_only, 0.5),
                1 - binom.cdf(max(geom_only, v2_only) - 1, geom_only + v2_only, 0.5)
            )
        print(f"McNemar's exact test: p = {mcnemar_p:.4f}")
        if mcnemar_p < 0.05:
            winner = "V2" if v2_only > geom_only else "Geometric"
            print(f"  → Significant difference (p<0.05): {winner} wins more")
        else:
            print(f"  → No significant difference (p≥0.05): Competitive")
    else:
        print("McNemar test: N/A (no discordant pairs)")

    print()
    print("=" * 70)

    # Save results to JSON
    output_path = repo_root / "results" / "geom_vs_v2_comparison.json"
    output_path.parent.mkdir(exist_ok=True)

    output = {
        "metadata": {
            "test_cohort": str(manifest_path),
            "n_galaxies": n_total,
            "geom_config": str(geom_config_path),
            "v2_config": str(v2_config_path),
        },
        "aggregate": {
            "geom_pass_rate": geom_pass_rate,
            "v2_pass_rate": v2_pass_rate,
            "geom_median_rms": float(geom_median_rms),
            "v2_median_rms": float(v2_median_rms),
        },
        "head_to_head": {
            "both_pass": both_pass,
            "geom_only": geom_only,
            "v2_only": v2_only,
            "both_fail": both_fail,
        },
        "per_galaxy": results,
    }

    with open(output_path, "w") as f:
        json.dump(output, f, indent=2)

    print(f"Results saved to: {output_path}")


if __name__ == "__main__":
    main()
