from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt

from extract_parameters import PEAK_RANGES

# -----------------------------------------------------------------------------
# Parameter Plotting
# -----------------------------------------------------------------------------

from pathlib import Path
import json
import numpy as np
import matplotlib.pyplot as plt

from extract_parameters import PEAK_RANGES

# -----------------------------------------------------------------------------
# Parameter Plotting with dual y-axis
# -----------------------------------------------------------------------------

def plot_extracted_parameters(parameters_json: Path, output_dir: Path):
    """
    Create combined bar plots of extracted parameters with dual y-axis.
    Two figures:
      1. Raw vs Raw corrected
      2. Normalized vs Normalized corrected
    Raw = left y-axis (grey), Corrected = right y-axis (color)
    """
    output_dir.mkdir(parents=True, exist_ok=True)

    with open(parameters_json, "r") as f:
        data = json.load(f)

    params = data["parameters"]
    sample_names = list(params.keys())
    indices = np.arange(len(sample_names))
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    grey = "gray"

    # --- Figure 1: Raw vs Raw Corrected ---
    n_peaks = len(PEAK_RANGES)
    n_cols = 2
    n_rows = int(np.ceil(n_peaks / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, n_rows * 4))
    axes = axes.flatten()

    for ax_idx, (peak, ax) in enumerate(zip(PEAK_RANGES.keys(), axes)):
        raw_vals = [params[s]["before_correction"][peak] for s in sample_names]
        corr_vals = [params[s]["after_correction"][peak] for s in sample_names]

        ax.bar(indices - 0.15, raw_vals, width=0.3, color=grey, label="Raw")
        ax.set_ylabel("Raw (left)")
        ax.set_xticks(indices)
        ax.set_xticklabels(sample_names, rotation=45, ha="right")
        ax.set_title(peak)

        # Twin axis for corrected
        ax2 = ax.twinx()
        ax2.bar(indices + 0.15, corr_vals, width=0.3, color=colors[0], label="Corrected")
        ax2.set_ylabel("Corrected (right)")

    # Remove empty axes
    for i in range(len(PEAK_RANGES), len(axes)):
        fig.delaxes(axes[i])

    # Combined legend
    handles = [plt.Rectangle((0,0),1,1,color=grey), plt.Rectangle((0,0),1,1,color=colors[0])]
    labels = ["Raw", "Corrected"]
    fig.legend(handles, labels, loc="upper center", ncol=2)
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(output_dir / "raw_vs_corrected_all_peaks_dual_axis.png", dpi=300)
    plt.show()
    plt.close(fig)

    # --- Figure 2: Normalized vs Normalized Corrected ---
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, n_rows * 4))
    axes = axes.flatten()

    for ax_idx, (peak, ax) in enumerate(zip(PEAK_RANGES.keys(), axes)):
        raw_norm_vals = [params[s]["before_correction"][peak] for s in sample_names]
        corr_norm_vals = [params[s]["after_correction"].get(f"{peak}_norm", 0) for s in sample_names]

        ax.bar(indices - 0.15, raw_norm_vals, width=0.3, color=grey, label="Normalized Raw")
        ax.set_ylabel("Normalized Raw (left)")
        ax.set_xticks(indices)
        ax.set_xticklabels(sample_names, rotation=45, ha="right")
        ax.set_title(peak)

        ax2 = ax.twinx()
        ax2.bar(indices + 0.15, corr_norm_vals, width=0.3, color=colors[1], label="Normalized Corrected")
        ax2.set_ylabel("Normalized Corrected (right)")

    for i in range(len(PEAK_RANGES), len(axes)):
        fig.delaxes(axes[i])

    handles = [plt.Rectangle((0,0),1,1,color=grey), plt.Rectangle((0,0),1,1,color=colors[1])]
    labels = ["Normalized Raw", "Normalized Corrected"]
    fig.legend(handles, labels, loc="upper center", ncol=2)
    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(output_dir / "normalized_vs_corrected_all_peaks_dual_axis.png", dpi=300)
    plt.show()
    plt.close(fig)

