"""Grid search for C10 tail-mode parameters (A0, A_flat)."""
from __future__ import annotations

import argparse
import itertools
import json
import subprocess
import sys
from pathlib import Path
from typing import Tuple

TRAIN_DEFAULT = "cases/SP99-TRAIN.manifest.txt"
BASE_DEFAULT = "config/global_c10_tail.json"
SUMMARY_SUFFIX = "C10GRID"


def _run_solver(manifest: str, cfg_path: Path, quiet: bool) -> None:
    cmd = [
        sys.executable, "-m", "cli.rft_rc_bench",
        "--batch", manifest,
        "--solver", "rft_geom",
        "--global-config", str(cfg_path),
        "--min-radius", "1.0",
        "--max-radius", "30.0",
        "--min-points", "10",
        "--max-workers", "0",
    ]
    if quiet:
        cmd.append("--quiet")
    subprocess.run(cmd, check=True, capture_output=quiet)


def _run_aggregator(manifest: str, quiet: bool) -> Path:
    cmd = [
        sys.executable, "-m", "batch.aggregate",
        "--suffix", SUMMARY_SUFFIX,
        "--restrict-manifest", manifest,
    ]
    subprocess.run(cmd, check=True, capture_output=quiet)
    return Path(f"reports/_summary_{SUMMARY_SUFFIX}.json")


def _extract_metrics(summary_path: Path) -> Tuple[float, float, int]:
    data = json.loads(summary_path.read_text())
    solver_totals = data["solver_totals"]["rft_geom"]
    rows = data.get("rows", [])
    pass_count = sum(
        1
        for row in rows
        if row.get("solvers", {}).get("rft_geom", {}).get("rms_pct", float("inf")) <= 20.0
    )
    n_cases = solver_totals.get("n_cases", len(rows))
    pass_rate = pass_count / max(n_cases, 1)
    median_rms = solver_totals.get("rms_pct_median", float("nan"))
    return pass_rate, median_rms, pass_count


def main() -> None:
    parser = argparse.ArgumentParser(description="C10 tail-mode grid search")
    parser.add_argument("--train", default=TRAIN_DEFAULT, help="Training manifest path")
    parser.add_argument("--base", default=BASE_DEFAULT, help="Base config to start from")
    parser.add_argument("--out", default="config/global_c10_tail_best.json", help="Path to write best config")
    parser.add_argument("--verbose", action="store_true", help="Show solver/aggregator logs")
    args = parser.parse_args()

    train_manifest = args.train
    quiet = not args.verbose

    base_cfg = json.loads(Path(args.base).read_text())

    a0_values = [80.0, 120.0, 180.0, 260.0]
    a_flat_values = [0.28, 0.34, 0.40]
    combos = list(itertools.product(a0_values, a_flat_values))

    print(f"C10 grid search: {len(combos)} configs")
    print(f"Manifest: {train_manifest}")
    print(f"Base config: {args.base}\n")

    best = None  # tuple(pass_rate, -median_rms, config, metrics)

    tmp_path = Path(".grid_c10_tmp.json")

    for idx, (a0, a_flat) in enumerate(combos, 1):
        cfg = json.loads(json.dumps(base_cfg))
        cfg.setdefault("mode_tail", {})
        cfg["mode_tail"]["enabled"] = True
        cfg["mode_tail"]["A0"] = a0
        cfg["mode_flattening"]["A_flat"] = a_flat

        tmp_path.write_text(json.dumps(cfg, indent=2))

        try:
            _run_solver(train_manifest, tmp_path, quiet)
            summary_path = _run_aggregator(train_manifest, quiet)
            pass_rate, median_rms, pass_count = _extract_metrics(summary_path)
        except subprocess.CalledProcessError as exc:
            print(f"[{idx}/{len(combos)}] A0={a0:.0f} A_flat={a_flat:.2f} -> ERROR {exc}")
            continue

        marker = ""
        score = (pass_rate, -median_rms)
        if best is None or score > (best[0], best[1]):
            best = (pass_rate, -median_rms, cfg, pass_count)
            marker = " ⭐ NEW BEST"

        print(
            f"[{idx}/{len(combos)}] A0={a0:5.1f}  A_flat={a_flat:.2f}"
            f" -> pass20={pass_rate*100:5.1f}% ({pass_count})  medianRMS={median_rms:5.2f}%" + marker
        )

    tmp_path.unlink(missing_ok=True)

    if best is None:
        print("No successful configurations.")
        return

    best_cfg = best[2]
    Path(args.out).write_text(json.dumps(best_cfg, indent=2))
    print("\n✅ Saved best config ->", args.out)
    print(f"   pass@20% = {best[0]*100:.1f}%")
    print(f"   median RMS% = {-best[1]:.2f}%")


if __name__ == "__main__":
    main()
