#!/usr/bin/env python3
"""Grid search for C10.2 power-tail parameters (A0, alpha, A_flat, r_turn)."""
from __future__ import annotations

import argparse
import itertools
import json
import subprocess
import sys
from pathlib import Path
from typing import Tuple

TRAIN_DEFAULT = "cases/SP99-TRAIN.manifest.txt"
BASE_DEFAULT = "config/global_c10p_tail.json"
SUMMARY_SUFFIX = "C10P_TRAIN"


def _run_solver(manifest: str, cfg_path: Path, quiet: bool) -> None:
    cmd = [
        sys.executable,
        "-m",
        "cli.rft_rc_bench",
        "--batch",
        manifest,
        "--solver",
        "rft_geom",
        "--global-config",
        str(cfg_path),
        "--min-radius",
        "1.0",
        "--max-radius",
        "30.0",
        "--min-points",
        "10",
        "--emit-curves",
        "--max-workers",
        "0",
    ]
    subprocess.run(cmd, check=True, capture_output=quiet)


def _run_aggregator(manifest: str, quiet: bool) -> Path:
    cmd = [
        sys.executable,
        "-m",
        "batch.aggregate",
        "--suffix",
        SUMMARY_SUFFIX,
        "--restrict-manifest",
        manifest,
        "--pass-threshold",
        "20.0",
    ]
    subprocess.run(cmd, check=True, capture_output=quiet)
    return Path(f"reports/_summary_{SUMMARY_SUFFIX}.json")


def _extract_metrics(summary_path: Path) -> Tuple[float, float, int, int]:
    """Returns (pass_rate, median_rms, pass_count, lsb_pass_count)."""
    data = json.loads(summary_path.read_text())
    solver_totals = data["solver_totals"]["rft_geom"]
    rows = data.get("rows", [])

    # Count total passes
    pass_count = sum(
        1
        for row in rows
        if row.get("solvers", {}).get("rft_geom", {}).get("rms_pct", float("inf"))
        <= 20.0
    )

    # Count LSB passes (check tags for "LSB" marker)
    lsb_pass_count = sum(
        1
        for row in rows
        if row.get("solvers", {}).get("rft_geom", {}).get("rms_pct", float("inf"))
        <= 20.0
        and "LSB" in row.get("tags", [])
    )

    n_cases = solver_totals.get("n_cases", len(rows))
    pass_rate = pass_count / max(n_cases, 1)
    median_rms = solver_totals.get("rms_pct_median", float("nan"))

    return pass_rate, median_rms, pass_count, lsb_pass_count


def main() -> None:
    parser = argparse.ArgumentParser(description="C10.2 power-tail grid search")
    parser.add_argument(
        "--train", default=TRAIN_DEFAULT, help="Training manifest path"
    )
    parser.add_argument(
        "--base", default=BASE_DEFAULT, help="Base config to start from"
    )
    parser.add_argument(
        "--out",
        default="config/global_c10p_best.json",
        help="Path to write best config",
    )
    parser.add_argument(
        "--verbose", action="store_true", help="Show solver/aggregator logs"
    )
    args = parser.parse_args()

    train_manifest = args.train
    quiet = not args.verbose

    base_cfg = json.loads(Path(args.base).read_text())

    # Grid parameters: 24 configs total
    # A0 ∈ {120, 180, 260} (km²/s²/kpc)
    # alpha ∈ {0.6, 0.8, 1.0} (power-law exponent)
    # A_flat ∈ {0.28, 0.34} (flattening amplitude)
    # r_turn ∈ {0, 2.0} (onset gate radius, kpc)
    a0_values = [120.0, 180.0, 260.0]
    alpha_values = [0.6, 0.8, 1.0]
    a_flat_values = [0.28, 0.34]
    r_turn_values = [0.0, 2.0]

    combos = list(
        itertools.product(a0_values, alpha_values, a_flat_values, r_turn_values)
    )

    print(f"C10.2 power-tail grid search: {len(combos)} configs")
    print(f"Manifest: {train_manifest}")
    print(f"Base config: {args.base}")
    print(
        f"Grid: A0={a0_values}, alpha={alpha_values}, A_flat={a_flat_values}, r_turn={r_turn_values}\n"
    )

    best = None  # tuple(pass_rate, -median_rms, config, metrics)

    tmp_path = Path(".grid_c10p_tmp.json")

    for idx, (a0, alpha, a_flat, r_turn) in enumerate(combos, 1):
        cfg = json.loads(json.dumps(base_cfg))

        # Update tail mode
        cfg.setdefault("mode_tail", {})
        cfg["mode_tail"]["enabled"] = True
        cfg["mode_tail"]["A0_kms2_per_kpc"] = a0
        cfg["mode_tail"]["alpha"] = alpha
        cfg["mode_tail"]["r_scale"] = "r_geo"
        cfg["mode_tail"]["r_turn_kpc"] = r_turn
        cfg["mode_tail"]["p"] = 2.0

        # Update flattening
        cfg["mode_flattening"]["A_flat"] = a_flat

        # Ensure shelf is OFF
        cfg.setdefault("mode_shelf", {})
        cfg["mode_shelf"]["A_shelf"] = 0.0

        tmp_path.write_text(json.dumps(cfg, indent=2))

        try:
            _run_solver(train_manifest, tmp_path, quiet)
            summary_path = _run_aggregator(train_manifest, quiet)
            pass_rate, median_rms, pass_count, lsb_pass = _extract_metrics(
                summary_path
            )
        except subprocess.CalledProcessError as exc:
            print(
                f"[{idx}/{len(combos)}] A0={a0:.0f} α={alpha:.1f} A_flat={a_flat:.2f} r_turn={r_turn:.1f} -> ERROR {exc}"
            )
            continue

        marker = ""
        score = (pass_rate, -median_rms)
        if best is None or score > (best[0], best[1]):
            best = (
                pass_rate,
                -median_rms,
                cfg,
                pass_count,
                lsb_pass,
                a0,
                alpha,
                a_flat,
                r_turn,
            )
            marker = " ⭐ NEW BEST"

        print(
            f"[{idx:2d}/{len(combos)}] A0={a0:5.0f}  α={alpha:.1f}  A_flat={a_flat:.2f}  r_turn={r_turn:.1f}"
            f" -> pass@20%={pass_rate*100:5.1f}% ({pass_count})  medRMS={median_rms:5.2f}%  LSB_pass={lsb_pass}"
            + marker
        )

    tmp_path.unlink(missing_ok=True)

    if best is None:
        print("\n❌ No successful configurations.")
        return

    best_cfg = best[2]
    pass_rate = best[0]
    median_rms = -best[1]
    pass_count = best[3]
    lsb_pass = best[4]
    best_a0, best_alpha, best_aflat, best_rturn = best[5], best[6], best[7], best[8]

    Path(args.out).write_text(json.dumps(best_cfg, indent=2))

    print("\n" + "=" * 70)
    print("✅ C10.2 Grid Search Complete")
    print("=" * 70)
    print(f"Best config: A0={best_a0:.0f}, α={best_alpha:.1f}, A_flat={best_aflat:.2f}, r_turn={best_rturn:.1f}")
    print(f"Pass@20%: {pass_rate*100:.1f}% ({pass_count}/55)")
    print(f"Median RMS: {median_rms:.2f}%")
    print(f"LSB passes: {lsb_pass}")
    print(f"\nConfig saved to: {args.out}")

    # Wilson 95% CI
    n = 55  # Expected TRAIN cohort size
    k = pass_count
    z = 1.96
    denom = 1 + z**2 / n
    center = (pass_rate + z**2 / (2 * n)) / denom
    import math

    margin = z * math.sqrt(pass_rate * (1 - pass_rate) / n + z**2 / (4 * n**2)) / denom
    lo, hi = max(0, center - margin), min(1, center + margin)

    print(f"\nWilson 95% CI: [{lo*100:.1f}%, {hi*100:.1f}%]")
    print(f"Predicted TEST range (n=34): [{int(34*lo)}, {int(34*hi)}]")

    # Decision rubric
    print("\n" + "=" * 70)
    if pass_rate >= 0.40:
        print("🟢 GREEN ZONE - Proceed to TEST validation")
        print("   Action: Freeze config, pre-register, run full TEST suite")
    elif pass_rate >= 0.30:
        print("🟢 GREEN ZONE (marginal) - Proceed to TEST with caution")
        print("   Action: Check LSB lift, inspect residuals, then run TEST")
    elif pass_rate >= 0.20:
        print("🟡 YELLOW ZONE - Marginal improvement")
        print(
            "   Action: Consider one more iteration (vary sigma_ln_r?) or publish exploratory result"
        )
    else:
        print("🔴 RED ZONE - Model inadequacy confirmed")
        print("   Action: Document negative result, publish comprehensive analysis")
    print("=" * 70)


if __name__ == "__main__":
    main()
