"""
Extract, align, average, and normalize ATR-FTIR spectra.

Expected CSV format:
- Column 1: wavenumber (cm⁻¹)
- Column 2: absorbance

Filename pattern:
<sample>_<ageing>_<replicate>.csv
"""

from __future__ import annotations

import argparse
import json
import re
from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd


# -----------------------------------------------------------------------------
# Constants
# -----------------------------------------------------------------------------

FILENAME_PATTERN = re.compile(
    r"(?P<sample_id>[^_]+)_(?P<ageing>[^_]+)_(?P<replicate>\d+)\.csv"
)

STANDARD_WAVENUMBERS = np.arange(4000.0, 399.5, -0.5)


# -----------------------------------------------------------------------------
# I/O utilities
# -----------------------------------------------------------------------------

def read_spectrum(csv_file: Path) -> pd.DataFrame:
    """Read a two-column FTIR spectrum CSV file."""
    df = pd.read_csv(csv_file, header=None, usecols=[0, 1], comment='#')
    df.columns = ["wavenumber", "absorbance"]

    # Force numeric conversion, drop invalid rows
    df = df.apply(pd.to_numeric, errors='coerce')
    df = df.dropna()
    return df



def interpolate_to_standard_axis(
    df: pd.DataFrame,
    target_axis: np.ndarray,
) -> np.ndarray:
    """Interpolate absorbance onto a standard wavenumber axis."""
    return np.interp(
        target_axis,
        df["wavenumber"].values[::-1],
        df["absorbance"].values[::-1],
    )


# -----------------------------------------------------------------------------
# Processing
# -----------------------------------------------------------------------------

def min_max_normalize(wn: np.ndarray, absorbance: np.ndarray,
                             max_range=(2500, 3500), min_range=(1800, 1900)) -> np.ndarray:
    """
    Min–max normalize a spectrum using custom wavenumber ranges.

    Parameters
    ----------
    wn : array_like
        Wavenumber axis.
    absorbance : array_like
        Absorbance values.
    max_range : tuple
        Wavenumber range to search for maximum (wn_min, wn_max).
    min_range : tuple
        Wavenumber range to search for minimum (wn_min, wn_max).

    Returns
    -------
    normalized : np.ndarray
        Spectrum scaled to [0, 1] using min/max from specified ranges.
    """
    wn = np.asarray(wn)
    absorbance = np.asarray(absorbance, float)

    # Find max in specified range
    mask_max = (wn >= max_range[0]) & (wn <= max_range[1])
    if not np.any(mask_max):
        raise ValueError("No wavenumbers found in max_range")
    max_val = absorbance[mask_max].max()

    # Find min in specified range
    mask_min = (wn >= min_range[0]) & (wn <= min_range[1])
    if not np.any(mask_min):
        raise ValueError("No wavenumbers found in min_range")
    min_val = absorbance[mask_min].min()

    if np.isclose(max_val - min_val, 0.0):
        return np.zeros_like(absorbance)

    normalized = (absorbance - min_val) / (max_val - min_val)
    return normalized


# -----------------------------------------------------------------------------
# Main extraction logic
# -----------------------------------------------------------------------------

def extract_folder(folder: Path, output_dir: Path) -> None:
    """
    Extract, average, and normalize all spectra in a folder.
    """
    samples: Dict[str, List[np.ndarray]] = {}

    for csv_file in sorted(folder.glob("*.csv")):
        match = FILENAME_PATTERN.match(csv_file.name)
        if not match:
            continue

        info = match.groupdict()
        key = f"{info['sample_id']}_{info['ageing']}"

        df = read_spectrum(csv_file)
        interp = interpolate_to_standard_axis(df, STANDARD_WAVENUMBERS)

        samples.setdefault(key, []).append(interp)

    output_dir.mkdir(parents=True, exist_ok=True)

    extracted_all = {}
    normalized_all = {}

    for key, spectra in samples.items():
        stack = np.vstack(spectra)

        mean = stack.mean(axis=0)
        std = stack.std(axis=0)

        extracted_all[key] = {
            "sample": key,
            "n_replicates": stack.shape[0],
            "wavenumber": STANDARD_WAVENUMBERS.tolist(),
            "mean": mean.tolist(),
            "std": std.tolist(),
        }

        norm = min_max_normalize(STANDARD_WAVENUMBERS, mean)
        normalized_all[key] = {
            "wavenumber": STANDARD_WAVENUMBERS.tolist(),
            "normalized_absorbance": norm.tolist(),
        }

    with open(output_dir / "extracted_spectra.json", "w") as f:
        json.dump(extracted_all, f, indent=2)

    with open(output_dir / "normalized_spectra.json", "w") as f:
        json.dump(normalized_all, f, indent=2)


