"""
ATR penetration-depth correction with dispersive refractive index fitting.
"""

from __future__ import annotations

import json
import numpy as np
from pathlib import Path
from typing import Dict, Tuple
import matplotlib.pyplot as plt
from scipy.optimize import minimize


# -----------------------------------------------------------------------------
# Physics
# -----------------------------------------------------------------------------

def penetration_depth(
    wavenumbers: np.ndarray,
    *,
    n1: float = 2.42,
    n2_0: float = 1.5,
    n2_slope: float = 0.0,
    theta_deg: float = 45.0,
    nu_ref: float = 2000.0,
) -> np.ndarray:
    """
    Calculate ATR penetration depth d_p (µm).

    Parameters
    ----------
    wavenumbers : array_like
        Wavenumbers in cm⁻¹
    n1 : float
        ATR crystal refractive index
    n2_0 : float
        Sample refractive index at reference wavenumber
    n2_slope : float
        Linear dispersion of sample refractive index (per cm⁻¹)
    theta_deg : float
        Incidence angle (degrees)
    nu_ref : float
        Reference wavenumber for n2_0
    """
    wn = np.asarray(wavenumbers, float)
    theta = np.radians(theta_deg)

    n2 = n2_0 + n2_slope * (wn - nu_ref)
    sin2 = np.sin(theta) ** 2
    ratio2 = (n2 / n1) ** 2

    root = np.sqrt(np.clip(sin2 - ratio2, 1e-12, None))

    wavelength_cm = 1.0 / wn
    dp_cm = wavelength_cm / (2 * np.pi * n1 * root)

    return dp_cm * 1e4  # µm


def atr_correction(
    wavenumbers: np.ndarray,
    absorbance: np.ndarray,
    *,
    d0: float = 0.32,
    n1: float = 2.42,
    n2_0: float = 1.5,
    n2_slope: float = 0.0,
    theta_deg: float = 45.0,
) -> np.ndarray:
    """Apply ATR penetration-depth correction."""
    dp = penetration_depth(
        wavenumbers,
        n1=n1,
        n2_0=n2_0,
        n2_slope=n2_slope,
        theta_deg=theta_deg,
    )
    return np.asarray(absorbance, float) * d0 / dp


# -----------------------------------------------------------------------------
# Metrics
# -----------------------------------------------------------------------------

def calculate_area(
    wavenumbers: np.ndarray,
    absorbance: np.ndarray,
    wn_min: float,
    wn_max: float,
) -> float:
    """Integrate absorbance over a wavenumber range."""
    mask = (wavenumbers >= wn_min) & (wavenumbers <= wn_max)
    if not np.any(mask):
        return 0.0

    return -np.trapz(absorbance[mask], wavenumbers[mask])


# -----------------------------------------------------------------------------
# Fitting
# -----------------------------------------------------------------------------

def fit_atr_by_projection_target(
    wavenumbers: np.ndarray,
    absorbance: np.ndarray,
    *,
    CH_STRETCH: Tuple[float, float],
    CH_ROCKING: Tuple[float, float],
    ratio: float = 0.016,
    p0: Tuple[float, float, float] = (0.825, 1.5, 1.0e-5),
    bounds: Tuple[Tuple[float, float], ...] = (
        (0.82499, 0.82501),
        (1.3, 1.7),
        (-5e-5, 20e-5),
    ),
):
    """
    Fit ATR correction parameters by projection onto a reference line
    in CH stretch / CH rocking area space.
    """
    wn = np.asarray(wavenumbers, float)
    A = np.asarray(absorbance, float)

    x0 = calculate_area(wn, A, *CH_STRETCH)
    y0 = calculate_area(wn, A, *CH_ROCKING)

    x_ref = (x0 + ratio * y0) / (1 + ratio**2)
    y_ref = ratio * x_ref

    def objective(p):
        d0, n2_0, n2_slope = p
        A_corr = atr_correction(
            wn,
            A,
            d0=d0,
            n2_0=n2_0,
            n2_slope=n2_slope,
        )

        x = calculate_area(wn, A_corr, *CH_STRETCH)
        y = calculate_area(wn, A_corr, *CH_ROCKING)

        if x <= 0 or y <= 0:
            return 1e6

        return ((x - x_ref) / x_ref) ** 2 + ((y - y_ref) / y_ref) ** 2

    return minimize(
        objective,
        x0=p0,
        bounds=bounds,
        method="L-BFGS-B",
        options={"maxiter": 1000},
    )

def apply_atr_correction_to_folder(
    input_dir: Path,
    output_dir: Path,
    *,
    CH_STRETCH=(2750, 3000),
    CH_ROCKING=(710, 733),
) -> None:
    """
    Apply ATR correction to extracted raw spectra, normalize corrected spectra,
    and save results as JSON and CSV (both raw and normalized corrected spectra).
    """
    from extract_spectra import min_max_normalize  # reuse your existing normalization

    input_path = Path(input_dir) / "extracted_spectra.json"
    output_dir.mkdir(parents=True, exist_ok=True)

    with open(input_path, "r") as f:
        extracted: Dict = json.load(f)

    corrected = {}

    for key, data in extracted.items():
        wn = np.array(data["wavenumber"])
        A = np.array(data["mean"])  # raw mean spectrum

        # Fit ATR parameters
        fit = fit_atr_by_projection_target(
            wn,
            A,
            CH_STRETCH=CH_STRETCH,
            CH_ROCKING=CH_ROCKING,
        )

        # Correct the spectrum (raw ATR-corrected)
        A_corr = atr_correction(
            wn,
            A,
            d0=fit.x[0],
            n2_0=fit.x[1],
            n2_slope=fit.x[2],
        )

        # Normalize the corrected spectrum
        A_corr_norm = min_max_normalize(wn, A_corr)

        corrected[key] = {
            "sample": key,
            "fit_parameters": {
                "d0": fit.x[0],
                "n2_0": fit.x[1],
                "n2_slope": fit.x[2],
            },
            "wavenumber": wn.tolist(),
            "corrected_absorbance": A_corr.tolist(),
            "normalized_corrected_absorbance": A_corr_norm.tolist(),
        }

        # --- Save CSVs ---
        # Raw corrected
        csv_raw_path = output_dir / f"{key}_corrected_raw.csv"
        np.savetxt(
            csv_raw_path,
            np.column_stack([wn, A_corr]),
            delimiter=",",
            header="wavenumber,absorbance",
            comments="",
            fmt="%.6f",
        )

        # Normalized corrected
        csv_norm_path = output_dir / f"{key}_corrected_normalized.csv"
        np.savetxt(
            csv_norm_path,
            np.column_stack([wn, A_corr_norm]),
            delimiter=",",
            header="wavenumber,absorbance",
            comments="",
            fmt="%.6f",
        )

    # Save corrected JSON
    with open(output_dir / "corrected_spectra.json", "w") as f:
        json.dump(corrected, f, indent=2)
