#!/usr/bin/env python3
"""Fit β coefficients from TRAIN set using least-squares regression on outer v² deficit."""
import json
import numpy as np
import argparse
import pathlib
from typing import Tuple

def load_case(p: pathlib.Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Load case data."""
    c = json.load(open(p, "r", encoding="utf-8"))
    r = np.array(c["r_kpc"], float)
    vd = np.array(c["v_baryon_disk_kms"], float)
    vg = np.array(c["v_baryon_gas_kms"], float)
    vbul = np.array(c.get("v_baryon_bulge_kms") or [0] * len(r), float)
    vobs = np.array(c["v_obs_kms"], float)
    vb2 = vd**2 + vg**2 + vbul**2
    return r, vd, vg, vb2, vobs

def descriptors_from_case(r: np.ndarray, vd: np.ndarray, vg: np.ndarray, vb2: np.ndarray) -> Tuple[float, float]:
    """Compute xi_outer and gas_frac_outer descriptors."""
    n = len(r)
    idx = np.arange(n)
    outer = idx[int(0.7 * n):] if n >= 10 else idx

    # Gas fraction (outer 30%)
    gas_frac_outer = np.median(np.clip(vg[outer]**2 / np.clip(vb2[outer], 1e-12, None), 0, 1))

    # Slope of ln(v_b^2) vs ln(r) over outer 20%
    outer2 = idx[int(0.8 * n):] if n >= 10 else idx
    x = np.log(np.clip(r[outer2], 1e-6, None))
    y = np.log(np.clip(vb2[outer2], 1e-12, None))
    slope_out = float(np.polyfit(x, y, 1)[0]) if len(x) >= 2 else 0.0
    xi_outer = max(0.0, min(5.0, -slope_out))  # Clamp like in solver

    return xi_outer, gas_frac_outer

def logmean_gb(r_kpc: np.ndarray, v2_b: np.ndarray) -> Tuple[float, float]:
    """Compute log-mean gb and geometric mean radius."""
    gb = v2_b / np.clip(r_kpc, 1e-6, None)
    glog = float(np.exp(np.mean(np.log(np.clip(gb, 1e-12, None)))))
    rgeo = float(np.exp(np.mean(np.log(np.clip(r_kpc, 1e-6, None)))))
    return glog, rgeo

def parse_manifest(path: pathlib.Path) -> list:
    """Parse manifest file."""
    lines = open(path, "r", encoding="utf-8").readlines()
    paths = []
    for line in lines:
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        # Handle both absolute and relative paths
        if line.startswith("cases/"):
            paths.append(pathlib.Path(line))
        else:
            paths.append(pathlib.Path("cases") / line)
    return paths

def main():
    ap = argparse.ArgumentParser(description="Fit β coefficients from TRAIN set")
    ap.add_argument("--manifest", required=True, help="Path to TRAIN manifest")
    ap.add_argument("--metrics_dir", default="reports", help="Directory with metrics")
    ap.add_argument("--base_config", default="config/global_c9.json", help="Base config to modify")
    ap.add_argument("--out", default="config/global_c9_betas_fit.json", help="Output config path")
    args = ap.parse_args()

    cases = parse_manifest(pathlib.Path(args.manifest))
    X, y = [], []

    for casep in cases:
        # Derive metrics path: reports/<name>/rft_geom/metrics.json
        name = casep.stem
        mpath = pathlib.Path(args.metrics_dir) / name / "rft_geom" / "metrics.json"

        if not mpath.exists():
            print(f"Skipping {name}: no metrics")
            continue

        # Load case data
        try:
            r, vd, vg, vb2, vobs = load_case(casep)
        except Exception as e:
            print(f"Skipping {name}: failed to load case - {e}")
            continue

        # Outer 30% window
        n = len(r)
        outer = np.arange(int(0.7 * n), n) if n >= 10 else np.arange(n)

        # Compute v² deficit in outer region
        # For now, use simple baryon-only baseline
        vpred_baseline = np.sqrt(np.clip(vb2, 0, None))
        dv2 = np.median(np.clip(vobs[outer], 0, None)**2 - np.clip(vpred_baseline[outer], 0, None)**2)

        # Compute log-mean factors used by Mode II
        glog, rgeo = logmean_gb(r, vb2)
        denom = glog * rgeo if glog * rgeo > 1e-12 else 1e-12

        # Target increment in A_flat to close the v² gap
        A_inc = float(dv2 / denom)

        # Compute descriptors
        xi, gas = descriptors_from_case(r, vd, vg, vb2)

        # Target A_flat (clamped to physical range)
        A_tgt = np.clip(A_inc, 0.18, 0.50)

        X.append([1.0, xi, gas])
        y.append(A_tgt)

        print(f"{name}: xi={xi:.3f}, gas={gas:.3f}, A_tgt={A_tgt:.3f}")

    if len(X) < 8:
        raise SystemExit(f"Too few cases to fit β ({len(X)} < 8); ensure TRAIN run produced metrics")

    X = np.array(X)
    y = np.array(y)

    # Ridge regression for stability
    lam = 1e-3
    beta = np.linalg.lstsq(X.T @ X + lam * np.eye(X.shape[1]), X.T @ y, rcond=None)[0]
    b0, b1, b2 = map(float, beta)

    print(f"\n=== Fitted β coefficients ===")
    print(f"beta0 = {b0:.4f}")
    print(f"beta1 = {b1:.4f}")
    print(f"beta2 = {b2:.4f}")
    print(f"N = {len(X)} galaxies")

    # Load base config and update
    cfg = json.load(open(args.base_config, "r", encoding="utf-8"))
    cfg["mode_flattening"]["beta0"] = round(b0, 4)
    cfg["mode_flattening"]["beta1"] = round(b1, 4)
    cfg["mode_flattening"]["beta2"] = round(b2, 4)

    # Write output
    with open(args.out, "w", encoding="utf-8") as f:
        json.dump(cfg, f, indent=2)

    print(f"\nConfig written to: {args.out}")

if __name__ == "__main__":
    main()
