import matplotlib.pyplot as plt
import numpy as np

from binny import NZTomography

def nested_dict_to_matrix(nested_dict):
    keys = sorted(nested_dict.keys())
    matrix = np.array(
        [[nested_dict[row_key][col_key] for col_key in keys] for row_key in keys],
        dtype=float,
    )
    return keys, matrix

z = np.linspace(0.0, 2.0, 500)

nz = NZTomography.nz_model(
    "smail",
    z,
    z0=0.2,
    alpha=2.0,
    beta=1.0,
    normalize=True,
)

photoz_spec = {
    "kind": "photoz",
    "bins": {
        "scheme": "equidistant",
        "n_bins": 4,
        "range": (0.2, 1.2),
    },
    "uncertainties": {
        "scatter_scale": 0.05,
        "mean_offset": 0.01,
        "outlier_frac": 0.03,
        "outlier_scatter_scale": 0.20,
        "outlier_mean_offset": 0.05,
    },
    "normalize_bins": True,
}

tomo = NZTomography()
result = tomo.build_bins(
    z=z,
    nz=nz,
    tomo_spec=photoz_spec,
    include_tomo_metadata=True,
)

bin_edges = result.tomo_meta["bins"]["bin_edges"]

stats = tomo.cross_bin_stats(
    overlap={"method": "min", "unit": "percent", "normalize": True, "decimal_places": 3},
    leakage={"bin_edges": bin_edges, "unit": "percent", "decimal_places": 3},
    pearson={"normalize": True, "decimal_places": 3},
)

overlap_keys, overlap_matrix = nested_dict_to_matrix(stats["overlap"])
leakage_keys, leakage_matrix = nested_dict_to_matrix(stats["leakage"])
pearson_keys, pearson_matrix = nested_dict_to_matrix(stats["pearson"])

fig, axes = plt.subplots(1, 3, figsize=(14.5, 4.6))

matrices = [
    (
        overlap_keys,
        overlap_matrix,
        "Overlap matrix",
        "Tomographic bin",
        "Tomographic bin"
    ),
    (
        leakage_keys,
        leakage_matrix,
        "Leakage matrix",
        "Nominal interval",
        "Input bin"
    ),
    (
        pearson_keys,
        pearson_matrix,
        "Pearson matrix",
        "Tomographic bin",
        "Tomographic bin",
    ),
]

for ax, (keys, matrix, title, xlabel, ylabel) in zip(axes, matrices, strict=True):
    n_rows, n_cols = matrix.shape

    ax.imshow(
        matrix,
        origin="lower",
        aspect="equal",
        cmap="viridis",
        alpha=0.6,
        interpolation="none",
    )

    ax.set_title(title)
    ax.set_xticks(np.arange(n_cols))
    ax.set_yticks(np.arange(n_rows))
    ax.set_xticklabels([f"{key + 1}" for key in keys])
    ax.set_yticklabels([f"{key + 1}" for key in keys])
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    ax.set_xticks(np.arange(-0.5, n_cols, 1), minor=True)
    ax.set_yticks(np.arange(-0.5, n_rows, 1), minor=True)
    ax.grid(which="minor", color="k", linestyle="-", linewidth=2)
    ax.tick_params(which="minor", bottom=False, left=False)

    for i in range(n_rows):
        for j in range(n_cols):
            ax.text(
                j,
                i,
                f"{matrix[i, j]:.1f}",
                ha="center",
                va="center",
                fontsize=15,
                color="k",
            )

plt.tight_layout()