"""
Coarse grid search for RFT global parameter optimization (C8).
Maximizes pass@20% on train set, tie-breaks on median RMS%.
"""
import json
import itertools
import subprocess
import argparse
import tempfile
import sys
from pathlib import Path


def run_batch(manifest: str, cfg: dict, quiet: bool = True) -> tuple[float, float, dict]:
    """Run RFT solver with given config and return (pass_rate@20%, median_rms%, config)."""
    with tempfile.NamedTemporaryFile('w', delete=False, suffix='.json') as f:
        json.dump(cfg, f, indent=2)
        f.flush()
        cfg_path = f.name

    try:
        # Run RFT solver
        cmd = [
            sys.executable, "-m", "cli.rft_rc_bench",
            "--batch", manifest,
            "--solver", "rft_geom",
            "--global-config", cfg_path,
            "--min-radius", "1.0",
            "--max-radius", "30.0",
            "--min-points", "10",
            "--max-workers", "0",
        ]
        if quiet:
            cmd.append("--quiet")

        subprocess.run(cmd, check=True, capture_output=quiet)

        # Aggregate results
        subprocess.run(
            [sys.executable, "-m", "batch.aggregate", "--suffix", "C8GRID"],
            check=True,
            capture_output=quiet
        )

        # Extract metrics
        summary_path = Path("reports/_summary_C8GRID.json")
        with summary_path.open() as f:
            summary = json.load(f)

        totals = summary["solver_totals"]["rft_geom"]

        # Compute pass@20% manually from rows
        rows = summary.get("rows", [])
        pass_count = sum(
            1 for r in rows
            if r.get("solvers", {}).get("rft_geom", {}).get("rms_pct", 999) <= 20.0
        )
        n_cases = totals["n_cases"]
        pass_rate = pass_count / max(n_cases, 1)
        median_rms = totals["rms_pct_median"]

        return pass_rate, median_rms, cfg

    finally:
        # Cleanup temp config
        Path(cfg_path).unlink(missing_ok=True)


def parse_ranges(ranges_str: str) -> dict[str, list[float]]:
    """
    Parse ranges like:
    "A_flat=0.22,0.28,0.34;A_core=0.08,0.16,0.24;gamma_core=0.4,0.6,0.8;sigma=0.35,0.45,0.55"
    """
    ranges = {}
    for pair in ranges_str.split(";"):
        key, values_str = pair.split("=")
        ranges[key.strip()] = [float(x) for x in values_str.split(",")]
    return ranges


def main():
    parser = argparse.ArgumentParser(
        description="Grid search for optimal RFT global parameters (C8)"
    )
    parser.add_argument("--train", required=True, help="Training manifest path")
    parser.add_argument(
        "--ranges",
        required=True,
        help='Parameter ranges: "A_flat=0.22,0.28;A_core=0.08,0.16;gamma_core=0.4,0.6;sigma=0.45,0.55"'
    )
    parser.add_argument("--out", required=True, help="Output path for best config")
    parser.add_argument("--base", default="config/global.json", help="Base config to modify")
    parser.add_argument("--verbose", action="store_true", help="Show solver output")
    args = parser.parse_args()

    # Load base config
    base_config = json.loads(Path(args.base).read_text())

    # Parse parameter ranges
    ranges = parse_ranges(args.ranges)

    # Validate required keys
    required = {"A_flat", "A_core", "gamma_core", "sigma"}
    if not all(k in ranges for k in required):
        raise ValueError(f"Missing required ranges. Need: {required}, got: {set(ranges.keys())}")

    # Generate grid
    grid = list(itertools.product(
        ranges["A_flat"],
        ranges["A_core"],
        ranges["gamma_core"],
        ranges["sigma"]
    ))

    print(f"Grid search: {len(grid)} configurations")
    print(f"Train set: {args.train}")
    print(f"Base config: {args.base}")
    print()

    best = (-1.0, 1e9, None)  # (pass_rate, median_rms, config)

    for i, (a_flat, a_core, gamma, sigma) in enumerate(grid, 1):
        # Create config for this combo
        cfg = json.loads(json.dumps(base_config))  # Deep copy
        cfg["mode_flattening"]["A_flat"] = a_flat
        cfg["mode_core"]["A_core"] = a_core
        cfg["mode_core"]["gamma_core"] = gamma
        cfg["mode_spiral"]["sigma_ln_r"] = sigma

        # Run batch
        try:
            pass_rate, median_rms, _ = run_batch(args.train, cfg, quiet=not args.verbose)
        except Exception as e:
            print(f"[{i}/{len(grid)}] ERROR: {e}")
            continue

        # Update best (maximize pass_rate, tie-break minimize median_rms)
        if (pass_rate, -median_rms) > (best[0], -best[1]):
            best = (pass_rate, median_rms, cfg)
            marker = " ⭐ NEW BEST"
        else:
            marker = ""

        print(
            f"[{i}/{len(grid)}] "
            f"A_flat={a_flat:.2f}, A_core={a_core:.2f}, γ={gamma:.1f}, σ={sigma:.2f} "
            f"→ pass@20%={pass_rate*100:.1f}%, RMS={median_rms:.1f}%{marker}"
        )

    # Save best config
    if best[2] is not None:
        out_path = Path(args.out)
        out_path.write_text(json.dumps(best[2], indent=2))
        print()
        print(f"✅ Best configuration saved to {args.out}")
        print(f"   Pass rate @ 20%: {best[0]*100:.1f}%")
        print(f"   Median RMS%: {best[1]:.2f}%")
        print()
        print("Best parameters:")
        print(f"  A_flat:      {best[2]['mode_flattening']['A_flat']:.3f}")
        print(f"  A_core:      {best[2]['mode_core']['A_core']:.3f}")
        print(f"  gamma_core:  {best[2]['mode_core']['gamma_core']:.2f}")
        print(f"  sigma_ln_r:  {best[2]['mode_spiral']['sigma_ln_r']:.3f}")
    else:
        print("❌ No valid configurations found")
        sys.exit(1)


if __name__ == "__main__":
    main()
