#!/usr/bin/env python3
"""
RFT v2.1 Grid Search: Shape Refinement Around Frozen Best

Pre-registered 81-config grid search to refine v2 frozen config.
Tests nearby shape variations (α, p, A₀, r_turn) to see if generalization improves.

Grid (pre-registered):
    p ∈ {1.5, 2.0, 2.5}        (onset smoothness)
    α ∈ {0.5, 0.6, 0.7}        (radial decay)
    A₀ ∈ {900, 1000, 1100}     (±10% amplitude)
    r_turn ∈ {1.7, 2.0, 2.3}   (±15% onset radius)

Fixed:
    g* = 1000 km²/s²/kpc
    γ = 0.5

Total: 81 configs

Selection:
    1. Run on TRAIN (n=65)
    2. Select best by BIC
    3. Evaluate on TEST (n=34) - blind

Author: RFT Cosmology Project
Date: 2025-11-10
Pre-reg: RFT_V2.1_PREREG.md
"""

import json
import sys
from pathlib import Path
import numpy as np
from datetime import datetime

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from sparc_rft.case import load_case
from solver.rft_v2_gated_tail import apply_v2_gated_tail
from metrics.rc_metrics import compute_metrics


def generate_v21_configs():
    """
    Generate v2.1 pre-registered shape refinement grid.

    Returns:
        List of 81 tail config dicts
    """
    configs = []
    config_id = 0

    # Pre-registered grid (centered on frozen v2)
    p_values = [1.5, 2.0, 2.5]
    alpha_values = [0.5, 0.6, 0.7]
    A0_values = [900, 1000, 1100]
    rturn_values = [1.7, 2.0, 2.3]

    # Frozen (no variation)
    gstar = 1000
    gamma = 0.5

    for p in p_values:
        for alpha in alpha_values:
            for A0 in A0_values:
                for rturn in rturn_values:
                    config = {
                        "id": f"v2.1_{config_id:03d}",
                        "A0_kms2_per_kpc": A0,
                        "alpha": alpha,
                        "g_star_kms2_per_kpc": gstar,
                        "gamma": gamma,
                        "r_turn_kpc": rturn,
                        "p": p,
                        "description": f"p={p}, α={alpha}, A0={A0}, rturn={rturn}"
                    }
                    configs.append(config)
                    config_id += 1

    return configs


def run_galaxy_v21(case_path, kernel_config, tail_config, min_radius=1.0, max_radius=30.0):
    """Run RFT v2.1 on a single galaxy."""
    try:
        case = load_case(case_path)

        # Apply v2 solver (same physics, different params)
        result = apply_v2_gated_tail(case, kernel_config, tail_config)

        r_kpc = np.array(result["r_kpc"])
        v_pred = np.array(result["v_pred_kms"])
        v_obs = np.array(case.v_obs_kms)
        sigma_v = np.array(case.sigma_v_kms)

        # Apply radius window
        mask = (r_kpc >= min_radius) & (r_kpc <= max_radius) & (sigma_v > 0)
        if mask.sum() < 3:
            return None

        r_fit = r_kpc[mask]
        v_pred_fit = v_pred[mask]
        v_obs_fit = v_obs[mask]
        sigma_fit = sigma_v[mask]

        # Compute metrics
        metrics = compute_metrics(v_obs_fit, v_pred_fit)

        # BIC = n·ln(RSS/n) + k·ln(n), k=6 (α,p,A0,rturn,g*,γ)
        n = len(v_obs_fit)
        k = 6
        rss = np.sum((v_obs_fit - v_pred_fit)**2)
        bic = n * np.log(rss / n) + k * np.log(n)

        return {
            "name": case.name,
            "n_points": int(n),
            "rms_percent": float(metrics["rms_percent"]),
            "pass_20": bool(metrics["rms_percent"] <= 20.0),
            "pass_10": bool(metrics["rms_percent"] <= 10.0),
            "bic": float(bic),
            "rss": float(rss)
        }

    except Exception as e:
        print(f"ERROR {case_path}: {e}", file=sys.stderr)
        return None


def evaluate_config_on_cohort(tail_config, kernel_config, manifest_path, min_radius, max_radius):
    """Evaluate one tail config on entire cohort."""
    with open(manifest_path) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    results = []
    for case_path in case_paths:
        result = run_galaxy_v21(case_path, kernel_config, tail_config, min_radius, max_radius)
        if result:
            results.append(result)

    if not results:
        return None

    # Aggregate metrics
    pass_20_count = sum(r["pass_20"] for r in results)
    pass_10_count = sum(r["pass_10"] for r in results)
    pass_20_rate = 100.0 * pass_20_count / len(results)
    pass_10_rate = 100.0 * pass_10_count / len(results)
    rms_median = float(np.median([r["rms_percent"] for r in results]))
    bic_median = float(np.median([r["bic"] for r in results]))
    bic_sum = float(np.sum([r["bic"] for r in results]))

    return {
        "config_id": tail_config["id"],
        "tail_config": tail_config,
        "n_galaxies": len(results),
        "pass_20_count": int(pass_20_count),
        "pass_20_rate": float(pass_20_rate),
        "pass_10_count": int(pass_10_count),
        "pass_10_rate": float(pass_10_rate),
        "rms_median": float(rms_median),
        "bic_median": float(bic_median),
        "bic_sum": float(bic_sum),
        "per_galaxy": results
    }


def main():
    import argparse
    parser = argparse.ArgumentParser(description="RFT v2.1 shape refinement grid search")
    parser.add_argument("--kernel-config", default="config/global_v2_identity.json", help="Kernel config path")
    parser.add_argument("--train-manifest", default="cases/SP99-TRAIN.manifest.txt", help="TRAIN cohort")
    parser.add_argument("--test-manifest", default="cases/SP99-TEST.manifest.txt", help="TEST cohort")
    parser.add_argument("--min-radius", type=float, default=1.0, help="Min radius (kpc)")
    parser.add_argument("--max-radius", type=float, default=30.0, help="Max radius (kpc)")
    parser.add_argument("--results-dir", default="results/v2.1_refine", help="Output directory")
    args = parser.parse_args()

    results_dir = Path(args.results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)

    # Load kernel config (identity)
    with open(args.kernel_config) as f:
        kernel_config = json.load(f)

    print("="*72)
    print("RFT v2.1 SHAPE REFINEMENT GRID SEARCH")
    print("="*72)
    print(f"Kernel: {args.kernel_config}")
    print(f"TRAIN: {args.train_manifest}")
    print(f"TEST:  {args.test_manifest}")
    print(f"Window: {args.min_radius}-{args.max_radius} kpc")
    print("")

    # Generate v2.1 configs
    tail_configs = generate_v21_configs()
    print(f"Generated {len(tail_configs)} v2.1 configurations")
    print("")

    # Phase 1: TRAIN evaluation
    print("="*72)
    print("PHASE 1: TRAIN EVALUATION (n=65)")
    print("="*72)

    train_results = []
    for i, tail_config in enumerate(tail_configs, 1):
        print(f"\n[{i}/{len(tail_configs)}] Config {tail_config['id']}: {tail_config['description']}")
        result = evaluate_config_on_cohort(
            tail_config, kernel_config, args.train_manifest,
            args.min_radius, args.max_radius
        )
        if result:
            train_results.append(result)
            print(f"  TRAIN: {result['pass_20_rate']:.1f}% pass@20%, BIC_sum={result['bic_sum']:.1f}")

    # Save all TRAIN results
    train_output = {
        "phase": "TRAIN",
        "n_configs": len(train_results),
        "timestamp": datetime.now().isoformat(),
        "configs": train_results
    }
    train_path = results_dir / "v2.1_train_results.json"
    with open(train_path, "w") as f:
        json.dump(train_output, f, indent=2)
    print(f"\nSaved TRAIN results: {train_path}")

    # Select best by BIC (lower is better)
    best_config = min(train_results, key=lambda x: x["bic_sum"])
    print("")
    print("="*72)
    print("BEST CONFIG (by BIC on TRAIN)")
    print("="*72)
    print(f"Config ID: {best_config['config_id']}")
    print(f"  {best_config['tail_config']['description']}")
    print(f"  TRAIN: {best_config['pass_20_rate']:.1f}% pass@20%, BIC_sum={best_config['bic_sum']:.1f}")
    print("")

    # Save best config
    best_config_path = results_dir / "v2.1_best_config.json"
    with open(best_config_path, "w") as f:
        json.dump({
            "kernel": kernel_config,
            "tail": best_config["tail_config"],
            "train_performance": {
                "pass_20_rate": best_config["pass_20_rate"],
                "pass_10_rate": best_config["pass_10_rate"],
                "rms_median": best_config["rms_median"],
                "bic_sum": best_config["bic_sum"]
            }
        }, f, indent=2)
    print(f"Saved best config: {best_config_path}")

    # Phase 2: TEST evaluation (blind)
    print("")
    print("="*72)
    print("PHASE 2: TEST EVALUATION (n=34, BLIND)")
    print("="*72)
    print("Applying best config to TEST cohort...")

    test_result = evaluate_config_on_cohort(
        best_config["tail_config"], kernel_config, args.test_manifest,
        args.min_radius, args.max_radius
    )

    print("")
    print(f"Config: {best_config['config_id']}")
    print(f"  TEST: {test_result['pass_20_rate']:.1f}% pass@20% ({test_result['pass_20_count']}/{test_result['n_galaxies']})")
    print(f"  TEST: {test_result['pass_10_rate']:.1f}% pass@10% ({test_result['pass_10_count']}/{test_result['n_galaxies']})")
    print(f"  RMS median: {test_result['rms_median']:.1f}%")
    print("")

    # Gate check
    test_pass_rate = test_result["pass_20_rate"]
    if test_pass_rate >= 30.0:
        gate = "GREEN"
    elif test_pass_rate >= 20.0:
        gate = "YELLOW"
    else:
        gate = "RED"

    print(f"GATE: {gate} ({test_pass_rate:.1f}% pass@20%)")

    # Save TEST results
    test_output = {
        "phase": "TEST",
        "config_id": best_config["config_id"],
        "tail_config": best_config["tail_config"],
        "gate": gate,
        "timestamp": datetime.now().isoformat(),
        **test_result
    }
    test_path = results_dir / "v2.1_test_results.json"
    with open(test_path, "w") as f:
        json.dump(test_output, f, indent=2)
    print(f"\nSaved TEST results: {test_path}")

    # Comparison to frozen v2
    print("")
    print("="*72)
    print("COMPARISON TO FROZEN v2")
    print("="*72)
    print("Frozen v2: 58.8% pass@20% on TEST (20/34)")
    print(f"v2.1:      {test_pass_rate:.1f}% pass@20% on TEST ({test_result['pass_20_count']}/34)")
    delta = test_pass_rate - 58.8
    print(f"Delta:     {delta:+.1f} percentage points")

    if delta > 0:
        print("\n✓ v2.1 WINS: Shape refinement improved generalization")
    elif delta == 0:
        print("\n≈ TIE: Shape refinement had no effect")
    else:
        print("\n✗ v2 WINS: Frozen config was optimal")

    print("")
    print("="*72)
    print("v2.1 GRID SEARCH COMPLETE")
    print("="*72)


if __name__ == "__main__":
    main()
