Complete working solutions for all 3 capstone challenges.
Try the exercises first! Only consult these solutions after making a genuine attempt.
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
import math
import random
import re
from pathlib import Path
from collections import Counter
import os
random.seed(42)
np.random.seed(42)
workshop_root = Path(os.path.abspath("../../"))
capstone_dir = workshop_root / 'data' / 'capstone'
FASTQ_PATH = capstone_dir / 'synthetic_reads.fastq'
VCF_PATH = capstone_dir / 'synthetic_variants.vcf'
COUNTS_PATH = capstone_dir / 'synthetic_counts.tsv'
ADAPTER = "AGATCGGAAGAGCACACGTCTGAACTCCAGTCA"
TI_PAIRS = {('A','G'),('G','A'),('C','T'),('T','C')}
print("Setup complete.")
# ── SOLUTION A1: Load FASTQ ───────────────────────────────────────────────────
def parse_fastq(filepath):
with open(filepath) as f:
while True:
h = f.readline().rstrip()
s = f.readline().rstrip()
_ = f.readline()
q = f.readline().rstrip()
if not h: break
yield h[1:], s, [ord(c) - 33 for c in q]
reads = list(parse_fastq(FASTQ_PATH))
total_reads = len(reads)
mean_read_len = sum(len(s) for _, s, _ in reads) / total_reads
all_quals = [q for _, _, qs in reads for q in qs]
mean_quality = sum(all_quals) / len(all_quals)
pct_q30 = sum(1 for q in all_quals if q >= 30) / len(all_quals) * 100
print(f"Total reads: {total_reads}")
print(f"Mean read length: {mean_read_len:.1f} bp")
print(f"Mean quality: {mean_quality:.1f}")
print(f"Bases with Q>=30: {pct_q30:.1f}%")
# ── SOLUTION A2: Per-base quality plot ────────────────────────────────────────
read_len = max(len(q) for _, _, q in reads)
per_pos_mean = []
for i in range(read_len):
vals = [q[i] for _, _, q in reads if i < len(q)]
per_pos_mean.append(sum(vals) / len(vals) if vals else 0)
fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(range(1, read_len+1), per_pos_mean, color='steelblue', linewidth=1.5)
ax.axhline(30, color='green', linestyle='--', label='Q30', alpha=0.7)
ax.axhline(20, color='orange', linestyle='--', label='Q20', alpha=0.7)
ax.fill_between(range(1, read_len+1), per_pos_mean, 20,
where=[q < 20 for q in per_pos_mean],
color='red', alpha=0.2, label='Below Q20')
ax.set_xlabel('Position in read (bp)'); ax.set_ylabel('Mean Phred quality')
ax.set_title('Per-Base Quality Profile (Before Trimming)')
ax.set_ylim(0, 45); ax.legend()
plt.tight_layout(); plt.show()
# ── SOLUTION A3: Adapter detection ───────────────────────────────────────────
adapter_prefix = ADAPTER[:12] # 12-mer is specific enough
adapter_positions = []
for _, seq, _ in reads:
pos = seq.find(adapter_prefix)
if pos >= 0:
adapter_positions.append(pos)
n_adapter_reads = len(adapter_positions)
pct_adapter = n_adapter_reads / total_reads * 100
print(f"Reads with adapter: {n_adapter_reads} ({pct_adapter:.1f}%)")
fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(adapter_positions, bins=20, color='darkorange', edgecolor='white')
ax.set_xlabel('Position of adapter in read (bp)')
ax.set_ylabel('Number of reads')
ax.set_title(f'Adapter Contamination: Position Distribution ({n_adapter_reads} reads, {pct_adapter:.1f}%)')
plt.tight_layout(); plt.show()
# ── SOLUTION A4: Trimming pipeline ───────────────────────────────────────────
def trim_adapter(seq, quals, adapter, min_overlap=10):
for overlap in range(min(len(adapter), len(seq)), min_overlap - 1, -1):
prefix = adapter[:overlap]
pos = seq.find(prefix)
if pos >= 0:
return seq[:pos], quals[:pos]
return seq, quals
def trim_sliding_window(seq, quals, window=4, min_qual=20):
trim_pos = len(seq)
for i in range(len(seq) - window + 1):
if sum(quals[i:i+window]) / window < min_qual:
trim_pos = i
break
return seq[:trim_pos], quals[:trim_pos]
def trim_minlen(seq, quals, min_len=36):
if len(seq) < min_len:
return None, None
return seq, quals
trimmed_reads = []
n_discarded = 0
for header, seq, quals in reads:
s, q = trim_adapter(seq, quals, ADAPTER)
s, q = trim_sliding_window(s, q)
s, q = trim_minlen(s, q)
if s is None:
n_discarded += 1
else:
trimmed_reads.append((header, s, q))
print(f"Input reads: {len(reads)}")
print(f"Passing reads: {len(trimmed_reads)} ({len(trimmed_reads)/len(reads)*100:.1f}%)")
print(f"Discarded: {n_discarded} ({n_discarded/len(reads)*100:.1f}%)")
# ── SOLUTION A5: Before/after comparison ─────────────────────────────────────
def per_pos_stats(reads_list, max_len=150):
per_pos = [[] for _ in range(max_len)]
for _, _, q in reads_list:
for i, v in enumerate(q[:max_len]):
per_pos[i].append(v)
return [sum(p)/len(p) if p else 0 for p in per_pos]
pp_raw = per_pos_stats(reads)
pp_trimmed = per_pos_stats(trimmed_reads)
raw_lens = [len(s) for _, s, _ in reads]
trimmed_lens = [len(s) for _, s, _ in trimmed_reads]
raw_mq = [sum(q)/len(q) for _, _, q in reads]
trimmed_mq = [sum(q)/len(q) for _, _, q in trimmed_reads]
fig, axes = plt.subplots(2, 2, figsize=(14, 8))
for ax, pp, color, label in [
(axes[0,0], pp_raw, 'steelblue', 'Before'),
(axes[0,1], pp_trimmed, 'mediumseagreen', 'After'),
]:
ax.plot(range(1, len(pp)+1), pp, color=color, linewidth=1.5)
ax.axhline(30, color='green', linestyle='--', alpha=0.6, label='Q30')
ax.axhline(20, color='orange', linestyle='--', alpha=0.6, label='Q20')
ax.set_title(f'Per-Base Quality — {label}'); ax.set_ylim(0, 45); ax.legend(fontsize=8)
ax.set_xlabel('Position (bp)'); ax.set_ylabel('Mean Q')
axes[1,0].hist(raw_lens, bins=20, color='steelblue', alpha=0.7, edgecolor='white', label='Raw')
axes[1,0].hist(trimmed_lens, bins=20, color='mediumseagreen', alpha=0.7, edgecolor='white', label='Trimmed')
axes[1,0].set_xlabel('Read length (bp)'); axes[1,0].set_ylabel('Count')
axes[1,0].set_title('Read Length Distribution'); axes[1,0].legend()
axes[1,1].hist(raw_mq, bins=25, color='steelblue', alpha=0.7, edgecolor='white', label='Raw')
axes[1,1].hist(trimmed_mq, bins=25, color='mediumseagreen', alpha=0.7, edgecolor='white', label='Trimmed')
axes[1,1].axvline(30, color='red', linestyle='--', label='Q30')
axes[1,1].set_xlabel('Mean quality per read'); axes[1,1].set_ylabel('Count')
axes[1,1].set_title('Per-Read Mean Quality'); axes[1,1].legend()
plt.suptitle('QC: Before vs After Trimming', fontsize=13)
plt.tight_layout(); plt.show()
print(f"Mean Q before: {sum(raw_mq)/len(raw_mq):.1f}")
print(f"Mean Q after: {sum(trimmed_mq)/len(trimmed_mq):.1f}")
# ── SOLUTION B1: Parse VCF ────────────────────────────────────────────────────
def parse_vcf(filepath):
variants = []
with open(filepath) as f:
for line in f:
line = line.rstrip()
if line.startswith('#') or not line:
continue
parts = line.split('\t')
chrom, pos, vid, ref, alt, qual, flt, info_str = parts[:8]
info = {}
for item in info_str.split(';'):
if '=' in item:
k, v = item.split('=', 1)
info[k] = v
variants.append({
'chrom': chrom,
'pos': int(pos),
'ref': ref,
'alt': alt,
'qual': float(qual),
'filter': flt,
'dp': int(info.get('DP', 0)),
'af': float(info.get('AF', 0)),
})
return variants
variants = parse_vcf(VCF_PATH)
print(f"Parsed {len(variants)} variants")
for v in variants[:3]:
print(f" {v}")
# ── SOLUTION B2: Apply filters ────────────────────────────────────────────────
pass_filter = [v for v in variants if v['filter'] == 'PASS']
pass_qual = [v for v in variants if v['qual'] >= 50]
pass_dp = [v for v in variants if v['dp'] >= 20]
pass_af = [v for v in variants if 0.1 <= v['af'] <= 0.95]
pass_all = [v for v in variants
if v['filter'] == 'PASS' and v['qual'] >= 50
and v['dp'] >= 20 and 0.1 <= v['af'] <= 0.95]
print(f"Total: {len(variants)}")
print(f"FILTER=PASS: {len(pass_filter)}")
print(f"QUAL >= 50: {len(pass_qual)}")
print(f"DP >= 20: {len(pass_dp)}")
print(f"AF 0.1-0.95: {len(pass_af)}")
print(f"All filters: {len(pass_all)}")
# ── SOLUTION B3: Ti/Tv ratio ──────────────────────────────────────────────────
def compute_titv(variant_list):
ti = sum(1 for v in variant_list if (v['ref'], v['alt']) in TI_PAIRS)
tv = len(variant_list) - ti
return ti, tv, ti / max(1, tv)
ti_all, tv_all, titv_all = compute_titv(variants)
ti_pass, tv_pass, titv_pass = compute_titv(pass_all)
print(f"{'':12s} {'Ti':>6} {'Tv':>6} {'Ti/Tv':>8}")
print(f"{'All':12s} {ti_all:>6} {tv_all:>6} {titv_all:>8.2f}")
print(f"{'Passing':12s} {ti_pass:>6} {tv_pass:>6} {titv_pass:>8.2f}")
# ── SOLUTION B4: QC dashboard ────────────────────────────────────────────────
fig, axes = plt.subplots(2, 2, figsize=(14, 9))
# QUAL distribution
axes[0,0].hist([v['qual'] for v in variants], bins=20, color='steelblue', alpha=0.6, edgecolor='white', label='All')
axes[0,0].hist([v['qual'] for v in pass_all], bins=20, color='mediumseagreen', alpha=0.8, edgecolor='white', label='Passing')
axes[0,0].axvline(50, color='red', linestyle='--', label='QUAL=50')
axes[0,0].set_xlabel('QUAL'); axes[0,0].set_ylabel('Count')
axes[0,0].set_title('Quality Score Distribution'); axes[0,0].legend()
# AF distribution
axes[0,1].hist([v['af'] for v in variants], bins=15, color='steelblue', alpha=0.6, edgecolor='white', label='All')
axes[0,1].hist([v['af'] for v in pass_all], bins=15, color='mediumseagreen', alpha=0.8, edgecolor='white', label='Passing')
axes[0,1].set_xlabel('Allele Frequency'); axes[0,1].set_ylabel('Count')
axes[0,1].set_title('Allele Frequency Distribution'); axes[0,1].legend()
# QUAL vs DP scatter
pass_set = set(id(v) for v in pass_all)
colors_v = ['mediumseagreen' if id(v) in pass_set else 'tomato' for v in variants]
axes[1,0].scatter([v['dp'] for v in variants], [v['qual'] for v in variants],
c=colors_v, alpha=0.7, s=40)
axes[1,0].axvline(20, color='red', linestyle='--', alpha=0.5)
axes[1,0].axhline(50, color='red', linestyle='--', alpha=0.5)
axes[1,0].set_xlabel('DP'); axes[1,0].set_ylabel('QUAL')
axes[1,0].set_title('QUAL vs DP (green=PASS, red=FAIL)')
# Mutation spectrum
mut_counts = Counter(f"{v['ref']}>{v['alt']}" for v in pass_all
if len(v['ref']) == 1 and len(v['alt']) == 1)
mut_labels = sorted(mut_counts.keys())
ti_set = {f"{r}>{a}" for r, a in TI_PAIRS}
bar_cols = ['steelblue' if m in ti_set else 'darkorange' for m in mut_labels]
axes[1,1].bar(mut_labels, [mut_counts[k] for k in mut_labels], color=bar_cols, edgecolor='white')
axes[1,1].set_xlabel('Mutation type'); axes[1,1].set_ylabel('Count')
axes[1,1].set_title('Mutation Spectrum (blue=Ti, orange=Tv)')
axes[1,1].tick_params(axis='x', rotation=45)
plt.suptitle('Variant Calling QC Dashboard', fontsize=13)
plt.tight_layout(); plt.show()
# ── SOLUTION C1: Load count matrix ────────────────────────────────────────────
with open(COUNTS_PATH) as f:
header_line = f.readline().strip().split('\t')
genes, counts_list = [], []
for line in f:
p = line.strip().split('\t')
genes.append(p[0])
counts_list.append([int(x) for x in p[1:]])
sample_names = header_line[1:]
counts_matrix = np.array(counts_list, dtype=float)
ctrl_idx = [i for i, s in enumerate(sample_names) if s.startswith('ctrl')]
treat_idx = [i for i, s in enumerate(sample_names) if s.startswith('treat')]
n_genes = counts_matrix.shape[0]
n_samples = counts_matrix.shape[1]
lib_sizes = counts_matrix.sum(axis=0)
print(f"Genes: {n_genes}")
print(f"Samples: {n_samples} {sample_names}")
print(f"Library sizes: {lib_sizes.astype(int)}")
# ── SOLUTION C2: Median-of-ratios normalization ───────────────────────────────
def median_of_ratios(counts):
nonzero = np.all(counts > 0, axis=1)
log_c = np.log(counts[nonzero].astype(float))
log_gm = log_c.mean(axis=1, keepdims=True)
ratios = log_c - log_gm
return np.exp(np.median(ratios, axis=0))
size_factors = median_of_ratios(counts_matrix)
norm_counts = counts_matrix / size_factors[np.newaxis, :]
print("Size factors:")
for name, sf in zip(sample_names, size_factors):
print(f" {name}: {sf:.4f}")
# ── SOLUTION C3: Differential expression ─────────────────────────────────────
try:
from scipy.stats import ttest_ind
HAS_SCIPY = True
except ImportError:
HAS_SCIPY = False
log_norm = np.log2(norm_counts + 1)
def bh_correction(p_values):
n = len(p_values)
idx = sorted(range(n), key=lambda i: p_values[i])
adj = [0.0] * n
prev = 1.0
for rank, i in enumerate(reversed(idx)):
adj[i] = min(prev, p_values[i] * n / (n - rank))
prev = adj[i]
return adj
results = []
for gi, gene in enumerate(genes):
ctrl_v = log_norm[gi, ctrl_idx]
treat_v = log_norm[gi, treat_idx]
log2fc = float(np.mean(treat_v) - np.mean(ctrl_v))
if HAS_SCIPY:
_, p = ttest_ind(treat_v, ctrl_v, equal_var=False)
else:
# Fallback: simple z-score approximation
diff = np.mean(treat_v) - np.mean(ctrl_v)
se = math.sqrt(np.var(ctrl_v, ddof=1)/3 + np.var(treat_v, ddof=1)/3)
p = 2 * (1 - 0.5 * (1 + math.erf(abs(diff/se if se else 0) / math.sqrt(2))))
results.append({'gene': gene, 'log2fc': log2fc, 'p_value': max(1e-300, float(p)),
'basemean': float(np.mean(norm_counts[gi])), 'gene_idx': gi})
p_adj = bh_correction([r['p_value'] for r in results])
for r, pa in zip(results, p_adj): r['padj'] = pa
results.sort(key=lambda r: r['padj'])
sig_up = [r for r in results if r['padj'] < 0.05 and r['log2fc'] >= 1]
sig_down = [r for r in results if r['padj'] < 0.05 and r['log2fc'] <= -1]
print(f"DE genes (padj<0.05, |FC|≥2): {len(sig_up)+len(sig_down)}")
print(f" Upregulated: {len(sig_up)}, Downregulated: {len(sig_down)}")
# ── SOLUTION C4: Volcano plot ─────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(9, 7))
x = np.array([r['log2fc'] for r in results])
y = np.array([-math.log10(max(r['padj'], 1e-300)) for r in results])
up = (x >= 1) & (y >= -math.log10(0.05))
down = (x <= -1) & (y >= -math.log10(0.05))
ns = ~(up | down)
ax.scatter(x[ns], y[ns], c='lightgray', alpha=0.5, s=20, label=f'NS ({ns.sum()})')
ax.scatter(x[up], y[up], c='#e74c3c', alpha=0.8, s=40, label=f'Up ({up.sum()})')
ax.scatter(x[down], y[down], c='#3498db', alpha=0.8, s=40, label=f'Down ({down.sum()})')
ax.axhline(-math.log10(0.05), color='gray', linestyle='--', alpha=0.6)
ax.axvline(1, color='gray', linestyle=':', alpha=0.6)
ax.axvline(-1, color='gray', linestyle=':', alpha=0.6)
for r in results[:5]:
xi = r['log2fc']; yi = -math.log10(max(r['padj'], 1e-300))
ax.annotate(r['gene'], (xi, yi), textcoords='offset points', xytext=(5, 3), fontsize=8)
ax.set_xlabel('log₂ Fold Change'); ax.set_ylabel('-log₁₀(padj)')
ax.set_title('Volcano Plot: Treated vs Control')
ax.legend()
plt.tight_layout(); plt.show()
# ── SOLUTION C5: Heatmap ──────────────────────────────────────────────────────
top15 = (sig_up + sig_down)
top15.sort(key=lambda r: r['padj'])
top15 = top15[:15]
mat = np.array([log_norm[r['gene_idx']] for r in top15])
# Z-score per gene
mu = mat.mean(axis=1, keepdims=True)
sd = mat.std( axis=1, keepdims=True)
sd[sd == 0] = 1
z = (mat - mu) / sd
fig, ax = plt.subplots(figsize=(9, 7))
im = ax.imshow(z, cmap='RdBu_r', aspect='auto', vmin=-2.5, vmax=2.5)
plt.colorbar(im, ax=ax, label='Z-score')
ax.set_xticks(range(len(sample_names)))
ax.set_xticklabels(sample_names, rotation=30, ha='right')
ax.set_yticks(range(len(top15)))
ax.set_yticklabels([r['gene'] for r in top15], fontsize=9)
ax.axvline(len(ctrl_idx) - 0.5, color='black', linewidth=2)
# Annotate FC direction
for i, r in enumerate(top15):
ax.text(len(sample_names) - 0.3, i,
'▲' if r['log2fc'] > 0 else '▼',
va='center', fontsize=9,
color='#e74c3c' if r['log2fc'] > 0 else '#3498db')
ax.set_title('Top 15 DE Genes: Expression Heatmap', fontsize=12)
plt.tight_layout(); plt.show()
# ── FINAL SUMMARY ─────────────────────────────────────────────────────────────
print("="*55)
print("CAPSTONE ANALYSIS SUMMARY — SOLUTIONS")
print("="*55)
print("\n[ Challenge A: QC and Trimming ]")
print(f" Input reads: {len(reads)}")
print(f" Reads with adapter: {n_adapter_reads} ({pct_adapter:.1f}%)")
print(f" Reads after trimming: {len(trimmed_reads)} ({len(trimmed_reads)/len(reads)*100:.1f}%)")
print(f" Mean quality (before): {sum(raw_mq)/len(raw_mq):.1f}")
print(f" Mean quality (after): {sum(trimmed_mq)/len(trimmed_mq):.1f}")
print("\n[ Challenge B: Variant Calling ]")
print(f" Total variants: {len(variants)}")
print(f" High-confidence: {len(pass_all)}")
print(f" Ti/Tv (all): {titv_all:.2f}")
print(f" Ti/Tv (passing): {titv_pass:.2f}")
print("\n[ Challenge C: Differential Expression ]")
print(f" Genes tested: {len(results)}")
print(f" DE genes (padj<0.05): {len(sig_up)+len(sig_down)}")
print(f" Upregulated: {len(sig_up)}")
print(f" Downregulated: {len(sig_down)}")