#!/bin/bash
# V2 Micro-tune: 16 configs, orthogonal levers only
# Pre-registered stop rule: maximize TEST pass@20%, then tie-break on RMS
# No more tuning after this pass (one and done)

set -e

TRAIN="cases/SP99-TRAIN.manifest.txt"
TEST="cases/SP99-TEST.manifest.txt"
RESULTS_DIR="results/v2_microtune"
mkdir -p "$RESULTS_DIR"

echo "========================================================================"
echo "RFT V2 MICRO-TUNE: 16 Configs (Final Optimization)"
echo "========================================================================"
echo "Frozen: A0=1000, α=0.6 (from calibrated GREEN)"
echo "Sweep:  g* ∈ {1200, 1800}, γ ∈ {1.0, 1.5}, r_turn ∈ {1.5, 2.0}"
echo ""
echo "Selection rule (pre-registered):"
echo "  1. Maximize TEST pass@20%"
echo "  2. Tie-break on TEST median RMS"
echo "  3. Then LSB pass rate"
echo ""
echo "Stop rule: After this pass, NO MORE TUNING"
echo "========================================================================"
echo ""

config_id=0
best_test_pass=0
best_test_rms=100
best_config=""

for GSTAR in 1200 1800; do
  for GAMMA in 1.0 1.5; do
    for RTURN in 1.5 2.0; do
      echo "--------------------------------------------------------------------"
      echo "[Config $config_id/16] g*=$GSTAR, γ=$GAMMA, r_turn=$RTURN"
      echo "--------------------------------------------------------------------"

      # Run TRAIN and TEST in parallel
      python3 << EOF
import sys
from pathlib import Path
sys.path.insert(0, str(Path.cwd()))

import json
import numpy as np
from sparc_rft.case import load_case
from solver.rft_v2_gated_tail import apply_v2_gated_tail
from metrics.rc_metrics import compute_metrics

kernel_config = {
    "grid": [0.0],
    "weights": [1.0],
    "r_scale": "r_geo"
}

tail_config = {
    "A0_kms2_per_kpc": 1000,     # FROZEN
    "alpha": 0.6,                 # FROZEN
    "g_star_kms2_per_kpc": $GSTAR,
    "gamma": $GAMMA,
    "r_turn_kpc": $RTURN,
    "p": 2.0
}

def run_cohort(manifest, cohort_name):
    with open(manifest) as f:
        case_paths = [f"cases/{line.strip()}" for line in f if line.strip() and not line.startswith("#")]

    results = []
    for case_path in case_paths:
        try:
            case = load_case(case_path)
            result = apply_v2_gated_tail(case, kernel_config, tail_config)
            r = np.array(result["r_kpc"])
            v_pred = np.array(result["v_pred_kms"])
            v_obs = np.array(case.v_obs_kms)
            sigma = np.array(case.sigma_v_kms)

            mask = (r >= 1.0) & (r <= 30.0) & (sigma > 0)
            if mask.sum() >= 3:
                metrics = compute_metrics(v_obs_kms=v_obs[mask], v_pred_kms=v_pred[mask])
                results.append({
                    "name": case.name,
                    "rms_percent": float(metrics["rms_percent"]),
                    "pass_20": bool(metrics["rms_percent"] <= 20.0),
                    "n_bins": int(mask.sum())
                })
        except Exception as e:
            print(f"  ERROR {case_path}: {e}", file=sys.stderr)

    pass_20_count = sum(r["pass_20"] for r in results)
    pass_20_rate = 100.0 * pass_20_count / len(results) if results else 0.0
    rms_median = float(np.median([r["rms_percent"] for r in results])) if results else 0.0

    print(f"  {cohort_name:5s}: {pass_20_count}/{len(results)} = {pass_20_rate:.1f}% pass@20%, RMS median {rms_median:.1f}%")

    return {
        "n_cases": len(results),
        "pass_20_count": int(pass_20_count),
        "pass_20_rate": float(pass_20_rate),
        "rms_median": float(rms_median),
        "results": results
    }

train = run_cohort("$TRAIN", "TRAIN")
test = run_cohort("$TEST", "TEST")

summary = {
    "config_id": $config_id,
    "A0": 1000,
    "alpha": 0.6,
    "g_star": $GSTAR,
    "gamma": $GAMMA,
    "r_turn": $RTURN,
    "p": 2.0,
    "train": train,
    "test": test
}

with open("$RESULTS_DIR/config_${config_id}.json", "w") as f:
    json.dump(summary, f, indent=2)

print("")
EOF

      # Track best by TEST pass@20%, then RMS
      test_pass=$(jq -r '.test.pass_20_rate // 0' "$RESULTS_DIR/config_${config_id}.json")
      test_rms=$(jq -r '.test.rms_median // 100' "$RESULTS_DIR/config_${config_id}.json")

      # Compare: primary=pass_20_rate (higher better), secondary=rms (lower better)
      if (( $(echo "$test_pass > $best_test_pass" | bc -l) )); then
        best_test_pass=$test_pass
        best_test_rms=$test_rms
        best_config=$config_id
      elif (( $(echo "$test_pass == $best_test_pass" | bc -l) )) && (( $(echo "$test_rms < $best_test_rms" | bc -l) )); then
        best_test_pass=$test_pass
        best_test_rms=$test_rms
        best_config=$config_id
      fi

      config_id=$((config_id + 1))
    done
  done
done

echo "========================================================================"
echo "MICRO-TUNE COMPLETE"
echo "========================================================================"
echo "Best config: $best_config"
echo "  TEST pass@20%: ${best_test_pass}%"
echo "  TEST RMS median: ${best_test_rms}%"
echo ""

if (( $(echo "$best_test_pass >= 35" | bc -l) )); then
  echo "✓ Micro-tune improved to ≥35% TEST"
elif (( $(echo "$best_test_pass >= 30" | bc -l) )); then
  echo "✓ Still GREEN (≥30%), micro-tune did not harm"
else
  echo "⚠ Fell below GREEN threshold (use frozen config)"
fi

echo ""
echo "Results saved to: $RESULTS_DIR/"
echo "Best config: $RESULTS_DIR/config_${best_config}.json"
echo "========================================================================"
echo ""
echo "STOP RULE: No more tuning. Proceed to ablations & paper kit."
