#!/usr/bin/env python3
"""
C13: Learn Universal Kernel from TRAIN cohort

Solves global NNLS problem:
    min_w  Σ_galaxies Σ_points (v_obs² - v_b² - r·(C_i·w))² + λ·||D²w||²
    subject to w ≥ 0

where:
    - w: kernel weights at knots (nonnegative)
    - C_i: convolution row for point i
    - D²: second-difference operator (smoothness)
    - λ: regularization parameter (CV to select)

Output: config/global_c13_kernel.json with learned kernel
"""

import argparse
import json
import sys
from pathlib import Path
import numpy as np

sys.path.insert(0, ".")
from sparc_rft.case import GalaxyCase, load_case


def geometric_mean_radius(r: np.ndarray) -> float:
    """Compute r_geo = exp(mean(ln(r)))."""
    return float(np.exp(np.mean(np.log(np.clip(r, 1e-6, None)))))


def build_design_matrix(
    cases: list[GalaxyCase],
    grid: np.ndarray,
    r_scale: str = "r_geo",
    min_radius: float = 1.0,
    max_radius: float = 30.0,
    fit_space: str = "v2",
    weighted: bool = False,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Build global design matrix A and target vector b.

    Args:
        fit_space: "v2" (default, v²) or "accel" (acceleration g)
        weighted: If True, weight by 1/σ_g² using case.sigma_v_kms

    Returns:
        A: (N_total_points, M) design matrix
        b: (N_total_points,) target vector
        weights_vec: (N_total_points,) point weights (for robust fitting)
    """
    rows = []
    targets = []
    point_weights = []

    for case in cases:
        # Extract data
        r = np.asarray(case.r_kpc, dtype=float)
        v_obs = np.asarray(case.v_obs_kms, dtype=float)
        v_disk = np.asarray(case.v_baryon_disk_kms, dtype=float)
        v_gas = np.asarray(case.v_baryon_gas_kms, dtype=float)
        v_bulge = getattr(case, "v_baryon_bulge_kms", None)
        if v_bulge is None:
            v_bulge = np.zeros_like(r)
        else:
            v_bulge = np.asarray(v_bulge, dtype=float)

        # Get sigma_v before filtering (if needed)
        if weighted:
            sigma_v = np.asarray(getattr(case, 'sigma_v_kms', None), dtype=float)
            if sigma_v is None or len(sigma_v) != len(r):
                sigma_v = np.ones_like(r) * 5.0  # Default 5 km/s if missing

        # Filter to window
        mask = (r >= min_radius) & (r <= max_radius)
        if np.sum(mask) < 3:
            continue  # Skip galaxies with <3 points in window

        r = r[mask]
        v_obs = v_obs[mask]
        v_disk = v_disk[mask]
        v_gas = v_gas[mask]
        v_bulge = v_bulge[mask]
        if weighted:
            sigma_v = sigma_v[mask]

        # Compute baryonic baseline
        v_b_sq = v_disk**2 + v_gas**2 + v_bulge**2
        g_b = v_b_sq / np.clip(r, 1e-6, None)

        # Compute scale
        if r_scale == "r_geo":
            r_star = geometric_mean_radius(r)
        elif r_scale == "median":
            r_star = float(np.median(r))
        else:
            r_star = 1.0

        # Transform to log-radius
        rho = np.log(r / r_star)

        # Compute target based on fit_space
        if fit_space == "accel":
            # Target: g_obs - g_b = g_res (acceleration residuals)
            g_obs = v_obs**2 / np.clip(r, 1e-6, None)
            target_g_res = g_obs - g_b
        else:
            # Target: (v_obs² - v_b²) / r = g_res (default v² space)
            target_g_res = (v_obs**2 - v_b_sq) / np.clip(r, 1e-6, None)

        # Compute uncertainties if weighted
        if weighted:
            # Propagate: σ_g ≈ 2·v·σ_v / r (from σ²_g ≈ (∂g/∂v)²·σ²_v)
            sigma_g = 2.0 * v_obs * sigma_v / np.clip(r, 1e-6, None)
            sigma_g = np.clip(sigma_g, 1.0, None)  # Floor at 1 km²/s²/kpc
        else:
            sigma_g = np.ones_like(r)

        # Build convolution rows for this galaxy
        for i in range(len(rho)):
            # Compute Δρ = ρ_i - ρ_j for all j
            delta_rho = rho[i] - rho

            # Interpolate kernel at these positions
            # Row is: K(ρ_i - ρ_j) for j=0..n-1, mapped to grid knots
            # We'll use a simpler approach: for each knot k, compute contribution
            # from nearby points

            # Create row: how much each grid knot contributes to point i
            row = np.zeros(len(grid))

            # Approximate spacing
            if i == 0:
                d_rho = rho[1] - rho[0] if len(rho) > 1 else 1.0
            elif i == len(rho) - 1:
                d_rho = rho[i] - rho[i - 1]
            else:
                d_rho = 0.5 * (rho[i + 1] - rho[i - 1])

            # For each grid knot, find contribution to output point i
            # This is the discretized convolution integral
            for j in range(len(rho)):
                delta = rho[i] - rho[j]
                # Find which grid bins this delta falls into
                # Use linear interpolation weights
                if delta < grid[0] or delta > grid[-1]:
                    continue  # Out of kernel support

                # Binary search for insertion point
                idx = np.searchsorted(grid, delta)
                if idx == 0:
                    row[0] += g_b[j] * d_rho / max(len(rho), 1)
                elif idx >= len(grid):
                    row[-1] += g_b[j] * d_rho / max(len(rho), 1)
                else:
                    # Linear interpolation weights
                    alpha = (delta - grid[idx - 1]) / (
                        grid[idx] - grid[idx - 1] + 1e-9
                    )
                    row[idx - 1] += (1 - alpha) * g_b[j] * d_rho / max(len(rho), 1)
                    row[idx] += alpha * g_b[j] * d_rho / max(len(rho), 1)

            rows.append(row)
            targets.append(target_g_res[i])
            # Weight by 1/σ_g²
            point_weights.append(1.0 / (sigma_g[i]**2 + 1e-9))

    A = np.array(rows)
    b = np.array(targets)
    w = np.array(point_weights)

    return A, b, w


def build_smoothness_matrix(n_knots: int) -> np.ndarray:
    """Build second-difference matrix D² for smoothness regularization."""
    D2 = np.zeros((n_knots - 2, n_knots))
    for i in range(n_knots - 2):
        D2[i, i] = 1.0
        D2[i, i + 1] = -2.0
        D2[i, i + 2] = 1.0
    return D2


def solve_nnls_smooth(
    A: np.ndarray, b: np.ndarray, lambda_reg: float, weights: np.ndarray = None, max_iter: int = 5000
) -> np.ndarray:
    """
    Solve weighted nonnegative least squares with smoothness:
        min_w  ||W^(1/2)(Aw - b)||² + λ||D²w||²
        s.t.   w ≥ 0

    Uses coordinate descent (simple, no SciPy dependency).

    Args:
        weights: Per-point weights (1/σ²). If None, uniform weighting.
    """
    n_points, n_knots = A.shape

    # Apply weights (W^1/2)
    if weights is None:
        weights = np.ones(n_points)
    W_sqrt = np.sqrt(weights)
    A_weighted = A * W_sqrt[:, None]
    b_weighted = b * W_sqrt

    # Build smoothness regularization
    D2 = build_smoothness_matrix(n_knots)

    # Augmented system: [W^1/2 A; √λ D²] w = [W^1/2 b; 0]
    A_aug = np.vstack([A_weighted, np.sqrt(lambda_reg) * D2])
    b_aug = np.hstack([b_weighted, np.zeros(n_knots - 2)])

    # Initialize w
    w = np.ones(n_knots) * 0.1  # Small positive start

    # Coordinate descent with projection
    for iteration in range(max_iter):
        w_old = w.copy()

        for k in range(n_knots):
            # Compute residual with k-th component zeroed
            w[k] = 0.0
            residual = b_aug - A_aug @ w

            # Optimal unconstrained update for w[k]
            numerator = A_aug[:, k] @ residual
            denominator = A_aug[:, k] @ A_aug[:, k] + 1e-9

            w_update = numerator / denominator

            # Project to nonnegative
            w[k] = max(0.0, w_update)

        # Check convergence
        if np.max(np.abs(w - w_old)) < 1e-6:
            break

    return w


def cross_validate_lambda(
    cases: list[GalaxyCase],
    grid: np.ndarray,
    lambda_grid: list[float],
    n_folds: int = 5,
    r_scale: str = "r_geo",
    min_radius: float = 1.0,
    max_radius: float = 30.0,
    fit_space: str = "v2",
    weighted: bool = False,
) -> float:
    """
    K-fold cross-validation to select best λ.

    Returns:
        Best λ value based on median RMS across folds.
    """
    # Shuffle cases
    rng = np.random.RandomState(42)
    indices = rng.permutation(len(cases))
    fold_size = len(cases) // n_folds

    best_lambda = lambda_grid[0]
    best_score = float("inf")

    print(f"Cross-validating λ over {lambda_grid} with {n_folds} folds...")

    for lam in lambda_grid:
        fold_scores = []

        for fold in range(n_folds):
            # Split train/val
            val_start = fold * fold_size
            val_end = (fold + 1) * fold_size if fold < n_folds - 1 else len(cases)
            val_indices = indices[val_start:val_end]
            train_indices = np.concatenate(
                [indices[:val_start], indices[val_end:]]
            )

            train_cases = [cases[i] for i in train_indices]
            val_cases = [cases[i] for i in val_indices]

            # Build design matrix on train fold
            A_train, b_train, w_train = build_design_matrix(
                train_cases, grid, r_scale, min_radius, max_radius, fit_space, weighted
            )

            if A_train.shape[0] < 10:
                continue  # Too few points

            # Solve NNLS
            w = solve_nnls_smooth(A_train, b_train, lam, w_train)

            # Evaluate on validation fold
            A_val, b_val, w_val = build_design_matrix(
                val_cases, grid, r_scale, min_radius, max_radius, fit_space, weighted
            )

            if A_val.shape[0] < 10:
                continue

            residuals = A_val @ w - b_val
            rms = np.sqrt(np.mean(residuals**2))
            fold_scores.append(rms)

        if fold_scores:
            median_score = float(np.median(fold_scores))
            print(f"  λ={lam:.1e}: median RMS = {median_score:.2f} km²/s²/kpc")

            if median_score < best_score:
                best_score = median_score
                best_lambda = lam

    print(f"✅ Selected λ = {best_lambda:.1e} (median RMS = {best_score:.2f})")
    return best_lambda


def main():
    parser = argparse.ArgumentParser(description="Learn universal kernel from TRAIN")
    parser.add_argument(
        "--manifest", default="cases/SP99-TRAIN.manifest.txt", help="Training manifest"
    )
    parser.add_argument(
        "--grid-knots", type=int, default=41, help="Number of kernel knots"
    )
    parser.add_argument(
        "--dln", type=float, default=0.15, help="Grid spacing in Δρ"
    )
    parser.add_argument(
        "--lambda-grid",
        default="0,1e-3,1e-2,1e-1",
        help="Comma-separated λ values for CV",
    )
    parser.add_argument(
        "--out",
        default="config/global_c13_kernel.json",
        help="Output kernel config",
    )
    parser.add_argument("--r-scale", default="r_geo", help="Radius scale method")
    parser.add_argument("--min-radius", type=float, default=1.0, help="Min radius (kpc)")
    parser.add_argument(
        "--max-radius", type=float, default=30.0, help="Max radius (kpc)"
    )
    parser.add_argument(
        "--fit-space", default="accel", choices=["v2", "accel"],
        help="Fit space: v2 (v² residuals) or accel (acceleration residuals)"
    )
    parser.add_argument(
        "--weighted", action="store_true",
        help="Weight by 1/σ_g² using case.sigma_v_kms"
    )
    args = parser.parse_args()

    # Load cases
    print(f"Loading cases from {args.manifest}...")
    manifest = Path(args.manifest)
    if not manifest.exists():
        print(f"❌ Manifest not found: {manifest}")
        return 1

    galaxy_names = [
        line.strip() for line in manifest.read_text().splitlines() if line.strip()
    ]
    cases = []
    for name in galaxy_names:
        # Handle both direct paths and bare names
        if "/" in name:
            # Direct path like "sparc_all/F571-8.json"
            case_path = Path(f"cases/{name}")
        else:
            # Bare name like "F571-8"
            case_path = Path(f"cases/{name}.json")

        if case_path.exists():
            case = load_case(case_path)
            cases.append(case)
        else:
            print(f"⚠️  Case not found: {case_path}")

    print(f"✅ Loaded {len(cases)} galaxies")

    # Build kernel grid
    n_knots = args.grid_knots
    rho_max = 3.0  # Cover ±3 in log-radius (e^±3 ≈ 0.05-20× scale)
    grid = np.linspace(-rho_max, rho_max, n_knots)

    print(
        f"Kernel grid: {n_knots} knots over Δρ ∈ [{grid[0]:.2f}, {grid[-1]:.2f}]"
    )

    # Parse lambda grid
    lambda_values = [float(x) for x in args.lambda_grid.split(",")]

    # Cross-validate λ
    best_lambda = cross_validate_lambda(
        cases,
        grid,
        lambda_values,
        n_folds=5,
        r_scale=args.r_scale,
        min_radius=args.min_radius,
        max_radius=args.max_radius,
        fit_space=args.fit_space,
        weighted=args.weighted,
    )

    # Train on full TRAIN set with best λ
    print(f"\nTraining on full TRAIN set with λ={best_lambda:.1e}...")
    A, b, w_points = build_design_matrix(
        cases, grid, args.r_scale, args.min_radius, args.max_radius,
        args.fit_space, args.weighted
    )

    print(f"Design matrix: {A.shape[0]} points × {A.shape[1]} knots")

    # Solve NNLS
    weights = solve_nnls_smooth(A, b, best_lambda, w_points)

    # Compute training RMS
    residuals = A @ weights - b
    train_rms = np.sqrt(np.mean(residuals**2))
    print(f"Training RMS: {train_rms:.2f} km²/s²/kpc")

    # Check kernel properties
    print(f"\nKernel properties:")
    print(f"  Max weight: {np.max(weights):.3f}")
    print(f"  Min weight: {np.min(weights):.3f}")
    print(f"  Sum: {np.sum(weights):.3f}")
    print(f"  Nonzero knots: {np.sum(weights > 1e-6)}/{len(weights)}")

    # Save config
    config = {
        "grid": grid.tolist(),
        "weights": weights.tolist(),
        "lambda": best_lambda,
        "r_scale": args.r_scale,
        "n_train_cases": len(cases),
        "n_train_points": A.shape[0],
        "train_rms_kms2_kpc": float(train_rms),
    }

    out_path = Path(args.out)
    out_path.write_text(json.dumps(config, indent=2))
    print(f"\n✅ Kernel saved to {out_path}")

    return 0


if __name__ == "__main__":
    sys.exit(main())
