#!/usr/bin/env python3
"""Run all ablation configs and compare to frozen baseline."""
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))

import json
import numpy as np
from sparc_rft.case import load_case
from solver.rft_v2_gated_tail import apply_v2_gated_tail
from metrics.rc_metrics import compute_metrics

def run_config(config_path, manifest_path, cohort_name):
    """Run a single config on a cohort."""
    cfg = json.load(open(config_path))
    kernel_config = cfg["kernel"]
    tail_config = cfg["tail"]

    with open(manifest_path) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    results = []
    for case_path in case_paths:
        try:
            case = load_case(case_path)
            result = apply_v2_gated_tail(case, kernel_config, tail_config)
            r = np.array(result["r_kpc"])
            v_pred = np.array(result["v_pred_kms"])
            v_obs = np.array(case.v_obs_kms)
            sigma = np.array(case.sigma_v_kms)

            mask = (r >= 1.0) & (r <= 30.0) & (sigma > 0)
            if mask.sum() >= 3:
                metrics = compute_metrics(v_obs_kms=v_obs[mask], v_pred_kms=v_pred[mask])
                results.append({
                    "name": case.name,
                    "rms_percent": float(metrics["rms_percent"]),
                    "pass_20": bool(metrics["rms_percent"] <= 20.0)
                })
        except Exception as e:
            print(f"  ERROR {case_path}: {e}", file=sys.stderr)

    pass_20_count = sum(r["pass_20"] for r in results)
    pass_20_rate = 100.0 * pass_20_count / len(results) if results else 0.0
    rms_median = float(np.median([r["rms_percent"] for r in results])) if results else 0.0

    return {
        "config": str(config_path),
        "cohort": cohort_name,
        "n_cases": len(results),
        "pass_20_count": int(pass_20_count),
        "pass_20_rate": float(pass_20_rate),
        "rms_median": float(rms_median),
        "results": results
    }

# Configs to test
configs = {
    "frozen_baseline": "config/global_rc_v2_frozen.json",
    "tail_off": "config/ablations/rc_v2_tail_off.json",
    "no_g_gate": "config/ablations/rc_v2_no_g_gate.json",
    "no_r_gate": "config/ablations/rc_v2_no_r_gate.json",
    "alpha1_const_v": "config/ablations/rc_v2_alpha1_const_v.json",
    "gate_soft": "config/ablations/rc_v2_gate_soft.json"
}

print("="*72)
print("RFT V2 ABLATION ANALYSIS")
print("="*72)
print("Testing 5 ablations + frozen baseline on TEST cohort")
print("Window: 1.0-30.0 kpc, pass@20% threshold")
print("")

# Run TEST for all configs
test_results = {}
for name, config_path in configs.items():
    print(f"Running {name}...")
    try:
        result = run_config(config_path, "cases/SP99-TEST.manifest.txt", "TEST")
        test_results[name] = result
        print(f"  {name:20s}: {result['pass_20_rate']:.1f}% pass@20%, RMS {result['rms_median']:.1f}%")
    except Exception as e:
        print(f"  ERROR: {e}")
        continue

# Save results
Path("reports/ablations").mkdir(parents=True, exist_ok=True)
with open("reports/ablations/ablation_results.json", "w") as f:
    json.dump(test_results, f, indent=2)

print("")
print("="*72)
print("ABLATION COMPARISON (vs frozen baseline)")
print("="*72)

baseline = test_results["frozen_baseline"]
print(f"Baseline: {baseline['pass_20_rate']:.1f}% pass@20%, RMS {baseline['rms_median']:.1f}%")
print("")

for name in ["tail_off", "no_g_gate", "no_r_gate", "alpha1_const_v", "gate_soft"]:
    if name not in test_results:
        continue
    ablation = test_results[name]
    delta_pass = ablation["pass_20_rate"] - baseline["pass_20_rate"]
    delta_rms = ablation["rms_median"] - baseline["rms_median"]

    print(f"{name:20s}: Δpass={delta_pass:+6.1f}pp, ΔRMS={delta_rms:+6.1f}%")

print("")
print("="*72)
print("Saved: reports/ablations/ablation_results.json")
