#!/usr/bin/env Rscript
# run_deseq2.R — the real differential-expression step for Module 6.
#
# This is what a bioinformatician actually runs: DESeq2 (R/Bioconductor) reads a
# gene-by-sample count matrix, fits a negative-binomial GLM per gene, shrinks
# dispersions across genes, runs a Wald test, and writes a results table.
#
# Run it on your own machine (DESeq2 is not in the browser kernel):
#   Rscript run_deseq2.R synthetic_counts.tsv
#
# The workshop ships the output of this script (deseq2_results.tsv,
# deseq2_normalized_counts.tsv) so the notebook can load and interpret it.

suppressMessages(library(DESeq2))

args <- commandArgs(trailingOnly = TRUE)
counts_file <- if (length(args) >= 1) args[1] else "synthetic_counts.tsv"

# --- read the count matrix (genes in rows, samples in columns) ---
counts <- read.delim(counts_file, row.names = 1, check.names = FALSE)
counts <- as.matrix(counts)

# --- sample table: which columns are control vs treated ---
condition <- factor(ifelse(grepl("^ctrl", colnames(counts)), "ctrl", "treat"),
                     levels = c("ctrl", "treat"))
coldata <- data.frame(row.names = colnames(counts), condition = condition)

# --- the DESeq2 workflow: three calls do the whole analysis ---
dds <- DESeqDataSetFromMatrix(countData = counts,
                              colData   = coldata,
                              design    = ~ condition)
dds <- DESeq(dds)                                   # size factors, dispersions, GLM, Wald test
res <- results(dds, contrast = c("condition", "treat", "ctrl"))
res <- res[order(res$pvalue), ]                     # most significant first

# --- write the two tables the notebook loads ---
write.table(data.frame(gene = rownames(res), as.data.frame(res)),
            "deseq2_results.tsv", sep = "\t", quote = FALSE, row.names = FALSE)

normed <- counts(dds, normalized = TRUE)
write.table(data.frame(gene = rownames(normed), normed),
            "deseq2_normalized_counts.tsv", sep = "\t", quote = FALSE, row.names = FALSE)

# --- console summary ---
cat(sprintf("Tested %d genes; %d significant at padj < 0.05 & |log2FC| >= 1\n",
            nrow(res),
            sum(res$padj < 0.05 & abs(res$log2FoldChange) >= 1, na.rm = TRUE)))
summary(res)
