"""
Extract spectral parameters from normalized and ATR-corrected spectra,
and provide a plotting function.
"""

from pathlib import Path
import json
from typing import Dict, Tuple

import numpy as np

# -----------------------------------------------------------------------------
# Peak regions (wavenumber ranges)
# -----------------------------------------------------------------------------

PEAK_RANGES: Dict[str, Tuple[float, float]] = {
    "aliphatic_CH_stretch": (2800, 3000),
    "carbonyl_C=O": (1650, 1750),
    "aromaticity": (1520, 1650),
    "sulfoxides": (980, 1070),
    "aromatic_bending": (735, 910),
    "aliphatic_CH_rocking": (710, 735),
}

# -----------------------------------------------------------------------------
# Area integration
# -----------------------------------------------------------------------------

def integrate_area(wavenumbers: np.ndarray, absorbance: np.ndarray, wn_range: Tuple[float, float]) -> float:
    """Integrate the absorbance over a given wavenumber range."""
    wn_min, wn_max = wn_range
    mask = (wavenumbers >= wn_min) & (wavenumbers <= wn_max)
    if not np.any(mask):
        return 0.0
    return -np.trapz(absorbance[mask], wavenumbers[mask])

# -----------------------------------------------------------------------------
# Parameter extraction
# -----------------------------------------------------------------------------

def extract_parameters(raw_json: Path, normalized_json: Path, corrected_json: Path, output_json: Path) -> None:
    """
    Extract areas for predefined peaks for:
      - raw spectra
      - normalized raw spectra
      - ATR-corrected raw spectra
      - ATR-corrected normalized spectra
    Saves results to JSON and collects indices for plotting.
    """
    # Load data
    with open(raw_json, "r") as f:
        raw_data = json.load(f)

    with open(normalized_json, "r") as f:
        norm_data = json.load(f)

    with open(corrected_json, "r") as f:
        corrected = json.load(f)

    results = {}
    indices = {
        "raw": [],
        "raw_corrected": [],
        "normalized": [],
        "normalized_corrected": []
    }

    for idx, key in enumerate(raw_data):
        wn = np.array(raw_data[key]["wavenumber"])
        A_raw = np.array(raw_data[key]["mean"])  # true raw spectrum
        A_norm = np.array(norm_data[key]["normalized_absorbance"])  # normalized raw

        res = {
            "before_correction": {},
            "after_correction": {},
            "indices": {
                "raw": idx,
                "raw_corrected": idx,
                "normalized": idx,
                "normalized_corrected": idx
            }
        }

        # --- Before correction ---
        for peak, rng in PEAK_RANGES.items():
            res["before_correction"][peak] = integrate_area(wn, A_raw, rng)        # raw
            res["before_correction"][f"{peak}_norm"] = integrate_area(wn, A_norm, rng)  # normalized raw

        # --- After correction ---
        if key in corrected:
            A_corr = np.array(corrected[key]["corrected_absorbance"])
            A_corr_norm = np.array(corrected[key]["normalized_corrected_absorbance"])
            for peak, rng in PEAK_RANGES.items():
                res["after_correction"][peak] = integrate_area(wn, A_corr, rng)       # corrected raw
                res["after_correction"][f"{peak}_norm"] = integrate_area(wn, A_corr_norm, rng)  # corrected normalized

        results[key] = res

        # --- Store indices ---
        indices["raw"].append(idx)
        indices["raw_corrected"].append(idx)
        indices["normalized"].append(idx)
        indices["normalized_corrected"].append(idx)

    # Save
    output_json.parent.mkdir(parents=True, exist_ok=True)
    with open(output_json, "w") as f:
        json.dump({"parameters": results, "indices": indices}, f, indent=2)




