"""qc_toolkit — the quality-control engine for Module 3 (QC + Read Trimming).

Why this file exists
--------------------
Module 3's notebook used to define a lot of dense Python inline: a read
simulator with realistic artifacts, per-position quality statistics, and the
adapter / sliding-window / length-filter trimmers that mirror Trimmomatic and
fastp. That is more implementation than belongs on a teaching slide. So it lives
here instead — a plain, well-commented script you can **download and run on your
own machine**:

    python qc_toolkit.py        # runs the demo at the bottom

The notebook simply does `from qc_toolkit import ...` and spends its cells on
*running* these functions and *plotting* the results (the FastQC-style panels,
the trade-off curves, the before/after dashboard). Read this file when you want
to see exactly what each QC step does — every function is short and commented.

Dependencies: the standard library only (random, math).
"""

import random

# Illumina TruSeq adapter and the DNA alphabet.
ADAPTER = "AGATCGGAAGAGCACACGTCTGAACTCCAGTCA"
BASES = 'ACGT'


# ---------------------------------------------------------------------------
# 1. Simulating a realistic FASTQ dataset
# ---------------------------------------------------------------------------
def random_dna(length):
    """A random DNA string of the given length."""
    return ''.join(random.choice(BASES) for _ in range(length))


def low_complexity_seq(length):
    """A poly-A / poly-G stretch followed by random bases — a common artifact."""
    base = random.choice(['A', 'G'])
    core = base * (length // 2)
    rest = random_dna(length - len(core))
    return core + rest


def quality_profile_good(length):
    """Q35-40 throughout — a high-quality run."""
    return [random.randint(35, 40) for _ in range(length)]


def quality_profile_bad(length):
    """Q30-38 in the first half, collapsing to Q8-22 at the 3' end — the
    classic Illumina quality decay."""
    quals = []
    for i in range(length):
        frac = i / length
        if frac < 0.5:
            quals.append(random.randint(30, 38))
        elif frac < 0.75:
            quals.append(random.randint(20, 32))
        else:
            quals.append(random.randint(8, 22))   # 3' quality collapse
    return quals


def _make_one_read(read_len):
    """Return (seq, quals, tag) for a single simulated read: a mix of good,
    adapter-contaminated, and low-complexity reads."""
    rand = random.random()
    if rand < 0.05:                       # 5% low complexity
        return low_complexity_seq(read_len), quality_profile_bad(read_len), 'lowcmplx'
    if rand < 0.20:                       # 15% adapter contamination
        insert_len = random.randint(50, 130)
        seq = (random_dna(insert_len) + ADAPTER)[:read_len]
        return seq, quality_profile_bad(read_len), 'adapter'
    # 80% real biological read: mostly good quality, sometimes bad.
    quals = quality_profile_bad(read_len) if random.random() < 0.3 else quality_profile_good(read_len)
    return random_dna(read_len), quals, 'good'


def generate_fastq_dataset(n_reads=2000, read_len=150, name='simulated'):
    """Generate a list of (header, seq, quals) tuples with realistic artifacts."""
    reads = []
    for i in range(n_reads):
        seq, quals, tag = _make_one_read(read_len)
        header = f"{name}_read_{i+1:05d} type={tag}"
        reads.append((header, seq[:read_len], quals[:read_len]))
    return reads


# ---------------------------------------------------------------------------
# 2. Per-position quality statistics (FastQC's per-base quality panel)
# ---------------------------------------------------------------------------
def compute_per_position_stats(reads):
    """For each position along the read, gather quality scores across every read
    and summarize them. Returns a list (one dict per position) with mean,
    median, and the 10/25/75/90th percentiles."""
    read_len = max(len(q) for _, _, q in reads)
    positions = [[] for _ in range(read_len)]
    for _, seq, quals in reads:
        for i, q in enumerate(quals):
            positions[i].append(q)

    stats = []
    for pos_quals in positions:
        if not pos_quals:
            continue
        arr = sorted(pos_quals)
        n = len(arr)
        stats.append({
            'mean': sum(arr) / n,
            'median': arr[n // 2],
            'q10': arr[n // 10],
            'q90': arr[int(n * 0.9)],
            'q25': arr[n // 4],
            'q75': arr[int(n * 0.75)],
        })
    return stats


# ---------------------------------------------------------------------------
# 3. Trimming: the building blocks Trimmomatic / fastp use
# ---------------------------------------------------------------------------
def trim_sliding_window(seq, quals, window=4, min_qual=20):
    """Trimmomatic SLIDINGWINDOW. Scan from the 5' end; when a window of
    `window` bases drops below mean quality `min_qual`, cut there.
    Returns (trimmed_seq, trimmed_quals)."""
    trim_pos = len(seq)  # default: no trim
    for i in range(len(seq) - window + 1):
        if sum(quals[i:i + window]) / window < min_qual:
            trim_pos = i
            break
    return seq[:trim_pos], quals[:trim_pos]


def trim_adapter(seq, quals, adapter, min_overlap=12):
    """Remove adapter sequence from the 3' end, searching with decreasing
    overlap so partial adapters are caught too."""
    for overlap in range(min(len(adapter), len(seq)), min_overlap - 1, -1):
        adapter_prefix = adapter[:overlap]
        if seq.endswith(adapter_prefix):
            trim_pos = len(seq) - overlap
            return seq[:trim_pos], quals[:trim_pos]
        pos = seq.rfind(adapter_prefix)   # also check internal positions
        if pos >= 0:
            return seq[:pos], quals[:pos]
    return seq, quals


def trim_minlen(seq, quals, min_len=36):
    """Discard reads shorter than min_len after other trimming (returns
    (None, None) for a discarded read)."""
    if len(seq) < min_len:
        return None, None
    return seq, quals


def full_trim_pipeline(reads, adapter=ADAPTER, window=4, min_qual=20, min_len=36):
    """Apply adapter trim -> sliding window -> length filter to every read.
    Returns (trimmed_reads, n_discarded)."""
    trimmed = []
    discarded = 0
    for header, seq, quals in reads:
        s, q = trim_adapter(seq, quals, adapter)
        s, q = trim_sliding_window(s, q, window=window, min_qual=min_qual)
        s, q = trim_minlen(s, q, min_len=min_len)
        if s is None:
            discarded += 1
        else:
            trimmed.append((header, s, q))
    return trimmed, discarded


# ---------------------------------------------------------------------------
# Standalone demo: `python qc_toolkit.py`
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Plotting / convenience helpers (real-tool workflow rewrite).
# matplotlib boilerplate lives here so notebook cells stay one or two lines.
# ---------------------------------------------------------------------------
def gc_content(seq):
    """Percent G+C in a sequence."""
    seq = seq.upper()
    return (seq.count('G') + seq.count('C')) / len(seq) * 100 if seq else 0.0


def trim_reads(reads, adapter=ADAPTER, window=4, min_qual=20, min_len=36):
    """Compose the trim primitives over a read list:
    adapter-clip -> sliding-window quality trim -> length filter.
    Returns (kept_reads, n_discarded)."""
    kept, discarded = [], 0
    for h, s, q in reads:
        s1, q1 = trim_adapter(s, q, adapter)
        s2, q2 = trim_sliding_window(s1, q1, window=window, min_qual=min_qual)
        if len(s2) >= min_len:
            kept.append((h, s2, q2))
        else:
            discarded += 1
    return kept, discarded


def adapter_content(reads, adapter=ADAPTER, read_len=None):
    """Per-position percentage of reads carrying the adapter stub (FastQC's metric)."""
    if read_len is None:
        read_len = max(len(s) for _, s, _ in reads)
    pref = adapter[:16]
    pres = [0] * read_len
    for _, s, _ in reads:
        idx = s.find(pref)
        if idx != -1:
            for i in range(idx, min(idx + len(pref), read_len)):
                pres[i] += 1
    return [c / len(reads) * 100 for c in pres]


def plot_per_base_quality(reads, title="Per-base sequence quality"):
    """Reproduce FastQC's per-base quality panel from a list of reads."""
    import matplotlib.pyplot as plt
    pp = compute_per_position_stats(reads)
    x = list(range(1, len(pp) + 1))
    means = [s['mean'] for s in pp]
    q25 = [s['q25'] for s in pp]; q75 = [s['q75'] for s in pp]
    q10 = [s['q10'] for s in pp]; q90 = [s['q90'] for s in pp]
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.fill_between(x, q10, q90, alpha=0.15, color='steelblue', label='10-90th pct')
    ax.fill_between(x, q25, q75, alpha=0.30, color='steelblue', label='25-75th pct')
    ax.plot(x, means, color='steelblue', linewidth=2, label='Mean Q')
    ax.axhspan(0, 20, alpha=0.08, color='red')
    ax.axhspan(20, 28, alpha=0.08, color='orange')
    ax.axhspan(28, 45, alpha=0.06, color='green')
    ax.set_xlabel('Position in read (bp)'); ax.set_ylabel('Phred quality score')
    ax.set_title(title); ax.set_ylim(0, 45); ax.legend(loc='lower left', fontsize=9)
    plt.tight_layout(); plt.show()
    return fig


def plot_gc_and_adapter(reads, adapter=ADAPTER, title_suffix=""):
    """FastQC-style GC distribution + adapter-content panels."""
    import matplotlib.pyplot as plt
    read_len = max(len(s) for _, s, _ in reads)
    gc = [gc_content(s) for _, s, _ in reads]
    apct = adapter_content(reads, adapter, read_len)
    fig, ax = plt.subplots(1, 2, figsize=(14, 4))
    ax[0].hist(gc, bins=40, color='mediumseagreen', edgecolor='white', density=True)
    ax[0].axvline(41, color='red', linestyle='--', label='Human ~41%')
    ax[0].set_xlabel('GC content (%)'); ax[0].set_ylabel('Density')
    ax[0].set_title('GC Content Distribution' + title_suffix); ax[0].legend()
    ax[1].plot(range(1, read_len + 1), apct, color='darkorange', linewidth=1.5)
    ax[1].fill_between(range(1, read_len + 1), apct, alpha=0.3, color='darkorange')
    ax[1].set_xlabel('Position in read (bp)'); ax[1].set_ylabel('% reads with adapter')
    ax[1].set_ylim(bottom=0); ax[1].set_title('Adapter Content' + title_suffix)
    plt.tight_layout(); plt.show()
    return fig


def plot_before_after(raw_reads, trimmed_reads):
    """Four-panel before/after QC dashboard."""
    import matplotlib.pyplot as plt
    pp_raw = compute_per_position_stats(raw_reads)
    pp_trim = compute_per_position_stats(trimmed_reads)
    fig, axes = plt.subplots(2, 2, figsize=(14, 9))
    for ax, pp, color, label in [
        (axes[0, 0], pp_raw, 'steelblue', 'Before trimming'),
        (axes[0, 1], pp_trim, 'mediumseagreen', 'After trimming'),
    ]:
        x = list(range(1, len(pp) + 1))
        ax.fill_between(x, [s['q25'] for s in pp], [s['q75'] for s in pp], alpha=0.3, color=color)
        ax.plot(x, [s['mean'] for s in pp], color=color, linewidth=2)
        ax.axhline(30, color='red', linestyle='--', alpha=0.6, label='Q30')
        ax.axhline(20, color='orange', linestyle='--', alpha=0.6, label='Q20')
        ax.set_title('Per-Base Quality - ' + label); ax.set_xlabel('Position (bp)')
        ax.set_ylabel('Phred Q'); ax.set_ylim(0, 45); ax.legend(fontsize=8)
    raw_lens = [len(s) for _, s, _ in raw_reads]
    trim_lens = [len(s) for _, s, _ in trimmed_reads]
    axes[1, 0].hist(raw_lens, bins=20, color='steelblue', edgecolor='white', alpha=0.7, label='Raw')
    axes[1, 0].hist(trim_lens, bins=20, color='mediumseagreen', edgecolor='white', alpha=0.7, label='Trimmed')
    axes[1, 0].set_xlabel('Read length (bp)'); axes[1, 0].set_ylabel('Count')
    axes[1, 0].set_title('Read Length Distribution'); axes[1, 0].legend()
    raw_q = [sum(q)/len(q) for _, _, q in raw_reads if q]
    trim_q = [sum(q)/len(q) for _, _, q in trimmed_reads if q]
    axes[1, 1].hist(raw_q, bins=30, color='steelblue', edgecolor='white', alpha=0.7, label='Raw')
    axes[1, 1].hist(trim_q, bins=30, color='mediumseagreen', edgecolor='white', alpha=0.7, label='Trimmed')
    axes[1, 1].axvline(30, color='red', linestyle='--', label='Q30')
    axes[1, 1].set_xlabel('Mean quality per read'); axes[1, 1].set_ylabel('Count')
    axes[1, 1].set_title('Per-Read Mean Quality'); axes[1, 1].legend()
    plt.suptitle('Quality Control: Before vs After Trimming', y=1.01)
    plt.tight_layout(); plt.show()
    return fig


if __name__ == "__main__":
    random.seed(42)

    reads = generate_fastq_dataset(n_reads=2000, read_len=150)
    print(f"Generated {len(reads)} reads")

    stats = compute_per_position_stats(reads)
    print(f"Pos 0 mean Q = {stats[0]['mean']:.1f}   "
          f"Pos -1 mean Q = {stats[-1]['mean']:.1f}  (note the 3' decay)")

    trimmed, discarded = full_trim_pipeline(reads, window=4, min_qual=20, min_len=36)
    kept = len(trimmed)
    mean_len = sum(len(s) for _, s, _ in trimmed) / max(1, kept)
    mean_q = sum(sum(q) / len(q) for _, _, q in trimmed) / max(1, kept)
    print(f"Trimmomatic Q20: kept {kept}/{len(reads)} "
          f"({kept/len(reads)*100:.1f}%) | mean len {mean_len:.1f} bp | mean Q {mean_q:.2f}")
