"""variant_toolkit — the plotting engine for Module 5 (Variant Calling).

Why this file exists
--------------------
Module 5's hands-on core is reading a VCF: the notebook keeps the `VCFRecord`
parser and the QUAL/DP/AF filters inline, because parsing and filtering a tool's
output IS the bioinformatics skill. What does NOT belong inline is ~70 lines of
matplotlib for the QC dashboard and the mutation spectrum. That lives here so the
notebook cells stay short and focused on interpretation.

Download it and run the demo:

    python variant_toolkit.py        # builds tiny demo records and plots

Each plot function takes a list of variant records that expose `.qual`, `.dp`,
`.af`, `.pos`, `.ref`, `.alt` (list), `.is_pass`, and `.is_snv` — i.e. the
`VCFRecord` objects the notebook builds.

Dependencies: matplotlib (numpy only in the demo).
"""

TI_PAIRS = {("A", "G"), ("G", "A"), ("C", "T"), ("T", "C")}


def plot_qc_dashboard(variants, passing):
    """Six-panel variant-calling QC dashboard (QUAL / DP / AF + scatters)."""
    import matplotlib.pyplot as plt
    import matplotlib.gridspec as gridspec

    fig = plt.figure(figsize=(15, 10))
    gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.4, wspace=0.35)

    quals_all = [v.qual for v in variants if v.qual]
    quals_pass = [v.qual for v in passing]
    dps_all = [v.dp for v in variants]
    dps_pass = [v.dp for v in passing]
    afs_all = [v.af for v in variants]
    afs_pass = [v.af for v in passing]

    ax1 = fig.add_subplot(gs[0, 0])
    ax1.hist(quals_all, bins=25, color="steelblue", alpha=0.6, label="All", edgecolor="white")
    ax1.hist(quals_pass, bins=25, color="mediumseagreen", alpha=0.8, label="Passing", edgecolor="white")
    ax1.axvline(30, color="red", linestyle="--", label="QUAL=30")
    ax1.set_xlabel("QUAL"); ax1.set_ylabel("Count")
    ax1.set_title("Quality Score Distribution"); ax1.legend(fontsize=8)

    ax2 = fig.add_subplot(gs[0, 1])
    ax2.hist(dps_all, bins=25, color="steelblue", alpha=0.6, label="All", edgecolor="white")
    ax2.hist(dps_pass, bins=25, color="mediumseagreen", alpha=0.8, label="Passing", edgecolor="white")
    ax2.axvline(15, color="red", linestyle="--", label="DP=15")
    ax2.set_xlabel("Read depth"); ax2.set_ylabel("Count")
    ax2.set_title("Read Depth at Variant Sites"); ax2.legend(fontsize=8)

    ax3 = fig.add_subplot(gs[0, 2])
    ax3.hist(afs_all, bins=20, color="steelblue", alpha=0.6, label="All", edgecolor="white")
    ax3.hist(afs_pass, bins=20, color="mediumseagreen", alpha=0.8, label="Passing", edgecolor="white")
    ax3.axvline(0.5, color="orange", linestyle="--", label="AF=0.5 (het)")
    ax3.set_xlabel("Allele frequency"); ax3.set_ylabel("Count")
    ax3.set_title("Allele Frequency Distribution"); ax3.legend(fontsize=8)

    colors = ["mediumseagreen" if v.is_pass else "tomato" for v in variants]

    ax4 = fig.add_subplot(gs[1, 0])
    ax4.scatter([v.dp for v in variants], [v.qual for v in variants], c=colors, alpha=0.6, s=30)
    ax4.axvline(15, color="red", linestyle="--", alpha=0.5)
    ax4.axhline(30, color="red", linestyle="--", alpha=0.5)
    ax4.set_xlabel("Read depth (DP)"); ax4.set_ylabel("QUAL")
    ax4.set_title("QUAL vs DP  (green=PASS, red=FAIL)")

    ax5 = fig.add_subplot(gs[1, 1])
    ax5.scatter([v.af for v in variants], [v.qual for v in variants], c=colors, alpha=0.6, s=30)
    ax5.axhline(30, color="red", linestyle="--", alpha=0.5)
    ax5.set_xlabel("Allele frequency"); ax5.set_ylabel("QUAL")
    ax5.set_title("QUAL vs AF")

    ax6 = fig.add_subplot(gs[1, 2])
    mask = [v.is_pass for v in variants]
    ax6.scatter([v.pos for v, m in zip(variants, mask) if not m],
                [v.qual for v, m in zip(variants, mask) if not m],
                c="tomato", alpha=0.6, s=30, label="FAIL")
    ax6.scatter([v.pos for v, m in zip(variants, mask) if m],
                [v.qual for v, m in zip(variants, mask) if m],
                c="mediumseagreen", alpha=0.6, s=30, label="PASS")
    ax6.set_xlabel("Genomic position (chr22)"); ax6.set_ylabel("QUAL")
    ax6.set_title("Variant Quality Along Chromosome"); ax6.legend(fontsize=8)

    plt.suptitle("Variant Calling QC Dashboard", fontsize=14, y=1.01)
    plt.show()


def plot_mutation_spectrum(passing):
    """Mutation-spectrum bar chart + Ti/Tv pie for the passing SNVs."""
    import matplotlib.pyplot as plt
    from collections import Counter

    spectrum = Counter()
    for v in passing:
        if v.is_snv and len(v.alt) == 1:
            spectrum[f"{v.ref}>{v.alt[0]}"] += 1

    types = sorted(spectrum)
    ti_types = [k for k in types if (k[0], k[2]) in TI_PAIRS]
    counts = [spectrum[k] for k in types]
    colors = ["steelblue" if k in ti_types else "darkorange" for k in types]

    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    axes[0].bar(types, counts, color=colors, edgecolor="white")
    axes[0].set_xlabel("Mutation type"); axes[0].set_ylabel("Count")
    axes[0].set_title("Mutation Spectrum (blue=Ti, orange=Tv)")
    axes[0].tick_params(axis="x", rotation=45)

    ti_total = sum(spectrum[k] for k in ti_types)
    tv_total = sum(spectrum[k] for k in types if k not in ti_types)
    axes[1].pie([ti_total, tv_total],
                labels=[f"Transitions (Ti)\n{ti_total}", f"Transversions (Tv)\n{tv_total}"],
                colors=["steelblue", "darkorange"], autopct="%1.1f%%", startangle=90)
    axes[1].set_title(f"Ti/Tv = {ti_total / max(1, tv_total):.2f}")
    plt.tight_layout(); plt.show()
    return ti_total, tv_total


if __name__ == "__main__":
    # Minimal stand-in records so the demo runs without the notebook's parser.
    class _V:
        def __init__(self, qual, dp, af, pos, ref, alt, is_pass):
            self.qual, self.dp, self.af, self.pos = qual, dp, af, pos
            self.ref, self.alt, self.is_pass = ref, [alt], is_pass
            self.is_snv = len(ref) == 1 and len(alt) == 1

    import random
    random.seed(1)
    demo = []
    for i in range(40):
        ref, alt = random.choice([("A", "G"), ("C", "T"), ("A", "C"), ("G", "T")])
        q = random.uniform(10, 300)
        demo.append(_V(q, random.randint(8, 60), random.uniform(0.1, 1.0),
                       17_000_000 + i * 10_000, ref, alt, q >= 30))
    passing = [v for v in demo if v.is_pass]
    print(f"Demo: {len(demo)} variants, {len(passing)} passing")
    plot_qc_dashboard(demo, passing)
    ti, tv = plot_mutation_spectrum(passing)
    print(f"Ti={ti} Tv={tv} Ti/Tv={ti/max(1,tv):.2f}")
