#!/usr/bin/env python3
"""
C9 Grid Search - Automated two-stage parameter optimization
Stage A: β-mapping (beta0, beta1, beta2)
Stage B: Shelf parameters (A_shelf, p)
"""
import json
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Tuple

# Unbuffer stdout for real-time progress
import os
os.environ['PYTHONUNBUFFERED'] = '1'
sys.stdout.reconfigure(line_buffering=True)

# Configuration
TRAIN_MANIFEST = "/home/rftuser/cases/SP99-TRAIN.manifest.txt"
BASE_CONFIG = "/home/rftuser/config/global_c9.json"
TRY_CONFIG = "/home/rftuser/config/c9_try.json"
FINAL_CONFIG = "/home/rftuser/config/global_c9.json"

# Stage A: β-mapping grid
BETA0_RANGE = [0.22, 0.28, 0.34, 0.40]
BETA1_RANGE = [0.00, 0.08, 0.15]
BETA2_RANGE = [0.00, 0.12, 0.24, 0.32]

# Stage B: Shelf grid
A_SHELF_RANGE = [0.00, 0.06, 0.12, 0.18]
P_RANGE = [1.0, 1.5, 2.0, 2.5]


def run_solver(config_path: str) -> None:
    """Run solver on TRAIN set with given config."""
    cmd = [
        "python3", "-m", "cli.rft_rc_bench",
        "--batch", TRAIN_MANIFEST,
        "--solver", "rft_geom",
        "--global-config", config_path,
        "--min-radius", "1.0",
        "--max-radius", "30.0",
        "--min-points", "10",
        "--max-workers", "0"
    ]
    # Don't use check=True - galaxy failures shouldn't stop grid search
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"Solver failed with exit code {result.returncode}: {result.stderr[:200]}")


def run_aggregator(suffix: str) -> Dict:
    """Run aggregator with manifest restriction and return summary."""
    cmd = [
        "python3", "-m", "batch.aggregate",
        "--restrict-manifest", TRAIN_MANIFEST,
        "--pass-threshold", "20.0",
        "--suffix", suffix
    ]
    subprocess.run(cmd, check=True, capture_output=True, text=True)

    summary_path = Path(f"reports/_summary_{suffix}.json")
    with summary_path.open("r") as f:
        return json.load(f)


def get_metrics(summary: Dict) -> Tuple[int, float, int, int, int]:
    """Extract key metrics from summary."""
    solver_totals = summary.get("solver_totals", {}).get("rft_geom", {})
    passes = solver_totals.get("pass_count", 0)
    median_rms = solver_totals.get("rms_pct_median", 999.0)
    n_cases = solver_totals.get("n_cases", 0)

    # Count LSB passes
    lsb_passes = 0
    lsb_total = 0
    for row in summary.get("rows", []):
        if "LSB" in row.get("tags", []):
            lsb_total += 1
            if row.get("solvers", {}).get("rft_geom", {}).get("pass", False):
                lsb_passes += 1

    return passes, median_rms, lsb_passes, lsb_total, n_cases


def update_config(params: Dict) -> None:
    """Update config with new parameters."""
    with Path(BASE_CONFIG).open("r") as f:
        config = json.load(f)

    # Update mode_flattening
    if "beta0" in params:
        config["mode_flattening"]["beta0"] = params["beta0"]
        config["mode_flattening"]["beta1"] = params.get("beta1", 0.0)
        config["mode_flattening"]["beta2"] = params.get("beta2", 0.0)

    # Update mode_shelf
    if "A_shelf" in params:
        if "mode_shelf" not in config:
            config["mode_shelf"] = {}
        config["mode_shelf"]["A_shelf"] = params["A_shelf"]
        config["mode_shelf"]["p"] = params.get("p", 1.5)

    with Path(TRY_CONFIG).open("w") as f:
        json.dump(config, f, indent=2)
        f.write("\n")


def stage_a_grid_search() -> Tuple[Dict, List[Dict]]:
    """Stage A: Search β-mapping parameters."""
    print("=" * 60)
    print("STAGE A: β-Mapping Grid Search")
    print("=" * 60)
    print(f"Configurations: {len(BETA0_RANGE) * len(BETA1_RANGE) * len(BETA2_RANGE)}")
    print()

    results = []
    best_config = None
    best_score = (-1, 999.0, -1)  # (passes, median_rms, lsb_passes)

    total_configs = len(BETA0_RANGE) * len(BETA1_RANGE) * len(BETA2_RANGE)
    config_num = 0

    for b0 in BETA0_RANGE:
        for b1 in BETA1_RANGE:
            for b2 in BETA2_RANGE:
                config_num += 1

                # Update config
                params = {"beta0": b0, "beta1": b1, "beta2": b2, "A_shelf": 0.0, "p": 1.5}
                update_config(params)

                # Run solver + aggregator
                try:
                    run_solver(TRY_CONFIG)
                    summary = run_aggregator("C9_A")
                    passes, median_rms, lsb_passes, lsb_total, n_cases = get_metrics(summary)

                    result = {
                        "beta0": b0,
                        "beta1": b1,
                        "beta2": b2,
                        "passes": passes,
                        "median_rms": median_rms,
                        "lsb_passes": lsb_passes,
                        "lsb_total": lsb_total,
                    }
                    results.append(result)

                    # Check if this is the best
                    score = (passes, -median_rms, lsb_passes)  # Higher is better
                    if score > best_score:
                        best_score = score
                        best_config = params

                    status = "⭐ NEW BEST!" if score == best_score else ""
                    print(f"[{config_num:2d}/{total_configs}] β0={b0:.2f} β1={b1:.2f} β2={b2:.2f} | "
                          f"pass={passes:2d}/{n_cases} ({passes*100/n_cases:.1f}%) | "
                          f"RMS={median_rms:.1f}% | LSB={lsb_passes}/{lsb_total} {status}")

                except Exception as e:
                    print(f"[{config_num:2d}/{total_configs}] β0={b0:.2f} β1={b1:.2f} β2={b2:.2f} | FAILED: {e}")

    print()
    print("=" * 60)
    print(f"STAGE A WINNER:")
    print(f"  β0={best_config['beta0']:.2f}, β1={best_config['beta1']:.2f}, β2={best_config['beta2']:.2f}")
    print(f"  Performance: {best_score[0]} passes @ 20%")
    print(f"  Median RMS: {-best_score[1]:.1f}%")
    print("=" * 60)
    print()

    return best_config, results


def stage_b_grid_search(beta_params: Dict) -> Tuple[Dict, List[Dict]]:
    """Stage B: Search shelf parameters with fixed β's."""
    print("=" * 60)
    print("STAGE B: Shelf Parameter Grid Search")
    print("=" * 60)
    print(f"Fixed: β0={beta_params['beta0']:.2f}, β1={beta_params['beta1']:.2f}, β2={beta_params['beta2']:.2f}")
    print(f"Configurations: {len(A_SHELF_RANGE) * len(P_RANGE)}")
    print()

    results = []
    best_config = dict(beta_params)
    best_score = (-1, 999.0, -1)

    total_configs = len(A_SHELF_RANGE) * len(P_RANGE)
    config_num = 0

    for Ash in A_SHELF_RANGE:
        for p in P_RANGE:
            config_num += 1

            # Update config
            params = dict(beta_params)
            params["A_shelf"] = Ash
            params["p"] = p
            update_config(params)

            # Run solver + aggregator
            try:
                run_solver(TRY_CONFIG)
                summary = run_aggregator("C9_B")
                passes, median_rms, lsb_passes, lsb_total, n_cases = get_metrics(summary)

                result = {
                    "A_shelf": Ash,
                    "p": p,
                    "passes": passes,
                    "median_rms": median_rms,
                    "lsb_passes": lsb_passes,
                    "lsb_total": lsb_total,
                }
                results.append(result)

                # Check if this is the best
                score = (passes, -median_rms, lsb_passes)
                if score > best_score:
                    best_score = score
                    best_config = params

                status = "⭐ NEW BEST!" if score == best_score else ""
                print(f"[{config_num:2d}/{total_configs}] A_shelf={Ash:.2f} p={p:.1f} | "
                      f"pass={passes:2d}/{n_cases} ({passes*100/n_cases:.1f}%) | "
                      f"RMS={median_rms:.1f}% | LSB={lsb_passes}/{lsb_total} {status}")

            except Exception as e:
                print(f"[{config_num:2d}/{total_configs}] A_shelf={Ash:.2f} p={p:.1f} | FAILED: {e}")

    print()
    print("=" * 60)
    print(f"STAGE B WINNER:")
    print(f"  A_shelf={best_config['A_shelf']:.2f}, p={best_config['p']:.1f}")
    print(f"  Performance: {best_score[0]} passes @ 20%")
    print(f"  Median RMS: {-best_score[1]:.1f}%")
    print("=" * 60)
    print()

    return best_config, results


def write_final_config(params: Dict) -> None:
    """Write final optimized config."""
    with Path(BASE_CONFIG).open("r") as f:
        config = json.load(f)

    # Update with final parameters
    config["mode_flattening"]["beta0"] = params["beta0"]
    config["mode_flattening"]["beta1"] = params["beta1"]
    config["mode_flattening"]["beta2"] = params["beta2"]

    if "mode_shelf" not in config:
        config["mode_shelf"] = {}
    config["mode_shelf"]["A_shelf"] = params["A_shelf"]
    config["mode_shelf"]["p"] = params["p"]

    with Path(FINAL_CONFIG).open("w") as f:
        json.dump(config, f, indent=2)
        f.write("\n")

    print(f"✅ Final config written to: {FINAL_CONFIG}")


def main():
    # Stage A: β-mapping
    beta_params, stage_a_results = stage_a_grid_search()

    # Stage B: Shelf
    final_params, stage_b_results = stage_b_grid_search(beta_params)

    # Write final config
    write_final_config(final_params)

    # Print summary
    print()
    print("=" * 60)
    print("C9 GRID SEARCH COMPLETE")
    print("=" * 60)
    print(f"Final Parameters:")
    print(f"  β0 = {final_params['beta0']:.3f}")
    print(f"  β1 = {final_params['beta1']:.3f}")
    print(f"  β2 = {final_params['beta2']:.3f}")
    print(f"  A_shelf = {final_params['A_shelf']:.3f}")
    print(f"  p = {final_params['p']:.2f}")
    print()
    print(f"Config frozen at: {FINAL_CONFIG}")
    print(f"Generate SHA256: sha256sum {FINAL_CONFIG}")
    print("=" * 60)


if __name__ == "__main__":
    main()
