"""
Plot overview of normalized and ATR-corrected spectra.
"""

from pathlib import Path
import json

import numpy as np
import matplotlib.pyplot as plt

"""
Overview plots for ATR-FTIR spectra with a broken x-axis at 2000 cm⁻¹.

Creates two figures:
1. Raw vs raw ATR-corrected spectra
2. Normalised vs normalised ATR-corrected spectra
"""

from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt


def plot_spectra(raw_json: Path, normalized_json: Path, corrected_json: Path, output_dir: Path):
    """
    Plot raw & raw ATR-corrected spectra and normalized & corrected normalized spectra.
    Creates broken-axis plots if needed.
    """

    output_dir.mkdir(parents=True, exist_ok=True)

    # Load JSON
    with open(raw_json, "r") as f:
        raw_spectra = json.load(f)
    with open(normalized_json, "r") as f:
        norm_spectra = json.load(f)
    with open(corrected_json, "r") as f:
        corrected_spectra = json.load(f)

    # --- Raw vs Raw Corrected ---
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=False, figsize=(14, 6))

    for key in raw_spectra:
        wn = np.array(raw_spectra[key]["wavenumber"])
        A_raw = np.array(raw_spectra[key]["mean"])
        A_corr = np.array(corrected_spectra[key]["corrected_absorbance"])

        # Left plot: 4000–2000 cm⁻¹
        mask_left = (wn >= 2000) & (wn <= 4000)
        ax1.plot(wn[mask_left], A_raw[mask_left], color="gray", alpha=0.5)
        ax1.plot(wn[mask_left], A_corr[mask_left], color="blue", alpha=0.7)

        # Right plot: 2000–400 cm⁻¹
        mask_right = (wn >= 400) & (wn <= 2000)
        ax2.plot(wn[mask_right], A_raw[mask_right], color="gray", alpha=0.5)
        ax2.plot(wn[mask_right], A_corr[mask_right], color="blue", alpha=0.7)

    # Invert x-axis (FTIR standard)
    ax1.invert_xaxis()
    ax2.invert_xaxis()

    # Y-axis ranges
    ax1.set_ylim(-0.025, 0.525)
    ax1.set_xlim(4000, 2000)
    ax2.set_ylim(-0.0125, 0.25125)
    ax2.set_xlim(2000, 400)

    ax1.set_xlabel("Wavenumber (cm⁻¹)")
    ax2.set_xlabel("Wavenumber (cm⁻¹)")
    ax1.set_ylabel("Absorbance (a.u.)")
    ax1.set_title("Raw vs ATR-corrected (raw)")

    plt.tight_layout()
    fig.savefig(output_dir / "raw_vs_corrected.png", dpi=300)

    # --- Normalized vs Normalized Corrected ---
    fig, (ax1, ax2) = plt.subplots(1, 2, sharey=False, figsize=(14, 6))
    for key in norm_spectra:
        wn = np.array(norm_spectra[key]["wavenumber"])
        A_norm = np.array(norm_spectra[key]["normalized_absorbance"])
        A_corr_norm = np.array(corrected_spectra[key]["normalized_corrected_absorbance"])

        mask_left = (wn >= 2000) & (wn <= 4000)
        mask_right = (wn >= 400) & (wn <= 2000)

        ax1.plot(wn[mask_left], A_norm[mask_left], color="gray", alpha=0.5)
        ax1.plot(wn[mask_left], A_corr_norm[mask_left], color="green", alpha=0.7)

        ax2.plot(wn[mask_right], A_norm[mask_right], color="gray", alpha=0.5)
        ax2.plot(wn[mask_right], A_corr_norm[mask_right], color="green", alpha=0.7)

    ax1.invert_xaxis()
    ax2.invert_xaxis()

    ax1.set_ylim(-0.05, 1.05)
    ax1.set_xlim(4000, 2000)
    ax2.set_ylim(-0.025, 0.525)
    ax2.set_xlim(2000, 400)
    ax1.set_xlabel("Wavenumber (cm⁻¹)")
    ax2.set_xlabel("Wavenumber (cm⁻¹)")
    ax1.set_ylabel("Absorbance (a.u.)")
    ax1.set_title("Normalized vs ATR-corrected (normalized)")

    plt.tight_layout()
    fig.savefig(output_dir / "normalized_vs_corrected.png", dpi=300)
    plt.show()



# -----------------------------------------------------------------------------
# CLI
# -----------------------------------------------------------------------------

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Plot ATR-FTIR spectra overview")
    parser.add_argument(
        "--normalized",
        type=Path,
        default=Path("data/extracted/normalized_spectra.json"),
        help="Path to normalized spectra JSON",
    )
    parser.add_argument(
        "--corrected",
        type=Path,
        default=Path("data/corrected/corrected_spectra.json"),
        help="Path to ATR-corrected spectra JSON",
    )
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("data/corrected/spectra_overview.png"),
        help="Output figure path",
    )

    args = parser.parse_args()

    plot_spectra(args.normalized, args.corrected, args.output)


def compute_diff_raw(A_dict, baseline_name):
    diffs = {}
    for key, data in A_dict.items():
        wn = np.array(data["wavenumber"])
        if key == baseline_name:
            continue
        A = np.array(data["mean"])              # raw mean
        A_base = np.array(A_dict[baseline_name]["mean"])
        diffs[key] = {"wavenumber": wn, "difference": A - A_base}
    return diffs

def compute_diff_corr(A_dict, baseline_name):
    diffs = {}
    for key, data in A_dict.items():
        wn = np.array(data["wavenumber"])
        if key == baseline_name:
            continue
        A = np.array(data["corrected_absorbance"])               # raw ATR-corrected
        A_base = np.array(A_dict[baseline_name]["corrected_absorbance"])
        diffs[key] = {"wavenumber": wn, "difference": A - A_base}
    return diffs

def compute_diff_norm(A_dict, baseline_name):
    diffs = {}
    for key, data in A_dict.items():
        wn = np.array(data["wavenumber"])
        if key == baseline_name:
            continue
        A = np.array(data["normalized_absorbance"])              # normalized raw
        A_base = np.array(A_dict[baseline_name]["normalized_absorbance"])
        diffs[key] = {"wavenumber": wn, "difference": A - A_base}
    return diffs

def compute_diff_norm_corr(A_dict, baseline_name):
    diffs = {}
    for key, data in A_dict.items():
        wn = np.array(data["wavenumber"])
        if key == baseline_name:
            continue
        A = np.array(data["normalized_corrected_absorbance"])   # normalized corrected
        A_base = np.array(A_dict[baseline_name]["normalized_corrected_absorbance"])
        diffs[key] = {"wavenumber": wn, "difference": A - A_base}
    return diffs

def plot_difference_spectra(raw_json: Path, normalized_json: Path, corrected_json: Path, output_dir: Path):
    """
    Plot difference spectra using a user-selected baseline spectrum.
    Shows:
    1. Raw - Raw corrected differences
    2. Normalized - Normalized corrected differences
    """

    output_dir.mkdir(parents=True, exist_ok=True)

    # Load JSON
    with open(raw_json, "r") as f:
        raw_spectra = json.load(f)
    with open(normalized_json, "r") as f:
        norm_spectra = json.load(f)
    with open(corrected_json, "r") as f:
        corrected_spectra = json.load(f)

    # --- Ask user for baseline spectrum ---
    sample_names = list(raw_spectra.keys())
    print("\nAvailable spectra:")
    for i, name in enumerate(sample_names):
        print(f"{i}: {name}")
    baseline_idx = int(input("\nEnter the index of the baseline spectrum: "))
    baseline_name = sample_names[baseline_idx]
    print(f"Selected baseline: {baseline_name}")

    raw_diff = compute_diff_raw(raw_spectra, baseline_name)
    raw_corr_diff = compute_diff_corr(corrected_spectra, baseline_name)
    norm_diff = compute_diff_norm(norm_spectra, baseline_name)
    norm_corr_diff = compute_diff_norm_corr(corrected_spectra, baseline_name)

    # --- Plot Raw differences ---
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), sharey=False)
    for key in raw_diff:
        wn = raw_diff[key]["wavenumber"]
        ax1.plot(wn[wn>=2000], raw_diff[key]["difference"][wn>=2000], color="gray", alpha=0.5)
        ax2.plot(wn[wn<2000], raw_diff[key]["difference"][wn<2000], color="gray", alpha=0.5)
    for key in raw_corr_diff:
        wn = raw_corr_diff[key]["wavenumber"]
        ax1.plot(wn[wn>=2000], raw_corr_diff[key]["difference"][wn>=2000], color="blue", alpha=0.7)
        ax2.plot(wn[wn<2000], raw_corr_diff[key]["difference"][wn<2000], color="blue", alpha=0.7)

    ax1.invert_xaxis()
    ax2.invert_xaxis()
    ax1.set_ylim(-0.005, 0.035)
    ax2.set_ylim(-0.005, 0.035)
    ax1.set_xlabel("Wavenumber (cm⁻¹)")
    ax2.set_xlabel("Wavenumber (cm⁻¹)")
    ax1.set_ylabel("ΔAbsorbance (a.u.)")
    ax1.set_title("Raw - Raw Corrected Difference Spectra")
    plt.tight_layout()
    fig.savefig(output_dir / "raw_difference_spectra.png", dpi=300)

    # --- Plot Normalized differences ---
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), sharey=False)
    for key in norm_diff:
        wn = norm_diff[key]["wavenumber"]
        ax1.plot(wn[wn>=2000], norm_diff[key]["difference"][wn>=2000], color="gray", alpha=0.5)
        ax2.plot(wn[wn<2000], norm_diff[key]["difference"][wn<2000], color="gray", alpha=0.5)
    for key in norm_corr_diff:
        wn = norm_corr_diff[key]["wavenumber"]
        ax1.plot(wn[wn>=2000], norm_corr_diff[key]["difference"][wn>=2000], color="green", alpha=0.7)
        ax2.plot(wn[wn<2000], norm_corr_diff[key]["difference"][wn<2000], color="green", alpha=0.7)

    ax1.invert_xaxis()
    ax2.invert_xaxis()
    ax1.set_ylim(-0.02, 0.090)
    ax2.set_ylim(-0.02, 0.090)
    ax1.set_xlabel("Wavenumber (cm⁻¹)")
    ax2.set_xlabel("Wavenumber (cm⁻¹)")
    ax1.set_ylabel("ΔAbsorbance (a.u.)")
    ax1.set_title("Normalized - Normalized Corrected Difference Spectra")
    plt.tight_layout()
    fig.savefig(output_dir / "normalized_difference_spectra.png", dpi=300)
    plt.show()
