"""
Main processing pipeline for ATR-FTIR analysis.

Pipeline steps:
1. Extract and normalize spectra
2. Apply ATR + baseline correction on normalized spectra
3. Extract spectral parameters
4. Visualise ATR correction
5. Plot spectra
6. Plot difference spectra
7. Plot extracted parameters
"""

from pathlib import Path
import argparse

from extract_spectra import extract_folder
from ATR_baseline_correction import apply_atr_correction_to_folder
from extract_parameters import extract_parameters
from visualise_ATR_correction import plot_rocking_vs_stretching
from plot_spectra import (plot_spectra,
                          plot_difference_spectra)
from plot_parameters import plot_extracted_parameters


# -----------------------------------------------------------------------------
# Pipeline
# -----------------------------------------------------------------------------

def run_pipeline(
    raw_dir: Path,
    extracted_dir: Path,
    corrected_dir: Path,
    spectra_plot_dir: Path,
):
    print("\n=== ATR-FTIR PIPELINE STARTED ===\n")

    # Ensure output directories exist
    extracted_dir.mkdir(parents=True, exist_ok=True)
    corrected_dir.mkdir(parents=True, exist_ok=True)
    spectra_plot_dir.mkdir(parents=True, exist_ok=True)

    # Step 1 — Extraction & normalization
    print("Step 1: Extracting and normalizing spectra")
    extract_folder(folder=raw_dir, output_dir=extracted_dir)

    # Step 2 — ATR correction on normalized spectra
    print("Step 2: ATR + baseline correction on raw spectra and normalise")
    apply_atr_correction_to_folder(input_dir=extracted_dir, output_dir=corrected_dir)

    # Step 3 — Extract spectral parameters
    parameters_json = extracted_dir / "spectral_parameters.json"
    print("Step 3: Extracting spectral parameters")
    extract_parameters(
        raw_json=extracted_dir / "extracted_spectra.json",
        normalized_json=extracted_dir / "normalized_spectra.json",
        corrected_json=corrected_dir / "corrected_spectra.json",
        output_json=parameters_json,
    )

    # Step 4 — Rocking vs stretching plot
    print("Step 4a: Plotting CH rocking vs stretching area")

    # This now returns fitted ATR parameters
    fitted_params = plot_rocking_vs_stretching(
        raw_json=extracted_dir / "extracted_spectra.json",
        corrected_json=corrected_dir / "corrected_spectra.json",
        output_file=spectra_plot_dir / "rocking_vs_stretching.png",
        ratio=0.016
    )

    # Step 4b — Plot fitted ATR parameters
    print("Step 4b: Plotting fitted ATR parameters")
    from visualise_ATR_correction import plot_fitted_atr_parameters

    plot_fitted_atr_parameters(fitted_params)

    # Step 5 — Plot spectra overview
    print("Step 5: Plotting spectra overview")

    plot_spectra(
        raw_json=extracted_dir / "extracted_spectra.json",
        normalized_json=extracted_dir / "normalized_spectra.json",
        corrected_json=corrected_dir / "corrected_spectra.json",
        output_dir=spectra_plot_dir,
    )
    # Step 6 — Plot difference spectra overview
    print("Step 6: Plotting difference spectra overview")
    plot_difference_spectra(
        raw_json=Path("data/extracted/extracted_spectra.json"),
        normalized_json=Path("data/extracted/normalized_spectra.json"),
        corrected_json=Path("data/corrected/corrected_spectra.json"),
        output_dir=Path("data/corrected/difference_plots")
    )

    # Step 7 — Plot extracted parameters
    print("Step 7: Plotting extracted parameters")

    plot_extracted_parameters(
        parameters_json=extracted_dir / "spectral_parameters.json",
        output_dir=Path("data/spectra_plots/parameter_bars")
    )

    print("\n=== PIPELINE COMPLETED SUCCESSFULLY ===\n")


# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------


def main():
    parser = argparse.ArgumentParser(description="Run the ATR-FTIR processing pipeline")

    parser.add_argument(
        "--raw",
        type=Path,
        default=Path("data/raw"),
        help="Folder containing raw CSV spectra (default: data/raw)",
    )
    parser.add_argument(
        "--extracted",
        type=Path,
        default=Path("data/extracted"),
        help="Folder for extracted/normalized spectra (default: data/extracted)",
    )
    parser.add_argument(
        "--corrected",
        type=Path,
        default=Path("data/corrected"),
        help="Folder for ATR-corrected spectra (default: data/corrected)",
    )
    parser.add_argument(
        "--plots",
        type=Path,
        default=Path("data/spectra_plots"),
        help="Folder to save overview plots (default: data/spectra_plots)",
    )

    args = parser.parse_args()

    run_pipeline(
        raw_dir=args.raw,
        extracted_dir=args.extracted,
        corrected_dir=args.corrected,
        spectra_plot_dir=args.plots,
    )


if __name__ == "__main__":
    main()
