#!/usr/bin/env python3
"""
Fill PRE_REGISTRATION_C9.md with TRAIN-derived predictions.
Computes Wilson confidence intervals and TEST predictions.
"""
import json
import sys
import math
from pathlib import Path

def wilson_ci(k: int, n: int, z: float = 1.96):
    """Wilson score confidence interval."""
    if n <= 0:
        return (0.0, 0.0)
    p = k / n
    denom = 1.0 + (z * z) / n
    center = (p + (z * z) / (2 * n)) / denom
    half = (z / denom) * math.sqrt((p * (1 - p) / n) + (z * z) / (4 * n * n))
    lower = max(0.0, center - half)
    upper = min(1.0, center + half)
    return (lower, upper)

def main():
    train_summary = Path("reports/_summary_C9_A.json")
    
    if not train_summary.exists():
        print(f"ERROR: TRAIN summary not found: {train_summary}", file=sys.stderr)
        sys.exit(1)
    
    with train_summary.open("r") as f:
        summary = json.load(f)
    
    # Extract TRAIN metrics
    solver_totals = summary.get("solver_totals", {}).get("rft_geom", {})
    n_train = solver_totals.get("n_cases", 0)
    pass_count = solver_totals.get("pass_count", 0)
    median_rms = solver_totals.get("rms_pct_median", 0.0)
    
    # Compute Wilson CI
    ci_low, ci_high = wilson_ci(pass_count, n_train)
    
    # Predict TEST range (34 galaxies)
    n_test = 34
    test_low = int(math.floor(n_test * ci_low))
    test_high = int(math.ceil(n_test * ci_high))
    
    # Print summary
    print("=" * 60)
    print("C9 PRE-REGISTRATION PREDICTIONS")
    print("=" * 60)
    print()
    print(f"TRAIN Results (n={n_train}):")
    print(f"  Pass@20%: {pass_count}/{n_train} = {pass_count*100/n_train:.1f}%")
    print(f"  Wilson 95% CI: [{ci_low*100:.1f}%, {ci_high*100:.1f}%]")
    print(f"  Median RMS: {median_rms:.1f}%")
    print()
    print(f"Predicted TEST (n={n_test}):")
    print(f"  Pass range: {test_low}–{test_high} ({test_low*100/n_test:.1f}%–{test_high*100/n_test:.1f}%)")
    print()
    print("=" * 60)
    print("Update PRE_REGISTRATION_C9.md with these values:")
    print("=" * 60)
    print(f"- TRAIN pass@20%: {pass_count}/{n_train} = {pass_count*100/n_train:.1f}%")
    print(f"- Wilson 95% CI: [{ci_low*100:.1f}%, {ci_high*100:.1f}%]")
    print(f"- Predicted TEST pass (n=34): [{test_low}, {test_high}]")
    print(f"- Pass rate: {test_low*100/n_test:.1f}% to {test_high*100/n_test:.1f}%")
    print(f"- Median RMS: ~{median_rms:.1f}%")
    print()

if __name__ == "__main__":
    main()
