#!/usr/bin/env python3
"""
RFT v2 Grid Search: Acceleration-Gated Tail

Pre-registered 12-config grid search to find best tail parameters.
Uses existing kernel (global_c13_kernel.json or global_c19_kernel.json).

Grid:
    A_0:   [0.5, 1.0, 2.0] × 10^-10 m/s²  (3 values)
    alpha: [0.5, 1.0]                      (2 values)
    r_geo: [8, 12] kpc                     (2 values)

Fixed:
    r_min = 5 kpc, Delta_r = 2 kpc
    g_min = 5e-11 m/s², Delta_g = 2e-11 m/s²

Total: 12 configs

Workflow:
    1. Generate 12 tail configs
    2. Run on SPARC TRAIN (99 galaxies) for each config
    3. Select best by median BIC
    4. Run best on SPARC TEST (21 galaxies)
    5. Gate check: GREEN (≥35%), YELLOW (20-35%), RED (<20%)

Author: RFT Cosmology Project
Date: 2025-11-10
"""

import json
import sys
from pathlib import Path
import numpy as np

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from sparc_rft.case import load_case
from solver.rft_v2_gated_tail import apply_v2_gated_tail, validate_v2_config
from metrics.rc_metrics import compute_metrics


def generate_tail_configs(mode="quick"):
    """
    Generate pre-registered tail configurations.

    Args:
        mode: "quick" (24 configs) or "full" (96 configs)

    Returns:
        List of tail config dicts
    """
    configs = []
    config_id = 0

    # GPT's rational gate formula (cleaner than tanh)
    # g_tail = A0 * (r_geo/r)^α * [1 + (gb/g*)^γ]^(-1) * [1 - exp(-(r/r_turn)^p)]

    if mode == "quick":
        # Quick first sweep: 24 configs
        A0_values = [200, 260]  # km²/s²/kpc
        alpha_values = [0.6, 0.8]
        gstar_values = [1500, 3700]  # km²/s²/kpc (MOND-like 3700 and lower)
        gamma_values = [1.0, 1.5]
        rturn_values = [1.5, 2.0]
        p_value = 2.0  # fixed
    else:
        # Full grid: 96 configs
        A0_values = [200, 260, 320]
        alpha_values = [0.6, 0.8]
        gstar_values = [900, 1500, 2500, 3700]
        gamma_values = [1.0, 1.5]
        rturn_values = [1.5, 2.0]
        p_value = 2.0

    for A0 in A0_values:
        for alpha in alpha_values:
            for gstar in gstar_values:
                for gamma in gamma_values:
                    for rturn in rturn_values:
                        config = {
                            "id": f"v2_{config_id:02d}",
                            "A0_kms2_per_kpc": A0,  # km²/s²/kpc
                            "alpha": alpha,
                            "g_star_kms2_per_kpc": gstar,  # km²/s²/kpc
                            "gamma": gamma,
                            "r_turn_kpc": rturn,  # kpc
                            "p": p_value,  # fixed
                            "description": f"A0={A0}, α={alpha}, g*={gstar}, γ={gamma}, rturn={rturn}"
                        }
                        configs.append(config)
                        config_id += 1

    return configs


def run_galaxy_v2(case_path, kernel_config, tail_config, min_radius=1.0, max_radius=30.0):
    """Run RFT v2 on a single galaxy."""
    try:
        case = load_case(case_path)

        # Apply v2 solver
        result = apply_v2_gated_tail(case, kernel_config, tail_config)

        r_kpc = np.array(result["r_kpc"])
        v_pred = np.array(result["v_pred_kms"])
        v_obs = np.array(case.v_obs_kms)
        sigma_v = np.array(case.sigma_v_kms)

        # Apply radius window
        mask = (r_kpc >= min_radius) & (r_kpc <= max_radius) & (sigma_v > 0)
        if mask.sum() < 3:
            return None

        r_fit = r_kpc[mask]
        v_pred_fit = v_pred[mask]
        v_obs_fit = v_obs[mask]
        sigma_fit = sigma_v[mask]

        # Compute metrics
        metrics = compute_metrics(
            v_obs_kms=v_obs_fit,
            v_pred_kms=v_pred_fit
        )

        return {
            "name": case.name,
            "pass": metrics.get("pass", False),
            "rms_percent": metrics.get("rms_percent", 100.0),
            "n_bins": len(v_pred_fit)
        }

    except Exception as e:
        print(f"  ERROR on {case_path}: {e}")
        return None


def run_cohort_v2(manifest_path, kernel_config, tail_config, output_path, min_radius=1.0, max_radius=30.0):
    """Run RFT v2 on a cohort of galaxies."""
    with open(manifest_path) as f:
        case_paths_raw = [line.strip() for line in f if line.strip() and not line.startswith("#")]

    # Prepend 'cases/' if paths don't already have it
    case_paths = []
    for p in case_paths_raw:
        if p.startswith("cases/"):
            case_paths.append(p)
        else:
            case_paths.append(f"cases/{p}")

    print(f"Running {tail_config['id']}: {tail_config['description']}")
    print(f"  {len(case_paths)} galaxies from {manifest_path}")

    results = []
    n_success = 0
    n_fail = 0

    for i, case_path in enumerate(case_paths, 1):
        if i % 10 == 0:
            print(f"  Progress: {i}/{len(case_paths)} galaxies...")

        result = run_galaxy_v2(case_path, kernel_config, tail_config, min_radius, max_radius)
        if result:
            results.append(result)
            n_success += 1
        else:
            n_fail += 1

    # Aggregate stats
    if results:
        passes = [r["pass"] for r in results]
        rms_pcts = [r["rms_percent"] for r in results]
        pass_rate = 100.0 * sum(passes) / len(passes)  # % of galaxies passing
        summary = {
            "config_id": tail_config["id"],
            "config": tail_config,
            "n_galaxies": len(results),
            "n_success": n_success,
            "n_fail": n_fail,
            "pass_rate": float(pass_rate),
            "rms_percent_median": float(np.median(rms_pcts)),
            "results": results
        }
    else:
        summary = {
            "config_id": tail_config["id"],
            "config": tail_config,
            "n_galaxies": 0,
            "n_success": 0,
            "n_fail": n_fail,
            "pass_rate": 0.0,
            "rms_percent_median": 100.0,
            "results": []
        }

    # Write results
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w") as f:
        json.dump(summary, f, indent=2)

    print(f"  ✓ {n_success} success, {n_fail} fail")
    print(f"  Pass%: {summary.get('pass_rate', 0):.1f}%, RMS%: {summary.get('rms_percent_median', 100):.1f}%")
    print(f"  Wrote: {output_path}")
    print()

    return summary


def select_best_config(results_dir):
    """Select best config by median BIC on TRAIN."""
    results_dir = Path(results_dir)
    train_files = list(results_dir.glob("*_train.json"))

    if not train_files:
        raise ValueError(f"No TRAIN results found in {results_dir}")

    print("=" * 70)
    print("TRAIN Results Summary:")
    print("=" * 70)

    configs_summary = []
    for train_file in sorted(train_files):
        with open(train_file) as f:
            data = json.load(f)

        config_id = data["config_id"]
        pass_pct = data.get("pass_rate", 0)
        rms_pct = data.get("rms_percent_median", 100)
        n_gal = data.get("n_galaxies", 0)

        configs_summary.append({
            "config_id": config_id,
            "pass_rate": pass_pct,
            "rms_percent_median": rms_pct,
            "n_galaxies": n_gal,
            "config": data["config"]
        })

        print(f"{config_id}: pass={pass_pct:.1f}%, rms={rms_pct:.1f}%, n={n_gal}")

    # Sort by pass % (higher is better)
    configs_summary.sort(key=lambda x: -x["pass_rate"])

    best = configs_summary[0]
    print()
    print("=" * 70)
    print(f"BEST CONFIG: {best['config_id']}")
    print(f"  {best['config']['description']}")
    print(f"  Pass% median: {best['pass_rate']:.1f}%")
    print(f"  RMS% median: {best['rms_percent_median']:.1f}%")
    print(f"  n_galaxies: {best['n_galaxies']}")
    print("=" * 70)
    print()

    return best["config"]


def main():
    import argparse
    parser = argparse.ArgumentParser(description="RFT v2 Grid Search")
    parser.add_argument("--kernel-config", required=True, help="Kernel config JSON (e.g., config/global_c13_kernel.json)")
    parser.add_argument("--train-manifest", default="cases/SP99-TRAIN.manifest.txt", help="TRAIN manifest")
    parser.add_argument("--test-manifest", default="cases/SP99-TEST.manifest.txt", help="TEST manifest")
    parser.add_argument("--results-dir", default="results/v2", help="Output directory")
    parser.add_argument("--min-radius", type=float, default=1.0, help="Minimum radius [kpc]")
    parser.add_argument("--max-radius", type=float, default=30.0, help="Maximum radius [kpc]")
    parser.add_argument("--skip-train", action="store_true", help="Skip TRAIN, jump to TEST with existing results")
    parser.add_argument("--mode", choices=["quick", "full"], default="quick", help="Grid mode: quick (24 configs) or full (96 configs)")

    args = parser.parse_args()

    # Load kernel config
    with open(args.kernel_config) as f:
        kernel_config = json.load(f)

    print("=" * 70)
    print("RFT v2: Acceleration-Gated Tail Grid Search")
    print("=" * 70)
    print(f"Mode: {args.mode.upper()}")
    print(f"Kernel config: {args.kernel_config}")
    print(f"TRAIN manifest: {args.train_manifest}")
    print(f"TEST manifest: {args.test_manifest}")
    print(f"Results dir: {args.results_dir}")
    print(f"Radius window: [{args.min_radius}, {args.max_radius}] kpc")
    print("=" * 70)
    print()

    # Generate configs
    tail_configs = generate_tail_configs(mode=args.mode)
    print(f"Generated {len(tail_configs)} tail configurations ({args.mode} mode)")
    print()

    results_dir = Path(args.results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)

    # Save config grid
    with open(results_dir / "config_grid.json", "w") as f:
        json.dump(tail_configs, f, indent=2)
    print(f"Saved config grid: {results_dir / 'config_grid.json'}")
    print()

    if not args.skip_train:
        # Run TRAIN for all configs
        print("=" * 70)
        print("PHASE 1: TRAIN Grid Search (12 configs × 99 galaxies)")
        print("=" * 70)
        print()

        for i, tail_config in enumerate(tail_configs, 1):
            print(f"[{i}/{len(tail_configs)}] Running {tail_config['id']}...")
            output_path = results_dir / f"{tail_config['id']}_train.json"

            run_cohort_v2(
                args.train_manifest,
                kernel_config,
                tail_config,
                output_path,
                args.min_radius,
                args.max_radius
            )

    # Select best config
    print("=" * 70)
    print("PHASE 2: Select Best Config")
    print("=" * 70)
    print()

    best_config = select_best_config(results_dir)

    # Save best config
    best_config_path = results_dir / "best_config.json"
    with open(best_config_path, "w") as f:
        json.dump(best_config, f, indent=2)
    print(f"Saved best config: {best_config_path}")
    print()

    # Run TEST with best config
    print("=" * 70)
    print("PHASE 3: TEST with Best Config")
    print("=" * 70)
    print()

    test_output = results_dir / f"{best_config['id']}_test.json"
    test_result = run_cohort_v2(
        args.test_manifest,
        kernel_config,
        best_config,
        test_output,
        args.min_radius,
        args.max_radius
    )

    # Gate check (placeholder - need to compare to baselines)
    print("=" * 70)
    print("PHASE 4: Gate Check")
    print("=" * 70)
    print()
    print("NOTE: Gate check requires comparing pass rates to baseline.")
    print(f"TEST pass% median: {test_result['pass_rate']:.1f}%")
    print(f"TEST rms% median: {test_result['rms_percent_median']:.1f}%")
    print()
    print("Gate criteria:")
    print("  GREEN (≥30%): Proceed to full grid + publication")
    print("  YELLOW (20-30%): One micro-tune iteration")
    print("  RED (<20%): Archive v2 tail")
    print()
    print("=" * 70)
    print("Grid search complete!")
    print(f"Results in: {results_dir}")
    print("=" * 70)


if __name__ == "__main__":
    main()
