"""rnaseq_toolkit — the RNA-seq engine for Module 6 (Differential Expression).

Why this file exists
--------------------
Differential expression is a *bioinformatics* task, not a Python exercise. On a
real machine you run **DESeq2** (R/Bioconductor): it fits a negative-binomial
GLM per gene, shrinks dispersions across genes, and runs a Wald test. You do not
hand-code the statistics — you run the tool and read its output table.

So Module 6's notebook does exactly that: it shows the DESeq2 R script
(`run_deseq2.R`), then *loads* the results DESeq2 produces
(`deseq2_results.tsv`, `deseq2_normalized_counts.tsv`) and spends its cells
*interpreting* and *plotting* them — which is the real day-to-day skill.

This module holds the two things a bioinformatician genuinely does by hand:
loading/parsing those tables, and the matplotlib boilerplate for the standard
RNA-seq figures (library-size QC, volcano, MA, PCA, heatmaps). Download it and
run the demo:

    python rnaseq_toolkit.py

Dependencies: numpy + matplotlib (matplotlib only for the plot_* helpers).
"""

import math
import numpy as np


# ---------------------------------------------------------------------------
# 1. Loading the count matrix and DESeq2's output tables
# ---------------------------------------------------------------------------
def load_counts(path):
    """Read a gene-by-sample counts TSV.

    Returns (genes, sample_names, counts_matrix, ctrl_idx, treat_idx) where
    counts_matrix has shape (n_genes, n_samples).
    """
    with open(path) as f:
        header = f.readline().strip().split("\t")
        genes, rows = [], []
        for line in f:
            parts = line.strip().split("\t")
            genes.append(parts[0])
            rows.append([float(x) for x in parts[1:]])
    samples = header[1:]
    counts = np.array(rows, dtype=float)
    ctrl_idx = [i for i, s in enumerate(samples) if s.startswith("ctrl")]
    treat_idx = [i for i, s in enumerate(samples) if s.startswith("treat")]
    return genes, samples, counts, ctrl_idx, treat_idx


def load_normalized_counts(path, genes):
    """Read DESeq2 normalized counts and align rows to `genes` order."""
    table = {}
    with open(path) as f:
        f.readline()  # header
        for line in f:
            parts = line.strip().split("\t")
            table[parts[0]] = [float(x) for x in parts[1:]]
    return np.array([table[g] for g in genes], dtype=float)


def load_deseq2_results(path, genes):
    """Parse DESeq2's results table into a list of per-gene dicts.

    Each dict: gene, basemean, log2fc, lfcSE, stat, p_value, padj, gene_idx
    (gene_idx = the gene's row in the counts/normalized matrices). The list is
    returned in the file's order (DESeq2 writes it sorted by significance).
    """
    idx_of = {g: i for i, g in enumerate(genes)}
    results = []
    with open(path) as f:
        f.readline()  # header: gene baseMean log2FoldChange lfcSE stat pvalue padj
        for line in f:
            p = line.strip().split("\t")
            if not p or not p[0]:
                continue
            gene = p[0]

            def num(x):
                try:
                    return float(x)
                except ValueError:
                    return float("nan")
            results.append({
                "gene": gene,
                "basemean": num(p[1]),
                "log2fc": num(p[2]),
                "lfcSE": num(p[3]),
                "stat": num(p[4]),
                "p_value": num(p[5]),
                "padj": num(p[6]),
                "gene_idx": idx_of.get(gene, -1),
            })
    return results


def split_significant(results, padj_thresh=0.05, fc_thresh=1.0):
    """Split results into (up, down, all-significant) by padj and |log2FC|."""
    sig_up, sig_down = [], []
    for r in results:
        if r["padj"] < padj_thresh and abs(r["log2fc"]) >= fc_thresh:
            (sig_up if r["log2fc"] > 0 else sig_down).append(r)
    return sig_up, sig_down, sig_up + sig_down


# ---------------------------------------------------------------------------
# 2. Plotting — matplotlib boilerplate lives here so notebook cells stay short.
# ---------------------------------------------------------------------------
CTRL_COLOR = "steelblue"
TREAT_COLOR = "darkorange"


def _sample_colors(n_samples, ctrl_idx):
    return [CTRL_COLOR if i in ctrl_idx else TREAT_COLOR for i in range(n_samples)]


def plot_library_qc(counts, samples, ctrl_idx, treat_idx):
    """Two-panel pre-normalization QC: library size + count distribution."""
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches

    colors = _sample_colors(len(samples), ctrl_idx)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].bar(samples, counts.sum(axis=0) / 1e6, color=colors)
    axes[0].set_xlabel("Sample"); axes[0].set_ylabel("Million reads")
    axes[0].set_title("Library Size per Sample")
    axes[0].tick_params(axis="x", rotation=30)

    bp = axes[1].boxplot(np.log2(counts + 1), tick_labels=samples, patch_artist=True)
    for patch, c in zip(bp["boxes"], colors):
        patch.set_facecolor(c); patch.set_alpha(0.7)
    axes[1].set_xlabel("Sample"); axes[1].set_ylabel("log2(count + 1)")
    axes[1].set_title("Count Distribution (raw)")
    axes[1].tick_params(axis="x", rotation=30)
    axes[1].legend(handles=[mpatches.Patch(color=CTRL_COLOR, label="Control"),
                            mpatches.Patch(color=TREAT_COLOR, label="Treated")])
    plt.tight_layout(); plt.show()


def plot_volcano(results, padj_thresh=0.05, fc_thresh=1.0, title="Volcano Plot"):
    """Standard volcano: log2FC (x) vs -log10(padj) (y), colored up/down/NS."""
    import matplotlib.pyplot as plt

    x = np.array([r["log2fc"] for r in results])
    y = np.array([-math.log10(max(r["padj"], 1e-300)) for r in results])
    ythr = -math.log10(padj_thresh)
    up = (x >= fc_thresh) & (y >= ythr)
    down = (x <= -fc_thresh) & (y >= ythr)
    ns = ~(up | down)

    fig, ax = plt.subplots(figsize=(9, 7))
    ax.scatter(x[ns], y[ns], c="lightgray", alpha=0.5, s=20, label=f"NS ({ns.sum()})")
    ax.scatter(x[up], y[up], c="#e74c3c", alpha=0.8, s=40, label=f"Up ({up.sum()})")
    ax.scatter(x[down], y[down], c="#3498db", alpha=0.8, s=40, label=f"Down ({down.sum()})")
    ax.axhline(ythr, color="gray", linestyle="--", alpha=0.6, label=f"padj={padj_thresh}")
    ax.axvline(fc_thresh, color="gray", linestyle=":", alpha=0.6)
    ax.axvline(-fc_thresh, color="gray", linestyle=":", alpha=0.6)
    for r in sorted(results, key=lambda r: r["padj"])[:8]:
        ax.annotate(r["gene"], (r["log2fc"], -math.log10(max(r["padj"], 1e-300))),
                    textcoords="offset points", xytext=(5, 3), fontsize=7, alpha=0.8)
    ax.set_xlabel("log₂ Fold Change (Treated / Control)", fontsize=11)
    ax.set_ylabel("-log₁₀(adjusted p-value)", fontsize=11)
    ax.set_title(title, fontsize=13)
    ax.legend(loc="upper right", fontsize=9)
    plt.tight_layout(); plt.show()


def plot_ma_and_pca(results, log_norm, samples, ctrl_idx, treat_idx, padj_thresh=0.05):
    """MA plot (mean expression vs log2FC) beside a PCA of the samples."""
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    basemean = np.array([r["basemean"] for r in results])
    l2fc = np.array([r["log2fc"] for r in results])
    padj = np.array([r["padj"] for r in results])
    sig = padj < padj_thresh
    axes[0].scatter(np.log2(basemean[~sig] + 1), l2fc[~sig], c="lightgray", alpha=0.5, s=20, label="NS")
    axes[0].scatter(np.log2(basemean[sig] + 1), l2fc[sig], c="#e74c3c", alpha=0.8, s=40, label="Significant")
    axes[0].axhline(0, color="black", linewidth=0.8)
    axes[0].set_xlabel("log₂(mean expression)"); axes[0].set_ylabel("log₂ fold change")
    axes[0].set_title("MA Plot"); axes[0].legend()

    X = log_norm.T - log_norm.T.mean(axis=0)
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    pcs = U * S[np.newaxis, :]
    var = (S ** 2) / (S ** 2).sum() * 100
    colors = _sample_colors(len(samples), ctrl_idx)
    for i, name in enumerate(samples):
        axes[1].scatter(pcs[i, 0], pcs[i, 1], c=colors[i], s=100, zorder=3)
        axes[1].annotate(name, (pcs[i, 0], pcs[i, 1]),
                         textcoords="offset points", xytext=(6, 3), fontsize=9)
    axes[1].set_xlabel(f"PC1 ({var[0]:.1f}% variance)")
    axes[1].set_ylabel(f"PC2 ({var[1]:.1f}% variance)")
    axes[1].set_title("PCA of RNA-seq samples")
    axes[1].axhline(0, color="gray", linewidth=0.5); axes[1].axvline(0, color="gray", linewidth=0.5)
    axes[1].legend(handles=[mpatches.Patch(color=CTRL_COLOR, label="Control"),
                            mpatches.Patch(color=TREAT_COLOR, label="Treated")])
    plt.tight_layout(); plt.show()
    return var


def plot_heatmap(genes_subset, rows, log_norm, samples):
    """Z-scored expression heatmap for a chosen set of gene rows."""
    import matplotlib.pyplot as plt

    mat = log_norm[rows, :]
    mu = mat.mean(axis=1, keepdims=True)
    sd = mat.std(axis=1, keepdims=True); sd[sd == 0] = 1.0
    z = (mat - mu) / sd
    fig, ax = plt.subplots(figsize=(8, 7))
    im = ax.imshow(z, cmap="RdBu_r", aspect="auto", vmin=-2, vmax=2)
    ax.set_xticks(range(len(samples))); ax.set_xticklabels(samples, rotation=30, ha="right")
    ax.set_yticks(range(len(genes_subset))); ax.set_yticklabels(genes_subset, fontsize=8)
    cbar = fig.colorbar(im, ax=ax, shrink=0.7); cbar.set_label("Z-score (per gene)")
    ax.set_title("Top DE genes (Z-score normalized)")
    plt.tight_layout(); plt.show()


# ---------------------------------------------------------------------------
# Demo — runs when you execute this file directly.
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import os
    here = os.path.dirname(os.path.abspath(__file__))
    base = os.path.join(here, "data", "capstone")
    genes, samples, counts, ctrl, treat = load_counts(os.path.join(base, "synthetic_counts.tsv"))
    print(f"Loaded {len(genes)} genes x {len(samples)} samples: {samples}")
    results = load_deseq2_results(os.path.join(base, "deseq2_results.tsv"), genes)
    up, down, sig = split_significant(results)
    print(f"DESeq2 results: {len(results)} genes, {len(sig)} significant "
          f"({len(up)} up, {len(down)} down)")
    print("Top 5 by padj:")
    for r in sorted(results, key=lambda r: r["padj"])[:5]:
        print(f"  {r['gene']}: log2FC={r['log2fc']:+.2f}  padj={r['padj']:.2e}")
