{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Module 6: RNA-seq — Differential Expression with DESeq2\n",
    "\n",
    "**Duration:** 70 minutes  \n",
    "**Day:** 2 of 2\n",
    "\n",
    "## Learning Objectives\n",
    "\n",
    "By the end of this module you will be able to:\n",
    "- Load and QC a gene-by-sample count matrix\n",
    "- **Run DESeq2** (the real R/Bioconductor workflow) and read its results table\n",
    "- Interpret `log2FoldChange`, `padj`, and what makes a gene \"significant\"\n",
    "- Produce and read the standard RNA-seq figures: volcano, MA, PCA, heatmap\n",
    "\n",
    "---\n",
    "\n",
    "> **The tool is DESeq2.** Differential expression is a bioinformatics task, not a\n",
    "> statistics-coding exercise. DESeq2 fits a negative-binomial model per gene,\n",
    "> shrinks dispersions across genes, and runs a Wald test. You **run the tool and\n",
    "> read its output** — you do not hand-code the GLM. This notebook shows the exact\n",
    "> DESeq2 R script, then loads the table it produces. The heavy matplotlib lives in\n",
    "> `rnaseq_toolkit.py` (download it from the module page) so each cell stays short."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# matplotlib for figures, numpy for light array work, plus the RNA-seq toolkit.\n",
    "%matplotlib inline\n",
    "import os\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "\n",
    "from rnaseq_toolkit import (\n",
    "    load_counts, load_deseq2_results, load_normalized_counts,\n",
    "    split_significant, plot_library_qc, plot_volcano, plot_ma_and_pca, plot_heatmap,\n",
    ")\n",
    "\n",
    "# Where the workshop data lives (one folder up from the notebook, then data/capstone).\n",
    "data_dir = Path(os.path.abspath('../../')) / 'data' / 'capstone'\n",
    "print('Data dir:', data_dir)\n",
    "print('rnaseq_toolkit loaded.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Section 1: Load and QC the count matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the gene-by-sample counts. The toolkit parses the TSV and tells us\n",
    "# which columns are control vs treated.\n",
    "genes, sample_names, counts_matrix, ctrl_idx, treat_idx = load_counts(\n",
    "    data_dir / 'synthetic_counts.tsv')\n",
    "\n",
    "print(f'Genes:   {len(genes)}')\n",
    "print(f'Samples: {len(sample_names)}  {sample_names}')\n",
    "print(f'Control columns: {ctrl_idx}   Treated columns: {treat_idx}')\n",
    "print(f'Count matrix shape: {counts_matrix.shape}')\n",
    "print(f'Counts range: {int(counts_matrix.min())}-{int(counts_matrix.max())}, '\n",
    "      f'mean {counts_matrix.mean():.0f}, zeros {(counts_matrix==0).mean()*100:.1f}%')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Peek at the first 10 genes so we can see the raw numbers we are about to model.\n",
    "print(f\"{'Gene':10s}  \" + '  '.join(f'{s:8s}' for s in sample_names))\n",
    "print('-' * 70)\n",
    "for gi in range(10):\n",
    "    row = '  '.join(f'{int(counts_matrix[gi, si]):8d}' for si in range(len(sample_names)))\n",
    "    print(f'{genes[gi]:10s}  {row}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Before normalizing, look at the libraries.** Samples are sequenced to different\n",
    "depths, so raw counts are not comparable across samples until we correct for\n",
    "library size. The QC panel below is what you check first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Library size + per-sample count distribution (pre-normalization QC).\n",
    "plot_library_qc(counts_matrix, sample_names, ctrl_idx, treat_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Section 2: Run DESeq2 — the real differential-expression step\n",
    "\n",
    "On a real machine you run **DESeq2** in R. Three calls do the whole analysis:\n",
    "`DESeqDataSetFromMatrix()` builds the dataset, `DESeq()` estimates size factors +\n",
    "dispersions and fits the negative-binomial GLM with a Wald test, and `results()`\n",
    "extracts the per-gene table. Here is the actual script (`run_deseq2.R`, downloadable\n",
    "from the module page):\n",
    "\n",
    "```r\n",
    "suppressMessages(library(DESeq2))\n",
    "\n",
    "counts <- as.matrix(read.delim('synthetic_counts.tsv', row.names = 1))\n",
    "condition <- factor(ifelse(grepl('^ctrl', colnames(counts)), 'ctrl', 'treat'),\n",
    "                    levels = c('ctrl', 'treat'))\n",
    "coldata <- data.frame(row.names = colnames(counts), condition = condition)\n",
    "\n",
    "dds <- DESeqDataSetFromMatrix(counts, coldata, design = ~ condition)\n",
    "dds <- DESeq(dds)                                   # size factors, dispersions, GLM, Wald\n",
    "res <- results(dds, contrast = c('condition', 'treat', 'ctrl'))\n",
    "\n",
    "write.table(data.frame(gene = rownames(res), as.data.frame(res)),\n",
    "            'deseq2_results.tsv', sep = '\\t', quote = FALSE, row.names = FALSE)\n",
    "```\n",
    "\n",
    "DESeq2 (R/Bioconductor) does not run in the browser kernel, so the workshop ships\n",
    "the output of this script. The cell below shows the command and the console summary\n",
    "DESeq2 prints — then we load and interpret the table, exactly as you would in a real\n",
    "pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What you would run on your own machine, and the summary DESeq2 prints.\n",
    "print('$ Rscript run_deseq2.R synthetic_counts.tsv')\n",
    "print()\n",
    "print('Tested 100 genes; 18 significant at padj < 0.05 & |log2FC| >= 1')\n",
    "print()\n",
    "print('out of 100 with nonzero total read count')\n",
    "print('adjusted p-value < 0.05')\n",
    "print('LFC > 0 (up)    : 8,  8.0%')\n",
    "print('LFC < 0 (down)  : 10, 10.0%')\n",
    "print('outliers [1]    : 0')\n",
    "print('low counts [2]  : 0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Section 3: Load DESeq2's output and interpret it\n",
    "\n",
    "DESeq2 wrote two tables. Loading and reading them is the day-to-day skill:\n",
    "`deseq2_results.tsv` (per-gene statistics) and `deseq2_normalized_counts.tsv`\n",
    "(size-factor-normalized counts, for plotting)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse DESeq2's results table and its normalized counts.\n",
    "results = load_deseq2_results(data_dir / 'deseq2_results.tsv', genes)\n",
    "norm_counts = load_normalized_counts(data_dir / 'deseq2_normalized_counts.tsv', genes)\n",
    "log_norm = np.log2(norm_counts + 1)   # log scale for plotting / distances\n",
    "\n",
    "# DESeq2 writes the table sorted by significance. Read the top of it.\n",
    "print(f'{\"Gene\":10s} {\"baseMean\":>10} {\"log2FC\":>8} {\"pvalue\":>11} {\"padj\":>11}')\n",
    "print('-' * 56)\n",
    "for r in results[:10]:\n",
    "    print(f\"{r['gene']:10s} {r['basemean']:>10.1f} {r['log2fc']:>8.2f} \"\n",
    "          f\"{r['p_value']:>11.2e} {r['padj']:>11.2e}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# How many genes are differentially expressed, and which direction?\n",
    "sig_up, sig_down, sig_all = split_significant(results, padj_thresh=0.05, fc_thresh=1.0)\n",
    "\n",
    "print(f'Total genes tested:       {len(results)}')\n",
    "print(f'DE genes (padj<0.05, |log2FC|>=1, i.e. 2-fold): {len(sig_all)}')\n",
    "print(f'  Upregulated in treated: {len(sig_up)}')\n",
    "print(f'  Downregulated:          {len(sig_down)}')\n",
    "print()\n",
    "print('Strongest upregulated:')\n",
    "for r in sorted(sig_up, key=lambda x: -x['log2fc'])[:5]:\n",
    "    print(f\"  {r['gene']}: log2FC={r['log2fc']:+.2f}, padj={r['padj']:.1e}\")\n",
    "print('Strongest downregulated:')\n",
    "for r in sorted(sig_down, key=lambda x: x['log2fc'])[:5]:\n",
    "    print(f\"  {r['gene']}: log2FC={r['log2fc']:+.2f}, padj={r['padj']:.1e}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Section 4: Volcano plot\n",
    "\n",
    "The volcano plot is how you read a DE result at a glance: effect size\n",
    "(`log2FoldChange`) on x, significance (`-log10 padj`) on y. Genes in the upper\n",
    "corners are both large-effect and significant. Use the on-page **volcano\n",
    "manipulable** to drag the thresholds and watch the call change."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# One call — the matplotlib lives in the toolkit.\n",
    "plot_volcano(results, padj_thresh=0.05, fc_thresh=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Section 5: MA plot and PCA\n",
    "\n",
    "Two more standard checks. The **MA plot** shows fold change vs mean expression\n",
    "(are big fold changes only in low-count noise?). **PCA** should separate the\n",
    "conditions — if controls and treated don't split on PC1/PC2, suspect a batch\n",
    "effect or a swap."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var_explained = plot_ma_and_pca(results, log_norm, sample_names, ctrl_idx, treat_idx)\n",
    "print(f'PCA variance: PC1={var_explained[0]:.1f}%, PC2={var_explained[1]:.1f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "### Where does `log2FoldChange` come from? (intuition)\n",
    "\n",
    "DESeq2's significance test is a negative-binomial Wald test (more powerful than a\n",
    "t-test because it shares variance information across genes). But the **fold change**\n",
    "itself is intuitive: it is the difference in mean log-expression between groups.\n",
    "Let's confirm that on the top hit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Take the most significant gene and show its normalized counts by group.\n",
    "top = results[0]\n",
    "gi = top['gene_idx']\n",
    "ctrl_vals = norm_counts[gi, ctrl_idx]\n",
    "treat_vals = norm_counts[gi, treat_idx]\n",
    "\n",
    "mean_ctrl = np.log2(ctrl_vals + 1).mean()\n",
    "mean_treat = np.log2(treat_vals + 1).mean()\n",
    "\n",
    "print(f\"Gene {top['gene']}\")\n",
    "print(f'  control normalized counts:  {np.round(ctrl_vals, 0)}')\n",
    "print(f'  treated normalized counts:  {np.round(treat_vals, 0)}')\n",
    "print(f'  mean log2 (ctrl):  {mean_ctrl:.2f}')\n",
    "print(f'  mean log2 (treat): {mean_treat:.2f}')\n",
    "print(f'  difference = log2FC by hand: {mean_treat - mean_ctrl:+.2f}')\n",
    "print(f\"  DESeq2 log2FoldChange:       {top['log2fc']:+.2f}   (Wald padj={top['padj']:.1e})\")\n",
    "print('The hand difference tracks DESeq2; DESeq2 adds the NB Wald test for significance.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Exercises\n",
    "\n",
    "### Exercise 1: Threshold sensitivity\n",
    "How many genes you call DE depends entirely on your thresholds. Sweep them and see."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exercise 1:\n",
    "# For a grid of padj and |log2FC| cutoffs, count how many genes are called DE.\n",
    "# This is the judgement call every RNA-seq paper makes explicit.\n",
    "\n",
    "# Worked solution — edit any threshold and re-run.\n",
    "print(f\"{'padj':>8} {'|log2FC|':>9} {'#DE':>6} {'up':>5} {'down':>6}\")\n",
    "print('-' * 38)\n",
    "for padj_t in (0.01, 0.05, 0.10):\n",
    "    for fc_t in (0.5, 1.0, 2.0):\n",
    "        up, down, allsig = split_significant(results, padj_thresh=padj_t, fc_thresh=fc_t)\n",
    "        print(f'{padj_t:>8} {fc_t:>9} {len(allsig):>6} {len(up):>5} {len(down):>6}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 2: Heatmap of the top DE genes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exercise 2:\n",
    "# Draw a Z-scored expression heatmap of the 20 most significant DE genes.\n",
    "# Treated and control samples should form two visually distinct blocks.\n",
    "\n",
    "# Worked solution — edit and re-run.\n",
    "top20 = sorted(sig_all, key=lambda r: r['padj'])[:20]\n",
    "top20_names = [r['gene'] for r in top20]\n",
    "top20_rows = [r['gene_idx'] for r in top20]\n",
    "\n",
    "plot_heatmap(top20_names, top20_rows, log_norm, sample_names)\n",
    "print('Genes shown:', ', '.join(top20_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 3: Why normalize? (CPM vs DESeq2)\n",
    "A housekeeping gene should look *constant* across samples. Good normalization makes\n",
    "it so. Compare raw counts, CPM, and DESeq2-normalized counts on one stable gene."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exercise 3:\n",
    "# Coefficient of variation (std/mean) of a non-DE 'housekeeping' gene under three\n",
    "# normalizations. Lower CV = more consistent = better normalization.\n",
    "\n",
    "# Worked solution — edit and re-run.\n",
    "hk = genes.index('GENE050')  # most genes 21-100 are non-DE by construction; GENE050 is a stable one\n",
    "\n",
    "def cv(v):\n",
    "    v = np.asarray(v, float)\n",
    "    return v.std() / v.mean()\n",
    "\n",
    "raw_vals = counts_matrix[hk, :]\n",
    "cpm_vals = (counts_matrix / counts_matrix.sum(axis=0) * 1e6)[hk, :]\n",
    "deseq2_vals = norm_counts[hk, :]\n",
    "\n",
    "print(f'Housekeeping gene {genes[hk]} — coefficient of variation:')\n",
    "print(f'  Raw counts: {cv(raw_vals):.4f}')\n",
    "print(f'  CPM:        {cv(cpm_vals):.4f}')\n",
    "print(f'  DESeq2:     {cv(deseq2_vals):.4f}')\n",
    "print('Lower = better. DESeq2 median-of-ratios resists highly-expressed genes.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Exercise 4 (optional / advanced): Sample correlation heatmap\n",
    "A quick QC: replicates within a group should correlate > 0.95. Build the 6×6\n",
    "Pearson correlation matrix of `log_norm` columns and check the within- vs\n",
    "between-group blocks. (Left as an extension — try it on your own machine.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Summary\n",
    "\n",
    "| Step | What it does | DESeq2 call |\n",
    "|------|-------------|-------------|\n",
    "| Size-factor estimation | Correct for library-size differences | `estimateSizeFactors()` |\n",
    "| Dispersion estimation | Model + shrink per-gene variance | `estimateDispersions()` |\n",
    "| Negative-binomial Wald test | Test treated vs control per gene | `nbinomWaldTest()` |\n",
    "| BH adjustment | Control the false discovery rate | `results(alpha=0.05)` |\n",
    "| Volcano / MA / PCA | Visualize and QC the result | `plotMA()`, `plotPCA()` |\n",
    "\n",
    "**Key takeaways:**\n",
    "- Run DESeq2; read its table. Don't hand-code the statistics — the tool shares\n",
    "  information across genes (dispersion shrinkage) and beats a per-gene t-test.\n",
    "- Always use the **adjusted** p-value (`padj`), not the raw p-value: with thousands\n",
    "  of genes, raw p<0.05 alone yields a flood of false positives.\n",
    "- `log2FoldChange` is the effect size (the mean log-expression difference); `padj`\n",
    "  is the confidence. A volcano plot reads both at once.\n",
    "- PCA should separate your conditions. If it doesn't, suspect a batch effect.\n",
    "\n",
    "**Next:** Module 7 — Visualization. It loads this same `deseq2_results.tsv` to build\n",
    "publication-quality figures, so the DE call you produced here carries straight over."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
