"""
Plot CH rocking vs CH stretching area for all corrected ATR-FTIR spectra.
"""

from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------------------------------------------------------
# Helper: area integration
# -----------------------------------------------------------------------------
def calculate_area(wavenumbers: np.ndarray, absorbance: np.ndarray, wn_min: float, wn_max: float) -> float:
    """Integrate absorbance over a wavenumber range."""
    mask = (wavenumbers >= wn_min) & (wavenumbers <= wn_max)
    if not np.any(mask):
        return 0.0
    return -np.trapz(absorbance[mask], wavenumbers[mask])

# -----------------------------------------------------------------------------
# Function for main pipeline
# -----------------------------------------------------------------------------
def plot_rocking_vs_stretching(raw_json: Path, corrected_json: Path, output_file: Path, ratio: float = 0.05):
    """
    Plot rocking vs stretching area for all samples.
    Uses raw (uncorrected) and raw ATR-corrected spectra.
    """

    # CH regions
    CH_STRETCH = (2750, 3000)
    CH_ROCKING = (710, 733)

    # Load JSON data
    with open(raw_json, "r") as f:
        raw_data = json.load(f)
    with open(corrected_json, "r") as f:
        corrected_data = json.load(f)

    stretch_raw, rock_raw = [], []
    stretch_corr, rock_corr = [], []
    fitted_params = {}

    for key in raw_data:
        wn = np.array(raw_data[key]["wavenumber"])
        A_raw = np.array(raw_data[key]["mean"])  # raw (not normalized)
        stretch_raw.append(calculate_area(wn, A_raw, *CH_STRETCH))
        rock_raw.append(calculate_area(wn, A_raw, *CH_ROCKING))

        if key in corrected_data:
            A_corr = np.array(corrected_data[key]["corrected_absorbance"])  # raw ATR-corrected
            stretch_corr.append(calculate_area(wn, A_corr, *CH_STRETCH))
            rock_corr.append(calculate_area(wn, A_corr, *CH_ROCKING))

            # Store fitted ATR parameters
            fitted_params[key] = corrected_data[key].get("fit_parameters", {})

    stretch_raw = np.array(stretch_raw)
    rock_raw = np.array(rock_raw)
    stretch_corr = np.array(stretch_corr)
    rock_corr = np.array(rock_corr)

    # --- Plot ---
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.scatter(stretch_raw, rock_raw, s=60, label="Raw (uncorrected)", alpha=0.7)
    ax.scatter(stretch_corr, rock_corr, s=60, label="Raw corrected", alpha=0.7)

    # Fit line through origin (raw uncorrected)
    slope = np.sum(stretch_raw * rock_raw) / np.sum(stretch_raw**2)
    x_fit = np.linspace(0, max(stretch_raw.max(), stretch_corr.max())*1.1, 200)
    y_fit = slope * x_fit
    y_ratio = ratio * x_fit

    # Get default color cycle
    prop_cycle = plt.rcParams['axes.prop_cycle']
    colors = prop_cycle.by_key()['color']
    blue_default = colors[0]  # standard default blue
    orange_default = colors[1]  # standard default orange
    ax.plot(x_fit, y_fit, linestyle='--', color=blue_default, label=f"Fit (raw uncorrected): y = {slope:.4f} x")
    ax.plot(x_fit, y_ratio, linestyle='--', color=orange_default, label=f"Target (corrected ratio): y = {ratio:.4f} x")

    # Vertical lines for extremes
    x_min, x_max = stretch_raw.min(), stretch_raw.max()
    ax.vlines(x_min, ymin=0, ymax=slope*x_min, colors='gray', linestyles='dotted')
    ax.vlines(x_max, ymin=0, ymax=slope*x_max, colors='gray', linestyles='dotted')

    # Labels, limits, grid
    ax.set_xlabel(f"CH₂ & CH₃ Stretching area ({CH_STRETCH[0]}-{CH_STRETCH[1]} cm⁻¹)")
    ax.set_ylabel(f"CH₂ Rocking area ({CH_ROCKING[0]}-{CH_ROCKING[1]} cm⁻¹)")
    ax.set_xlim(0, max(stretch_raw.max(), stretch_corr.max())*1.1)
    ax.set_ylim(0, max(rock_raw.max(), rock_corr.max())*1.1)
    ax.grid(True)
    ax.legend(title="Dataset / Fit")

    plt.tight_layout()
    output_file.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_file, dpi=300)
    plt.show()
    return fitted_params

# -----------------------------------------------------------------------------
# Plot fitted ATR parameters
# -----------------------------------------------------------------------------
def plot_fitted_atr_parameters(fitted_params: dict):
    """
    Plot fitted ATR parameters (d0, n2_0, n2_slope) sorted by increasing n2_0.
    """

    sample_names = list(fitted_params.keys())
    d0_values = np.array([fitted_params[s]["d0"] for s in sample_names])
    n2_0_values = np.array([fitted_params[s]["n2_0"] for s in sample_names])
    n2_slope_values = np.array([fitted_params[s]["n2_slope"] for s in sample_names])

    # Sort by n2_0
    sorted_idx = np.argsort(n2_0_values)
    sample_names_sorted = [sample_names[i] for i in sorted_idx]
    d0_sorted = d0_values[sorted_idx]
    n2_0_sorted = n2_0_values[sorted_idx]
    n2_slope_sorted = n2_slope_values[sorted_idx]

    x = np.arange(len(sample_names_sorted))

    fig, ax_d0 = plt.subplots(figsize=(12, 6))
    ax_n2_0 = ax_d0.twinx()
    ax_n2_slope = ax_d0.twinx()
    ax_n2_slope.spines.right.set_position(("axes", 1.1))

    # Plot parameters
    l1 = ax_d0.plot(x, d0_sorted, 'o-', color="blue", label="d₀ [µm]")
    l2 = ax_n2_0.plot(x, n2_0_sorted, 's-', color="red", label="n₂₀")
    l3 = ax_n2_slope.plot(x, n2_slope_sorted * 1e5, 'd-', color="green", label="n₂ slope [10⁻⁵/cm⁻¹]")

    # --- Parse and build clean x-axis labels ---
    x_labels = [s for s in sample_names_sorted]  # just use the full sample name

    ax_d0.set_xticks(x)
    ax_d0.set_xticklabels(x_labels, rotation=45, ha="right")
    ax_d0.set_xlabel("Sample (group – ageing – condition)")

    # Axis labels
    ax_d0.set_ylabel("d₀ [µm]")
    ax_d0.set_ylim(0.75, 1.0)
    ax_n2_0.set_ylabel("n₂₀")
    #ax_n2_0.set_ylim(1.45, 1.6)
    ax_n2_slope.set_ylabel("n₂ slope [10⁻⁵ / cm⁻¹]")
    ax_n2_slope.set_ylim(0, 5)

    # Grid and legend
    ax_d0.grid(True, axis="y", alpha=0.3)
    lines = l1 + l2 + l3
    labels = [l.get_label() for l in lines]
    ax_d0.legend(lines, labels, loc="upper center")

    plt.tight_layout()
    plt.show()
    return fig
