Module 8: Capstone Solutions

Complete working solutions for all 3 capstone challenges.

Try the exercises first! Only consult these solutions after making a genuine attempt.


In [ ]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import math
import random
import re
from pathlib import Path
from collections import Counter
import os

random.seed(42)
np.random.seed(42)

workshop_root = Path(os.path.abspath("../../"))
capstone_dir  = workshop_root / 'data' / 'capstone'

FASTQ_PATH  = capstone_dir / 'synthetic_reads.fastq'
VCF_PATH    = capstone_dir / 'synthetic_variants.vcf'
COUNTS_PATH = capstone_dir / 'synthetic_counts.tsv'
ADAPTER     = "AGATCGGAAGAGCACACGTCTGAACTCCAGTCA"
TI_PAIRS    = {('A','G'),('G','A'),('C','T'),('T','C')}

print("Setup complete.")

Challenge A Solutions: QC and Read Trimming

In [ ]:
# ── SOLUTION A1: Load FASTQ ───────────────────────────────────────────────────
def parse_fastq(filepath):
    with open(filepath) as f:
        while True:
            h = f.readline().rstrip()
            s = f.readline().rstrip()
            _ = f.readline()
            q = f.readline().rstrip()
            if not h: break
            yield h[1:], s, [ord(c) - 33 for c in q]

reads = list(parse_fastq(FASTQ_PATH))

total_reads   = len(reads)
mean_read_len = sum(len(s) for _, s, _ in reads) / total_reads
all_quals     = [q for _, _, qs in reads for q in qs]
mean_quality  = sum(all_quals) / len(all_quals)
pct_q30       = sum(1 for q in all_quals if q >= 30) / len(all_quals) * 100

print(f"Total reads:       {total_reads}")
print(f"Mean read length:  {mean_read_len:.1f} bp")
print(f"Mean quality:      {mean_quality:.1f}")
print(f"Bases with Q>=30:  {pct_q30:.1f}%")
In [ ]:
# ── SOLUTION A2: Per-base quality plot ────────────────────────────────────────
read_len = max(len(q) for _, _, q in reads)
per_pos_mean = []
for i in range(read_len):
    vals = [q[i] for _, _, q in reads if i < len(q)]
    per_pos_mean.append(sum(vals) / len(vals) if vals else 0)

fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(range(1, read_len+1), per_pos_mean, color='steelblue', linewidth=1.5)
ax.axhline(30, color='green',  linestyle='--', label='Q30', alpha=0.7)
ax.axhline(20, color='orange', linestyle='--', label='Q20', alpha=0.7)
ax.fill_between(range(1, read_len+1), per_pos_mean, 20,
                where=[q < 20 for q in per_pos_mean],
                color='red', alpha=0.2, label='Below Q20')
ax.set_xlabel('Position in read (bp)'); ax.set_ylabel('Mean Phred quality')
ax.set_title('Per-Base Quality Profile (Before Trimming)')
ax.set_ylim(0, 45); ax.legend()
plt.tight_layout(); plt.show()
In [ ]:
# ── SOLUTION A3: Adapter detection ───────────────────────────────────────────
adapter_prefix = ADAPTER[:12]  # 12-mer is specific enough
adapter_positions = []
for _, seq, _ in reads:
    pos = seq.find(adapter_prefix)
    if pos >= 0:
        adapter_positions.append(pos)

n_adapter_reads = len(adapter_positions)
pct_adapter     = n_adapter_reads / total_reads * 100

print(f"Reads with adapter: {n_adapter_reads} ({pct_adapter:.1f}%)")

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(adapter_positions, bins=20, color='darkorange', edgecolor='white')
ax.set_xlabel('Position of adapter in read (bp)')
ax.set_ylabel('Number of reads')
ax.set_title(f'Adapter Contamination: Position Distribution ({n_adapter_reads} reads, {pct_adapter:.1f}%)')
plt.tight_layout(); plt.show()
In [ ]:
# ── SOLUTION A4: Trimming pipeline ───────────────────────────────────────────
def trim_adapter(seq, quals, adapter, min_overlap=10):
    for overlap in range(min(len(adapter), len(seq)), min_overlap - 1, -1):
        prefix = adapter[:overlap]
        pos = seq.find(prefix)
        if pos >= 0:
            return seq[:pos], quals[:pos]
    return seq, quals

def trim_sliding_window(seq, quals, window=4, min_qual=20):
    trim_pos = len(seq)
    for i in range(len(seq) - window + 1):
        if sum(quals[i:i+window]) / window < min_qual:
            trim_pos = i
            break
    return seq[:trim_pos], quals[:trim_pos]

def trim_minlen(seq, quals, min_len=36):
    if len(seq) < min_len:
        return None, None
    return seq, quals

trimmed_reads = []
n_discarded   = 0

for header, seq, quals in reads:
    s, q = trim_adapter(seq, quals, ADAPTER)
    s, q = trim_sliding_window(s, q)
    s, q = trim_minlen(s, q)
    if s is None:
        n_discarded += 1
    else:
        trimmed_reads.append((header, s, q))

print(f"Input reads:    {len(reads)}")
print(f"Passing reads:  {len(trimmed_reads)} ({len(trimmed_reads)/len(reads)*100:.1f}%)")
print(f"Discarded:      {n_discarded} ({n_discarded/len(reads)*100:.1f}%)")
In [ ]:
# ── SOLUTION A5: Before/after comparison ─────────────────────────────────────
def per_pos_stats(reads_list, max_len=150):
    per_pos = [[] for _ in range(max_len)]
    for _, _, q in reads_list:
        for i, v in enumerate(q[:max_len]):
            per_pos[i].append(v)
    return [sum(p)/len(p) if p else 0 for p in per_pos]

pp_raw     = per_pos_stats(reads)
pp_trimmed = per_pos_stats(trimmed_reads)

raw_lens     = [len(s) for _, s, _ in reads]
trimmed_lens = [len(s) for _, s, _ in trimmed_reads]
raw_mq       = [sum(q)/len(q) for _, _, q in reads]
trimmed_mq   = [sum(q)/len(q) for _, _, q in trimmed_reads]

fig, axes = plt.subplots(2, 2, figsize=(14, 8))

for ax, pp, color, label in [
    (axes[0,0], pp_raw,     'steelblue', 'Before'),
    (axes[0,1], pp_trimmed, 'mediumseagreen', 'After'),
]:
    ax.plot(range(1, len(pp)+1), pp, color=color, linewidth=1.5)
    ax.axhline(30, color='green',  linestyle='--', alpha=0.6, label='Q30')
    ax.axhline(20, color='orange', linestyle='--', alpha=0.6, label='Q20')
    ax.set_title(f'Per-Base Quality — {label}'); ax.set_ylim(0, 45); ax.legend(fontsize=8)
    ax.set_xlabel('Position (bp)'); ax.set_ylabel('Mean Q')

axes[1,0].hist(raw_lens,     bins=20, color='steelblue',      alpha=0.7, edgecolor='white', label='Raw')
axes[1,0].hist(trimmed_lens, bins=20, color='mediumseagreen', alpha=0.7, edgecolor='white', label='Trimmed')
axes[1,0].set_xlabel('Read length (bp)'); axes[1,0].set_ylabel('Count')
axes[1,0].set_title('Read Length Distribution'); axes[1,0].legend()

axes[1,1].hist(raw_mq,     bins=25, color='steelblue',      alpha=0.7, edgecolor='white', label='Raw')
axes[1,1].hist(trimmed_mq, bins=25, color='mediumseagreen', alpha=0.7, edgecolor='white', label='Trimmed')
axes[1,1].axvline(30, color='red', linestyle='--', label='Q30')
axes[1,1].set_xlabel('Mean quality per read'); axes[1,1].set_ylabel('Count')
axes[1,1].set_title('Per-Read Mean Quality'); axes[1,1].legend()

plt.suptitle('QC: Before vs After Trimming', fontsize=13)
plt.tight_layout(); plt.show()

print(f"Mean Q before: {sum(raw_mq)/len(raw_mq):.1f}")
print(f"Mean Q after:  {sum(trimmed_mq)/len(trimmed_mq):.1f}")

Challenge B Solutions: Variant Calling

In [ ]:
# ── SOLUTION B1: Parse VCF ────────────────────────────────────────────────────
def parse_vcf(filepath):
    variants = []
    with open(filepath) as f:
        for line in f:
            line = line.rstrip()
            if line.startswith('#') or not line:
                continue
            parts = line.split('\t')
            chrom, pos, vid, ref, alt, qual, flt, info_str = parts[:8]
            info = {}
            for item in info_str.split(';'):
                if '=' in item:
                    k, v = item.split('=', 1)
                    info[k] = v
            variants.append({
                'chrom':  chrom,
                'pos':    int(pos),
                'ref':    ref,
                'alt':    alt,
                'qual':   float(qual),
                'filter': flt,
                'dp':     int(info.get('DP', 0)),
                'af':     float(info.get('AF', 0)),
            })
    return variants

variants = parse_vcf(VCF_PATH)
print(f"Parsed {len(variants)} variants")
for v in variants[:3]:
    print(f"  {v}")
In [ ]:
# ── SOLUTION B2: Apply filters ────────────────────────────────────────────────
pass_filter = [v for v in variants if v['filter'] == 'PASS']
pass_qual   = [v for v in variants if v['qual'] >= 50]
pass_dp     = [v for v in variants if v['dp'] >= 20]
pass_af     = [v for v in variants if 0.1 <= v['af'] <= 0.95]
pass_all    = [v for v in variants
               if v['filter'] == 'PASS' and v['qual'] >= 50
               and v['dp'] >= 20 and 0.1 <= v['af'] <= 0.95]

print(f"Total:          {len(variants)}")
print(f"FILTER=PASS:    {len(pass_filter)}")
print(f"QUAL >= 50:     {len(pass_qual)}")
print(f"DP >= 20:       {len(pass_dp)}")
print(f"AF 0.1-0.95:    {len(pass_af)}")
print(f"All filters:    {len(pass_all)}")
In [ ]:
# ── SOLUTION B3: Ti/Tv ratio ──────────────────────────────────────────────────
def compute_titv(variant_list):
    ti = sum(1 for v in variant_list if (v['ref'], v['alt']) in TI_PAIRS)
    tv = len(variant_list) - ti
    return ti, tv, ti / max(1, tv)

ti_all,  tv_all,  titv_all  = compute_titv(variants)
ti_pass, tv_pass, titv_pass = compute_titv(pass_all)

print(f"{'':12s} {'Ti':>6} {'Tv':>6} {'Ti/Tv':>8}")
print(f"{'All':12s} {ti_all:>6} {tv_all:>6} {titv_all:>8.2f}")
print(f"{'Passing':12s} {ti_pass:>6} {tv_pass:>6} {titv_pass:>8.2f}")
In [ ]:
# ── SOLUTION B4: QC dashboard ────────────────────────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(14, 9))

# QUAL distribution
axes[0,0].hist([v['qual'] for v in variants],  bins=20, color='steelblue',      alpha=0.6, edgecolor='white', label='All')
axes[0,0].hist([v['qual'] for v in pass_all],  bins=20, color='mediumseagreen', alpha=0.8, edgecolor='white', label='Passing')
axes[0,0].axvline(50, color='red', linestyle='--', label='QUAL=50')
axes[0,0].set_xlabel('QUAL'); axes[0,0].set_ylabel('Count')
axes[0,0].set_title('Quality Score Distribution'); axes[0,0].legend()

# AF distribution
axes[0,1].hist([v['af'] for v in variants],  bins=15, color='steelblue',      alpha=0.6, edgecolor='white', label='All')
axes[0,1].hist([v['af'] for v in pass_all],  bins=15, color='mediumseagreen', alpha=0.8, edgecolor='white', label='Passing')
axes[0,1].set_xlabel('Allele Frequency'); axes[0,1].set_ylabel('Count')
axes[0,1].set_title('Allele Frequency Distribution'); axes[0,1].legend()

# QUAL vs DP scatter
pass_set = set(id(v) for v in pass_all)
colors_v = ['mediumseagreen' if id(v) in pass_set else 'tomato' for v in variants]
axes[1,0].scatter([v['dp'] for v in variants], [v['qual'] for v in variants],
                   c=colors_v, alpha=0.7, s=40)
axes[1,0].axvline(20, color='red', linestyle='--', alpha=0.5)
axes[1,0].axhline(50, color='red', linestyle='--', alpha=0.5)
axes[1,0].set_xlabel('DP'); axes[1,0].set_ylabel('QUAL')
axes[1,0].set_title('QUAL vs DP (green=PASS, red=FAIL)')

# Mutation spectrum
mut_counts = Counter(f"{v['ref']}>{v['alt']}" for v in pass_all
                     if len(v['ref']) == 1 and len(v['alt']) == 1)
mut_labels = sorted(mut_counts.keys())
ti_set = {f"{r}>{a}" for r, a in TI_PAIRS}
bar_cols = ['steelblue' if m in ti_set else 'darkorange' for m in mut_labels]
axes[1,1].bar(mut_labels, [mut_counts[k] for k in mut_labels], color=bar_cols, edgecolor='white')
axes[1,1].set_xlabel('Mutation type'); axes[1,1].set_ylabel('Count')
axes[1,1].set_title('Mutation Spectrum (blue=Ti, orange=Tv)')
axes[1,1].tick_params(axis='x', rotation=45)

plt.suptitle('Variant Calling QC Dashboard', fontsize=13)
plt.tight_layout(); plt.show()

Challenge C Solutions: Differential Expression

In [ ]:
# ── SOLUTION C1: Load count matrix ────────────────────────────────────────────
with open(COUNTS_PATH) as f:
    header_line = f.readline().strip().split('\t')
    genes, counts_list = [], []
    for line in f:
        p = line.strip().split('\t')
        genes.append(p[0])
        counts_list.append([int(x) for x in p[1:]])

sample_names  = header_line[1:]
counts_matrix = np.array(counts_list, dtype=float)
ctrl_idx      = [i for i, s in enumerate(sample_names) if s.startswith('ctrl')]
treat_idx     = [i for i, s in enumerate(sample_names) if s.startswith('treat')]

n_genes   = counts_matrix.shape[0]
n_samples = counts_matrix.shape[1]
lib_sizes = counts_matrix.sum(axis=0)

print(f"Genes:   {n_genes}")
print(f"Samples: {n_samples}  {sample_names}")
print(f"Library sizes: {lib_sizes.astype(int)}")
In [ ]:
# ── SOLUTION C2: Median-of-ratios normalization ───────────────────────────────
def median_of_ratios(counts):
    nonzero = np.all(counts > 0, axis=1)
    log_c   = np.log(counts[nonzero].astype(float))
    log_gm  = log_c.mean(axis=1, keepdims=True)
    ratios  = log_c - log_gm
    return np.exp(np.median(ratios, axis=0))

size_factors = median_of_ratios(counts_matrix)
norm_counts  = counts_matrix / size_factors[np.newaxis, :]

print("Size factors:")
for name, sf in zip(sample_names, size_factors):
    print(f"  {name}: {sf:.4f}")
In [ ]:
# ── SOLUTION C3: Differential expression ─────────────────────────────────────
try:
    from scipy.stats import ttest_ind
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False

log_norm = np.log2(norm_counts + 1)

def bh_correction(p_values):
    n   = len(p_values)
    idx = sorted(range(n), key=lambda i: p_values[i])
    adj = [0.0] * n
    prev = 1.0
    for rank, i in enumerate(reversed(idx)):
        adj[i] = min(prev, p_values[i] * n / (n - rank))
        prev   = adj[i]
    return adj

results = []
for gi, gene in enumerate(genes):
    ctrl_v  = log_norm[gi, ctrl_idx]
    treat_v = log_norm[gi, treat_idx]
    log2fc  = float(np.mean(treat_v) - np.mean(ctrl_v))
    
    if HAS_SCIPY:
        _, p = ttest_ind(treat_v, ctrl_v, equal_var=False)
    else:
        # Fallback: simple z-score approximation
        diff = np.mean(treat_v) - np.mean(ctrl_v)
        se   = math.sqrt(np.var(ctrl_v, ddof=1)/3 + np.var(treat_v, ddof=1)/3)
        p    = 2 * (1 - 0.5 * (1 + math.erf(abs(diff/se if se else 0) / math.sqrt(2))))
    
    results.append({'gene': gene, 'log2fc': log2fc, 'p_value': max(1e-300, float(p)),
                    'basemean': float(np.mean(norm_counts[gi])), 'gene_idx': gi})

p_adj = bh_correction([r['p_value'] for r in results])
for r, pa in zip(results, p_adj): r['padj'] = pa
results.sort(key=lambda r: r['padj'])

sig_up   = [r for r in results if r['padj'] < 0.05 and r['log2fc'] >= 1]
sig_down = [r for r in results if r['padj'] < 0.05 and r['log2fc'] <= -1]

print(f"DE genes (padj<0.05, |FC|≥2): {len(sig_up)+len(sig_down)}")
print(f"  Upregulated: {len(sig_up)},  Downregulated: {len(sig_down)}")
In [ ]:
# ── SOLUTION C4: Volcano plot ─────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(9, 7))

x = np.array([r['log2fc'] for r in results])
y = np.array([-math.log10(max(r['padj'], 1e-300)) for r in results])

up   = (x >= 1)  & (y >= -math.log10(0.05))
down = (x <= -1) & (y >= -math.log10(0.05))
ns   = ~(up | down)

ax.scatter(x[ns],   y[ns],   c='lightgray', alpha=0.5, s=20, label=f'NS ({ns.sum()})')
ax.scatter(x[up],   y[up],   c='#e74c3c',   alpha=0.8, s=40, label=f'Up ({up.sum()})')
ax.scatter(x[down], y[down], c='#3498db',   alpha=0.8, s=40, label=f'Down ({down.sum()})')

ax.axhline(-math.log10(0.05), color='gray', linestyle='--', alpha=0.6)
ax.axvline(1,  color='gray', linestyle=':', alpha=0.6)
ax.axvline(-1, color='gray', linestyle=':', alpha=0.6)

for r in results[:5]:
    xi = r['log2fc']; yi = -math.log10(max(r['padj'], 1e-300))
    ax.annotate(r['gene'], (xi, yi), textcoords='offset points', xytext=(5, 3), fontsize=8)

ax.set_xlabel('log₂ Fold Change'); ax.set_ylabel('-log₁₀(padj)')
ax.set_title('Volcano Plot: Treated vs Control')
ax.legend()
plt.tight_layout(); plt.show()
In [ ]:
# ── SOLUTION C5: Heatmap ──────────────────────────────────────────────────────
top15 = (sig_up + sig_down)
top15.sort(key=lambda r: r['padj'])
top15 = top15[:15]

mat = np.array([log_norm[r['gene_idx']] for r in top15])
# Z-score per gene
mu = mat.mean(axis=1, keepdims=True)
sd = mat.std( axis=1, keepdims=True)
sd[sd == 0] = 1
z  = (mat - mu) / sd

fig, ax = plt.subplots(figsize=(9, 7))
im = ax.imshow(z, cmap='RdBu_r', aspect='auto', vmin=-2.5, vmax=2.5)
plt.colorbar(im, ax=ax, label='Z-score')
ax.set_xticks(range(len(sample_names)))
ax.set_xticklabels(sample_names, rotation=30, ha='right')
ax.set_yticks(range(len(top15)))
ax.set_yticklabels([r['gene'] for r in top15], fontsize=9)
ax.axvline(len(ctrl_idx) - 0.5, color='black', linewidth=2)

# Annotate FC direction
for i, r in enumerate(top15):
    ax.text(len(sample_names) - 0.3, i,
            '▲' if r['log2fc'] > 0 else '▼',
            va='center', fontsize=9,
            color='#e74c3c' if r['log2fc'] > 0 else '#3498db')

ax.set_title('Top 15 DE Genes: Expression Heatmap', fontsize=12)
plt.tight_layout(); plt.show()
In [ ]:
# ── FINAL SUMMARY ─────────────────────────────────────────────────────────────
print("="*55)
print("CAPSTONE ANALYSIS SUMMARY — SOLUTIONS")
print("="*55)

print("\n[ Challenge A: QC and Trimming ]")
print(f"  Input reads:             {len(reads)}")
print(f"  Reads with adapter:      {n_adapter_reads} ({pct_adapter:.1f}%)")
print(f"  Reads after trimming:    {len(trimmed_reads)} ({len(trimmed_reads)/len(reads)*100:.1f}%)")
print(f"  Mean quality (before):   {sum(raw_mq)/len(raw_mq):.1f}")
print(f"  Mean quality (after):    {sum(trimmed_mq)/len(trimmed_mq):.1f}")

print("\n[ Challenge B: Variant Calling ]")
print(f"  Total variants:          {len(variants)}")
print(f"  High-confidence:         {len(pass_all)}")
print(f"  Ti/Tv (all):             {titv_all:.2f}")
print(f"  Ti/Tv (passing):         {titv_pass:.2f}")

print("\n[ Challenge C: Differential Expression ]")
print(f"  Genes tested:            {len(results)}")
print(f"  DE genes (padj<0.05):    {len(sig_up)+len(sig_down)}")
print(f"  Upregulated:             {len(sig_up)}")
print(f"  Downregulated:           {len(sig_down)}")