{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Module 7: Bioinformatics Visualization\n",
    "\n",
    "**Duration:** 50 minutes &nbsp;\u00b7&nbsp; **Day:** 2 of 2\n",
    "\n",
    "## Learning objectives\n",
    "By the end of this module you will be able to:\n",
    "- Read a coverage plot and a log\u2082-ratio track from aligned reads\n",
    "- Build a z-scored expression heatmap\n",
    "- Run PCA on RNA-seq samples and read a scree plot\n",
    "- Turn a gene list into a pathway-enrichment figure\n",
    "- Assemble a publication-style multi-panel summary\n",
    "\n",
    "> **Tip:** this page is interactive. The *data journey* animation up top previews\n",
    "> the whole pipeline, and each section below has its own live explorer embedded\n",
    "> right next to the code that builds the figure \u2014 drag the sliders as you read.\n",
    "\n",
    "> **Why visualization matters:** a result no one can understand is not a result.\n",
    "> Every figure here starts from the same data structures you've built all week."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We pull in matplotlib (plotting), numpy (numeric arrays) and seaborn (nicer\n",
    "defaults). Fixed random seeds make every figure reproducible, and\n",
    "`capstone_dir` points at the shared sample data."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "import numpy as np\n",
    "import math, random, os\n",
    "from pathlib import Path\n",
    "\n",
    "try:\n",
    "    import seaborn as sns\n",
    "    HAS_SEABORN = True\n",
    "    sns.set_theme(style='whitegrid', palette='muted')\n",
    "except ImportError:\n",
    "    HAS_SEABORN = False\n",
    "    print(\"seaborn not installed \u2014 using matplotlib directly\")\n",
    "\n",
    "random.seed(42); np.random.seed(42)\n",
    "capstone_dir = Path(os.path.abspath(\"../../\")) / 'data' / 'capstone'\n",
    "print(f\"seaborn available: {HAS_SEABORN}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Section 1 \u00b7 Coverage plots\n",
    "\n",
    "A **coverage plot** shows how many reads align to each genomic position \u2014 the\n",
    "first thing you check in IGV when an alignment looks off. We'll build one\n",
    "*without* a real BAM by simulating reads, so the mechanics stay clear.\n",
    "\n",
    "**Step 1.** A helper that drops `n_reads` of length `read_len` at random start\n",
    "positions and tallies per-base depth. Optional `peaks` add Gaussian bumps where\n",
    "exons are enriched."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def simulate_coverage(start, end, n_reads=500, read_len=150, peaks=None):\n",
    "    \"\"\"Per-base coverage over a region; peaks=[(center, height_mult)] enrich exons.\"\"\"\n",
    "    region_len = end - start\n",
    "    coverage = np.zeros(region_len)\n",
    "    for _ in range(n_reads):\n",
    "        r_start = random.randint(0, max(1, region_len - read_len))\n",
    "        coverage[r_start:min(r_start + read_len, region_len)] += 1\n",
    "    if peaks:\n",
    "        for center, mult in peaks:\n",
    "            c = center - start\n",
    "            if 0 <= c < region_len:\n",
    "                sigma = 300\n",
    "                for pos in range(max(0, c - 3*sigma), min(region_len, c + 3*sigma)):\n",
    "                    coverage[pos] += mult * n_reads * 0.5 * math.exp(-((pos-c)**2) / (2*sigma**2))\n",
    "    return list(range(start, end)), coverage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 2.** Define a toy gene with four exons, then make two tracks: a\n",
    "*normal* sample and a *tumor* sample. The tumor amplifies exons 1\u20132 (mimicking a\n",
    "copy-number gain) and leaves exons 3\u20134 at normal depth."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "gene_start, gene_end = 17_100_000, 17_110_000\n",
    "exons = [\n",
    "    (gene_start + 500,  gene_start + 1200),\n",
    "    (gene_start + 3000, gene_start + 3800),\n",
    "    (gene_start + 6500, gene_start + 7500),\n",
    "    (gene_start + 8800, gene_start + 9700),\n",
    "]\n",
    "mid = lambda ex: ex[0] + (ex[1] - ex[0]) // 2\n",
    "\n",
    "positions, cov_normal = simulate_coverage(gene_start, gene_end, n_reads=800,\n",
    "                                          peaks=[(mid(ex), 3) for ex in exons])\n",
    "_, cov_tumor = simulate_coverage(gene_start, gene_end, n_reads=1200,\n",
    "                                 peaks=[(mid(ex), 5) for ex in exons[:2]] + [(mid(ex), 1) for ex in exons[2:]])\n",
    "\n",
    "print(f\"Coverage arrays: {len(positions)} positions\")\n",
    "print(f\"Normal mean: {cov_normal.mean():.1f}x   Tumor mean: {cov_tumor.mean():.1f}x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 3.** The normal and tumor panels are drawn the same way \u2014 a filled depth\n",
    "area plus a dashed mean line \u2014 so we factor that into a one-line helper. Pulling\n",
    "the repetition out keeps the figure code below short and readable."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def coverage_track(ax, x, y, fill, mean_color, ylabel):\n",
    "    \"\"\"Draw one coverage panel: filled depth + dashed mean line.\"\"\"\n",
    "    ax.fill_between(x, y, color=fill, alpha=0.7)\n",
    "    ax.axhline(y.mean(), color=mean_color, linestyle='--', alpha=0.6, label=f'Mean {y.mean():.0f}x')\n",
    "    ax.set_ylabel(ylabel); ax.legend(fontsize=8)\n",
    "    ax.set_xlim(x[0], x[-1]); ax.set_xticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 4.** Now stack four aligned panels that share the x-axis:\n",
    "\n",
    "1. **Gene model** \u2014 exons as blue blocks on an intron backbone\n",
    "2. **Normal coverage** \u00b7 3. **Tumor coverage** \u2014 via the helper above\n",
    "4. **log\u2082 ratio** \u2014 `log\u2082(tumor / normal)` per base; bars above the red `+1` line\n",
    "   are the amplified (gained) regions\n",
    "\n",
    "The log\u2082-ratio track is the payoff: amplified exons jump above the threshold."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(14, 8))\n",
    "gs  = gridspec.GridSpec(4, 1, height_ratios=[0.8, 2, 2, 0.6], hspace=0.08)\n",
    "pos_mb = np.array(positions) / 1e6\n",
    "\n",
    "ax_ann = fig.add_subplot(gs[0])\n",
    "ax_ann.plot([pos_mb[0], pos_mb[-1]], [0, 0], 'k-', linewidth=1)\n",
    "for ex_s, ex_e in exons:\n",
    "    ax_ann.fill_between([ex_s/1e6, ex_e/1e6], [-0.3, -0.3], [0.3, 0.3], color='steelblue', alpha=0.8)\n",
    "ax_ann.set_xlim(pos_mb[0], pos_mb[-1]); ax_ann.set_ylim(-0.5, 0.5)\n",
    "ax_ann.set_yticks([]); ax_ann.set_xticks([])\n",
    "ax_ann.set_title('Coverage Plot: chr22:17.10-17.11 Mb', fontsize=12, pad=4)\n",
    "\n",
    "coverage_track(fig.add_subplot(gs[1]), pos_mb, cov_normal, 'steelblue', 'navy', 'Coverage (normal)')\n",
    "coverage_track(fig.add_subplot(gs[2]), pos_mb, cov_tumor, '#e74c3c', 'darkred', 'Coverage (tumor)')\n",
    "\n",
    "ax_r = fig.add_subplot(gs[3])\n",
    "log2_ratio = np.log2((cov_tumor + 1.0) / (cov_normal + 1.0))\n",
    "ax_r.bar(pos_mb, log2_ratio, color=np.where(log2_ratio > 0, '#e74c3c', '#3498db'), width=0.0002, alpha=0.7)\n",
    "ax_r.axhline(0, color='black', linewidth=0.8)\n",
    "ax_r.axhline(1, color='red', linestyle='--', alpha=0.4, linewidth=0.8)\n",
    "ax_r.set_ylabel('log\u2082 ratio'); ax_r.set_xlabel('Genomic position (Mb)')\n",
    "ax_r.set_xlim(pos_mb[0], pos_mb[-1])\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Play with it.** The explorer below is the same coverage\u2192CNV idea, live. Crank\n",
    "the tumor amplification on exons 1\u20132 and watch them rise above the gain-call\n",
    "threshold; drop the threshold and you start calling noise as gains \u2014 the exact\n",
    "trade-off a real copy-number caller faces."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!--widget:coverage-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Section 2 \u00b7 Expression heatmap\n",
    "\n",
    "A heatmap shows expression for many genes across many samples at once. Two\n",
    "preprocessing steps make it readable:\n",
    "\n",
    "- **log\u2082 CPM** \u2014 divide each count by its sample's library size (counts-per-\n",
    "  million), then `log\u2082`. This removes the \"bigger library \u21d2 bigger numbers\"\n",
    "  artefact so samples are comparable.\n",
    "- **z-score per gene** \u2014 center and scale each row so colour means *relative*\n",
    "  expression. Without it, a few high-baseline genes wash out everything else.\n",
    "\n",
    "The cell below loads the Module 6 count matrix and applies both. We keep the\n",
    "first 30 genes \u2014 the first 20 are differentially expressed, the last 10 are not."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "with open(capstone_dir / 'synthetic_counts.tsv') as f:\n",
    "    header_line = f.readline().strip().split('\\t')\n",
    "    genes_list, counts_list = [], []\n",
    "    for line in f:\n",
    "        parts = line.strip().split('\\t')\n",
    "        genes_list.append(parts[0])\n",
    "        counts_list.append([int(x) for x in parts[1:]])\n",
    "\n",
    "sample_names = header_line[1:]\n",
    "counts_np = np.array(counts_list, dtype=float)\n",
    "ctrl_idx  = [i for i, s in enumerate(sample_names) if s.startswith('ctrl')]\n",
    "treat_idx = [i for i, s in enumerate(sample_names) if s.startswith('treat')]\n",
    "\n",
    "lib_sizes = counts_np.sum(axis=0)\n",
    "log_cpm = np.log2(counts_np / lib_sizes[np.newaxis, :] * 1e6 + 1)\n",
    "\n",
    "def zscore_rows(mat):\n",
    "    means = mat.mean(axis=1, keepdims=True)\n",
    "    stds  = mat.std(axis=1, keepdims=True); stds[stds == 0] = 1\n",
    "    return (mat - means) / stds\n",
    "\n",
    "top_genes = 30\n",
    "z_scores = zscore_rows(log_cpm[:top_genes])\n",
    "print(f\"Heatmap matrix shape: {z_scores.shape}  (genes x samples)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now draw it with `imshow`: rows = genes, columns = samples, a diverging\n",
    "red/blue colormap centered at zero. A black divider separates control from\n",
    "treated columns, and DE genes are flagged on the right. The DE block should\n",
    "light up as a clear red/blue split between the two groups."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(9, 10))\n",
    "im = ax.imshow(z_scores, cmap='RdBu_r', aspect='auto', vmin=-2.5, vmax=2.5, interpolation='nearest')\n",
    "cbar = plt.colorbar(im, ax=ax, fraction=0.03, pad=0.02)\n",
    "cbar.set_label('Z-score (log\u2082 CPM)', fontsize=10)\n",
    "\n",
    "ax.set_xticks(range(len(sample_names)))\n",
    "ax.set_xticklabels(sample_names, rotation=30, ha='right', fontsize=10)\n",
    "ax.set_yticks(range(top_genes))\n",
    "ax.set_yticklabels(genes_list[:top_genes], fontsize=8)\n",
    "ax.axvline(len(ctrl_idx) - 0.5, color='black', linewidth=2)\n",
    "for gi in range(20):\n",
    "    ax.text(len(sample_names) - 0.4, gi, 'DE', va='center', fontsize=7, color='darkred')\n",
    "\n",
    "ax.set_title(f'Gene Expression Heatmap\\n(Z-score normalized log\u2082 CPM, top {top_genes} genes)', fontsize=12)\n",
    "ax.text(len(ctrl_idx)/2 - 0.5,  -1.0, 'Control', ha='center', fontsize=10, color='steelblue')\n",
    "ax.text(len(ctrl_idx) + len(treat_idx)/2 - 0.5, -1.0, 'Treated', ha='center', fontsize=10, color='darkorange')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Section 3 \u00b7 PCA + confidence ellipses\n",
    "\n",
    "PCA compresses thousands of genes into a few axes that capture the most\n",
    "variation, so you can *see* whether samples cluster by condition. It's the\n",
    "single best QC step in RNA-seq \u2014 batch effects and outliers jump out instantly.\n",
    "\n",
    "We compute it from scratch with an **SVD** of the mean-centered log\u2082-CPM matrix\n",
    "(no scikit-learn needed). First, a small helper that draws a 2\u03c3 **confidence\n",
    "ellipse** around a group of points \u2014 we'll reuse it on every PCA plot."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "def confidence_ellipse(points, ax, n_std=2.0, **kwargs):\n",
    "    \"\"\"Add a covariance ellipse (n_std sigmas) around a cloud of 2-D points.\"\"\"\n",
    "    from matplotlib.patches import Ellipse\n",
    "    if len(points) < 3:\n",
    "        return\n",
    "    mean = points.mean(axis=0)\n",
    "    eigvals, eigvecs = np.linalg.eigh(np.cov(points.T))\n",
    "    angle  = math.degrees(math.atan2(eigvecs[1, 0], eigvecs[0, 0]))\n",
    "    width  = 2 * n_std * math.sqrt(abs(eigvals[0]))\n",
    "    height = 2 * n_std * math.sqrt(abs(eigvals[1]))\n",
    "    ax.add_patch(Ellipse(xy=mean, width=width, height=height, angle=angle, **kwargs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the PCA itself. The squared singular values give each PC's share of variance;\n",
    "`pcs` holds every sample's coordinates along the principal components."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "X_c = log_cpm.T - log_cpm.T.mean(axis=0)\n",
    "U, S, Vt = np.linalg.svd(X_c, full_matrices=False)\n",
    "pcs = U * S[np.newaxis, :]\n",
    "var = (S**2) / (S**2).sum() * 100\n",
    "print(f\"PC1 explains {var[0]:.1f}% of variance, PC2 {var[1]:.1f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**The scatter.** Samples on PC1 vs PC2, colored by condition, each group wrapped\n",
    "in its confidence ellipse. Good separation along PC1 means the treatment is the\n",
    "dominant source of variation."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(7, 5.5))\n",
    "group_map = {'ctrl': ('steelblue', 'Control', 'o'), 'treat': ('darkorange', 'Treated', 's')}\n",
    "\n",
    "for prefix, (color, label, marker) in group_map.items():\n",
    "    idx = [i for i, s in enumerate(sample_names) if s.startswith(prefix)]\n",
    "    pts = pcs[idx, :2]\n",
    "    ax.scatter(pts[:, 0], pts[:, 1], c=color, s=100, marker=marker,\n",
    "               zorder=3, label=label, edgecolors='black', linewidths=0.5)\n",
    "    for samp_i in idx:\n",
    "        ax.annotate(sample_names[samp_i], pcs[samp_i, :2],\n",
    "                    textcoords='offset points', xytext=(6, 3), fontsize=8)\n",
    "    confidence_ellipse(pts, ax, n_std=2, facecolor=color, alpha=0.1, edgecolor=color, linewidth=1.5)\n",
    "\n",
    "ax.axhline(0, color='gray', linewidth=0.5); ax.axvline(0, color='gray', linewidth=0.5)\n",
    "ax.set_xlabel(f'PC1 ({var[0]:.1f}%)', fontsize=11)\n",
    "ax.set_ylabel(f'PC2 ({var[1]:.1f}%)', fontsize=11)\n",
    "ax.set_title('PCA of RNA-seq Samples'); ax.legend()\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Play with it.** What decides whether those clusters actually separate? Two\n",
    "things: the size of the treatment effect and the biological noise between\n",
    "replicates. Push them around below and watch the ellipses pull apart or smear\n",
    "together \u2014 and watch PC1's variance share rise and fall with them."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!--widget:pca-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**The scree plot.** How many PCs do you actually need? Variance per PC (bars)\n",
    "plus the cumulative curve (line); the dashed line marks 80% explained."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(7, 5))\n",
    "n_pcs = min(10, len(var))\n",
    "cumvar = np.cumsum(var[:n_pcs])\n",
    "ax.bar(range(1, n_pcs+1), var[:n_pcs], color='steelblue', edgecolor='white', label='Per-PC')\n",
    "ax2 = ax.twinx()\n",
    "ax2.plot(range(1, n_pcs+1), cumvar, 'o-', color='darkorange', linewidth=2, label='Cumulative')\n",
    "ax2.axhline(80, color='red', linestyle='--', alpha=0.5, label='80% explained')\n",
    "ax2.set_ylabel('Cumulative variance (%)', color='darkorange')\n",
    "ax.set_xlabel('Principal Component'); ax.set_ylabel('Variance explained (%)')\n",
    "ax.set_title('Scree Plot')\n",
    "l1, lab1 = ax.get_legend_handles_labels()\n",
    "l2, lab2 = ax2.get_legend_handles_labels()\n",
    "ax.legend(l1 + l2, lab1 + lab2, fontsize=9)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Section 4 \u00b7 Pathway enrichment\n",
    "\n",
    "After finding DE genes the question becomes: **what biology is enriched?**\n",
    "Over-representation analysis (ORA) asks whether a pathway's genes show up in your\n",
    "DE list more than chance would predict. Tools like clusterProfiler or DAVID\n",
    "return a table like the one we hard-code below \u2014 an adjusted p-value, a gene\n",
    "count, and a background size per pathway.\n",
    "\n",
    "The **enrichment ratio** = (fraction of DE genes in the pathway) \u00f7 (fraction of\n",
    "all genes in the pathway). Above 1 means the pathway is over-represented."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "pathway_results = [\n",
    "    {'pathway': 'Cell cycle',               'p_adj': 2.1e-8, 'gene_count': 18, 'bg': 120, 'direction': 'up'},\n",
    "    {'pathway': 'DNA repair',               'p_adj': 5.3e-7, 'gene_count': 12, 'bg': 85,  'direction': 'up'},\n",
    "    {'pathway': 'Oxidative phosphorylation','p_adj': 1.2e-6, 'gene_count': 15, 'bg': 98,  'direction': 'up'},\n",
    "    {'pathway': 'mTOR signaling',           'p_adj': 8.4e-6, 'gene_count': 9,  'bg': 62,  'direction': 'up'},\n",
    "    {'pathway': 'P53 pathway',              'p_adj': 2.1e-5, 'gene_count': 8,  'bg': 55,  'direction': 'up'},\n",
    "    {'pathway': 'Apoptosis',                'p_adj': 4.5e-5, 'gene_count': 11, 'bg': 82,  'direction': 'down'},\n",
    "    {'pathway': 'Immune response',          'p_adj': 1.1e-4, 'gene_count': 7,  'bg': 67,  'direction': 'down'},\n",
    "    {'pathway': 'Fatty acid metabolism',    'p_adj': 2.3e-4, 'gene_count': 6,  'bg': 48,  'direction': 'down'},\n",
    "    {'pathway': 'Wnt signaling',            'p_adj': 5.6e-4, 'gene_count': 5,  'bg': 45,  'direction': 'down'},\n",
    "    {'pathway': 'Notch signaling',          'p_adj': 9.8e-4, 'gene_count': 4,  'bg': 32,  'direction': 'down'},\n",
    "]\n",
    "\n",
    "n_de, n_bg = 20, 100\n",
    "for p in pathway_results:\n",
    "    p['enrichment'] = (p['gene_count'] / n_de) / (p['bg'] / n_bg)\n",
    "    p['neg_log_p']  = -math.log10(p['p_adj'])\n",
    "pathway_results.sort(key=lambda x: x['p_adj'])\n",
    "\n",
    "print(f\"{'Pathway':30s} {'p_adj':>10} {'Genes':>6} {'Enrich':>8} {'Direction'}\")\n",
    "print(\"-\" * 65)\n",
    "for p in pathway_results:\n",
    "    print(f\"{p['pathway']:30s} {p['p_adj']:>10.2e} {p['gene_count']:>6d} {p['enrichment']:>8.2f} {p['direction']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First pull the columns we'll plot out of the table once, so both figures below\n",
    "just reference these short lists."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "names    = [p['pathway']    for p in pathway_results]\n",
    "neg_logp = [p['neg_log_p']  for p in pathway_results]\n",
    "enrich   = [p['enrichment'] for p in pathway_results]\n",
    "gcounts  = [p['gene_count'] for p in pathway_results]\n",
    "bar_colors = ['#e74c3c' if p['direction'] == 'up' else '#3498db' for p in pathway_results]\n",
    "y_pos = list(range(len(names)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**View 1 \u2014 the bar chart.** `\u2212log\u2081\u2080(p_adj)` per pathway; taller bars are more\n",
    "significant. The dashed line is the p = 0.05 cutoff (red = up-regulated, blue = down)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 5))\n",
    "ax.barh(y_pos, neg_logp, color=bar_colors, edgecolor='white', alpha=0.85)\n",
    "ax.set_yticks(y_pos); ax.set_yticklabels(names, fontsize=9)\n",
    "ax.axvline(-math.log10(0.05), color='gray', linestyle='--', alpha=0.7, label='padj=0.05')\n",
    "ax.set_xlabel('-log\u2081\u2080(adjusted p-value)')\n",
    "ax.set_title('Pathway Enrichment (red=up, blue=down)'); ax.legend(fontsize=9)\n",
    "ax.invert_yaxis()\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**View 2 \u2014 the dot plot** (clusterProfiler-style). Same pathways, but now\n",
    "x = enrichment ratio, color = significance, and dot size = gene count \u2014 three\n",
    "numbers in one glance."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 5))\n",
    "scatter = ax.scatter(enrich, y_pos, c=neg_logp, cmap='YlOrRd',\n",
    "                     s=[g * 20 for g in gcounts], vmin=0, vmax=max(neg_logp), zorder=3)\n",
    "plt.colorbar(scatter, ax=ax, label='-log\u2081\u2080(p_adj)')\n",
    "ax.set_yticks(y_pos); ax.set_yticklabels(names, fontsize=9)\n",
    "ax.set_xlabel('Enrichment ratio'); ax.set_title('Pathway Dot Plot (size = gene count)')\n",
    "ax.axvline(1, color='gray', linestyle='--', alpha=0.5); ax.invert_yaxis()\n",
    "for gc in [5, 10, 18]:\n",
    "    ax.scatter([], [], s=gc*20, c='gray', alpha=0.5, label=f'{gc} genes')\n",
    "ax.legend(loc='lower right', fontsize=8)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Section 5 \u00b7 Putting it together\n",
    "\n",
    "Real papers rarely show one plot at a time \u2014 they combine the key views into one\n",
    "labelled figure. Here we reuse the objects already in memory (`pcs`, `var`,\n",
    "`z_scores`, `pathway_results`) and lay out four panels with `GridSpec`:\n",
    "\n",
    "- **A** PCA \u00b7 **B** volcano (a fresh synthetic DE set) \u00b7 **C** top-5 pathways \u00b7\n",
    "  **D** a wide heatmap of the top-20 DE genes\n",
    "\n",
    "This is the figure to screenshot for a slide."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Panel B is a **volcano plot** \u2014 the one-glance DE summary: x = fold-change,\n",
    "y = \u2212log\u2081\u2080(p), so the genes that moved *and* moved convincingly sit top-left and\n",
    "top-right. The catch students always trip on: \"significant\" is a *choice of two\n",
    "thresholds*. The explorer makes that choice tangible \u2014 drag the fold-change and\n",
    "p-value cutoffs and watch which genes flip color."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<!--widget:volcano-->"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 1.** Panel B is the volcano from **Module 6** \u2014 we load the DESeq2\n",
    "results table (`deseq2_results.tsv`) produced there, so this figure shows the\n",
    "same differential-expression call you made, not a throwaway synthetic set."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Panel B reuses Module 6's real DESeq2 output (continuity across the workshop).\n",
    "fc, padj = [], []\n",
    "with open(capstone_dir / 'deseq2_results.tsv') as f:\n",
    "    f.readline()  # header\n",
    "    for line in f:\n",
    "        p = line.strip().split('\\t')\n",
    "        if len(p) < 7:\n",
    "            continue\n",
    "        try:\n",
    "            fc.append(float(p[2])); padj.append(float(p[6]))\n",
    "        except ValueError:\n",
    "            continue\n",
    "fc = np.array(fc)\n",
    "padj = np.clip(np.array(padj), 1e-300, 1)\n",
    "lp = -np.log10(padj)\n",
    "sig = (padj < 0.05) & (np.abs(fc) > 1)\n",
    "print(f\"{int(sig.sum())} of {len(fc)} genes pass |log2FC|>1 and padj<0.05  (DESeq2, Module 6)\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Step 2.** Now assemble the four panels with `GridSpec`, reusing the objects\n",
    "already in memory (`pcs`, `var`, `z_scores`, `pathway_results`, and the volcano\n",
    "arrays above)."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(16, 12)); fig.patch.set_facecolor('white')\n",
    "gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.4, wspace=0.35)\n",
    "\n",
    "# Panel A: PCA\n",
    "ax_pca = fig.add_subplot(gs[0, 0])\n",
    "for prefix, (color, label) in [('ctrl', ('steelblue', 'Control')), ('treat', ('darkorange', 'Treated'))]:\n",
    "    idx = [i for i, s in enumerate(sample_names) if s.startswith(prefix)]\n",
    "    ax_pca.scatter(pcs[idx, 0], pcs[idx, 1], c=color, s=80, label=label, edgecolors='black', linewidths=0.5, zorder=3)\n",
    "    confidence_ellipse(pcs[idx, :2], ax_pca, n_std=1.5, facecolor=color, alpha=0.1, edgecolor=color, linewidth=1.5)\n",
    "ax_pca.set_xlabel(f'PC1 ({var[0]:.0f}%)'); ax_pca.set_ylabel(f'PC2 ({var[1]:.0f}%)')\n",
    "ax_pca.set_title('A  PCA', loc='left', fontweight='bold'); ax_pca.legend(fontsize=8)\n",
    "\n",
    "# Panel B: Volcano (uses fc/lp/sig from the prep cell)\n",
    "ax_vol = fig.add_subplot(gs[0, 1])\n",
    "ax_vol.scatter(fc[~sig], lp[~sig], c='lightgray', alpha=0.5, s=15)\n",
    "ax_vol.scatter(fc[sig & (fc > 0)], lp[sig & (fc > 0)], c='#e74c3c', alpha=0.8, s=30)\n",
    "ax_vol.scatter(fc[sig & (fc < 0)], lp[sig & (fc < 0)], c='#3498db', alpha=0.8, s=30)\n",
    "ax_vol.axhline(-math.log10(0.05), color='gray', linestyle='--', alpha=0.6)\n",
    "ax_vol.axvline(1, color='gray', linestyle=':', alpha=0.6); ax_vol.axvline(-1, color='gray', linestyle=':', alpha=0.6)\n",
    "ax_vol.set_xlabel('log\u2082 FC'); ax_vol.set_ylabel('-log\u2081\u2080(p)')\n",
    "ax_vol.set_title('B  Volcano', loc='left', fontweight='bold')\n",
    "\n",
    "# Panel C: top-5 pathways\n",
    "ax_path = fig.add_subplot(gs[0, 2])\n",
    "top5 = pathway_results[:5]\n",
    "ax_path.barh(range(5), [-math.log10(p['p_adj']) for p in top5],\n",
    "             color=['#e74c3c' if p['direction'] == 'up' else '#3498db' for p in top5], edgecolor='white')\n",
    "ax_path.set_yticks(range(5)); ax_path.set_yticklabels([p['pathway'] for p in top5], fontsize=8)\n",
    "ax_path.set_xlabel('-log\u2081\u2080(p_adj)'); ax_path.set_title('C  Pathways', loc='left', fontweight='bold')\n",
    "ax_path.invert_yaxis()\n",
    "\n",
    "# Panel D: heatmap of top-20 DE genes\n",
    "ax_heat = fig.add_subplot(gs[1, :])\n",
    "im = ax_heat.imshow(z_scores[:20].T, cmap='RdBu_r', aspect='auto', vmin=-2.5, vmax=2.5)\n",
    "plt.colorbar(im, ax=ax_heat, orientation='vertical', fraction=0.02, label='Z-score')\n",
    "ax_heat.set_xticks(range(20)); ax_heat.set_xticklabels(genes_list[:20], rotation=45, ha='right', fontsize=8)\n",
    "ax_heat.set_yticks(range(len(sample_names))); ax_heat.set_yticklabels(sample_names, fontsize=9)\n",
    "ax_heat.axhline(len(ctrl_idx) - 0.5, color='black', linewidth=2)\n",
    "ax_heat.set_title('D  Top 20 DE Genes Heatmap', loc='left', fontweight='bold')\n",
    "\n",
    "plt.suptitle('Multi-omics Summary Figure', fontsize=15, y=1.01, fontweight='bold')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Exercises\n",
    "\n",
    "Each exercise has a worked solution you can edit and re-run. Try to predict the\n",
    "output before pressing \u25b6.\n",
    "\n",
    "### Exercise 1 \u00b7 Annotate the coverage plot\n",
    "Add to the coverage view: **(a)** a variant track with allele-frequency lollipops,\n",
    "**(b)** a shaded band over every exon across all tracks, and **(c)** an\n",
    "\"Exon N\" label above each exon."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Variant positions (allele frequency) to mark, plus a reusable exon-shading helper.\n",
    "snp_sites = [\n",
    "    (gene_start + 900,  0.45),   # inside Exon 1\n",
    "    (gene_start + 3400, 0.30),   # inside Exon 2\n",
    "    (gene_start + 5000, 0.15),   # intron\n",
    "    (gene_start + 7000, 0.60),   # inside Exon 3\n",
    "]\n",
    "\n",
    "def shade_exons(ax):\n",
    "    for ex_s, ex_e in exons:\n",
    "        ax.axvspan(ex_s / 1e6, ex_e / 1e6, color='gold', alpha=0.15, zorder=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the data and helper ready, the figure itself is just three stacked panels \u2014\n",
    "normal coverage (with Exon labels), tumor coverage, and the variant lollipops \u2014\n",
    "all sharing the exon shading."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "pos_mb = np.array(positions) / 1e6\n",
    "fig, axes = plt.subplots(3, 1, figsize=(14, 8), sharex=True)\n",
    "\n",
    "axes[0].fill_between(pos_mb, cov_normal, color='steelblue', alpha=0.7)\n",
    "shade_exons(axes[0]); axes[0].set_ylabel('Normal\\ncoverage')\n",
    "y_top = cov_normal.max() * 1.05\n",
    "for i, (ex_s, ex_e) in enumerate(exons, start=1):\n",
    "    axes[0].text((ex_s + ex_e) / 2 / 1e6, y_top, f'Exon {i}', ha='center', fontsize=8,\n",
    "                 color='darkgoldenrod', fontweight='bold')\n",
    "\n",
    "axes[1].fill_between(pos_mb, cov_tumor, color='#e74c3c', alpha=0.7)\n",
    "shade_exons(axes[1]); axes[1].set_ylabel('Tumor\\ncoverage')\n",
    "\n",
    "shade_exons(axes[2])\n",
    "for var_pos, allele_freq in snp_sites:\n",
    "    x = var_pos / 1e6\n",
    "    axes[2].vlines(x, 0, allele_freq, color='purple', linewidth=1.5)\n",
    "    axes[2].plot(x, allele_freq, marker='^', color='purple', markersize=8 + allele_freq * 10)\n",
    "    axes[2].text(x, allele_freq + 0.04, f'{allele_freq:.2f}', ha='center', fontsize=7, color='purple')\n",
    "axes[2].set_ylim(0, 0.8); axes[2].set_ylabel('Variant\\nallele freq'); axes[2].set_xlabel('Genomic position (Mb)')\n",
    "\n",
    "axes[0].set_xlim(pos_mb[0], pos_mb[-1])\n",
    "plt.suptitle('Enhanced Coverage Plot'); plt.tight_layout(); plt.show()\n",
    "print(f\"Added {len(snp_sites)} variant markers and shaded {len(exons)} exons.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 2 \u00b7 Clustered heatmap\n",
    "Cluster both genes (rows) and samples (columns) with `scipy.cluster.hierarchy`\n",
    "and draw the dendrograms framing the heatmap. Reordering rows/columns by\n",
    "similarity makes co-regulated gene blocks pop out."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "try:\n",
    "    from scipy.cluster.hierarchy import linkage, dendrogram\n",
    "    HAS_SCIPY = True\n",
    "except ImportError:\n",
    "    HAS_SCIPY = False\n",
    "    print(\"scipy not installed \u2014 skip this exercise\")\n",
    "\n",
    "if HAS_SCIPY:\n",
    "    data = z_scores                       # rows = first 30 genes, cols = samples\n",
    "    gene_labels, sample_labels = genes_list[:data.shape[0]], sample_names\n",
    "    row_link = linkage(data,   method='average', metric='euclidean')  # cluster genes\n",
    "    col_link = linkage(data.T, method='average', metric='euclidean')  # cluster samples\n",
    "    print(f\"Linkage computed for {data.shape[0]} genes x {data.shape[1]} samples\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the two linkages computed, draw the figure: column dendrogram on top, row\n",
    "dendrogram on the left, and the heatmap re-ordered by both so co-regulated blocks\n",
    "line up."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "if HAS_SCIPY:\n",
    "    fig = plt.figure(figsize=(10, 9))\n",
    "    grid = fig.add_gridspec(2, 2, width_ratios=[1, 4], height_ratios=[1, 4], wspace=0.02, hspace=0.02)\n",
    "\n",
    "    ax_top = fig.add_subplot(grid[0, 1])\n",
    "    col_dend = dendrogram(col_link, ax=ax_top, color_threshold=0, above_threshold_color='gray')\n",
    "    ax_top.set_xticks([]); ax_top.set_yticks([])\n",
    "    for s in ax_top.spines.values(): s.set_visible(False)\n",
    "\n",
    "    ax_left = fig.add_subplot(grid[1, 0])\n",
    "    row_dend = dendrogram(row_link, ax=ax_left, orientation='left', color_threshold=0, above_threshold_color='gray')\n",
    "    ax_left.set_xticks([]); ax_left.set_yticks([])\n",
    "    for s in ax_left.spines.values(): s.set_visible(False)\n",
    "\n",
    "    row_order, col_order = row_dend['leaves'], col_dend['leaves']\n",
    "    ordered = data[np.ix_(row_order, col_order)]\n",
    "\n",
    "    ax_heat = fig.add_subplot(grid[1, 1])\n",
    "    im = ax_heat.imshow(ordered, cmap='RdBu_r', aspect='auto', vmin=-2.5, vmax=2.5, interpolation='nearest')\n",
    "    ax_heat.set_xticks(range(len(col_order)))\n",
    "    ax_heat.set_xticklabels([sample_labels[i] for i in col_order], rotation=45, ha='right', fontsize=8)\n",
    "    ax_heat.set_yticks(range(len(row_order)))\n",
    "    ax_heat.set_yticklabels([gene_labels[i] for i in row_order], fontsize=6)\n",
    "    ax_heat.yaxis.tick_right()\n",
    "    fig.colorbar(im, ax=ax_left, fraction=0.05, pad=0.3, label='Z-score')\n",
    "    plt.suptitle('Clustered Heatmap'); plt.show()\n",
    "    print(f\"Clustered {data.shape[0]} genes x {data.shape[1]} samples.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 3 \u00b7 Variant lollipop plot\n",
    "Parse `synthetic_variants.vcf` and plot each variant as a lollipop: x = position,\n",
    "y = QUAL, color by FILTER (red = PASS, gray = filtered), size \u221d allele frequency."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "records = []\n",
    "with open(capstone_dir / 'synthetic_variants.vcf') as fh:\n",
    "    for line in fh:\n",
    "        if line.startswith('#'):\n",
    "            continue\n",
    "        f = line.rstrip('\\n').split('\\t')\n",
    "        af = 0.0\n",
    "        for kv in f[7].split(';'):          # INFO column, e.g. \"DP=19;AF=0.392\"\n",
    "            if kv.startswith('AF='):\n",
    "                af = float(kv.split('=')[1])\n",
    "        records.append({'pos': int(f[1]), 'qual': float(f[5]), 'filter': f[6], 'af': af})\n",
    "print(f\"Parsed {len(records)} variants\")\n",
    "\n",
    "positions_v = np.array([r['pos']  for r in records])\n",
    "quals       = np.array([r['qual'] for r in records])\n",
    "afs         = np.array([r['af']   for r in records])\n",
    "is_pass     = np.array([r['filter'] == 'PASS' for r in records])\n",
    "print(f\"{int(is_pass.sum())} PASS, {len(records) - int(is_pass.sum())} filtered\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the lollipop figure: a stem from the axis to each variant's QUAL, a marker\n",
    "sized by allele frequency, and color from the FILTER flag."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots(figsize=(12, 5))\n",
    "for x, q, ok in zip(positions_v, quals, is_pass):\n",
    "    ax.vlines(x, 0, q, color=('#e74c3c' if ok else 'gray'), alpha=0.6, linewidth=1)\n",
    "ax.scatter(positions_v[is_pass], quals[is_pass], s=30 + afs[is_pass] * 300, c='#e74c3c',\n",
    "           edgecolors='black', linewidths=0.4, zorder=3, label='PASS')\n",
    "ax.scatter(positions_v[~is_pass], quals[~is_pass], s=30 + afs[~is_pass] * 300, c='gray',\n",
    "           edgecolors='black', linewidths=0.4, zorder=3, label='Filtered')\n",
    "ax.set_xlabel('Genomic position (chr22)'); ax.set_ylabel('QUAL score'); ax.set_ylim(bottom=0)\n",
    "ax.legend(title='FILTER'); plt.title('Variant Lollipop Plot'); plt.tight_layout(); plt.show()\n",
    "print(f\"PASS: {int(is_pass.sum())}   Filtered: {len(records) - int(is_pass.sum())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 4 \u00b7 PCA biplot\n",
    "Overlay the top gene **loadings** (rows of `Vt`) as arrows on the PCA scores to\n",
    "show *which genes drive the separation* \u2014 the longest arrows are your biomarkers."
   ]
  },
  {
   "cell_type": "code",
   "metadata": {},
   "execution_count": null,
   "outputs": [],
   "source": [
    "loadings_pc1, loadings_pc2 = Vt[0, :], Vt[1, :]\n",
    "loading_mag = np.sqrt(loadings_pc1**2 + loadings_pc2**2)\n",
    "top10 = np.argsort(loading_mag)[-10:]\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(9, 7))\n",
    "for prefix, color, label in [('ctrl', 'steelblue', 'Control'), ('treat', 'darkorange', 'Treated')]:\n",
    "    idx = [i for i, s in enumerate(sample_names) if s.startswith(prefix)]\n",
    "    ax.scatter(pcs[idx, 0], pcs[idx, 1], c=color, s=90, label=label, edgecolors='black', linewidths=0.5, zorder=3)\n",
    "    for i in idx:\n",
    "        ax.annotate(sample_names[i], (pcs[i, 0], pcs[i, 1]), textcoords='offset points', xytext=(6, 3), fontsize=8)\n",
    "\n",
    "score_radius = np.sqrt(pcs[:, 0]**2 + pcs[:, 1]**2).max()\n",
    "scale = (score_radius * 0.9) / loading_mag[top10].max()\n",
    "for gi in top10:\n",
    "    dx, dy = loadings_pc1[gi] * scale, loadings_pc2[gi] * scale\n",
    "    ax.arrow(0, 0, dx, dy, color='crimson', alpha=0.7, head_width=score_radius * 0.03, length_includes_head=True, zorder=2)\n",
    "    ax.text(dx * 1.12, dy * 1.12, genes_list[gi], color='darkred', fontsize=8, ha='center', va='center')\n",
    "\n",
    "ax.axhline(0, color='gray', linewidth=0.5); ax.axvline(0, color='gray', linewidth=0.5)\n",
    "ax.set_xlabel(f'PC1 ({var[0]:.1f}%)'); ax.set_ylabel(f'PC2 ({var[1]:.1f}%)')\n",
    "ax.legend(); plt.title('PCA Biplot'); plt.show()\n",
    "print(\"Top genes driving the separation:\", \", \".join(genes_list[gi] for gi in reversed(top10)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Summary\n",
    "\n",
    "| Plot | When to use | Key parameters |\n",
    "|------|-------------|----------------|\n",
    "| Coverage | Check alignment quality, find CNVs | smoothing window, log\u2082 ratio |\n",
    "| Heatmap | Expression patterns across samples | z-score, clustering |\n",
    "| Volcano | Summarize DE results | FC + padj thresholds |\n",
    "| PCA | Sample clustering & outliers | n_components, ellipses |\n",
    "| Pathway bar/dot | Communicate the biology | enrichment ratio, gene count |\n",
    "\n",
    "**Key takeaways**\n",
    "- Z-score heatmaps before reading them \u2014 raw counts mislead.\n",
    "- PCA is your first QC step: it exposes batch effects and outliers immediately.\n",
    "- A volcano needs *both* a fold-change and a significance threshold \u2014 replay the\n",
    "  **volcano explorer** in Section 5 to feel why.\n",
    "- Watch color accessibility: pair red/blue with shape for colorblind readers.\n",
    "- Export at 300 DPI: `plt.savefig('figure.png', dpi=300, bbox_inches='tight')`\n",
    "\n",
    "**Next:** Module 8 \u2014 Capstone Challenge"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
