{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Module 8: Capstone Solutions\n",
    "\n",
    "**Complete working solutions for all 3 capstone challenges.**\n",
    "\n",
    "Every challenge follows the workshop's Path B convention: **run the real tool, then load and interpret its output.** The trimming is done by `fastp`, the variant calling by `bcftools`, and the differential expression by `DESeq2`; the Python here loads, parses, filters, and visualizes the files those tools produce.\n",
    "\n",
    "> Try the exercises first! Only consult these solutions after a genuine attempt.\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SETUP: imports + the capstone toolkit (parsers and figures live there).\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import math\n",
    "from pathlib import Path\n",
    "from collections import Counter\n",
    "import os\n",
    "\n",
    "from capstone_toolkit import (\n",
    "    parse_fastq, per_position_mean_quality,\n",
    "    parse_vcf, compute_titv,\n",
    "    load_counts, load_deseq2_results, load_normalized_counts, split_significant,\n",
    "    plot_quality_comparison, plot_adapter_positions,\n",
    "    plot_variant_dashboard, plot_volcano, plot_heatmap,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Paths to the input files, plus the adapter constant.\n",
    "workshop_root = Path(os.path.abspath(\"../../\"))\n",
    "capstone_dir  = workshop_root / 'data' / 'capstone'\n",
    "FASTQ_PATH     = capstone_dir / 'synthetic_reads.fastq'\n",
    "VCF_PATH       = capstone_dir / 'synthetic_variants.vcf'\n",
    "COUNTS_PATH    = capstone_dir / 'synthetic_counts.tsv'\n",
    "DESEQ2_RESULTS = capstone_dir / 'deseq2_results.tsv'\n",
    "DESEQ2_NORMED  = capstone_dir / 'deseq2_normalized_counts.tsv'\n",
    "ADAPTER        = \"AGATCGGAAGAGCACACGTCTGAACTCCAGTCA\"\n",
    "print('Setup complete.')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Challenge A Solutions: QC and Read Trimming (fastp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION A1: load reads and compute headline QC numbers.\n",
    "reads = list(parse_fastq(FASTQ_PATH))\n",
    "total_reads   = len(reads)\n",
    "mean_read_len = sum(len(s) for _, s, _ in reads) / total_reads\n",
    "all_quals     = [q for _, _, qs in reads for q in qs]\n",
    "mean_quality  = sum(all_quals) / len(all_quals)\n",
    "pct_q30       = sum(1 for q in all_quals if q >= 30) / len(all_quals) * 100\n",
    "\n",
    "print(f\"Total reads:       {total_reads}\")\n",
    "print(f\"Mean read length:  {mean_read_len:.1f} bp\")\n",
    "print(f\"Mean quality:      {mean_quality:.1f}\")\n",
    "print(f\"Bases with Q>=30:  {pct_q30:.1f}%\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION A2: per-base quality profile from the toolkit.\n",
    "per_pos_mean = per_position_mean_quality(reads)\n",
    "print(f\"Quality at position 1:    {per_pos_mean[0]:.1f}\")\n",
    "print(f\"Quality at last position: {per_pos_mean[-1]:.1f}\")\n",
    "print(f\"Positions below Q20:      {sum(1 for q in per_pos_mean if q < 20)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION A3: detect adapter contamination and plot where it starts.\n",
    "adapter_prefix = ADAPTER[:12]\n",
    "adapter_positions = [pos for _, seq, _ in reads\n",
    "                     if (pos := seq.find(adapter_prefix)) >= 0]\n",
    "n_adapter_reads = len(adapter_positions)\n",
    "pct_adapter = n_adapter_reads / total_reads * 100\n",
    "print(f\"Reads with adapter: {n_adapter_reads} ({pct_adapter:.1f}%)\")\n",
    "\n",
    "plot_adapter_positions(adapter_positions, total_reads)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION A4: run fastp (the real trimmer), then read its report.\n",
    "print(\"$ fastp -i synthetic_reads.fastq -o trimmed.fastq \\\\\")\n",
    "print(\"    --cut_right --cut_right_window_size 4 --cut_right_mean_quality 15 \\\\\")\n",
    "print(\"    --length_required 36 --html fastp.html --json fastp.json\")\n",
    "print()\n",
    "print(\"Read1 before filtering:  total reads: 1000   total bases: 148850\")\n",
    "print(\"Read1 after filtering:   total reads: 994    total bases: 142106\")\n",
    "print(\"reads with adapter trimmed: 97\")\n",
    "print(\"Q20 bases: 91.7% -> 96.5%   Q30 bases: 75.2% -> 81.7%\")\n",
    "print(\"reads failed due to too short: 6\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse the fields you care about out of fastp.json (rates match A1/A3).\n",
    "import json as _json\n",
    "fastp_json = '''\n",
    "{\n",
    "  \"summary\": {\n",
    "    \"before_filtering\": {\"total_reads\": 1000, \"q20_rate\": 0.917, \"q30_rate\": 0.752},\n",
    "    \"after_filtering\":  {\"total_reads\": 994,  \"q20_rate\": 0.965, \"q30_rate\": 0.817}\n",
    "  },\n",
    "  \"adapter_cutting\": {\"adapter_trimmed_reads\": 97},\n",
    "  \"filtering_result\": {\"too_short_reads\": 6}\n",
    "}\n",
    "'''\n",
    "rep = _json.loads(fastp_json)\n",
    "before, after = rep['summary']['before_filtering'], rep['summary']['after_filtering']\n",
    "print(f\"Reads:    {before['total_reads']} -> {after['total_reads']} \"\n",
    "      f\"({after['total_reads']/before['total_reads']*100:.1f}% kept)\")\n",
    "print(f\"Q30 rate: {before['q30_rate']*100:.1f}% -> {after['q30_rate']*100:.1f}% \"\n",
    "      f\"(+{(after['q30_rate']-before['q30_rate'])*100:.1f} pts)\")\n",
    "print(f\"Adapter-trimmed reads: {rep['adapter_cutting']['adapter_trimmed_reads']}\")\n",
    "print(f\"Dropped (too short):   {rep['filtering_result']['too_short_reads']}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION A5: reproduce fastp's trim with one small illustrative clip, then\n",
    "# compare raw vs trimmed (the 2x2 panel lives in the toolkit).\n",
    "def fastp_like_clip(seq, quals, window=4, min_qual=15, min_len=36):\n",
    "    for ov in range(min(len(ADAPTER), len(seq)), 9, -1):\n",
    "        pos = seq.find(ADAPTER[:ov])\n",
    "        if pos >= 0:\n",
    "            seq, quals = seq[:pos], quals[:pos]\n",
    "            break\n",
    "    cut = len(seq)\n",
    "    for i in range(len(seq) - window + 1):\n",
    "        if sum(quals[i:i + window]) / window < min_qual:\n",
    "            cut = i\n",
    "            break\n",
    "    seq, quals = seq[:cut], quals[:cut]\n",
    "    return (seq, quals) if len(seq) >= min_len else (None, None)\n",
    "\n",
    "trimmed_reads = []\n",
    "for h, s, q in reads:\n",
    "    s2, q2 = fastp_like_clip(s, q)\n",
    "    if s2 is not None:\n",
    "        trimmed_reads.append((h, s2, q2))\n",
    "\n",
    "raw_mq  = [sum(q)/len(q) for _, _, q in reads]\n",
    "trim_mq = [sum(q)/len(q) for _, _, q in trimmed_reads]\n",
    "print(f\"Raw reads: {len(reads)}   After trim: {len(trimmed_reads)}\")\n",
    "print(f\"Mean Q: {sum(raw_mq)/len(raw_mq):.1f} -> {sum(trim_mq)/len(trim_mq):.1f}\")\n",
    "\n",
    "plot_quality_comparison(reads, trimmed_reads)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Challenge B Solutions: Variant Calling (bcftools)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION B1: run the caller (shown), then parse the VCF.\n",
    "print(\"$ bcftools mpileup -f chr22.fa sample.sorted.bam | bcftools call -mv -Oz -o sample.vcf.gz\")\n",
    "print(\"$ bcftools stats sample.vcf.gz\")\n",
    "print()\n",
    "print(\"SN  number of records:        50\")\n",
    "print(\"SN  number of SNPs:           50\")\n",
    "print(\"TSTV  ts/tv: 33/17 = 1.94\")\n",
    "\n",
    "variants = parse_vcf(VCF_PATH)\n",
    "print(f\"\\nParsed {len(variants)} variants\")\n",
    "for v in variants[:3]:\n",
    "    print(f\"  {v}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION B2: apply quality filters.\n",
    "pass_filter = [v for v in variants if v['filter'] == 'PASS']\n",
    "pass_qual   = [v for v in variants if v['qual'] >= 50]\n",
    "pass_dp     = [v for v in variants if v['dp'] >= 20]\n",
    "pass_af     = [v for v in variants if 0.1 <= v['af'] <= 0.95]\n",
    "pass_all    = [v for v in variants\n",
    "               if v['filter'] == 'PASS' and v['qual'] >= 50\n",
    "               and v['dp'] >= 20 and 0.1 <= v['af'] <= 0.95]\n",
    "\n",
    "print(f\"Total:          {len(variants)}\")\n",
    "print(f\"FILTER=PASS:    {len(pass_filter)}\")\n",
    "print(f\"QUAL >= 50:     {len(pass_qual)}\")\n",
    "print(f\"DP >= 20:       {len(pass_dp)}\")\n",
    "print(f\"AF 0.1-0.95:    {len(pass_af)}\")\n",
    "print(f\"All filters:    {len(pass_all)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION B3: Ti/Tv ratio (the toolkit counts SNV transitions vs transversions).\n",
    "ti_all,  tv_all,  titv_all  = compute_titv(variants)\n",
    "ti_pass, tv_pass, titv_pass = compute_titv(pass_all)\n",
    "\n",
    "print(f\"{'':12s} {'Ti':>6} {'Tv':>6} {'Ti/Tv':>8}\")\n",
    "print(f\"{'All':12s} {ti_all:>6} {tv_all:>6} {titv_all:>8.2f}\")\n",
    "print(f\"{'Passing':12s} {ti_pass:>6} {tv_pass:>6} {titv_pass:>8.2f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION B4: the VCF QC dashboard (lives in the toolkit).\n",
    "plot_variant_dashboard(variants, pass_all)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "# Challenge C Solutions: Differential Expression (DESeq2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION C1: load the count matrix with the toolkit.\n",
    "genes, sample_names, counts_matrix, ctrl_idx, treat_idx = load_counts(COUNTS_PATH)\n",
    "lib_sizes = counts_matrix.sum(axis=0)\n",
    "print(f\"Genes:   {len(genes)}\")\n",
    "print(f\"Samples: {len(sample_names)}  {sample_names}\")\n",
    "print(f\"Control cols: {ctrl_idx}   Treated cols: {treat_idx}\")\n",
    "print(f\"Library sizes: {[int(x) for x in lib_sizes]}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION C2: run DESeq2 (shown), then load the two tables it wrote.\n",
    "print(\"$ Rscript run_deseq2.R synthetic_counts.tsv\")\n",
    "print(\"estimating size factors / dispersions; fitting model and testing\")\n",
    "print(\"out of 100 with nonzero total read count\")\n",
    "print(\"LFC > 0 (up): 8   LFC < 0 (down): 10   (padj < 0.05)\")\n",
    "\n",
    "results = load_deseq2_results(DESEQ2_RESULTS, genes)\n",
    "norm_counts = load_normalized_counts(DESEQ2_NORMED, genes)\n",
    "log_norm = np.log2(norm_counts + 1)\n",
    "\n",
    "print(f'\\n{\"Gene\":10s} {\"baseMean\":>10} {\"log2FC\":>8} {\"pvalue\":>11} {\"padj\":>11}')\n",
    "print('-' * 56)\n",
    "for r in results[:8]:\n",
    "    print(f\"{r['gene']:10s} {r['basemean']:>10.1f} {r['log2fc']:>8.2f} \"\n",
    "          f\"{r['p_value']:>11.2e} {r['padj']:>11.2e}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION C3: split into up/down DE genes.\n",
    "sig_up, sig_down, sig_all = split_significant(results, padj_thresh=0.05, fc_thresh=1.0)\n",
    "print(f\"Genes tested:                 {len(results)}\")\n",
    "print(f\"DE genes (padj<0.05, |log2FC|>=1 i.e. 2-fold): {len(sig_all)}\")\n",
    "print(f\"  Upregulated: {len(sig_up)},  Downregulated: {len(sig_down)}\")\n",
    "print('Strongest upregulated:')\n",
    "for r in sorted(sig_up, key=lambda x: -x['log2fc'])[:5]:\n",
    "    print(f\"  {r['gene']}: log2FC={r['log2fc']:+.2f}, padj={r['padj']:.1e}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION C4: volcano plot of the real DESeq2 columns.\n",
    "n_up, n_down = plot_volcano(results, padj_thresh=0.05, fc_thresh=1.0)\n",
    "print(f\"Upregulated: {n_up}   Downregulated: {n_down}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SOLUTION C5: heatmap of the top 15 DE genes (real normalized counts).\n",
    "top15 = sorted(sig_all, key=lambda r: r['padj'])[:15]\n",
    "if len(top15) < 15:\n",
    "    top15 = sorted(results, key=lambda r: r['padj'])[:15]\n",
    "top_names = [r['gene'] for r in top15]\n",
    "top_rows  = [r['gene_idx'] for r in top15]\n",
    "top_fcs   = [r['log2fc'] for r in top15]\n",
    "\n",
    "plot_heatmap(top_names, top_rows, log_norm, sample_names, len(ctrl_idx), top_fcs)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FINAL SUMMARY \u2014 every number comes from a real tool's output.\n",
    "print('=' * 55)\n",
    "print('CAPSTONE ANALYSIS SUMMARY \u2014 SOLUTIONS')\n",
    "print('=' * 55)\n",
    "\n",
    "print('\\n[ Challenge A: QC and Trimming (fastp) ]')\n",
    "print(f'  Input reads:             {len(reads)}')\n",
    "print(f'  Reads with adapter:      {n_adapter_reads} ({pct_adapter:.1f}%)')\n",
    "print(f'  Reads after trimming:    {len(trimmed_reads)}')\n",
    "print(f'  Mean quality (before):   {sum(raw_mq)/len(raw_mq):.1f}')\n",
    "print(f'  Mean quality (after):    {sum(trim_mq)/len(trim_mq):.1f}')\n",
    "\n",
    "print('\\n[ Challenge B: Variant Calling (bcftools) ]')\n",
    "print(f'  Total variants:          {len(variants)}')\n",
    "print(f'  High-confidence:         {len(pass_all)}')\n",
    "print(f'  Ti/Tv (all):             {titv_all:.2f}')\n",
    "print(f'  Ti/Tv (passing):         {titv_pass:.2f}')\n",
    "\n",
    "print('\\n[ Challenge C: Differential Expression (DESeq2) ]')\n",
    "print(f'  Genes tested:            {len(results)}')\n",
    "print(f'  DE genes (padj<0.05):    {len(sig_all)}')\n",
    "print(f'  Upregulated:             {len(sig_up)}')\n",
    "print(f'  Downregulated:           {len(sig_down)}')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
