#!/usr/bin/env python3
"""
Generate a sample galaxy rotation curve JSON for the public findings page.

The output contains observed data plus RFT v2, global NFW, and MOND predictions
for a representative SPARC galaxy (NGC3198).
"""

from __future__ import annotations

import hashlib
import json
from pathlib import Path
from typing import Dict

import numpy as np
import sys

PROJECT_ROOT = Path(__file__).resolve().parents[1]
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

from sparc_rft.case import load_case  # type: ignore
from solver.rft_v2_gated_tail import apply_v2_gated_tail  # type: ignore
from baselines import nfw as nfw_module  # type: ignore
from baselines.mond import mond_predict  # type: ignore

SAMPLE_GALAXY = "NGC3198"
OUTPUT_PATH = PROJECT_ROOT / "app/static/data/sample_curve_ngc3198.json"


def load_global_nfw_params() -> Dict[str, float]:
    baseline_path = PROJECT_ROOT / "baselines/results/nfw_global_test_baseline.json"
    payload = json.loads(baseline_path.read_text())
    params = payload.get("global_params") or {}
    return {
        "rho_s": float(params["rho_s_Msun_per_kpc3"]),
        "r_s": float(params["r_s_kpc"]),
    }


def compute_global_nfw(case, rho_s: float, r_s: float) -> np.ndarray:
    r = np.asarray(case.r_kpc, dtype=float)
    v_disk = np.asarray(case.v_baryon_disk_kms, dtype=float)
    v_gas = np.asarray(case.v_baryon_gas_kms, dtype=float)
    if case.v_baryon_bulge_kms is not None:
        v_bulge = np.asarray(case.v_baryon_bulge_kms, dtype=float)
    else:
        v_bulge = np.zeros_like(v_disk)

    v_baryon_sq = v_disk**2 + v_gas**2 + v_bulge**2
    v_dm_sq = nfw_module._v_dm_squared(r, rho_s, r_s)  # pylint: disable=protected-access
    v_total = np.sqrt(np.clip(v_baryon_sq + v_dm_sq, 0, None))
    return v_total


def main() -> None:
    case = load_case(f"cases/sparc_all/{SAMPLE_GALAXY}.json")
    config_path = PROJECT_ROOT / "config/global_rc_v2_frozen.json"
    config = json.loads(config_path.read_text())
    config_sha = hashlib.sha256(config_path.read_bytes()).hexdigest()

    rft_payload = apply_v2_gated_tail(case, config["kernel"], config["tail"])

    nfw_params = load_global_nfw_params()
    v_nfw = compute_global_nfw(case, nfw_params["rho_s"], nfw_params["r_s"])

    mond_pred, _ = mond_predict(case, a0_m_s2=1.2e-10, law="standard")

    sample_payload = {
        "name": case.name,
        "cohort": "SP99-TEST",
        "r_kpc": list(map(float, case.r_kpc)),
        "v_obs_kms": list(map(float, case.v_obs_kms)),
        "sigma_v_kms": list(map(float, case.sigma_v_kms)),
        "rft_v2_kms": list(map(float, rft_payload["v_pred_kms"])),
        "nfw_global_kms": list(map(float, v_nfw)),
        "mond_kms": list(map(float, mond_pred)),
        "config_sha256": config_sha,
        "nfw_global_params": nfw_params,
        "mond_a0_m_s2": 1.2e-10,
    }

    OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
    OUTPUT_PATH.write_text(json.dumps(sample_payload, indent=2))
    print(f"[sample-curve] wrote {OUTPUT_PATH} for {case.name}")


if __name__ == "__main__":
    main()
