#!/usr/bin/env python3
"""
Export publication-grade figures from multi-solver summary.

Generates 5 figures:
1. Pass rate bar chart with Wilson CI whiskers
2. RMS% violin plot (overall)
3. RMS% violin plot split by LSB/HSB
4. RMS% vs r_max scatter (by tag)
5. Head-to-head wins matrix

Usage:
    python -m scripts.export_figs --summary reports/_summary_SP120.json --out reports/figs
"""

from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np


def load_summary(summary_path: Path) -> Tuple[Dict, List[Dict], Dict[str, Dict]]:
    """Load summary JSON and extract key components."""
    with summary_path.open("r") as f:
        summary = json.load(f)

    rows = summary.get("rows", [])
    solver_totals = summary.get("solver_totals", {})

    return summary, rows, solver_totals


def fig_passrate_bar(solver_totals: Dict[str, Dict], out_dir: Path) -> None:
    """
    Figure 1: Pass rate bars with Wilson 95% CI error bars.
    """
    solvers = sorted(solver_totals.keys())
    pass_rates = [solver_totals[s]["pass_rate"] for s in solvers]
    wilson_cis = [solver_totals[s]["wilson_95"] for s in solvers]

    xs = np.arange(len(solvers))

    fig, ax = plt.subplots(figsize=(8, 5))

    # Bar chart
    colors = ["#d62728", "#2ca02c", "#1f77b4"]  # RFT red, MOND green, NFW blue
    bars = ax.bar(xs, pass_rates, color=colors[:len(solvers)], alpha=0.8, edgecolor="black")

    # Wilson CI whiskers
    for i, (lo, hi) in enumerate(wilson_cis):
        center = pass_rates[i]
        # Vertical line from lo to hi
        ax.plot([i, i], [lo, hi], color="black", linewidth=2, zorder=10)
        # Horizontal caps
        ax.plot([i - 0.15, i + 0.15], [lo, lo], color="black", linewidth=2, zorder=10)
        ax.plot([i - 0.15, i + 0.15], [hi, hi], color="black", linewidth=2, zorder=10)

    ax.set_ylabel("Pass Rate", fontsize=12)
    ax.set_xlabel("Solver", fontsize=12)
    ax.set_title("Pass Rate by Solver (with 95% Wilson CI)", fontsize=14, fontweight="bold")
    ax.set_xticks(xs)
    ax.set_xticklabels(solvers, rotation=15, ha="right")
    ax.set_ylim(0, 1.05)
    ax.grid(axis="y", alpha=0.3)

    # Add pass rate labels on bars
    for i, (bar, rate) in enumerate(zip(bars, pass_rates)):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width() / 2, height + 0.02,
                f"{rate:.1%}", ha="center", va="bottom", fontsize=10, fontweight="bold")

    plt.tight_layout()
    out_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_dir / "fig_passrate_bar.png", dpi=180)
    plt.close()


def fig_violin_rms(rows: List[Dict], out_dir: Path) -> None:
    """
    Figure 2: RMS% violin plot (overall distributions by solver).
    """
    # Extract RMS% per solver
    solvers = set()
    for row in rows:
        solvers.update(row.get("solvers", {}).keys())

    solvers = sorted(solvers)
    data_by_solver = {s: [] for s in solvers}

    for row in rows:
        for solver, metrics in row.get("solvers", {}).items():
            rms = metrics.get("rms_pct")
            if rms is not None and math.isfinite(rms):
                data_by_solver[solver].append(rms)

    fig, ax = plt.subplots(figsize=(8, 6))

    # Filter out solvers with no data
    valid_solvers = [s for s in solvers if len(data_by_solver[s]) > 0]
    if not valid_solvers:
        ax.text(0.5, 0.5, "No data available", ha="center", va="center")
        plt.savefig(out_dir / "fig_violin_rms.png", dpi=180)
        plt.close()
        return

    positions = np.arange(len(valid_solvers))
    parts = ax.violinplot(
        [data_by_solver[s] for s in valid_solvers],
        positions=positions,
        showmeans=True,
        showmedians=True,
    )

    # Color violins
    colors = ["#d62728", "#2ca02c", "#1f77b4"]
    for i, pc in enumerate(parts["bodies"]):
        pc.set_facecolor(colors[i % len(colors)])
        pc.set_alpha(0.7)

    ax.set_ylabel("RMS% Error", fontsize=12)
    ax.set_xlabel("Solver", fontsize=12)
    ax.set_title("RMS% Distribution by Solver", fontsize=14, fontweight="bold")
    ax.set_xticks(positions)
    ax.set_xticklabels(valid_solvers, rotation=15, ha="right")
    ax.grid(axis="y", alpha=0.3)

    plt.tight_layout()
    plt.savefig(out_dir / "fig_violin_rms.png", dpi=180)
    plt.close()


def fig_violin_rms_lsb_hsb(rows: List[Dict], out_dir: Path) -> None:
    """
    Figure 3: RMS% violin plot split by LSB/HSB tags.
    """
    # Extract RMS% per solver, split by LSB/HSB
    solvers = set()
    for row in rows:
        solvers.update(row.get("solvers", {}).keys())

    solvers = sorted(solvers)
    data = {s: {"LSB": [], "HSB": []} for s in solvers}

    for row in rows:
        tags = row.get("tags", [])
        tag_group = None
        if "LSB" in tags:
            tag_group = "LSB"
        elif "HSB" in tags:
            tag_group = "HSB"

        if tag_group:
            for solver, metrics in row.get("solvers", {}).items():
                rms = metrics.get("rms_pct")
                if rms is not None and math.isfinite(rms):
                    data[solver][tag_group].append(rms)

    fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True)

    for idx, tag in enumerate(["LSB", "HSB"]):
        ax = axes[idx]

        # Filter out empty datasets
        valid_solvers = [s for s in solvers if len(data[s][tag]) > 0]
        if not valid_solvers:
            ax.text(0.5, 0.5, f"No {tag} data", ha="center", va="center", transform=ax.transAxes)
            ax.set_xticks([])
            continue

        positions = np.arange(len(valid_solvers))
        violin_data = [data[s][tag] for s in valid_solvers]

        parts = ax.violinplot(
            violin_data,
            positions=positions,
            showmeans=True,
            showmedians=True,
        )

        # Color violins
        colors = ["#d62728", "#2ca02c", "#1f77b4"]
        for i, pc in enumerate(parts["bodies"]):
            pc.set_facecolor(colors[i % len(colors)])
            pc.set_alpha(0.7)

        ax.set_ylabel("RMS% Error" if idx == 0 else "", fontsize=12)
        ax.set_xlabel("Solver", fontsize=12)
        ax.set_title(f"{tag} Galaxies", fontsize=13, fontweight="bold")
        if valid_solvers:
            ax.set_xticks(positions)
            ax.set_xticklabels(valid_solvers, rotation=15, ha="right")
        ax.grid(axis="y", alpha=0.3)

    plt.suptitle("RMS% Distribution: LSB vs HSB", fontsize=14, fontweight="bold")
    plt.tight_layout()
    plt.savefig(out_dir / "fig_violin_rms_lsb_hsb.png", dpi=180)
    plt.close()


def fig_scatter_rms_vs_rmax(rows: List[Dict], out_dir: Path) -> None:
    """
    Figure 4: RMS% vs r_max scatter, colored by solver, shaped by tag.
    """
    fig, ax = plt.subplots(figsize=(10, 6))

    solvers = set()
    for row in rows:
        solvers.update(row.get("solvers", {}).keys())

    solvers = sorted(solvers)
    colors = {"rft_geom": "#d62728", "mond": "#2ca02c", "nfw_fit": "#1f77b4"}
    markers = {"LSB": "o", "HSB": "s", "other": "D"}

    for solver in solvers:
        for row in rows:
            metrics = row.get("solvers", {}).get(solver)
            if not metrics:
                continue

            rms = metrics.get("rms_pct")
            r_max = row.get("r_kpc_max")

            if rms is None or r_max is None:
                continue
            if not math.isfinite(rms):
                continue

            # Determine tag
            tags = row.get("tags", [])
            if "LSB" in tags:
                marker = markers["LSB"]
            elif "HSB" in tags:
                marker = markers["HSB"]
            else:
                marker = markers["other"]

            ax.scatter(r_max, rms, c=colors.get(solver, "#888888"),
                      marker=marker, s=60, alpha=0.7, edgecolors="black", linewidth=0.5)

    ax.set_xlabel("r_max (kpc)", fontsize=12)
    ax.set_ylabel("RMS% Error", fontsize=12)
    ax.set_title("RMS% vs Maximum Radius", fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.3)

    # Legend for solvers
    from matplotlib.lines import Line2D
    solver_handles = [Line2D([0], [0], marker="o", color="w", markerfacecolor=colors.get(s, "#888"),
                             markersize=8, label=s) for s in solvers]
    tag_handles = [Line2D([0], [0], marker=markers[t], color="w", markerfacecolor="gray",
                         markersize=8, label=t) for t in ["LSB", "HSB"]]

    legend1 = ax.legend(handles=solver_handles, title="Solver", loc="upper left")
    ax.add_artist(legend1)
    ax.legend(handles=tag_handles, title="Galaxy Type", loc="upper right")

    plt.tight_layout()
    plt.savefig(out_dir / "fig_scatter_rms_vs_rmax.png", dpi=180)
    plt.close()


def fig_headtohead_matrix(summary: Dict, out_dir: Path) -> None:
    """
    Figure 5: Head-to-head wins matrix (3×3 heatmap).
    """
    wins_data = summary.get("wins", {})
    solvers = sorted(summary.get("solver_totals", {}).keys())

    # Build matrix: rows = solver A, cols = solver B, value = A's win rate vs B
    n = len(solvers)
    matrix = np.zeros((n, n))

    for i, solver_a in enumerate(solvers):
        for j, solver_b in enumerate(solvers):
            if i == j:
                matrix[i, j] = 0.5  # Diagonal (vs self)
            else:
                key = f"{solver_a}>{solver_b}"
                if key in wins_data:
                    wins = wins_data[key]["wins"]
                    total = wins_data[key]["of"]
                    win_rate = wins / total if total > 0 else 0
                    matrix[i, j] = win_rate

    fig, ax = plt.subplots(figsize=(8, 7))

    # Heatmap
    im = ax.imshow(matrix, cmap="RdYlGn", vmin=0, vmax=1, aspect="auto")

    # Colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("Win Rate", rotation=270, labelpad=20, fontsize=12)

    # Labels
    ax.set_xticks(np.arange(n))
    ax.set_yticks(np.arange(n))
    ax.set_xticklabels(solvers)
    ax.set_yticklabels(solvers)
    ax.set_xlabel("Opponent (Solver B)", fontsize=12)
    ax.set_ylabel("Solver A", fontsize=12)
    ax.set_title("Head-to-Head Win Rate Matrix\n(Row beats Column)", fontsize=14, fontweight="bold")

    # Annotate cells
    for i in range(n):
        for j in range(n):
            if i == j:
                text = "—"
            else:
                key = f"{solvers[i]}>{solvers[j]}"
                if key in wins_data:
                    wins = wins_data[key]["wins"]
                    total = wins_data[key]["of"]
                    text = f"{wins}/{total}\n{matrix[i,j]:.0%}"
                else:
                    text = "N/A"

            color = "white" if matrix[i, j] < 0.5 else "black"
            ax.text(j, i, text, ha="center", va="center", color=color, fontsize=10, fontweight="bold")

    plt.tight_layout()
    plt.savefig(out_dir / "fig_headtohead_matrix.png", dpi=180)
    plt.close()


def main(summary_path: str, out_dir: str) -> None:
    """Export all figures from summary."""
    summary_path = Path(summary_path)
    out_dir = Path(out_dir)

    if not summary_path.exists():
        raise FileNotFoundError(f"Summary not found: {summary_path}")

    print(f"Loading summary from {summary_path}...")
    summary, rows, solver_totals = load_summary(summary_path)

    print(f"Exporting figures to {out_dir}...")

    print("  [1/5] Pass rate bar chart...")
    fig_passrate_bar(solver_totals, out_dir)

    print("  [2/5] RMS% violin plot (overall)...")
    fig_violin_rms(rows, out_dir)

    print("  [3/5] RMS% violin plot (LSB/HSB split)...")
    fig_violin_rms_lsb_hsb(rows, out_dir)

    print("  [4/5] RMS% vs r_max scatter...")
    fig_scatter_rms_vs_rmax(rows, out_dir)

    print("  [5/5] Head-to-head wins matrix...")
    fig_headtohead_matrix(summary, out_dir)

    print(f"\n✅ All figures exported to {out_dir}/")
    print("\nGenerated files:")
    for fig_file in sorted(out_dir.glob("fig_*.png")):
        print(f"  - {fig_file.name}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export publication-grade figures from multi-solver summary")
    parser.add_argument("--summary", required=True, help="Path to summary JSON (e.g., reports/_summary_SP120.json)")
    parser.add_argument("--out", required=True, help="Output directory for figures")

    args = parser.parse_args()
    main(args.summary, args.out)
