"""align_toolkit — the alignment engine for Module 4 (Read Alignment).

Why this file exists
--------------------
Module 4's notebook used to define the heavy algorithms inline: a full
Smith-Waterman dynamic-programming aligner, the traceback, a CIGAR encoder, and
a pileup simulator/variant-caller. That is a lot of dense Python to read on a
slide. So the implementations live here instead — a plain, well-commented script
you can **download and run on your own machine**:

    python align_toolkit.py        # runs the demo at the bottom

The notebook simply does `from align_toolkit import ...` and spends its cells on
*using* and *visualizing* these functions rather than re-deriving them. Read this
file top to bottom when you want to see exactly how an aligner works under the
hood — every function is short and commented.

Dependencies: numpy (for the scoring matrix) and the standard library only.
"""

import random
import numpy as np


# ---------------------------------------------------------------------------
# 1. Smith-Waterman local alignment
# ---------------------------------------------------------------------------
def smith_waterman(query, ref, match=2, mismatch=-1, gap=-2):
    """Smith-Waterman local alignment.

    Fills the dynamic-programming (DP) matrix: each cell H[i, j] holds the best
    local-alignment score of any alignment ending with query[i-1] / ref[j-1],
    and T[i, j] records which move produced it (used later for traceback).

    Returns (score_matrix H, traceback_matrix T, best_score, best_pos).
    """
    m, n = len(query), len(ref)
    H = np.zeros((m + 1, n + 1), dtype=float)
    T = np.zeros((m + 1, n + 1), dtype=int)  # 0=end, 1=diag, 2=up, 3=left

    best_score = 0
    best_pos = (0, 0)

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            s = match if query[i - 1] == ref[j - 1] else mismatch

            diag = H[i - 1, j - 1] + s
            up = H[i - 1, j] + gap
            left = H[i, j - 1] + gap
            best = max(0, diag, up, left)

            H[i, j] = best
            if best == 0:
                T[i, j] = 0
            elif best == diag:
                T[i, j] = 1  # diagonal -> aligned column
            elif best == up:
                T[i, j] = 2  # up -> gap in reference
            else:
                T[i, j] = 3  # left -> gap in query

            if H[i, j] > best_score:
                best_score = H[i, j]
                best_pos = (i, j)

    return H, T, best_score, best_pos


# ---------------------------------------------------------------------------
# 2. Traceback -> aligned strings
# ---------------------------------------------------------------------------
def traceback(query, ref, H, T, best_pos):
    """Begin at the highest-scoring cell and follow the stored arrows back to a
    zero, rebuilding the aligned query / reference strings and a match track.

    Returns (aligned_query, alignment_track, aligned_ref) where the track uses
    '|' for a match, '.' for a mismatch, and ' ' for a gap column.
    """
    i, j = best_pos
    aligned_q, aligned_r, alignment = [], [], []

    while H[i, j] > 0:
        direction = T[i, j]
        if direction == 1:        # diagonal
            aligned_q.append(query[i - 1])
            aligned_r.append(ref[j - 1])
            alignment.append('|' if query[i - 1] == ref[j - 1] else '.')
            i -= 1
            j -= 1
        elif direction == 2:      # up (gap in ref)
            aligned_q.append(query[i - 1])
            aligned_r.append('-')
            alignment.append(' ')
            i -= 1
        else:                     # left (gap in query)
            aligned_q.append('-')
            aligned_r.append(ref[j - 1])
            alignment.append(' ')
            j -= 1

    return (''.join(reversed(aligned_q)),
            ''.join(reversed(alignment)),
            ''.join(reversed(aligned_r)))


def align_and_show(query, ref, label=""):
    """Align `query` to `ref`, then print a tidy report: the alignment plus
    counts of matches, mismatches, gaps, and percent identity."""
    H, T, score, pos = smith_waterman(query, ref)
    aq, aln, ar = traceback(query, ref, H, T, pos)
    matches = aln.count('|')
    mismatches = aln.count('.')
    gaps = aq.count('-') + ar.count('-')
    pct_id = matches / len(aln) * 100 if aln else 0
    print(f"=== {label} ===")
    print(f"Query:     {aq}")
    print(f"Alignment: {aln}")
    print(f"Ref:       {ar}")
    print(f"Score={score:.0f}  Matches={matches}  Mismatches={mismatches}  "
          f"Gaps={gaps}  %Identity={pct_id:.1f}%")
    print()


# ---------------------------------------------------------------------------
# 3. Alignment -> CIGAR string (SAM/BAM column 6)
# ---------------------------------------------------------------------------
def cigar_for(read, ref, **kw):
    """Align `read` to `ref` and return (cigar_string, score).

    Operations: M = aligned column (match OR mismatch), I = insertion to the
    reference, D = deletion from the reference, S = soft-clip (read ends outside
    the local alignment). e.g. "4S6M2I4M".
    """
    H, T, score, (i_end, j_end) = smith_waterman(read, ref, **kw)

    # Walk the traceback, recording the operation at each step.
    i, j = i_end, j_end
    ops = []
    while H[i, j] > 0:
        d = T[i, j]
        if d == 1:            # diagonal -> aligned column (match or mismatch)
            ops.append('M'); i -= 1; j -= 1
        elif d == 2:          # up -> gap in reference -> insertion in read
            ops.append('I'); i -= 1
        else:                 # left -> gap in read -> deletion from reference
            ops.append('D'); j -= 1
    ops.reverse()

    # Local alignment: read bases before/after the aligned block are soft-clipped.
    lead = i                        # query index where the alignment started
    trail = len(read) - i_end       # query bases past where it ended
    full = (['S'] * lead) + ops + (['S'] * trail)

    # Run-length encode: e.g. M,M,M,I,I -> "3M2I".
    cigar, k = '', 0
    while k < len(full):
        n = k
        while n < len(full) and full[n] == full[k]:
            n += 1
        cigar += f"{n - k}{full[k]}"
        k = n
    return cigar, score


# ---------------------------------------------------------------------------
# 4. Pileup simulation + variant calling
# ---------------------------------------------------------------------------
def simulate_pileup(ref_len=100, n_reads=80, read_len=30,
                    snp_positions=None, error_rate=0.01):
    """Simulate short reads aligned to a random reference of length ref_len,
    optionally planting SNPs and sequencing errors.

    snp_positions maps {pos: {'alt': base, 'af': allele_frequency}}.

    Returns (reference, pileup, reads) where pileup is
    {pos: {'A':n, 'C':n, 'G':n, 'T':n, 'depth':n}}.
    """
    if snp_positions is None:
        snp_positions = {}

    reference = ''.join(random.choice('ACGT') for _ in range(ref_len))

    # Guarantee every planted SNP is a real variant: if the random reference
    # happens to carry the planted alt base at that position, reassign the alt
    # to a different base so the SNP is never a silent no-op (alt == ref).
    for _pos, _info in snp_positions.items():
        if 0 <= _pos < ref_len and _info['alt'] == reference[_pos]:
            _info['alt'] = next(b for b in 'ACGT' if b != reference[_pos])

    pileup = {i: {'A': 0, 'C': 0, 'G': 0, 'T': 0, 'depth': 0}
              for i in range(ref_len)}

    reads_data = []
    for _ in range(n_reads):
        start = random.randint(0, ref_len - read_len)
        read_seq = list(reference[start:start + read_len])

        for offset in range(read_len):
            pos = start + offset
            base = read_seq[offset]

            # Apply a planted SNP at its allele frequency.
            if pos in snp_positions and random.random() < snp_positions[pos]['af']:
                base = snp_positions[pos]['alt']
            # Random sequencing error.
            if random.random() < error_rate:
                base = random.choice([b for b in 'ACGT' if b != base])

            pileup[pos][base] += 1
            pileup[pos]['depth'] += 1

        reads_data.append((start, ''.join(read_seq)))

    return reference, pileup, reads_data


def call_variants(pileup, reference, min_af=0.1, min_depth=10):
    """Call a variant wherever the most common non-reference allele clears both
    thresholds. Returns a list of {'pos','ref','alt','af','depth'} dicts."""
    variants = []
    for pos in sorted(pileup.keys()):
        counts = pileup[pos]
        depth = counts['depth']
        if depth < min_depth:
            continue  # too little coverage to trust

        ref_base = reference[pos]
        # Among A/C/G/T, find the most common base that is NOT the reference.
        alt_base = max('ACGT', key=lambda b: counts[b] if b != ref_base else -1)
        alt_af = counts[alt_base] / depth

        if alt_af >= min_af:
            variants.append({'pos': pos, 'ref': ref_base, 'alt': alt_base,
                             'af': alt_af, 'depth': depth})
    return variants


# ---------------------------------------------------------------------------
# Standalone demo: `python align_toolkit.py`
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Plotting / convenience helpers (real-tool workflow rewrite).
# ---------------------------------------------------------------------------
def traceback_path(H, T, best_pos):
    """Return the list of (row, col) cells along the SW traceback."""
    i, j = best_pos
    path = []
    while i > 0 and j > 0 and H[i, j] > 0:
        path.append((i, j))
        d = T[i, j]
        if d == 1:
            i -= 1; j -= 1
        elif d == 2:
            i -= 1
        else:
            j -= 1
    return path


def plot_sw_matrix(H, query, ref, best_pos=None, path=None, annotate=True,
                   title="Smith-Waterman scoring matrix"):
    """Heatmap of the SW DP matrix, optionally marking best cell and traceback."""
    import matplotlib.pyplot as plt
    w = min(2 + len(ref) * 0.7, 16)
    h = min(2 + len(query) * 0.7, 8)
    fig, ax = plt.subplots(figsize=(w, h))
    im = ax.imshow(H, cmap='YlOrRd', aspect='auto', interpolation='nearest')
    plt.colorbar(im, ax=ax, label='SW score')
    if annotate:
        for i in range(H.shape[0]):
            for j in range(H.shape[1]):
                ax.text(j, i, '%.0f' % H[i, j], ha='center', va='center', fontsize=8)
    if best_pos is not None:
        ax.plot(best_pos[1], best_pos[0], 'b*', markersize=16, label='Best score')
    if path is not None:
        ax.plot([p[1] for p in path], [p[0] for p in path], 'b-o', linewidth=2,
                markersize=7, label='Traceback')
    if best_pos is not None or path is not None:
        ax.legend()
    ax.set_xticks(range(len(ref) + 1)); ax.set_xticklabels([' '] + list(ref), fontfamily='monospace')
    ax.set_yticks(range(len(query) + 1)); ax.set_yticklabels([' '] + list(query), fontfamily='monospace')
    ax.set_xlabel('Reference'); ax.set_ylabel('Query'); ax.set_title(title)
    plt.tight_layout(); plt.show()
    return fig


def plot_pileup(pileup, ref_seq, snps=None):
    """Three-panel pileup view: coverage depth, per-position AF, base composition."""
    import matplotlib.pyplot as plt
    import numpy as np
    positions = sorted(pileup.keys())
    depths = [pileup[p]['depth'] for p in positions]
    def af(p):
        d = pileup[p]['depth']
        return 0 if d == 0 else (d - pileup[p].get(ref_seq[p], 0)) / d
    afs = [af(p) for p in positions]
    fig, axes = plt.subplots(3, 1, figsize=(14, 9), sharex=True)
    axes[0].fill_between(positions, depths, color='steelblue', alpha=0.6)
    axes[0].plot(positions, depths, color='steelblue', linewidth=0.8)
    axes[0].set_ylabel('Coverage depth'); axes[0].set_title('Coverage Pileup')
    axes[1].bar(positions, afs, width=1.0, alpha=0.8,
                color=['red' if a > 0.05 else 'steelblue' for a in afs])
    axes[1].axhline(0.05, color='orange', linestyle='--', label='5% AF')
    axes[1].set_ylabel('Non-reference AF'); axes[1].set_title('Allele Frequency per Position'); axes[1].legend()
    colors = {'A': '#2ca02c', 'C': '#1f77b4', 'G': '#ff7f0e', 'T': '#d62728'}
    bottom = np.zeros(len(positions))
    for base in 'ACGT':
        vals = [pileup[p][base] / max(1, pileup[p]['depth']) for p in positions]
        axes[2].bar(positions, vals, bottom=bottom, color=colors[base], label=base, width=1.0, alpha=0.85)
        bottom += np.array(vals)
    axes[2].set_ylabel('Base proportion'); axes[2].set_xlabel('Reference position')
    axes[2].set_title('Base Composition per Position (stacked)'); axes[2].legend(loc='upper right', ncol=4)
    if snps:
        for pos in snps:
            for a in axes:
                a.axvline(pos, color='black', linestyle='--', alpha=0.4)
    plt.tight_layout(); plt.show()
    return fig


def example_depths(seed=42):
    """A realistic per-position coverage profile (stand-in for a parsed
    `samtools depth`): most of the genome near 35x, a few high-coverage
    repeats, ~4.5% uncovered."""
    import numpy as np
    np.random.seed(seed)
    d = np.concatenate([
        np.random.poisson(35, size=95000),
        np.random.poisson(200, size=500),
        np.zeros(4500, dtype=int),
    ])
    np.random.shuffle(d)
    return d.tolist()


def plot_coverage(depths):
    """Depth histogram + cumulative coverage curve."""
    import matplotlib.pyplot as plt
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].hist(depths, bins=40, color='steelblue', edgecolor='white')
    mean_d = sum(depths) / len(depths)
    axes[0].axvline(mean_d, color='red', linestyle='--', label='Mean=%.1fx' % mean_d)
    axes[0].set_xlabel('Coverage depth'); axes[0].set_ylabel('Number of positions')
    axes[0].set_title('Coverage Depth Distribution'); axes[0].legend()
    ds = sorted(depths, reverse=True); n = len(ds)
    step = max(1, max(ds) // 50)
    thr = list(range(0, max(ds) + 1, step))
    cum = [sum(1 for d in ds if d >= t) / n * 100 for t in thr]
    axes[1].plot(thr, cum, color='darkorange', linewidth=2)
    axes[1].axvline(10, color='green', linestyle='--', label='10x (germline)')
    axes[1].axvline(30, color='red', linestyle='--', label='30x (gold standard)')
    axes[1].set_xlabel('Minimum coverage depth'); axes[1].set_ylabel('% genome covered')
    axes[1].set_title('Cumulative Coverage'); axes[1].legend()
    plt.tight_layout(); plt.show()
    return fig


if __name__ == "__main__":
    random.seed(42)
    np.random.seed(42)

    print("Smith-Waterman alignment + CIGAR")
    print("-" * 40)
    align_and_show("ACGTACGT", "TTTACGTACGTGGG", label="exact match in a longer reference")
    for label, read in [("perfect", "ACGTACGT"), ("1 SNP", "ACGTAGGT"),
                        ("insertion", "ACGTTTACGT"), ("deletion", "ACGCGT")]:
        cig, sc = cigar_for(read, "NNNNACGTACGTNN")
        print(f"{label:10s} read={read:12s} CIGAR={cig:12s} score={sc:.0f}")

    print("\nPileup variant calling")
    print("-" * 40)
    snps = {25: {'alt': 'T', 'af': 0.5},
            50: {'alt': 'G', 'af': 0.95},
            75: {'alt': 'C', 'af': 0.2}}
    ref_seq, pileup, reads = simulate_pileup(
        ref_len=100, n_reads=100, read_len=30, snp_positions=snps, error_rate=0.005)
    for v in call_variants(pileup, ref_seq):
        print(f"  pos={v['pos']:3d}  {v['ref']}->{v['alt']}  "
              f"AF={v['af']:.2f}  depth={v['depth']}")
