From 0d02fc65d46e9cab9e5b6a0e56b10e61475c6941 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Mon, 20 May 2024 16:54:22 -0500 Subject: [PATCH 01/35] first pass at visualization suite --- src/plot_output.py | 190 +++++++++++++++++++++ src/post_process.py | 406 ++++++++++++++++++++++++++++++++++++++++++++ visualize.py | 62 +++++++ 3 files changed, 658 insertions(+) create mode 100644 src/plot_output.py create mode 100644 src/post_process.py create mode 100644 visualize.py diff --git a/src/plot_output.py b/src/plot_output.py new file mode 100644 index 00000000..6909fd52 --- /dev/null +++ b/src/plot_output.py @@ -0,0 +1,190 @@ +import os +import matplotlib.pyplot as plt +import matplotlib.ticker as ticker +import numpy as np + +class PlotOutput: + def __init__(self, updated_gene_dict, gene_names, output_directory, reads_and_class=None, filter_transcripts=None, conditions=False, use_counts=False): + self.updated_gene_dict = updated_gene_dict + self.gene_names = gene_names + self.output_directory = output_directory + self.reads_and_class = reads_and_class + self.filter_transcripts = filter_transcripts + self.conditions = conditions + self.use_counts = use_counts + + # Create visualization subdirectory if it doesn't exist + self.visualization_dir = os.path.join(self.output_directory, "visualization") + os.makedirs(self.visualization_dir, exist_ok=True) + + def plot_transcript_map(self): + # Get the first condition's gene dictionary + first_condition = next(iter(self.updated_gene_dict)) + gene_dict = self.updated_gene_dict[first_condition] + + for gene_name in self.gene_names: + if gene_name in gene_dict: + gene_data = gene_dict[gene_name] + num_transcripts = len(gene_data['transcripts']) + plot_height = max(3, num_transcripts * 0.3) # Adjust the height dynamically + + fig, ax = plt.subplots(figsize=(12, plot_height)) # Adjust height dynamically + + if self.filter_transcripts is not None: + ax.set_title(f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']} with value over {self.filter_transcripts}") + else: + ax.set_title(f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']}") + + ax.set_xlabel("Chromosomal position") + ax.set_ylabel("Transcripts") + ax.set_yticks(range(num_transcripts)) + ax.set_yticklabels([f"{transcript_id}" for transcript_id in gene_data['transcripts'].keys()]) + + ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Ensure genomic positions are integers + ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{int(x)}')) # Format x-axis ticks as integers + + # Plot each transcript + for i, (transcript_id, transcript_info) in enumerate(gene_data['transcripts'].items()): + # Determine the direction based on the gene's strand information + direction_marker = '>' if gene_data['strand'] == '+' else '<' + + # Add a direction marker to indicate the direction of the transcript + marker_pos = transcript_info['end'] + 100 if gene_data['strand'] == '+' else transcript_info['start'] - 100 + ax.plot(marker_pos, i, marker=direction_marker, markersize=5, color="blue") + + # Draw the line for the whole transcript + ax.plot([transcript_info['start'], transcript_info['end']], [i, i], color="grey", linewidth=2) + + # Exon blocks + for exon in transcript_info['exons']: + exon_length = exon['end'] - exon['start'] + ax.add_patch(plt.Rectangle((exon['start'], i - 0.4), exon_length, 0.8, color="skyblue")) + + ax.set_xlim(gene_data['start'], gene_data['end']) + ax.invert_yaxis() # First transcript at the top + + plt.tight_layout() + plot_path = os.path.join(self.visualization_dir, f'{gene_name}_splicing.png') + plt.savefig(plot_path) # Saving plot by gene name + plt.close(fig) + + def plot_transcript_usage(self): + """ + Visualize transcript usage for each gene in gene_names across different conditions. + """ + + for gene_name in self.gene_names: + gene_data = {} + for condition, genes in self.updated_gene_dict.items(): + if gene_name in genes: + gene_data[condition] = genes[gene_name]['transcripts'] + + if not gene_data: + print(f"Gene {gene_name} not found in the data.") + continue + + conditions = list(gene_data.keys()) + n_bars = len(conditions) + + fig, ax = plt.subplots(figsize=(12, 8)) + index = np.arange(n_bars) + bar_width = 0.35 + opacity = 0.8 + + for sample_type, transcripts in gene_data.items(): + print(f"Sample Type: {sample_type}") + for transcript_id, transcript_info in transcripts.items(): + print(f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}") + # Adjusting the colors for better within-bar comparison + max_transcripts = max(len(gene_data[condition]) for condition in conditions) + colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) # Using plasma for better color gradation + + bottom_val = np.zeros(n_bars) + for i, condition in enumerate(conditions): + transcripts = gene_data[condition] + for j, (transcript_id, transcript_info) in enumerate(transcripts.items()): + color = colors[j % len(colors)] + value = transcript_info['value'] + plt.bar(i, float(value), bar_width, bottom=bottom_val[i], alpha=opacity, color=color, label=transcript_id if i == 0 else "") + bottom_val[i] += float(value) + + plt.xlabel('Sample Type') + plt.ylabel('Transcript Usage (TPM)') + plt.title(f'Transcript Usage for {gene_name} by Sample Type') + plt.xticks(index, conditions) + plt.legend(title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc='upper left') + + plt.tight_layout() + plot_path = os.path.join(self.visualization_dir, f'{gene_name}_transcript_usage_by_sample_type.png') + plt.savefig(plot_path) + plt.close(fig) + + + + +def make_pie_chart(): + + data = { + "ambiguous": 236646, + "inconsistent": 1212565, + "intergenic": 6886, + "noninformative": 130493, + "unique": 745194, + "unique_minor_difference": 79178, + } + labels = data.keys() + sizes = data.values() + total = sum(sizes) + print(total) + plt.pie(sizes, labels=labels, autopct='%1.1f%%') + plt.axis('equal') + plt.title(f"Total: {total}") + plt.show() + plt.savefig('read_assignment_pie_chart.png') + + + + + + + +def visualize_transcript_usage_single_gene(gene_data, gene_name): + """ + + :param gene_data: Dict containing transcript usage data for a given gene across different sample types. + :param gene_name: The gene to visualize. + """ + + if gene_name not in gene_data: + print(f"Gene {gene_name} not found in the data.") + return + + sample_types = gene_data[gene_name].keys() + n_bars = len(sample_types) + + fig, ax = plt.subplots(figsize=(10, 7)) + index = np.arange(n_bars) + bar_width = 0.35 + opacity = 0.8 + + # Adjusting the colors for better within-bar comparison + max_transcripts = max(len(gene_data[gene_name][sample]) for sample in sample_types) + colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) # Using plasma for better color gradation + + bottom_val = np.zeros(n_bars) + for i, sample_type in enumerate(sample_types): + transcripts = gene_data[gene_name][sample_type] + for j, (transcript_id, value) in enumerate(transcripts): + color = colors[j % len(colors)] + plt.bar(i, float(value), bar_width, bottom=bottom_val[i], alpha=opacity, color=color, label=transcript_id if i == 0 else "") + bottom_val[i] += float(value) + + plt.xlabel('Sample Type') + plt.ylabel('Transcript Usage (TPM)') + plt.title(f'Transcript Usage for {gene_name} by Sample Type') + plt.xticks(index, sample_types) + plt.legend(title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc='upper left') + + plt.tight_layout() + plt.show() + plt.savefig(f'{gene_name}_transcript_usage_by_sample_type_ref.png') \ No newline at end of file diff --git a/src/post_process.py b/src/post_process.py new file mode 100644 index 00000000..22ac2131 --- /dev/null +++ b/src/post_process.py @@ -0,0 +1,406 @@ +import csv +import os +import re +import gzip +import shutil +import copy + + +class OutputConfig: + """Class to build dictionaries from the output files of the pipeline.""" + def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): + self.output_directory = output_directory + self.log_details = {} + self.extended_annotation = None + self.read_assignments = None + self.input_gtf = gtf # Initialize with the provided gtf flag + self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. + self.conditions = False + self.gene_grouped_counts = None + self.transcript_grouped_counts = None + self.transcript_grouped_tpm = None + self.gene_grouped_tpm = None + self.gene_counts = None + self.transcript_counts = None + self.gene_tpm = None + self.transcript_tpm = None + self.transcript_model_counts = None + self.transcript_model_tpm = None + self.transcript_model_grouped_tpm = None + self.transcript_model_grouped_counts = None + self.use_counts = use_counts + self.ref_only = ref_only + + self._parse_isoquant_log() # Always parse the log + self._find_files() + self._conditional_unzip() + + # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log + if self.ref_only and not self.input_gtf: + raise ValueError("Input GTF file is required when ref_only is set. Please provide it using the --gtf flag.") + + def _parse_isoquant_log(self): + """Parse the isoquant.log for necessary configuration and commands.""" + log_path = os.path.join(self.output_directory, "isoquant.log") + if os.path.exists(log_path): + with open(log_path, 'r') as file: + log_content = file.read() + gene_db_match = re.search(r"--genedb (\S+)", log_content) + fastq_flag = "--fastq" in log_content + processing_sample_match = re.search(r"Processing sample (\S+)", log_content) + if gene_db_match and not self.input_gtf: + self.input_gtf = gene_db_match.group(1) + self.log_details['gene_db'] = self.input_gtf + self.log_details['fastq_used'] = fastq_flag + + if processing_sample_match: + self.output_directory = os.path.join(self.output_directory, processing_sample_match.group(1)) + else: + raise ValueError("Processing sample directory not found in log.") + + def _conditional_unzip(self): + """Check if unzip is needed and perform it conditionally based on the model use.""" + if self.ref_only and self.input_gtf and self.input_gtf.endswith('.gz'): + self.input_gtf = self._unzip_file(self.input_gtf) + if not self.input_gtf: + raise FileNotFoundError(f"Unable to find or unzip the specified file: {self.input_gtf}") + + def _unzip_file(self, file_path): + """Unzip a gzipped file and return the path to the uncompressed file.""" + new_path = file_path[:-3] # Remove .gz extension + + if os.path.exists(new_path): + print(f"File {new_path} already exists, using this file.") + return new_path + + if not os.path.exists(file_path): + self.gtf_flag_needed = True + return None + + with gzip.open(file_path, 'rb') as f_in: + with open(new_path, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"File {file_path} was decompressed to {new_path}.") + + return new_path + + def _find_files(self): + """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" + if not os.path.exists(self.output_directory): + print(f"Directory not found: {self.output_directory}") # Debugging output + raise FileNotFoundError("Specified sample subdirectory does not exist.") + + for file_name in os.listdir(self.output_directory): + if file_name.endswith('.extended_annotation.gtf'): + self.extended_annotation = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.read_assignments.tsv'): + self.read_assignments = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.gene_grouped_counts.tsv'): + self.conditions = True + self.gene_grouped_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_grouped_counts.tsv'): + self.transcript_grouped_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_grouped_tpm.tsv'): + self.transcript_grouped_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.gene_grouped_tpm.tsv'): + self.gene_grouped_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.gene_counts.tsv'): + self.gene_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_counts.tsv'): + self.transcript_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.gene_tpm.tsv'): + self.gene_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_tpm.tsv'): + self.transcript_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_model_counts.tsv'): + self.transcript_model_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_model_tpm.tsv'): + self.transcript_model_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_model_grouped_tpm.tsv'): + self.transcript_model_grouped_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith('.transcript_model_grouped_counts.tsv'): + self.transcript_model_grouped_counts = os.path.join(self.output_directory, file_name) + + # Determine if GTF flag is needed + if not self.input_gtf or not os.path.exists(self.input_gtf) and not os.path.exists(self.input_gtf + '.gz') and self.ref_only: + self.gtf_flag_needed = True + + # Set ref_only default based on the availability of extended_annotation + if self.ref_only is None: + self.ref_only = not self.extended_annotation + + +class DictionaryBuilder: + """Class to build dictionaries from the output files of the pipeline.""" + def __init__(self, config): + self.config = config + + def build_gene_transcript_exon_dictionaries(self): + """Builds dictionaries of genes, transcripts, and exons from the GTF file.""" + if self.config.extended_annotation and not self.config.ref_only: + return self.parse_extended_annotation() + else: + return self.parse_input_gtf() + + def build_read_assignment_and_classification_dictionaries(self): + """Indexes classifications and assignment types from the read_assignments.tsv.""" + classification_counts = {} + assignment_type_counts = {} + if not self.config.read_assignments: + raise FileNotFoundError("Read assignments file is missing.") + + with open(self.config.read_assignments, 'r') as file: + next(file) + next(file) + next(file) + for line in file: + parts = line.split('\t') + if len(parts) < 6: + continue + additional_info = parts[-1] + classification = additional_info.split('Classification=')[-1].replace(';', '').strip() + assignment_type = parts[5] + + classification_counts[classification] = classification_counts.get(classification, 0) + 1 + assignment_type_counts[assignment_type] = assignment_type_counts.get(assignment_type, 0) + 1 + + return classification_counts, assignment_type_counts + + def parse_input_gtf(self): + """Parses the GTF file to build a detailed dictionary of genes, transcripts, and exons, + while skipping entries that are not genes, transcripts, or exons.""" + gene_dict = {} + if not self.config.input_gtf: + raise FileNotFoundError("Extended annotation GTF file is missing.") + + input_gtf_path = self.config.input_gtf + + try: + # Try opening the file as a regular text file first + file = open(input_gtf_path, 'r') + except FileNotFoundError: + raise FileNotFoundError(f"Extended annotation GTF file is missing at {input_gtf_path}.") + except OSError: + # If it fails, assume it's likely gzipped and try opening it with gzip + try: + file = gzip.open(input_gtf_path, 'rt') + except FileNotFoundError: + input_gtf_path = input_gtf_path.rstrip('.gz') + try: + file = open(input_gtf_path, 'r') + except FileNotFoundError: + raise FileNotFoundError(f"Extended annotation GTF file is missing at {input_gtf_path}.") + + with file: + for line in file: + if line.startswith("#") or not line.strip(): + continue + fields = line.strip().split('\t') + if len(fields) < 9: + print(f"Skipping malformed line due to insufficient fields: {line.strip()}") + continue + + entry_type = fields[2].lower() + if entry_type not in {"gene", "transcript", "exon"}: + continue # Skip types like CDS, start_codon, etc. + + info_fields = fields[8].strip(';').split(';') + details = {field.strip().split(' ')[0]: field.strip().split(' ')[1].strip('"') for field in info_fields if ' ' in field} + + try: + if entry_type == "gene": + gene_id = details['gene_id'] + gene_dict[gene_id] = { + 'chromosome': fields[0], + 'start': int(fields[3]), + 'end': int(fields[4]), + 'strand': fields[6], + 'name': details.get('gene_name', ''), + 'biotype': details.get('gene_biotype', ''), + 'transcripts': {} + } + elif entry_type == "transcript": + transcript_id = details['transcript_id'] + gene_dict[details['gene_id']]['transcripts'][transcript_id] = { + 'start': int(fields[3]), + 'end': int(fields[4]), + 'name': details.get('transcript_name', ''), + 'biotype': details.get('transcript_biotype', ''), + 'exons': [], + 'tags': details.get('tag', '').split(',') + } + elif entry_type == "exon": + transcript_id = details['transcript_id'] + exon_info = { + 'exon_id': details['exon_id'], + 'start': int(fields[3]), + 'end': int(fields[4]), + 'number': details.get('exon_number', '') + } + gene_dict[details['gene_id']]['transcripts'][transcript_id]['exons'].append(exon_info) + except KeyError as e: + print(f"Key error in line: {line.strip()} | Missing key: {e}") + return gene_dict + + def parse_extended_annotation(self): + """Parses the GTF file to build a detailed dictionary of genes, transcripts, and exons.""" + gene_dict = {} + if not self.config.extended_annotation: + raise FileNotFoundError("Extended annotation GTF file is missing.") + + with open(self.config.extended_annotation, 'r') as file: + for line in file: + if line.startswith("#") or not line.strip(): + continue + fields = line.strip().split('\t') + if len(fields) < 9: + print(f"Skipping malformed line due to insufficient fields: {line.strip()}") + continue + + info_fields = fields[8].strip(';').split(';') + details = {field.strip().split(' ')[0]: field.strip().split(' ')[1].strip('"') for field in info_fields if ' ' in field} + + try: + if fields[2] == "gene": + gene_id = details['gene_id'] + gene_dict[gene_id] = { + 'chromosome': fields[0], + 'start': int(fields[3]), + 'end': int(fields[4]), + 'strand': fields[6], + 'name': details.get('gene_name', ''), + 'biotype': details.get('gene_biotype', ''), + 'transcripts': {} + } + elif fields[2] == "transcript": + transcript_id = details['transcript_id'] + gene_dict[details['gene_id']]['transcripts'][transcript_id] = { + 'start': int(fields[3]), + 'end': int(fields[4]), + 'exons': [] + } + elif fields[2] == "exon": + transcript_id = details['transcript_id'] + exon_info = { + 'exon_id': details['exon_id'], + 'start': int(fields[3]), + 'end': int(fields[4]) + } + gene_dict[details['gene_id']]['transcripts'][transcript_id]['exons'].append(exon_info) + except KeyError as e: + print(f"Key error in line: {line.strip()} | Missing key: {e}") + return gene_dict + + def update_gene_dict(self, gene_dict, value_df): + new_dict = {} + gene_values = {} + + # Read gene counts from value_df + with open(value_df, 'r') as file: + reader = csv.reader(file, delimiter='\t') + header = next(reader) + conditions = header[1:] # Assumes the first column is gene ID + + # Initialize gene_values dictionary + for row in reader: + gene_id = row[0] + gene_values[gene_id] = {} + for i, condition in enumerate(conditions): + if len(row) > i + 1: + value = float(row[i + 1]) + else: + value = 0.0 # Default to 0 if no value + gene_values[gene_id][condition] = value + + # Build the new dictionary structure by conditions + for condition in conditions: + new_dict[condition] = {} # Create a new sub-dictionary for each condition + + # Deep copy the gene_dict and update with values from value_df + for gene_id, gene_info in gene_dict.items(): + new_dict[condition][gene_id] = copy.deepcopy(gene_info) + if gene_id in gene_values and condition in gene_values[gene_id]: + new_dict[condition][gene_id]['value'] = gene_values[gene_id][condition] + else: + new_dict[condition][gene_id]['value'] = 0 # Default to 0 if the gene_id has no corresponding value + + return new_dict + + def update_transcript_values(self, gene_dict, value_df): + new_dict = copy.deepcopy(gene_dict) # Preserve the original structure + transcript_values = {} + + # Load transcript counts from value_df + with open(value_df, 'r') as file: + reader = csv.reader(file, delimiter='\t') + header = next(reader) + conditions = header[1:] # Assumes the first column is transcript ID + + for row in reader: + transcript_id = row[0] + for i, condition in enumerate(conditions): + if len(row) > i + 1: + value = float(row[i + 1]) + else: + value = 0.0 # Default to 0 if no value + if transcript_id not in transcript_values: + transcript_values[transcript_id] = {} + transcript_values[transcript_id][condition] = value + + # Update each condition without restructuring the original dictionary + for condition in conditions: + if condition not in new_dict: + new_dict[condition] = copy.deepcopy(gene_dict) # Make sure all genes are present + + for gene_id, gene_info in new_dict[condition].items(): + if 'transcripts' in gene_info: + for transcript_id, transcript_info in gene_info['transcripts'].items(): + if transcript_id in transcript_values and condition in transcript_values[transcript_id]: + transcript_info['value'] = transcript_values[transcript_id][condition] + else: + transcript_info['value'] = 0 # Set default if no value for this transcript + return new_dict + + def update_gene_names(self, gene_dict): + updated_dict = {} + for condition, genes in gene_dict.items(): + updated_genes = {} + for gene_id, gene_info in genes.items(): + gene_name_upper = gene_info['name'].upper() + updated_genes[gene_name_upper] = gene_info + updated_dict[condition] = updated_genes + return updated_dict + + def filter_transcripts_by_minimum_value(self, gene_dict, min_value=1.0): + # Dictionary to hold genes and transcripts that meet the criteria + transcript_passes_threshold = {} + + # First pass: Determine which transcripts meet the minimum value requirement in any condition + for condition, genes in gene_dict.items(): + for gene_id, gene_info in genes.items(): + for transcript_id, transcript_info in gene_info['transcripts'].items(): + if 'value' in transcript_info and transcript_info['value'] != 'NA' and float(transcript_info['value']) >= min_value: + if gene_id not in transcript_passes_threshold: + transcript_passes_threshold[gene_id] = {} + transcript_passes_threshold[gene_id][transcript_id] = True + + # Second pass: Build the filtered dictionary including only transcripts that have eligible values in any condition + filtered_dict = {} + for condition, genes in gene_dict.items(): + filtered_genes = {} + for gene_id, gene_info in genes.items(): + if gene_id in transcript_passes_threshold: + eligible_transcripts = {transcript_id: transcript_info for transcript_id, transcript_info in gene_info['transcripts'].items() if transcript_id in transcript_passes_threshold[gene_id]} + if eligible_transcripts: # Only add genes with non-empty transcript sets + filtered_gene_info = copy.deepcopy(gene_info) + filtered_gene_info['transcripts'] = eligible_transcripts + filtered_genes[gene_id] = filtered_gene_info + if filtered_genes: # Only add conditions with non-empty gene sets + filtered_dict[condition] = filtered_genes + + return filtered_dict + + def read_gene_list(self, gene_list_path): + with open(gene_list_path, 'r') as file: + gene_list = [line.strip().upper() for line in file] # Convert each gene to uppercase + return gene_list diff --git a/visualize.py b/visualize.py new file mode 100644 index 00000000..e5f32a9e --- /dev/null +++ b/visualize.py @@ -0,0 +1,62 @@ +from src.post_process import OutputConfig, DictionaryBuilder +from src.plot_output import PlotOutput +import argparse + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Visualize your IsoQuant output.") + parser.add_argument("output_directory", type=str, help="Directory containing IsoQuant output files.") + parser.add_argument("--gtf", type=str, help="Optional path to a GTF file if unable to be extracted from IsoQuant log", default=None) + parser.add_argument("--counts", action="store_true", help="Use counts instead of TPM files.") + parser.add_argument("--ref_only", action="store_true", help="Use only reference transcript quantification instead of transcript model quantification.") + parser.add_argument("--filter_transcripts", type=float, help="Filter transcripts by minimum value occuring in at least one condition.", default=None) + parser.add_argument("--gene_list", type=str, required=True, help="Path to a .txt file containing a list of genes, each on its own line.") + return parser.parse_args() + + +def main(): + args = parse_arguments() + output = OutputConfig(args.output_directory, use_counts=args.counts, ref_only=args.ref_only, gtf=args.gtf) + dictionary_builder = DictionaryBuilder(output) + gene_list = dictionary_builder.read_gene_list(args.gene_list) + update_names = not all(gene.startswith("ENS") for gene in gene_list) + gene_dict = dictionary_builder.build_gene_transcript_exon_dictionaries() + reads_and_class = dictionary_builder.build_read_assignment_and_classification_dictionaries() + + if output.conditions: + gene_file = output.gene_grouped_tpm if not output.use_counts else output.gene_grouped_counts + else: + gene_file = output.gene_tpm if not output.use_counts else output.gene_counts + + updated_gene_dict = dictionary_builder.update_gene_dict(gene_dict, gene_file) + + if update_names: + print("Updating gene names to gene symbols.") + updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) + + if output.ref_only or not output.extended_annotation: + print("Using reference-only based quantification.") + if output.conditions: + updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_grouped_tpm if not output.use_counts else output.transcript_grouped_counts) + else: + updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_tpm if not output.use_counts else output.transcript_counts) + else: + print("Using transcript model quantification.") + if output.conditions: + updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_model_grouped_tpm if not output.use_counts else output.transcript_model_grouped_counts) + else: + updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_model_tpm if not output.use_counts else output.transcript_model_counts) + + if args.filter_transcripts is not None: + print(f"Filtering transcripts with minimum value {args.filter_transcripts} in at least one condition.") + updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value(updated_gene_dict, min_value=args.filter_transcripts) + else: + updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value(updated_gene_dict) + + plot_output = PlotOutput(updated_gene_dict, gene_list, args.output_directory, reads_and_class, filter_transcripts=args.filter_transcripts, conditions=output.conditions, use_counts=args.counts) + plot_output.plot_transcript_map() + plot_output.plot_transcript_usage() + + +if __name__ == "__main__": + main() From e9ff3adb15af08dd50f74eb9caf4f95bf5dd6c4b Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Wed, 17 Jul 2024 17:16:55 -0500 Subject: [PATCH 02/35] save gene dict, added output dir, and made viz executable --- src/plot_output.py | 207 +++++++++++++++++++--------- src/post_process.py | 327 +++++++++++++++++++++++++++++--------------- visualize.py | 121 +++++++++++++--- 3 files changed, 459 insertions(+), 196 deletions(-) diff --git a/src/plot_output.py b/src/plot_output.py index 6909fd52..13388f0c 100644 --- a/src/plot_output.py +++ b/src/plot_output.py @@ -3,8 +3,18 @@ import matplotlib.ticker as ticker import numpy as np + class PlotOutput: - def __init__(self, updated_gene_dict, gene_names, output_directory, reads_and_class=None, filter_transcripts=None, conditions=False, use_counts=False): + def __init__( + self, + updated_gene_dict, + gene_names, + output_directory, + reads_and_class=None, + filter_transcripts=None, + conditions=False, + use_counts=False, + ): self.updated_gene_dict = updated_gene_dict self.gene_names = gene_names self.output_directory = output_directory @@ -21,50 +31,93 @@ def plot_transcript_map(self): # Get the first condition's gene dictionary first_condition = next(iter(self.updated_gene_dict)) gene_dict = self.updated_gene_dict[first_condition] - + for gene_name in self.gene_names: if gene_name in gene_dict: gene_data = gene_dict[gene_name] - num_transcripts = len(gene_data['transcripts']) - plot_height = max(3, num_transcripts * 0.3) # Adjust the height dynamically + num_transcripts = len(gene_data["transcripts"]) + plot_height = max( + 3, num_transcripts * 0.3 + ) # Adjust the height dynamically - fig, ax = plt.subplots(figsize=(12, plot_height)) # Adjust height dynamically + fig, ax = plt.subplots( + figsize=(12, plot_height) + ) # Adjust height dynamically if self.filter_transcripts is not None: - ax.set_title(f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']} with value over {self.filter_transcripts}") + ax.set_title( + f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']} with value over {self.filter_transcripts}" + ) else: - ax.set_title(f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']}") - + ax.set_title( + f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']}" + ) + ax.set_xlabel("Chromosomal position") ax.set_ylabel("Transcripts") ax.set_yticks(range(num_transcripts)) - ax.set_yticklabels([f"{transcript_id}" for transcript_id in gene_data['transcripts'].keys()]) + ax.set_yticklabels( + [ + f"{transcript_id}" + for transcript_id in gene_data["transcripts"].keys() + ] + ) - ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Ensure genomic positions are integers - ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{int(x)}')) # Format x-axis ticks as integers + ax.xaxis.set_major_locator( + ticker.MaxNLocator(integer=True) + ) # Ensure genomic positions are integers + ax.xaxis.set_major_formatter( + ticker.FuncFormatter(lambda x, pos: f"{int(x)}") + ) # Format x-axis ticks as integers # Plot each transcript - for i, (transcript_id, transcript_info) in enumerate(gene_data['transcripts'].items()): + for i, (transcript_id, transcript_info) in enumerate( + gene_data["transcripts"].items() + ): # Determine the direction based on the gene's strand information - direction_marker = '>' if gene_data['strand'] == '+' else '<' + direction_marker = ">" if gene_data["strand"] == "+" else "<" # Add a direction marker to indicate the direction of the transcript - marker_pos = transcript_info['end'] + 100 if gene_data['strand'] == '+' else transcript_info['start'] - 100 - ax.plot(marker_pos, i, marker=direction_marker, markersize=5, color="blue") + marker_pos = ( + transcript_info["end"] + 100 + if gene_data["strand"] == "+" + else transcript_info["start"] - 100 + ) + ax.plot( + marker_pos, + i, + marker=direction_marker, + markersize=5, + color="blue", + ) # Draw the line for the whole transcript - ax.plot([transcript_info['start'], transcript_info['end']], [i, i], color="grey", linewidth=2) + ax.plot( + [transcript_info["start"], transcript_info["end"]], + [i, i], + color="grey", + linewidth=2, + ) # Exon blocks - for exon in transcript_info['exons']: - exon_length = exon['end'] - exon['start'] - ax.add_patch(plt.Rectangle((exon['start'], i - 0.4), exon_length, 0.8, color="skyblue")) + for exon in transcript_info["exons"]: + exon_length = exon["end"] - exon["start"] + ax.add_patch( + plt.Rectangle( + (exon["start"], i - 0.4), + exon_length, + 0.8, + color="skyblue", + ) + ) - ax.set_xlim(gene_data['start'], gene_data['end']) + ax.set_xlim(gene_data["start"], gene_data["end"]) ax.invert_yaxis() # First transcript at the top plt.tight_layout() - plot_path = os.path.join(self.visualization_dir, f'{gene_name}_splicing.png') + plot_path = os.path.join( + self.visualization_dir, f"{gene_name}_splicing.png" + ) plt.savefig(plot_path) # Saving plot by gene name plt.close(fig) @@ -72,56 +125,73 @@ def plot_transcript_usage(self): """ Visualize transcript usage for each gene in gene_names across different conditions. """ - + for gene_name in self.gene_names: gene_data = {} for condition, genes in self.updated_gene_dict.items(): if gene_name in genes: - gene_data[condition] = genes[gene_name]['transcripts'] + gene_data[condition] = genes[gene_name]["transcripts"] if not gene_data: print(f"Gene {gene_name} not found in the data.") continue - + conditions = list(gene_data.keys()) n_bars = len(conditions) - + fig, ax = plt.subplots(figsize=(12, 8)) index = np.arange(n_bars) bar_width = 0.35 opacity = 0.8 - + for sample_type, transcripts in gene_data.items(): print(f"Sample Type: {sample_type}") for transcript_id, transcript_info in transcripts.items(): - print(f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}") + print( + f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}" + ) # Adjusting the colors for better within-bar comparison max_transcripts = max(len(gene_data[condition]) for condition in conditions) - colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) # Using plasma for better color gradation - - bottom_val = np.zeros(n_bars) + colors = plt.cm.plasma( + np.linspace(0, 1, num=max_transcripts) + ) # Using plasma for better color gradation + + bottom_val = np.zeros(n_bars) for i, condition in enumerate(conditions): transcripts = gene_data[condition] - for j, (transcript_id, transcript_info) in enumerate(transcripts.items()): + for j, (transcript_id, transcript_info) in enumerate( + transcripts.items() + ): color = colors[j % len(colors)] - value = transcript_info['value'] - plt.bar(i, float(value), bar_width, bottom=bottom_val[i], alpha=opacity, color=color, label=transcript_id if i == 0 else "") + value = transcript_info["value"] + plt.bar( + i, + float(value), + bar_width, + bottom=bottom_val[i], + alpha=opacity, + color=color, + label=transcript_id if i == 0 else "", + ) bottom_val[i] += float(value) - - plt.xlabel('Sample Type') - plt.ylabel('Transcript Usage (TPM)') - plt.title(f'Transcript Usage for {gene_name} by Sample Type') + + plt.xlabel("Sample Type") + plt.ylabel("Transcript Usage (TPM)") + plt.title(f"Transcript Usage for {gene_name} by Sample Type") plt.xticks(index, conditions) - plt.legend(title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc='upper left') - + plt.legend( + title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc="upper left" + ) + plt.tight_layout() - plot_path = os.path.join(self.visualization_dir, f'{gene_name}_transcript_usage_by_sample_type.png') + plot_path = os.path.join( + self.visualization_dir, + f"{gene_name}_transcript_usage_by_sample_type.png", + ) plt.savefig(plot_path) plt.close(fig) - - def make_pie_chart(): data = { @@ -131,21 +201,16 @@ def make_pie_chart(): "noninformative": 130493, "unique": 745194, "unique_minor_difference": 79178, - } + } labels = data.keys() sizes = data.values() total = sum(sizes) print(total) - plt.pie(sizes, labels=labels, autopct='%1.1f%%') - plt.axis('equal') + plt.pie(sizes, labels=labels, autopct="%1.1f%%") + plt.axis("equal") plt.title(f"Total: {total}") plt.show() - plt.savefig('read_assignment_pie_chart.png') - - - - - + plt.savefig("read_assignment_pie_chart.png") def visualize_transcript_usage_single_gene(gene_data, gene_name): @@ -154,37 +219,47 @@ def visualize_transcript_usage_single_gene(gene_data, gene_name): :param gene_data: Dict containing transcript usage data for a given gene across different sample types. :param gene_name: The gene to visualize. """ - + if gene_name not in gene_data: print(f"Gene {gene_name} not found in the data.") return - + sample_types = gene_data[gene_name].keys() n_bars = len(sample_types) - + fig, ax = plt.subplots(figsize=(10, 7)) index = np.arange(n_bars) bar_width = 0.35 opacity = 0.8 - + # Adjusting the colors for better within-bar comparison max_transcripts = max(len(gene_data[gene_name][sample]) for sample in sample_types) - colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) # Using plasma for better color gradation - - bottom_val = np.zeros(n_bars) + colors = plt.cm.plasma( + np.linspace(0, 1, num=max_transcripts) + ) # Using plasma for better color gradation + + bottom_val = np.zeros(n_bars) for i, sample_type in enumerate(sample_types): transcripts = gene_data[gene_name][sample_type] for j, (transcript_id, value) in enumerate(transcripts): color = colors[j % len(colors)] - plt.bar(i, float(value), bar_width, bottom=bottom_val[i], alpha=opacity, color=color, label=transcript_id if i == 0 else "") + plt.bar( + i, + float(value), + bar_width, + bottom=bottom_val[i], + alpha=opacity, + color=color, + label=transcript_id if i == 0 else "", + ) bottom_val[i] += float(value) - - plt.xlabel('Sample Type') - plt.ylabel('Transcript Usage (TPM)') - plt.title(f'Transcript Usage for {gene_name} by Sample Type') + + plt.xlabel("Sample Type") + plt.ylabel("Transcript Usage (TPM)") + plt.title(f"Transcript Usage for {gene_name} by Sample Type") plt.xticks(index, sample_types) - plt.legend(title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc='upper left') - + plt.legend(title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.tight_layout() plt.show() - plt.savefig(f'{gene_name}_transcript_usage_by_sample_type_ref.png') \ No newline at end of file + plt.savefig(f"{gene_name}_transcript_usage_by_sample_type_ref.png") diff --git a/src/post_process.py b/src/post_process.py index 22ac2131..0d94d691 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -4,10 +4,12 @@ import gzip import shutil import copy +import json class OutputConfig: """Class to build dictionaries from the output files of the pipeline.""" + def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.output_directory = output_directory self.log_details = {} @@ -37,33 +39,42 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log if self.ref_only and not self.input_gtf: - raise ValueError("Input GTF file is required when ref_only is set. Please provide it using the --gtf flag.") + raise ValueError( + "Input GTF file is required when ref_only is set. Please provide it using the --gtf flag." + ) def _parse_isoquant_log(self): """Parse the isoquant.log for necessary configuration and commands.""" log_path = os.path.join(self.output_directory, "isoquant.log") + assert os.path.exists(log_path), f"Log file not found: {log_path}" if os.path.exists(log_path): - with open(log_path, 'r') as file: + with open(log_path, "r") as file: log_content = file.read() gene_db_match = re.search(r"--genedb (\S+)", log_content) fastq_flag = "--fastq" in log_content - processing_sample_match = re.search(r"Processing sample (\S+)", log_content) + processing_sample_match = re.search( + r"Processed experiment (\S+)", log_content + ) if gene_db_match and not self.input_gtf: self.input_gtf = gene_db_match.group(1) - self.log_details['gene_db'] = self.input_gtf - self.log_details['fastq_used'] = fastq_flag + self.log_details["gene_db"] = self.input_gtf + self.log_details["fastq_used"] = fastq_flag if processing_sample_match: - self.output_directory = os.path.join(self.output_directory, processing_sample_match.group(1)) + self.output_directory = os.path.join( + self.output_directory, processing_sample_match.group(1) + ) else: raise ValueError("Processing sample directory not found in log.") def _conditional_unzip(self): """Check if unzip is needed and perform it conditionally based on the model use.""" - if self.ref_only and self.input_gtf and self.input_gtf.endswith('.gz'): + if self.ref_only and self.input_gtf and self.input_gtf.endswith(".gz"): self.input_gtf = self._unzip_file(self.input_gtf) if not self.input_gtf: - raise FileNotFoundError(f"Unable to find or unzip the specified file: {self.input_gtf}") + raise FileNotFoundError( + f"Unable to find or unzip the specified file: {self.input_gtf}" + ) def _unzip_file(self, file_path): """Unzip a gzipped file and return the path to the uncompressed file.""" @@ -77,8 +88,8 @@ def _unzip_file(self, file_path): self.gtf_flag_needed = True return None - with gzip.open(file_path, 'rb') as f_in: - with open(new_path, 'wb') as f_out: + with gzip.open(file_path, "rb") as f_in: + with open(new_path, "wb") as f_out: shutil.copyfileobj(f_in, f_out) print(f"File {file_path} was decompressed to {new_path}.") @@ -91,38 +102,63 @@ def _find_files(self): raise FileNotFoundError("Specified sample subdirectory does not exist.") for file_name in os.listdir(self.output_directory): - if file_name.endswith('.extended_annotation.gtf'): - self.extended_annotation = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.read_assignments.tsv'): + if file_name.endswith(".extended_annotation.gtf"): + self.extended_annotation = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".read_assignments.tsv"): self.read_assignments = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.gene_grouped_counts.tsv'): + elif file_name.endswith(".read_assignments.tsv.gz"): + self.read_assignments = self._unzip_file( + os.path.join(self.output_directory, file_name) + ) + elif file_name.endswith(".gene_grouped_counts.tsv"): self.conditions = True - self.gene_grouped_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_grouped_counts.tsv'): - self.transcript_grouped_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_grouped_tpm.tsv'): - self.transcript_grouped_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.gene_grouped_tpm.tsv'): + self.gene_grouped_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_grouped_counts.tsv"): + self.transcript_grouped_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_grouped_tpm.tsv"): + self.transcript_grouped_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".gene_grouped_tpm.tsv"): self.gene_grouped_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.gene_counts.tsv'): + elif file_name.endswith(".gene_counts.tsv"): self.gene_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_counts.tsv'): + elif file_name.endswith(".transcript_counts.tsv"): self.transcript_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.gene_tpm.tsv'): + elif file_name.endswith(".gene_tpm.tsv"): self.gene_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_tpm.tsv'): + elif file_name.endswith(".transcript_tpm.tsv"): self.transcript_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_model_counts.tsv'): - self.transcript_model_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_model_tpm.tsv'): - self.transcript_model_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_model_grouped_tpm.tsv'): - self.transcript_model_grouped_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith('.transcript_model_grouped_counts.tsv'): - self.transcript_model_grouped_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_model_counts.tsv"): + self.transcript_model_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_tpm.tsv"): + self.transcript_model_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_grouped_tpm.tsv"): + self.transcript_model_grouped_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_grouped_counts.tsv"): + self.transcript_model_grouped_counts = os.path.join( + self.output_directory, file_name + ) # Determine if GTF flag is needed - if not self.input_gtf or not os.path.exists(self.input_gtf) and not os.path.exists(self.input_gtf + '.gz') and self.ref_only: + if ( + not self.input_gtf + or not os.path.exists(self.input_gtf) + and not os.path.exists(self.input_gtf + ".gz") + and self.ref_only + ): self.gtf_flag_needed = True # Set ref_only default based on the availability of extended_annotation @@ -132,6 +168,7 @@ def _find_files(self): class DictionaryBuilder: """Class to build dictionaries from the output files of the pipeline.""" + def __init__(self, config): self.config = config @@ -149,20 +186,28 @@ def build_read_assignment_and_classification_dictionaries(self): if not self.config.read_assignments: raise FileNotFoundError("Read assignments file is missing.") - with open(self.config.read_assignments, 'r') as file: + with open(self.config.read_assignments, "r") as file: next(file) next(file) next(file) for line in file: - parts = line.split('\t') + parts = line.split("\t") if len(parts) < 6: continue additional_info = parts[-1] - classification = additional_info.split('Classification=')[-1].replace(';', '').strip() + classification = ( + additional_info.split("Classification=")[-1] + .replace(";", "") + .strip() + ) assignment_type = parts[5] - classification_counts[classification] = classification_counts.get(classification, 0) + 1 - assignment_type_counts[assignment_type] = assignment_type_counts.get(assignment_type, 0) + 1 + classification_counts[classification] = ( + classification_counts.get(classification, 0) + 1 + ) + assignment_type_counts[assignment_type] = ( + assignment_type_counts.get(assignment_type, 0) + 1 + ) return classification_counts, assignment_type_counts @@ -177,67 +222,79 @@ def parse_input_gtf(self): try: # Try opening the file as a regular text file first - file = open(input_gtf_path, 'r') + file = open(input_gtf_path, "r") except FileNotFoundError: - raise FileNotFoundError(f"Extended annotation GTF file is missing at {input_gtf_path}.") + raise FileNotFoundError( + f"Extended annotation GTF file is missing at {input_gtf_path}." + ) except OSError: # If it fails, assume it's likely gzipped and try opening it with gzip try: - file = gzip.open(input_gtf_path, 'rt') + file = gzip.open(input_gtf_path, "rt") except FileNotFoundError: - input_gtf_path = input_gtf_path.rstrip('.gz') + input_gtf_path = input_gtf_path.rstrip(".gz") try: - file = open(input_gtf_path, 'r') + file = open(input_gtf_path, "r") except FileNotFoundError: - raise FileNotFoundError(f"Extended annotation GTF file is missing at {input_gtf_path}.") + raise FileNotFoundError( + f"Extended annotation GTF file is missing at {input_gtf_path}." + ) with file: for line in file: if line.startswith("#") or not line.strip(): continue - fields = line.strip().split('\t') + fields = line.strip().split("\t") if len(fields) < 9: - print(f"Skipping malformed line due to insufficient fields: {line.strip()}") + print( + f"Skipping malformed line due to insufficient fields: {line.strip()}" + ) continue entry_type = fields[2].lower() if entry_type not in {"gene", "transcript", "exon"}: continue # Skip types like CDS, start_codon, etc. - info_fields = fields[8].strip(';').split(';') - details = {field.strip().split(' ')[0]: field.strip().split(' ')[1].strip('"') for field in info_fields if ' ' in field} + info_fields = fields[8].strip(";").split(";") + details = { + field.strip().split(" ")[0]: field.strip().split(" ")[1].strip('"') + for field in info_fields + if " " in field + } try: if entry_type == "gene": - gene_id = details['gene_id'] + gene_id = details["gene_id"] gene_dict[gene_id] = { - 'chromosome': fields[0], - 'start': int(fields[3]), - 'end': int(fields[4]), - 'strand': fields[6], - 'name': details.get('gene_name', ''), - 'biotype': details.get('gene_biotype', ''), - 'transcripts': {} + "chromosome": fields[0], + "start": int(fields[3]), + "end": int(fields[4]), + "strand": fields[6], + "name": details.get("gene_name", ""), + "biotype": details.get("gene_biotype", ""), + "transcripts": {}, } elif entry_type == "transcript": - transcript_id = details['transcript_id'] - gene_dict[details['gene_id']]['transcripts'][transcript_id] = { - 'start': int(fields[3]), - 'end': int(fields[4]), - 'name': details.get('transcript_name', ''), - 'biotype': details.get('transcript_biotype', ''), - 'exons': [], - 'tags': details.get('tag', '').split(',') + transcript_id = details["transcript_id"] + gene_dict[details["gene_id"]]["transcripts"][transcript_id] = { + "start": int(fields[3]), + "end": int(fields[4]), + "name": details.get("transcript_name", ""), + "biotype": details.get("transcript_biotype", ""), + "exons": [], + "tags": details.get("tag", "").split(","), } elif entry_type == "exon": - transcript_id = details['transcript_id'] + transcript_id = details["transcript_id"] exon_info = { - 'exon_id': details['exon_id'], - 'start': int(fields[3]), - 'end': int(fields[4]), - 'number': details.get('exon_number', '') + "exon_id": details["exon_id"], + "start": int(fields[3]), + "end": int(fields[4]), + "number": details.get("exon_number", ""), } - gene_dict[details['gene_id']]['transcripts'][transcript_id]['exons'].append(exon_info) + gene_dict[details["gene_id"]]["transcripts"][transcript_id][ + "exons" + ].append(exon_info) except KeyError as e: print(f"Key error in line: {line.strip()} | Missing key: {e}") return gene_dict @@ -248,45 +305,53 @@ def parse_extended_annotation(self): if not self.config.extended_annotation: raise FileNotFoundError("Extended annotation GTF file is missing.") - with open(self.config.extended_annotation, 'r') as file: + with open(self.config.extended_annotation, "r") as file: for line in file: if line.startswith("#") or not line.strip(): continue - fields = line.strip().split('\t') + fields = line.strip().split("\t") if len(fields) < 9: - print(f"Skipping malformed line due to insufficient fields: {line.strip()}") + print( + f"Skipping malformed line due to insufficient fields: {line.strip()}" + ) continue - info_fields = fields[8].strip(';').split(';') - details = {field.strip().split(' ')[0]: field.strip().split(' ')[1].strip('"') for field in info_fields if ' ' in field} + info_fields = fields[8].strip(";").split(";") + details = { + field.strip().split(" ")[0]: field.strip().split(" ")[1].strip('"') + for field in info_fields + if " " in field + } try: if fields[2] == "gene": - gene_id = details['gene_id'] + gene_id = details["gene_id"] gene_dict[gene_id] = { - 'chromosome': fields[0], - 'start': int(fields[3]), - 'end': int(fields[4]), - 'strand': fields[6], - 'name': details.get('gene_name', ''), - 'biotype': details.get('gene_biotype', ''), - 'transcripts': {} + "chromosome": fields[0], + "start": int(fields[3]), + "end": int(fields[4]), + "strand": fields[6], + "name": details.get("gene_name", ""), + "biotype": details.get("gene_biotype", ""), + "transcripts": {}, } elif fields[2] == "transcript": - transcript_id = details['transcript_id'] - gene_dict[details['gene_id']]['transcripts'][transcript_id] = { - 'start': int(fields[3]), - 'end': int(fields[4]), - 'exons': [] + transcript_id = details["transcript_id"] + gene_dict[details["gene_id"]]["transcripts"][transcript_id] = { + "start": int(fields[3]), + "end": int(fields[4]), + "exons": [], } elif fields[2] == "exon": - transcript_id = details['transcript_id'] + transcript_id = details["transcript_id"] exon_info = { - 'exon_id': details['exon_id'], - 'start': int(fields[3]), - 'end': int(fields[4]) + "exon_id": details["exon_id"], + "start": int(fields[3]), + "end": int(fields[4]), } - gene_dict[details['gene_id']]['transcripts'][transcript_id]['exons'].append(exon_info) + gene_dict[details["gene_id"]]["transcripts"][transcript_id][ + "exons" + ].append(exon_info) except KeyError as e: print(f"Key error in line: {line.strip()} | Missing key: {e}") return gene_dict @@ -296,8 +361,8 @@ def update_gene_dict(self, gene_dict, value_df): gene_values = {} # Read gene counts from value_df - with open(value_df, 'r') as file: - reader = csv.reader(file, delimiter='\t') + with open(value_df, "r") as file: + reader = csv.reader(file, delimiter="\t") header = next(reader) conditions = header[1:] # Assumes the first column is gene ID @@ -320,9 +385,13 @@ def update_gene_dict(self, gene_dict, value_df): for gene_id, gene_info in gene_dict.items(): new_dict[condition][gene_id] = copy.deepcopy(gene_info) if gene_id in gene_values and condition in gene_values[gene_id]: - new_dict[condition][gene_id]['value'] = gene_values[gene_id][condition] + new_dict[condition][gene_id]["value"] = gene_values[gene_id][ + condition + ] else: - new_dict[condition][gene_id]['value'] = 0 # Default to 0 if the gene_id has no corresponding value + new_dict[condition][gene_id][ + "value" + ] = 0 # Default to 0 if the gene_id has no corresponding value return new_dict @@ -331,8 +400,8 @@ def update_transcript_values(self, gene_dict, value_df): transcript_values = {} # Load transcript counts from value_df - with open(value_df, 'r') as file: - reader = csv.reader(file, delimiter='\t') + with open(value_df, "r") as file: + reader = csv.reader(file, delimiter="\t") header = next(reader) conditions = header[1:] # Assumes the first column is transcript ID @@ -350,15 +419,26 @@ def update_transcript_values(self, gene_dict, value_df): # Update each condition without restructuring the original dictionary for condition in conditions: if condition not in new_dict: - new_dict[condition] = copy.deepcopy(gene_dict) # Make sure all genes are present + new_dict[condition] = copy.deepcopy( + gene_dict + ) # Make sure all genes are present for gene_id, gene_info in new_dict[condition].items(): - if 'transcripts' in gene_info: - for transcript_id, transcript_info in gene_info['transcripts'].items(): - if transcript_id in transcript_values and condition in transcript_values[transcript_id]: - transcript_info['value'] = transcript_values[transcript_id][condition] + if "transcripts" in gene_info: + for transcript_id, transcript_info in gene_info[ + "transcripts" + ].items(): + if ( + transcript_id in transcript_values + and condition in transcript_values[transcript_id] + ): + transcript_info["value"] = transcript_values[transcript_id][ + condition + ] else: - transcript_info['value'] = 0 # Set default if no value for this transcript + transcript_info["value"] = ( + 0 # Set default if no value for this transcript + ) return new_dict def update_gene_names(self, gene_dict): @@ -366,7 +446,7 @@ def update_gene_names(self, gene_dict): for condition, genes in gene_dict.items(): updated_genes = {} for gene_id, gene_info in genes.items(): - gene_name_upper = gene_info['name'].upper() + gene_name_upper = gene_info["name"].upper() updated_genes[gene_name_upper] = gene_info updated_dict[condition] = updated_genes return updated_dict @@ -378,8 +458,12 @@ def filter_transcripts_by_minimum_value(self, gene_dict, min_value=1.0): # First pass: Determine which transcripts meet the minimum value requirement in any condition for condition, genes in gene_dict.items(): for gene_id, gene_info in genes.items(): - for transcript_id, transcript_info in gene_info['transcripts'].items(): - if 'value' in transcript_info and transcript_info['value'] != 'NA' and float(transcript_info['value']) >= min_value: + for transcript_id, transcript_info in gene_info["transcripts"].items(): + if ( + "value" in transcript_info + and transcript_info["value"] != "NA" + and float(transcript_info["value"]) >= min_value + ): if gene_id not in transcript_passes_threshold: transcript_passes_threshold[gene_id] = {} transcript_passes_threshold[gene_id][transcript_id] = True @@ -390,10 +474,18 @@ def filter_transcripts_by_minimum_value(self, gene_dict, min_value=1.0): filtered_genes = {} for gene_id, gene_info in genes.items(): if gene_id in transcript_passes_threshold: - eligible_transcripts = {transcript_id: transcript_info for transcript_id, transcript_info in gene_info['transcripts'].items() if transcript_id in transcript_passes_threshold[gene_id]} - if eligible_transcripts: # Only add genes with non-empty transcript sets + eligible_transcripts = { + transcript_id: transcript_info + for transcript_id, transcript_info in gene_info[ + "transcripts" + ].items() + if transcript_id in transcript_passes_threshold[gene_id] + } + if ( + eligible_transcripts + ): # Only add genes with non-empty transcript sets filtered_gene_info = copy.deepcopy(gene_info) - filtered_gene_info['transcripts'] = eligible_transcripts + filtered_gene_info["transcripts"] = eligible_transcripts filtered_genes[gene_id] = filtered_gene_info if filtered_genes: # Only add conditions with non-empty gene sets filtered_dict[condition] = filtered_genes @@ -401,6 +493,15 @@ def filter_transcripts_by_minimum_value(self, gene_dict, min_value=1.0): return filtered_dict def read_gene_list(self, gene_list_path): - with open(gene_list_path, 'r') as file: - gene_list = [line.strip().upper() for line in file] # Convert each gene to uppercase + with open(gene_list_path, "r") as file: + gene_list = [ + line.strip().upper() for line in file + ] # Convert each gene to uppercase return gene_list + + def save_gene_dict_to_json(self, gene_dict, output_path): + """Saves the gene dictionary to a JSON file.""" + # name the gene_dict file + output_path = os.path.join(self.config.output_directory, "gene_dict.json") + with open(output_path, "w") as file: + json.dump(gene_dict, file, indent=4) diff --git a/visualize.py b/visualize.py index e5f32a9e..97e56b28 100644 --- a/visualize.py +++ b/visualize.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + from src.post_process import OutputConfig, DictionaryBuilder from src.plot_output import PlotOutput import argparse @@ -5,26 +7,66 @@ def parse_arguments(): parser = argparse.ArgumentParser(description="Visualize your IsoQuant output.") - parser.add_argument("output_directory", type=str, help="Directory containing IsoQuant output files.") - parser.add_argument("--gtf", type=str, help="Optional path to a GTF file if unable to be extracted from IsoQuant log", default=None) - parser.add_argument("--counts", action="store_true", help="Use counts instead of TPM files.") - parser.add_argument("--ref_only", action="store_true", help="Use only reference transcript quantification instead of transcript model quantification.") - parser.add_argument("--filter_transcripts", type=float, help="Filter transcripts by minimum value occuring in at least one condition.", default=None) - parser.add_argument("--gene_list", type=str, required=True, help="Path to a .txt file containing a list of genes, each on its own line.") + parser.add_argument( + "output_directory", type=str, help="Directory containing IsoQuant output files." + ) + parser.add_argument( + "--viz_output", + type=str, + help="Optional directory to save visualization output files, defaults to the main output directory.", + default=None, + ) + parser.add_argument( + "--gtf", + type=str, + help="Optional path to a GTF file if unable to be extracted from IsoQuant log", + default=None, + ) + parser.add_argument( + "--counts", action="store_true", help="Use counts instead of TPM files." + ) + parser.add_argument( + "--ref_only", + action="store_true", + help="Use only reference transcript quantification instead of transcript model quantification.", + ) + parser.add_argument( + "--filter_transcripts", + type=float, + help="Filter transcripts by minimum value occuring in at least one condition.", + default=None, + ) + parser.add_argument( + "--gene_list", + type=str, + required=True, + help="Path to a .txt file containing a list of genes, each on its own line.", + ) return parser.parse_args() def main(): args = parse_arguments() - output = OutputConfig(args.output_directory, use_counts=args.counts, ref_only=args.ref_only, gtf=args.gtf) + output = OutputConfig( + args.output_directory, + use_counts=args.counts, + ref_only=args.ref_only, + gtf=args.gtf, + ) dictionary_builder = DictionaryBuilder(output) gene_list = dictionary_builder.read_gene_list(args.gene_list) update_names = not all(gene.startswith("ENS") for gene in gene_list) gene_dict = dictionary_builder.build_gene_transcript_exon_dictionaries() - reads_and_class = dictionary_builder.build_read_assignment_and_classification_dictionaries() + reads_and_class = ( + dictionary_builder.build_read_assignment_and_classification_dictionaries() + ) if output.conditions: - gene_file = output.gene_grouped_tpm if not output.use_counts else output.gene_grouped_counts + gene_file = ( + output.gene_grouped_tpm + if not output.use_counts + else output.gene_grouped_counts + ) else: gene_file = output.gene_tpm if not output.use_counts else output.gene_counts @@ -37,23 +79,68 @@ def main(): if output.ref_only or not output.extended_annotation: print("Using reference-only based quantification.") if output.conditions: - updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_grouped_tpm if not output.use_counts else output.transcript_grouped_counts) + updated_gene_dict = dictionary_builder.update_transcript_values( + updated_gene_dict, + ( + output.transcript_grouped_tpm + if not output.use_counts + else output.transcript_grouped_counts + ), + ) else: - updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_tpm if not output.use_counts else output.transcript_counts) + updated_gene_dict = dictionary_builder.update_transcript_values( + updated_gene_dict, + ( + output.transcript_tpm + if not output.use_counts + else output.transcript_counts + ), + ) else: print("Using transcript model quantification.") if output.conditions: - updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_model_grouped_tpm if not output.use_counts else output.transcript_model_grouped_counts) + updated_gene_dict = dictionary_builder.update_transcript_values( + updated_gene_dict, + ( + output.transcript_model_grouped_tpm + if not output.use_counts + else output.transcript_model_grouped_counts + ), + ) else: - updated_gene_dict = dictionary_builder.update_transcript_values(updated_gene_dict, output.transcript_model_tpm if not output.use_counts else output.transcript_model_counts) + updated_gene_dict = dictionary_builder.update_transcript_values( + updated_gene_dict, + ( + output.transcript_model_tpm + if not output.use_counts + else output.transcript_model_counts + ), + ) if args.filter_transcripts is not None: - print(f"Filtering transcripts with minimum value {args.filter_transcripts} in at least one condition.") - updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value(updated_gene_dict, min_value=args.filter_transcripts) + print( + f"Filtering transcripts with minimum value {args.filter_transcripts} in at least one condition." + ) + updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value( + updated_gene_dict, min_value=args.filter_transcripts + ) else: - updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value(updated_gene_dict) + updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value( + updated_gene_dict + ) - plot_output = PlotOutput(updated_gene_dict, gene_list, args.output_directory, reads_and_class, filter_transcripts=args.filter_transcripts, conditions=output.conditions, use_counts=args.counts) + # Visualization output directory decision + viz_output_directory = args.viz_output if args.viz_output else args.output_directory + dictionary_builder.save_gene_dict_to_json(updated_gene_dict, args.output_directory) + plot_output = PlotOutput( + updated_gene_dict, + gene_list, + viz_output_directory, + reads_and_class, + filter_transcripts=args.filter_transcripts, + conditions=output.conditions, + use_counts=args.counts, + ) plot_output.plot_transcript_map() plot_output.plot_transcript_usage() From 2cde19878ca791aae531967f4dcccbe8723a3ba9 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Sat, 20 Jul 2024 20:18:24 -0500 Subject: [PATCH 03/35] add pie charts and find_genes algo --- src/gene_model.py | 277 ++++++++++++++++++++++++++++++++++++++++++++ src/plot_output.py | 134 +++++++++------------ src/post_process.py | 57 ++++----- src/process_dict.py | 92 +++++++++++++++ visualize.py | 37 +++++- 5 files changed, 487 insertions(+), 110 deletions(-) create mode 100644 src/gene_model.py create mode 100644 src/process_dict.py diff --git a/src/gene_model.py b/src/gene_model.py new file mode 100644 index 00000000..73c1ae88 --- /dev/null +++ b/src/gene_model.py @@ -0,0 +1,277 @@ +import json +import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.spatial.distance import euclidean + + +def parse_data(data): + genes = {} + for condition, condition_data in data.items(): + for gene, gene_data in condition_data.items(): + if gene not in genes: + genes[gene] = { + "chromosome": gene_data["chromosome"], + "start": gene_data["start"], + "end": gene_data["end"], + "strand": gene_data["strand"], + "biotype": gene_data["biotype"], + "transcripts": {}, + } + genes[gene]["transcripts"][condition] = gene_data["transcripts"] + genes[gene][condition] = gene_data["value"] + return genes + + +def calculate_deviance(wt_transcripts, condition_transcripts): + all_transcripts = set(wt_transcripts.keys()).union( + set(condition_transcripts.keys()) + ) + + wt_proportions = [wt_transcripts.get(t, 0) for t in all_transcripts] + condition_proportions = [condition_transcripts.get(t, 0) for t in all_transcripts] + + total_wt = sum(wt_proportions) + total_condition = sum(condition_proportions) + + if total_wt > 0: + wt_proportions = [p / total_wt for p in wt_proportions] + if total_condition > 0: + condition_proportions = [p / total_condition for p in condition_proportions] + + distance = euclidean(wt_proportions, condition_proportions) + + # Reduce distance if total unique transcripts are 1 + if len(all_transcripts) == 1: + distance *= 0.7 + + return distance + + +def calculate_metrics(genes): + metrics = [] + for gene, gene_data in genes.items(): + wt_transcripts = gene_data["transcripts"].get("wild_type", {}) + + for condition in gene_data: + if condition in [ + "chromosome", + "start", + "end", + "strand", + "biotype", + "transcripts", + "wild_type", + ]: + continue + condition_transcripts = gene_data["transcripts"].get(condition, {}) + deviance = calculate_deviance(wt_transcripts, condition_transcripts) + metrics.append({"gene": gene, "condition": condition, "deviance": deviance}) + + value = gene_data.get(condition, 0) + wt_value = gene_data.get("wild_type", 0) + abs_diff = abs(value - wt_value) + metrics.append( + { + "gene": gene, + "condition": condition, + "value": value, + "abs_diff": abs_diff, + } + ) + + return pd.DataFrame(metrics) + + +def rank_genes(df): + value_ranking = df.groupby("gene")["value"].mean().reset_index() + abs_diff_ranking = df.groupby("gene")["abs_diff"].mean().reset_index() + deviance_ranking = df.groupby("gene")["deviance"].mean().reset_index() + + value_ranking["rank_value"] = value_ranking["value"].rank(ascending=False) + abs_diff_ranking["rank_abs_diff"] = abs_diff_ranking["abs_diff"].rank( + ascending=False + ) + deviance_ranking["rank_deviance"] = deviance_ranking["deviance"].rank( + ascending=False + ) + + merged_df = value_ranking[["gene", "rank_value"]].merge( + abs_diff_ranking[["gene", "rank_abs_diff"]], on="gene" + ) + merged_df = merged_df.merge(deviance_ranking[["gene", "rank_deviance"]], on="gene") + + # Devalue the importance of overall expression by reducing its weight + merged_df["combined_rank"] = ( + merged_df["rank_value"] # Reduced weight for rank_value + + merged_df["rank_abs_diff"] + + merged_df["rank_deviance"] + ) + + top_combined_ranking = merged_df.sort_values(by="combined_rank").head(10) + top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(10) + top_100_combined_ranking = merged_df.sort_values(by="combined_rank").head(100) + + return ( + top_combined_ranking, + top_deviance_ranking, + top_100_combined_ranking, + merged_df, + ) + + +def visualize_ranking( + top_combined_ranking, top_deviance_ranking, merged_df, output_dir +): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Bar plot for combined rank + plt.figure(figsize=(12, 8)) + sns.barplot( + x="combined_rank", y="gene", data=top_combined_ranking, palette="viridis" + ) + plt.title("Top 10 Genes by Combined Ranking") + plt.xlabel("Combined Rank") + plt.ylabel("Gene") + plt.savefig(os.path.join(output_dir, "top_genes_combined_ranking.png"), dpi=300) + plt.close() + + # Heatmap for metric ranks + top_genes = top_combined_ranking["gene"].tolist() + heatmap_data = merged_df[merged_df["gene"].isin(top_genes)] + heatmap_data = heatmap_data.set_index("gene")[ + ["rank_value", "rank_abs_diff", "rank_deviance"] + ] + + plt.figure(figsize=(12, 8)) + sns.heatmap( + heatmap_data, + annot=True, + cmap="RdBu_r", + linewidths=0.5, + cbar_kws={"label": "Rank"}, + ) + plt.title("Metric Ranks for Top 10 Genes") + plt.savefig(os.path.join(output_dir, "metric_ranks_heatmap.png"), dpi=300) + plt.close() + + # Diverging bar plot for deviance + plt.figure(figsize=(12, 8)) + sns.barplot( + x="rank_deviance", + y="gene", + data=top_deviance_ranking, + palette="coolwarm", + orient="h", + ) + plt.title("Top 10 Genes by Transcript Deviance from Wild Type") + plt.xlabel("Rank of Deviance from Wild Type") + plt.ylabel("Gene") + plt.axvline(x=0, color="grey", linestyle="--") + plt.savefig(os.path.join(output_dir, "deviance_from_wild_type.png"), dpi=300) + plt.close() + + # Scatter plot for rank_value vs rank_abs_diff + plt.figure(figsize=(12, 8)) + sns.scatterplot( + x="rank_value", + y="rank_abs_diff", + hue="gene", + data=top_deviance_ranking, + palette="deep", + s=100, + ) + plt.title("Rank Value vs Rank Absolute Difference") + plt.xlabel("Rank Value") + plt.ylabel("Rank Absolute Difference") + plt.savefig(os.path.join(output_dir, "rank_value_vs_rank_abs_diff.png"), dpi=300) + plt.close() + + # Combined multi-metric visualization + fig, axes = plt.subplots(2, 2, figsize=(20, 16)) + sns.barplot( + x="combined_rank", + y="gene", + data=top_combined_ranking, + palette="viridis", + ax=axes[0, 0], + ) + axes[0, 0].set_title("Combined Rank") + axes[0, 0].set_xlabel("Combined Rank") + axes[0, 0].set_ylabel("Gene") + + sns.heatmap( + heatmap_data, + annot=True, + cmap="RdBu_r", + linewidths=0.5, + cbar_kws={"label": "Rank"}, + ax=axes[0, 1], + ) + axes[0, 1].set_title("Metric Ranks") + + sns.barplot( + x="rank_deviance", + y="gene", + data=top_deviance_ranking, + palette="coolwarm", + orient="h", + ax=axes[1, 0], + ) + axes[1, 0].set_title("Transcript Deviance from Wild Type") + axes[1, 0].set_xlabel("Rank of Deviance from Wild Type") + axes[1, 0].set_ylabel("Gene") + axes[1, 0].axvline(x=0, color="grey", linestyle="--") + + sns.scatterplot( + x="rank_value", + y="rank_abs_diff", + hue="gene", + data=top_deviance_ranking, + palette="deep", + s=100, + ax=axes[1, 1], + ) + axes[1, 1].set_title("Rank Value vs Rank Absolute Difference") + axes[1, 1].set_xlabel("Rank Value") + axes[1, 1].set_ylabel("Rank Absolute Difference") + + plt.tight_layout() + plt.savefig(os.path.join(output_dir, "combined_visualization.png"), dpi=300) + plt.close() + + +def save_top_genes(top_combined_ranking, output_dir, num_genes): + top_combined_ranking.head(num_genes)[["gene"]].to_csv( + os.path.join(output_dir, f"top_{num_genes}_genes.txt"), + index=False, + header=False, + sep="\t", + ) + return os.path.join(output_dir, f"top_{num_genes}_genes.txt") + + +def rank_and_visualize_genes(input_data, output_dir, num_genes=100): + genes = parse_data(input_data) + metrics_df = calculate_metrics(genes) + top_combined_ranking, top_deviance_ranking, top_100_combined_ranking, merged_df = ( + rank_genes(metrics_df) + ) + + top_combined_ranking = merged_df.sort_values(by="combined_rank").head(num_genes) + top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(num_genes) + + visualize_ranking(top_combined_ranking, top_deviance_ranking, merged_df, output_dir) + path = save_top_genes(top_combined_ranking, output_dir, num_genes) + + print(f"\nTop {num_genes} Genes by Combined Ranking:") + print(top_combined_ranking[["gene", "combined_rank"]]) + print(f"\nDetailed Metrics for Top {num_genes} Genes by Combined Ranking:") + print(top_combined_ranking) + print(f"\nNumber of Genes: {len(merged_df)}") + print(f"\nTop {num_genes} Genes by Transcript Deviance from Wild Type:") + print(top_deviance_ranking) + return path diff --git a/src/plot_output.py b/src/plot_output.py index 13388f0c..51242cc8 100644 --- a/src/plot_output.py +++ b/src/plot_output.py @@ -10,6 +10,7 @@ def __init__( updated_gene_dict, gene_names, output_directory, + create_visualization_subdir=False, reads_and_class=None, filter_transcripts=None, conditions=False, @@ -23,9 +24,14 @@ def __init__( self.conditions = conditions self.use_counts = use_counts - # Create visualization subdirectory if it doesn't exist - self.visualization_dir = os.path.join(self.output_directory, "visualization") - os.makedirs(self.visualization_dir, exist_ok=True) + # Create visualization subdirectory if specified + if create_visualization_subdir: + self.visualization_dir = os.path.join( + self.output_directory, "visualization" + ) + os.makedirs(self.visualization_dir, exist_ok=True) + else: + self.visualization_dir = self.output_directory def plot_transcript_map(self): # Get the first condition's gene dictionary @@ -76,8 +82,6 @@ def plot_transcript_map(self): ): # Determine the direction based on the gene's strand information direction_marker = ">" if gene_data["strand"] == "+" else "<" - - # Add a direction marker to indicate the direction of the transcript marker_pos = ( transcript_info["end"] + 100 if gene_data["strand"] == "+" @@ -144,12 +148,12 @@ def plot_transcript_usage(self): bar_width = 0.35 opacity = 0.8 - for sample_type, transcripts in gene_data.items(): - print(f"Sample Type: {sample_type}") - for transcript_id, transcript_info in transcripts.items(): - print( - f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}" - ) + # for sample_type, transcripts in gene_data.items(): + # print(f"Sample Type: {sample_type}") + # for transcript_id, transcript_info in transcripts.items(): + # print( + # f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}" + # ) # Adjusting the colors for better within-bar comparison max_transcripts = max(len(gene_data[condition]) for condition in conditions) colors = plt.cm.plasma( @@ -191,75 +195,45 @@ def plot_transcript_usage(self): plt.savefig(plot_path) plt.close(fig) - -def make_pie_chart(): - - data = { - "ambiguous": 236646, - "inconsistent": 1212565, - "intergenic": 6886, - "noninformative": 130493, - "unique": 745194, - "unique_minor_difference": 79178, - } - labels = data.keys() - sizes = data.values() - total = sum(sizes) - print(total) - plt.pie(sizes, labels=labels, autopct="%1.1f%%") - plt.axis("equal") - plt.title(f"Total: {total}") - plt.show() - plt.savefig("read_assignment_pie_chart.png") - - -def visualize_transcript_usage_single_gene(gene_data, gene_name): - """ - - :param gene_data: Dict containing transcript usage data for a given gene across different sample types. - :param gene_name: The gene to visualize. - """ - - if gene_name not in gene_data: - print(f"Gene {gene_name} not found in the data.") - return - - sample_types = gene_data[gene_name].keys() - n_bars = len(sample_types) - - fig, ax = plt.subplots(figsize=(10, 7)) - index = np.arange(n_bars) - bar_width = 0.35 - opacity = 0.8 - - # Adjusting the colors for better within-bar comparison - max_transcripts = max(len(gene_data[gene_name][sample]) for sample in sample_types) - colors = plt.cm.plasma( - np.linspace(0, 1, num=max_transcripts) - ) # Using plasma for better color gradation - - bottom_val = np.zeros(n_bars) - for i, sample_type in enumerate(sample_types): - transcripts = gene_data[gene_name][sample_type] - for j, (transcript_id, value) in enumerate(transcripts): - color = colors[j % len(colors)] - plt.bar( - i, - float(value), - bar_width, - bottom=bottom_val[i], - alpha=opacity, - color=color, - label=transcript_id if i == 0 else "", + def make_pie_charts(self): + """ + Create pie charts for transcript alignment classifications and read assignment consistency. + """ + titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] + + for i, (title, reads_dict) in enumerate(zip(titles, self.reads_and_class)): + labels = reads_dict.keys() + sizes = reads_dict.values() + total = sum(sizes) + + # Generate a file-friendly title + file_title = title.lower().replace(" ", "_") + + plt.figure() + wedges, texts, autotexts = plt.pie( + sizes, + labels=labels, + autopct="%1.1f%%", + startangle=140, + textprops=dict(color="w"), ) - bottom_val[i] += float(value) + plt.setp(autotexts, size=10, weight="bold") + plt.setp(texts, size=9) - plt.xlabel("Sample Type") - plt.ylabel("Transcript Usage (TPM)") - plt.title(f"Transcript Usage for {gene_name} by Sample Type") - plt.xticks(index, sample_types) - plt.legend(title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc="upper left") + plt.axis( + "equal" + ) # Equal aspect ratio ensures that pie is drawn as a circle. + plt.title(f"{title}\nTotal: {total}") - plt.tight_layout() - plt.show() - plt.savefig(f"{gene_name}_transcript_usage_by_sample_type_ref.png") + plt.legend( + wedges, + labels, + title="Categories", + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + ) + plot_path = os.path.join( + self.visualization_dir, f"{file_title}_pie_chart.png" + ) + plt.savefig(plot_path, bbox_inches="tight") + plt.close() diff --git a/src/post_process.py b/src/post_process.py index 0d94d691..9315ed5c 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -1,10 +1,11 @@ import csv import os -import re +import pickle import gzip import shutil import copy import json +from argparse import Namespace class OutputConfig: @@ -33,7 +34,7 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.use_counts = use_counts self.ref_only = ref_only - self._parse_isoquant_log() # Always parse the log + self._load_params_file() # Load the params file instead of parsing the log self._find_files() self._conditional_unzip() @@ -43,29 +44,33 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): "Input GTF file is required when ref_only is set. Please provide it using the --gtf flag." ) - def _parse_isoquant_log(self): - """Parse the isoquant.log for necessary configuration and commands.""" - log_path = os.path.join(self.output_directory, "isoquant.log") - assert os.path.exists(log_path), f"Log file not found: {log_path}" - if os.path.exists(log_path): - with open(log_path, "r") as file: - log_content = file.read() - gene_db_match = re.search(r"--genedb (\S+)", log_content) - fastq_flag = "--fastq" in log_content - processing_sample_match = re.search( - r"Processed experiment (\S+)", log_content - ) - if gene_db_match and not self.input_gtf: - self.input_gtf = gene_db_match.group(1) - self.log_details["gene_db"] = self.input_gtf - self.log_details["fastq_used"] = fastq_flag - - if processing_sample_match: - self.output_directory = os.path.join( - self.output_directory, processing_sample_match.group(1) - ) + def _load_params_file(self): + """Load the .params file for necessary configuration and commands.""" + params_path = os.path.join(self.output_directory, ".params") + assert os.path.exists(params_path), f"Params file not found: {params_path}" + try: + with open(params_path, "rb") as file: + params = pickle.load(file) + if isinstance(params, Namespace): + self._process_params(vars(params)) else: - raise ValueError("Processing sample directory not found in log.") + print("Unexpected params format.") + except Exception as e: + raise ValueError(f"An error occurred while loading params: {e}") + + def _process_params(self, params): + """Process parameters loaded from the .params file.""" + self.log_details["gene_db"] = params.get("genedb") + self.log_details["fastq_used"] = bool(params.get("fastq")) + self.input_gtf = self.input_gtf or params.get("genedb") + + processing_sample = params.get("prefix") + if processing_sample: + self.output_directory = os.path.join( + self.output_directory, processing_sample + ) + else: + raise ValueError("Processing sample directory not found in params.") def _conditional_unzip(self): """Check if unzip is needed and perform it conditionally based on the model use.""" @@ -81,7 +86,7 @@ def _unzip_file(self, file_path): new_path = file_path[:-3] # Remove .gz extension if os.path.exists(new_path): - print(f"File {new_path} already exists, using this file.") + # print(f"File {new_path} already exists, using this file.") return new_path if not os.path.exists(file_path): @@ -502,6 +507,6 @@ def read_gene_list(self, gene_list_path): def save_gene_dict_to_json(self, gene_dict, output_path): """Saves the gene dictionary to a JSON file.""" # name the gene_dict file - output_path = os.path.join(self.config.output_directory, "gene_dict.json") + output_path = os.path.join(output_path, "gene_dict.json") with open(output_path, "w") as file: json.dump(gene_dict, file, indent=4) diff --git a/src/process_dict.py b/src/process_dict.py new file mode 100644 index 00000000..bb3ca001 --- /dev/null +++ b/src/process_dict.py @@ -0,0 +1,92 @@ +import json +import sys +import os + + +def simplify_and_sum_transcripts(data): + gene_totals_across_conditions = {} + simplified_data = {} + + # Sum transcript values and collect them across all conditions + for sample_id, genes in data.items(): + simplified_data[sample_id] = {} + for gene_id, gene_data in genes.items(): + transcripts = gene_data.get("transcripts", {}) + total_value = 0.0 + simplified_transcripts = {} + for transcript_id, transcript_details in transcripts.items(): + transcript_value = ( + transcript_details.get("value", 0.0) + if isinstance(transcript_details, dict) + else 0.0 + ) + simplified_transcripts[transcript_id] = transcript_value + total_value += transcript_value + + gene_data_copy = ( + gene_data.copy() + ) # Make a copy to avoid modifying the original + gene_data_copy["transcripts"] = simplified_transcripts + gene_data_copy["value"] = ( + total_value # Replace the gene-level value with the sum of transcript values + ) + simplified_data[sample_id][gene_id] = gene_data_copy + + if gene_id not in gene_totals_across_conditions: + gene_totals_across_conditions[gene_id] = [] + gene_totals_across_conditions[gene_id].append(total_value) + + # Determine which genes to remove + genes_to_remove = [ + gene_id + for gene_id, totals in gene_totals_across_conditions.items() + if all(total < 5 for total in totals) + ] + + # Remove genes from the simplified data structure + for sample_id, genes in simplified_data.items(): + for gene_id in genes_to_remove: + if gene_id in genes: + del genes[gene_id] + + return simplified_data + + +def read_json(file_path): + with open(file_path, "r") as file: + return json.load(file) + + +def write_json(data, file_path): + with open(file_path, "w") as file: + json.dump(data, file, indent=4) + + +def main(): + if len(sys.argv) != 2: + print("Usage: python script.py ") + sys.exit(1) + + input_file_path = sys.argv[1] + base, ext = os.path.splitext(input_file_path) + output_file_path = f"{base}_simplified{ext}" + + try: + # Load the gene data from the specified input JSON file + gene_dict = read_json(input_file_path) + + # Simplify the transcripts, sum their values, and remove genes under a threshold across all conditions + modified_gene_dict = simplify_and_sum_transcripts(gene_dict) + + # Save the modified gene data to the newly named output JSON file + write_json(modified_gene_dict, output_file_path) + + print(f"Modified gene data has been saved to {output_file_path}") + + except Exception as e: + print(f"Error: {str(e)}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/visualize.py b/visualize.py index 97e56b28..5b82cfe7 100644 --- a/visualize.py +++ b/visualize.py @@ -3,6 +3,17 @@ from src.post_process import OutputConfig, DictionaryBuilder from src.plot_output import PlotOutput import argparse +from src.process_dict import simplify_and_sum_transcripts +from src.gene_model import rank_and_visualize_genes + +import argparse + + +class FindGenesAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + if values is None: + values = 100 # Default value when the flag is used without a value + setattr(namespace, self.dest, values) def parse_arguments(): @@ -33,7 +44,7 @@ def parse_arguments(): parser.add_argument( "--filter_transcripts", type=float, - help="Filter transcripts by minimum value occuring in at least one condition.", + help="Filter transcripts by minimum value occurring in at least one condition.", default=None, ) parser.add_argument( @@ -42,6 +53,13 @@ def parse_arguments(): required=True, help="Path to a .txt file containing a list of genes, each on its own line.", ) + parser.add_argument( + "--find_genes", + nargs="?", + const=100, + type=int, + help="Find genes with the highest combined rank and visualize them. Optionally specify the number of top genes to evaluate (default is 100).", + ) return parser.parse_args() @@ -73,7 +91,7 @@ def main(): updated_gene_dict = dictionary_builder.update_gene_dict(gene_dict, gene_file) if update_names: - print("Updating gene names to gene symbols.") + print("Updating Ensembl IDs to gene symbols.") updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) if output.ref_only or not output.extended_annotation: @@ -131,18 +149,29 @@ def main(): # Visualization output directory decision viz_output_directory = args.viz_output if args.viz_output else args.output_directory - dictionary_builder.save_gene_dict_to_json(updated_gene_dict, args.output_directory) + + if args.find_genes: + print("Finding genes.") + simple_gene_dict = simplify_and_sum_transcripts(updated_gene_dict) + path = rank_and_visualize_genes( + simple_gene_dict, viz_output_directory, args.find_genes + ) + gene_list = dictionary_builder.read_gene_list(path) + + # dictionary_builder.save_gene_dict_to_json(updated_gene_dict, viz_output_directory) plot_output = PlotOutput( updated_gene_dict, gene_list, viz_output_directory, - reads_and_class, + create_visualization_subdir=(viz_output_directory == args.output_directory), + reads_and_class=reads_and_class, filter_transcripts=args.filter_transcripts, conditions=output.conditions, use_counts=args.counts, ) plot_output.plot_transcript_map() plot_output.plot_transcript_usage() + plot_output.make_pie_charts() if __name__ == "__main__": From 03c1aa2749143500c06cd3659f8b6791af3fa510 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Wed, 31 Jul 2024 11:06:27 -0500 Subject: [PATCH 04/35] manual gtf parse to gffutils - plus name check --- src/gene_model.py | 30 +++++++++- src/post_process.py | 137 ++++++++++++++++++++------------------------ visualize.py | 1 - 3 files changed, 89 insertions(+), 79 deletions(-) diff --git a/src/gene_model.py b/src/gene_model.py index 73c1ae88..f1a88726 100644 --- a/src/gene_model.py +++ b/src/gene_model.py @@ -85,7 +85,29 @@ def calculate_metrics(genes): return pd.DataFrame(metrics) +def check_known_target(gene, known_targets): + for target in known_targets: + if "|" in target: + if any(part in gene for part in target.split("|")): + return 1 + elif target == gene: + return 1 + return 0 + + def rank_genes(df): + target_genes_df = pd.read_csv( + "/w5home/jfreeman/IsoQuant/Target_genes.csv", header=None, names=["gene"] + ) + + # Convert the known targets to a list + known_targets = target_genes_df["gene"].tolist() + + # Apply the function to create the 'known_target' column + df["known_target"] = df["gene"].apply( + lambda x: check_known_target(x, known_targets) + ) + value_ranking = df.groupby("gene")["value"].mean().reset_index() abs_diff_ranking = df.groupby("gene")["abs_diff"].mean().reset_index() deviance_ranking = df.groupby("gene")["deviance"].mean().reset_index() @@ -102,7 +124,7 @@ def rank_genes(df): abs_diff_ranking[["gene", "rank_abs_diff"]], on="gene" ) merged_df = merged_df.merge(deviance_ranking[["gene", "rank_deviance"]], on="gene") - + merged_df = merged_df.merge(df[["gene", "known_target"]], on="gene") # Devalue the importance of overall expression by reducing its weight merged_df["combined_rank"] = ( merged_df["rank_value"] # Reduced weight for rank_value @@ -256,11 +278,12 @@ def save_top_genes(top_combined_ranking, output_dir, num_genes): def rank_and_visualize_genes(input_data, output_dir, num_genes=100): genes = parse_data(input_data) + print(genes) metrics_df = calculate_metrics(genes) top_combined_ranking, top_deviance_ranking, top_100_combined_ranking, merged_df = ( rank_genes(metrics_df) ) - + merged_df = merged_df.drop_duplicates(subset="gene", keep="first") top_combined_ranking = merged_df.sort_values(by="combined_rank").head(num_genes) top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(num_genes) @@ -274,4 +297,7 @@ def rank_and_visualize_genes(input_data, output_dir, num_genes=100): print(f"\nNumber of Genes: {len(merged_df)}") print(f"\nTop {num_genes} Genes by Transcript Deviance from Wild Type:") print(top_deviance_ranking) + + merged_df.to_csv(os.path.join(output_dir, "gene_metrics.csv"), index=False) + return path diff --git a/src/post_process.py b/src/post_process.py index 9315ed5c..9f90cfb1 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -6,6 +6,8 @@ import copy import json from argparse import Namespace +import tempfile +import gffutils class OutputConfig: @@ -217,8 +219,7 @@ def build_read_assignment_and_classification_dictionaries(self): return classification_counts, assignment_type_counts def parse_input_gtf(self): - """Parses the GTF file to build a detailed dictionary of genes, transcripts, and exons, - while skipping entries that are not genes, transcripts, or exons.""" + """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" gene_dict = {} if not self.config.input_gtf: raise FileNotFoundError("Extended annotation GTF file is missing.") @@ -226,82 +227,62 @@ def parse_input_gtf(self): input_gtf_path = self.config.input_gtf try: - # Try opening the file as a regular text file first - file = open(input_gtf_path, "r") - except FileNotFoundError: - raise FileNotFoundError( - f"Extended annotation GTF file is missing at {input_gtf_path}." - ) - except OSError: - # If it fails, assume it's likely gzipped and try opening it with gzip - try: - file = gzip.open(input_gtf_path, "rt") - except FileNotFoundError: - input_gtf_path = input_gtf_path.rstrip(".gz") - try: - file = open(input_gtf_path, "r") - except FileNotFoundError: - raise FileNotFoundError( - f"Extended annotation GTF file is missing at {input_gtf_path}." - ) - - with file: - for line in file: - if line.startswith("#") or not line.strip(): - continue - fields = line.strip().split("\t") - if len(fields) < 9: - print( - f"Skipping malformed line due to insufficient fields: {line.strip()}" - ) - continue - - entry_type = fields[2].lower() - if entry_type not in {"gene", "transcript", "exon"}: - continue # Skip types like CDS, start_codon, etc. + # Create a temporary database + with tempfile.NamedTemporaryFile(suffix=".db") as tmp: + db = gffutils.create_db( + input_gtf_path, + dbfn=tmp.name, + force=True, + keep_order=True, + merge_strategy="merge", + sort_attribute_values=True, + disable_infer_genes=True, + disable_infer_transcripts=True, + ) - info_fields = fields[8].strip(";").split(";") - details = { - field.strip().split(" ")[0]: field.strip().split(" ")[1].strip('"') - for field in info_fields - if " " in field - } + for gene in db.features_of_type("gene"): + gene_id = gene.id + gene_dict[gene_id] = { + "chromosome": gene.seqid, + "start": gene.start, + "end": gene.end, + "strand": gene.strand, + "name": gene.attributes.get("gene_name", [""])[0], + "biotype": gene.attributes.get("gene_biotype", [""])[0], + "transcripts": {}, + } - try: - if entry_type == "gene": - gene_id = details["gene_id"] - gene_dict[gene_id] = { - "chromosome": fields[0], - "start": int(fields[3]), - "end": int(fields[4]), - "strand": fields[6], - "name": details.get("gene_name", ""), - "biotype": details.get("gene_biotype", ""), - "transcripts": {}, - } - elif entry_type == "transcript": - transcript_id = details["transcript_id"] - gene_dict[details["gene_id"]]["transcripts"][transcript_id] = { - "start": int(fields[3]), - "end": int(fields[4]), - "name": details.get("transcript_name", ""), - "biotype": details.get("transcript_biotype", ""), + for transcript in db.children(gene, featuretype="transcript"): + transcript_id = transcript.id + gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": transcript.start, + "end": transcript.end, + "name": transcript.attributes.get("transcript_name", [""])[ + 0 + ], + "biotype": transcript.attributes.get( + "transcript_biotype", [""] + )[0], "exons": [], - "tags": details.get("tag", "").split(","), - } - elif entry_type == "exon": - transcript_id = details["transcript_id"] - exon_info = { - "exon_id": details["exon_id"], - "start": int(fields[3]), - "end": int(fields[4]), - "number": details.get("exon_number", ""), + "tags": transcript.attributes.get("tag", [""])[0].split( + "," + ), } - gene_dict[details["gene_id"]]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) - except KeyError as e: - print(f"Key error in line: {line.strip()} | Missing key: {e}") + + for exon in db.children(transcript, featuretype="exon"): + exon_info = { + "exon_id": exon.id, + "start": exon.start, + "end": exon.end, + "number": exon.attributes.get("exon_number", [""])[0], + } + gene_dict[gene_id]["transcripts"][transcript_id][ + "exons" + ].append(exon_info) + + except Exception as e: + raise Exception(f"Error parsing GTF file: {str(e)}") + return gene_dict def parse_extended_annotation(self): @@ -451,8 +432,12 @@ def update_gene_names(self, gene_dict): for condition, genes in gene_dict.items(): updated_genes = {} for gene_id, gene_info in genes.items(): - gene_name_upper = gene_info["name"].upper() - updated_genes[gene_name_upper] = gene_info + if gene_info["name"]: + gene_name_upper = gene_info["name"].upper() + updated_genes[gene_name_upper] = gene_info + else: + # If name is empty, use the original gene_id + updated_genes[gene_id] = gene_info updated_dict[condition] = updated_genes return updated_dict diff --git a/visualize.py b/visualize.py index 5b82cfe7..b786a7a9 100644 --- a/visualize.py +++ b/visualize.py @@ -89,7 +89,6 @@ def main(): gene_file = output.gene_tpm if not output.use_counts else output.gene_counts updated_gene_dict = dictionary_builder.update_gene_dict(gene_dict, gene_file) - if update_names: print("Updating Ensembl IDs to gene symbols.") updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) From 388eac0d2c995374aaf3d50ac612ecdf41e8f75d Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Thu, 1 Aug 2024 15:29:17 -0500 Subject: [PATCH 05/35] Updated README and requirements.txt --- README.md | 41 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 5 +++++ src/gene_model.py | 31 +++++++++++++------------------ visualize.py | 13 ++++++++++--- 4 files changed, 69 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index c412ac3f..7eaa4f83 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ 3.1. [IsoQuant input](#sec3.1)
3.2. [Command line options](#sec3.2)
3.3. [IsoQuant output](#sec3.3)
+3.4. [Visualization](#sec3.4)
4. [Citation](#sec4)
5. [Feedback and bug reports](#sec5)
@@ -918,6 +919,46 @@ In addition, it contains `canonical` property if `--check_canonical` is set. #### PolyA classifications ![PolyA](figs/polya.png) + + +## Visualization + +IsoQuant provides a visualization tool to help interpret and explore the output data. The goal of this visualization is to create informative plots that represent transcript usage, splicing patterns, and read assignment statistics from the IsoQuant analysis. + +### Running the visualization tool + +To run the visualization tool, use the following command: + +```bash + +python visualize.py [options] + +``` + +### Command line options + +* `output_directory` (required): Directory containing IsoQuant output files. +* `--viz_output`: Optional directory to save visualization output files. Defaults to the main output directory if not specified. +* `--gtf`: Optional path to a GTF file if it cannot be extracted from the IsoQuant log. +* `--counts`: Use counts instead of TPM files for visualization. +* `--ref_only`: Use only reference transcript quantification instead of transcript model quantification. +* `--filter_transcripts`: Filter transcripts by minimum value occurring in at least one condition. +* `--gene_list`: Path to a .txt file containing a list of genes, each on its own line (required). + +### Output + +The visualization tool generates the following plots based on the IsoQuant output: + +1. Transcript usage profiles: For each gene specified in the gene list, a plot showing the relative usage of different transcripts across conditions or samples. + +2. Gene-specific transcript maps: Visual representation of the different splicing patterns of transcripts for each gene, allowing easy comparison of exon usage and alternative splicing events. + +3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed. + +4. Global transcript alignment classifications: A chart or plot representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. + +These visualizations provide valuable insights into transcript diversity, splicing patterns, and the overall quality of the IsoQuant analysis. + ## Citation The paper describing IsoQuant algorithms and benchmarking is available at [10.1038/s41587-022-01565-y](https://doi.org/10.1038/s41587-022-01565-y). diff --git a/requirements.txt b/requirements.txt index ed7ed899..52037078 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,8 @@ pysam>=0.15 packaging pyfaidx>=0.7 pyyaml>=5.4 +matplotlib>=3.1.3 +numpy>=1.18.1 +scipy>=1.4.1 +seaborn>=0.10.0 + diff --git a/src/gene_model.py b/src/gene_model.py index f1a88726..0a7e893b 100644 --- a/src/gene_model.py +++ b/src/gene_model.py @@ -95,18 +95,15 @@ def check_known_target(gene, known_targets): return 0 -def rank_genes(df): - target_genes_df = pd.read_csv( - "/w5home/jfreeman/IsoQuant/Target_genes.csv", header=None, names=["gene"] - ) - - # Convert the known targets to a list - known_targets = target_genes_df["gene"].tolist() - - # Apply the function to create the 'known_target' column - df["known_target"] = df["gene"].apply( - lambda x: check_known_target(x, known_targets) - ) +def rank_genes(df, known_genes_path=None): + if known_genes_path: + target_genes_df = pd.read_csv(known_genes_path, header=None, names=["gene"]) + known_targets = target_genes_df["gene"].tolist() + df["known_target"] = df["gene"].apply( + lambda x: check_known_target(x, known_targets) + ) + else: + df["known_target"] = 0 value_ranking = df.groupby("gene")["value"].mean().reset_index() abs_diff_ranking = df.groupby("gene")["abs_diff"].mean().reset_index() @@ -276,12 +273,13 @@ def save_top_genes(top_combined_ranking, output_dir, num_genes): return os.path.join(output_dir, f"top_{num_genes}_genes.txt") -def rank_and_visualize_genes(input_data, output_dir, num_genes=100): +def rank_and_visualize_genes( + input_data, output_dir, num_genes=100, known_genes_path=None +): genes = parse_data(input_data) - print(genes) metrics_df = calculate_metrics(genes) top_combined_ranking, top_deviance_ranking, top_100_combined_ranking, merged_df = ( - rank_genes(metrics_df) + rank_genes(metrics_df, known_genes_path) ) merged_df = merged_df.drop_duplicates(subset="gene", keep="first") top_combined_ranking = merged_df.sort_values(by="combined_rank").head(num_genes) @@ -294,9 +292,6 @@ def rank_and_visualize_genes(input_data, output_dir, num_genes=100): print(top_combined_ranking[["gene", "combined_rank"]]) print(f"\nDetailed Metrics for Top {num_genes} Genes by Combined Ranking:") print(top_combined_ranking) - print(f"\nNumber of Genes: {len(merged_df)}") - print(f"\nTop {num_genes} Genes by Transcript Deviance from Wild Type:") - print(top_deviance_ranking) merged_df.to_csv(os.path.join(output_dir, "gene_metrics.csv"), index=False) diff --git a/visualize.py b/visualize.py index b786a7a9..799c2c37 100644 --- a/visualize.py +++ b/visualize.py @@ -6,8 +6,6 @@ from src.process_dict import simplify_and_sum_transcripts from src.gene_model import rank_and_visualize_genes -import argparse - class FindGenesAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): @@ -60,6 +58,12 @@ def parse_arguments(): type=int, help="Find genes with the highest combined rank and visualize them. Optionally specify the number of top genes to evaluate (default is 100).", ) + parser.add_argument( + "--known_genes_path", + type=str, + help="Path to a CSV file containing known target genes.", + default=None, + ) return parser.parse_args() @@ -153,7 +157,10 @@ def main(): print("Finding genes.") simple_gene_dict = simplify_and_sum_transcripts(updated_gene_dict) path = rank_and_visualize_genes( - simple_gene_dict, viz_output_directory, args.find_genes + simple_gene_dict, + viz_output_directory, + args.find_genes, + known_genes_path=args.known_genes_path, ) gene_list = dictionary_builder.read_gene_list(path) From 26cf0fc59807c9ae1bd12574ac65264f84dbef2a Mon Sep 17 00:00:00 2001 From: Jack Freeman Date: Thu, 1 Aug 2024 15:33:58 -0500 Subject: [PATCH 06/35] Update README.md --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7eaa4f83..0357847d 100644 --- a/README.md +++ b/README.md @@ -923,7 +923,7 @@ In addition, it contains `canonical` property if `--check_canonical` is set. ## Visualization -IsoQuant provides a visualization tool to help interpret and explore the output data. The goal of this visualization is to create informative plots that represent transcript usage, splicing patterns, and read assignment statistics from the IsoQuant analysis. +IsoQuant provides a visualization tool to help interpret and explore the output data. The goal of this visualization is to create informative plots that represent transcript usage and splicing patterns for genes of interest. Additionally, we provide global transcript and read assignment statistics from the IsoQuant analysis. ### Running the visualization tool @@ -931,19 +931,20 @@ To run the visualization tool, use the following command: ```bash -python visualize.py [options] +python visualize.py --gene_list [options] ``` ### Command line options * `output_directory` (required): Directory containing IsoQuant output files. +* * `--gene_list` (required): Path to a .txt file containing a list of genes, each on its own line. * `--viz_output`: Optional directory to save visualization output files. Defaults to the main output directory if not specified. * `--gtf`: Optional path to a GTF file if it cannot be extracted from the IsoQuant log. * `--counts`: Use counts instead of TPM files for visualization. * `--ref_only`: Use only reference transcript quantification instead of transcript model quantification. * `--filter_transcripts`: Filter transcripts by minimum value occurring in at least one condition. -* `--gene_list`: Path to a .txt file containing a list of genes, each on its own line (required). + ### Output From 00c39e710cb75503daf224058e56e6c930b73ca9 Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 2 Aug 2024 02:18:55 +0300 Subject: [PATCH 07/35] save genedb name in .params before conversion --- isoquant.py | 5 ++++- src/gtf2db.py | 8 ++------ src/read_mapper.py | 6 +++--- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/isoquant.py b/isoquant.py index d83dac80..d0d26616 100755 --- a/isoquant.py +++ b/isoquant.py @@ -315,11 +315,14 @@ def check_and_load_args(args, parser): else: logger.warning("Output folder already exists, some files may be overwritten.") - args.gtf = args.genedb if args.genedb_output is None: args.genedb_output = args.output elif not os.path.exists(args.genedb_output): os.makedirs(args.genedb_output) + if args.genedb.lower().endswith("db"): + args.genedb_filename = args.genedb + else: + args.genedb_filename = os.path.join(args.output, os.path.splitext(os.path.basename(args.genedb))[0] + ".db") if not check_input_params(args): parser.print_usage() diff --git a/src/gtf2db.py b/src/gtf2db.py index fe467467..ef6fc4fa 100755 --- a/src/gtf2db.py +++ b/src/gtf2db.py @@ -134,14 +134,10 @@ def gtf2db(gtf, db, complete_db=False, check_gtf=True): logger.info("Provide this database next time to avoid excessive conversion") -def convert_gtf_to_db(args, output_is_dir=True, check_input_gtf=True): +def convert_gtf_to_db(args): gtf_filename = args.genedb gtf_filename = os.path.abspath(gtf_filename) - output_path = args.output if args.genedb_output is None else args.genedb_output - if output_is_dir: - genedb_filename = os.path.join(output_path, os.path.splitext(os.path.basename(gtf_filename))[0] + ".db") - else: - genedb_filename = output_path + "." + os.path.splitext(os.path.basename(gtf_filename))[0] + ".db" + genedb_filename = args.genedb_filename gtf_filename, genedb_filename = convert_db(gtf_filename, genedb_filename, gtf2db, args) return genedb_filename diff --git a/src/read_mapper.py b/src/read_mapper.py index 56e8a936..670c56a9 100644 --- a/src/read_mapper.py +++ b/src/read_mapper.py @@ -251,9 +251,9 @@ def find_annotation(aligner, args): if args.no_junc_bed: return None if aligner == "starlong": - if args.gtf('.db'): - args.gtf = convert_db_to_gtf(args) - return os.path.abspath(args.gtf) + if args.genedb.lower().endswith("db"): + return os.path.abspath(convert_db_to_gtf(args)) + return os.path.abspath(args.genedb) elif aligner == "minimap2": bed_fname = None if args.junc_bed_file: From 348af00504d3123f5c896f121a9ebbb85ec4633c Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 2 Aug 2024 02:29:16 +0300 Subject: [PATCH 08/35] thread db file into visuzalizer --- src/post_process.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/post_process.py b/src/post_process.py index 9f90cfb1..12a3ff17 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -19,6 +19,7 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.extended_annotation = None self.read_assignments = None self.input_gtf = gtf # Initialize with the provided gtf flag + self.genedb_filename = None self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. self.conditions = False self.gene_grouped_counts = None @@ -65,6 +66,7 @@ def _process_params(self, params): self.log_details["gene_db"] = params.get("genedb") self.log_details["fastq_used"] = bool(params.get("fastq")) self.input_gtf = self.input_gtf or params.get("genedb") + self.genedb_filename = params.get("genedb_filename") processing_sample = params.get("prefix") if processing_sample: @@ -221,25 +223,12 @@ def build_read_assignment_and_classification_dictionaries(self): def parse_input_gtf(self): """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" gene_dict = {} - if not self.config.input_gtf: - raise FileNotFoundError("Extended annotation GTF file is missing.") - - input_gtf_path = self.config.input_gtf + if not self.config.genedb_filename: + raise FileNotFoundError("IsoQuant annotation DB file is missing.") try: # Create a temporary database - with tempfile.NamedTemporaryFile(suffix=".db") as tmp: - db = gffutils.create_db( - input_gtf_path, - dbfn=tmp.name, - force=True, - keep_order=True, - merge_strategy="merge", - sort_attribute_values=True, - disable_infer_genes=True, - disable_infer_transcripts=True, - ) - + with gffutils.FeatureDB(self.config.genedb_filename) as db: for gene in db.features_of_type("gene"): gene_id = gene.id gene_dict[gene_id] = { From 2da050c5a5dc792d3259dc5f419737e68079315c Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 2 Aug 2024 02:34:10 +0300 Subject: [PATCH 09/35] print usage if no args --- isoquant.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/isoquant.py b/isoquant.py index d0d26616..f6c2fc9e 100755 --- a/isoquant.py +++ b/isoquant.py @@ -796,8 +796,11 @@ def _check_log(): return all([result in log for result in correct_results]) -def main(args): - args, parser = parse_args(args) +def main(cmd_args): + args, parser = parse_args(cmd_args) + if not cmd_args: + parser.print_usage() + exit(0) set_logger(args, logger) args = check_and_load_args(args, parser) create_output_dirs(args) From c06c93fb3305a8eee062fb4fe781e582e6d51ddb Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 2 Aug 2024 02:40:14 +0300 Subject: [PATCH 10/35] make viz executive --- visualize.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 visualize.py diff --git a/visualize.py b/visualize.py old mode 100644 new mode 100755 From ac2918df7ec67a740da60a1b7b34f8ace3709ead Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 2 Aug 2024 02:48:40 +0300 Subject: [PATCH 11/35] support old isoquant runs --- src/post_process.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/post_process.py b/src/post_process.py index 12a3ff17..f801d6a8 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -224,7 +224,22 @@ def parse_input_gtf(self): """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" gene_dict = {} if not self.config.genedb_filename: - raise FileNotFoundError("IsoQuant annotation DB file is missing.") + # convert GFT to DB if we use previous IsoQuant runs + # remove this functionality later + tmp_file = tempfile.NamedTemporaryFile(suffix=".db") + self.config.genedb_filename = tmp_file.name + input_gtf_path = self.config.input_gtf + gffutils.create_db( + input_gtf_path, + dbfn=self.config.genedb_filename, + force=True, + keep_order=True, + merge_strategy="merge", + sort_attribute_values=True, + disable_infer_genes=True, + disable_infer_transcripts=True, + ) + # raise FileNotFoundError("IsoQuant annotation DB file is missing.") try: # Create a temporary database From 8ffa4a8cb9cf8c903d01798890f5097468ded24c Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Tue, 6 Aug 2024 21:44:21 -0500 Subject: [PATCH 12/35] Issue #222 yaml support --- src/plot_output.py | 85 ++++++++++------- src/post_process.py | 226 ++++++++++++++++++++++++++++++++------------ 2 files changed, 218 insertions(+), 93 deletions(-) diff --git a/src/plot_output.py b/src/plot_output.py index 51242cc8..9dd754e0 100644 --- a/src/plot_output.py +++ b/src/plot_output.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np +import pprint class PlotOutput: @@ -198,42 +199,58 @@ def plot_transcript_usage(self): def make_pie_charts(self): """ Create pie charts for transcript alignment classifications and read assignment consistency. + Handles both combined and separate sample data structures. """ + print("self.reads_and_class structure:") + pprint.pprint(self.reads_and_class) + titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] - for i, (title, reads_dict) in enumerate(zip(titles, self.reads_and_class)): - labels = reads_dict.keys() - sizes = reads_dict.values() - total = sum(sizes) - - # Generate a file-friendly title - file_title = title.lower().replace(" ", "_") - - plt.figure() - wedges, texts, autotexts = plt.pie( - sizes, - labels=labels, - autopct="%1.1f%%", - startangle=140, - textprops=dict(color="w"), - ) - plt.setp(autotexts, size=10, weight="bold") - plt.setp(texts, size=9) + for title, data in zip(titles, self.reads_and_class): + if isinstance(data, dict): + if any(isinstance(v, dict) for v in data.values()): + # Separate 'Mutants' and 'WildType' case + for sample_name, sample_data in data.items(): + self._create_pie_chart(f"{title} - {sample_name}", sample_data) + else: + # Combined data case + self._create_pie_chart(title, data) + else: + print(f"Skipping unexpected data type for {title}: {type(data)}") - plt.axis( - "equal" - ) # Equal aspect ratio ensures that pie is drawn as a circle. - plt.title(f"{title}\nTotal: {total}") + def _create_pie_chart(self, title, data): + """ + Helper method to create a single pie chart. + """ + labels = list(data.keys()) + sizes = list(data.values()) + total = sum(sizes) - plt.legend( - wedges, - labels, - title="Categories", - loc="center left", - bbox_to_anchor=(1, 0, 0.5, 1), - ) - plot_path = os.path.join( - self.visualization_dir, f"{file_title}_pie_chart.png" - ) - plt.savefig(plot_path, bbox_inches="tight") - plt.close() + # Generate a file-friendly title + file_title = title.lower().replace(" ", "_").replace("-", "_") + + plt.figure(figsize=(12, 8)) + wedges, texts, autotexts = plt.pie( + sizes, + labels=labels, + autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total):d})", + startangle=140, + textprops=dict(color="w"), + ) + plt.setp(autotexts, size=8, weight="bold") + plt.setp(texts, size=7) + + plt.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. + plt.title(f"{title}\nTotal: {total}") + + plt.legend( + wedges, + labels, + title="Categories", + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + fontsize=8, + ) + plot_path = os.path.join(self.visualization_dir, f"{file_title}_pie_chart.png") + plt.savefig(plot_path, bbox_inches="tight", dpi=300) + plt.close() diff --git a/src/post_process.py b/src/post_process.py index f801d6a8..94a1102f 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -8,6 +8,7 @@ from argparse import Namespace import tempfile import gffutils +import yaml class OutputConfig: @@ -20,6 +21,8 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.read_assignments = None self.input_gtf = gtf # Initialize with the provided gtf flag self.genedb_filename = None + self.yaml_input = True + self.yaml_input_path = None self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. self.conditions = False self.gene_grouped_counts = None @@ -37,7 +40,7 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.use_counts = use_counts self.ref_only = ref_only - self._load_params_file() # Load the params file instead of parsing the log + self._load_params_file() self._find_files() self._conditional_unzip() @@ -68,13 +71,23 @@ def _process_params(self, params): self.input_gtf = self.input_gtf or params.get("genedb") self.genedb_filename = params.get("genedb_filename") - processing_sample = params.get("prefix") - if processing_sample: - self.output_directory = os.path.join( - self.output_directory, processing_sample - ) + if params.get("yaml"): + # YAML input case + self.yaml_input = True + self.yaml_input_path = params.get("yaml") + # Keep the output_directory as is, don't modify it else: - raise ValueError("Processing sample directory not found in params.") + # Non-YAML input case + self.yaml_input = False + processing_sample = params.get("prefix") + if processing_sample: + self.output_directory = os.path.join( + self.output_directory, processing_sample + ) + else: + raise ValueError( + "Processing sample directory not found in params for non-YAML input." + ) def _conditional_unzip(self): """Check if unzip is needed and perform it conditionally based on the model use.""" @@ -106,9 +119,17 @@ def _unzip_file(self, file_path): def _find_files(self): """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" + if self.yaml_input: + self.conditions = True + self.ref_only = True + self._find_files_from_yaml() + return # Exit the method after processing YAML input + if not os.path.exists(self.output_directory): print(f"Directory not found: {self.output_directory}") # Debugging output - raise FileNotFoundError("Specified sample subdirectory does not exist.") + raise FileNotFoundError( + f"Specified sample subdirectory does not exist: {self.output_directory}" + ) for file_name in os.listdir(self.output_directory): if file_name.endswith(".extended_annotation.gtf"): @@ -174,6 +195,80 @@ def _find_files(self): if self.ref_only is None: self.ref_only = not self.extended_annotation + def _find_files_from_yaml(self): + """Locate the necessary files in the directory, set specific grouped count and TPM files, and process read assignments.""" + if not os.path.exists(self.yaml_input_path): + print(f"YAML file not found: {self.yaml_input_path}") + raise FileNotFoundError( + f"Specified YAML file does not exist: {self.yaml_input_path}" + ) + + # Set the four specific attributes + self.gene_grouped_counts = os.path.join( + self.output_directory, "combined_gene_counts.tsv" + ) + self.transcript_grouped_counts = os.path.join( + self.output_directory, "combined_transcript_counts.tsv" + ) + self.transcript_grouped_tpm = os.path.join( + self.output_directory, "combined_transcript_tpm.tsv" + ) + self.gene_grouped_tpm = os.path.join( + self.output_directory, "combined_gene_tpm.tsv" + ) + + # Check if the files exist + for attr in [ + "gene_grouped_counts", + "transcript_grouped_counts", + "transcript_grouped_tpm", + "gene_grouped_tpm", + ]: + file_path = getattr(self, attr) + if not os.path.exists(file_path): + print(f"Warning: {attr} file not found at {file_path}") + setattr(self, attr, None) + + # Initialize read_assignments list + self.read_assignments = [] + + # Read and process the YAML file + with open(self.yaml_input_path, "r") as yaml_file: + yaml_data = yaml.safe_load(yaml_file) + + # Check if yaml_data is a list + if isinstance(yaml_data, list): + samples = yaml_data + else: + # If it's not a list, assume it's a dictionary with a 'samples' key + samples = yaml_data.get("samples", []) + + for sample in samples: + name = sample.get("name") + if name: + sample_dir = os.path.join(self.output_directory, name) + + # Check for .read_assignments.tsv.gz + gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") + if os.path.exists(gz_file): + unzipped_file = self._unzip_file(gz_file) + if unzipped_file: + self.read_assignments.append((name, unzipped_file)) + else: + print(f"Warning: Failed to unzip {gz_file}") + else: + # Check for .read_assignments.tsv + non_gz_file = os.path.join( + sample_dir, f"{name}.read_assignments.tsv" + ) + if os.path.exists(non_gz_file): + self.read_assignments.append((name, non_gz_file)) + else: + print(f"Warning: No read assignments file found for {name}") + + if not self.read_assignments: + print("Warning: No read assignment files found for any samples") + class DictionaryBuilder: """Class to build dictionaries from the output files of the pipeline.""" @@ -189,25 +284,42 @@ def build_gene_transcript_exon_dictionaries(self): return self.parse_input_gtf() def build_read_assignment_and_classification_dictionaries(self): - """Indexes classifications and assignment types from the read_assignments.tsv.""" + """Indexes classifications and assignment types from read_assignments.tsv file(s).""" + if not self.config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + if isinstance(self.config.read_assignments, list): + # YAML input case (multiple files) + classification_counts_dict = {} + assignment_type_counts_dict = {} + for sample_name, read_assignment_file in self.config.read_assignments: + classification_counts, assignment_type_counts = ( + self._process_read_assignment_file(read_assignment_file) + ) + classification_counts_dict[sample_name] = classification_counts + assignment_type_counts_dict[sample_name] = assignment_type_counts + return classification_counts_dict, assignment_type_counts_dict + else: + # Non-YAML input case (single file) + return self._process_read_assignment_file(self.config.read_assignments) + + def _process_read_assignment_file(self, file_path): classification_counts = {} assignment_type_counts = {} - if not self.config.read_assignments: - raise FileNotFoundError("Read assignments file is missing.") - with open(self.config.read_assignments, "r") as file: - next(file) - next(file) - next(file) + with open(file_path, "r") as file: + # Skip header lines + for _ in range(3): + next(file, None) + for line in file: - parts = line.split("\t") + parts = line.strip().split("\t") if len(parts) < 6: continue + additional_info = parts[-1] classification = ( - additional_info.split("Classification=")[-1] - .replace(";", "") - .strip() + additional_info.split("Classification=")[-1].split(";")[0].strip() ) assignment_type = parts[5] @@ -224,7 +336,7 @@ def parse_input_gtf(self): """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" gene_dict = {} if not self.config.genedb_filename: - # convert GFT to DB if we use previous IsoQuant runs + # convert GTF to DB if we use previous IsoQuant runs # remove this functionality later tmp_file = tempfile.NamedTemporaryFile(suffix=".db") self.config.genedb_filename = tmp_file.name @@ -239,50 +351,46 @@ def parse_input_gtf(self): disable_infer_genes=True, disable_infer_transcripts=True, ) - # raise FileNotFoundError("IsoQuant annotation DB file is missing.") try: - # Create a temporary database - with gffutils.FeatureDB(self.config.genedb_filename) as db: - for gene in db.features_of_type("gene"): - gene_id = gene.id - gene_dict[gene_id] = { - "chromosome": gene.seqid, - "start": gene.start, - "end": gene.end, - "strand": gene.strand, - "name": gene.attributes.get("gene_name", [""])[0], - "biotype": gene.attributes.get("gene_biotype", [""])[0], - "transcripts": {}, + # Create a database without using a context manager + db = gffutils.FeatureDB(self.config.genedb_filename) + + for gene in db.features_of_type("gene"): + gene_id = gene.id + gene_dict[gene_id] = { + "chromosome": gene.seqid, + "start": gene.start, + "end": gene.end, + "strand": gene.strand, + "name": gene.attributes.get("gene_name", [""])[0], + "biotype": gene.attributes.get("gene_biotype", [""])[0], + "transcripts": {}, + } + + for transcript in db.children(gene, featuretype="transcript"): + transcript_id = transcript.id + gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": transcript.start, + "end": transcript.end, + "name": transcript.attributes.get("transcript_name", [""])[0], + "biotype": transcript.attributes.get( + "transcript_biotype", [""] + )[0], + "exons": [], + "tags": transcript.attributes.get("tag", [""])[0].split(","), } - for transcript in db.children(gene, featuretype="transcript"): - transcript_id = transcript.id - gene_dict[gene_id]["transcripts"][transcript_id] = { - "start": transcript.start, - "end": transcript.end, - "name": transcript.attributes.get("transcript_name", [""])[ - 0 - ], - "biotype": transcript.attributes.get( - "transcript_biotype", [""] - )[0], - "exons": [], - "tags": transcript.attributes.get("tag", [""])[0].split( - "," - ), + for exon in db.children(transcript, featuretype="exon"): + exon_info = { + "exon_id": exon.id, + "start": exon.start, + "end": exon.end, + "number": exon.attributes.get("exon_number", [""])[0], } - - for exon in db.children(transcript, featuretype="exon"): - exon_info = { - "exon_id": exon.id, - "start": exon.start, - "end": exon.end, - "number": exon.attributes.get("exon_number", [""])[0], - } - gene_dict[gene_id]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) + gene_dict[gene_id]["transcripts"][transcript_id][ + "exons" + ].append(exon_info) except Exception as e: raise Exception(f"Error parsing GTF file: {str(e)}") From c49ebff86a4a86a1741490c401d3078e4fe3a638 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Wed, 14 Aug 2024 16:49:58 -0500 Subject: [PATCH 13/35] restructured output and fized combined viz --- src/gene_model.py | 151 +++++++++++++++++++++++++++------------------ src/plot_output.py | 31 +++------- visualize.py | 67 ++++++++++++++++++-- 3 files changed, 164 insertions(+), 85 deletions(-) diff --git a/src/gene_model.py b/src/gene_model.py index 0a7e893b..a25008ce 100644 --- a/src/gene_model.py +++ b/src/gene_model.py @@ -1,7 +1,5 @@ -import json import os import pandas as pd -import numpy as np import matplotlib.pyplot as plt import seaborn as sns from scipy.spatial.distance import euclidean @@ -25,23 +23,23 @@ def parse_data(data): return genes -def calculate_deviance(wt_transcripts, condition_transcripts): - all_transcripts = set(wt_transcripts.keys()).union( +def calculate_deviance(reference_transcripts, condition_transcripts): + all_transcripts = set(reference_transcripts.keys()).union( set(condition_transcripts.keys()) ) - wt_proportions = [wt_transcripts.get(t, 0) for t in all_transcripts] + reference_proportions = [reference_transcripts.get(t, 0) for t in all_transcripts] condition_proportions = [condition_transcripts.get(t, 0) for t in all_transcripts] - total_wt = sum(wt_proportions) + total_reference = sum(reference_proportions) total_condition = sum(condition_proportions) - if total_wt > 0: - wt_proportions = [p / total_wt for p in wt_proportions] + if total_reference > 0: + reference_proportions = [p / total_reference for p in reference_proportions] if total_condition > 0: condition_proportions = [p / total_condition for p in condition_proportions] - distance = euclidean(wt_proportions, condition_proportions) + distance = euclidean(reference_proportions, condition_proportions) # Reduce distance if total unique transcripts are 1 if len(all_transcripts) == 1: @@ -50,10 +48,10 @@ def calculate_deviance(wt_transcripts, condition_transcripts): return distance -def calculate_metrics(genes): +def calculate_metrics(genes, reference_condition): metrics = [] for gene, gene_data in genes.items(): - wt_transcripts = gene_data["transcripts"].get("wild_type", {}) + reference_transcripts = gene_data["transcripts"].get(reference_condition, {}) for condition in gene_data: if condition in [ @@ -63,16 +61,16 @@ def calculate_metrics(genes): "strand", "biotype", "transcripts", - "wild_type", + reference_condition, ]: continue condition_transcripts = gene_data["transcripts"].get(condition, {}) - deviance = calculate_deviance(wt_transcripts, condition_transcripts) + deviance = calculate_deviance(reference_transcripts, condition_transcripts) metrics.append({"gene": gene, "condition": condition, "deviance": deviance}) value = gene_data.get(condition, 0) - wt_value = gene_data.get("wild_type", 0) - abs_diff = abs(value - wt_value) + reference_value = gene_data.get(reference_condition, 0) + abs_diff = abs(value - reference_value) metrics.append( { "gene": gene, @@ -96,15 +94,6 @@ def check_known_target(gene, known_targets): def rank_genes(df, known_genes_path=None): - if known_genes_path: - target_genes_df = pd.read_csv(known_genes_path, header=None, names=["gene"]) - known_targets = target_genes_df["gene"].tolist() - df["known_target"] = df["gene"].apply( - lambda x: check_known_target(x, known_targets) - ) - else: - df["known_target"] = 0 - value_ranking = df.groupby("gene")["value"].mean().reset_index() abs_diff_ranking = df.groupby("gene")["abs_diff"].mean().reset_index() deviance_ranking = df.groupby("gene")["deviance"].mean().reset_index() @@ -121,7 +110,15 @@ def rank_genes(df, known_genes_path=None): abs_diff_ranking[["gene", "rank_abs_diff"]], on="gene" ) merged_df = merged_df.merge(deviance_ranking[["gene", "rank_deviance"]], on="gene") - merged_df = merged_df.merge(df[["gene", "known_target"]], on="gene") + + if known_genes_path: + target_genes_df = pd.read_csv(known_genes_path, header=None, names=["gene"]) + known_targets = target_genes_df["gene"].tolist() + df["known_target"] = df["gene"].apply( + lambda x: check_known_target(x, known_targets) + ) + merged_df = merged_df.merge(df[["gene", "known_target"]], on="gene") + # Devalue the importance of overall expression by reducing its weight merged_df["combined_rank"] = ( merged_df["rank_value"] # Reduced weight for rank_value @@ -142,24 +139,40 @@ def rank_genes(df, known_genes_path=None): def visualize_ranking( - top_combined_ranking, top_deviance_ranking, merged_df, output_dir + top_combined_ranking, + top_deviance_ranking, + merged_df, + output_dir, + reference_condition, ): - if not os.path.exists(output_dir): - os.makedirs(output_dir) + find_genes_dir = os.path.join(output_dir, "find_genes") + if not os.path.exists(find_genes_dir): + os.makedirs(find_genes_dir) + + # Ensure we're using only the top 10 genes + top_10_combined = top_combined_ranking.head(10) + top_10_deviance = top_deviance_ranking.head(10) # Bar plot for combined rank plt.figure(figsize=(12, 8)) sns.barplot( - x="combined_rank", y="gene", data=top_combined_ranking, palette="viridis" + x="combined_rank", + y="gene", + data=top_10_combined, + hue="gene", + palette="viridis", + legend=False, ) plt.title("Top 10 Genes by Combined Ranking") plt.xlabel("Combined Rank") plt.ylabel("Gene") - plt.savefig(os.path.join(output_dir, "top_genes_combined_ranking.png"), dpi=300) + plt.savefig( + os.path.join(find_genes_dir, "top_10_genes_combined_ranking.png"), dpi=300 + ) plt.close() # Heatmap for metric ranks - top_genes = top_combined_ranking["gene"].tolist() + top_genes = top_10_combined["gene"].tolist() heatmap_data = merged_df[merged_df["gene"].isin(top_genes)] heatmap_data = heatmap_data.set_index("gene")[ ["rank_value", "rank_abs_diff", "rank_deviance"] @@ -174,7 +187,9 @@ def visualize_ranking( cbar_kws={"label": "Rank"}, ) plt.title("Metric Ranks for Top 10 Genes") - plt.savefig(os.path.join(output_dir, "metric_ranks_heatmap.png"), dpi=300) + plt.savefig( + os.path.join(find_genes_dir, "top_10_metric_ranks_heatmap.png"), dpi=300 + ) plt.close() # Diverging bar plot for deviance @@ -182,15 +197,19 @@ def visualize_ranking( sns.barplot( x="rank_deviance", y="gene", - data=top_deviance_ranking, + data=top_10_deviance, + hue="gene", palette="coolwarm", orient="h", + legend=False, ) - plt.title("Top 10 Genes by Transcript Deviance from Wild Type") - plt.xlabel("Rank of Deviance from Wild Type") + plt.title(f"Top 10 Genes by Transcript Deviance from {reference_condition}") + plt.xlabel(f"Rank of Deviance from {reference_condition}") plt.ylabel("Gene") plt.axvline(x=0, color="grey", linestyle="--") - plt.savefig(os.path.join(output_dir, "deviance_from_wild_type.png"), dpi=300) + plt.savefig( + os.path.join(find_genes_dir, "top_10_deviance_from_reference.png"), dpi=300 + ) plt.close() # Scatter plot for rank_value vs rank_abs_diff @@ -199,14 +218,16 @@ def visualize_ranking( x="rank_value", y="rank_abs_diff", hue="gene", - data=top_deviance_ranking, + data=top_10_combined, palette="deep", s=100, ) - plt.title("Rank Value vs Rank Absolute Difference") + plt.title("Rank Value vs Rank Absolute Difference (Top 10 Genes)") plt.xlabel("Rank Value") plt.ylabel("Rank Absolute Difference") - plt.savefig(os.path.join(output_dir, "rank_value_vs_rank_abs_diff.png"), dpi=300) + plt.savefig( + os.path.join(find_genes_dir, "top_10_rank_value_vs_rank_abs_diff.png"), dpi=300 + ) plt.close() # Combined multi-metric visualization @@ -214,14 +235,15 @@ def visualize_ranking( sns.barplot( x="combined_rank", y="gene", - data=top_combined_ranking, + data=top_10_combined, + hue="gene", palette="viridis", ax=axes[0, 0], + legend=False, ) - axes[0, 0].set_title("Combined Rank") + axes[0, 0].set_title("Combined Rank (Top 10 Genes)") axes[0, 0].set_xlabel("Combined Rank") axes[0, 0].set_ylabel("Gene") - sns.heatmap( heatmap_data, annot=True, @@ -230,18 +252,22 @@ def visualize_ranking( cbar_kws={"label": "Rank"}, ax=axes[0, 1], ) - axes[0, 1].set_title("Metric Ranks") + axes[0, 1].set_title("Metric Ranks (Top 10 Genes)") sns.barplot( x="rank_deviance", y="gene", - data=top_deviance_ranking, + data=top_10_deviance, + hue="gene", palette="coolwarm", orient="h", ax=axes[1, 0], + legend=False, + ) + axes[1, 0].set_title( + f"Transcript Deviance from {reference_condition} (Top 10 Genes)" ) - axes[1, 0].set_title("Transcript Deviance from Wild Type") - axes[1, 0].set_xlabel("Rank of Deviance from Wild Type") + axes[1, 0].set_xlabel(f"Rank of Deviance from {reference_condition}") axes[1, 0].set_ylabel("Gene") axes[1, 0].axvline(x=0, color="grey", linestyle="--") @@ -249,17 +275,19 @@ def visualize_ranking( x="rank_value", y="rank_abs_diff", hue="gene", - data=top_deviance_ranking, + data=top_10_combined, palette="deep", s=100, ax=axes[1, 1], ) - axes[1, 1].set_title("Rank Value vs Rank Absolute Difference") + axes[1, 1].set_title("Rank Value vs Rank Absolute Difference (Top 10 Genes)") axes[1, 1].set_xlabel("Rank Value") axes[1, 1].set_ylabel("Rank Absolute Difference") plt.tight_layout() - plt.savefig(os.path.join(output_dir, "combined_visualization.png"), dpi=300) + plt.savefig( + os.path.join(find_genes_dir, "top_10_combined_visualization.png"), dpi=300 + ) plt.close() @@ -274,10 +302,14 @@ def save_top_genes(top_combined_ranking, output_dir, num_genes): def rank_and_visualize_genes( - input_data, output_dir, num_genes=100, known_genes_path=None + input_data, + output_dir, + num_genes=100, + known_genes_path=None, + reference_condition=None, ): genes = parse_data(input_data) - metrics_df = calculate_metrics(genes) + metrics_df = calculate_metrics(genes, reference_condition) top_combined_ranking, top_deviance_ranking, top_100_combined_ranking, merged_df = ( rank_genes(metrics_df, known_genes_path) ) @@ -285,14 +317,15 @@ def rank_and_visualize_genes( top_combined_ranking = merged_df.sort_values(by="combined_rank").head(num_genes) top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(num_genes) - visualize_ranking(top_combined_ranking, top_deviance_ranking, merged_df, output_dir) + visualize_ranking( + top_combined_ranking, + top_deviance_ranking, + merged_df, + output_dir, + reference_condition, + ) path = save_top_genes(top_combined_ranking, output_dir, num_genes) - - print(f"\nTop {num_genes} Genes by Combined Ranking:") - print(top_combined_ranking[["gene", "combined_rank"]]) - print(f"\nDetailed Metrics for Top {num_genes} Genes by Combined Ranking:") - print(top_combined_ranking) - - merged_df.to_csv(os.path.join(output_dir, "gene_metrics.csv"), index=False) + find_genes_dir = os.path.join(output_dir, "find_genes") + merged_df.to_csv(os.path.join(find_genes_dir, "gene_metrics.csv"), index=False) return path diff --git a/src/plot_output.py b/src/plot_output.py index 9dd754e0..b2550252 100644 --- a/src/plot_output.py +++ b/src/plot_output.py @@ -2,7 +2,6 @@ import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np -import pprint class PlotOutput: @@ -11,7 +10,7 @@ def __init__( updated_gene_dict, gene_names, output_directory, - create_visualization_subdir=False, + read_assignments_dir, reads_and_class=None, filter_transcripts=None, conditions=False, @@ -19,20 +18,16 @@ def __init__( ): self.updated_gene_dict = updated_gene_dict self.gene_names = gene_names - self.output_directory = output_directory + self.visualization_dir = output_directory + self.read_assignments_dir = read_assignments_dir self.reads_and_class = reads_and_class self.filter_transcripts = filter_transcripts self.conditions = conditions self.use_counts = use_counts - # Create visualization subdirectory if specified - if create_visualization_subdir: - self.visualization_dir = os.path.join( - self.output_directory, "visualization" - ) - os.makedirs(self.visualization_dir, exist_ok=True) - else: - self.visualization_dir = self.output_directory + # Ensure the visualization directory exists + os.makedirs(self.visualization_dir, exist_ok=True) + os.makedirs(self.read_assignments_dir, exist_ok=True) def plot_transcript_map(self): # Get the first condition's gene dictionary @@ -148,14 +143,6 @@ def plot_transcript_usage(self): index = np.arange(n_bars) bar_width = 0.35 opacity = 0.8 - - # for sample_type, transcripts in gene_data.items(): - # print(f"Sample Type: {sample_type}") - # for transcript_id, transcript_info in transcripts.items(): - # print( - # f" Transcript ID: {transcript_id}, Value: {transcript_info['value']}" - # ) - # Adjusting the colors for better within-bar comparison max_transcripts = max(len(gene_data[condition]) for condition in conditions) colors = plt.cm.plasma( np.linspace(0, 1, num=max_transcripts) @@ -201,8 +188,6 @@ def make_pie_charts(self): Create pie charts for transcript alignment classifications and read assignment consistency. Handles both combined and separate sample data structures. """ - print("self.reads_and_class structure:") - pprint.pprint(self.reads_and_class) titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] @@ -251,6 +236,8 @@ def _create_pie_chart(self, title, data): bbox_to_anchor=(1, 0, 0.5, 1), fontsize=8, ) - plot_path = os.path.join(self.visualization_dir, f"{file_title}_pie_chart.png") + plot_path = os.path.join( + self.read_assignments_dir, f"{file_title}_pie_chart.png" + ) plt.savefig(plot_path, bbox_inches="tight", dpi=300) plt.close() diff --git a/visualize.py b/visualize.py index 799c2c37..35f6175c 100755 --- a/visualize.py +++ b/visualize.py @@ -5,6 +5,7 @@ import argparse from src.process_dict import simplify_and_sum_transcripts from src.gene_model import rank_and_visualize_genes +import os class FindGenesAction(argparse.Action): @@ -64,10 +65,54 @@ def parse_arguments(): help="Path to a CSV file containing known target genes.", default=None, ) - return parser.parse_args() + + args = parser.parse_args() + + # If --find_genes is used, prompt for reference condition + if args.find_genes: + output = OutputConfig( + args.output_directory, + use_counts=args.counts, + ref_only=args.ref_only, + gtf=args.gtf, + ) + + # Read the first line of the transcript_grouped_tpm file to get the conditions + with open(output.transcript_grouped_tpm, "r") as f: + header = f.readline().strip().split("\t") + + # The first column is typically '#feature_id', so we skip it + conditions = header[1:] + + if len(conditions) == 2: + # If there are only two conditions, automatically use the first as reference + args.reference_condition = conditions[0] + print( + f"Automatically selected '{args.reference_condition}' as the reference condition." + ) + else: + print("Available conditions:") + for i, condition in enumerate(conditions, 1): + print(f"{i}. {condition}") + + while True: + try: + choice = int( + input("Enter the number of the condition to use as reference: ") + ) + if 1 <= choice <= len(conditions): + args.reference_condition = conditions[choice - 1] + break + else: + print("Invalid choice. Please enter a number from the list.") + except ValueError: + print("Invalid input. Please enter a number.") + + return args def main(): + print("Reading IsoQuant parameters.") args = parse_arguments() output = OutputConfig( args.output_directory, @@ -78,7 +123,9 @@ def main(): dictionary_builder = DictionaryBuilder(output) gene_list = dictionary_builder.read_gene_list(args.gene_list) update_names = not all(gene.startswith("ENS") for gene in gene_list) + print("Building gene, transcript, and exon dictionaries.") gene_dict = dictionary_builder.build_gene_transcript_exon_dictionaries() + print("Building read assignment and classification dictionaries.") reads_and_class = ( dictionary_builder.build_read_assignment_and_classification_dictionaries() ) @@ -151,7 +198,11 @@ def main(): ) # Visualization output directory decision - viz_output_directory = args.viz_output if args.viz_output else args.output_directory + if args.viz_output: + viz_output_directory = args.viz_output + else: + viz_output_directory = os.path.join(args.output_directory, "visualization") + os.makedirs(viz_output_directory, exist_ok=True) if args.find_genes: print("Finding genes.") @@ -161,15 +212,23 @@ def main(): viz_output_directory, args.find_genes, known_genes_path=args.known_genes_path, + reference_condition=args.reference_condition, ) gene_list = dictionary_builder.read_gene_list(path) - # dictionary_builder.save_gene_dict_to_json(updated_gene_dict, viz_output_directory) + # Create gene_visualizations subdirectory + viz_output_directory = os.path.join(viz_output_directory, "gene_visualizations") + os.makedirs(viz_output_directory, exist_ok=True) + + # Create read_assignments subdirectory + read_assignments_dir = os.path.join(viz_output_directory, "read_assignments") + os.makedirs(read_assignments_dir, exist_ok=True) + plot_output = PlotOutput( updated_gene_dict, gene_list, viz_output_directory, - create_visualization_subdir=(viz_output_directory == args.output_directory), + read_assignments_dir=read_assignments_dir, reads_and_class=reads_and_class, filter_transcripts=args.filter_transcripts, conditions=output.conditions, From 11bdbd92bcb23ee94102f61b26ef2a22fa950893 Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 20 Sep 2024 17:44:26 +0300 Subject: [PATCH 14/35] GFF3 checker --- src/gtf2db.py | 75 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/src/gtf2db.py b/src/gtf2db.py index ef6fc4fa..ee82a43d 100755 --- a/src/gtf2db.py +++ b/src/gtf2db.py @@ -59,6 +59,8 @@ def get_color(transcript_kind): gene_name = record["gene_name"][0] elif "gene_id" in record.attributes: gene_name = record["gene_id"][0] + elif "Parent" in record.attributes: + gene_name = record["Parent"][0] else: gene_name = "unknown_gene" transcript_name = record.id + "|" + transcript_type + "|" + gene_name @@ -97,12 +99,13 @@ def check_input_gtf(gtf, db, complete_db): gtf_is_correct, corrected_gtf, out_fname, has_meta_features = check_gtf_duplicates(gtf) if not gtf_is_correct: outdir = os.path.dirname(db) - new_gtf_path = os.path.join(outdir, out_fname) - with open(new_gtf_path, "w") as out_gtf: - out_gtf.write(corrected_gtf) logger.error("Input GTF seems to be corrupted (see warnings above).") - logger.error("An attempt to correct this GTF was made, the result is written to %s" % new_gtf_path) - logger.error("NB! some transcript / gene ids in the corrected annotation are modified.") + if out_fname and corrected_gtf: + new_gtf_path = os.path.join(outdir, out_fname) + with open(new_gtf_path, "w") as out_gtf: + out_gtf.write(corrected_gtf) + logger.error("An attempt to correct this GTF was made, the result is written to %s" % new_gtf_path) + logger.error("NB! some transcript / gene ids in the corrected annotation are modified.") logger.error("Provide a correct GTF by fixing the original input GTF or checking the corrected one.") exit(-3) else: @@ -167,6 +170,10 @@ def check_gtf_duplicates(gtf): handle = open(gtf, "rt") inner_ext = outer_ext + if inner_ext.lower() == 'gff3': + return check_gff3_duplicates(handle) + + gff3_checked = False for l in handle.readlines(): line_count += 1 if l.startswith("#"): @@ -177,8 +184,14 @@ def check_gtf_duplicates(gtf): corrected_gtf += l continue - feature_type = v[2] - attrs = v[8].split(" ") + attribute_column = v[8] + if not gff3_checked: + gff3_checked = True + if attribute_column.find("ID=") != -1: + handle.seek(0) + return check_gff3_duplicates(handle) + + attrs = attribute_column.split(" ") gene_id_pos = -1 for i in range(len(attrs)): @@ -190,6 +203,7 @@ def check_gtf_duplicates(gtf): gtf_correct = False continue + feature_type = v[2] gene_str = attrs[gene_id_pos + 1] start_pos = gene_str.find('"') end_pos = gene_str.rfind('"') @@ -259,6 +273,53 @@ def check_gtf_duplicates(gtf): return gtf_correct, corrected_gtf, gtf_name + ".corrected" + inner_ext.lower(), complete_genedb +def check_gff3_duplicates(handle): + gtf_correct = True + gene_count = 0 + transcript_count = 0 + line_count = 0 + feature_ids = {} + + for l in handle.readlines(): + line_count += 1 + if l.startswith("#"): + continue + v = l.strip().split("\t") + if len(v) < 9: + continue + + feature_type = v[2] + if feature_type == 'gene': + gene_count += 1 + elif feature_type in ["transcript", "mRNA"]: + transcript_count += 1 + + attrs = v[8].split(";") + id_pos = -1 + for i in range(len(attrs)): + if attrs[i].startswith('ID'): + id_pos = i + if id_pos == -1: + if feature_type in ["gene", "transcript", "mRNA"]: + logger.warning("Malformed GTF line %d (ID attribute value cannot be found)" % line_count) + logger.warning(l.strip()) + gtf_correct = False + continue + + id_str = attrs[id_pos] + id_value = id_str.split("=")[1] + if id_value in feature_ids: + logger.warning("Duplicated ID %s on line %d" % (id_value, line_count)) + gtf_correct = False + feature_ids[id_value] += 1 + + complete_genedb = 1 + if transcript_count == 0 or gene_count == 0: + complete_genedb = -1 + + return gtf_correct, None, None, complete_genedb + + def find_converted_db(converted_gtfs, gtf_filename, complete_genedb): gtf_mtime = converted_gtfs.get(gtf_filename, {}).get('gtf_mtime') db_mtime = converted_gtfs.get(gtf_filename, {}).get('db_mtime') From a7e1593c34a9f51e99f751fa8157020a1b406c51 Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 20 Sep 2024 18:14:42 +0300 Subject: [PATCH 15/35] import exon information from reference annotation --- src/gene_info.py | 29 ++++++++++++++++++----------- src/transcript_printer.py | 15 ++++++++++----- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/gene_info.py b/src/gene_info.py index 469bceae..a3f1e067 100644 --- a/src/gene_info.py +++ b/src/gene_info.py @@ -183,7 +183,7 @@ def __init__(self, gene_db_list, db, delta=0, prepare_profiles=True): self.set_sources() self.gene_id_map = {} self.set_gene_ids() - self.gene_attributes = {} + self.feature_attributes = {} self.set_gene_attributes() if prepare_profiles: self.exon_property_map = self.set_feature_properties(self.all_isoforms_exons, self.exon_profiles) @@ -208,7 +208,7 @@ def from_models(cls, transcript_model_storage, delta=0): gene_info.sources = {} gene_info.other_features = {} gene_info.gene_id_map = {} - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} introns = set() exons = set() @@ -294,7 +294,7 @@ def from_model(cls, transcript_model, delta=0): transcript_model.gene_id: transcript_model.source} gene_info.other_features = {transcript_model.transcript_id: transcript_model.other_features} gene_info.gene_id_map = {transcript_model.transcript_id: transcript_model.gene_id} - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} gene_info.regions_for_bam_fetch = [(gene_info.start, gene_info.end)] gene_info.exon_property_map = None @@ -332,7 +332,7 @@ def from_region(cls, chr_id, start, end, delta=0, chr_record=None): gene_info.other_features = {} gene_info.sources = {} gene_info.gene_id_map = {} - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} gene_info.regions_for_bam_fetch = [(start, end)] gene_info.exon_property_map = None gene_info.intron_property_map = None @@ -390,7 +390,7 @@ def deserialize(cls, infile, genedb): gene_info.set_sources() gene_info.gene_id_map = {} gene_info.set_gene_ids() - gene_info.gene_attributes = {} + gene_info.feature_attributes = {} gene_info.set_gene_attributes() gene_info.exon_property_map = gene_info.set_feature_properties(gene_info.all_isoforms_exons, gene_info.exon_profiles) gene_info.intron_property_map = gene_info.set_feature_properties(gene_info.all_isoforms_introns, gene_info.intron_profiles) @@ -475,19 +475,26 @@ def set_gene_ids(self): self.gene_id_map[t.id] = gene_db.id def set_gene_attributes(self): - self.gene_attributes = defaultdict(str) + self.feature_attributes = defaultdict(str) for gene_db in self.gene_db_list: for attr in gene_db.attributes.keys(): - if attr in ['gene_id', 'ID', 'level']: + if attr in ['gene_id', 'ID', 'level', 'Parent']: continue if gene_db.attributes[attr]: - self.gene_attributes[gene_db.id] += '%s "%s"; ' % (attr, gene_db.attributes[attr][0]) - for t in self.db.children(gene_db, featuretype=('transcript', 'mRNA'), order_by='start'): + self.feature_attributes[gene_db.id] += '%s "%s"; ' % (attr, gene_db.attributes[attr][0]) + for t in self.db.children(gene_db, featuretype=('transcript', 'mRNA')): for attr in t.attributes.keys(): - if attr in ['transcript_id', 'gene_id', 'ID', 'level', 'exons']: + if attr in ['transcript_id', 'gene_id', 'ID', 'level', 'exons', 'Parent']: continue if t.attributes[attr]: - self.gene_attributes[t.id] += '%s "%s"; ' % (attr, t.attributes[attr][0]) + self.feature_attributes[t.id] += '%s "%s"; ' % (attr, t.attributes[attr][0]) + for e in self.db.children(gene_db, featuretype=('exon')): + exon_id = t.id + "_%d_%d_%s" % (e.start, e.end, e.strand) + for attr in t.attributes.keys(): + if attr in ['transcript_id', 'gene_id', 'ID', 'Parent', 'level', 'exon_id', 'exon', 'exon_number']: + continue + if t.attributes[attr]: + self.feature_attributes[exon_id] += '%s "%s"; ' % (attr, t.attributes[attr][0]) # assigns an ordered list of all known exons and introns to self.exons and self.introns # returns 2 maps, isoform id -> intron / exon list diff --git a/src/transcript_printer.py b/src/transcript_printer.py index e84b8a52..3e00aa2e 100644 --- a/src/transcript_printer.py +++ b/src/transcript_printer.py @@ -99,8 +99,8 @@ def dump(self, gene_info, transcript_model_storage): for gene_id, coords in gene_order: if gene_id not in self.printed_gene_ids: gene_additiional_info = "" - if gene_info and gene_id in gene_info.gene_attributes: - gene_additiional_info = gene_info.gene_attributes[gene_id] + if gene_info and gene_id in gene_info.feature_attributes: + gene_additiional_info = gene_info.feature_attributes[gene_id] source = "IsoQuant" if gene_info and gene_id in gene_info.sources: source = gene_info.sources[gene_id] @@ -117,8 +117,8 @@ def dump(self, gene_info, transcript_model_storage): if not model.check_additional("exons"): model.add_additional_attribute("exons", str(len(model.exon_blocks))) transcript_additiional_info = "" - if gene_info and model.transcript_id in gene_info.gene_attributes: - transcript_additiional_info = " " + gene_info.gene_attributes[model.transcript_id] + if gene_info and model.transcript_id in gene_info.feature_attributes: + transcript_additiional_info = " " + gene_info.feature_attributes[model.transcript_id] transcript_line = '%s\t%s\ttranscript\t%d\t%d\t.\t%s\t.\tgene_id "%s"; transcript_id "%s"; %s\n' \ % (model.chr_id, model.source, model.exon_blocks[0][0], model.exon_blocks[-1][1], @@ -137,9 +137,14 @@ def dump(self, gene_info, transcript_model_storage): exons_to_print = sorted(exons_to_print, reverse=True) if model.strand == '-' else sorted(exons_to_print) for i, e in enumerate(exons_to_print): exon_str_id = self.exon_id_storage.get_id(model.chr_id, e, model.strand) + + exon_id = model.transcript_id + "_%d_%d_%s" % (e[0], e[1], model.strand) + exon_additiional_info = "" + if gene_info and exon_id in gene_info.feature_attributes: + exon_additiional_info = " " + gene_info.feature_attributes[model.transcript_id] feature_type = e[2] self.out_gff.write(prefix_columns + "%s\t%d\t%d\t" % (feature_type, e[0], e[1]) + suffix_columns + - ' exon "%d"; exon_id "%s";\n' % ((i + 1), exon_str_id)) + ' exon_number "%d"; exon_id "%s"; %s\n' % ((i + 1), exon_str_id, exon_additiional_info)) self.out_gff.flush() def dump_read_assignments(self, transcript_model_constructor): From 51d39e34d3715da4b82e01de33597ebea300d029 Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Thu, 26 Sep 2024 02:53:56 +0300 Subject: [PATCH 16/35] get rid of side-effect --- src/intron_graph.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/intron_graph.py b/src/intron_graph.py index 7a459b82..64d12789 100644 --- a/src/intron_graph.py +++ b/src/intron_graph.py @@ -209,6 +209,8 @@ def signleton_dead_start(self, v): def get_outgoing(self, intron, v_type=None): res = [] + if intron not in self.outgoing_edges: + return res if v_type is None: for v in self.outgoing_edges[intron]: if v[0] >= 0: @@ -221,6 +223,8 @@ def get_outgoing(self, intron, v_type=None): def get_incoming(self, intron, v_type=None): res = [] + if intron not in self.incoming_edges: + return res if v_type is None: for v in self.incoming_edges[intron]: if v[0] >= 0: From b1270212d62ce1c9f5840fe61547b7d79cd21b74 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Tue, 28 Jan 2025 12:11:00 -0600 Subject: [PATCH 17/35] Fixed dictionary building and caching --- src/visualization_cache_utils.py | 158 +++++ src/visualization_dictionary_builder.py | 748 ++++++++++++++++++++++++ 2 files changed, 906 insertions(+) create mode 100644 src/visualization_cache_utils.py create mode 100644 src/visualization_dictionary_builder.py diff --git a/src/visualization_cache_utils.py b/src/visualization_cache_utils.py new file mode 100644 index 00000000..40ae4534 --- /dev/null +++ b/src/visualization_cache_utils.py @@ -0,0 +1,158 @@ +import pickle +import logging +import time +from pathlib import Path +from typing import Dict, Any, Optional, Union +import random +import re + + +def build_gene_dict_cache_file( + extended_annotation: Optional[str], input_gtf: str, ref_only: bool, cache_dir: Path +) -> Path: + """ + Generate a gene dictionary cache filename based on: + - Which annotation file we're using (extended vs. reference GTF). + - The modification time of that file. + - The ref_only setting. + """ + if extended_annotation and not ref_only: + source_file = Path(extended_annotation) + source_type = "extended" + else: + source_file = Path(input_gtf) + source_type = "reference" + mtime = source_file.stat().st_mtime + cache_name = f"gene_dict_cache_{source_type}_{source_file.name}_{mtime}_ref_only_{ref_only}.pkl" + return cache_dir / cache_name + + +def build_read_assignment_cache_file( + read_assignments: Union[str, list], ref_only: bool, cache_dir: Path +) -> Path: + """ + Generate a read-assignment cache filename based on: + - The read assignment file(s). + - Possibly their modification times. + - The ref_only setting. + """ + if isinstance(read_assignments, str): + source_file = Path(read_assignments) + mtime = source_file.stat().st_mtime + cache_name = ( + f"read_assignment_cache_{source_file.name}_{mtime}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + elif isinstance(read_assignments, list): + # Build a composite name from the multiple input files + file_info = [] + for sample_name, path_str in read_assignments: + path_obj = Path(path_str) + file_info.append( + f"{sample_name}-{path_obj.name}-{path_obj.stat().st_mtime}" + ) + composite_name = "_".join(file_info).replace(" ", "_")[:100] + cache_name = ( + f"read_assignment_cache_multi_{composite_name}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + else: + return cache_dir / "read_assignment_cache_default.pkl" + + +def save_cache(cache_file: Path, data_to_cache: Any) -> None: + """Save data to a cache file using pickle.""" + try: + with open(cache_file, 'wb') as f: + pickle.dump(data_to_cache, f, protocol=pickle.HIGHEST_PROTOCOL) # Save the entire tuple + logging.debug(f"Successfully saved cache to {cache_file}") + except Exception as e: + logging.error(f"Error saving cache to {cache_file}: {e}") + + +def load_cache(cache_file: Path) -> Any: + """Load data from a cache file.""" + try: + with open(cache_file, 'rb') as f: + cached_data = pickle.load(f) + if isinstance(cached_data, tuple) and len(cached_data) == 3: # Check if it's the new tuple format + gene_dict, novel_gene_ids, novel_transcript_ids = cached_data # Unpack tuple + return gene_dict, novel_gene_ids, novel_transcript_ids # Return the tuple + else: # Handle old cache format (just gene_dict) + return cached_data # Return just the gene_dict for backward compatibility + except FileNotFoundError: + logging.debug(f"Cache file not found: {cache_file}") + return None # Indicate cache miss + except Exception as e: + logging.error(f"Error loading cache from {cache_file}: {e}") + return None + + +def validate_gene_dict(gene_dict: Dict, ref_only: bool = False) -> bool: + """Enhanced validation with novel gene check.""" + if not gene_dict: + return False + + # Always check for novel genes regardless of ref_only mode + novel_genes = sum(1 for condition in gene_dict.values() + for gene_id in condition.keys() + if re.match(r"novel_gene", gene_id)) + if novel_genes > 0: + logging.warning(f"Found {novel_genes} novel genes in cached dictionary. Rebuilding required.") + return False + + # Existing structure validation + try: + for condition in gene_dict.values(): + for gene_info in condition.values(): + if not all(k in gene_info for k in ["chromosome", "start", "end", "strand", "transcripts"]): + return False + return True + except (KeyError, AttributeError): + return False + + +def validate_read_assignment_data( + data: Any, read_assignments: Union[str, list] +) -> bool: + """ + Validate the structure of the cached read-assignment data. + """ + try: + if isinstance(read_assignments, list): + # Expecting something like: + # { + # "classification_counts": { "sampleA": {...}, "sampleB": {...} }, + # "assignment_type_counts": { "sampleA": {...}, "sampleB": {...} } + # } + if not isinstance(data, dict): + return False + if ( + "classification_counts" not in data + or "assignment_type_counts" not in data + ): + return False + return True + else: + # Single file scenario: We expect a 2-tuple (classification_counts, assignment_type_counts) + if not isinstance(data, (tuple, list)) or len(data) != 2: + return False + return True + except Exception as e: + logging.error(f"Read-assignment validation error: {e}") + return False + + +def cleanup_cache(cache_dir: Path, max_age_days: int = 7) -> None: + """ + Remove cache files older than specified days. + """ + current_time = time.time() + for cache_file in cache_dir.glob("*.pkl"): + file_age_days = (current_time - cache_file.stat().st_mtime) / (24 * 3600) + if file_age_days > max_age_days: + try: + cache_file.unlink() + logging.info(f"Removed old cache file: {cache_file}") + except Exception as e: + logging.warning(f"Failed to remove cache file {cache_file}: {e}") diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py new file mode 100644 index 00000000..85821946 --- /dev/null +++ b/src/visualization_dictionary_builder.py @@ -0,0 +1,748 @@ +import copy +import gffutils +import pandas as pd +import re +import logging +from pathlib import Path +from typing import Dict, Any, List, Union, Tuple + +from src.visualization_cache_utils import ( + build_gene_dict_cache_file, + build_read_assignment_cache_file, + save_cache, + load_cache, + validate_gene_dict, + validate_read_assignment_data, + cleanup_cache, +) + + +class DictionaryBuilder: + def __init__(self, config): + self.config = config + self.cache_dir = Path(config.output_directory) / ".cache" + self.cache_dir.mkdir(exist_ok=True) + + # Set up logger for DictionaryBuilder + self.logger = logging.getLogger('IsoQuant.visualization.dictionary_builder') + self.logger.setLevel(logging.DEBUG) + + # Initialize sets to store novel gene and transcript IDs + self.novel_gene_ids = set() + self.novel_transcript_ids = set() + + # Clean up old cache files on init + cleanup_cache(self.cache_dir, max_age_days=7) + + def build_gene_dict_with_expression_and_filter( + self, min_value: float = 1.0 + ) -> Dict[str, Any]: + """ + Optimized build process with early filtering and combined steps. + """ + self.logger.debug(f"Starting optimized dictionary build with min_value={min_value}") + + # 1. Check cache first + expr_file, tpm_file = self._get_expression_files() + base_cache_file = build_gene_dict_cache_file( + self.config.extended_annotation, + expr_file, + self.config.ref_only, + self.cache_dir, + ) + expr_filter_cache = base_cache_file.parent / ( + f"{base_cache_file.stem}_with_expr_minval_{min_value}.pkl" + ) + + if expr_filter_cache.exists(): + cached_data = load_cache(expr_filter_cache) + if cached_data and len(cached_data) == 3: # Check if load_cache returned a tuple + cached_gene_dict, cached_novel_gene_ids, cached_novel_transcript_ids = cached_data # Unpack tuple + if validate_gene_dict(cached_gene_dict): + self.novel_gene_ids = cached_novel_gene_ids # Restore from cache + self.novel_transcript_ids = cached_novel_transcript_ids # Restore from cache + return cached_gene_dict + else: # Handle older cache format (just gene_dict) + cached_gene_dict = cached_data + if validate_gene_dict(cached_gene_dict): + return cached_gene_dict + + # 2. Filter novel genes from the base gene dict (not per-condition) + self.logger.info("Parsing GTF and filtering novel genes") + parsed_data = self.parse_extended_annotation() + self._validate_gene_structure(parsed_data) + base_gene_dict = self._filter_novel_genes(parsed_data) + + # Add debug log: Number of genes and transcripts after novel gene filtering + gene_count_after_novel_filter = len(base_gene_dict) + transcript_count_after_novel_filter = sum(len(gene_info.get("transcripts", {})) for gene_info in base_gene_dict.values()) + self.logger.debug(f"After novel gene filtering: {gene_count_after_novel_filter} genes, {transcript_count_after_novel_filter} transcripts") + + # 3. Load expression data with consistent header handling + self.logger.info("Loading expression matrix") + try: + expr_df = pd.read_csv(expr_file, sep='\t', comment=None) + expr_df.columns = [col.lstrip('#') for col in expr_df.columns] # Clean headers + expr_df = expr_df.set_index('feature_id') # Use cleaned column name + except KeyError as e: + self.logger.error(f"Missing required column in {expr_file}: {str(e)}") + raise + except Exception as e: + self.logger.error(f"Failed to load expression matrix: {str(e)}") + raise + + conditions = expr_df.columns.tolist() + + # 4. Vectorized processing instead of row-wise iteration + transcript_max_values = expr_df.max(axis=1) + valid_transcripts = set(transcript_max_values[transcript_max_values >= min_value].index) + + # Add debug log: Number of valid transcripts after min_value filtering + valid_transcript_count = len(valid_transcripts) + self.logger.debug(f"After min_value ({min_value}) filtering: {valid_transcript_count} valid transcripts") + + # 5. Single-pass filtering and value updating + filtered_dict = {} + for condition in conditions: + filtered_dict[condition] = {} + condition_values = expr_df[condition] + + for gene_id, gene_info in base_gene_dict.items(): + new_transcripts = { + tid: {**tinfo, 'value': condition_values.get(tid, 0)} + for tid, tinfo in gene_info['transcripts'].items() + if tid in valid_transcripts + } + + if new_transcripts: + filtered_dict[condition][gene_id] = { + **gene_info, + 'transcripts': new_transcripts + } + self._validate_gene_structure(filtered_dict[condition]) # Validate structure for each condition's gene dict + + save_cache(expr_filter_cache, (filtered_dict, self.novel_gene_ids, self.novel_transcript_ids)) # Save tuple to cache + return filtered_dict + + def _get_expression_files(self) -> Tuple[str, str]: + """Get count file for filtering and TPM file for values.""" + # Get counts file path using existing logic + counts_file = self._get_expression_file() + + # Get corresponding TPM file path + if self.config.conditions: + tpm_file = self.config.transcript_grouped_tpm + else: + if self.config.ref_only: + tpm_file = self.config.transcript_tpm_ref + else: + base_file = self.config.transcript_tpm.replace('.tsv', '') + tpm_file = f"{base_file}_merged.tsv" + + self.logger.debug(f"Selected TPM file: {tpm_file}") + if not tpm_file or not Path(tpm_file).exists(): + raise FileNotFoundError(f"TPM file {tpm_file} not found") + + return counts_file, tpm_file + + def _get_expression_file(self) -> str: + """Get the appropriate count file path from config.""" + if self.config.conditions: # Check if we have multiple conditions + expr_file = self.config.transcript_grouped_counts + else: + if self.config.ref_only: + expr_file = self.config.transcript_counts_ref + else: + base_file = self.config.transcript_counts.replace('.tsv', '') + expr_file = f"{base_file}_merged.tsv" + + self.logger.debug(f"Selected count file: {expr_file}") + if not expr_file or not Path(expr_file).exists(): + raise FileNotFoundError(f"Count file {expr_file} not found") + return expr_file + + def _filter_transcripts_above_threshold( + self, gene_dict: Dict[str, Any], min_value: float + ) -> Dict[str, Any]: + """Filter transcripts based on expression threshold.""" + self.logger.info(f"Starting transcript filtering with threshold {min_value}") + + # Track transcripts and their maximum values across all conditions + transcript_max_values = {} + condition_names = list(gene_dict.keys()) + + # First pass: find maximum value for each transcript across all conditions + total_transcripts_before = 0 + for condition in condition_names: + condition_transcripts = sum(len(gene_info.get("transcripts", {})) + for gene_info in gene_dict[condition].values()) + total_transcripts_before += condition_transcripts + self.logger.info(f"Condition {condition}: {condition_transcripts} transcripts before filtering") + + # Log sample of transcripts (max 2 per condition) + sample_transcripts = [] + for gene_info in gene_dict[condition].values(): + sample_transcripts.extend(list(gene_info.get("transcripts", {}).keys())[:2]) + if len(sample_transcripts) >= 2: + break + if sample_transcripts: + self.logger.debug(f" Sample transcripts in {condition}: {sample_transcripts[:2]}") + + self.logger.info(f"Found {len(transcript_max_values)} unique transcripts across all conditions") + + # Sample of transcripts before filtering + sample_before = list(transcript_max_values.keys())[:5] + self.logger.debug(f"Sample transcripts before filtering: {sample_before}") + + # Build filtered dictionary + filtered_dict = {} + kept_transcripts = set() + + for tid, max_value in transcript_max_values.items(): + if max_value >= min_value: + kept_transcripts.add(tid) + + self.logger.info(f"Keeping {len(kept_transcripts)} transcripts that meet threshold {min_value}") + + # Sample of kept and filtered transcripts + sample_kept = list(kept_transcripts)[:5] + sample_filtered = list(set(transcript_max_values.keys()) - kept_transcripts)[:5] + self.logger.debug(f"Sample kept transcripts: {sample_kept}") + self.logger.debug(f"Sample filtered transcripts: {sample_filtered}") + + # Create filtered dictionary with same structure as input + for condition in condition_names: + filtered_dict[condition] = {} + for gene_id, gene_info in gene_dict[condition].items(): + new_gene_info = copy.deepcopy(gene_info) + new_transcripts = {} + + for tid, tinfo in gene_info.get("transcripts", {}).items(): + if tid in kept_transcripts: + new_transcripts[tid] = tinfo + + if new_transcripts: # Only keep genes that have remaining transcripts + new_gene_info["transcripts"] = new_transcripts + filtered_dict[condition][gene_id] = new_gene_info + + # Log final statistics + for condition in condition_names: + final_count = sum(len(gene_info.get("transcripts", {})) + for gene_info in filtered_dict[condition].values()) + self.logger.debug(f" {condition}: {final_count} transcripts") + + return filtered_dict + + # ------------------ READ ASSIGNMENT CACHING ------------------ + + def build_read_assignment_and_classification_dictionaries(self): + """ + Index classifications and assignment types from read_assignments.tsv file(s). + Returns either: + - (classification_counts, assignment_type_counts) for single-file input, or + - (classification_counts_dict, assignment_type_counts_dict) for multi-file (YAML) input. + """ + if not self.config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + # 1. Determine cache file + cache_file = build_read_assignment_cache_file( + self.config.read_assignments, self.config.ref_only, self.cache_dir + ) + + # 2. Attempt to load from cache + if cache_file.exists(): + cached_data = load_cache(cache_file) + if cached_data and validate_read_assignment_data( + cached_data, self.config.read_assignments + ): + self.logger.info("Using cached read assignment data.") + return self._post_process_cached_data(cached_data) + + # 3. Otherwise, build from scratch + self.logger.info("Building read assignment data from scratch.") + if isinstance(self.config.read_assignments, list): + classification_counts_dict = {} + assignment_type_counts_dict = {} + for sample_name, read_assignment_file in self.config.read_assignments: + c_counts, a_counts = self._process_read_assignment_file( + read_assignment_file + ) + classification_counts_dict[sample_name] = c_counts + assignment_type_counts_dict[sample_name] = a_counts + + data_to_cache = { + "classification_counts": classification_counts_dict, + "assignment_type_counts": assignment_type_counts_dict, + } + save_cache(cache_file, data_to_cache) + return classification_counts_dict, assignment_type_counts_dict + else: + classification_counts, assignment_type_counts = ( + self._process_read_assignment_file(self.config.read_assignments) + ) + data_to_cache = (classification_counts, assignment_type_counts) + save_cache(cache_file, data_to_cache) + return classification_counts, assignment_type_counts + + def _post_process_cached_data(self, cached_data): + """ + Convert cached_data back to the return format + for build_read_assignment_and_classification_dictionaries(). + """ + if isinstance(self.config.read_assignments, list): + return ( + cached_data["classification_counts"], + cached_data["assignment_type_counts"], + ) + return cached_data # (classification_counts, assignment_type_counts) + + def _process_read_assignment_file(self, file_path): + """ + Parse a read_assignment TSV file, returning: + - classification_counts: dict(classification -> count) + - assignment_type_counts: dict(assignment type -> count) + """ + classification_counts = {} + assignment_type_counts = {} + + with open(file_path, "r") as file: + # Skip header lines + for _ in range(3): + next(file, None) + + for line in file: + parts = line.strip().split("\t") + if len(parts) < 6: + continue + + additional_info = parts[-1] + classification = ( + additional_info.split("Classification=")[-1].split(";")[0].strip() + ) + assignment_type = parts[5] + + classification_counts[classification] = ( + classification_counts.get(classification, 0) + 1 + ) + assignment_type_counts[assignment_type] = ( + assignment_type_counts.get(assignment_type, 0) + 1 + ) + + return classification_counts, assignment_type_counts + + # -------------------- GTF PARSING -------------------- + + def parse_input_gtf(self) -> Dict[str, Any]: + """ + Parse the reference GTF file using gffutils with optimized settings, + building a dictionary of genes, transcripts, and exons. + """ + if not self.config.genedb_filename: + db_path = self.cache_dir / "gtf.db" + if not db_path.exists(): + self.logger.info(f"Creating GTF database at {db_path}") + gffutils.create_db( + self.config.input_gtf, + dbfn=str(db_path), + force=True, + merge_strategy="create_unique", # Faster than merge + disable_infer_genes=True, + disable_infer_transcripts=True, + verbose=False, + ) + self.config.genedb_filename = str(db_path) + + self.logger.info("Opening GTF database") + db = gffutils.FeatureDB(self.config.genedb_filename) + + # Pre-fetch all features + self.logger.info("Pre-fetching features from database") + features = {feature.id: feature for feature in db.all_features()} + + # Build gene -> transcripts -> exons structure + gene_dict = {} + self.logger.info("Processing gene features") + for feature in features.values(): + if feature.featuretype != "gene": + continue + + gene_id = feature.id + gene_dict[gene_id] = { + "chromosome": feature.seqid, + "start": feature.start, + "end": feature.end, + "strand": feature.strand, + "name": feature.attributes.get("gene_name", [""])[0], + "biotype": feature.attributes.get("gene_biotype", [""])[0], + "transcripts": {}, + } + + self.logger.info("Processing transcript and exon features") + for feature in features.values(): + if feature.featuretype == "transcript": + gene_id = feature.attributes.get("gene_id", [""])[0] + if gene_id not in gene_dict: + continue + + transcript_id = feature.id + gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": feature.start, + "end": feature.end, + "name": feature.attributes.get("transcript_name", [""])[0], + "biotype": feature.attributes.get("transcript_biotype", [""])[0], + "exons": [], + "tags": feature.attributes.get("tag", [""])[0].split(","), + } + elif feature.featuretype == "exon": + gene_id = feature.attributes.get("gene_id", [""])[0] + transcript_id = feature.attributes.get("transcript_id", [""])[0] + if ( + gene_id in gene_dict + and transcript_id in gene_dict[gene_id]["transcripts"] + ): + gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append( + { + "exon_id": feature.id, + "start": feature.start, + "end": feature.end, + "number": feature.attributes.get("exon_number", [""])[0], + } + ) + + self.logger.info(f"Processed {len(gene_dict)} genes from GTF") + return gene_dict + + def parse_extended_annotation(self) -> Dict[str, Any]: + """Parse merged GTF into base structure without condition info.""" + base_gene_dict = {} + self.logger.info("Parsing extended annotation GTF (non-ref_only)") + + try: + with open(self.config.extended_annotation, "r") as file: + attr_pattern = re.compile(r'(\S+) "([^"]+)";') + + # First pass: genes and transcripts + for line in file: + if line.startswith("#") or not line.strip(): + continue + + fields = line.strip().split("\t") + if len(fields) < 9: + continue + + feature_type = fields[2] + attrs = dict(attr_pattern.findall(fields[8])) + gene_id = attrs.get("gene_id") + transcript_id = attrs.get("transcript_id") + + if feature_type == "gene" and gene_id: + if gene_id not in base_gene_dict: + base_gene_dict[gene_id] = { + "chromosome": fields[0], + "start": int(fields[3]), + "end": int(fields[4]), + "strand": fields[6], + "name": attrs.get("gene_name", gene_id), + "biotype": attrs.get("gene_biotype", "unknown"), + "transcripts": {} + } + + elif feature_type == "transcript" and gene_id and transcript_id: + if gene_id not in base_gene_dict: + base_gene_dict[gene_id] = { + "chromosome": fields[0], + "start": int(fields[3]), + "end": int(fields[4]), + "strand": fields[6], + "name": attrs.get("gene_name", gene_id), + "biotype": attrs.get("gene_biotype", "unknown"), + "transcripts": {} + } + + base_gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": int(fields[3]), + "end": int(fields[4]), + "exons": [], + "expression": 0.0, + "tags": attrs.get("tags", "").split(",") + } + + elif feature_type == "exon" and transcript_id and gene_id: + if gene_id in base_gene_dict and transcript_id in base_gene_dict[gene_id]["transcripts"]: + exon_info = { + "start": int(fields[3]), + "end": int(fields[4]), + "number": attrs.get("exon_number", "1") + } + base_gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append(exon_info) + + self.logger.info(f"Parsed base structure: {len(base_gene_dict)} genes") + return base_gene_dict + except Exception as e: + self.logger.error(f"GTF parsing failed: {str(e)}") + raise + + def update_transcript_values(self, gene_dict: Dict[str, Any], counts_file: str, tpm_file: str) -> Dict[str, Any]: + """Update transcript values from TPM file after filtering with counts.""" + # Read counts for filtering + counts_df = pd.read_csv(counts_file, sep='\t', comment=None) + counts_df.columns = [col.lstrip('#') for col in counts_df.columns] + counts_df = counts_df.set_index('feature_id') + + # Read TPMs for values + tpm_df = pd.read_csv(tpm_file, sep='\t', comment=None) + tpm_df.columns = [col.lstrip('#') for col in tpm_df.columns] + tpm_df = tpm_df.set_index('feature_id') + + # Align indices between counts and TPMs + common_transcripts = counts_df.index.intersection(tpm_df.index) + tpm_df = tpm_df.loc[common_transcripts] + + # Rest of existing update logic using tpm_df instead of expr_df + condition_gene_dict = {condition: copy.deepcopy(gene_dict) for condition in tpm_df.columns} + + for tid, row in tpm_df.iterrows(): + base_tid = tid.split('.')[0] + for condition, tpm_value in row.items(): + # Add nested loop to access gene_info + for gene_id, gene_info in condition_gene_dict[condition].items(): + if base_tid in gene_info.get('transcripts', {}): + gene_info['transcripts'][base_tid]['value'] = float(tpm_value) + + return condition_gene_dict + + # -------------------- UPDATES & UTILITIES -------------------- + + def update_gene_names(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Update gene and transcript identifiers to their names, if available, + while preserving all nested structure. + """ + try: + updated_dict = {} + total_transcripts = 0 + + for condition, genes in gene_dict.items(): + updated_genes = {} + condition_transcripts = 0 + + for gene_id, gene_info in genes.items(): + new_gene_info = copy.deepcopy(gene_info) + + # Update gene name + if "name" in gene_info and gene_info["name"]: + gene_name_upper = gene_info["name"].upper() + updated_genes[gene_name_upper] = new_gene_info + else: + updated_genes[gene_id] = new_gene_info + + # Count transcripts + transcripts = new_gene_info.get("transcripts", {}) + condition_transcripts += len(transcripts) + + # Debug sample of transcript structure + if gene_id == list(genes.keys())[0]: + self.logger.debug(f"Sample gene {gene_id} transcript structure:") + for tid in list(transcripts.keys())[:3]: + self.logger.debug(f"Transcript {tid}: {transcripts[tid]}") + + total_transcripts += condition_transcripts + updated_dict[condition] = updated_genes + self.logger.debug(f"Condition {condition}: {condition_transcripts} transcripts") + + self.logger.info(f"Updated gene names for {len(gene_dict)} conditions") + self.logger.info(f"Total transcripts in dictionary: {total_transcripts}") + + return updated_dict + + except Exception as e: + self.logger.error(f"Error updating gene/transcript names: {e}") + self.logger.error(f"Dictionary structure before update: {str(type(gene_dict))}") + raise + + def read_gene_list(self, gene_list_path: Union[str, Path]) -> List[str]: + """ + Read and parse a plain-text file containing one gene identifier per line. + Return a list of uppercase gene IDs/names. + """ + try: + with open(gene_list_path, "r") as file: + gene_list = [line.strip().upper() for line in file if line.strip()] + self.logger.info(f"Read {len(gene_list)} genes from {gene_list_path}") + return gene_list + except Exception as e: + self.logger.error(f"Error reading gene list from {gene_list_path}: {e}") + raise + + def _filter_novel_genes(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: + """Filter out novel genes based on gene ID pattern.""" + self.logger.debug("Starting novel gene filtering") + filtered_dict = {} + total_removed_genes = 0 + total_removed_transcripts = 0 + checked_gene_count = 0 + sample_removed = [] # For debug logging + + novel_gene_pattern = r"novel_gene" # Make sure this pattern is correct for your novel gene IDs + + for gene_id, gene_info in gene_dict.items(): + checked_gene_count += 1 + is_novel = bool(re.match(novel_gene_pattern, gene_id)) + + if is_novel: + total_removed_genes += 1 + removed_transcript_count = len(gene_info.get("transcripts", {})) + total_removed_transcripts += removed_transcript_count + + # Add novel gene ID to the set + self.novel_gene_ids.add(gene_id) + # Add novel transcript IDs to the set + self.novel_transcript_ids.update(gene_info.get("transcripts", {}).keys()) + + + if len(sample_removed) < 5: # Sample log of removed genes + sample_removed.append({ + 'gene_id': gene_id, + 'transcript_count': removed_transcript_count + }) + continue # Skip adding novel genes to filtered_dict + + filtered_dict[gene_id] = gene_info # Keep known genes + + self.logger.info( + f"Filtered {total_removed_genes} novel genes " + f"({total_removed_genes/checked_gene_count:.2%} of total) " + f"and {total_removed_transcripts} associated transcripts" + ) + + if sample_removed: + sample_output = "\n".join( + f"- {g['gene_id']}: {g['transcript_count']} transcripts" + for g in sample_removed + ) + self.logger.debug(f"Sample removed novel genes:\n{sample_output}") + else: + self.logger.warning("No novel genes detected with current filtering pattern") + + return filtered_dict + + def get_novel_feature_ids(self) -> Tuple[set, set]: + """Return the sets of novel gene and transcript IDs.""" + return self.novel_gene_ids, self.novel_transcript_ids + + def _filter_low_expression_transcripts(self, condition_gene_dict: Dict[str, Any], min_value: float) -> Dict[str, Any]: + """Filter transcripts based on expression threshold.""" + self.logger.info(f"Starting transcript filtering with threshold {min_value}") + + # Track transcripts and their maximum values across all conditions + transcript_max_values = {} + condition_names = list(condition_gene_dict.keys()) + + # First pass: find maximum value for each transcript across all conditions + total_transcripts_before = 0 + for condition in condition_names: + condition_transcripts = 0 + for gene_info in condition_gene_dict[condition].values(): + for tid, tinfo in gene_info.get("transcripts", {}).items(): + current_value = tinfo.get('value', 0) + # Update maximum value tracking + if tid not in transcript_max_values or current_value > transcript_max_values[tid]: + transcript_max_values[tid] = current_value + condition_transcripts += 1 + total_transcripts_before += condition_transcripts + self.logger.info(f"Condition {condition}: {condition_transcripts} transcripts before filtering") + + # Log sample of transcripts (max 2 per condition) + sample_transcripts = [] + for gene_info in condition_gene_dict[condition].values(): + sample_transcripts.extend(list(gene_info.get("transcripts", {}).keys())[:2]) + if len(sample_transcripts) >= 2: + break + if sample_transcripts: + self.logger.debug(f" Sample transcripts in {condition}: {sample_transcripts[:2]}") + + self.logger.info(f"Found {len(transcript_max_values)} unique transcripts across all conditions") + + # Sample of transcripts before filtering + sample_before = list(transcript_max_values.keys())[:5] + self.logger.debug(f"Sample transcripts before filtering: {sample_before}") + + # Build filtered dictionary + filtered_dict = {} + kept_transcripts = set() + + for tid, max_value in transcript_max_values.items(): + if max_value >= min_value: + kept_transcripts.add(tid) + + self.logger.info(f"Keeping {len(kept_transcripts)} transcripts that meet threshold {min_value}") + + # Sample of kept and filtered transcripts + sample_kept = list(kept_transcripts)[:5] + sample_filtered = list(set(transcript_max_values.keys()) - kept_transcripts)[:5] + self.logger.debug(f"Sample kept transcripts: {sample_kept}") + self.logger.debug(f"Sample filtered transcripts: {sample_filtered}") + + # Create filtered dictionary with same structure as input + for condition in condition_names: + filtered_dict[condition] = {} + for gene_id, gene_info in condition_gene_dict[condition].items(): + new_gene_info = copy.deepcopy(gene_info) + new_transcripts = {} + + for tid, tinfo in gene_info.get("transcripts", {}).items(): + if tid in kept_transcripts: + new_transcripts[tid] = tinfo + + if new_transcripts: # Only keep genes that have remaining transcripts + new_gene_info["transcripts"] = new_transcripts + filtered_dict[condition][gene_id] = new_gene_info + + # Log final statistics + for condition in condition_names: + final_count = sum(len(gene_info.get("transcripts", {})) + for gene_info in filtered_dict[condition].values()) + self.logger.debug(f" {condition}: {final_count} transcripts") + + return filtered_dict + + def _batch_update_values(self, gene_dict, expr_df, valid_transcripts): + """Vectorized value updating for all conditions.""" + return { + condition: { + gene_id: { + **gene_info, + 'transcripts': { + tid: {**tinfo, 'value': expr_df.at[tid, condition]} + for tid, tinfo in gene_info['transcripts'].items() + if tid in valid_transcripts + } + } + for gene_id, gene_info in gene_dict.items() + if any(tid in valid_transcripts for tid in gene_info['transcripts']) + } + for condition in expr_df.columns + } + + def _validate_gene_structure(self, gene_dict: Dict[str, Any]) -> None: + """Ensure proper gene-centric structure before condition processing.""" + required_gene_keys = ['chromosome', 'start', 'end', 'strand', 'name', 'biotype', 'transcripts'] + + for gene_id, gene_info in gene_dict.items(): + # Check gene ID format + if not isinstance(gene_id, str) or len(gene_id) < 4: + self.logger.error(f"Invalid gene ID format: {gene_id}") + raise ValueError("Malformed gene ID structure") + + # Check required keys + missing = [k for k in required_gene_keys if k not in gene_info] + if missing: + self.logger.error(f"Gene {gene_id} missing keys: {missing}") + raise ValueError("Incomplete gene information") + + # Check transcripts structure + transcripts = gene_info.get('transcripts', {}) + if not isinstance(transcripts, dict): + self.logger.error(f"Invalid transcripts in gene {gene_id} - expected dict") + raise ValueError("Malformed transcript structure") From 8e699fc44379155265bb88b3070a4821720b10a6 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Tue, 28 Jan 2025 12:33:34 -0600 Subject: [PATCH 18/35] Added DESEQ2 method --- src/gene_model.py | 331 -------------- src/plot_output.py | 242 ---------- src/post_process.py | 609 -------------------------- src/process_dict.py | 92 ---- src/visualization_differential_exp.py | 456 +++++++++++++++++++ 5 files changed, 456 insertions(+), 1274 deletions(-) delete mode 100644 src/gene_model.py delete mode 100644 src/plot_output.py delete mode 100644 src/post_process.py delete mode 100644 src/process_dict.py create mode 100644 src/visualization_differential_exp.py diff --git a/src/gene_model.py b/src/gene_model.py deleted file mode 100644 index a25008ce..00000000 --- a/src/gene_model.py +++ /dev/null @@ -1,331 +0,0 @@ -import os -import pandas as pd -import matplotlib.pyplot as plt -import seaborn as sns -from scipy.spatial.distance import euclidean - - -def parse_data(data): - genes = {} - for condition, condition_data in data.items(): - for gene, gene_data in condition_data.items(): - if gene not in genes: - genes[gene] = { - "chromosome": gene_data["chromosome"], - "start": gene_data["start"], - "end": gene_data["end"], - "strand": gene_data["strand"], - "biotype": gene_data["biotype"], - "transcripts": {}, - } - genes[gene]["transcripts"][condition] = gene_data["transcripts"] - genes[gene][condition] = gene_data["value"] - return genes - - -def calculate_deviance(reference_transcripts, condition_transcripts): - all_transcripts = set(reference_transcripts.keys()).union( - set(condition_transcripts.keys()) - ) - - reference_proportions = [reference_transcripts.get(t, 0) for t in all_transcripts] - condition_proportions = [condition_transcripts.get(t, 0) for t in all_transcripts] - - total_reference = sum(reference_proportions) - total_condition = sum(condition_proportions) - - if total_reference > 0: - reference_proportions = [p / total_reference for p in reference_proportions] - if total_condition > 0: - condition_proportions = [p / total_condition for p in condition_proportions] - - distance = euclidean(reference_proportions, condition_proportions) - - # Reduce distance if total unique transcripts are 1 - if len(all_transcripts) == 1: - distance *= 0.7 - - return distance - - -def calculate_metrics(genes, reference_condition): - metrics = [] - for gene, gene_data in genes.items(): - reference_transcripts = gene_data["transcripts"].get(reference_condition, {}) - - for condition in gene_data: - if condition in [ - "chromosome", - "start", - "end", - "strand", - "biotype", - "transcripts", - reference_condition, - ]: - continue - condition_transcripts = gene_data["transcripts"].get(condition, {}) - deviance = calculate_deviance(reference_transcripts, condition_transcripts) - metrics.append({"gene": gene, "condition": condition, "deviance": deviance}) - - value = gene_data.get(condition, 0) - reference_value = gene_data.get(reference_condition, 0) - abs_diff = abs(value - reference_value) - metrics.append( - { - "gene": gene, - "condition": condition, - "value": value, - "abs_diff": abs_diff, - } - ) - - return pd.DataFrame(metrics) - - -def check_known_target(gene, known_targets): - for target in known_targets: - if "|" in target: - if any(part in gene for part in target.split("|")): - return 1 - elif target == gene: - return 1 - return 0 - - -def rank_genes(df, known_genes_path=None): - value_ranking = df.groupby("gene")["value"].mean().reset_index() - abs_diff_ranking = df.groupby("gene")["abs_diff"].mean().reset_index() - deviance_ranking = df.groupby("gene")["deviance"].mean().reset_index() - - value_ranking["rank_value"] = value_ranking["value"].rank(ascending=False) - abs_diff_ranking["rank_abs_diff"] = abs_diff_ranking["abs_diff"].rank( - ascending=False - ) - deviance_ranking["rank_deviance"] = deviance_ranking["deviance"].rank( - ascending=False - ) - - merged_df = value_ranking[["gene", "rank_value"]].merge( - abs_diff_ranking[["gene", "rank_abs_diff"]], on="gene" - ) - merged_df = merged_df.merge(deviance_ranking[["gene", "rank_deviance"]], on="gene") - - if known_genes_path: - target_genes_df = pd.read_csv(known_genes_path, header=None, names=["gene"]) - known_targets = target_genes_df["gene"].tolist() - df["known_target"] = df["gene"].apply( - lambda x: check_known_target(x, known_targets) - ) - merged_df = merged_df.merge(df[["gene", "known_target"]], on="gene") - - # Devalue the importance of overall expression by reducing its weight - merged_df["combined_rank"] = ( - merged_df["rank_value"] # Reduced weight for rank_value - + merged_df["rank_abs_diff"] - + merged_df["rank_deviance"] - ) - - top_combined_ranking = merged_df.sort_values(by="combined_rank").head(10) - top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(10) - top_100_combined_ranking = merged_df.sort_values(by="combined_rank").head(100) - - return ( - top_combined_ranking, - top_deviance_ranking, - top_100_combined_ranking, - merged_df, - ) - - -def visualize_ranking( - top_combined_ranking, - top_deviance_ranking, - merged_df, - output_dir, - reference_condition, -): - find_genes_dir = os.path.join(output_dir, "find_genes") - if not os.path.exists(find_genes_dir): - os.makedirs(find_genes_dir) - - # Ensure we're using only the top 10 genes - top_10_combined = top_combined_ranking.head(10) - top_10_deviance = top_deviance_ranking.head(10) - - # Bar plot for combined rank - plt.figure(figsize=(12, 8)) - sns.barplot( - x="combined_rank", - y="gene", - data=top_10_combined, - hue="gene", - palette="viridis", - legend=False, - ) - plt.title("Top 10 Genes by Combined Ranking") - plt.xlabel("Combined Rank") - plt.ylabel("Gene") - plt.savefig( - os.path.join(find_genes_dir, "top_10_genes_combined_ranking.png"), dpi=300 - ) - plt.close() - - # Heatmap for metric ranks - top_genes = top_10_combined["gene"].tolist() - heatmap_data = merged_df[merged_df["gene"].isin(top_genes)] - heatmap_data = heatmap_data.set_index("gene")[ - ["rank_value", "rank_abs_diff", "rank_deviance"] - ] - - plt.figure(figsize=(12, 8)) - sns.heatmap( - heatmap_data, - annot=True, - cmap="RdBu_r", - linewidths=0.5, - cbar_kws={"label": "Rank"}, - ) - plt.title("Metric Ranks for Top 10 Genes") - plt.savefig( - os.path.join(find_genes_dir, "top_10_metric_ranks_heatmap.png"), dpi=300 - ) - plt.close() - - # Diverging bar plot for deviance - plt.figure(figsize=(12, 8)) - sns.barplot( - x="rank_deviance", - y="gene", - data=top_10_deviance, - hue="gene", - palette="coolwarm", - orient="h", - legend=False, - ) - plt.title(f"Top 10 Genes by Transcript Deviance from {reference_condition}") - plt.xlabel(f"Rank of Deviance from {reference_condition}") - plt.ylabel("Gene") - plt.axvline(x=0, color="grey", linestyle="--") - plt.savefig( - os.path.join(find_genes_dir, "top_10_deviance_from_reference.png"), dpi=300 - ) - plt.close() - - # Scatter plot for rank_value vs rank_abs_diff - plt.figure(figsize=(12, 8)) - sns.scatterplot( - x="rank_value", - y="rank_abs_diff", - hue="gene", - data=top_10_combined, - palette="deep", - s=100, - ) - plt.title("Rank Value vs Rank Absolute Difference (Top 10 Genes)") - plt.xlabel("Rank Value") - plt.ylabel("Rank Absolute Difference") - plt.savefig( - os.path.join(find_genes_dir, "top_10_rank_value_vs_rank_abs_diff.png"), dpi=300 - ) - plt.close() - - # Combined multi-metric visualization - fig, axes = plt.subplots(2, 2, figsize=(20, 16)) - sns.barplot( - x="combined_rank", - y="gene", - data=top_10_combined, - hue="gene", - palette="viridis", - ax=axes[0, 0], - legend=False, - ) - axes[0, 0].set_title("Combined Rank (Top 10 Genes)") - axes[0, 0].set_xlabel("Combined Rank") - axes[0, 0].set_ylabel("Gene") - sns.heatmap( - heatmap_data, - annot=True, - cmap="RdBu_r", - linewidths=0.5, - cbar_kws={"label": "Rank"}, - ax=axes[0, 1], - ) - axes[0, 1].set_title("Metric Ranks (Top 10 Genes)") - - sns.barplot( - x="rank_deviance", - y="gene", - data=top_10_deviance, - hue="gene", - palette="coolwarm", - orient="h", - ax=axes[1, 0], - legend=False, - ) - axes[1, 0].set_title( - f"Transcript Deviance from {reference_condition} (Top 10 Genes)" - ) - axes[1, 0].set_xlabel(f"Rank of Deviance from {reference_condition}") - axes[1, 0].set_ylabel("Gene") - axes[1, 0].axvline(x=0, color="grey", linestyle="--") - - sns.scatterplot( - x="rank_value", - y="rank_abs_diff", - hue="gene", - data=top_10_combined, - palette="deep", - s=100, - ax=axes[1, 1], - ) - axes[1, 1].set_title("Rank Value vs Rank Absolute Difference (Top 10 Genes)") - axes[1, 1].set_xlabel("Rank Value") - axes[1, 1].set_ylabel("Rank Absolute Difference") - - plt.tight_layout() - plt.savefig( - os.path.join(find_genes_dir, "top_10_combined_visualization.png"), dpi=300 - ) - plt.close() - - -def save_top_genes(top_combined_ranking, output_dir, num_genes): - top_combined_ranking.head(num_genes)[["gene"]].to_csv( - os.path.join(output_dir, f"top_{num_genes}_genes.txt"), - index=False, - header=False, - sep="\t", - ) - return os.path.join(output_dir, f"top_{num_genes}_genes.txt") - - -def rank_and_visualize_genes( - input_data, - output_dir, - num_genes=100, - known_genes_path=None, - reference_condition=None, -): - genes = parse_data(input_data) - metrics_df = calculate_metrics(genes, reference_condition) - top_combined_ranking, top_deviance_ranking, top_100_combined_ranking, merged_df = ( - rank_genes(metrics_df, known_genes_path) - ) - merged_df = merged_df.drop_duplicates(subset="gene", keep="first") - top_combined_ranking = merged_df.sort_values(by="combined_rank").head(num_genes) - top_deviance_ranking = merged_df.sort_values(by="rank_deviance").head(num_genes) - - visualize_ranking( - top_combined_ranking, - top_deviance_ranking, - merged_df, - output_dir, - reference_condition, - ) - path = save_top_genes(top_combined_ranking, output_dir, num_genes) - find_genes_dir = os.path.join(output_dir, "find_genes") - merged_df.to_csv(os.path.join(find_genes_dir, "gene_metrics.csv"), index=False) - - return path diff --git a/src/plot_output.py b/src/plot_output.py deleted file mode 100644 index 74045ccc..00000000 --- a/src/plot_output.py +++ /dev/null @@ -1,242 +0,0 @@ -import os -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -import numpy as np -import pprint - - -class PlotOutput: - def __init__( - self, - updated_gene_dict, - gene_names, - output_directory, - read_assignments_dir, - reads_and_class=None, - filter_transcripts=None, - conditions=False, - use_counts=False, - ): - self.updated_gene_dict = updated_gene_dict - self.gene_names = gene_names - self.visualization_dir = output_directory - self.read_assignments_dir = read_assignments_dir - self.reads_and_class = reads_and_class - self.filter_transcripts = filter_transcripts - self.conditions = conditions - self.use_counts = use_counts - - # Ensure the visualization directory exists - os.makedirs(self.visualization_dir, exist_ok=True) - os.makedirs(self.read_assignments_dir, exist_ok=True) - - def plot_transcript_map(self): - # Get the first condition's gene dictionary - first_condition = next(iter(self.updated_gene_dict)) - gene_dict = self.updated_gene_dict[first_condition] - - for gene_name in self.gene_names: - if gene_name in gene_dict: - gene_data = gene_dict[gene_name] - num_transcripts = len(gene_data["transcripts"]) - plot_height = max( - 3, num_transcripts * 0.3 - ) # Adjust the height dynamically - - fig, ax = plt.subplots( - figsize=(12, plot_height) - ) # Adjust height dynamically - - if self.filter_transcripts is not None: - ax.set_title( - f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']} with value over {self.filter_transcripts}" - ) - else: - ax.set_title( - f"Transcripts of Gene: {gene_data['name']} on Chromosome {gene_data['chromosome']}" - ) - - ax.set_xlabel("Chromosomal position") - ax.set_ylabel("Transcripts") - ax.set_yticks(range(num_transcripts)) - ax.set_yticklabels( - [ - f"{transcript_id}" - for transcript_id in gene_data["transcripts"].keys() - ] - ) - - ax.xaxis.set_major_locator( - ticker.MaxNLocator(integer=True) - ) # Ensure genomic positions are integers - ax.xaxis.set_major_formatter( - ticker.FuncFormatter(lambda x, pos: f"{int(x)}") - ) # Format x-axis ticks as integers - - # Plot each transcript - for i, (transcript_id, transcript_info) in enumerate( - gene_data["transcripts"].items() - ): - # Determine the direction based on the gene's strand information - direction_marker = ">" if gene_data["strand"] == "+" else "<" - marker_pos = ( - transcript_info["end"] + 100 - if gene_data["strand"] == "+" - else transcript_info["start"] - 100 - ) - ax.plot( - marker_pos, - i, - marker=direction_marker, - markersize=5, - color="blue", - ) - - # Draw the line for the whole transcript - ax.plot( - [transcript_info["start"], transcript_info["end"]], - [i, i], - color="grey", - linewidth=2, - ) - - # Exon blocks - for exon in transcript_info["exons"]: - exon_length = exon["end"] - exon["start"] - ax.add_patch( - plt.Rectangle( - (exon["start"], i - 0.4), - exon_length, - 0.8, - color="skyblue", - ) - ) - - ax.set_xlim(gene_data["start"], gene_data["end"]) - ax.invert_yaxis() # First transcript at the top - - plt.tight_layout() - plot_path = os.path.join( - self.visualization_dir, f"{gene_name}_splicing.png" - ) - plt.savefig(plot_path) # Saving plot by gene name - plt.close(fig) - - def plot_transcript_usage(self): - """ - Visualize transcript usage for each gene in gene_names across different conditions. - """ - - for gene_name in self.gene_names: - gene_data = {} - for condition, genes in self.updated_gene_dict.items(): - if gene_name in genes: - gene_data[condition] = genes[gene_name]["transcripts"] - - if not gene_data: - print(f"Gene {gene_name} not found in the data.") - continue - - conditions = list(gene_data.keys()) - n_bars = len(conditions) - - fig, ax = plt.subplots(figsize=(12, 8)) - index = np.arange(n_bars) - bar_width = 0.35 - opacity = 0.8 - max_transcripts = max(len(gene_data[condition]) for condition in conditions) - colors = plt.cm.plasma( - np.linspace(0, 1, num=max_transcripts) - ) # Using plasma for better color gradation - - bottom_val = np.zeros(n_bars) - for i, condition in enumerate(conditions): - transcripts = gene_data[condition] - for j, (transcript_id, transcript_info) in enumerate( - transcripts.items() - ): - color = colors[j % len(colors)] - value = transcript_info["value"] - plt.bar( - i, - float(value), - bar_width, - bottom=bottom_val[i], - alpha=opacity, - color=color, - label=transcript_id if i == 0 else "", - ) - bottom_val[i] += float(value) - - plt.xlabel("Sample Type") - plt.ylabel("Transcript Usage (TPM)") - plt.title(f"Transcript Usage for {gene_name} by Sample Type") - plt.xticks(index, conditions) - plt.legend( - title="Transcript IDs", bbox_to_anchor=(1.05, 1), loc="upper left" - ) - - plt.tight_layout() - plot_path = os.path.join( - self.visualization_dir, - f"{gene_name}_transcript_usage_by_sample_type.png", - ) - plt.savefig(plot_path) - plt.close(fig) - - def make_pie_charts(self): - """ - Create pie charts for transcript alignment classifications and read assignment consistency. - Handles both combined and separate sample data structures. - """ - - titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] - - for title, data in zip(titles, self.reads_and_class): - if isinstance(data, dict): - if any(isinstance(v, dict) for v in data.values()): - # Separate 'Mutants' and 'WildType' case - for sample_name, sample_data in data.items(): - self._create_pie_chart(f"{title} - {sample_name}", sample_data) - else: - # Combined data case - self._create_pie_chart(title, data) - else: - print(f"Skipping unexpected data type for {title}: {type(data)}") - - def _create_pie_chart(self, title, data): - """ - Helper method to create a single pie chart. - """ - labels = list(data.keys()) - sizes = list(data.values()) - total = sum(sizes) - - # Generate a file-friendly title - file_title = title.lower().replace(" ", "_").replace("-", "_") - - plt.figure(figsize=(12, 8)) - wedges, texts, autotexts = plt.pie( - sizes, - labels=labels, - autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total):d})", - startangle=140, - textprops=dict(color="w"), - ) - plt.setp(autotexts, size=8, weight="bold") - plt.setp(texts, size=7) - - plt.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. - plt.title(f"{title}\nTotal: {total}") - - plt.legend( - wedges, - labels, - title="Categories", - loc="center left", - bbox_to_anchor=(1, 0, 0.5, 1), - fontsize=8, - ) - plot_path = os.path.join(self.visualization_dir, f"{file_title}_pie_chart.png") - plt.savefig(plot_path, bbox_inches="tight", dpi=300) - plt.close() diff --git a/src/post_process.py b/src/post_process.py deleted file mode 100644 index 94a1102f..00000000 --- a/src/post_process.py +++ /dev/null @@ -1,609 +0,0 @@ -import csv -import os -import pickle -import gzip -import shutil -import copy -import json -from argparse import Namespace -import tempfile -import gffutils -import yaml - - -class OutputConfig: - """Class to build dictionaries from the output files of the pipeline.""" - - def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): - self.output_directory = output_directory - self.log_details = {} - self.extended_annotation = None - self.read_assignments = None - self.input_gtf = gtf # Initialize with the provided gtf flag - self.genedb_filename = None - self.yaml_input = True - self.yaml_input_path = None - self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. - self.conditions = False - self.gene_grouped_counts = None - self.transcript_grouped_counts = None - self.transcript_grouped_tpm = None - self.gene_grouped_tpm = None - self.gene_counts = None - self.transcript_counts = None - self.gene_tpm = None - self.transcript_tpm = None - self.transcript_model_counts = None - self.transcript_model_tpm = None - self.transcript_model_grouped_tpm = None - self.transcript_model_grouped_counts = None - self.use_counts = use_counts - self.ref_only = ref_only - - self._load_params_file() - self._find_files() - self._conditional_unzip() - - # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log - if self.ref_only and not self.input_gtf: - raise ValueError( - "Input GTF file is required when ref_only is set. Please provide it using the --gtf flag." - ) - - def _load_params_file(self): - """Load the .params file for necessary configuration and commands.""" - params_path = os.path.join(self.output_directory, ".params") - assert os.path.exists(params_path), f"Params file not found: {params_path}" - try: - with open(params_path, "rb") as file: - params = pickle.load(file) - if isinstance(params, Namespace): - self._process_params(vars(params)) - else: - print("Unexpected params format.") - except Exception as e: - raise ValueError(f"An error occurred while loading params: {e}") - - def _process_params(self, params): - """Process parameters loaded from the .params file.""" - self.log_details["gene_db"] = params.get("genedb") - self.log_details["fastq_used"] = bool(params.get("fastq")) - self.input_gtf = self.input_gtf or params.get("genedb") - self.genedb_filename = params.get("genedb_filename") - - if params.get("yaml"): - # YAML input case - self.yaml_input = True - self.yaml_input_path = params.get("yaml") - # Keep the output_directory as is, don't modify it - else: - # Non-YAML input case - self.yaml_input = False - processing_sample = params.get("prefix") - if processing_sample: - self.output_directory = os.path.join( - self.output_directory, processing_sample - ) - else: - raise ValueError( - "Processing sample directory not found in params for non-YAML input." - ) - - def _conditional_unzip(self): - """Check if unzip is needed and perform it conditionally based on the model use.""" - if self.ref_only and self.input_gtf and self.input_gtf.endswith(".gz"): - self.input_gtf = self._unzip_file(self.input_gtf) - if not self.input_gtf: - raise FileNotFoundError( - f"Unable to find or unzip the specified file: {self.input_gtf}" - ) - - def _unzip_file(self, file_path): - """Unzip a gzipped file and return the path to the uncompressed file.""" - new_path = file_path[:-3] # Remove .gz extension - - if os.path.exists(new_path): - # print(f"File {new_path} already exists, using this file.") - return new_path - - if not os.path.exists(file_path): - self.gtf_flag_needed = True - return None - - with gzip.open(file_path, "rb") as f_in: - with open(new_path, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - print(f"File {file_path} was decompressed to {new_path}.") - - return new_path - - def _find_files(self): - """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" - if self.yaml_input: - self.conditions = True - self.ref_only = True - self._find_files_from_yaml() - return # Exit the method after processing YAML input - - if not os.path.exists(self.output_directory): - print(f"Directory not found: {self.output_directory}") # Debugging output - raise FileNotFoundError( - f"Specified sample subdirectory does not exist: {self.output_directory}" - ) - - for file_name in os.listdir(self.output_directory): - if file_name.endswith(".extended_annotation.gtf"): - self.extended_annotation = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".read_assignments.tsv"): - self.read_assignments = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".read_assignments.tsv.gz"): - self.read_assignments = self._unzip_file( - os.path.join(self.output_directory, file_name) - ) - elif file_name.endswith(".gene_grouped_counts.tsv"): - self.conditions = True - self.gene_grouped_counts = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_grouped_counts.tsv"): - self.transcript_grouped_counts = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_grouped_tpm.tsv"): - self.transcript_grouped_tpm = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".gene_grouped_tpm.tsv"): - self.gene_grouped_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".gene_counts.tsv"): - self.gene_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".transcript_counts.tsv"): - self.transcript_counts = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".gene_tpm.tsv"): - self.gene_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".transcript_tpm.tsv"): - self.transcript_tpm = os.path.join(self.output_directory, file_name) - elif file_name.endswith(".transcript_model_counts.tsv"): - self.transcript_model_counts = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_model_tpm.tsv"): - self.transcript_model_tpm = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_model_grouped_tpm.tsv"): - self.transcript_model_grouped_tpm = os.path.join( - self.output_directory, file_name - ) - elif file_name.endswith(".transcript_model_grouped_counts.tsv"): - self.transcript_model_grouped_counts = os.path.join( - self.output_directory, file_name - ) - - # Determine if GTF flag is needed - if ( - not self.input_gtf - or not os.path.exists(self.input_gtf) - and not os.path.exists(self.input_gtf + ".gz") - and self.ref_only - ): - self.gtf_flag_needed = True - - # Set ref_only default based on the availability of extended_annotation - if self.ref_only is None: - self.ref_only = not self.extended_annotation - - def _find_files_from_yaml(self): - """Locate the necessary files in the directory, set specific grouped count and TPM files, and process read assignments.""" - if not os.path.exists(self.yaml_input_path): - print(f"YAML file not found: {self.yaml_input_path}") - raise FileNotFoundError( - f"Specified YAML file does not exist: {self.yaml_input_path}" - ) - - # Set the four specific attributes - self.gene_grouped_counts = os.path.join( - self.output_directory, "combined_gene_counts.tsv" - ) - self.transcript_grouped_counts = os.path.join( - self.output_directory, "combined_transcript_counts.tsv" - ) - self.transcript_grouped_tpm = os.path.join( - self.output_directory, "combined_transcript_tpm.tsv" - ) - self.gene_grouped_tpm = os.path.join( - self.output_directory, "combined_gene_tpm.tsv" - ) - - # Check if the files exist - for attr in [ - "gene_grouped_counts", - "transcript_grouped_counts", - "transcript_grouped_tpm", - "gene_grouped_tpm", - ]: - file_path = getattr(self, attr) - if not os.path.exists(file_path): - print(f"Warning: {attr} file not found at {file_path}") - setattr(self, attr, None) - - # Initialize read_assignments list - self.read_assignments = [] - - # Read and process the YAML file - with open(self.yaml_input_path, "r") as yaml_file: - yaml_data = yaml.safe_load(yaml_file) - - # Check if yaml_data is a list - if isinstance(yaml_data, list): - samples = yaml_data - else: - # If it's not a list, assume it's a dictionary with a 'samples' key - samples = yaml_data.get("samples", []) - - for sample in samples: - name = sample.get("name") - if name: - sample_dir = os.path.join(self.output_directory, name) - - # Check for .read_assignments.tsv.gz - gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") - if os.path.exists(gz_file): - unzipped_file = self._unzip_file(gz_file) - if unzipped_file: - self.read_assignments.append((name, unzipped_file)) - else: - print(f"Warning: Failed to unzip {gz_file}") - else: - # Check for .read_assignments.tsv - non_gz_file = os.path.join( - sample_dir, f"{name}.read_assignments.tsv" - ) - if os.path.exists(non_gz_file): - self.read_assignments.append((name, non_gz_file)) - else: - print(f"Warning: No read assignments file found for {name}") - - if not self.read_assignments: - print("Warning: No read assignment files found for any samples") - - -class DictionaryBuilder: - """Class to build dictionaries from the output files of the pipeline.""" - - def __init__(self, config): - self.config = config - - def build_gene_transcript_exon_dictionaries(self): - """Builds dictionaries of genes, transcripts, and exons from the GTF file.""" - if self.config.extended_annotation and not self.config.ref_only: - return self.parse_extended_annotation() - else: - return self.parse_input_gtf() - - def build_read_assignment_and_classification_dictionaries(self): - """Indexes classifications and assignment types from read_assignments.tsv file(s).""" - if not self.config.read_assignments: - raise FileNotFoundError("No read assignments file(s) found.") - - if isinstance(self.config.read_assignments, list): - # YAML input case (multiple files) - classification_counts_dict = {} - assignment_type_counts_dict = {} - for sample_name, read_assignment_file in self.config.read_assignments: - classification_counts, assignment_type_counts = ( - self._process_read_assignment_file(read_assignment_file) - ) - classification_counts_dict[sample_name] = classification_counts - assignment_type_counts_dict[sample_name] = assignment_type_counts - return classification_counts_dict, assignment_type_counts_dict - else: - # Non-YAML input case (single file) - return self._process_read_assignment_file(self.config.read_assignments) - - def _process_read_assignment_file(self, file_path): - classification_counts = {} - assignment_type_counts = {} - - with open(file_path, "r") as file: - # Skip header lines - for _ in range(3): - next(file, None) - - for line in file: - parts = line.strip().split("\t") - if len(parts) < 6: - continue - - additional_info = parts[-1] - classification = ( - additional_info.split("Classification=")[-1].split(";")[0].strip() - ) - assignment_type = parts[5] - - classification_counts[classification] = ( - classification_counts.get(classification, 0) + 1 - ) - assignment_type_counts[assignment_type] = ( - assignment_type_counts.get(assignment_type, 0) + 1 - ) - - return classification_counts, assignment_type_counts - - def parse_input_gtf(self): - """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" - gene_dict = {} - if not self.config.genedb_filename: - # convert GTF to DB if we use previous IsoQuant runs - # remove this functionality later - tmp_file = tempfile.NamedTemporaryFile(suffix=".db") - self.config.genedb_filename = tmp_file.name - input_gtf_path = self.config.input_gtf - gffutils.create_db( - input_gtf_path, - dbfn=self.config.genedb_filename, - force=True, - keep_order=True, - merge_strategy="merge", - sort_attribute_values=True, - disable_infer_genes=True, - disable_infer_transcripts=True, - ) - - try: - # Create a database without using a context manager - db = gffutils.FeatureDB(self.config.genedb_filename) - - for gene in db.features_of_type("gene"): - gene_id = gene.id - gene_dict[gene_id] = { - "chromosome": gene.seqid, - "start": gene.start, - "end": gene.end, - "strand": gene.strand, - "name": gene.attributes.get("gene_name", [""])[0], - "biotype": gene.attributes.get("gene_biotype", [""])[0], - "transcripts": {}, - } - - for transcript in db.children(gene, featuretype="transcript"): - transcript_id = transcript.id - gene_dict[gene_id]["transcripts"][transcript_id] = { - "start": transcript.start, - "end": transcript.end, - "name": transcript.attributes.get("transcript_name", [""])[0], - "biotype": transcript.attributes.get( - "transcript_biotype", [""] - )[0], - "exons": [], - "tags": transcript.attributes.get("tag", [""])[0].split(","), - } - - for exon in db.children(transcript, featuretype="exon"): - exon_info = { - "exon_id": exon.id, - "start": exon.start, - "end": exon.end, - "number": exon.attributes.get("exon_number", [""])[0], - } - gene_dict[gene_id]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) - - except Exception as e: - raise Exception(f"Error parsing GTF file: {str(e)}") - - return gene_dict - - def parse_extended_annotation(self): - """Parses the GTF file to build a detailed dictionary of genes, transcripts, and exons.""" - gene_dict = {} - if not self.config.extended_annotation: - raise FileNotFoundError("Extended annotation GTF file is missing.") - - with open(self.config.extended_annotation, "r") as file: - for line in file: - if line.startswith("#") or not line.strip(): - continue - fields = line.strip().split("\t") - if len(fields) < 9: - print( - f"Skipping malformed line due to insufficient fields: {line.strip()}" - ) - continue - - info_fields = fields[8].strip(";").split(";") - details = { - field.strip().split(" ")[0]: field.strip().split(" ")[1].strip('"') - for field in info_fields - if " " in field - } - - try: - if fields[2] == "gene": - gene_id = details["gene_id"] - gene_dict[gene_id] = { - "chromosome": fields[0], - "start": int(fields[3]), - "end": int(fields[4]), - "strand": fields[6], - "name": details.get("gene_name", ""), - "biotype": details.get("gene_biotype", ""), - "transcripts": {}, - } - elif fields[2] == "transcript": - transcript_id = details["transcript_id"] - gene_dict[details["gene_id"]]["transcripts"][transcript_id] = { - "start": int(fields[3]), - "end": int(fields[4]), - "exons": [], - } - elif fields[2] == "exon": - transcript_id = details["transcript_id"] - exon_info = { - "exon_id": details["exon_id"], - "start": int(fields[3]), - "end": int(fields[4]), - } - gene_dict[details["gene_id"]]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) - except KeyError as e: - print(f"Key error in line: {line.strip()} | Missing key: {e}") - return gene_dict - - def update_gene_dict(self, gene_dict, value_df): - new_dict = {} - gene_values = {} - - # Read gene counts from value_df - with open(value_df, "r") as file: - reader = csv.reader(file, delimiter="\t") - header = next(reader) - conditions = header[1:] # Assumes the first column is gene ID - - # Initialize gene_values dictionary - for row in reader: - gene_id = row[0] - gene_values[gene_id] = {} - for i, condition in enumerate(conditions): - if len(row) > i + 1: - value = float(row[i + 1]) - else: - value = 0.0 # Default to 0 if no value - gene_values[gene_id][condition] = value - - # Build the new dictionary structure by conditions - for condition in conditions: - new_dict[condition] = {} # Create a new sub-dictionary for each condition - - # Deep copy the gene_dict and update with values from value_df - for gene_id, gene_info in gene_dict.items(): - new_dict[condition][gene_id] = copy.deepcopy(gene_info) - if gene_id in gene_values and condition in gene_values[gene_id]: - new_dict[condition][gene_id]["value"] = gene_values[gene_id][ - condition - ] - else: - new_dict[condition][gene_id][ - "value" - ] = 0 # Default to 0 if the gene_id has no corresponding value - - return new_dict - - def update_transcript_values(self, gene_dict, value_df): - new_dict = copy.deepcopy(gene_dict) # Preserve the original structure - transcript_values = {} - - # Load transcript counts from value_df - with open(value_df, "r") as file: - reader = csv.reader(file, delimiter="\t") - header = next(reader) - conditions = header[1:] # Assumes the first column is transcript ID - - for row in reader: - transcript_id = row[0] - for i, condition in enumerate(conditions): - if len(row) > i + 1: - value = float(row[i + 1]) - else: - value = 0.0 # Default to 0 if no value - if transcript_id not in transcript_values: - transcript_values[transcript_id] = {} - transcript_values[transcript_id][condition] = value - - # Update each condition without restructuring the original dictionary - for condition in conditions: - if condition not in new_dict: - new_dict[condition] = copy.deepcopy( - gene_dict - ) # Make sure all genes are present - - for gene_id, gene_info in new_dict[condition].items(): - if "transcripts" in gene_info: - for transcript_id, transcript_info in gene_info[ - "transcripts" - ].items(): - if ( - transcript_id in transcript_values - and condition in transcript_values[transcript_id] - ): - transcript_info["value"] = transcript_values[transcript_id][ - condition - ] - else: - transcript_info["value"] = ( - 0 # Set default if no value for this transcript - ) - return new_dict - - def update_gene_names(self, gene_dict): - updated_dict = {} - for condition, genes in gene_dict.items(): - updated_genes = {} - for gene_id, gene_info in genes.items(): - if gene_info["name"]: - gene_name_upper = gene_info["name"].upper() - updated_genes[gene_name_upper] = gene_info - else: - # If name is empty, use the original gene_id - updated_genes[gene_id] = gene_info - updated_dict[condition] = updated_genes - return updated_dict - - def filter_transcripts_by_minimum_value(self, gene_dict, min_value=1.0): - # Dictionary to hold genes and transcripts that meet the criteria - transcript_passes_threshold = {} - - # First pass: Determine which transcripts meet the minimum value requirement in any condition - for condition, genes in gene_dict.items(): - for gene_id, gene_info in genes.items(): - for transcript_id, transcript_info in gene_info["transcripts"].items(): - if ( - "value" in transcript_info - and transcript_info["value"] != "NA" - and float(transcript_info["value"]) >= min_value - ): - if gene_id not in transcript_passes_threshold: - transcript_passes_threshold[gene_id] = {} - transcript_passes_threshold[gene_id][transcript_id] = True - - # Second pass: Build the filtered dictionary including only transcripts that have eligible values in any condition - filtered_dict = {} - for condition, genes in gene_dict.items(): - filtered_genes = {} - for gene_id, gene_info in genes.items(): - if gene_id in transcript_passes_threshold: - eligible_transcripts = { - transcript_id: transcript_info - for transcript_id, transcript_info in gene_info[ - "transcripts" - ].items() - if transcript_id in transcript_passes_threshold[gene_id] - } - if ( - eligible_transcripts - ): # Only add genes with non-empty transcript sets - filtered_gene_info = copy.deepcopy(gene_info) - filtered_gene_info["transcripts"] = eligible_transcripts - filtered_genes[gene_id] = filtered_gene_info - if filtered_genes: # Only add conditions with non-empty gene sets - filtered_dict[condition] = filtered_genes - - return filtered_dict - - def read_gene_list(self, gene_list_path): - with open(gene_list_path, "r") as file: - gene_list = [ - line.strip().upper() for line in file - ] # Convert each gene to uppercase - return gene_list - - def save_gene_dict_to_json(self, gene_dict, output_path): - """Saves the gene dictionary to a JSON file.""" - # name the gene_dict file - output_path = os.path.join(output_path, "gene_dict.json") - with open(output_path, "w") as file: - json.dump(gene_dict, file, indent=4) diff --git a/src/process_dict.py b/src/process_dict.py deleted file mode 100644 index bb3ca001..00000000 --- a/src/process_dict.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import sys -import os - - -def simplify_and_sum_transcripts(data): - gene_totals_across_conditions = {} - simplified_data = {} - - # Sum transcript values and collect them across all conditions - for sample_id, genes in data.items(): - simplified_data[sample_id] = {} - for gene_id, gene_data in genes.items(): - transcripts = gene_data.get("transcripts", {}) - total_value = 0.0 - simplified_transcripts = {} - for transcript_id, transcript_details in transcripts.items(): - transcript_value = ( - transcript_details.get("value", 0.0) - if isinstance(transcript_details, dict) - else 0.0 - ) - simplified_transcripts[transcript_id] = transcript_value - total_value += transcript_value - - gene_data_copy = ( - gene_data.copy() - ) # Make a copy to avoid modifying the original - gene_data_copy["transcripts"] = simplified_transcripts - gene_data_copy["value"] = ( - total_value # Replace the gene-level value with the sum of transcript values - ) - simplified_data[sample_id][gene_id] = gene_data_copy - - if gene_id not in gene_totals_across_conditions: - gene_totals_across_conditions[gene_id] = [] - gene_totals_across_conditions[gene_id].append(total_value) - - # Determine which genes to remove - genes_to_remove = [ - gene_id - for gene_id, totals in gene_totals_across_conditions.items() - if all(total < 5 for total in totals) - ] - - # Remove genes from the simplified data structure - for sample_id, genes in simplified_data.items(): - for gene_id in genes_to_remove: - if gene_id in genes: - del genes[gene_id] - - return simplified_data - - -def read_json(file_path): - with open(file_path, "r") as file: - return json.load(file) - - -def write_json(data, file_path): - with open(file_path, "w") as file: - json.dump(data, file, indent=4) - - -def main(): - if len(sys.argv) != 2: - print("Usage: python script.py ") - sys.exit(1) - - input_file_path = sys.argv[1] - base, ext = os.path.splitext(input_file_path) - output_file_path = f"{base}_simplified{ext}" - - try: - # Load the gene data from the specified input JSON file - gene_dict = read_json(input_file_path) - - # Simplify the transcripts, sum their values, and remove genes under a threshold across all conditions - modified_gene_dict = simplify_and_sum_transcripts(gene_dict) - - # Save the modified gene data to the newly named output JSON file - write_json(modified_gene_dict, output_file_path) - - print(f"Modified gene data has been saved to {output_file_path}") - - except Exception as e: - print(f"Error: {str(e)}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py new file mode 100644 index 00000000..91b2210d --- /dev/null +++ b/src/visualization_differential_exp.py @@ -0,0 +1,456 @@ +import logging +import pandas as pd +from typing import Dict, List, Tuple, Optional, Union +from pathlib import Path +from rpy2 import robjects +from rpy2.robjects import r, Formula +from rpy2.robjects.packages import importr +from rpy2.robjects import pandas2ri +from rpy2.robjects.conversion import localconverter +from src.visualization_plotter import ExpressionVisualizer +from src.visualization_mapping import GeneMapper + + +class DifferentialAnalysis: + def __init__( + self, + output_dir: Path, + viz_output: Path, + ref_conditions: List[str], + target_conditions: List[str], + updated_gene_dict: Dict[str, Dict], + ref_only: bool = False, + dictionary_builder: "DictionaryBuilder" = None, + ): + """Initialize differential expression analysis.""" + # Configure rpy2 to suppress R console output + from rpy2.rinterface_lib import callbacks + + # Create a custom callback that does nothing + def quiet_cb(x): + pass + + # Silence R stdout/stderr + callbacks.logger.setLevel(logging.WARNING) # Affects R's logging only + callbacks.consolewrite_print = quiet_cb + callbacks.consolewrite_warnerror = quiet_cb + + self.output_dir = Path(output_dir) + self.deseq_dir = Path(viz_output) / "deseq2_results" + self.deseq_dir.mkdir(parents=True, exist_ok=True) + self.ref_conditions = ref_conditions + self.target_conditions = target_conditions + self.ref_only = ref_only + self.updated_gene_dict = updated_gene_dict + self.dictionary_builder = dictionary_builder + + # Create a single logger for this class + self.logger = logging.getLogger('IsoQuant.visualization.differential_exp') + + self.transcript_to_gene = self._create_transcript_to_gene_map() + self.visualizer = ExpressionVisualizer(self.deseq_dir) + self.gene_mapper = GeneMapper() + + def _create_transcript_to_gene_map(self) -> Dict[str, str]: + """ + Create a mapping from transcript IDs to gene names. + + Returns: + Dict[str, str]: Mapping from transcript ID to gene name. + """ + transcript_map = {} + for gene_category, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + gene_name = gene_info.get("name", gene_id) + transcripts = gene_info.get("transcripts", {}) + for transcript_id, transcript_info in transcripts.items(): + transcript_name = transcript_info.get("name", gene_name) + transcript_map[transcript_id] = transcript_name + return transcript_map + + def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: + """Run differential expression analysis for both genes and transcripts.""" + self.logger.info("Starting differential expression analysis") + + valid_transcripts = set() + for condition_genes in self.updated_gene_dict.values(): + for gene_info in condition_genes.values(): + valid_transcripts.update(gene_info.get("transcripts", {}).keys()) + + # --- 1. Load Count Data --- + gene_counts = self._get_condition_data("gene_grouped_counts.tsv") + transcript_counts = self._get_condition_data("transcript_grouped_counts.tsv") + self.logger.debug(f"Transcript counts shape after loading: {transcript_counts.shape}") + self.logger.debug(f"Gene counts shape after loading: {gene_counts.shape}") + + # --- 2. Novel Transcript Filtering (Transcript Level) --- + if self.dictionary_builder: + novel_transcript_ids = self.dictionary_builder.get_novel_feature_ids()[1] + self.logger.debug(f"Number of novel transcripts identified: {len(novel_transcript_ids)}") + + original_transcript_count_novel_filter = transcript_counts.shape[0] + transcript_counts = transcript_counts[~transcript_counts.index.isin(novel_transcript_ids)] # Filter out novel transcripts + novel_filtered_count = transcript_counts.shape[0] + removed_novel_count = original_transcript_count_novel_filter - novel_filtered_count + self.logger.info(f"Novel transcript filtering: Removed {removed_novel_count} novel transcripts ({removed_novel_count / original_transcript_count_novel_filter * 100:.1f}%)") + self.logger.debug(f"Transcript counts shape after novel filtering: {transcript_counts.shape}") + else: + self.logger.info("Novel transcript filtering: Skipped (no dictionary builder)") + + # --- 3. Valid Transcript Filtering (Transcript Level) --- + original_transcript_count_valid_filter = transcript_counts.shape[0] + transcript_counts = transcript_counts[transcript_counts.index.isin(valid_transcripts)] # Filter to valid transcripts + valid_transcript_filtered_count = transcript_counts.shape[0] + removed_valid_transcript_count = original_transcript_count_valid_filter - valid_transcript_filtered_count + self.logger.info(f"Valid transcript filtering: Removed {removed_valid_transcript_count} transcripts not in updated_gene_dict ({removed_valid_transcript_count / original_transcript_count_valid_filter * 100:.1f}%)") + self.logger.debug(f"Transcript counts shape after valid transcript filtering: {transcript_counts.shape}") + + if transcript_counts.empty: + self.logger.error("No valid transcripts found after filtering.") + raise ValueError("No valid transcripts found after filtering.") + + # --- 4. Count-based Filtering (Gene and Transcript Levels) --- + gene_counts_filtered = self._filter_counts(gene_counts, level="gene") + transcript_counts_filtered = self._filter_counts(transcript_counts, level="transcript") # Filter transcript counts AFTER novel and valid transcript filtering + + self.transcript_count_data = transcript_counts_filtered # Store filtered transcript counts + + # --- 5. Run DESeq2 Analysis --- + if not self.ref_only: + deseq2_results_gene_file, _ = self._run_level_analysis( + level="gene", + pattern="gene_grouped_counts.tsv", # Pattern is still needed in _run_level_analysis for output file naming + count_data=gene_counts_filtered # Pass PRE-FILTERED gene counts + ) + deseq2_results_transcript_file, deseq2_transcript_df = self._run_level_analysis( + level="transcript", + pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", # Pattern is still needed in _run_level_analysis for output file naming + count_data=transcript_counts_filtered # Pass PRE-FILTERED transcript counts + ) + + # --- Visualize Gene-Level Results --- + gene_results_df = pd.read_csv(deseq2_results_gene_file) + target_label = f"{'+'.join(self.target_conditions)}_vs_{'+'.join(self.ref_conditions)}" + reference_label = f"{'+'.join(self.ref_conditions)}" # Corrected reference label + self.visualizer.visualize_results( # Call visualize_results for gene-level + results=gene_results_df, + target_label=target_label, + reference_label=reference_label, + min_count=10, # Assuming min_count_threshold is defined in DifferentialAnalysis + feature_type="genes", + ) + self.logger.info(f"Gene-level visualizations saved to {self.deseq_dir}") + + + # --- Visualize Transcript-Level Results --- + transcript_results_df = pd.read_csv(deseq2_results_transcript_file) + self.visualizer.visualize_results( # Call visualize_results for transcript-level + results=transcript_results_df, + target_label=target_label, + reference_label=reference_label, + min_count=10, + feature_type="transcripts", + ) + self.logger.info(f"Transcript-level visualizations saved to {self.deseq_dir}") + + return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered + + def _run_level_analysis( + self, level: str, count_data: pd.DataFrame, pattern: Optional[str] = None + ) -> Tuple[Path, pd.DataFrame]: + """ + Run DESeq2 analysis for a specific level and return results. + + Args: + level: Analysis level ("gene" or "transcript") + pattern: Optional pattern for output file naming (not used for data loading anymore) + count_data: PRE-FILTERED count data DataFrame + + Returns: + Tuple containing: (results_path, results_df) + """ + # --- SIMPLIFIED: _run_level_analysis now assumes count_data is already loaded and filtered --- + + if count_data.empty: + self.logger.error(f"Input count data is empty for level: {level}") + raise ValueError(f"Input count data is empty for level: {level}") + + filtered_data = count_data.copy() # Work with a copy to avoid modifying original + + # Create design matrix and run DESeq2 + coldata = self._build_design_matrix(filtered_data) + results = self._run_deseq2(filtered_data, coldata) + + # Process results + results.index.name = "feature_id" + results.reset_index(inplace=True) + mapping = self._map_gene_symbols(results["feature_id"].unique(), level) + + # Add symbol column + results["symbol"] = results["feature_id"].map(lambda x: mapping[x][0]) + + # Add gene_name column only for transcript level + if level == "transcript": + results["gene_name"] = results["feature_id"].map(lambda x: mapping[x][1]) + + # Save results + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + outfile = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}.csv" + results.to_csv(outfile, index=False) + self.logger.info(f"Saved DESeq2 results to {outfile}") + + # Write top genes + self._write_top_genes(results, level) + + return outfile, results + + def _get_condition_data(self, pattern: str) -> pd.DataFrame: + """Combine count data from all conditions.""" + all_counts = [] + + # Modify pattern if ref_only is False and pattern is transcript_grouped_counts - CORRECTED LOGIC + adjusted_pattern = pattern # Initialize adjusted_pattern to the original pattern + if not self.ref_only and pattern == "transcript_grouped_counts.tsv": # Only adjust if ref_only is False AND base pattern is transcript_grouped_counts + adjusted_pattern = "transcript_model_grouped_counts.tsv" + + for condition in self.ref_conditions + self.target_conditions: + condition_dir = self.output_dir / condition + count_files = list(condition_dir.glob(f"*{adjusted_pattern}")) # Use adjusted pattern + + for file_path in count_files: + self.logger.info(f"Reading count data from: {file_path}") + df = pd.read_csv(file_path, sep="\t", dtype={"#feature_id": str}) + df.set_index("#feature_id", inplace=True) + + # Rename columns to include condition + for col in df.columns: + if col.lower() != "count": + df = df.rename(columns={col: f"{condition}_{col}"}) + + all_counts.append(df) + + return pd.concat(all_counts, axis=1) + + def _filter_counts(self, count_data: pd.DataFrame, min_count: int = 10, level: str = "gene") -> pd.DataFrame: + """ + Filter features based on counts. + + For genes: Keep if mean count >= min_count in either condition group + For transcripts: Keep if count >= min_count in at least half of all samples + """ + if level == "transcript": + total_samples = len(count_data.columns) + min_samples_required = total_samples // 2 + samples_passing = (count_data >= min_count).sum(axis=1) + keep_features = samples_passing >= min_samples_required + + self.logger.info( + f"Transcript filtering: Keeping transcripts with counts >= {min_count} " + f"in at least {min_samples_required}/{total_samples} samples" + ) + else: # gene level + ref_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.ref_conditions) + ] + tgt_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.target_conditions) + ] + + ref_means = count_data[ref_cols].mean(axis=1) + tgt_means = count_data[tgt_cols].mean(axis=1) + keep_features = (ref_means >= min_count) | (tgt_means >= min_count) + + self.logger.info( + f"Gene filtering: Keeping genes with mean count >= {min_count} " + f"in either condition group" + ) + + filtered_data = count_data[keep_features] + self.logger.info( + f"After filtering: Retained {filtered_data.shape[0]}/{count_data.shape[0]} features " + f"({(filtered_data.shape[0]/count_data.shape[0]*100):.1f}%)" + ) + + return filtered_data + + def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: + """Create experimental design matrix.""" + groups = [] + for sample in count_data.columns: + if any(sample.startswith(f"{cond}_") for cond in self.ref_conditions): + groups.append("Reference") + else: + groups.append("Target") + + return pd.DataFrame({"group": groups}, index=count_data.columns) + + def _run_deseq2( + self, count_data: pd.DataFrame, coldata: pd.DataFrame + ) -> pd.DataFrame: + """Run DESeq2 analysis.""" + deseq2 = importr("DESeq2") + count_data = count_data.fillna(0).round().astype(int) + + with localconverter(robjects.default_converter + pandas2ri.converter): + dds = deseq2.DESeqDataSetFromMatrix( + countData=pandas2ri.py2rpy(count_data), + colData=pandas2ri.py2rpy(coldata), + design=Formula("~ group"), + ) + dds = deseq2.DESeq(dds) + res = deseq2.results( + dds, contrast=robjects.StrVector(["group", "Target", "Reference"]) + ) + return pd.DataFrame( + robjects.conversion.rpy2py(r("data.frame")(res)), index=count_data.index + ) + + def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Tuple[str, str]]: + """Map feature IDs to symbols with MyGene.info fallback.""" + if level == "gene": + return self.gene_mapper.map_genes(feature_ids, self.updated_gene_dict) + else: + return self._map_transcript_symbols(feature_ids) + + def _map_transcript_symbols(self, feature_ids: List[str]) -> Dict[str, Tuple[str, str]]: + mapping = {} + missing_count = 0 + found_count = 0 + missing_samples = [] # Track first few missing transcripts for logging + + # Store expression stats for missing transcripts + missing_stats = [] + missing_expression_values = {} # Store expression values for logging + + self.logger.debug(f"_map_transcript_symbols: Shape of self.transcript_count_data: {self.transcript_count_data.shape}") + self.logger.debug(f"_map_transcript_symbols: Sample indices of self.transcript_count_data (first 10): {self.transcript_count_data.index[:10].tolist()}") + self.logger.debug(f"_map_transcript_symbols: Sample feature_ids (first 10 from DESeq2 results): {feature_ids[:10]}") + + for fid in feature_ids: + transcript_found = False + + # Existing search logic through updated_gene_dict + for gene_category, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + transcripts = gene_info.get("transcripts", {}) + if fid in transcripts: + gene_name = gene_info.get("name", gene_id) + transcript_info = transcripts[fid] + transcript_name = transcript_info.get("name", fid) + mapping[fid] = (transcript_name, gene_name) + found_count += 1 + transcript_found = True + break + if transcript_found: + break + + if not transcript_found: + # Get expression data from filtered counts + if fid in self.transcript_count_data.index: + counts = self.transcript_count_data.loc[fid] + stats = { + 'mean': counts.mean(), + 'max': counts.max(), + 'samples_gt0': sum(counts > 0), + 'total_samples': len(counts) + } + missing_stats.append(stats) + + # Store first 5 missing IDs for logging and expression values + if len(missing_samples) < 5: + missing_samples.append(fid) + missing_expression_values[fid] = counts.to_dict() # Store expression values + else: + # Store first 5 missing IDs not in matrix + if len(missing_samples) < 5: + missing_samples.append(fid) + + mapping[fid] = (fid, "Unknown") + missing_count += 1 + + # Log aggregate statistics + if missing_count > 0: + log_msg = [ + f"Missing {missing_count} transcripts in gene dictionary:", + f"- Sample missing IDs: {missing_samples[:5]}" + ] + + if missing_stats: + avg_mean = sum(s['mean'] for s in missing_stats) / len(missing_stats) + avg_max = sum(s['max'] for s in missing_stats) / len(missing_stats) + avg_expression_rate = sum(s['samples_gt0'] for s in missing_stats) / sum(s['total_samples'] for s in missing_stats) + + log_msg.extend([ + f"- Average mean count: {avg_mean:.1f}", + f"- Average max count: {avg_max:.1f}", + f"- Average expression rate: {avg_expression_rate:.1%}" + ]) + + # Add expression values to log for sample missing transcripts + if missing_expression_values: + expression_log = [] + for missing_id in missing_samples: + if missing_id in missing_expression_values: + expr_vals = missing_expression_values[missing_id] + expr_str = ", ".join([f"{sample}:{val}" for sample, val in expr_vals.items()]) + expression_log.append(f" - {missing_id} expression: [{expr_str}]") + if expression_log: + log_msg.append("- Sample missing transcript expression values:") + log_msg.extend(expression_log) + + + else: + log_msg.append("- All missing transcripts absent from count matrix") + + self.logger.warning("\n".join(log_msg)) + + return mapping + + def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: + """Write genes associated with top 100 transcripts by absolute fold change to file.""" + results["abs_stat"] = abs(results["stat"]) + + if level == "transcript": + top_transcripts = results.nlargest(len(results), "abs_stat") # Get ALL transcripts ranked by abs_stat + + unique_genes = set() + top_unique_gene_transcripts = [] + transcript_count = 0 + unique_gene_count = 0 + + for _, transcript_row in top_transcripts.iterrows(): + gene_name = transcript_row["gene_name"] + if gene_name not in unique_genes: + unique_genes.add(gene_name) + top_unique_gene_transcripts.append(transcript_row) + unique_gene_count += 1 + if unique_gene_count >= 100: # Stop when we reach 100 unique genes + break + transcript_count += 1 # Keep track of total transcripts considered + + top_genes = [row["gene_name"] for row in top_unique_gene_transcripts] # Extract gene names from selected transcripts + + # Write to file + top_genes_file = self.deseq_dir / "genes_from_top_100_transcripts.txt" + pd.Series(top_genes).to_csv(top_genes_file, index=False, header=False) + self.logger.info(f"Wrote genes from top 100 transcripts to {top_genes_file}") + + # Log the top transcripts for debugging - UPDATED LOGGING + self.logger.debug(f"Total transcripts considered to get top 100 unique genes: {transcript_count}") + self.logger.debug(f"Number of unique genes found: {unique_gene_count}") + for row in top_unique_gene_transcripts: # Iterate through selected transcripts + self.logger.debug( + f"Transcript ID: {row['feature_id']}, " + f"Transcript Name: {row['symbol']}, " + f"Gene Name: {row['gene_name']}, " + f"abs_stat: {row['abs_stat']}" + ) + else: + # For gene-level analysis, keep original behavior + top_genes = results.nlargest(100, "abs_stat")["symbol"] + top_genes_file = self.deseq_dir / "top_100_genes.txt" + top_genes.to_csv(top_genes_file, index=False, header=False) + self.logger.info(f"Wrote top 100 genes to {top_genes_file}") \ No newline at end of file From bfff4b35e2398df7cba09662ba7c74706edebe3d Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Tue, 28 Jan 2025 15:11:26 -0600 Subject: [PATCH 19/35] DE at both levels and plotting work --- src/visualization_dictionary_builder.py | 3 +- src/visualization_differential_exp.py | 137 +------ src/visualization_mapping.py | 210 ++++++++++ src/visualization_output_config.py | 477 ++++++++++++++++++++++ src/visualization_plotter.py | 513 ++++++++++++++++++++++++ src/visualize_expression.py | 193 +++++++++ visualize.py | 371 ++++++++++------- 7 files changed, 1646 insertions(+), 258 deletions(-) create mode 100644 src/visualization_mapping.py create mode 100644 src/visualization_output_config.py create mode 100644 src/visualization_plotter.py create mode 100644 src/visualize_expression.py diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py index 85821946..6600a3d9 100644 --- a/src/visualization_dictionary_builder.py +++ b/src/visualization_dictionary_builder.py @@ -465,7 +465,8 @@ def parse_extended_annotation(self) -> Dict[str, Any]: "end": int(fields[4]), "exons": [], "expression": 0.0, - "tags": attrs.get("tags", "").split(",") + "tags": attrs.get("tags", "").split(","), + "name": attrs.get("transcript_name", transcript_id), } elif feature_type == "exon" and transcript_id and gene_id: diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 91b2210d..6172ccfa 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -186,12 +186,13 @@ def _run_level_analysis( results.reset_index(inplace=True) mapping = self._map_gene_symbols(results["feature_id"].unique(), level) - # Add symbol column - results["symbol"] = results["feature_id"].map(lambda x: mapping[x][0]) - - # Add gene_name column only for transcript level - if level == "transcript": - results["gene_name"] = results["feature_id"].map(lambda x: mapping[x][1]) + # Add transcript_symbol and gene_name columns + results["transcript_symbol"] = results["feature_id"].map(lambda x: mapping[x]["transcript_symbol"]) + results["gene_name"] = results["feature_id"].map(lambda x: mapping[x]["gene_name"]) + + # Drop transcript_symbol column for gene-level analysis as it's redundant + if level == "gene": + results = results.drop(columns=["transcript_symbol"]) # Save results target_label = "+".join(self.target_conditions) @@ -308,106 +309,20 @@ def _run_deseq2( robjects.conversion.rpy2py(r("data.frame")(res)), index=count_data.index ) - def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Tuple[str, str]]: - """Map feature IDs to symbols with MyGene.info fallback.""" - if level == "gene": - return self.gene_mapper.map_genes(feature_ids, self.updated_gene_dict) - else: - return self._map_transcript_symbols(feature_ids) - - def _map_transcript_symbols(self, feature_ids: List[str]) -> Dict[str, Tuple[str, str]]: - mapping = {} - missing_count = 0 - found_count = 0 - missing_samples = [] # Track first few missing transcripts for logging - - # Store expression stats for missing transcripts - missing_stats = [] - missing_expression_values = {} # Store expression values for logging - - self.logger.debug(f"_map_transcript_symbols: Shape of self.transcript_count_data: {self.transcript_count_data.shape}") - self.logger.debug(f"_map_transcript_symbols: Sample indices of self.transcript_count_data (first 10): {self.transcript_count_data.index[:10].tolist()}") - self.logger.debug(f"_map_transcript_symbols: Sample feature_ids (first 10 from DESeq2 results): {feature_ids[:10]}") - - for fid in feature_ids: - transcript_found = False - - # Existing search logic through updated_gene_dict - for gene_category, genes in self.updated_gene_dict.items(): - for gene_id, gene_info in genes.items(): - transcripts = gene_info.get("transcripts", {}) - if fid in transcripts: - gene_name = gene_info.get("name", gene_id) - transcript_info = transcripts[fid] - transcript_name = transcript_info.get("name", fid) - mapping[fid] = (transcript_name, gene_name) - found_count += 1 - transcript_found = True - break - if transcript_found: - break - - if not transcript_found: - # Get expression data from filtered counts - if fid in self.transcript_count_data.index: - counts = self.transcript_count_data.loc[fid] - stats = { - 'mean': counts.mean(), - 'max': counts.max(), - 'samples_gt0': sum(counts > 0), - 'total_samples': len(counts) - } - missing_stats.append(stats) - - # Store first 5 missing IDs for logging and expression values - if len(missing_samples) < 5: - missing_samples.append(fid) - missing_expression_values[fid] = counts.to_dict() # Store expression values - else: - # Store first 5 missing IDs not in matrix - if len(missing_samples) < 5: - missing_samples.append(fid) - - mapping[fid] = (fid, "Unknown") - missing_count += 1 - - # Log aggregate statistics - if missing_count > 0: - log_msg = [ - f"Missing {missing_count} transcripts in gene dictionary:", - f"- Sample missing IDs: {missing_samples[:5]}" - ] - - if missing_stats: - avg_mean = sum(s['mean'] for s in missing_stats) / len(missing_stats) - avg_max = sum(s['max'] for s in missing_stats) / len(missing_stats) - avg_expression_rate = sum(s['samples_gt0'] for s in missing_stats) / sum(s['total_samples'] for s in missing_stats) - - log_msg.extend([ - f"- Average mean count: {avg_mean:.1f}", - f"- Average max count: {avg_max:.1f}", - f"- Average expression rate: {avg_expression_rate:.1%}" - ]) - - # Add expression values to log for sample missing transcripts - if missing_expression_values: - expression_log = [] - for missing_id in missing_samples: - if missing_id in missing_expression_values: - expr_vals = missing_expression_values[missing_id] - expr_str = ", ".join([f"{sample}:{val}" for sample, val in expr_vals.items()]) - expression_log.append(f" - {missing_id} expression: [{expr_str}]") - if expression_log: - log_msg.append("- Sample missing transcript expression values:") - log_msg.extend(expression_log) - + def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dict[str, Optional[str]]]: + """ + Map feature IDs to gene and transcript names using GeneMapper class. - else: - log_msg.append("- All missing transcripts absent from count matrix") - - self.logger.warning("\n".join(log_msg)) + Args: + feature_ids: List of feature IDs (gene symbols or transcript IDs) + level: Analysis level ("gene" or "transcript") - return mapping + Returns: + Dict[str, Dict[str, Optional[str]]]: Mapping from feature ID to a dictionary + containing 'transcript_symbol' and 'gene_name'. + 'transcript_symbol' is None for gene-level analysis. + """ + return self.gene_mapper.map_gene_symbols(feature_ids, level, self.updated_gene_dict) def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: """Write genes associated with top 100 transcripts by absolute fold change to file.""" @@ -437,20 +352,10 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: top_genes_file = self.deseq_dir / "genes_from_top_100_transcripts.txt" pd.Series(top_genes).to_csv(top_genes_file, index=False, header=False) self.logger.info(f"Wrote genes from top 100 transcripts to {top_genes_file}") - - # Log the top transcripts for debugging - UPDATED LOGGING - self.logger.debug(f"Total transcripts considered to get top 100 unique genes: {transcript_count}") - self.logger.debug(f"Number of unique genes found: {unique_gene_count}") - for row in top_unique_gene_transcripts: # Iterate through selected transcripts - self.logger.debug( - f"Transcript ID: {row['feature_id']}, " - f"Transcript Name: {row['symbol']}, " - f"Gene Name: {row['gene_name']}, " - f"abs_stat: {row['abs_stat']}" - ) else: # For gene-level analysis, keep original behavior - top_genes = results.nlargest(100, "abs_stat")["symbol"] + # top_genes = results.nlargest(100, "abs_stat")["symbol"] # OLD: was writing symbols (gene IDs) + top_genes = results.nlargest(100, "abs_stat")["gene_name"] top_genes_file = self.deseq_dir / "top_100_genes.txt" top_genes.to_csv(top_genes_file, index=False, header=False) self.logger.info(f"Wrote top 100 genes to {top_genes_file}") \ No newline at end of file diff --git a/src/visualization_mapping.py b/src/visualization_mapping.py new file mode 100644 index 00000000..b91718f0 --- /dev/null +++ b/src/visualization_mapping.py @@ -0,0 +1,210 @@ +import mygene +import logging +from typing import Dict, List, Tuple, Optional + +class GeneMapper: + def __init__(self): + self.mg = mygene.MyGeneInfo() + self.logger = logging.getLogger('IsoQuant.visualization.mapping') + self.logger.setLevel(logging.INFO) + + def get_gene_info_from_mygene(self, ensembl_ids: List[str]) -> Dict[str, Dict]: + """ + Query MyGene.info API for gene information using batch query. + + Args: + ensembl_ids: List of Ensembl gene IDs + + Returns: + Dict mapping query IDs to gene information + """ + try: + # Batch query for gene information + results = self.mg.querymany( + ensembl_ids, + scopes='ensembl.gene', # Only search for gene IDs + fields=['symbol', 'name'], # Only get essential fields + species='human', + as_dataframe=False, + returnall=True + ) + + # Process results + mapping = {} + for hit in results['out']: + query_id = hit.get('query', '') + if 'notfound' in hit: + self.logger.debug(f"Gene ID not found: {query_id}") + continue + + mapping[query_id] = { + 'symbol': hit.get('symbol', query_id), + 'name': hit.get('name', hit.get('symbol', query_id)) + } + + # Log query statistics + self.logger.info( + f"MyGene.info query stats: " + f"Total={len(ensembl_ids)}, " + f"Found={len(mapping)}, " + f"Missing={len(ensembl_ids) - len(mapping)}" + ) + + return mapping + + except Exception as e: + self.logger.error(f"Failed to fetch info from MyGene.info: {str(e)}") + return {} + + def map_genes(self, gene_ids: List[str], updated_gene_dict: Dict) -> Dict[str, Tuple[str, str]]: + """ + Map Ensembl gene IDs to symbols, using multiple fallback methods: + 1. IsoQuant's updated_gene_dict + 2. MyGene.info + 3. Parse symbol from Ensembl ID if possible + + Returns: + Dict mapping Ensembl IDs to (symbol, gene_name) tuples + """ + mapping = {} + unmapped_ids = [] + + # First try to map using updated_gene_dict + for gene_id in gene_ids: + symbol_found = False + for gene_category, genes in updated_gene_dict.items(): + if gene_id in genes: + gene_info = genes[gene_id] + # Only use name if it's not empty and not the same as gene_id + if gene_info.get("name") and gene_info["name"] != gene_id: + mapping[gene_id] = (gene_info["name"], gene_info["name"]) + symbol_found = True + break + + if not symbol_found: + unmapped_ids.append(gene_id) + + # For unmapped genes, try MyGene.info batch query + if unmapped_ids: + self.logger.info(f"Querying MyGene.info for {len(unmapped_ids)} unmapped genes") + mygene_results = self.get_gene_info_from_mygene(unmapped_ids) + + remaining_unmapped = [] + for gene_id in unmapped_ids: + if gene_id in mygene_results: + info = mygene_results[gene_id] + mapping[gene_id] = (info['symbol'], info['name']) + self.logger.debug(f"Mapped {gene_id} to {info['symbol']} using MyGene.info") + else: + remaining_unmapped.append(gene_id) + + # For still unmapped genes, try to extract info from Ensembl ID + for gene_id in remaining_unmapped: + # Try to extract meaningful info from Ensembl ID + if gene_id.startswith('ENSG'): + # For novel genes, use the last part of the ID as a temporary symbol + temp_symbol = f"GENE_{gene_id.split('0')[-1]}" + mapping[gene_id] = (temp_symbol, gene_id) + self.logger.warning(f"Using derived symbol {temp_symbol} for {gene_id}") + else: + mapping[gene_id] = (gene_id, gene_id) + self.logger.warning(f"Could not map {gene_id} using any method") + + return mapping + + def map_gene_symbols(self, feature_ids: List[str], level: str, updated_gene_dict: Dict = None) -> Dict[str, Dict[str, Optional[str]]]: + """ + Map feature IDs to gene and transcript names using updated gene dictionary. + + Args: + feature_ids: List of feature IDs (gene symbols or transcript IDs) + level: Analysis level ("gene" or "transcript") + updated_gene_dict: Optional updated gene dictionary + + Returns: + Dict[str, Dict[str, Optional[str]]]: Mapping from feature ID to a dictionary + containing 'transcript_symbol' and 'gene_name'. + 'transcript_symbol' is None for gene-level analysis. + """ + mapping: Dict[str, Dict[str, Optional[str]]] = {} + unmapped_gene_ids_batch: List[str] = [] # Initialize list to collect unmapped gene IDs for batch query + + for feature_id in feature_ids: + if level == "gene": + # Gene-level mapping: Search in updated_gene_dict, fallback to batched MyGene API + gene_name = None + found_in_dict = False + if updated_gene_dict: + for condition, condition_gene_dict in updated_gene_dict.items(): + if feature_id in condition_gene_dict: + found_in_dict = True + gene_name = condition_gene_dict[feature_id].get("name") + break + if not found_in_dict: + unmapped_gene_ids_batch.append(feature_id) # Add to batch list for MyGene query + else: + unmapped_gene_ids_batch.append(feature_id) # Add to batch list for MyGene query + + + mapping[feature_id] = { + "transcript_symbol": gene_name, # For gene-level, use gene name as transcript_symbol + "gene_name": gene_name if gene_name else feature_id + } + + elif level == "transcript": + # Transcript-level mapping: Search for transcript name in updated_gene_dict across all conditions + gene_name = None + transcript_symbol = None + gene_found_for_transcript = False # Flag to track if gene is found for transcript + + if updated_gene_dict: + for condition, condition_gene_dict in updated_gene_dict.items(): # Iterate through conditions + for gene_id, gene_data in condition_gene_dict.items(): # Iterate through genes in each condition + if "transcripts" in gene_data and feature_id in gene_data["transcripts"]: + gene_found_for_transcript = True + transcript_info = gene_data["transcripts"].get(feature_id, {}) + transcript_symbol = transcript_info.get("name") + mapping[feature_id] = { + "transcript_symbol": f"{transcript_symbol} ({gene_data.get('name')})" if feature_id.startswith("transcript") else transcript_symbol, + "gene_name": gene_data.get("name") # Get gene_name from gene_data + } + self.logger.debug(f"Transcript-level mapping: Found transcript {feature_id}, gene_data: {gene_data}") # Debug log to inspect gene_data + break # Found transcript, exit inner loop (genes in condition) + if gene_found_for_transcript: # If transcript found in any gene in this condition, exit condition loop + break + if not gene_found_for_transcript: + self.logger.debug(f"Transcript-level mapping: No gene found for Transcript ID {feature_id} in updated_gene_dict across any condition") + mapping[feature_id] = { # Assign mapping here for not found case + "transcript_symbol": f"{feature_id} (No gene name)", # Indicate no gene name available + "gene_name": None + } + else: # If updated_gene_dict is None + self.logger.debug("Transcript-level mapping: updated_gene_dict is None") + mapping[feature_id] = { + "transcript_symbol": f"{feature_id} (No gene name)", # Indicate no gene name available when dict is None + "gene_name": None # gene_name is None + } + self.logger.debug(f"Transcript-level mapping: Using feature_id as transcript_symbol, no gene name available (updated_gene_dict is None)") # Debug log + + else: + raise ValueError(f"Invalid level: {level}. Must be 'gene' or 'transcript'.") + + # Perform batched MyGene API query for all unmapped gene IDs at once (gene-level only) + if level == "gene" and unmapped_gene_ids_batch: + self.logger.info(f"Gene-level mapping: Performing batched MyGene API query for {len(unmapped_gene_ids_batch)} gene IDs") + mygene_batch_info = self.get_gene_info_from_mygene(unmapped_gene_ids_batch) # Batched query + + if mygene_batch_info: + for feature_id in unmapped_gene_ids_batch: # Iterate through the unmapped IDs + if feature_id in mygene_batch_info: # Check if MyGene returned info for this ID + gene_name_from_mygene = mygene_batch_info[feature_id].get('symbol') + if gene_name_from_mygene: + mapping[feature_id]["gene_name"] = gene_name_from_mygene # Update gene_name in mapping + mapping[feature_id]["transcript_symbol"] = gene_name_from_mygene # Update transcript_symbol + else: + self.logger.debug(f"Gene-level mapping: Batched MyGene API did not return info for Feature ID {feature_id}") + else: + self.logger.warning("Gene-level mapping: Batched MyGene API query failed or returned no results.") + + + return mapping \ No newline at end of file diff --git a/src/visualization_output_config.py b/src/visualization_output_config.py new file mode 100644 index 00000000..0c79171d --- /dev/null +++ b/src/visualization_output_config.py @@ -0,0 +1,477 @@ +import csv +import os +import pickle +import gzip +import shutil +from argparse import Namespace +import gffutils +import yaml +from typing import List +import logging + +class OutputConfig: + """Class to build dictionaries from the output files of the pipeline.""" + + def __init__( + self, + output_directory: str, + use_counts: bool = False, + ref_only: bool = False, + gtf: str = None, + ): + self.output_directory = output_directory + self.log_details = {} + self.extended_annotation = None + self.read_assignments = None + self.input_gtf = gtf # Initialize with the provided gtf flag + self.genedb_filename = None + self.yaml_input = True + self.yaml_input_path = None + self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. + self._conditions = None # Changed from self.conditions = False + self.gene_grouped_counts = None + self.transcript_grouped_counts = None + self.transcript_grouped_tpm = None + self.gene_grouped_tpm = None + self.gene_counts = None + self.transcript_counts = None + self.gene_tpm = None + self.transcript_tpm = None + self.transcript_model_counts = None + self.transcript_model_tpm = None + self.transcript_model_grouped_tpm = None + self.transcript_model_grouped_counts = None + self.use_counts = use_counts + self.ref_only = ref_only + + # New attributes for handling extended annotations + self.sample_extended_gtfs = [] + self.merged_extended_gtf = None + + # Attributes to store sample-level transcript model data + self.samples = [] + self.sample_transcript_model_tpm = {} + self.sample_transcript_model_counts = {} + + self._load_params_file() + self._find_files() + self._conditional_unzip() + + # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log + if self.ref_only and not self.input_gtf: + raise ValueError( + "Input GTF file is required when ref_only is set. Please provide it using the --gtf flag." + ) + + def _load_params_file(self): + """Load the .params file for necessary configuration and commands.""" + params_path = os.path.join(self.output_directory, ".params") + assert os.path.exists(params_path), f"Params file not found: {params_path}" + try: + with open(params_path, "rb") as file: + params = pickle.load(file) + if isinstance(params, Namespace): + self._process_params(vars(params)) + else: + print("Unexpected params format.") + except Exception as e: + raise ValueError(f"An error occurred while loading params: {e}") + + def _process_params(self, params): + """Process parameters loaded from the .params file.""" + self.log_details["gene_db"] = params.get("genedb") + self.log_details["fastq_used"] = bool(params.get("fastq")) + self.input_gtf = self.input_gtf or params.get("genedb") + self.genedb_filename = params.get("genedb_filename") + + if params.get("yaml"): + # YAML input case + self.yaml_input = True + self.yaml_input_path = params.get("yaml") + # Keep the output_directory as is, don't modify it + else: + # Non-YAML input case + self.yaml_input = False + processing_sample = params.get("prefix") + if processing_sample: + self.output_directory = os.path.join( + self.output_directory, processing_sample + ) + else: + raise ValueError( + "Processing sample directory not found in params for non-YAML input." + ) + + def _conditional_unzip(self): + """Check if unzip is needed and perform it conditionally based on the model use.""" + if self.ref_only and self.input_gtf and self.input_gtf.endswith(".gz"): + self.input_gtf = self._unzip_file(self.input_gtf) + if not self.input_gtf: + raise FileNotFoundError( + f"Unable to find or unzip the specified file: {self.input_gtf}" + ) + + def _unzip_file(self, file_path): + """Unzip a gzipped file and return the path to the uncompressed file.""" + new_path = file_path[:-3] # Remove .gz extension + + if os.path.exists(new_path): + return new_path + + if not os.path.exists(file_path): + self.gtf_flag_needed = True + return None + + with gzip.open(file_path, "rb") as f_in: + with open(new_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"File {file_path} was decompressed to {new_path}.") + + return new_path + + def _find_files(self): + """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" + if self.yaml_input: + self.conditions = True + self._find_files_from_yaml() + return # Exit the method after processing YAML input + + if not os.path.exists(self.output_directory): + print(f"Directory not found: {self.output_directory}") # Debugging output + raise FileNotFoundError( + f"Specified sample subdirectory does not exist: {self.output_directory}" + ) + + for file_name in os.listdir(self.output_directory): + if file_name.endswith(".extended_annotation.gtf"): + self.extended_annotation = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".read_assignments.tsv"): + self.read_assignments = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".read_assignments.tsv.gz"): + self.read_assignments = self._unzip_file( + os.path.join(self.output_directory, file_name) + ) + elif file_name.endswith(".gene_grouped_counts.tsv"): + self._conditions = self._get_conditions_from_file( + os.path.join(self.output_directory, file_name) + ) + self.gene_grouped_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_grouped_counts.tsv"): + self.transcript_grouped_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_grouped_tpm.tsv"): + self.transcript_grouped_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".gene_grouped_tpm.tsv"): + self.gene_grouped_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".gene_counts.tsv"): + self.gene_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_counts.tsv"): + self.transcript_counts = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".gene_tpm.tsv"): + self.gene_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_tpm.tsv"): + self.transcript_tpm = os.path.join(self.output_directory, file_name) + elif file_name.endswith(".transcript_model_counts.tsv"): + self.transcript_model_counts = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_tpm.tsv"): + self.transcript_model_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_grouped_tpm.tsv"): + self.transcript_model_grouped_tpm = os.path.join( + self.output_directory, file_name + ) + elif file_name.endswith(".transcript_model_grouped_counts.tsv"): + self.transcript_model_grouped_counts = os.path.join( + self.output_directory, file_name + ) + + # Determine if GTF flag is needed + if ( + not self.input_gtf + or ( + not os.path.exists(self.input_gtf) + and not os.path.exists(self.input_gtf + ".gz") + ) + and self.ref_only + ): + self.gtf_flag_needed = True + + # Set ref_only default based on the availability of extended_annotation + if self.ref_only is None: + self.ref_only = not self.extended_annotation + + def _find_files_from_yaml(self): + """Locate files and samples from YAML, apply filters to ensure only valid samples are processed.""" + if not os.path.exists(self.yaml_input_path): + print(f"YAML file not found: {self.yaml_input_path}") + raise FileNotFoundError( + f"Specified YAML file does not exist: {self.yaml_input_path}" + ) + + # Set these attributes based on YAML input expectations + self.gene_grouped_counts = os.path.join( + self.output_directory, "combined_gene_counts.tsv" + ) + self.transcript_grouped_counts = os.path.join( + self.output_directory, "combined_transcript_counts.tsv" + ) + self.transcript_grouped_tpm = os.path.join( + self.output_directory, "combined_transcript_tpm.tsv" + ) + self.gene_grouped_tpm = os.path.join( + self.output_directory, "combined_gene_tpm.tsv" + ) + + # Check if the files exist + for attr in [ + "gene_grouped_counts", + "transcript_grouped_counts", + "transcript_grouped_tpm", + "gene_grouped_tpm", + ]: + file_path = getattr(self, attr) + if not os.path.exists(file_path): + print(f"Warning: {attr} file not found at {file_path}") + setattr(self, attr, None) + + # Initialize read_assignments list + self.read_assignments = [] + + # Read and process the YAML file + with open(self.yaml_input_path, "r") as yaml_file: + yaml_data = yaml.safe_load(yaml_file) + + # If yaml_data is a list but also contains non-sample items, filter them + if isinstance(yaml_data, list): + samples = [ + item for item in yaml_data if isinstance(item, dict) and "name" in item + ] + else: + # If it's not a list, assume it's a dictionary with a 'samples' key + samples = yaml_data.get("samples", []) + # Filter samples + samples = [item for item in samples if "name" in item] + + self.samples = [sample.get("name") for sample in samples] + + # Since we have a YAML file with multiple samples, we have conditions + self.conditions = True + + for sample in samples: + name = sample.get("name") + if name: + sample_dir = os.path.join(self.output_directory, name) + + # Check for extended_annotation.gtf + extended_gtf = os.path.join( + sample_dir, f"{name}.extended_annotation.gtf" + ) + if os.path.exists(extended_gtf): + self.sample_extended_gtfs.append(extended_gtf) + else: + print( + f"Warning: extended_annotation.gtf not found for sample {name}" + ) + + # Check for .read_assignments.tsv.gz + gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") + if os.path.exists(gz_file): + unzipped_file = self._unzip_file(gz_file) + if unzipped_file: + self.read_assignments.append((name, unzipped_file)) + else: + print(f"Warning: Failed to unzip {gz_file}") + else: + # Check for .read_assignments.tsv + non_gz_file = os.path.join( + sample_dir, f"{name}.read_assignments.tsv" + ) + if os.path.exists(non_gz_file): + self.read_assignments.append((name, non_gz_file)) + else: + print(f"Warning: No read assignments file found for {name}") + + # Load transcript_model_tpm and transcript_model_counts for merging + tpm_path = os.path.join(sample_dir, f"{name}.transcript_model_tpm.tsv") + counts_path = os.path.join( + sample_dir, f"{name}.transcript_model_counts.tsv" + ) + + self.sample_transcript_model_tpm[name] = ( + tpm_path if os.path.exists(tpm_path) else None + ) + self.sample_transcript_model_counts[name] = ( + counts_path if os.path.exists(counts_path) else None + ) + + if not self.read_assignments: + print("Warning: No read assignment files found for any samples") + + # Handle extended annotations only if ref_only is not True + if self.ref_only is not True: + self._handle_extended_annotations(samples_count=len(self.samples)) + + # Merge transcript_model_tpm and transcript_model_counts if conditions are met and not ref_only + # and we have extended annotations (if needed) + if self.yaml_input and not self.ref_only and self.extended_annotation: + merged_tpm = os.path.join( + self.output_directory, "combined_transcript_tpm_merged.tsv" + ) + merged_counts = os.path.join( + self.output_directory, "combined_transcript_counts_merged.tsv" + ) + + if os.path.exists(merged_tpm) and os.path.exists(merged_counts): + # Load directly + self.transcript_grouped_tpm = merged_tpm + self.transcript_grouped_counts = merged_counts + else: + # Perform merging + self._merge_transcript_files( + self.sample_transcript_model_tpm, merged_tpm, "TPM" + ) + self._merge_transcript_files( + self.sample_transcript_model_counts, merged_counts, "Count" + ) + self.transcript_grouped_tpm = merged_tpm + self.transcript_grouped_counts = merged_counts + + def _handle_extended_annotations(self, samples_count): + """Check if extended annotations should be handled. If ref_only is true, skip handling them entirely.""" + if self.ref_only: + logging.debug("ref_only is True. Skipping extended annotation merging.") + return + + # Check if merged_extended_annotation.gtf already exists + existing_merged_gtf = os.path.join( + self.output_directory, "merged_extended_annotation.gtf" + ) + existing_partial_merged_gtf = os.path.join( + self.output_directory, "merged_extended_annotation_partial.gtf" + ) + + if os.path.exists(existing_merged_gtf): + logging.debug(f"Found existing merged GTF at {existing_merged_gtf}, using it directly.") + self.merged_extended_gtf = existing_merged_gtf + self.extended_annotation = self.merged_extended_gtf + return + elif os.path.exists(existing_partial_merged_gtf): + logging.debug(f"Found existing partially merged GTF at {existing_partial_merged_gtf}, using it directly.") + self.merged_extended_gtf = existing_partial_merged_gtf + self.extended_annotation = self.merged_extended_gtf + return + + # If no pre-merged file is found, proceed with merging logic + if len(self.sample_extended_gtfs) == samples_count and samples_count > 0: + logging.debug("All samples have extended_annotation.gtf. Proceeding to merge them.") + self.merged_extended_gtf = os.path.join( + self.output_directory, "merged_extended_annotation.gtf" + ) + self.merge_gtfs(self.sample_extended_gtfs, self.merged_extended_gtf) + self.extended_annotation = self.merged_extended_gtf + logging.debug(f"Merged GTF created at: {self.merged_extended_gtf}") + else: + logging.debug("Not all samples have extended_annotation.gtf. Skipping merge.") + + if hasattr(self, "samples") and self.samples: + for s in self.samples: + gtf_path = os.path.join( + self.output_directory, s, f"{s}.extended_annotation.gtf" + ) + if not os.path.exists(gtf_path): + logging.debug( + f"Missing GTF for sample: {s}, expected at {gtf_path}" + ) + + if self.sample_extended_gtfs: + logging.debug("Merging available extended_annotation.gtf files.") + self.merged_extended_gtf = os.path.join( + self.output_directory, "merged_extended_annotation_partial.gtf" + ) + self.merge_gtfs(self.sample_extended_gtfs, self.merged_extended_gtf) + self.extended_annotation = self.merged_extended_gtf + logging.debug(f"Partially merged GTF created at: {self.merged_extended_gtf}") + else: + logging.debug( + "No extended_annotation.gtf files found. Continuing without merge." + ) + + def _merge_transcript_files(self, sample_files_dict, output_file, metric_type): + # sample_files_dict: {sample_name: filepath or None} + # Merge logic: + # 1. Gather all transcripts from all samples + # 2. For each transcript, write a line with transcript_id and values from each sample (0 if missing) + transcripts = {} + samples = self.samples + + # Read each sample file + for sample_name, file_path in sample_files_dict.items(): + if file_path and os.path.exists(file_path): + with open(file_path, "r") as f: + reader = csv.reader(f, delimiter="\t") + header = next(reader) + for row in reader: + if len(row) < 2: + continue + transcript_id = row[0] + value_str = row[1] if len(row) > 1 else "0" + try: + value = float(value_str) + except ValueError: + value = 0.0 + if transcript_id not in transcripts: + transcripts[transcript_id] = {} + transcripts[transcript_id][sample_name] = value + else: + # Sample missing file, will assign 0 later + pass + + # Write merged file + with open(output_file, "w", newline="") as out_f: + writer = csv.writer(out_f, delimiter="\t") + header = ["#feature_id"] + samples + writer.writerow(header) + for transcript_id in sorted(transcripts.keys()): + row = [transcript_id] + for sample_name in samples: + row.append(transcripts[transcript_id].get(sample_name, 0)) + writer.writerow(row) + + def merge_gtfs(self, gtfs, output_gtf): + """Merge multiple GTF files into a single GTF file.""" + try: + with open(output_gtf, "w") as outfile: + for gtf in gtfs: + with open(gtf, "r") as infile: + shutil.copyfileobj(infile, outfile) + print(f"Successfully merged {len(gtfs)} GTF files into {output_gtf}") + except Exception as e: + raise Exception(f"Failed to merge GTF files: {e}") + + def _get_conditions_from_file(self, file_path: str) -> List[str]: + """Extract conditions from file header.""" + try: + with open(file_path) as f: + header = f.readline().strip().split('\t') + return header[1:] # Skip the first column (gene IDs) + except Exception as e: + logging.error(f"Error reading conditions from {file_path}: {e}") + return [] + + @property + def conditions(self): + return self._conditions + + @conditions.setter + def conditions(self, value): + self._conditions = value diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py new file mode 100644 index 00000000..54c24193 --- /dev/null +++ b/src/visualization_plotter.py @@ -0,0 +1,513 @@ +import os +import matplotlib.pyplot as plt +import numpy as np +from pathlib import Path +import logging +import pandas as pd +import random +import json +import matplotlib.patches as patches + + +class PlotOutput: + def __init__( + self, + updated_gene_dict, + gene_names, + gene_visualizations_dir, + read_assignments_dir, + reads_and_class=None, + filter_transcripts=None, + conditions=False, + use_counts=False, + ): + self.updated_gene_dict = updated_gene_dict + self.gene_names = gene_names + self.gene_visualizations_dir = gene_visualizations_dir + self.read_assignments_dir = read_assignments_dir + self.reads_and_class = reads_and_class + self.filter_transcripts = filter_transcripts + self.conditions = conditions + self.use_counts = use_counts + + # Ensure output directories exist + if self.gene_visualizations_dir: + os.makedirs(self.gene_visualizations_dir, exist_ok=True) + os.makedirs(self.read_assignments_dir, exist_ok=True) + + def plot_transcript_map(self): + """Plot transcript structure with different colors for reference and novel exons.""" + if not self.gene_visualizations_dir: + logging.warning("No gene_visualizations_dir provided. Skipping transcript map plotting.") + return + + for gene_name in self.gene_names: + gene_data = {} + for condition, genes in self.updated_gene_dict.items(): + if gene_name in genes: + gene_data = genes[gene_name] + break + + if not gene_data: + logging.warning(f"Gene {gene_name} not found in the data.") + continue + + # Get chromosome info and calculate buffer + chromosome = gene_data.get("chromosome", "Unknown") + start = gene_data.get("start", 0) + end = gene_data.get("end", 0) + + # Calculate buffer (5% of total width) + width = end - start + buffer = width * 0.05 + plot_start = start - buffer + plot_end = end + buffer + + plot_height = max(8, len(gene_data["transcripts"]) * 0.4) + logging.debug(f"Creating transcript map for gene '{gene_name}' with {len(gene_data['transcripts'])} transcripts") + + # Collect all reference exon coordinates from reference transcripts + reference_exons = set() + for transcript_id, transcript_info in gene_data["transcripts"].items(): + if transcript_id.startswith("ENST"): + for exon in transcript_info["exons"]: + # Store exon coordinates as tuple for easy comparison + reference_exons.add((exon["start"], exon["end"])) + + logging.debug(f"Found {len(reference_exons)} reference exons for gene '{gene_name}'") + + fig, ax = plt.subplots(figsize=(12, plot_height)) + + # Add legend handles + legend_elements = [ + patches.Patch(facecolor='skyblue', label='Reference Exon'), + patches.Patch(facecolor='red', alpha=0.6, label='Novel Exon') + ] + + # Plot each transcript + y_ticks = [] + y_labels = [] + for i, (transcript_id, transcript_info) in enumerate(gene_data["transcripts"].items()): + # Plot direction marker + direction_marker = ">" if gene_data["strand"] == "+" else "<" + marker_pos = ( + transcript_info["end"] + 100 + if gene_data["strand"] == "+" + else transcript_info["start"] - 100 + ) + ax.plot( + marker_pos, i, marker=direction_marker, markersize=5, color="blue" + ) + + # Draw the line for the whole transcript + ax.plot( + [transcript_info["start"], transcript_info["end"]], + [i, i], + color="grey", + linewidth=2, + ) + + # Exon blocks with color based on reference status + for exon in transcript_info["exons"]: + exon_length = exon["end"] - exon["start"] + # Check if this exon's coordinates match any reference exon + is_reference_exon = (exon["start"], exon["end"]) in reference_exons + exon_color = "skyblue" if is_reference_exon else "red" + exon_alpha = 1.0 if is_reference_exon else 0.6 + + ax.add_patch( + plt.Rectangle( + (exon["start"], i - 0.4), + exon_length, + 0.8, + color=exon_color, + alpha=exon_alpha + ) + ) + + if not any((exon["start"], exon["end"]) in reference_exons for exon in transcript_info["exons"]): + logging.debug(f"Transcript {transcript_id} in gene {gene_name} contains all novel exons") + + # Store y-axis label information + y_ticks.append(i) + # Get transcript name with fallback options + transcript_name = (transcript_info.get("name") or + transcript_info.get("transcript_id") or + transcript_id) + y_labels.append(transcript_name) + + # Set up the plot formatting with just chromosome + if self.filter_transcripts: + title = f"Transcript Structure - {gene_name} (Chromosome {chromosome}) (Count > {self.filter_transcripts})" + else: + title = f"Transcript Structure - {gene_name} (Chromosome {chromosome})" + + ax.set_title(title, pad=20) # Increase padding to move title up + ax.set_xlabel("Chromosomal position") + ax.set_ylabel("Transcripts") + + # Set y-axis ticks and labels + ax.set_yticks(y_ticks) + ax.set_yticklabels(y_labels) + + # Add legend in upper right corner + ax.legend(handles=legend_elements, loc='upper right') + + # Set plot limits with buffer + ax.set_xlim(plot_start, plot_end) + ax.invert_yaxis() # First transcript at the top + + # Add grid lines + ax.grid(True, axis='y', linestyle='--', alpha=0.3) + + plt.tight_layout() + plot_path = os.path.join( + self.gene_visualizations_dir, f"{gene_name}_splicing.png" + ) + plt.savefig(plot_path, bbox_inches='tight', dpi=300) + plt.close(fig) + logging.debug(f"Saved transcript map for gene '{gene_name}' at: {plot_path}") + + + def plot_transcript_usage(self): + """Visualize transcript usage for each gene in gene_names across different conditions.""" + if not self.gene_visualizations_dir: + logging.warning("No gene_visualizations_dir provided. Skipping transcript usage plotting.") + return + + for gene_name in self.gene_names: + gene_data = {} + for condition, genes in self.updated_gene_dict.items(): + if gene_name in genes: + gene_data[condition] = genes[gene_name]["transcripts"] + + if not gene_data: + logging.warning(f"Gene {gene_name} not found in the data.") + continue + + conditions = list(gene_data.keys()) + n_bars = len(conditions) + + fig, ax = plt.subplots(figsize=(12, 8)) + index = np.arange(n_bars) + bar_width = 0.35 + opacity = 0.8 + max_transcripts = max(len(gene_data[condition]) for condition in conditions) + colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) + + bottom_val = np.zeros(n_bars) + for i, condition in enumerate(conditions): + transcripts = gene_data[condition] + for j, (transcript_id, transcript_info) in enumerate(transcripts.items()): + color = colors[j % len(colors)] + value = transcript_info["value"] + # Get transcript name with fallback options + transcript_name = (transcript_info.get("name") or + transcript_info.get("transcript_id") or + transcript_id) + ax.bar( + i, + float(value), + bar_width, + bottom=bottom_val[i], + alpha=opacity, + color=color, + label=transcript_name if i == 0 else "", + ) + bottom_val[i] += float(value) + + ax.set_xlabel("Sample Type") + ax.set_ylabel("Transcript Usage (TPM)") + ax.set_title(f"Transcript Usage for {gene_name} by Sample Type") + ax.set_xticks(index) + ax.set_xticklabels(conditions) + ax.legend( + title="Transcript IDs", + bbox_to_anchor=(1.05, 1), + loc="upper left", + fontsize=8, + ) + + plt.tight_layout() + plot_path = os.path.join( + self.gene_visualizations_dir, + f"{gene_name}_transcript_usage_by_sample_type.png", + ) + plt.savefig(plot_path) + plt.close(fig) + + def make_pie_charts(self): + """ + Create pie charts for transcript alignment classifications and read assignment consistency. + Handles both combined and separate sample data structures. + """ + + titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] + + for title, data in zip(titles, self.reads_and_class): + if isinstance(data, dict): + if any(isinstance(v, dict) for v in data.values()): + # Separate 'Mutants' and 'WildType' case + for sample_name, sample_data in data.items(): + self._create_pie_chart(f"{title} - {sample_name}", sample_data) + else: + # Combined data case + self._create_pie_chart(title, data) + else: + print(f"Skipping unexpected data type for {title}: {type(data)}") + + def _create_pie_chart(self, title, data): + """ + Helper method to create a single pie chart. + """ + labels = list(data.keys()) + sizes = list(data.values()) + total = sum(sizes) + + # Generate a file-friendly title + file_title = title.lower().replace(" ", "_").replace("-", "_") + + plt.figure(figsize=(12, 8)) + wedges, texts, autotexts = plt.pie( + sizes, + labels=labels, + autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total):d})", + startangle=140, + textprops=dict(color="w"), + ) + plt.setp(autotexts, size=8, weight="bold") + plt.setp(texts, size=7) + + plt.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. + plt.title(f"{title}\nTotal: {total}") + + plt.legend( + wedges, + labels, + title="Categories", + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + fontsize=8, + ) + # Save pie charts in the read_assignments directory + plot_path = os.path.join( + self.read_assignments_dir, f"{file_title}_pie_chart.png" + ) + plt.savefig(plot_path, bbox_inches="tight", dpi=300) + plt.close() + + +class ExpressionVisualizer: + def __init__(self, output_path: Path): + """ + Initialize visualizer with output directory. + + Args: + output_path: Path to output directory + """ + self.output_path = Path(output_path) + self.output_path.mkdir(parents=True, exist_ok=True) + + def create_volcano_plot( + self, + df: pd.DataFrame, + target_label: str, + reference_label: str, + padj_threshold: float = 0.05, + lfc_threshold: float = 1, + top_n: int = 10, + feature_type: str = "genes", # Added parameter + ) -> None: + """Create volcano plot from differential expression results.""" + plt.figure(figsize=(10, 8)) + + # Prepare data + df["padj"] = df["padj"].replace(0, 1e-300) + df = df[df["padj"] > 0] + df = df.copy() # Create a copy to avoid the warning + df.loc[:, "-log10(padj)"] = -np.log10(df["padj"]) + + # Define significant genes + significant = (df["padj"] < padj_threshold) & ( + abs(df["log2FoldChange"]) > lfc_threshold + ) + up_regulated = significant & (df["log2FoldChange"] > lfc_threshold) + down_regulated = significant & (df["log2FoldChange"] < -lfc_threshold) + + # Plot points + plt.scatter( + df.loc[~significant, "log2FoldChange"], + df.loc[~significant, "-log10(padj)"], + color="grey", + alpha=0.5, + label="Not Significant", + ) + plt.scatter( + df.loc[up_regulated, "log2FoldChange"], + df.loc[up_regulated, "-log10(padj)"], + color="red", + alpha=0.7, + label=f"Up-regulated in ({target_label})", + ) + plt.scatter( + df.loc[down_regulated, "log2FoldChange"], + df.loc[down_regulated, "-log10(padj)"], + color="blue", + alpha=0.7, + label=f"Up-regulated in ({reference_label})", + ) + + # Add threshold lines and labels + plt.axhline(-np.log10(padj_threshold), color="grey", linestyle="--") + plt.axvline(lfc_threshold, color="grey", linestyle="--") + plt.axvline(-lfc_threshold, color="grey", linestyle="--") + + plt.xlabel("log2 Fold Change") + plt.ylabel("-log10(adjusted p-value)") + plt.title(f"Volcano Plot: {target_label} vs {reference_label}") + plt.legend() + + # Add labels for top significant features + sig_df = df.loc[significant].nsmallest(top_n, "padj") + for _, row in sig_df.iterrows(): + if feature_type == "genes": + symbol = row["gene_name"] if pd.notnull(row["gene_name"]) else row["feature_id"] + elif feature_type == "transcripts": + symbol = row["transcript_symbol"] if pd.notnull(row["transcript_symbol"]) else row["feature_id"] + else: # Fallback to feature_id if feature_type is not recognized + symbol = row["feature_id"] + plt.text( + row["log2FoldChange"], + row["-log10(padj)"], + symbol, + fontsize=8, + ha="center", + va="bottom", + ) + + plt.tight_layout() + plot_path = ( + self.output_path / f"volcano_plot_{feature_type}.png" + ) # Modified line + plt.savefig(str(plot_path)) + plt.close() + logging.info(f"Volcano plot saved to {plot_path}") + + def create_ma_plot( + self, + df: pd.DataFrame, + target_label: str, + reference_label: str, + feature_type: str = "genes", + ) -> None: + """Create MA plot from differential expression results.""" + plt.figure(figsize=(10, 8)) + + # Prepare data + df = df[df["baseMean"] > 0] + df["log10(baseMean)"] = np.log10(df["baseMean"]) + + # Create plot + plt.scatter( + df["log10(baseMean)"], df["log2FoldChange"], alpha=0.5, color="grey" + ) + plt.axhline(y=0, color="red", linestyle="--") + + plt.xlabel("log10(Base Mean)") + plt.ylabel("log2 Fold Change") + plt.title(f"MA Plot: {target_label} vs {reference_label}") + + plt.tight_layout() + plot_path = self.output_path / f"ma_plot_{feature_type}.png" # Modified line + plt.savefig(str(plot_path)) + plt.close() + logging.info(f"MA plot saved to {plot_path}") + + def create_summary( + self, + res_df: pd.DataFrame, + target_label: str, + reference_label: str, + min_count: int, + feature_type: str, + ) -> None: + """ + Create and save analysis summary with correct filtering criteria reporting. + + Args: + res_df: Results DataFrame + target_label: Target condition label + reference_label: Reference condition label + min_count: Minimum count threshold used in filtering + feature_type: Type of features analyzed ("genes" or "transcripts") + """ + total_features = len(res_df) + sig_features = ( + (res_df["padj"] < 0.05) & (res_df["log2FoldChange"].abs() > 1) + ).sum() + up_regulated = ((res_df["padj"] < 0.05) & (res_df["log2FoldChange"] > 1)).sum() + down_regulated = ( + (res_df["padj"] < 0.05) & (res_df["log2FoldChange"] < -1) + ).sum() + + # Incorporate feature_type into the summary filename + summary_filename = f"analysis_summary_{feature_type}.txt" + summary_path = self.output_path / summary_filename + + with summary_path.open("w") as f: + f.write(f"Analysis Summary: {target_label} vs {reference_label}\n") + f.write("================================\n") + + # Different filtering description based on feature type + if feature_type == "genes": + f.write( + f"{feature_type.capitalize()} after filtering " + f"(mean count >= {min_count} in either condition group): {total_features}\n" + ) + else: # transcripts + f.write( + f"{feature_type.capitalize()} after filtering " + f"(count >= {min_count} in at least half of all samples): {total_features}\n" + ) + + f.write(f"Significantly differential {feature_type}: {sig_features}\n") + f.write(f"Up-regulated {feature_type}: {up_regulated}\n") + f.write(f"Down-regulated {feature_type}: {down_regulated}\n") + + logging.info(f"Analysis summary saved to {summary_path}") + + def visualize_results( + self, + results: pd.DataFrame, + target_label: str, + reference_label: str, + min_count: int, + feature_type: str, + ) -> None: + """ + Create all visualizations and summary for the analysis results. + + Args: + results: DataFrame containing differential expression results + target_label: Target condition label + reference_label: Reference condition label + min_count: Minimum count threshold used in filtering + feature_type: Type of features analyzed ("genes" or "transcripts") + """ + try: + self.create_volcano_plot( + results, target_label, reference_label, feature_type=feature_type + ) + self.create_ma_plot( + results, target_label, reference_label, feature_type=feature_type + ) + self.create_summary( + results, + target_label, + reference_label, + min_count, + feature_type=feature_type, + ) + except Exception as e: + logging.exception("Failed to create visualizations") + raise diff --git a/src/visualize_expression.py b/src/visualize_expression.py new file mode 100644 index 00000000..4b9101b3 --- /dev/null +++ b/src/visualize_expression.py @@ -0,0 +1,193 @@ +""" +Visualization and summary module for differential expression analysis results. +""" + +from pathlib import Path +import logging +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from typing import Dict, List + + +class ExpressionVisualizer: + def __init__(self, output_path: Path): + """ + Initialize visualizer with output directory. + + Args: + output_path: Path to output directory + """ + self.output_path = Path(output_path) + self.output_path.mkdir(parents=True, exist_ok=True) + + def create_volcano_plot( + self, + df: pd.DataFrame, + target_label: str, + reference_label: str, + padj_threshold: float = 0.05, + lfc_threshold: float = 1, + top_n: int = 10, + ) -> None: + """Create volcano plot from differential expression results.""" + plt.figure(figsize=(10, 8)) + + # Prepare data + df["padj"] = df["padj"].replace(0, 1e-300) + df = df[df["padj"] > 0] + df = df.copy() # Create a copy to avoid the warning + df.loc[:, "-log10(padj)"] = -np.log10(df["padj"]) + + # Define significant genes + significant = (df["padj"] < padj_threshold) & ( + abs(df["log2FoldChange"]) > lfc_threshold + ) + up_regulated = significant & (df["log2FoldChange"] > lfc_threshold) + down_regulated = significant & (df["log2FoldChange"] < -lfc_threshold) + + # Plot points + plt.scatter( + df.loc[~significant, "log2FoldChange"], + df.loc[~significant, "-log10(padj)"], + color="grey", + alpha=0.5, + label="Not Significant", + ) + plt.scatter( + df.loc[up_regulated, "log2FoldChange"], + df.loc[up_regulated, "-log10(padj)"], + color="red", + alpha=0.7, + label=f"Up-regulated in ({target_label})", + ) + plt.scatter( + df.loc[down_regulated, "log2FoldChange"], + df.loc[down_regulated, "-log10(padj)"], + color="blue", + alpha=0.7, + label=f"Up-regulated in ({reference_label})", + ) + + # Add threshold lines and labels + plt.axhline(-np.log10(padj_threshold), color="grey", linestyle="--") + plt.axvline(lfc_threshold, color="grey", linestyle="--") + plt.axvline(-lfc_threshold, color="grey", linestyle="--") + + plt.xlabel("log2 Fold Change") + plt.ylabel("-log10(adjusted p-value)") + plt.title(f"Volcano Plot: {target_label} vs {reference_label}") + plt.legend() + + # Add labels for top significant features + sig_df = df.loc[significant].nsmallest(top_n, "padj") + for _, row in sig_df.iterrows(): + symbol = row["symbol"] if pd.notnull(row["symbol"]) else row["feature_id"] + plt.text( + row["log2FoldChange"], + row["-log10(padj)"], + symbol, + fontsize=8, + ha="center", + va="bottom", + ) + + plt.tight_layout() + plot_path = self.output_path / "volcano_plot.png" + plt.savefig(str(plot_path)) + plt.close() + logging.info(f"Volcano plot saved to {plot_path}") + + def create_ma_plot( + self, df: pd.DataFrame, target_label: str, reference_label: str + ) -> None: + """Create MA plot from differential expression results.""" + plt.figure(figsize=(10, 8)) + + # Prepare data + df = df[df["baseMean"] > 0] + df["log10(baseMean)"] = np.log10(df["baseMean"]) + + # Create plot + plt.scatter( + df["log10(baseMean)"], df["log2FoldChange"], alpha=0.5, color="grey" + ) + plt.axhline(y=0, color="red", linestyle="--") + + plt.xlabel("log10(Base Mean)") + plt.ylabel("log2 Fold Change") + plt.title(f"MA Plot: {target_label} vs {reference_label}") + + plt.tight_layout() + plot_path = self.output_path / "ma_plot.png" + plt.savefig(str(plot_path)) + plt.close() + logging.info(f"MA plot saved to {plot_path}") + + def create_summary( + self, + res_df: pd.DataFrame, + target_label: str, + reference_label: str, + min_count: int, + feature_type: str, + ) -> None: + """ + Create and save analysis summary. + + Args: + res_df: Results DataFrame + target_label: Target condition label + reference_label: Reference condition label + min_count: Minimum count threshold used in filtering + feature_type: Type of features analyzed ("genes" or "transcripts") + """ + total_features = len(res_df) + sig_features = ( + (res_df["padj"] < 0.05) & (res_df["log2FoldChange"].abs() > 1) + ).sum() + up_regulated = ((res_df["padj"] < 0.05) & (res_df["log2FoldChange"] > 1)).sum() + down_regulated = ( + (res_df["padj"] < 0.05) & (res_df["log2FoldChange"] < -1) + ).sum() + + summary_path = self.output_path / "analysis_summary.txt" + with summary_path.open("w") as f: + f.write(f"Analysis Summary: {target_label} vs {reference_label}\n") + f.write("================================\n") + f.write( + f"{feature_type.capitalize()} after filtering " + f"(mean count >= {min_count} in both groups): {total_features}\n" + ) + f.write(f"Significantly differential {feature_type}: {sig_features}\n") + f.write(f"Up-regulated {feature_type}: {up_regulated}\n") + f.write(f"Down-regulated {feature_type}: {down_regulated}\n") + logging.info(f"Analysis summary saved to {summary_path}") + + def visualize_results( + self, + results: pd.DataFrame, + target_label: str, + reference_label: str, + min_count: int, + feature_type: str, + ) -> None: + """ + Create all visualizations and summary for the analysis results. + + Args: + results: DataFrame containing differential expression results + target_label: Target condition label + reference_label: Reference condition label + min_count: Minimum count threshold used in filtering + feature_type: Type of features analyzed ("genes" or "transcripts") + """ + try: + self.create_volcano_plot(results, target_label, reference_label) + self.create_ma_plot(results, target_label, reference_label) + self.create_summary( + results, target_label, reference_label, min_count, feature_type + ) + except Exception as e: + logging.exception("Failed to create visualizations") + raise diff --git a/visualize.py b/visualize.py index 35f6175c..006a5e25 100755 --- a/visualize.py +++ b/visualize.py @@ -1,25 +1,79 @@ -#!/usr/bin/env python3 - -from src.post_process import OutputConfig, DictionaryBuilder -from src.plot_output import PlotOutput import argparse -from src.process_dict import simplify_and_sum_transcripts -from src.gene_model import rank_and_visualize_genes -import os +import sys +import logging +from src.visualization_output_config import OutputConfig +from src.visualization_dictionary_builder import DictionaryBuilder +from src.visualization_plotter import PlotOutput +from src.visualization_differential_exp import DifferentialAnalysis +from src.visualization_gsea import GSEAAnalysis +from pathlib import Path + + +def setup_logging(viz_output_dir: Path) -> None: + """Configure centralized logging for all visualization processes.""" + log_file = viz_output_dir / "visualize.log" + + # Create formatters + file_formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s' + ) + console_formatter = logging.Formatter('%(levelname)s: %(message)s') + + # File handler - detailed logging + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(file_formatter) + + # Console handler - less detailed + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.INFO) + console_handler.setFormatter(console_formatter) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + root_logger.handlers = [] # Clear existing handlers + root_logger.addHandler(file_handler) + root_logger.addHandler(console_handler) + + # Create logger for the visualization package + viz_logger = logging.getLogger('IsoQuant.visualization') + viz_logger.setLevel(logging.DEBUG) + + # Create logger for differential expression + diff_logger = logging.getLogger('IsoQuant.visualization.differential_exp') + diff_logger.setLevel(logging.DEBUG) + + logging.info("Initialized centralized logging system") + logging.debug(f"Log file location: {log_file}") + + +def setup_viz_output(output_directory: str, viz_output: str = None) -> Path: + """Set up visualization output directory.""" + if viz_output: + viz_output_dir = Path(viz_output) + else: + viz_output_dir = Path(output_directory) / "visualization" + viz_output_dir.mkdir(parents=True, exist_ok=True) + return viz_output_dir class FindGenesAction(argparse.Action): def __call__(self, parser, namespace, values, option_string=None): if values is None: - values = 100 # Default value when the flag is used without a value + values = 100 # Default if flag used without value setattr(namespace, self.dest, values) def parse_arguments(): parser = argparse.ArgumentParser(description="Visualize your IsoQuant output.") + + # Positional Argument parser.add_argument( "output_directory", type=str, help="Directory containing IsoQuant output files." ) + + # Optional Arguments parser.add_argument( "--viz_output", type=str, @@ -47,73 +101,120 @@ def parse_arguments(): default=None, ) parser.add_argument( + "--gsea", + action="store_true", + help="Perform GSEA analysis on differential expression results", + ) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument( "--gene_list", type=str, - required=True, - help="Path to a .txt file containing a list of genes, each on its own line.", + help="Path to a .txt file containing a list of genes to evaluate.", ) - parser.add_argument( + group.add_argument( "--find_genes", nargs="?", const=100, type=int, - help="Find genes with the highest combined rank and visualize them. Optionally specify the number of top genes to evaluate (default is 100).", - ) - parser.add_argument( - "--known_genes_path", - type=str, - help="Path to a CSV file containing known target genes.", - default=None, + help="Find top genes with highest combined rank (default 100).", ) args = parser.parse_args() - # If --find_genes is used, prompt for reference condition - if args.find_genes: + if args.find_genes is not None: output = OutputConfig( args.output_directory, use_counts=args.counts, ref_only=args.ref_only, gtf=args.gtf, ) + if output.conditions: + gene_file = ( + output.transcript_grouped_tpm + if not output.use_counts + else output.transcript_grouped_counts + ) + else: + gene_file = ( + output.transcript_tpm + if not output.use_counts + else output.transcript_counts + ) - # Read the first line of the transcript_grouped_tpm file to get the conditions - with open(output.transcript_grouped_tpm, "r") as f: - header = f.readline().strip().split("\t") + if not gene_file or not Path(gene_file).is_file(): + print(f"Error: Grouped TPM/Counts file not found at {gene_file}.") + sys.exit(1) - # The first column is typically '#feature_id', so we skip it - conditions = header[1:] + with open(gene_file, "r") as f: + header = f.readline().strip().split("\t") - if len(conditions) == 2: - # If there are only two conditions, automatically use the first as reference - args.reference_condition = conditions[0] + if len(header) < 2: print( - f"Automatically selected '{args.reference_condition}' as the reference condition." + "Error: The grouped TPM/Counts file does not contain condition information." ) - else: - print("Available conditions:") - for i, condition in enumerate(conditions, 1): - print(f"{i}. {condition}") - - while True: - try: - choice = int( - input("Enter the number of the condition to use as reference: ") - ) - if 1 <= choice <= len(conditions): - args.reference_condition = conditions[choice - 1] - break - else: - print("Invalid choice. Please enter a number from the list.") - except ValueError: - print("Invalid input. Please enter a number.") + sys.exit(1) + + available_conditions = header[1:] + if not available_conditions: + print("Error: No conditions found in the grouped TPM/Counts file.") + sys.exit(1) + + args.available_conditions = available_conditions return args +def select_conditions_interactively(args): + print("\nAvailable conditions:") + for idx, condition in enumerate(args.available_conditions, 1): + print(f"{idx}. {condition}") + + def get_selection(prompt, max_selection, exclude=[]): + while True: + try: + choices = input(prompt) + choice_indices = [int(x.strip()) for x in choices.split(",")] + if all(1 <= idx <= max_selection for idx in choice_indices): + selected = [ + args.available_conditions[idx - 1] + for idx in choice_indices + if args.available_conditions[idx - 1] not in exclude + ] + if not selected: + print("No valid conditions selected. Please try again.") + continue + return selected + else: + print(f"Please enter numbers between 1 and {max_selection}.") + except ValueError: + print("Invalid input. Please enter numbers separated by commas.") + + max_idx = len(args.available_conditions) + args.reference_conditions = get_selection( + "\nEnter refs (comma-separated): ", max_idx + ) + selected_refs = set(args.reference_conditions) + args.target_conditions = get_selection( + "\nEnter targets (comma-separated): ", max_idx, exclude=selected_refs + ) + + print("\nSelected Reference Conditions:", ", ".join(args.reference_conditions)) + print("Selected Target Conditions:", ", ".join(args.target_conditions), "\n") + + def main(): - print("Reading IsoQuant parameters.") + # Parse args first without logging args = parse_arguments() + + # Set up visualization directory and logging + viz_output_dir = setup_viz_output(args.output_directory, args.viz_output) + setup_logging(viz_output_dir) + + # If find_genes is specified, get conditions interactively + if args.find_genes is not None: + select_conditions_interactively(args) + + logging.info("Reading IsoQuant parameters.") output = OutputConfig( args.output_directory, use_counts=args.counts, @@ -121,122 +222,110 @@ def main(): gtf=args.gtf, ) dictionary_builder = DictionaryBuilder(output) - gene_list = dictionary_builder.read_gene_list(args.gene_list) - update_names = not all(gene.startswith("ENS") for gene in gene_list) - print("Building gene, transcript, and exon dictionaries.") - gene_dict = dictionary_builder.build_gene_transcript_exon_dictionaries() - print("Building read assignment and classification dictionaries.") - reads_and_class = ( - dictionary_builder.build_read_assignment_and_classification_dictionaries() - ) - - if output.conditions: - gene_file = ( - output.gene_grouped_tpm - if not output.use_counts - else output.gene_grouped_counts - ) + # logging.debug("OutputConfig details:") + # logging.debug(vars(output)) + + # Ask user about read assignments (optional) + use_read_assignments = ( + input("Do you want to look at read_assignment data? (y/n): ") + .strip() + .lower() + .startswith("y") + ) + + # If gene_list was given, read it; might use later for some optional steps + if args.gene_list: + logging.info(f"Reading gene list from {args.gene_list}") + gene_list = dictionary_builder.read_gene_list(args.gene_list) + # Decide if you need to rename Genes -> Symbol + update_names = not all(gene.startswith("ENS") for gene in gene_list) else: - gene_file = output.gene_tpm if not output.use_counts else output.gene_counts + gene_list = None + update_names = True - updated_gene_dict = dictionary_builder.update_gene_dict(gene_dict, gene_file) - if update_names: - print("Updating Ensembl IDs to gene symbols.") - updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) + # 1. Build a dictionary that includes transcript expression + # and filters transcripts if they do NOT exceed args.filter_transcripts + # (defaulting to 1.0 if not provided) + min_val = args.filter_transcripts if args.filter_transcripts is not None else 1.0 + updated_gene_dict = dictionary_builder.build_gene_dict_with_expression_and_filter( + min_value=min_val + ) - if output.ref_only or not output.extended_annotation: - print("Using reference-only based quantification.") - if output.conditions: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_grouped_tpm - if not output.use_counts - else output.transcript_grouped_counts - ), - ) - else: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_tpm - if not output.use_counts - else output.transcript_counts - ), - ) + # 2. If read assignments are desired, build those as well (cached) + if use_read_assignments: + logging.info("Building read assignment and classification dictionaries.") + reads_and_class = ( + dictionary_builder.build_read_assignment_and_classification_dictionaries() + ) else: - print("Using transcript model quantification.") - if output.conditions: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_model_grouped_tpm - if not output.use_counts - else output.transcript_model_grouped_counts - ), - ) - else: - updated_gene_dict = dictionary_builder.update_transcript_values( - updated_gene_dict, - ( - output.transcript_model_tpm - if not output.use_counts - else output.transcript_model_counts - ), - ) + reads_and_class = None - if args.filter_transcripts is not None: - print( - f"Filtering transcripts with minimum value {args.filter_transcripts} in at least one condition." + # 3. If user wants to find top genes (--find_genes), run your differential analysis + if args.find_genes is not None: + ref_str = "_".join( + x.upper().replace(" ", "_") for x in args.reference_conditions ) - updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value( - updated_gene_dict, min_value=args.filter_transcripts + target_str = "_".join( + x.upper().replace(" ", "_") for x in args.target_conditions ) - else: - updated_gene_dict = dictionary_builder.filter_transcripts_by_minimum_value( - updated_gene_dict + main_dir_name = f"find_genes_{ref_str}_vs_{target_str}" + base_dir = ( + viz_output_dir / main_dir_name if not args.viz_output else viz_output_dir ) + base_dir.mkdir(exist_ok=True) - # Visualization output directory decision - if args.viz_output: - viz_output_directory = args.viz_output - else: - viz_output_directory = os.path.join(args.output_directory, "visualization") - os.makedirs(viz_output_directory, exist_ok=True) - - if args.find_genes: - print("Finding genes.") - simple_gene_dict = simplify_and_sum_transcripts(updated_gene_dict) - path = rank_and_visualize_genes( - simple_gene_dict, - viz_output_directory, - args.find_genes, - known_genes_path=args.known_genes_path, - reference_condition=args.reference_condition, + logging.info("Finding genes via differential analysis.") + diff_analysis = DifferentialAnalysis( + output_dir=args.output_directory, + viz_output=base_dir, + ref_conditions=args.reference_conditions, + target_conditions=args.target_conditions, + updated_gene_dict=updated_gene_dict, + ref_only=args.ref_only, + dictionary_builder=dictionary_builder, ) - gene_list = dictionary_builder.read_gene_list(path) + gene_results, transcript_results, _ = diff_analysis.run_complete_analysis() + find_genes_list_path = gene_results.parent / "genes_from_top_100_transcripts.txt" + gene_list = dictionary_builder.read_gene_list(find_genes_list_path) + + if args.gsea: + gsea = GSEAAnalysis(output_path=base_dir) + target_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + gsea.run_gsea_analysis(deseq2_df, target_label) + + # Use genes from top transcripts instead of top genes + find_genes_list_path = gene_results.parent / "genes_from_top_100_transcripts.txt" + gene_list = dictionary_builder.read_gene_list(find_genes_list_path) + else: + base_dir = viz_output_dir - # Create gene_visualizations subdirectory - viz_output_directory = os.path.join(viz_output_directory, "gene_visualizations") - os.makedirs(viz_output_directory, exist_ok=True) + if update_names: + logging.info("Updating Ensembl IDs to gene symbols.") + updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) - # Create read_assignments subdirectory - read_assignments_dir = os.path.join(viz_output_directory, "read_assignments") - os.makedirs(read_assignments_dir, exist_ok=True) + # 5. Set up output directories + read_assignments_dir = base_dir / "read_assignments" + gene_visualizations_dir = base_dir / "gene_visualizations" + read_assignments_dir.mkdir(exist_ok=True) + gene_visualizations_dir.mkdir(exist_ok=True) + # 6. Plotting with PlotOutput plot_output = PlotOutput( updated_gene_dict, gene_list, - viz_output_directory, - read_assignments_dir=read_assignments_dir, + str(gene_visualizations_dir), + read_assignments_dir=str(read_assignments_dir), reads_and_class=reads_and_class, - filter_transcripts=args.filter_transcripts, + filter_transcripts=min_val, # Just pass your chosen threshold for reference conditions=output.conditions, use_counts=args.counts, ) + plot_output.plot_transcript_map() plot_output.plot_transcript_usage() - plot_output.make_pie_charts() + + if use_read_assignments: + plot_output.make_pie_charts() if __name__ == "__main__": From 0d9ad97ee5d0b7a09c10ac7031c0d101276da42f Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Wed, 29 Jan 2025 16:32:13 -0600 Subject: [PATCH 20/35] working multilevel PCA --- src/visualization_differential_exp.py | 115 +++++++++++++++++++++++--- 1 file changed, 102 insertions(+), 13 deletions(-) diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 6172ccfa..4b94a7e9 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -9,6 +9,11 @@ from rpy2.robjects.conversion import localconverter from src.visualization_plotter import ExpressionVisualizer from src.visualization_mapping import GeneMapper +import numpy as np +from scipy.stats import gmean +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import seaborn as sns class DifferentialAnalysis: @@ -117,12 +122,13 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: # --- 5. Run DESeq2 Analysis --- if not self.ref_only: - deseq2_results_gene_file, _ = self._run_level_analysis( + deseq2_results_gene_file, gene_normalized_counts = self._run_level_analysis( level="gene", pattern="gene_grouped_counts.tsv", # Pattern is still needed in _run_level_analysis for output file naming - count_data=gene_counts_filtered # Pass PRE-FILTERED gene counts + count_data=gene_counts_filtered, + coldata=self._build_design_matrix(gene_counts_filtered) ) - deseq2_results_transcript_file, deseq2_transcript_df = self._run_level_analysis( + deseq2_results_transcript_file, transcript_normalized_counts = self._run_level_analysis( level="transcript", pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", # Pattern is still needed in _run_level_analysis for output file naming count_data=transcript_counts_filtered # Pass PRE-FILTERED transcript counts @@ -141,6 +147,8 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: ) self.logger.info(f"Gene-level visualizations saved to {self.deseq_dir}") + normalized_gene_counts = self._median_ratio_normalization(gene_counts_filtered) # Normalize gene counts in Python + self._run_pca(normalized_gene_counts, "gene", coldata=self._build_design_matrix(gene_counts_filtered)) # Run PCA for gene level, pass normalized_counts # --- Visualize Transcript-Level Results --- transcript_results_df = pd.read_csv(deseq2_results_transcript_file) @@ -153,10 +161,13 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: ) self.logger.info(f"Transcript-level visualizations saved to {self.deseq_dir}") + normalized_transcript_counts = self._median_ratio_normalization(transcript_counts_filtered) # Normalize transcript counts in Python + self._run_pca(normalized_transcript_counts, "transcript", coldata=self._build_design_matrix(transcript_counts_filtered)) # Run PCA for transcript level, pass normalized_counts + return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered def _run_level_analysis( - self, level: str, count_data: pd.DataFrame, pattern: Optional[str] = None + self, level: str, count_data: pd.DataFrame, pattern: Optional[str] = None, coldata=None ) -> Tuple[Path, pd.DataFrame]: """ Run DESeq2 analysis for a specific level and return results. @@ -179,7 +190,7 @@ def _run_level_analysis( # Create design matrix and run DESeq2 coldata = self._build_design_matrix(filtered_data) - results = self._run_deseq2(filtered_data, coldata) + results, normalized_counts_r = self._run_deseq2(filtered_data, coldata, level) # Process results results.index.name = "feature_id" @@ -204,7 +215,8 @@ def _run_level_analysis( # Write top genes self._write_top_genes(results, level) - return outfile, results + # No normalized counts returned from _run_deseq2 anymore + return outfile, pd.DataFrame() # Return empty DataFrame for normalized counts def _get_condition_data(self, pattern: str) -> pd.DataFrame: """Combine count data from all conditions.""" @@ -289,25 +301,29 @@ def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame({"group": groups}, index=count_data.columns) def _run_deseq2( - self, count_data: pd.DataFrame, coldata: pd.DataFrame - ) -> pd.DataFrame: + self, count_data: pd.DataFrame, coldata: pd.DataFrame, level: str + ) -> Tuple[pd.DataFrame, pd.DataFrame]: """Run DESeq2 analysis.""" deseq2 = importr("DESeq2") count_data = count_data.fillna(0).round().astype(int) with localconverter(robjects.default_converter + pandas2ri.converter): + # Convert count_data and coldata to R DataFrames explicitly before creating DESeqDataSet + count_data_r = pandas2ri.py2rpy(count_data) + coldata_r = pandas2ri.py2rpy(coldata) + dds = deseq2.DESeqDataSetFromMatrix( - countData=pandas2ri.py2rpy(count_data), - colData=pandas2ri.py2rpy(coldata), - design=Formula("~ group"), + countData=count_data_r, colData=coldata_r, design=Formula("~ group") ) dds = deseq2.DESeq(dds) res = deseq2.results( dds, contrast=robjects.StrVector(["group", "Target", "Reference"]) ) + # No normalized counts from DESeq2 anymore + return pd.DataFrame( robjects.conversion.rpy2py(r("data.frame")(res)), index=count_data.index - ) + ), pd.DataFrame() # Return empty DataFrame for normalized counts def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dict[str, Optional[str]]]: """ @@ -358,4 +374,77 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: top_genes = results.nlargest(100, "abs_stat")["gene_name"] top_genes_file = self.deseq_dir / "top_100_genes.txt" top_genes.to_csv(top_genes_file, index=False, header=False) - self.logger.info(f"Wrote top 100 genes to {top_genes_file}") \ No newline at end of file + self.logger.info(f"Wrote top 100 genes to {top_genes_file}") + + def _run_pca(self, normalized_counts, level, coldata): # Add coldata parameter + self.logger.info(f"Running PCA for {level} level in Python using median-by-ratio normalized counts...") + + pca_plot_file = str(self.deseq_dir / f"pca_{level}_pca.png") + + # Debugging logs before PCA + self.logger.debug(f"PCA Input - Level: {level}") + self.logger.debug(f"PCA Input - Normalized Counts Shape: {normalized_counts.shape}") + self.logger.debug(f"PCA Input - Normalized Counts Dtype: {normalized_counts.dtypes}") + self.logger.debug(f"PCA Input - Normalized Counts Head:\n{normalized_counts.head()}") + + pca = PCA(n_components=2) + self.logger.debug(f"PCA Input - PCA Object: {pca}") + self.logger.debug(f"PCA Input - PCA Object Attributes: {dir(pca)}") + # Log transform normalized counts (adding small constant to avoid log(0)) + log_normalized_counts = np.log2(normalized_counts + 1) + self.logger.debug(f"PCA Input - Log Transformed Counts Shape: {log_normalized_counts.shape}") + self.logger.debug(f"PCA Input - Log Transformed Counts Head:\n{log_normalized_counts.head()}") + + pca_result = pca.fit_transform(log_normalized_counts.transpose()) # Transpose for samples as rows + + pca_df = pd.DataFrame(data=pca_result, columns=['PC1', 'PC2'], index=log_normalized_counts.columns) # Index with sample names + pca_df['group'] = coldata['group'].values # Add group info + + # Calculate explained variance + explained_variance = pca.explained_variance_ratio_ + VarExplPC1 = f"{100*explained_variance[0]:.2f}%" + VarExplPC2 = f"{100*explained_variance[1]:.2f}%" + + # Create PCA plot using matplotlib and seaborn + plt.figure(figsize=(8, 6)) + sns.scatterplot(x='PC1', y='PC2', hue='group', data=pca_df, s=100) + plt.xlabel(f"PC1 ({VarExplPC1})") + plt.ylabel(f"PC2 ({VarExplPC2})") + plt.title(f"{level.capitalize()} Level PCA (Median-by-Ratio Normalized Counts)") + + # Label each point with the sample name + for i, sample_name in enumerate(pca_df.index): + plt.text(pca_df.loc[sample_name, 'PC1'], pca_df.loc[sample_name, 'PC2'], sample_name, fontsize=8, ha='left', va='bottom') + + plt.gca().spines['top'].set_visible(False) + plt.gca().spines['right'].set_visible(False) + plt.tight_layout() + + plt.savefig(pca_plot_file) + self.logger.info(f"Python PCA plot saved to {pca_plot_file}") + plt.close() + + def _median_ratio_normalization(self, count_data: pd.DataFrame) -> pd.DataFrame: + """ + Perform median-by-ratio normalization on count data. + This is similar to the normalization used in DESeq2. + """ + # 1. Calculate geometric mean for each feature (row) + geometric_means = count_data.apply(gmean, axis=1) + + # 2. Handle rows with zero geometric mean (replace with NaN to avoid division by zero) + geometric_means[geometric_means == 0] = np.nan + + # 3. Calculate ratio of each count to the geometric mean + count_ratios = count_data.divide(geometric_means, axis=0) + + # 4. Calculate size factor for each sample (column) as the median of ratios + size_factors = count_ratios.median(axis=0) + + # 5. Normalize counts by dividing by size factors + normalized_counts = count_data.divide(size_factors, axis=1) + + # 6. Fill NaN values with 0 after normalization + normalized_counts = normalized_counts.fillna(0) + + return normalized_counts \ No newline at end of file From b6202eaa1c8f9ceafbe8dd8513868c121b5967b9 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Thu, 30 Jan 2025 10:47:25 -0600 Subject: [PATCH 21/35] Fixed labels for figures --- src/visualization_differential_exp.py | 98 ++++++++++++--------------- src/visualization_plotter.py | 57 +++++++++++++--- 2 files changed, 92 insertions(+), 63 deletions(-) diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 4b94a7e9..97f22483 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -134,25 +134,34 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: count_data=transcript_counts_filtered # Pass PRE-FILTERED transcript counts ) + # Update how we create the labels + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + # --- Visualize Gene-Level Results --- gene_results_df = pd.read_csv(deseq2_results_gene_file) - target_label = f"{'+'.join(self.target_conditions)}_vs_{'+'.join(self.ref_conditions)}" - reference_label = f"{'+'.join(self.ref_conditions)}" # Corrected reference label - self.visualizer.visualize_results( # Call visualize_results for gene-level + self.visualizer.visualize_results( results=gene_results_df, target_label=target_label, reference_label=reference_label, - min_count=10, # Assuming min_count_threshold is defined in DifferentialAnalysis + min_count=10, feature_type="genes", ) self.logger.info(f"Gene-level visualizations saved to {self.deseq_dir}") - normalized_gene_counts = self._median_ratio_normalization(gene_counts_filtered) # Normalize gene counts in Python - self._run_pca(normalized_gene_counts, "gene", coldata=self._build_design_matrix(gene_counts_filtered)) # Run PCA for gene level, pass normalized_counts + # Run PCA with correct labels + normalized_gene_counts = self._median_ratio_normalization(gene_counts_filtered) + self._run_pca( + normalized_gene_counts, + "gene", + coldata=self._build_design_matrix(gene_counts_filtered), + target_label=target_label, + reference_label=reference_label + ) # --- Visualize Transcript-Level Results --- transcript_results_df = pd.read_csv(deseq2_results_transcript_file) - self.visualizer.visualize_results( # Call visualize_results for transcript-level + self.visualizer.visualize_results( results=transcript_results_df, target_label=target_label, reference_label=reference_label, @@ -161,8 +170,15 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: ) self.logger.info(f"Transcript-level visualizations saved to {self.deseq_dir}") - normalized_transcript_counts = self._median_ratio_normalization(transcript_counts_filtered) # Normalize transcript counts in Python - self._run_pca(normalized_transcript_counts, "transcript", coldata=self._build_design_matrix(transcript_counts_filtered)) # Run PCA for transcript level, pass normalized_counts + # Run PCA with correct labels for transcript level + normalized_transcript_counts = self._median_ratio_normalization(transcript_counts_filtered) + self._run_pca( + normalized_transcript_counts, + "transcript", + coldata=self._build_design_matrix(transcript_counts_filtered), + target_label=target_label, + reference_label=reference_label + ) return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered @@ -376,53 +392,29 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: top_genes.to_csv(top_genes_file, index=False, header=False) self.logger.info(f"Wrote top 100 genes to {top_genes_file}") - def _run_pca(self, normalized_counts, level, coldata): # Add coldata parameter - self.logger.info(f"Running PCA for {level} level in Python using median-by-ratio normalized counts...") - - pca_plot_file = str(self.deseq_dir / f"pca_{level}_pca.png") - - # Debugging logs before PCA - self.logger.debug(f"PCA Input - Level: {level}") - self.logger.debug(f"PCA Input - Normalized Counts Shape: {normalized_counts.shape}") - self.logger.debug(f"PCA Input - Normalized Counts Dtype: {normalized_counts.dtypes}") - self.logger.debug(f"PCA Input - Normalized Counts Head:\n{normalized_counts.head()}") - + def _run_pca(self, normalized_counts, level, coldata, target_label, reference_label): + """Run PCA analysis and create visualization.""" + self.logger.info(f"Running PCA for {level} level...") + + # Run PCA pca = PCA(n_components=2) - self.logger.debug(f"PCA Input - PCA Object: {pca}") - self.logger.debug(f"PCA Input - PCA Object Attributes: {dir(pca)}") - # Log transform normalized counts (adding small constant to avoid log(0)) log_normalized_counts = np.log2(normalized_counts + 1) - self.logger.debug(f"PCA Input - Log Transformed Counts Shape: {log_normalized_counts.shape}") - self.logger.debug(f"PCA Input - Log Transformed Counts Head:\n{log_normalized_counts.head()}") - - pca_result = pca.fit_transform(log_normalized_counts.transpose()) # Transpose for samples as rows - - pca_df = pd.DataFrame(data=pca_result, columns=['PC1', 'PC2'], index=log_normalized_counts.columns) # Index with sample names - pca_df['group'] = coldata['group'].values # Add group info - - # Calculate explained variance + pca_result = pca.fit_transform(log_normalized_counts.transpose()) + + # Create DataFrame + pca_df = pd.DataFrame(data=pca_result, columns=['PC1', 'PC2'], index=log_normalized_counts.columns) + pca_df['group'] = coldata['group'].values + + # Calculate variance explained_variance = pca.explained_variance_ratio_ - VarExplPC1 = f"{100*explained_variance[0]:.2f}%" - VarExplPC2 = f"{100*explained_variance[1]:.2f}%" - - # Create PCA plot using matplotlib and seaborn - plt.figure(figsize=(8, 6)) - sns.scatterplot(x='PC1', y='PC2', hue='group', data=pca_df, s=100) - plt.xlabel(f"PC1 ({VarExplPC1})") - plt.ylabel(f"PC2 ({VarExplPC2})") - plt.title(f"{level.capitalize()} Level PCA (Median-by-Ratio Normalized Counts)") - - # Label each point with the sample name - for i, sample_name in enumerate(pca_df.index): - plt.text(pca_df.loc[sample_name, 'PC1'], pca_df.loc[sample_name, 'PC2'], sample_name, fontsize=8, ha='left', va='bottom') - - plt.gca().spines['top'].set_visible(False) - plt.gca().spines['right'].set_visible(False) - plt.tight_layout() - - plt.savefig(pca_plot_file) - self.logger.info(f"Python PCA plot saved to {pca_plot_file}") - plt.close() + title = f"{level.capitalize()} Level PCA: {target_label} vs {reference_label}\nPC1 ({100*explained_variance[0]:.2f}%) / PC2 ({100*explained_variance[1]:.2f}%)" + + # Use the plotter's PCA method + self.visualizer.plot_pca( + pca_df=pca_df, + title=title, + output_prefix=f"pca_{level}" + ) def _median_ratio_normalization(self, count_data: pd.DataFrame) -> pd.DataFrame: """ diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 54c24193..e556bc3f 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -4,9 +4,8 @@ from pathlib import Path import logging import pandas as pd -import random -import json import matplotlib.patches as patches +import seaborn as sns class PlotOutput: @@ -298,15 +297,13 @@ def _create_pie_chart(self, title, data): class ExpressionVisualizer: - def __init__(self, output_path: Path): - """ - Initialize visualizer with output directory. - - Args: - output_path: Path to output directory - """ + def __init__(self, output_path): + """Initialize with output path for plots.""" self.output_path = Path(output_path) self.output_path.mkdir(parents=True, exist_ok=True) + self.logger = logging.getLogger(__name__) # Logger for this class + # Suppress matplotlib font debug messages + logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING) def create_volcano_plot( self, @@ -316,7 +313,7 @@ def create_volcano_plot( padj_threshold: float = 0.05, lfc_threshold: float = 1, top_n: int = 10, - feature_type: str = "genes", # Added parameter + feature_type: str = "genes", ) -> None: """Create volcano plot from differential expression results.""" plt.figure(figsize=(10, 8)) @@ -511,3 +508,43 @@ def visualize_results( except Exception as e: logging.exception("Failed to create visualizations") raise + + + def plot_pca(self, pca_df: pd.DataFrame, title: str, output_prefix: str) -> Path: + """Plot PCA scatter plot.""" + plt.figure(figsize=(8, 6)) + + # Extract variance info from title for axis labels only + pc1_var = title.split("PC1 (")[1].split("%)")[0] + pc2_var = title.split("PC2 (")[1].split("%)")[0] + + # Get clean title without PCs and variance - using string literal instead of \n + base_title = title.split(' Level PCA: ')[0] + comparison = title.split(': ')[1].split('PC1')[0].strip() + clean_title = f"{base_title} Level PCA: {comparison}" + + # Update group labels in the DataFrame + condition_mapping = {'Target': title.split(": ")[1].split(" vs ")[0], + 'Reference': title.split(" vs ")[1].split("PC1")[0].strip()} + pca_df['group'] = pca_df['group'].map(condition_mapping) + + # Create plot with updated labels + sns.scatterplot(x='PC1', y='PC2', hue='group', data=pca_df, s=100) + plt.xlabel(f'PC1 ({pc1_var}%)') + plt.ylabel(f'PC2 ({pc2_var}%)') + plt.title(clean_title) + + # Label points + for sample_name in pca_df.index: + plt.text(pca_df.loc[sample_name, 'PC1'], pca_df.loc[sample_name, 'PC2'], + sample_name, fontsize=8, ha='left', va='bottom') + + plt.gca().spines['top'].set_visible(False) + plt.gca().spines['right'].set_visible(False) + plt.tight_layout() + + output_path = self.output_path / f"{output_prefix}_pca.png" + plt.savefig(output_path) + plt.close() + return output_path + From 57ac5e086d24fd930faace53f2333788f70576f8 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Thu, 30 Jan 2025 12:53:54 -0600 Subject: [PATCH 22/35] Removed point labels --- src/visualization_plotter.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index e556bc3f..4d914856 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -533,12 +533,6 @@ def plot_pca(self, pca_df: pd.DataFrame, title: str, output_prefix: str) -> Path plt.xlabel(f'PC1 ({pc1_var}%)') plt.ylabel(f'PC2 ({pc2_var}%)') plt.title(clean_title) - - # Label points - for sample_name in pca_df.index: - plt.text(pca_df.loc[sample_name, 'PC1'], pca_df.loc[sample_name, 'PC2'], - sample_name, fontsize=8, ha='left', va='bottom') - plt.gca().spines['top'].set_visible(False) plt.gca().spines['right'].set_visible(False) plt.tight_layout() From bcdd1935e46657981d501271692c10e47843236d Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Mon, 3 Feb 2025 11:15:56 -0600 Subject: [PATCH 23/35] gsea works with refactor --- src/visualization_differential_exp.py | 29 ++-- src/visualization_gsea.py | 196 ++++++++++++++++++++++++++ 2 files changed, 215 insertions(+), 10 deletions(-) create mode 100644 src/visualization_gsea.py diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 97f22483..18bd45de 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -73,8 +73,16 @@ def _create_transcript_to_gene_map(self) -> Dict[str, str]: transcript_map[transcript_id] = transcript_name return transcript_map - def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: - """Run differential expression analysis for both genes and transcripts.""" + def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame]: + """Run differential expression analysis for both genes and transcripts. + + Returns: + Tuple containing: + - Path to gene results file + - Path to transcript results file + - DataFrame of transcript counts + - DataFrame of DESeq2 gene-level results + """ self.logger.info("Starting differential expression analysis") valid_transcripts = set() @@ -124,24 +132,26 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: if not self.ref_only: deseq2_results_gene_file, gene_normalized_counts = self._run_level_analysis( level="gene", - pattern="gene_grouped_counts.tsv", # Pattern is still needed in _run_level_analysis for output file naming + pattern="gene_grouped_counts.tsv", count_data=gene_counts_filtered, coldata=self._build_design_matrix(gene_counts_filtered) ) deseq2_results_transcript_file, transcript_normalized_counts = self._run_level_analysis( level="transcript", - pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", # Pattern is still needed in _run_level_analysis for output file naming - count_data=transcript_counts_filtered # Pass PRE-FILTERED transcript counts + pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", + count_data=transcript_counts_filtered ) + # Load the gene-level results for GSEA + deseq2_results_df = pd.read_csv(deseq2_results_gene_file) + # Update how we create the labels target_label = "+".join(self.target_conditions) reference_label = "+".join(self.ref_conditions) # --- Visualize Gene-Level Results --- - gene_results_df = pd.read_csv(deseq2_results_gene_file) self.visualizer.visualize_results( - results=gene_results_df, + results=deseq2_results_df, target_label=target_label, reference_label=reference_label, min_count=10, @@ -160,9 +170,8 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: ) # --- Visualize Transcript-Level Results --- - transcript_results_df = pd.read_csv(deseq2_results_transcript_file) self.visualizer.visualize_results( - results=transcript_results_df, + results=pd.read_csv(deseq2_results_transcript_file), target_label=target_label, reference_label=reference_label, min_count=10, @@ -180,7 +189,7 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame]: reference_label=reference_label ) - return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered + return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered, deseq2_results_df def _run_level_analysis( self, level: str, count_data: pd.DataFrame, pattern: Optional[str] = None, coldata=None diff --git a/src/visualization_gsea.py b/src/visualization_gsea.py new file mode 100644 index 00000000..148fc718 --- /dev/null +++ b/src/visualization_gsea.py @@ -0,0 +1,196 @@ +import logging +from pathlib import Path +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from rpy2 import robjects +from rpy2.robjects import r, pandas2ri +from rpy2.robjects.packages import importr +from rpy2.robjects.conversion import localconverter + + +class GSEAAnalysis: + def __init__(self, output_path: Path): + """ + Initialize GSEA analysis. + + Args: + output_path: Path to save GSEA results + """ + self.output_path = Path(output_path) / "gsea_results" + self.output_path.mkdir(parents=True, exist_ok=True) + + # Configure R to be quiet + from rpy2.rinterface_lib import callbacks + + def quiet_cb(x): + pass + + callbacks.logger.setLevel(logging.WARNING) + callbacks.consolewrite_print = quiet_cb + callbacks.consolewrite_warnerror = quiet_cb + + def run_gsea_analysis(self, results: pd.DataFrame, target_label: str) -> None: + """ + Run GSEA analysis using DESeq2 stat value as ranking metric. + Creates visualizations for top enriched pathways in each GO category. + + Args: + results: DataFrame containing DESeq2 results + target_label: Label indicating the comparison being made + """ + if results is None or results.empty: + logging.error("No DESeq2 results provided for GSEA analysis") + return + + logging.info("Starting GSEA analysis...") + logging.debug(f"Full DE results shape: {results.shape}") + logging.debug(f"DE results columns: {results.columns.tolist()}") + + # Filter for significant DE genes + sig_genes = results.dropna(subset=["padj"]) + logging.debug(f"After dropping genes with NaN padj: {sig_genes.shape}") + sig_genes = sig_genes[sig_genes["padj"] < 0.05] + logging.debug(f"Significantly DE genes (padj<0.05): {sig_genes.shape}") + if sig_genes.empty: + logging.info("No significantly DE genes found for GSEA.") + return + + sig_genes = sig_genes[sig_genes["pvalue"] > 0] + logging.debug(f"Significantly DE genes with pvalue>0: {sig_genes.shape}") + if sig_genes.empty: + logging.info("No genes with valid p-values for GSEA.") + return + + # Use gene_name instead of symbol + gene_symbols_final = sig_genes["gene_name"].values + + ranked_genes = pd.Series( + sig_genes["stat"].values, index=gene_symbols_final + ).dropna() + ranked_genes = ranked_genes[~ranked_genes.index.duplicated(keep="first")] + logging.debug(f"Final ranked genes count: {len(ranked_genes)}") + + if ranked_genes.empty: + logging.info("No valid ranked genes after processing.") + return + + # Save the ranked genes + ranked_outfile = self.output_path / "ranked_genes.csv" + ranked_genes_df = pd.DataFrame( + {"gene": ranked_genes.index, "rank": ranked_genes.values} + ) + ranked_genes_df.to_csv(ranked_outfile, index=False) + logging.info(f"Ranked genes saved to {ranked_outfile}") + + # Import required R packages + clusterProfiler = importr("clusterProfiler") + r("library(org.Hs.eg.db)") + + with localconverter(robjects.default_converter + pandas2ri.converter): + r_ranked_genes = pandas2ri.py2rpy(ranked_genes.sort_values(ascending=False)) + + def plot_pathways(df: pd.DataFrame, direction: str, ont: str): + if df.empty: + logging.info(f"No {direction} pathways to plot.") + return + + df["label"] = df["ID"] + ": " + df["Description"] + df["-log10(p.adjust)"] = -np.log10(df["p.adjust"]) + values = df["-log10(p.adjust)"] + norm = plt.Normalize(vmin=values.min(), vmax=values.max()) + cmap = plt.cm.get_cmap("viridis") + colors_for_bars = [cmap(norm(v)) for v in values] + + plt.figure(figsize=(12, 8)) + plt.barh( + df["label"].iloc[::-1], + df["-log10(p.adjust)"].iloc[::-1], + color=colors_for_bars[::-1], + ) + plt.xlabel("-log10(adjusted p-value)") + + # Split target label into reference and target parts + target_parts = target_label.split("_vs_") + target_condition = target_parts[0] + ref_condition = target_parts[1] + + # Create title based on direction + if direction == "up": + condition_str = f"Pathways enriched in {target_condition}\nvs {ref_condition} - {ont}" + else: + condition_str = f"Pathways enriched in {ref_condition}\nvs {target_condition} - {ont}" + + plt.title(condition_str, fontsize=10) + plt.tight_layout() + plot_path = self.output_path / f"GSEA_top_pathways_{direction}_{ont}.png" + plt.savefig(plot_path) + plt.close() + logging.info(f"GSEA {direction} pathways plot saved to {plot_path}") + + # Run GO analysis for each ontology + ontologies = ["BP", "MF", "CC"] + for ont in ontologies: + logging.debug(f"Running gseGO for {ont}...") + + gsea_res = clusterProfiler.gseGO( + geneList=r_ranked_genes, + OrgDb="org.Hs.eg.db", + keyType="SYMBOL", + ont=ont, + minGSSize=5, + maxGSSize=1000, + pvalueCutoff=1, + verbose=True, + nPermSimple=10000, + ) + + gsea_table = r("data.frame")(gsea_res) + with localconverter(robjects.default_converter + pandas2ri.converter): + gsea_df = pandas2ri.rpy2py(gsea_table) + + # Log detailed results + logging.debug(f"GSEA results for {ont}:") + logging.debug(f" Total pathways tested: {len(gsea_df)}") + if not gsea_df.empty: + logging.debug( + f" P-value range: {gsea_df['pvalue'].min():.2e} - {gsea_df['pvalue'].max():.2e}" + ) + logging.debug( + f" Adjusted p-value range: {gsea_df['p.adjust'].min():.2e} - {gsea_df['p.adjust'].max():.2e}" + ) + logging.debug( + f" NES range: {gsea_df['NES'].min():.2f} - {gsea_df['NES'].max():.2f}" + ) + logging.debug( + f" Pathways with adj.P<0.1: {len(gsea_df[gsea_df['p.adjust'] < 0.1])}" + ) + logging.debug( + f" Pathways with adj.P<0.05: {len(gsea_df[gsea_df['p.adjust'] < 0.05])}" + ) + + # Save all results + gsea_outfile = self.output_path / f"GSEA_results_{ont}.csv" + gsea_df.to_csv(gsea_outfile, index=False) + logging.info(f"Complete GSEA results for {ont} saved to {gsea_outfile}") + + # Plot significant pathways + sig_gsea_df = gsea_df[ + gsea_df["p.adjust"] < 0.05 + ].copy() # Using 0.05 threshold + if not sig_gsea_df.empty: + up_pathways = sig_gsea_df[sig_gsea_df["NES"] > 0].nsmallest( + 10, "p.adjust" + ) + down_pathways = sig_gsea_df[sig_gsea_df["NES"] < 0].nsmallest( + 10, "p.adjust" + ) + + if not up_pathways.empty: + plot_pathways(up_pathways, "up", ont) + if not down_pathways.empty: + plot_pathways(down_pathways, "down", ont) + else: + logging.info(f"No pathways with adj.P<0.05 found for {ont}") + + logging.info("GSEA analysis completed.") From 2a37a767a5de183748e9fdbe12d13fc0d665be5f Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Mon, 3 Feb 2025 11:16:13 -0600 Subject: [PATCH 24/35] added logger level debug and GSEA for top level script --- visualize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/visualize.py b/visualize.py index 006a5e25..774a9995 100755 --- a/visualize.py +++ b/visualize.py @@ -26,7 +26,7 @@ def setup_logging(viz_output_dir: Path) -> None: # Console handler - less detailed console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) + console_handler.setLevel(logging.DEBUG) console_handler.setFormatter(console_formatter) # Configure root logger @@ -284,7 +284,7 @@ def main(): ref_only=args.ref_only, dictionary_builder=dictionary_builder, ) - gene_results, transcript_results, _ = diff_analysis.run_complete_analysis() + gene_results, transcript_results, _, deseq2_df = diff_analysis.run_complete_analysis() find_genes_list_path = gene_results.parent / "genes_from_top_100_transcripts.txt" gene_list = dictionary_builder.read_gene_list(find_genes_list_path) From d7d9ab4273732e146c04255ec63ada93c9ef35ce Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Tue, 4 Feb 2025 12:43:03 -0600 Subject: [PATCH 25/35] create exon table and fixed counts to TPM for expr assignment --- src/visualization_dictionary_builder.py | 446 ++++++++++++------------ 1 file changed, 229 insertions(+), 217 deletions(-) diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py index 6600a3d9..525abb3e 100644 --- a/src/visualization_dictionary_builder.py +++ b/src/visualization_dictionary_builder.py @@ -5,6 +5,7 @@ import logging from pathlib import Path from typing import Dict, Any, List, Union, Tuple +import random from src.visualization_cache_utils import ( build_gene_dict_cache_file, @@ -25,7 +26,7 @@ def __init__(self, config): # Set up logger for DictionaryBuilder self.logger = logging.getLogger('IsoQuant.visualization.dictionary_builder') - self.logger.setLevel(logging.DEBUG) + self.logger.setLevel(logging.INFO) # Initialize sets to store novel gene and transcript IDs self.novel_gene_ids = set() @@ -41,7 +42,7 @@ def build_gene_dict_with_expression_and_filter( Optimized build process with early filtering and combined steps. """ self.logger.debug(f"Starting optimized dictionary build with min_value={min_value}") - + # 1. Check cache first expr_file, tpm_file = self._get_expression_files() base_cache_file = build_gene_dict_cache_file( @@ -56,13 +57,13 @@ def build_gene_dict_with_expression_and_filter( if expr_filter_cache.exists(): cached_data = load_cache(expr_filter_cache) - if cached_data and len(cached_data) == 3: # Check if load_cache returned a tuple - cached_gene_dict, cached_novel_gene_ids, cached_novel_transcript_ids = cached_data # Unpack tuple + if cached_data and len(cached_data) == 3: # Check if load_cache returned a tuple + cached_gene_dict, cached_novel_gene_ids, cached_novel_transcript_ids = cached_data # Unpack tuple if validate_gene_dict(cached_gene_dict): - self.novel_gene_ids = cached_novel_gene_ids # Restore from cache - self.novel_transcript_ids = cached_novel_transcript_ids # Restore from cache + self.novel_gene_ids = cached_novel_gene_ids # Restore from cache + self.novel_transcript_ids = cached_novel_transcript_ids # Restore from cache return cached_gene_dict - else: # Handle older cache format (just gene_dict) + else: # Handle older cache format (just gene_dict) cached_gene_dict = cached_data if validate_gene_dict(cached_gene_dict): return cached_gene_dict @@ -75,11 +76,15 @@ def build_gene_dict_with_expression_and_filter( # Add debug log: Number of genes and transcripts after novel gene filtering gene_count_after_novel_filter = len(base_gene_dict) - transcript_count_after_novel_filter = sum(len(gene_info.get("transcripts", {})) for gene_info in base_gene_dict.values()) - self.logger.debug(f"After novel gene filtering: {gene_count_after_novel_filter} genes, {transcript_count_after_novel_filter} transcripts") + transcript_count_after_novel_filter = sum( + len(gene_info.get("transcripts", {})) for gene_info in base_gene_dict.values() + ) + self.logger.debug( + f"After novel gene filtering: {gene_count_after_novel_filter} genes, {transcript_count_after_novel_filter} transcripts" + ) # 3. Load expression data with consistent header handling - self.logger.info("Loading expression matrix") + self.logger.info("Loading expression matrix (Counts)") try: expr_df = pd.read_csv(expr_file, sep='\t', comment=None) expr_df.columns = [col.lstrip('#') for col in expr_df.columns] # Clean headers @@ -88,30 +93,47 @@ def build_gene_dict_with_expression_and_filter( self.logger.error(f"Missing required column in {expr_file}: {str(e)}") raise except Exception as e: - self.logger.error(f"Failed to load expression matrix: {str(e)}") + self.logger.error(f"Failed to load count expression matrix: {str(e)}") + raise + + self.logger.info("Loading TPM matrix") + try: + tpm_df = pd.read_csv(tpm_file, sep='\t', comment=None) + tpm_df.columns = [col.lstrip('#') for col in tpm_df.columns] # Clean headers + tpm_df = tpm_df.set_index('feature_id') # Use cleaned column name + except KeyError as e: + self.logger.error(f"Missing required column in {tpm_file}: {str(e)}") + raise + except Exception as e: + self.logger.error(f"Failed to load TPM expression matrix: {str(e)}") raise conditions = expr_df.columns.tolist() - - # 4. Vectorized processing instead of row-wise iteration + + # 4. Vectorized processing instead of row-wise iteration (using counts for filtering) transcript_max_values = expr_df.max(axis=1) - valid_transcripts = set(transcript_max_values[transcript_max_values >= min_value].index) + valid_transcripts = set( + transcript_max_values[transcript_max_values >= min_value].index + ) # Add debug log: Number of valid transcripts after min_value filtering valid_transcript_count = len(valid_transcripts) - self.logger.debug(f"After min_value ({min_value}) filtering: {valid_transcript_count} valid transcripts") - - # 5. Single-pass filtering and value updating + self.logger.debug( + f"After min_value ({min_value}) filtering: {valid_transcript_count} valid transcripts" + ) + + # 5. Single-pass filtering and value updating (using TPMs for values) filtered_dict = {} for condition in conditions: filtered_dict[condition] = {} - condition_values = expr_df[condition] + condition_counts = expr_df[condition] # Still using counts for filtering logic if needed later + condition_tpm_values = tpm_df[condition] # Use TPM values for assigning expression for gene_id, gene_info in base_gene_dict.items(): new_transcripts = { - tid: {**tinfo, 'value': condition_values.get(tid, 0)} + tid: {**tinfo, 'value': condition_tpm_values.get(tid, 0)} # Use TPM values here! for tid, tinfo in gene_info['transcripts'].items() - if tid in valid_transcripts + if tid in valid_transcripts # Filtering is still based on counts implicitly from valid_transcripts } if new_transcripts: @@ -119,9 +141,42 @@ def build_gene_dict_with_expression_and_filter( **gene_info, 'transcripts': new_transcripts } - self._validate_gene_structure(filtered_dict[condition]) # Validate structure for each condition's gene dict + self._validate_gene_structure(filtered_dict[condition]) # Validate structure for each condition's gene dict - save_cache(expr_filter_cache, (filtered_dict, self.novel_gene_ids, self.novel_transcript_ids)) # Save tuple to cache + for condition in conditions: + for gene_id, gene_info in filtered_dict[condition].items(): + # Initialize a dictionary to hold aggregated exon values for the gene. + aggregated_exons = {} + # Iterate over each transcript in the gene. + for transcript_id, transcript_info in gene_info["transcripts"].items(): + transcript_value = transcript_info.get("value", 0) # Now this is TPM value + # Loop through each exon in the current transcript. + for exon in transcript_info.get("exons", []): + exon_id = exon.get("exon_id") + if not exon_id: + continue # Skip if no exon_id is provided. + # If this exon hasn't been seen before, add it. + if exon_id not in aggregated_exons: + aggregated_exons[exon_id] = { + "exon_id": exon_id, + "start": exon["start"], + "end": exon["end"], + "number": exon.get("number", "NA"), + "value": 0.0, + } + # Sum the transcript value into the exon value. + aggregated_exons[exon_id]["value"] += transcript_value + # Now assign the aggregated exon dictionary to the gene. + gene_info["exons"] = aggregated_exons + + # Write exon expression table with proper Path handling + output_file = Path(self.config.output_directory) / "exon_expression_table.tsv" + self.write_exon_expression_table(filtered_dict, output_file) + + save_cache( + expr_filter_cache, (filtered_dict, self.novel_gene_ids, self.novel_transcript_ids) + ) + self.logger.info(f"Saved dictionary to cache at {expr_filter_cache}") return filtered_dict def _get_expression_files(self) -> Tuple[str, str]: @@ -161,77 +216,157 @@ def _get_expression_file(self) -> str: raise FileNotFoundError(f"Count file {expr_file} not found") return expr_file - def _filter_transcripts_above_threshold( - self, gene_dict: Dict[str, Any], min_value: float - ) -> Dict[str, Any]: - """Filter transcripts based on expression threshold.""" - self.logger.info(f"Starting transcript filtering with threshold {min_value}") - - # Track transcripts and their maximum values across all conditions - transcript_max_values = {} - condition_names = list(gene_dict.keys()) - - # First pass: find maximum value for each transcript across all conditions - total_transcripts_before = 0 - for condition in condition_names: - condition_transcripts = sum(len(gene_info.get("transcripts", {})) - for gene_info in gene_dict[condition].values()) - total_transcripts_before += condition_transcripts - self.logger.info(f"Condition {condition}: {condition_transcripts} transcripts before filtering") - - # Log sample of transcripts (max 2 per condition) - sample_transcripts = [] - for gene_info in gene_dict[condition].values(): - sample_transcripts.extend(list(gene_info.get("transcripts", {}).keys())[:2]) - if len(sample_transcripts) >= 2: + def write_exon_expression_table(self, gene_dict: Dict[str, Any], output_path: Path) -> None: + """ + Write a table of exon expressions across conditions. + """ + self.logger.info("Creating exon expression table") + + # Ensure output directory exists. + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Get all conditions (keys in gene_dict). + conditions = list(gene_dict.keys()) + self.logger.debug(f"Processing {len(conditions)} conditions: {conditions}") + + # Prepare header. + header = [ + "Gene Symbol", "Gene Name", "Gene Coordinates", "Ensembl ID", + "Exon number", "Chrom", "Exon start", "Exon end", "Strand" + ] + conditions + + # Instead of looping condition by condition, create a union of gene IDs. + all_gene_ids = set() + for cond in conditions: + all_gene_ids.update(gene_dict[cond].keys()) + self.logger.debug(f"Total unique genes to process: {len(all_gene_ids)}") + + rows = [] + gene_count = 0 # Initialize gene counter for logging + processed_exon_count = 0 + sample_exon_ids = set() # To keep track of sampled exons + num_sample_exons = 100 + + for gene_id in all_gene_ids: + gene_count += 1 # Increment gene counter + self.logger.debug(f"Processing gene {gene_count}/{len(all_gene_ids)}: {gene_id}") + + # Get a representative gene_info (static data is the same across conditions). + rep_gene_info = None + for cond in conditions: + if gene_id in gene_dict[cond]: + rep_gene_info = gene_dict[cond][gene_id] break - if sample_transcripts: - self.logger.debug(f" Sample transcripts in {condition}: {sample_transcripts[:2]}") - - self.logger.info(f"Found {len(transcript_max_values)} unique transcripts across all conditions") - - # Sample of transcripts before filtering - sample_before = list(transcript_max_values.keys())[:5] - self.logger.debug(f"Sample transcripts before filtering: {sample_before}") - - # Build filtered dictionary - filtered_dict = {} - kept_transcripts = set() - - for tid, max_value in transcript_max_values.items(): - if max_value >= min_value: - kept_transcripts.add(tid) - - self.logger.info(f"Keeping {len(kept_transcripts)} transcripts that meet threshold {min_value}") - - # Sample of kept and filtered transcripts - sample_kept = list(kept_transcripts)[:5] - sample_filtered = list(set(transcript_max_values.keys()) - kept_transcripts)[:5] - self.logger.debug(f"Sample kept transcripts: {sample_kept}") - self.logger.debug(f"Sample filtered transcripts: {sample_filtered}") - - # Create filtered dictionary with same structure as input - for condition in condition_names: - filtered_dict[condition] = {} - for gene_id, gene_info in gene_dict[condition].items(): - new_gene_info = copy.deepcopy(gene_info) - new_transcripts = {} - - for tid, tinfo in gene_info.get("transcripts", {}).items(): - if tid in kept_transcripts: - new_transcripts[tid] = tinfo - - if new_transcripts: # Only keep genes that have remaining transcripts - new_gene_info["transcripts"] = new_transcripts - filtered_dict[condition][gene_id] = new_gene_info + if rep_gene_info is None: + self.logger.warning(f"Gene {gene_id} not found in any condition, skipping.") + continue # Should not happen, but skip if not found. + + # Compute gene coordinates string. + gene_coords = f"{rep_gene_info['chromosome']}:{rep_gene_info['start']}-{rep_gene_info['end']}" + + # Collect the union of exon IDs across all conditions for this gene. + all_exon_ids = set() + for cond in conditions: + if gene_id in gene_dict[cond]: + all_exon_ids.update(gene_dict[cond][gene_id].get("exons", {}).keys()) + self.logger.debug(f" Gene {gene_id} - Total unique exons across conditions: {len(all_exon_ids)}") + + exon_count = 0 # Initialize exon counter for logging + exon_ids_list = list(all_exon_ids) # Convert to list for sampling + sampled_exons_for_gene = [] + + # Sample exons if we haven't reached the desired number yet + if processed_exon_count < num_sample_exons: + num_to_sample = min(num_sample_exons - processed_exon_count, len(exon_ids_list)) + sampled_exons_for_gene = random.sample(exon_ids_list, num_to_sample) + + # For each exon, gather condition-specific expression values. + for exon_id in exon_ids_list: # Iterate through all exons, sample only for logging + exon_count += 1 # Increment exon counter + process_exon_for_log = False + if exon_id in sampled_exons_for_gene and processed_exon_count < num_sample_exons and exon_id not in sample_exon_ids: + process_exon_for_log = True + processed_exon_count += 1 + sample_exon_ids.add(exon_id) # Mark as processed + + if process_exon_for_log: + self.logger.debug(f" Gene {gene_id} - Processing exon {exon_count}/{len(all_exon_ids)} (SAMPLE): {exon_id}") + else: + self.logger.debug(f" Gene {gene_id} - Processing exon {exon_count}/{len(all_exon_ids)}: {exon_id}") + + exon_expressions = [] + aggregated_transcript_values = {} # To store transcript values for logging + + for cond in conditions: + expr = 0.0 + # Lookup the gene in the current condition. + gene_info = gene_dict[cond].get(gene_id, {}) + exon_info = gene_info.get("exons", {}).get(exon_id, {}) + expr = exon_info.get("value", 0.0) # Get condition-specific exon value + + if process_exon_for_log: + # Find transcripts contributing to this exon and log their values + contributing_transcripts = [] + for transcript_id, transcript_info in gene_info.get("transcripts", {}).items(): + for exon_data in transcript_info.get("exons", []): + if exon_data.get("exon_id") == exon_id: + transcript_value = transcript_info.get("value", 0.0) + contributing_transcripts.append((transcript_id, transcript_value)) + aggregated_transcript_values[cond] = aggregated_transcript_values.get(cond, []) + [(transcript_id, transcript_value)] + + self.logger.debug(f" Condition {cond} - Exon {exon_id} expression: {expr:.2f}") + if contributing_transcripts: + transcript_log_str = ", ".join([f"{tid}:{val:.2f}" for tid, val in contributing_transcripts]) + self.logger.debug(f" Contributing transcripts (Condition {cond}): {transcript_log_str}") + else: + self.logger.debug(f" No transcripts contributing to exon {exon_id} in condition {cond}") + + + exon_expressions.append(f"{expr:.2f}") + + if process_exon_for_log: + # Log the aggregation process + for cond in conditions: + transcript_values_for_cond = aggregated_transcript_values.get(cond, []) + if transcript_values_for_cond: + sum_of_transcripts = sum([val for tid, val in transcript_values_for_cond]) + self.logger.debug(f" Condition {cond} - Sum of contributing transcript TPMs for exon {exon_id}: {sum_of_transcripts:.2f} (Exon TPM: {exon_expressions[conditions.index(cond)]})") + else: + self.logger.debug(f" Condition {cond} - No contributing transcripts found to sum for exon {exon_id}") + + + # Use the representative gene's exon info for static details. + rep_exon_info = rep_gene_info.get("exons", {}).get(exon_id, {}) + exon_number = rep_exon_info.get("number", "NA") + exon_start = str(rep_exon_info.get("start", "")) + exon_end = str(rep_exon_info.get("end", "")) + + row = [ + gene_id, # Gene Symbol + rep_gene_info.get("name", ""), # Gene Name + gene_coords, # Gene Coordinates + exon_id, # Ensembl ID (exon_id) + exon_number, # Exon number + rep_gene_info["chromosome"], # Chromosome + exon_start, # Exon start + exon_end, # Exon end + rep_gene_info["strand"], # Strand + ] + exon_expressions # Expression values for each condition + + rows.append(row) + self.logger.debug(f" Gene {gene_id} - Row for exon {exon_id} prepared.") + if process_exon_for_log: + self.logger.debug(f" Gene {gene_id} - Sampled exon {exon_id} processing complete.") + + # Write header and rows to the output file. + self.logger.info(f"Writing {len(rows)} exon entries to table") + with open(output_path, 'w') as f: + f.write('\t'.join(header) + '\n') + for row in rows: + f.write('\t'.join(str(x) for x in row) + '\n') + + self.logger.info(f"Exon expression table written to {output_path}") - # Log final statistics - for condition in condition_names: - final_count = sum(len(gene_info.get("transcripts", {})) - for gene_info in filtered_dict[condition].values()) - self.logger.debug(f" {condition}: {final_count} transcripts") - - return filtered_dict # ------------------ READ ASSIGNMENT CACHING ------------------ @@ -464,7 +599,6 @@ def parse_extended_annotation(self) -> Dict[str, Any]: "start": int(fields[3]), "end": int(fields[4]), "exons": [], - "expression": 0.0, "tags": attrs.get("tags", "").split(","), "name": attrs.get("transcript_name", transcript_id), } @@ -472,9 +606,11 @@ def parse_extended_annotation(self) -> Dict[str, Any]: elif feature_type == "exon" and transcript_id and gene_id: if gene_id in base_gene_dict and transcript_id in base_gene_dict[gene_id]["transcripts"]: exon_info = { + "exon_id": attrs.get("exon_id", ""), "start": int(fields[3]), "end": int(fields[4]), - "number": attrs.get("exon_number", "1") + "number": attrs.get("exon_number", "1"), + "value": 0.0 } base_gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append(exon_info) @@ -483,36 +619,6 @@ def parse_extended_annotation(self) -> Dict[str, Any]: except Exception as e: self.logger.error(f"GTF parsing failed: {str(e)}") raise - - def update_transcript_values(self, gene_dict: Dict[str, Any], counts_file: str, tpm_file: str) -> Dict[str, Any]: - """Update transcript values from TPM file after filtering with counts.""" - # Read counts for filtering - counts_df = pd.read_csv(counts_file, sep='\t', comment=None) - counts_df.columns = [col.lstrip('#') for col in counts_df.columns] - counts_df = counts_df.set_index('feature_id') - - # Read TPMs for values - tpm_df = pd.read_csv(tpm_file, sep='\t', comment=None) - tpm_df.columns = [col.lstrip('#') for col in tpm_df.columns] - tpm_df = tpm_df.set_index('feature_id') - - # Align indices between counts and TPMs - common_transcripts = counts_df.index.intersection(tpm_df.index) - tpm_df = tpm_df.loc[common_transcripts] - - # Rest of existing update logic using tpm_df instead of expr_df - condition_gene_dict = {condition: copy.deepcopy(gene_dict) for condition in tpm_df.columns} - - for tid, row in tpm_df.iterrows(): - base_tid = tid.split('.')[0] - for condition, tpm_value in row.items(): - # Add nested loop to access gene_info - for gene_id, gene_info in condition_gene_dict[condition].items(): - if base_tid in gene_info.get('transcripts', {}): - gene_info['transcripts'][base_tid]['value'] = float(tpm_value) - - return condition_gene_dict - # -------------------- UPDATES & UTILITIES -------------------- def update_gene_names(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: @@ -632,100 +738,6 @@ def get_novel_feature_ids(self) -> Tuple[set, set]: """Return the sets of novel gene and transcript IDs.""" return self.novel_gene_ids, self.novel_transcript_ids - def _filter_low_expression_transcripts(self, condition_gene_dict: Dict[str, Any], min_value: float) -> Dict[str, Any]: - """Filter transcripts based on expression threshold.""" - self.logger.info(f"Starting transcript filtering with threshold {min_value}") - - # Track transcripts and their maximum values across all conditions - transcript_max_values = {} - condition_names = list(condition_gene_dict.keys()) - - # First pass: find maximum value for each transcript across all conditions - total_transcripts_before = 0 - for condition in condition_names: - condition_transcripts = 0 - for gene_info in condition_gene_dict[condition].values(): - for tid, tinfo in gene_info.get("transcripts", {}).items(): - current_value = tinfo.get('value', 0) - # Update maximum value tracking - if tid not in transcript_max_values or current_value > transcript_max_values[tid]: - transcript_max_values[tid] = current_value - condition_transcripts += 1 - total_transcripts_before += condition_transcripts - self.logger.info(f"Condition {condition}: {condition_transcripts} transcripts before filtering") - - # Log sample of transcripts (max 2 per condition) - sample_transcripts = [] - for gene_info in condition_gene_dict[condition].values(): - sample_transcripts.extend(list(gene_info.get("transcripts", {}).keys())[:2]) - if len(sample_transcripts) >= 2: - break - if sample_transcripts: - self.logger.debug(f" Sample transcripts in {condition}: {sample_transcripts[:2]}") - - self.logger.info(f"Found {len(transcript_max_values)} unique transcripts across all conditions") - - # Sample of transcripts before filtering - sample_before = list(transcript_max_values.keys())[:5] - self.logger.debug(f"Sample transcripts before filtering: {sample_before}") - - # Build filtered dictionary - filtered_dict = {} - kept_transcripts = set() - - for tid, max_value in transcript_max_values.items(): - if max_value >= min_value: - kept_transcripts.add(tid) - - self.logger.info(f"Keeping {len(kept_transcripts)} transcripts that meet threshold {min_value}") - - # Sample of kept and filtered transcripts - sample_kept = list(kept_transcripts)[:5] - sample_filtered = list(set(transcript_max_values.keys()) - kept_transcripts)[:5] - self.logger.debug(f"Sample kept transcripts: {sample_kept}") - self.logger.debug(f"Sample filtered transcripts: {sample_filtered}") - - # Create filtered dictionary with same structure as input - for condition in condition_names: - filtered_dict[condition] = {} - for gene_id, gene_info in condition_gene_dict[condition].items(): - new_gene_info = copy.deepcopy(gene_info) - new_transcripts = {} - - for tid, tinfo in gene_info.get("transcripts", {}).items(): - if tid in kept_transcripts: - new_transcripts[tid] = tinfo - - if new_transcripts: # Only keep genes that have remaining transcripts - new_gene_info["transcripts"] = new_transcripts - filtered_dict[condition][gene_id] = new_gene_info - - # Log final statistics - for condition in condition_names: - final_count = sum(len(gene_info.get("transcripts", {})) - for gene_info in filtered_dict[condition].values()) - self.logger.debug(f" {condition}: {final_count} transcripts") - - return filtered_dict - - def _batch_update_values(self, gene_dict, expr_df, valid_transcripts): - """Vectorized value updating for all conditions.""" - return { - condition: { - gene_id: { - **gene_info, - 'transcripts': { - tid: {**tinfo, 'value': expr_df.at[tid, condition]} - for tid, tinfo in gene_info['transcripts'].items() - if tid in valid_transcripts - } - } - for gene_id, gene_info in gene_dict.items() - if any(tid in valid_transcripts for tid in gene_info['transcripts']) - } - for condition in expr_df.columns - } - def _validate_gene_structure(self, gene_dict: Dict[str, Any]) -> None: """Ensure proper gene-centric structure before condition processing.""" required_gene_keys = ['chromosome', 'start', 'end', 'strand', 'name', 'biotype', 'transcripts'] From 172cdaa856b3d96a1fdcde9a1783753eca93ed88 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Mon, 17 Feb 2025 23:14:31 -0600 Subject: [PATCH 26/35] Cleaned up for testing, tested ref only, didnt test GSEA, validated over new data, changed logger, updated reqs --- install_r_packages.py | 24 +++ requirements.txt | 4 + src/visualization_dictionary_builder.py | 32 ++-- src/visualization_differential_exp.py | 172 ++++++++++----------- src/visualization_gsea.py | 3 +- src/visualization_mapping.py | 6 +- src/visualization_plotter.py | 192 +++++++++++++++-------- src/visualize_expression.py | 193 ------------------------ visualize.py | 33 ++-- 9 files changed, 289 insertions(+), 370 deletions(-) create mode 100644 install_r_packages.py delete mode 100644 src/visualize_expression.py diff --git a/install_r_packages.py b/install_r_packages.py new file mode 100644 index 00000000..7461369b --- /dev/null +++ b/install_r_packages.py @@ -0,0 +1,24 @@ +import rpy2.robjects.packages as rpackages +from rpy2.robjects.vectors import StrVector + +# List of R packages to install +r_package_names = ('DESeq2', 'ggplot2', 'ggrepel', 'RColorBrewer', 'clusterProfiler', 'org.Hs.eg.db') + +# Get R's utility package +utils = rpackages.importr('utils') + +# Select CRAN mirror (optional, but recommended for reproducibility) +utils.chooseCRANmirror(ind=1) # Select the first mirror in the list + +# Function to check if R package is installed +def is_installed(package_name): + return package_name in rpackages.packages() + +# Install R packages if not already installed +packages_to_install = [pkg for pkg in r_package_names if not is_installed(pkg)] + +if packages_to_install: + print(f"Installing R packages: {', '.join(packages_to_install)}") + utils.install_packages(StrVector(packages_to_install)) +else: + print("All required R packages are already installed.") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 52037078..6ac32a8d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,8 @@ matplotlib>=3.1.3 numpy>=1.18.1 scipy>=1.4.1 seaborn>=0.10.0 +rpy2>=3.5.1 +mygene>=3.2.0 + + diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py index 525abb3e..c17e84c2 100644 --- a/src/visualization_dictionary_builder.py +++ b/src/visualization_dictionary_builder.py @@ -70,7 +70,10 @@ def build_gene_dict_with_expression_and_filter( # 2. Filter novel genes from the base gene dict (not per-condition) self.logger.info("Parsing GTF and filtering novel genes") - parsed_data = self.parse_extended_annotation() + if self.config.ref_only: + parsed_data = self.parse_input_gtf() + else: + parsed_data = self.parse_extended_annotation() self._validate_gene_structure(parsed_data) base_gene_dict = self._filter_novel_genes(parsed_data) @@ -472,6 +475,7 @@ def parse_input_gtf(self) -> Dict[str, Any]: """ Parse the reference GTF file using gffutils with optimized settings, building a dictionary of genes, transcripts, and exons. + Updated to match the structure of parse_extended_annotation. """ if not self.config.genedb_filename: db_path = self.cache_dir / "gtf.db" @@ -503,13 +507,16 @@ def parse_input_gtf(self) -> Dict[str, Any]: continue gene_id = feature.id + gene_name = feature.attributes.get("gene_name", [gene_id])[0] # Default to gene_id if name missing + gene_biotype = feature.attributes.get("gene_biotype", ["unknown"])[0] # Default to "unknown" + gene_dict[gene_id] = { "chromosome": feature.seqid, "start": feature.start, "end": feature.end, "strand": feature.strand, - "name": feature.attributes.get("gene_name", [""])[0], - "biotype": feature.attributes.get("gene_biotype", [""])[0], + "name": gene_name, # Use updated gene_name + "biotype": gene_biotype, # Use updated gene_biotype "transcripts": {}, } @@ -521,13 +528,17 @@ def parse_input_gtf(self) -> Dict[str, Any]: continue transcript_id = feature.id + transcript_name = feature.attributes.get("transcript_name", [transcript_id])[0] # Default to transcript_id + transcript_biotype = feature.attributes.get("transcript_biotype", ["unknown"])[0] # Default to "unknown" + transcript_tags = feature.attributes.get("tag", [""])[0].split(",") # Get tags + gene_dict[gene_id]["transcripts"][transcript_id] = { "start": feature.start, "end": feature.end, - "name": feature.attributes.get("transcript_name", [""])[0], - "biotype": feature.attributes.get("transcript_biotype", [""])[0], + "name": transcript_name, # Use updated transcript_name + "biotype": transcript_biotype, # Use updated transcript_biotype "exons": [], - "tags": feature.attributes.get("tag", [""])[0].split(","), + "tags": transcript_tags, # Use updated transcript_tags } elif feature.featuretype == "exon": gene_id = feature.attributes.get("gene_id", [""])[0] @@ -536,12 +547,15 @@ def parse_input_gtf(self) -> Dict[str, Any]: gene_id in gene_dict and transcript_id in gene_dict[gene_id]["transcripts"] ): + exon_number = feature.attributes.get("exon_number", ["1"])[0] # Default to "1" + exon_id = feature.attributes.get("exon_id", [""])[0] # Get exon_id + gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append( { - "exon_id": feature.id, + "exon_id": exon_id, # Use retrieved exon_id "start": feature.start, "end": feature.end, - "number": feature.attributes.get("exon_number", [""])[0], + "number": exon_number, # Use updated exon_number } ) @@ -676,7 +690,7 @@ def read_gene_list(self, gene_list_path: Union[str, Path]) -> List[str]: try: with open(gene_list_path, "r") as file: gene_list = [line.strip().upper() for line in file if line.strip()] - self.logger.info(f"Read {len(gene_list)} genes from {gene_list_path}") + self.logger.debug(f"Read {len(gene_list)} genes from {gene_list_path}") return gene_list except Exception as e: self.logger.error(f"Error reading gene list from {gene_list_path}: {e}") diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 18bd45de..487551c7 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -12,9 +12,7 @@ import numpy as np from scipy.stats import gmean from sklearn.decomposition import PCA -import matplotlib.pyplot as plt -import seaborn as sns - +from rpy2.rinterface_lib import callbacks class DifferentialAnalysis: def __init__( @@ -28,10 +26,6 @@ def __init__( dictionary_builder: "DictionaryBuilder" = None, ): """Initialize differential expression analysis.""" - # Configure rpy2 to suppress R console output - from rpy2.rinterface_lib import callbacks - - # Create a custom callback that does nothing def quiet_cb(x): pass @@ -41,7 +35,7 @@ def quiet_cb(x): callbacks.consolewrite_warnerror = quiet_cb self.output_dir = Path(output_dir) - self.deseq_dir = Path(viz_output) / "deseq2_results" + self.deseq_dir = Path(viz_output) / "differential_expression" self.deseq_dir.mkdir(parents=True, exist_ok=True) self.ref_conditions = ref_conditions self.target_conditions = target_conditions @@ -105,8 +99,8 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame] transcript_counts = transcript_counts[~transcript_counts.index.isin(novel_transcript_ids)] # Filter out novel transcripts novel_filtered_count = transcript_counts.shape[0] removed_novel_count = original_transcript_count_novel_filter - novel_filtered_count - self.logger.info(f"Novel transcript filtering: Removed {removed_novel_count} novel transcripts ({removed_novel_count / original_transcript_count_novel_filter * 100:.1f}%)") - self.logger.debug(f"Transcript counts shape after novel filtering: {transcript_counts.shape}") + self.logger.info(f"Novel transcript filtering: Removed {removed_novel_count} transcripts from novel genes ({removed_novel_count / original_transcript_count_novel_filter * 100:.1f}%)") + self.logger.debug(f"Transcript counts shape after novel gene filtering: {transcript_counts.shape}") else: self.logger.info("Novel transcript filtering: Skipped (no dictionary builder)") @@ -129,65 +123,65 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame] self.transcript_count_data = transcript_counts_filtered # Store filtered transcript counts # --- 5. Run DESeq2 Analysis --- - if not self.ref_only: - deseq2_results_gene_file, gene_normalized_counts = self._run_level_analysis( - level="gene", - pattern="gene_grouped_counts.tsv", - count_data=gene_counts_filtered, - coldata=self._build_design_matrix(gene_counts_filtered) - ) - deseq2_results_transcript_file, transcript_normalized_counts = self._run_level_analysis( - level="transcript", - pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", - count_data=transcript_counts_filtered - ) + + deseq2_results_gene_file, gene_normalized_counts = self._run_level_analysis( + level="gene", + pattern="gene_grouped_counts.tsv", + count_data=gene_counts_filtered, + coldata=self._build_design_matrix(gene_counts_filtered) + ) + deseq2_results_transcript_file, transcript_normalized_counts = self._run_level_analysis( + level="transcript", + pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", + count_data=transcript_counts_filtered + ) - # Load the gene-level results for GSEA - deseq2_results_df = pd.read_csv(deseq2_results_gene_file) + # Load the gene-level results for GSEA + deseq2_results_df = pd.read_csv(deseq2_results_gene_file) - # Update how we create the labels - target_label = "+".join(self.target_conditions) - reference_label = "+".join(self.ref_conditions) + # Update how we create the labels + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) - # --- Visualize Gene-Level Results --- - self.visualizer.visualize_results( - results=deseq2_results_df, - target_label=target_label, - reference_label=reference_label, - min_count=10, - feature_type="genes", - ) - self.logger.info(f"Gene-level visualizations saved to {self.deseq_dir}") - - # Run PCA with correct labels - normalized_gene_counts = self._median_ratio_normalization(gene_counts_filtered) - self._run_pca( - normalized_gene_counts, - "gene", - coldata=self._build_design_matrix(gene_counts_filtered), - target_label=target_label, - reference_label=reference_label - ) + # --- Visualize Gene-Level Results --- + self.visualizer.visualize_results( + results=deseq2_results_df, + target_label=target_label, + reference_label=reference_label, + min_count=10, + feature_type="genes", + ) + self.logger.info(f"Gene-level visualizations saved to {self.deseq_dir}") + + # Run PCA with correct labels + normalized_gene_counts = self._median_ratio_normalization(gene_counts_filtered) + self._run_pca( + normalized_gene_counts, + "gene", + coldata=self._build_design_matrix(gene_counts_filtered), + target_label=target_label, + reference_label=reference_label + ) - # --- Visualize Transcript-Level Results --- - self.visualizer.visualize_results( - results=pd.read_csv(deseq2_results_transcript_file), - target_label=target_label, - reference_label=reference_label, - min_count=10, - feature_type="transcripts", - ) - self.logger.info(f"Transcript-level visualizations saved to {self.deseq_dir}") - - # Run PCA with correct labels for transcript level - normalized_transcript_counts = self._median_ratio_normalization(transcript_counts_filtered) - self._run_pca( - normalized_transcript_counts, - "transcript", - coldata=self._build_design_matrix(transcript_counts_filtered), - target_label=target_label, - reference_label=reference_label - ) + # --- Visualize Transcript-Level Results --- + self.visualizer.visualize_results( + results=pd.read_csv(deseq2_results_transcript_file), + target_label=target_label, + reference_label=reference_label, + min_count=10, + feature_type="transcripts", + ) + self.logger.info(f"Transcript-level visualizations saved to {self.deseq_dir}") + + # Run PCA with correct labels for transcript level + normalized_transcript_counts = self._median_ratio_normalization(transcript_counts_filtered) + self._run_pca( + normalized_transcript_counts, + "transcript", + coldata=self._build_design_matrix(transcript_counts_filtered), + target_label=target_label, + reference_label=reference_label + ) return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered, deseq2_results_df @@ -257,7 +251,7 @@ def _get_condition_data(self, pattern: str) -> pd.DataFrame: count_files = list(condition_dir.glob(f"*{adjusted_pattern}")) # Use adjusted pattern for file_path in count_files: - self.logger.info(f"Reading count data from: {file_path}") + self.logger.debug(f"Reading count data from: {file_path}") df = pd.read_csv(file_path, sep="\t", dtype={"#feature_id": str}) df.set_index("#feature_id", inplace=True) @@ -308,8 +302,7 @@ def _filter_counts(self, count_data: pd.DataFrame, min_count: int = 10, level: s filtered_data = count_data[keep_features] self.logger.info( - f"After filtering: Retained {filtered_data.shape[0]}/{count_data.shape[0]} features " - f"({(filtered_data.shape[0]/count_data.shape[0]*100):.1f}%)" + f"After filtering: Retained {filtered_data.shape[0]} features" ) return filtered_data @@ -390,39 +383,48 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: top_genes = [row["gene_name"] for row in top_unique_gene_transcripts] # Extract gene names from selected transcripts # Write to file - top_genes_file = self.deseq_dir / "genes_from_top_100_transcripts.txt" + top_genes_file = self.deseq_dir / "genes_of_top_100_DE_transcripts.txt" pd.Series(top_genes).to_csv(top_genes_file, index=False, header=False) - self.logger.info(f"Wrote genes from top 100 transcripts to {top_genes_file}") + self.logger.debug(f"Wrote genes of top 100 DE transcripts to {top_genes_file}") else: # For gene-level analysis, keep original behavior # top_genes = results.nlargest(100, "abs_stat")["symbol"] # OLD: was writing symbols (gene IDs) top_genes = results.nlargest(100, "abs_stat")["gene_name"] - top_genes_file = self.deseq_dir / "top_100_genes.txt" + top_genes_file = self.deseq_dir / "top_100_DE_genes.txt" top_genes.to_csv(top_genes_file, index=False, header=False) - self.logger.info(f"Wrote top 100 genes to {top_genes_file}") + self.logger.debug(f"Wrote top 100 DE genes to {top_genes_file}") def _run_pca(self, normalized_counts, level, coldata, target_label, reference_label): """Run PCA analysis and create visualization.""" self.logger.info(f"Running PCA for {level} level...") - - # Run PCA - pca = PCA(n_components=2) + + # Run PCA - Calculate 10 components + pca = PCA(n_components=10) # Keep n_components=10 to generate scree plot with 10 components log_normalized_counts = np.log2(normalized_counts + 1) pca_result = pca.fit_transform(log_normalized_counts.transpose()) - - # Create DataFrame - pca_df = pd.DataFrame(data=pca_result, columns=['PC1', 'PC2'], index=log_normalized_counts.columns) - pca_df['group'] = coldata['group'].values - - # Calculate variance + #map the feature names to gene names using the gene_mapper + feature_names = normalized_counts.index.tolist() + gene_names = self.gene_mapper.map_gene_symbols(feature_names, level, self.updated_gene_dict) + + # Get explained variance ratio and loadings explained_variance = pca.explained_variance_ratio_ + loadings = pca.components_ # Loadings are in pca.components_ + + # Create DataFrame with columns for all 10 components + pc_columns = [f'PC{i+1}' for i in range(10)] # Generate column names: PC1, PC2, ..., PC10 + pca_df = pd.DataFrame(data=pca_result, columns=pc_columns, index=log_normalized_counts.columns) # Use all 10 column names + pca_df['group'] = coldata['group'].values + title = f"{level.capitalize()} Level PCA: {target_label} vs {reference_label}\nPC1 ({100*explained_variance[0]:.2f}%) / PC2 ({100*explained_variance[1]:.2f}%)" - - # Use the plotter's PCA method + + # Use the plotter's PCA method, passing explained variance and loadings self.visualizer.plot_pca( - pca_df=pca_df, + pca_df=pca_df, # pca_df now contains 10 components title=title, - output_prefix=f"pca_{level}" + output_prefix=f"pca_{level}", + explained_variance=explained_variance, # Pass explained variance (for scree plot) + loadings=loadings, # Pass loadings (for loadings output) + feature_names=gene_names # Pass feature names (gene names) ) def _median_ratio_normalization(self, count_data: pd.DataFrame) -> pd.DataFrame: diff --git a/src/visualization_gsea.py b/src/visualization_gsea.py index 148fc718..619be95b 100644 --- a/src/visualization_gsea.py +++ b/src/visualization_gsea.py @@ -7,6 +7,7 @@ from rpy2.robjects import r, pandas2ri from rpy2.robjects.packages import importr from rpy2.robjects.conversion import localconverter +from rpy2.rinterface_lib import callbacks class GSEAAnalysis: @@ -21,7 +22,7 @@ def __init__(self, output_path: Path): self.output_path.mkdir(parents=True, exist_ok=True) # Configure R to be quiet - from rpy2.rinterface_lib import callbacks + def quiet_cb(x): pass diff --git a/src/visualization_mapping.py b/src/visualization_mapping.py index b91718f0..ee1f3af6 100644 --- a/src/visualization_mapping.py +++ b/src/visualization_mapping.py @@ -43,7 +43,7 @@ def get_gene_info_from_mygene(self, ensembl_ids: List[str]) -> Dict[str, Dict]: } # Log query statistics - self.logger.info( + self.logger.debug( f"MyGene.info query stats: " f"Total={len(ensembl_ids)}, " f"Found={len(mapping)}, " @@ -86,7 +86,7 @@ def map_genes(self, gene_ids: List[str], updated_gene_dict: Dict) -> Dict[str, T # For unmapped genes, try MyGene.info batch query if unmapped_ids: - self.logger.info(f"Querying MyGene.info for {len(unmapped_ids)} unmapped genes") + self.logger.debug(f"Querying MyGene.info for {len(unmapped_ids)} unmapped genes") mygene_results = self.get_gene_info_from_mygene(unmapped_ids) remaining_unmapped = [] @@ -191,7 +191,7 @@ def map_gene_symbols(self, feature_ids: List[str], level: str, updated_gene_dict # Perform batched MyGene API query for all unmapped gene IDs at once (gene-level only) if level == "gene" and unmapped_gene_ids_batch: - self.logger.info(f"Gene-level mapping: Performing batched MyGene API query for {len(unmapped_gene_ids_batch)} gene IDs") + self.logger.debug(f"Gene-level mapping: Performing batched MyGene API query for {len(unmapped_gene_ids_batch)} gene IDs") mygene_batch_info = self.get_gene_info_from_mygene(unmapped_gene_ids_batch) # Batched query if mygene_batch_info: diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 4d914856..2efafdfd 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -6,7 +6,7 @@ import pandas as pd import matplotlib.patches as patches import seaborn as sns - +from typing import List class PlotOutput: def __init__( @@ -19,6 +19,7 @@ def __init__( filter_transcripts=None, conditions=False, use_counts=False, + ref_only=False, ): self.updated_gene_dict = updated_gene_dict self.gene_names = gene_names @@ -28,6 +29,7 @@ def __init__( self.filter_transcripts = filter_transcripts self.conditions = conditions self.use_counts = use_counts + self.ref_only = ref_only # Ensure output directories exist if self.gene_visualizations_dir: @@ -40,16 +42,24 @@ def plot_transcript_map(self): logging.warning("No gene_visualizations_dir provided. Skipping transcript map plotting.") return - for gene_name in self.gene_names: - gene_data = {} - for condition, genes in self.updated_gene_dict.items(): - if gene_name in genes: - gene_data = genes[gene_name] - break + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) + gene_data = None # Initialize gene_data to None + found_by_name = False # Flag to track if gene was found by name - if not gene_data: - logging.warning(f"Gene {gene_name} not found in the data.") - continue + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) + gene_data = gene_info + found_by_name = True + break # Found gene, break inner loop + if found_by_name: + break # Found gene, break outer loop + + if gene_data: + logging.debug(f"Gene {gene_name_or_id} found by name in the data.") + else: + logging.warning(f"Gene {gene_name_or_id} not found in the data.") + continue # Skip to the next gene if not found # Get chromosome info and calculate buffer chromosome = gene_data.get("chromosome", "Unknown") @@ -63,25 +73,18 @@ def plot_transcript_map(self): plot_end = end + buffer plot_height = max(8, len(gene_data["transcripts"]) * 0.4) - logging.debug(f"Creating transcript map for gene '{gene_name}' with {len(gene_data['transcripts'])} transcripts") - - # Collect all reference exon coordinates from reference transcripts - reference_exons = set() - for transcript_id, transcript_info in gene_data["transcripts"].items(): - if transcript_id.startswith("ENST"): - for exon in transcript_info["exons"]: - # Store exon coordinates as tuple for easy comparison - reference_exons.add((exon["start"], exon["end"])) - - logging.debug(f"Found {len(reference_exons)} reference exons for gene '{gene_name}'") + logging.debug(f"Creating transcript map for gene '{gene_name_or_id}' with {len(gene_data['transcripts'])} transcripts") + + fig, ax = plt.subplots(figsize=(12, plot_height)) # Add legend handles legend_elements = [ - patches.Patch(facecolor='skyblue', label='Reference Exon'), - patches.Patch(facecolor='red', alpha=0.6, label='Novel Exon') + patches.Patch(facecolor='skyblue', label='Exon'), ] + if not self.ref_only: + legend_elements.append(patches.Patch(facecolor='red', alpha=0.6, label='Novel Exon')) # Plot each transcript y_ticks = [] @@ -109,10 +112,13 @@ def plot_transcript_map(self): # Exon blocks with color based on reference status for exon in transcript_info["exons"]: exon_length = exon["end"] - exon["start"] - # Check if this exon's coordinates match any reference exon - is_reference_exon = (exon["start"], exon["end"]) in reference_exons - exon_color = "skyblue" if is_reference_exon else "red" - exon_alpha = 1.0 if is_reference_exon else 0.6 + if self.ref_only: # Check ref_only flag + exon_color = "skyblue" # If ref_only, always treat as reference + exon_alpha = 1.0 + else: + is_reference_exon = exon["exon_id"].startswith("ENSE") # Original logic + exon_color = "skyblue" if is_reference_exon else "red" + exon_alpha = 1.0 if is_reference_exon else 0.6 ax.add_patch( plt.Rectangle( @@ -124,8 +130,12 @@ def plot_transcript_map(self): ) ) - if not any((exon["start"], exon["end"]) in reference_exons for exon in transcript_info["exons"]): - logging.debug(f"Transcript {transcript_id} in gene {gene_name} contains all novel exons") + if not any(exon["exon_id"].startswith("ENSE") for exon in transcript_info["exons"]): + logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains NO reference exons (based on ENSEMBL IDs)") + #log the exon_ids + logging.debug(f"Exon IDs: {[exon['exon_id'] for exon in transcript_info['exons']]}") + else: + logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains at least one reference exon (based on ENSEMBL IDs)") # Store y-axis label information y_ticks.append(i) @@ -136,10 +146,11 @@ def plot_transcript_map(self): y_labels.append(transcript_name) # Set up the plot formatting with just chromosome + gene_display_name = gene_data.get("name", gene_name_or_id) # Fallback to ID if no name if self.filter_transcripts: - title = f"Transcript Structure - {gene_name} (Chromosome {chromosome}) (Count > {self.filter_transcripts})" + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (Count > {self.filter_transcripts})" else: - title = f"Transcript Structure - {gene_name} (Chromosome {chromosome})" + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome})" ax.set_title(title, pad=20) # Increase padding to move title up ax.set_xlabel("Chromosomal position") @@ -161,11 +172,11 @@ def plot_transcript_map(self): plt.tight_layout() plot_path = os.path.join( - self.gene_visualizations_dir, f"{gene_name}_splicing.png" + self.gene_visualizations_dir, f"{gene_name_or_id}_splicing.png" # Use gene_name_or_id in filename ) plt.savefig(plot_path, bbox_inches='tight', dpi=300) plt.close(fig) - logging.debug(f"Saved transcript map for gene '{gene_name}' at: {plot_path}") + logging.debug(f"Saved transcript map for gene '{gene_name_or_id}' at: {plot_path}") def plot_transcript_usage(self): @@ -174,35 +185,43 @@ def plot_transcript_usage(self): logging.warning("No gene_visualizations_dir provided. Skipping transcript usage plotting.") return - for gene_name in self.gene_names: - gene_data = {} - for condition, genes in self.updated_gene_dict.items(): - if gene_name in genes: - gene_data[condition] = genes[gene_name]["transcripts"] - - if not gene_data: - logging.warning(f"Gene {gene_name} not found in the data.") - continue + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) + gene_data_per_condition = {} # Store gene data per condition + found_gene_any_condition = False # Flag if gene found in any condition - conditions = list(gene_data.keys()) + for condition, genes in self.updated_gene_dict.items(): + condition_gene_data = None + for gene_id, gene_info in genes.items(): + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) + condition_gene_data = gene_info["transcripts"] # Only need transcripts for usage plot + found_gene_any_condition = True + break # Found gene in this condition, break inner loop + if condition_gene_data: + gene_data_per_condition[condition] = condition_gene_data # Store transcripts for this condition + + if not found_gene_any_condition: + logging.warning(f"Gene {gene_name_or_id} not found in the data.") + continue # Skip to the next gene if not found + + conditions = list(gene_data_per_condition.keys()) n_bars = len(conditions) fig, ax = plt.subplots(figsize=(12, 8)) index = np.arange(n_bars) bar_width = 0.35 opacity = 0.8 - max_transcripts = max(len(gene_data[condition]) for condition in conditions) + max_transcripts = max(len(gene_data_per_condition[condition]) for condition in conditions) colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) bottom_val = np.zeros(n_bars) for i, condition in enumerate(conditions): - transcripts = gene_data[condition] + transcripts = gene_data_per_condition[condition] for j, (transcript_id, transcript_info) in enumerate(transcripts.items()): color = colors[j % len(colors)] value = transcript_info["value"] # Get transcript name with fallback options - transcript_name = (transcript_info.get("name") or - transcript_info.get("transcript_id") or + transcript_name = (transcript_info.get("name") or + transcript_info.get("transcript_id") or transcript_id) ax.bar( i, @@ -217,7 +236,8 @@ def plot_transcript_usage(self): ax.set_xlabel("Sample Type") ax.set_ylabel("Transcript Usage (TPM)") - ax.set_title(f"Transcript Usage for {gene_name} by Sample Type") + gene_display_name = list(gene_data_per_condition.values())[0].get("name", gene_name_or_id) # Fallback to ID if no name + ax.set_title(f"Transcript Usage for {gene_display_name} by Sample Type") ax.set_xticks(index) ax.set_xticklabels(conditions) ax.legend( @@ -230,7 +250,7 @@ def plot_transcript_usage(self): plt.tight_layout() plot_path = os.path.join( self.gene_visualizations_dir, - f"{gene_name}_transcript_usage_by_sample_type.png", + f"{gene_name_or_id}_transcript_usage_by_sample_type.png", # Use gene_name_or_id in filename ) plt.savefig(plot_path) plt.close(fig) @@ -510,25 +530,33 @@ def visualize_results( raise - def plot_pca(self, pca_df: pd.DataFrame, title: str, output_prefix: str) -> Path: - """Plot PCA scatter plot.""" + def plot_pca( + self, + pca_df: pd.DataFrame, + title: str, + output_prefix: str, + explained_variance: np.ndarray, + loadings: np.ndarray, + feature_names: List[str] + ) -> Path: + """Plot PCA scatter plot, scree plot, and loadings.""" plt.figure(figsize=(8, 6)) - + # Extract variance info from title for axis labels only pc1_var = title.split("PC1 (")[1].split("%)")[0] pc2_var = title.split("PC2 (")[1].split("%)")[0] - - # Get clean title without PCs and variance - using string literal instead of \n + + # Get clean title without PCs and variance base_title = title.split(' Level PCA: ')[0] comparison = title.split(': ')[1].split('PC1')[0].strip() clean_title = f"{base_title} Level PCA: {comparison}" - + # Update group labels in the DataFrame condition_mapping = {'Target': title.split(": ")[1].split(" vs ")[0], 'Reference': title.split(" vs ")[1].split("PC1")[0].strip()} pca_df['group'] = pca_df['group'].map(condition_mapping) - - # Create plot with updated labels + + # Create scatter plot sns.scatterplot(x='PC1', y='PC2', hue='group', data=pca_df, s=100) plt.xlabel(f'PC1 ({pc1_var}%)') plt.ylabel(f'PC2 ({pc2_var}%)') @@ -537,8 +565,54 @@ def plot_pca(self, pca_df: pd.DataFrame, title: str, output_prefix: str) -> Path plt.gca().spines['right'].set_visible(False) plt.tight_layout() - output_path = self.output_path / f"{output_prefix}_pca.png" - plt.savefig(output_path) + scatter_plot_path = self.output_path / f"{output_prefix}.png" # Separate path for scatter plot + plt.savefig(scatter_plot_path) + plt.close() + + # --- Scree Plot --- + self._plot_scree(explained_variance, output_prefix) + + # --- Loadings --- + self._output_loadings(loadings, feature_names, output_prefix) + + return scatter_plot_path # Return path to scatter plot + + def _plot_scree(self, explained_variance: np.ndarray, output_prefix: str) -> Path: + """Plot scree plot of explained variance.""" + plt.figure(figsize=(8, 6)) + num_components = len(explained_variance) + component_numbers = range(1, num_components + 1) + + plt.bar(component_numbers, explained_variance * 100) + plt.xlabel('Principal Component') + plt.ylabel('Percentage of Explained Variance') + plt.title('Scree Plot') + plt.xticks(component_numbers) # Ensure all component numbers are labeled + plt.gca().spines['top'].set_visible(False) + plt.gca().spines['right'].set_visible(False) + plt.tight_layout() + + scree_plot_path = self.output_path / f"scree_{output_prefix}.png" + plt.savefig(scree_plot_path) plt.close() + return scree_plot_path + + def _output_loadings(self, loadings: np.ndarray, feature_names: List[str], output_prefix: str, top_n: int = 10) -> Path: + """Output top N loadings for PC1 and PC2.""" + # Generate column names dynamically based on the number of components + num_components = loadings.shape[0] # Get the number of components from loadings shape + pc_columns = [f'PC{i+1}' for i in range(num_components)] + + loadings_df = pd.DataFrame(loadings.T, index=feature_names, columns=pc_columns) # Use dynamic column names + + output_path = self.output_path / f"loadings_{output_prefix}.txt" + with open(output_path, 'w') as f: + f.write("PCA Loadings (Top {} Features for PC1 and PC2):\n\n".format(top_n)) + for pc_name in ['PC1', 'PC2']: + f.write(f"\n--- {pc_name} ---\n") + # Sort by absolute value of loading + top_loadings = loadings_df.sort_values(by=pc_name, key=lambda x: x.abs(), ascending=False).head(top_n) + for gene, loading in top_loadings[pc_name].items(): # Iterate over series items + f.write(f"{gene}:\t{loading:.4f}\n") # Tab-separated for readability return output_path diff --git a/src/visualize_expression.py b/src/visualize_expression.py deleted file mode 100644 index 4b9101b3..00000000 --- a/src/visualize_expression.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -Visualization and summary module for differential expression analysis results. -""" - -from pathlib import Path -import logging -import pandas as pd -import matplotlib.pyplot as plt -import numpy as np -from typing import Dict, List - - -class ExpressionVisualizer: - def __init__(self, output_path: Path): - """ - Initialize visualizer with output directory. - - Args: - output_path: Path to output directory - """ - self.output_path = Path(output_path) - self.output_path.mkdir(parents=True, exist_ok=True) - - def create_volcano_plot( - self, - df: pd.DataFrame, - target_label: str, - reference_label: str, - padj_threshold: float = 0.05, - lfc_threshold: float = 1, - top_n: int = 10, - ) -> None: - """Create volcano plot from differential expression results.""" - plt.figure(figsize=(10, 8)) - - # Prepare data - df["padj"] = df["padj"].replace(0, 1e-300) - df = df[df["padj"] > 0] - df = df.copy() # Create a copy to avoid the warning - df.loc[:, "-log10(padj)"] = -np.log10(df["padj"]) - - # Define significant genes - significant = (df["padj"] < padj_threshold) & ( - abs(df["log2FoldChange"]) > lfc_threshold - ) - up_regulated = significant & (df["log2FoldChange"] > lfc_threshold) - down_regulated = significant & (df["log2FoldChange"] < -lfc_threshold) - - # Plot points - plt.scatter( - df.loc[~significant, "log2FoldChange"], - df.loc[~significant, "-log10(padj)"], - color="grey", - alpha=0.5, - label="Not Significant", - ) - plt.scatter( - df.loc[up_regulated, "log2FoldChange"], - df.loc[up_regulated, "-log10(padj)"], - color="red", - alpha=0.7, - label=f"Up-regulated in ({target_label})", - ) - plt.scatter( - df.loc[down_regulated, "log2FoldChange"], - df.loc[down_regulated, "-log10(padj)"], - color="blue", - alpha=0.7, - label=f"Up-regulated in ({reference_label})", - ) - - # Add threshold lines and labels - plt.axhline(-np.log10(padj_threshold), color="grey", linestyle="--") - plt.axvline(lfc_threshold, color="grey", linestyle="--") - plt.axvline(-lfc_threshold, color="grey", linestyle="--") - - plt.xlabel("log2 Fold Change") - plt.ylabel("-log10(adjusted p-value)") - plt.title(f"Volcano Plot: {target_label} vs {reference_label}") - plt.legend() - - # Add labels for top significant features - sig_df = df.loc[significant].nsmallest(top_n, "padj") - for _, row in sig_df.iterrows(): - symbol = row["symbol"] if pd.notnull(row["symbol"]) else row["feature_id"] - plt.text( - row["log2FoldChange"], - row["-log10(padj)"], - symbol, - fontsize=8, - ha="center", - va="bottom", - ) - - plt.tight_layout() - plot_path = self.output_path / "volcano_plot.png" - plt.savefig(str(plot_path)) - plt.close() - logging.info(f"Volcano plot saved to {plot_path}") - - def create_ma_plot( - self, df: pd.DataFrame, target_label: str, reference_label: str - ) -> None: - """Create MA plot from differential expression results.""" - plt.figure(figsize=(10, 8)) - - # Prepare data - df = df[df["baseMean"] > 0] - df["log10(baseMean)"] = np.log10(df["baseMean"]) - - # Create plot - plt.scatter( - df["log10(baseMean)"], df["log2FoldChange"], alpha=0.5, color="grey" - ) - plt.axhline(y=0, color="red", linestyle="--") - - plt.xlabel("log10(Base Mean)") - plt.ylabel("log2 Fold Change") - plt.title(f"MA Plot: {target_label} vs {reference_label}") - - plt.tight_layout() - plot_path = self.output_path / "ma_plot.png" - plt.savefig(str(plot_path)) - plt.close() - logging.info(f"MA plot saved to {plot_path}") - - def create_summary( - self, - res_df: pd.DataFrame, - target_label: str, - reference_label: str, - min_count: int, - feature_type: str, - ) -> None: - """ - Create and save analysis summary. - - Args: - res_df: Results DataFrame - target_label: Target condition label - reference_label: Reference condition label - min_count: Minimum count threshold used in filtering - feature_type: Type of features analyzed ("genes" or "transcripts") - """ - total_features = len(res_df) - sig_features = ( - (res_df["padj"] < 0.05) & (res_df["log2FoldChange"].abs() > 1) - ).sum() - up_regulated = ((res_df["padj"] < 0.05) & (res_df["log2FoldChange"] > 1)).sum() - down_regulated = ( - (res_df["padj"] < 0.05) & (res_df["log2FoldChange"] < -1) - ).sum() - - summary_path = self.output_path / "analysis_summary.txt" - with summary_path.open("w") as f: - f.write(f"Analysis Summary: {target_label} vs {reference_label}\n") - f.write("================================\n") - f.write( - f"{feature_type.capitalize()} after filtering " - f"(mean count >= {min_count} in both groups): {total_features}\n" - ) - f.write(f"Significantly differential {feature_type}: {sig_features}\n") - f.write(f"Up-regulated {feature_type}: {up_regulated}\n") - f.write(f"Down-regulated {feature_type}: {down_regulated}\n") - logging.info(f"Analysis summary saved to {summary_path}") - - def visualize_results( - self, - results: pd.DataFrame, - target_label: str, - reference_label: str, - min_count: int, - feature_type: str, - ) -> None: - """ - Create all visualizations and summary for the analysis results. - - Args: - results: DataFrame containing differential expression results - target_label: Target condition label - reference_label: Reference condition label - min_count: Minimum count threshold used in filtering - feature_type: Type of features analyzed ("genes" or "transcripts") - """ - try: - self.create_volcano_plot(results, target_label, reference_label) - self.create_ma_plot(results, target_label, reference_label) - self.create_summary( - results, target_label, reference_label, min_count, feature_type - ) - except Exception as e: - logging.exception("Failed to create visualizations") - raise diff --git a/visualize.py b/visualize.py index 774a9995..f4b04fd8 100755 --- a/visualize.py +++ b/visualize.py @@ -15,7 +15,7 @@ def setup_logging(viz_output_dir: Path) -> None: # Create formatters file_formatter = logging.Formatter( - '%(asctime)s - %(levelname)s - %(message)s' + '%(asctime)s - %(levelname)s - %(module)s - %(funcName)s - %(levelname)s - %(message)s' ) console_formatter = logging.Formatter('%(levelname)s: %(message)s') @@ -26,24 +26,17 @@ def setup_logging(viz_output_dir: Path) -> None: # Console handler - less detailed console_handler = logging.StreamHandler() - console_handler.setLevel(logging.DEBUG) + console_handler.setLevel(logging.INFO) # Console output at INFO level console_handler.setFormatter(console_formatter) # Configure root logger root_logger = logging.getLogger() - root_logger.setLevel(logging.DEBUG) + root_logger.setLevel(logging.DEBUG) # Root logger at DEBUG level root_logger.handlers = [] # Clear existing handlers root_logger.addHandler(file_handler) root_logger.addHandler(console_handler) - # Create logger for the visualization package - viz_logger = logging.getLogger('IsoQuant.visualization') - viz_logger.setLevel(logging.DEBUG) - - # Create logger for differential expression - diff_logger = logging.getLogger('IsoQuant.visualization.differential_exp') - diff_logger.setLevel(logging.DEBUG) - + logging.info("Initialized centralized logging system") logging.debug(f"Log file location: {log_file}") @@ -285,7 +278,7 @@ def main(): dictionary_builder=dictionary_builder, ) gene_results, transcript_results, _, deseq2_df = diff_analysis.run_complete_analysis() - find_genes_list_path = gene_results.parent / "genes_from_top_100_transcripts.txt" + find_genes_list_path = gene_results.parent / "genes_of_top_100_DE_transcripts.txt" gene_list = dictionary_builder.read_gene_list(find_genes_list_path) if args.gsea: @@ -294,27 +287,27 @@ def main(): gsea.run_gsea_analysis(deseq2_df, target_label) # Use genes from top transcripts instead of top genes - find_genes_list_path = gene_results.parent / "genes_from_top_100_transcripts.txt" + find_genes_list_path = gene_results.parent / "genes_of_top_100_DE_transcripts.txt" gene_list = dictionary_builder.read_gene_list(find_genes_list_path) else: base_dir = viz_output_dir - if update_names: - logging.info("Updating Ensembl IDs to gene symbols.") - updated_gene_dict = dictionary_builder.update_gene_names(updated_gene_dict) - # 5. Set up output directories - read_assignments_dir = base_dir / "read_assignments" gene_visualizations_dir = base_dir / "gene_visualizations" - read_assignments_dir.mkdir(exist_ok=True) gene_visualizations_dir.mkdir(exist_ok=True) + if use_read_assignments: + read_assignments_dir = base_dir / "read_assignments" + read_assignments_dir.mkdir(exist_ok=True) + else: + read_assignments_dir = None # Set to None if not used + # 6. Plotting with PlotOutput plot_output = PlotOutput( updated_gene_dict, gene_list, str(gene_visualizations_dir), - read_assignments_dir=str(read_assignments_dir), + read_assignments_dir=str(read_assignments_dir), # Pass None if not used reads_and_class=reads_and_class, filter_transcripts=min_val, # Just pass your chosen threshold for reference conditions=output.conditions, From d855c0f84cf5804d32af2bf4eea1bf056c224d78 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Thu, 27 Feb 2025 12:01:51 -0600 Subject: [PATCH 27/35] Modified GSEA approach --- src/visualization_gsea.py | 90 +++++++++++++++++++++++++----------- src/visualization_plotter.py | 2 +- 2 files changed, 65 insertions(+), 27 deletions(-) diff --git a/src/visualization_gsea.py b/src/visualization_gsea.py index 619be95b..4047c43e 100644 --- a/src/visualization_gsea.py +++ b/src/visualization_gsea.py @@ -8,6 +8,7 @@ from rpy2.robjects.packages import importr from rpy2.robjects.conversion import localconverter from rpy2.rinterface_lib import callbacks +from matplotlib.patches import Patch class GSEAAnalysis: @@ -48,32 +49,27 @@ def run_gsea_analysis(self, results: pd.DataFrame, target_label: str) -> None: logging.debug(f"Full DE results shape: {results.shape}") logging.debug(f"DE results columns: {results.columns.tolist()}") - # Filter for significant DE genes - sig_genes = results.dropna(subset=["padj"]) - logging.debug(f"After dropping genes with NaN padj: {sig_genes.shape}") - sig_genes = sig_genes[sig_genes["padj"] < 0.05] - logging.debug(f"Significantly DE genes (padj<0.05): {sig_genes.shape}") - if sig_genes.empty: - logging.info("No significantly DE genes found for GSEA.") - return - - sig_genes = sig_genes[sig_genes["pvalue"] > 0] - logging.debug(f"Significantly DE genes with pvalue>0: {sig_genes.shape}") - if sig_genes.empty: - logging.info("No genes with valid p-values for GSEA.") + # Don't filter for significant genes - use ALL genes with valid statistics + # Just remove NaN values + valid_genes = results.dropna(subset=["stat", "gene_name"]) + logging.debug(f"Genes with valid statistics: {valid_genes.shape}") + + if valid_genes.empty: + logging.info("No genes with valid statistics found for GSEA.") return # Use gene_name instead of symbol - gene_symbols_final = sig_genes["gene_name"].values + gene_symbols = valid_genes["gene_name"].values + # Create ranked list using ALL genes (not just significant ones) ranked_genes = pd.Series( - sig_genes["stat"].values, index=gene_symbols_final + valid_genes["stat"].values, index=gene_symbols ).dropna() ranked_genes = ranked_genes[~ranked_genes.index.duplicated(keep="first")] logging.debug(f"Final ranked genes count: {len(ranked_genes)}") - if ranked_genes.empty: - logging.info("No valid ranked genes after processing.") + if ranked_genes.empty or len(ranked_genes) < 50: # Ensure we have enough genes + logging.info(f"Not enough valid ranked genes for GSEA: {len(ranked_genes)}") return # Save the ranked genes @@ -98,18 +94,56 @@ def plot_pathways(df: pd.DataFrame, direction: str, ont: str): df["label"] = df["ID"] + ": " + df["Description"] df["-log10(p.adjust)"] = -np.log10(df["p.adjust"]) + + # Sort by NES value - for up-regulated, highest NES first; for down-regulated, lowest NES first + if direction == "up": + df = df.sort_values(by="NES", ascending=False) + plot_values = df["NES"] # Use NES directly for up-regulated + else: # down + df = df.sort_values(by="NES", ascending=True) + plot_values = df["NES"].abs() # Use absolute NES for down-regulated + + # Use NES for bar length but -log10(p.adjust) for color values = df["-log10(p.adjust)"] - norm = plt.Normalize(vmin=values.min(), vmax=values.max()) + + # Use the data's own range for each direction + vmin = values.min() + vmax = values.max() + + norm = plt.Normalize(vmin=vmin, vmax=vmax) cmap = plt.cm.get_cmap("viridis") colors_for_bars = [cmap(norm(v)) for v in values] - plt.figure(figsize=(12, 8)) + plt.figure(figsize=(14, 8)) # Wider figure to accommodate legend + + # Use NES for bar length (absolute value for down-regulated) plt.barh( df["label"].iloc[::-1], - df["-log10(p.adjust)"].iloc[::-1], - color=colors_for_bars[::-1], + plot_values.iloc[::-1], # Use appropriate values based on direction + color=colors_for_bars[::-1], # Still color by significance ) - plt.xlabel("-log10(adjusted p-value)") + + # Add a colorbar to show the significance scale + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label("-log10(adjusted p-value)") + + # Add legend explaining the visualization + legend_elements = [ + Patch(facecolor='gray', alpha=0.5, + label='Bar length: Normalized Enrichment Score (NES)'), + Patch(facecolor=cmap(0.25), alpha=0.8, + label='Bar color: Statistical significance'), + ] + # Move legend much further to the right + plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.25, 1)) + + # Set x-axis label based on direction + if direction == "up": + plt.xlabel("Normalized Enrichment Score (NES)") + else: + plt.xlabel("Absolute Normalized Enrichment Score (|NES|)") # Split target label into reference and target parts target_parts = target_label.split("_vs_") @@ -123,9 +157,12 @@ def plot_pathways(df: pd.DataFrame, direction: str, ont: str): condition_str = f"Pathways enriched in {ref_condition}\nvs {target_condition} - {ont}" plt.title(condition_str, fontsize=10) + + # Adjust layout to make room for legend plt.tight_layout() - plot_path = self.output_path / f"GSEA_top_pathways_{direction}_{ont}.png" - plt.savefig(plot_path) + # Save with extra space for the legend + plot_path = self.output_path / f"GSEA_top_pathways_{direction}_{ont}.pdf" + plt.savefig(plot_path, format="pdf", bbox_inches="tight", dpi=300) plt.close() logging.info(f"GSEA {direction} pathways plot saved to {plot_path}") @@ -181,12 +218,13 @@ def plot_pathways(df: pd.DataFrame, direction: str, ont: str): ].copy() # Using 0.05 threshold if not sig_gsea_df.empty: up_pathways = sig_gsea_df[sig_gsea_df["NES"] > 0].nsmallest( - 10, "p.adjust" + 15, "p.adjust" ) down_pathways = sig_gsea_df[sig_gsea_df["NES"] < 0].nsmallest( - 10, "p.adjust" + 15, "p.adjust" ) + # Use separate color scales for each direction if not up_pathways.empty: plot_pathways(up_pathways, "up", ont) if not down_pathways.empty: diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 2efafdfd..17a95e7d 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -371,7 +371,7 @@ def create_volcano_plot( df.loc[down_regulated, "-log10(padj)"], color="blue", alpha=0.7, - label=f"Up-regulated in ({reference_label})", + label=f"Down-regulated in ({target_label})", ) # Add threshold lines and labels From ac0f6dfe22c20d0aa83ecd8363c22bb3f52d3712 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Mon, 31 Mar 2025 23:50:27 -0500 Subject: [PATCH 28/35] New plotter that inherits filtering --- src/visualization_plotter.py | 702 ++++++++++++++++++++++++++++++++--- 1 file changed, 643 insertions(+), 59 deletions(-) diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 17a95e7d..55491b01 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -7,6 +7,8 @@ import matplotlib.patches as patches import seaborn as sns from typing import List +from matplotlib.colors import Normalize +import matplotlib.cm as cm class PlotOutput: def __init__( @@ -18,18 +20,32 @@ def __init__( reads_and_class=None, filter_transcripts=None, conditions=False, - use_counts=False, ref_only=False, + ref_conditions=None, + target_conditions=None, ): self.updated_gene_dict = updated_gene_dict self.gene_names = gene_names self.gene_visualizations_dir = gene_visualizations_dir self.read_assignments_dir = read_assignments_dir self.reads_and_class = reads_and_class - self.filter_transcripts = filter_transcripts self.conditions = conditions - self.use_counts = use_counts self.ref_only = ref_only + self.min_tpm_threshold = filter_transcripts + + # Explicitly set reference and target conditions + self.ref_conditions = ref_conditions if ref_conditions else [] + self.target_conditions = target_conditions if target_conditions else [] + + # Log conditions for debugging + if self.ref_conditions and self.target_conditions: + logging.info(f"Filtering plots to include only ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") + else: + logging.warning("No ref_conditions or target_conditions set, filtering may not work correctly") + + # Log TPM threshold if set + if self.min_tpm_threshold: + logging.info(f"Filtering transcripts with TPM value < {self.min_tpm_threshold}") # Ensure output directories exist if self.gene_visualizations_dir: @@ -42,24 +58,58 @@ def plot_transcript_map(self): logging.warning("No gene_visualizations_dir provided. Skipping transcript map plotting.") return - for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) - gene_data = None # Initialize gene_data to None - found_by_name = False # Flag to track if gene was found by name - - for condition, genes in self.updated_gene_dict.items(): - for gene_id, gene_info in genes.items(): - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) - gene_data = gene_info - found_by_name = True - break # Found gene, break inner loop - if found_by_name: - break # Found gene, break outer loop + # Check if reference and target conditions are defined + has_specific_conditions = (hasattr(self, 'ref_conditions') and hasattr(self, 'target_conditions') and + self.ref_conditions and self.target_conditions) + + if has_specific_conditions: + logging.info(f"Filtering transcript map to include only ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") + # Define all allowed conditions + allowed_conditions = set(self.ref_conditions + self.target_conditions) + + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) + gene_data = None # Initialize gene_data to None + found_condition = None # Track which condition we found the gene in + + # First pass: Try to find the gene in allowed conditions only + if has_specific_conditions: + # Search only in allowed conditions + for condition in allowed_conditions: + if condition not in self.updated_gene_dict: + continue + + genes = self.updated_gene_dict[condition] + for gene_id, gene_info in genes.items(): + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) + gene_data = gene_info + found_condition = condition + break + if gene_data: + break # Found gene, stop searching + + # Second pass: If not found and we're allowing fallback, try all conditions + if not gene_data: + for condition, genes in self.updated_gene_dict.items(): + # Skip conditions we already checked if using specific conditions + if has_specific_conditions and condition in allowed_conditions: + continue + + for gene_id, gene_info in genes.items(): + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): + gene_data = gene_info + found_condition = condition + break + if gene_data: + break # Found gene, stop searching if gene_data: - logging.debug(f"Gene {gene_name_or_id} found by name in the data.") + if has_specific_conditions and found_condition in allowed_conditions: + logging.debug(f"Gene {gene_name_or_id} found in prioritized condition: {found_condition}") + else: + logging.debug(f"Gene {gene_name_or_id} found in fallback condition: {found_condition}") else: - logging.warning(f"Gene {gene_name_or_id} not found in the data.") - continue # Skip to the next gene if not found + logging.warning(f"Gene {gene_name_or_id} not found in any condition.") + continue # Skip to the next gene if not found # Get chromosome info and calculate buffer chromosome = gene_data.get("chromosome", "Unknown") @@ -72,10 +122,76 @@ def plot_transcript_map(self): plot_start = start - buffer plot_end = end + buffer - plot_height = max(8, len(gene_data["transcripts"]) * 0.4) - logging.debug(f"Creating transcript map for gene '{gene_name_or_id}' with {len(gene_data['transcripts'])} transcripts") - - + # NEW APPROACH: If we have ref/target conditions AND TPM filtering, + # we need to consider transcript expression across ALL relevant conditions + if has_specific_conditions and self.min_tpm_threshold is not None: + # First, collect the max TPM for each transcript across all ref/target conditions + transcript_max_tpm = {} + + # Collect all transcripts from the current condition first + for transcript_id, transcript_info in gene_data["transcripts"].items(): + value = float(transcript_info.get("value", 0)) + transcript_max_tpm[transcript_id] = value + + # Check other ref/target conditions for the same gene to find max TPM values + for condition in allowed_conditions: + if condition == found_condition or condition not in self.updated_gene_dict: + continue # Skip the condition we already processed + + genes = self.updated_gene_dict[condition] + for gene_id, gene_info in genes.items(): + # Check if this is the same gene in another condition + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): + # Found the same gene in another condition, check transcript TPM values + for transcript_id, transcript_info in gene_info["transcripts"].items(): + value = float(transcript_info.get("value", 0)) + # Update max TPM if higher in this condition + if transcript_id in transcript_max_tpm: + transcript_max_tpm[transcript_id] = max(transcript_max_tpm[transcript_id], value) + else: + transcript_max_tpm[transcript_id] = value + break # Found the gene in this condition, no need to check other genes + + # Now filter transcripts based on their max TPM across ref/target conditions + filtered_transcripts = {} + for transcript_id, transcript_info in gene_data["transcripts"].items(): + max_tpm = transcript_max_tpm.get(transcript_id, 0) + if max_tpm >= self.min_tpm_threshold: + filtered_transcripts[transcript_id] = transcript_info + + # Log filtering results + total_transcripts = len(gene_data["transcripts"]) + filtered_count = len(filtered_transcripts) + logging.debug(f"Cross-condition TPM filtering: {filtered_count} of {total_transcripts} transcripts have TPM >= {self.min_tpm_threshold} in any ref/target condition for gene {gene_name_or_id}") + else: + # Original filtering approach for single condition + filtered_transcripts = {} + total_transcripts = len(gene_data["transcripts"]) + filtered_count = 0 + + for transcript_id, transcript_info in gene_data["transcripts"].items(): + # Apply TPM filtering if threshold is set + if self.min_tpm_threshold is not None: + value = float(transcript_info.get("value", 0)) + if value >= self.min_tpm_threshold: + filtered_transcripts[transcript_id] = transcript_info + filtered_count += 1 + else: + # No filtering, include all transcripts + filtered_transcripts[transcript_id] = transcript_info + + if self.min_tpm_threshold is not None: + logging.debug(f"Single-condition TPM filtering: {filtered_count} of {total_transcripts} transcripts have TPM >= {self.min_tpm_threshold} for gene {gene_name_or_id}") + + # Skip plotting if no transcripts pass the filter + if not filtered_transcripts: + logging.warning(f"No transcripts for gene {gene_name_or_id} pass the TPM threshold of {self.min_tpm_threshold}. Skipping plot.") + continue + + # Calculate plot height based on number of filtered transcripts + num_transcripts = len(filtered_transcripts) + plot_height = max(8, num_transcripts * 0.4) + #logging.debug(f"Creating transcript map for gene '{gene_name_or_id}' with {num_transcripts} transcripts from {found_condition}") fig, ax = plt.subplots(figsize=(12, plot_height)) @@ -89,7 +205,7 @@ def plot_transcript_map(self): # Plot each transcript y_ticks = [] y_labels = [] - for i, (transcript_id, transcript_info) in enumerate(gene_data["transcripts"].items()): + for i, (transcript_id, transcript_info) in enumerate(filtered_transcripts.items()): # Plot direction marker direction_marker = ">" if gene_data["strand"] == "+" else "<" marker_pos = ( @@ -133,22 +249,35 @@ def plot_transcript_map(self): if not any(exon["exon_id"].startswith("ENSE") for exon in transcript_info["exons"]): logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains NO reference exons (based on ENSEMBL IDs)") #log the exon_ids - logging.debug(f"Exon IDs: {[exon['exon_id'] for exon in transcript_info['exons']]}") + #logging.debug(f"Exon IDs: {[exon['exon_id'] for exon in transcript_info['exons']]}") else: - logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains at least one reference exon (based on ENSEMBL IDs)") - + #logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains at least one reference exon (based on ENSEMBL IDs)") + pass # Added explicit pass statement for the empty block + # Store y-axis label information y_ticks.append(i) # Get transcript name with fallback options transcript_name = (transcript_info.get("name") or - transcript_info.get("transcript_id") or - transcript_id) - y_labels.append(transcript_name) + transcript_info.get("transcript_id") or + transcript_id) + value = float(transcript_info.get("value", 0)) + + # If cross-condition filtering was used, show the max TPM value in the label + if has_specific_conditions and self.min_tpm_threshold is not None: + max_tpm = transcript_max_tpm.get(transcript_id, 0) + y_labels.append(f"{transcript_name}") + else: + y_labels.append(f"{transcript_name}") # Set up the plot formatting with just chromosome gene_display_name = gene_data.get("name", gene_name_or_id) # Fallback to ID if no name - if self.filter_transcripts: - title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (Count > {self.filter_transcripts})" + + # Update title to include TPM threshold if applied + if self.min_tpm_threshold is not None: + if has_specific_conditions: + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (TPM >= {self.min_tpm_threshold} in any ref/target condition)" + else: + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (TPM >= {self.min_tpm_threshold})" else: title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome})" @@ -172,7 +301,7 @@ def plot_transcript_map(self): plt.tight_layout() plot_path = os.path.join( - self.gene_visualizations_dir, f"{gene_name_or_id}_splicing.png" # Use gene_name_or_id in filename + self.gene_visualizations_dir, f"{gene_name_or_id}_splicing.pdf" # Changed from .png to .pdf ) plt.savefig(plot_path, bbox_inches='tight', dpi=300) plt.close(fig) @@ -184,28 +313,113 @@ def plot_transcript_usage(self): if not self.gene_visualizations_dir: logging.warning("No gene_visualizations_dir provided. Skipping transcript usage plotting.") return - - for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) - gene_data_per_condition = {} # Store gene data per condition - found_gene_any_condition = False # Flag if gene found in any condition - + + # Add this section near the beginning of the method + logging.info("=== SPECIAL DEBUG FOR YBX1 TRANSCRIPTS ===") + for condition in self.updated_gene_dict: + for gene_id, gene_info in self.updated_gene_dict[condition].items(): + if gene_info.get("name") == "YBX1" or gene_id == "YBX1": + logging.info(f"Found YBX1 in condition {condition}") + logging.info(f"Total transcripts before filtering: {len(gene_info.get('transcripts', {}))}") + for transcript_id, transcript_info in gene_info.get('transcripts', {}).items(): + value = transcript_info.get('value', 0) + logging.info(f" Transcript {transcript_id}: TPM = {value:.2f}") + + # Check if reference and target conditions are defined + has_specific_conditions = (hasattr(self, 'ref_conditions') and hasattr(self, 'target_conditions') and + self.ref_conditions and self.target_conditions) + + if has_specific_conditions: + logging.debug(f"Filtering transcript usage plot to include only ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") + # Define all allowed conditions + allowed_conditions = set(self.ref_conditions + self.target_conditions) + + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) + gene_data_per_condition = {} # Store gene data per condition + found_gene_any_condition = False # Flag if gene found in any condition + + # Only process allowed conditions if specific conditions are defined for condition, genes in self.updated_gene_dict.items(): + # Skip conditions that aren't in ref or target if we have those defined + if has_specific_conditions and condition not in allowed_conditions: + #logging.debug(f"Skipping condition {condition} for gene {gene_name_or_id} (not in allowed conditions)") + continue + condition_gene_data = None for gene_id, gene_info in genes.items(): - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) - condition_gene_data = gene_info["transcripts"] # Only need transcripts for usage plot + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) + condition_gene_data = gene_info["transcripts"] # Only need transcripts for usage plot found_gene_any_condition = True - break # Found gene in this condition, break inner loop + #logging.debug(f"Found gene {gene_name_or_id} in condition {condition}") + break # Found gene in this condition, break inner loop if condition_gene_data: - gene_data_per_condition[condition] = condition_gene_data # Store transcripts for this condition - + gene_data_per_condition[condition] = condition_gene_data # Store transcripts for this condition + if not found_gene_any_condition: - logging.warning(f"Gene {gene_name_or_id} not found in the data.") - continue # Skip to the next gene if not found - + logging.warning(f"Gene {gene_name_or_id} not found in any allowed condition.") + continue # Skip to the next gene if not found + + # NEW APPROACH: If TPM threshold is set, collect max TPM for each transcript across all conditions + if self.min_tpm_threshold is not None: + # Create list of all allowed conditions + allowed_conditions = set(self.ref_conditions + self.target_conditions) if has_specific_conditions else set(gene_data_per_condition.keys()) + + # First, collect the max TPM for each transcript ONLY in allowed conditions + transcript_max_tpm = {} + + for condition, transcripts in gene_data_per_condition.items(): + # Skip conditions that aren't in ref or target if we have those defined + if has_specific_conditions and condition not in allowed_conditions: + continue + + for transcript_id, transcript_info in transcripts.items(): + value = float(transcript_info.get("value", 0)) + if transcript_id not in transcript_max_tpm: + transcript_max_tpm[transcript_id] = value + else: + transcript_max_tpm[transcript_id] = max(transcript_max_tpm[transcript_id], value) + + # Filter out transcripts that don\'t meet threshold in ANY allowed condition + valid_transcripts = {t_id: t_val for t_id, t_val in transcript_max_tpm.items() + if t_val >= self.min_tpm_threshold} + + # Now keep ALL instances of valid transcripts in allowed conditions, even if below threshold + filtered_gene_data_per_condition = {} + for condition, transcripts in gene_data_per_condition.items(): + # Skip conditions that aren\'t in ref or target if we have those defined + if has_specific_conditions and condition not in allowed_conditions: + continue + + filtered_transcripts = {} + for transcript_id, transcript_info in transcripts.items(): + # Include this transcript if it\'s in our valid list + if transcript_id in valid_transcripts: + filtered_transcripts[transcript_id] = transcript_info + + # Only include conditions with transcripts + if filtered_transcripts: + filtered_gene_data_per_condition[condition] = filtered_transcripts + + # Replace original data with filtered data + gene_data_per_condition = filtered_gene_data_per_condition + + # Log filtering results + total_unique_transcripts = len(transcript_max_tpm) + kept_transcripts = len(valid_transcripts) + logging.debug(f"TPM filtering for {gene_name_or_id}: {kept_transcripts} of {total_unique_transcripts} unique transcripts have TPM >= {self.min_tpm_threshold} in at least one ref/target condition") + + if not gene_data_per_condition: + logging.warning(f"No data available for gene {gene_name_or_id} after filtering. Skipping plot.") + continue + conditions = list(gene_data_per_condition.keys()) n_bars = len(conditions) - + + if n_bars == 0: + logging.warning(f"No conditions found for gene {gene_name_or_id} after filtering.") + continue + + fig, ax = plt.subplots(figsize=(12, 8)) index = np.arange(n_bars) bar_width = 0.35 @@ -216,6 +430,9 @@ def plot_transcript_usage(self): bottom_val = np.zeros(n_bars) for i, condition in enumerate(conditions): transcripts = gene_data_per_condition[condition] + if not transcripts: # Skip if no transcript data for this condition + continue + for j, (transcript_id, transcript_info) in enumerate(transcripts.items()): color = colors[j % len(colors)] value = transcript_info["value"] @@ -234,10 +451,16 @@ def plot_transcript_usage(self): ) bottom_val[i] += float(value) - ax.set_xlabel("Sample Type") + ax.set_xlabel("Condition") ax.set_ylabel("Transcript Usage (TPM)") - gene_display_name = list(gene_data_per_condition.values())[0].get("name", gene_name_or_id) # Fallback to ID if no name - ax.set_title(f"Transcript Usage for {gene_display_name} by Sample Type") + gene_display_name = list(gene_data_per_condition.values())[0].get("name", gene_name_or_id) # Fallback to ID if no name + + # Update title to include TPM threshold if applied + if self.min_tpm_threshold is not None: + ax.set_title(f"Transcript Usage for {gene_display_name} by Condition (TPM >= {self.min_tpm_threshold} in any ref/target condition)") + else: + ax.set_title(f"Transcript Usage for {gene_display_name} by Condition") + ax.set_xticks(index) ax.set_xticklabels(conditions) ax.legend( @@ -250,7 +473,7 @@ def plot_transcript_usage(self): plt.tight_layout() plot_path = os.path.join( self.gene_visualizations_dir, - f"{gene_name_or_id}_transcript_usage_by_sample_type.png", # Use gene_name_or_id in filename + f"{gene_name_or_id}_transcript_usage_by_sample_type.pdf", # Changed from .png to .pdf ) plt.savefig(plot_path) plt.close(fig) @@ -260,17 +483,28 @@ def make_pie_charts(self): Create pie charts for transcript alignment classifications and read assignment consistency. Handles both combined and separate sample data structures. """ + # Skip if reads_and_class is not provided + if not self.reads_and_class: + logging.warning("No reads_and_class data provided. Skipping pie chart creation.") + return titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] - + + # Check if reference and target conditions are defined + has_specific_conditions = hasattr(self, 'ref_conditions') and hasattr(self, 'target_conditions') + for title, data in zip(titles, self.reads_and_class): if isinstance(data, dict): if any(isinstance(v, dict) for v in data.values()): - # Separate 'Mutants' and 'WildType' case + # Separate sample data case (e.g. 'Mutants' and 'WildType') for sample_name, sample_data in data.items(): + # Skip conditions that aren't in ref or target if we have those defined + if has_specific_conditions and sample_name not in self.ref_conditions and sample_name not in self.target_conditions: + logging.debug(f"Skipping pie chart for {sample_name} (not in ref/target conditions)") + continue self._create_pie_chart(f"{title} - {sample_name}", sample_data) else: - # Combined data case + # Combined data case - always create this as it's an overall summary self._create_pie_chart(title, data) else: print(f"Skipping unexpected data type for {title}: {type(data)}") @@ -310,11 +544,361 @@ def _create_pie_chart(self, title, data): ) # Save pie charts in the read_assignments directory plot_path = os.path.join( - self.read_assignments_dir, f"{file_title}_pie_chart.png" + self.read_assignments_dir, f"{file_title}_pie_chart.pdf" # Changed from .png to .pdf ) plt.savefig(plot_path, bbox_inches="tight", dpi=300) plt.close() + def plot_novel_transcript_contribution(self): + """ + Creates a plot showing the percentage of expression from novel transcripts. + - Y-axis: Percentage of expression from novel transcripts (combined across conditions) + - X-axis: Expression log2 fold change between conditions + - Point size: Overall expression level + - Color: Red (target) to Blue (reference) indicating which condition contributes more to novel transcript expression + """ + logging.info("Creating novel transcript contribution plot") + + # Skip if we don't have reference vs target conditions defined + if not hasattr(self, 'ref_conditions') or not hasattr(self, 'target_conditions'): + logging.warning("Cannot create novel transcript plot: missing reference or target conditions") + return + + # Get actual condition labels + ref_label = "+".join(self.ref_conditions) + target_label = "+".join(self.target_conditions) + + # Set TPM threshold for transcript inclusion + min_tpm_threshold = 10 + + # Track all unique genes across all conditions + all_genes = {} # Dictionary to track gene_id -> gene_info mapping across conditions + + # First, collect all genes from all conditions + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + gene_name = gene_info.get('name', gene_id) + if gene_id not in all_genes: + all_genes[gene_id] = {'name': gene_name, 'conditions': {}} + + # Store condition-specific data + all_genes[gene_id]['conditions'][condition] = gene_info + + logging.info(f"Total unique genes found across all conditions: {len(all_genes)}") + + # First, let's investigate the discrepancy between our filtering and the GTF filtering + logging.info("Investigating transcript filtering discrepancy") + + # Track unique transcript IDs that pass our threshold (to match GTF filtering method) + unique_transcripts_above_threshold = set() + transcript_values = {} # To store max values for debugging + + # First pass: identify all unique transcripts with TPM >= threshold + for condition, genes in self.updated_gene_dict.items(): + # Check if this condition is in our ref or target groups + is_relevant_condition = condition in self.ref_conditions or condition in self.target_conditions + if not is_relevant_condition: + continue + + for gene_id, gene_info in genes.items(): + transcripts = gene_info.get('transcripts', {}) + for transcript_id, transcript_info in transcripts.items(): + value = float(transcript_info.get("value", 0)) + + # Track max value for this transcript across all conditions + if transcript_id not in transcript_values: + transcript_values[transcript_id] = value + else: + transcript_values[transcript_id] = max(transcript_values[transcript_id], value) + + # Check if transcript meets threshold + if value >= min_tpm_threshold: + unique_transcripts_above_threshold.add(transcript_id) + + logging.info(f"Filtering comparison: Found {len(unique_transcripts_above_threshold)} unique transcripts with TPM >= {min_tpm_threshold}") + logging.info(f"Filtering comparison: This compares to 8,230 transcripts reported by GTF filter") + + # Analyze distribution of TPM values to understand filtering + tpm_value_counts = { + "0-1": 0, + "1-5": 0, + "5-10": 0, + "10-20": 0, + "20-50": 0, + "50-100": 0, + "100+": 0 + } + + for transcript_id, max_value in transcript_values.items(): + if max_value < 1: + tpm_value_counts["0-1"] += 1 + elif max_value < 5: + tpm_value_counts["1-5"] += 1 + elif max_value < 10: + tpm_value_counts["5-10"] += 1 + elif max_value < 20: + tpm_value_counts["10-20"] += 1 + elif max_value < 50: + tpm_value_counts["20-50"] += 1 + elif max_value < 100: + tpm_value_counts["50-100"] += 1 + else: + tpm_value_counts["100+"] += 1 + + logging.info(f"TPM value distribution across transcripts: {tpm_value_counts}") + + # Check if there are any TTN transcripts in the unique set + ttn_transcripts = [t for t in unique_transcripts_above_threshold if "TTN" in t.upper()] + if ttn_transcripts: + logging.info(f"Found {len(ttn_transcripts)} TTN transcripts in high TPM set: {ttn_transcripts}") + + # Prepare data storage for the main plot + plot_data = [] # Re-initialize plot_data here + + # Track transcripts that pass TPM threshold + total_transcripts = 0 + transcripts_above_threshold = 0 + total_genes = 0 + genes_with_high_expr_transcripts = 0 + + # Process each gene from all_genes + for gene_id, gene_data in all_genes.items(): + total_genes += 1 + gene_name = gene_data['name'] + conditions_data = gene_data['conditions'] + + # Calculate expression for each condition group + ref_total_exp = {cond: 0 for cond in self.ref_conditions} + target_total_exp = {cond: 0 for cond in self.target_conditions} + ref_novel_exp = {cond: 0 for cond in self.ref_conditions} + target_novel_exp = {cond: 0 for cond in self.target_conditions} + + # Track if this gene has any high-expression transcripts + gene_has_high_expr_transcript = False + + # Process each condition + for condition, gene_info in conditions_data.items(): + transcripts = gene_info.get('transcripts', {}) + if not transcripts: + continue + + # Check if this condition is in our condition groups + is_ref = condition in self.ref_conditions + is_target = condition in self.target_conditions + + if not (is_ref or is_target): + continue # Skip conditions that aren't in ref or target groups + + for transcript_id, transcript_info in transcripts.items(): + total_transcripts += 1 + + # Improved novel transcript identification - transcript is novel if not from Ensembl + transcript_is_reference = transcript_id.startswith("ENST") + is_novel = not transcript_is_reference + + value = float(transcript_info.get("value", 0)) + + # Filter by TPM threshold - only count transcripts with TPM >= threshold + # We only check TPM threshold in ref and target conditions (not other conditions) + if value >= min_tpm_threshold: + transcripts_above_threshold += 1 + gene_has_high_expr_transcript = True + + if is_ref: + ref_total_exp[condition] += value + if is_novel: + ref_novel_exp[condition] += value + + if is_target: + target_total_exp[condition] += value + if is_novel: + target_novel_exp[condition] += value + elif gene_name == "TTN" and condition in self.ref_conditions: + pass # Add pass to avoid empty block + + # Count genes with high-expression transcripts + if gene_has_high_expr_transcript: + genes_with_high_expr_transcripts += 1 + + # Calculate average expression for each condition group + ref_novel_pct = 0 + target_novel_pct = 0 + ref_expr_total = 0 + target_expr_total = 0 + ref_novel_expr_total = 0 + target_novel_expr_total = 0 + + # Sum up expression values across conditions + for cond in self.ref_conditions: + ref_expr_total += ref_total_exp.get(cond, 0) + ref_novel_expr_total += ref_novel_exp.get(cond, 0) + + # Also calculate percentages per condition for color coding + if ref_total_exp.get(cond, 0) > 0: + ref_novel_pct += (ref_novel_exp.get(cond, 0) / ref_total_exp[cond]) * 100 + + for cond in self.target_conditions: + target_expr_total += target_total_exp.get(cond, 0) + target_novel_expr_total += target_novel_exp.get(cond, 0) + + # Also calculate percentages per condition for color coding + if target_total_exp.get(cond, 0) > 0: + target_novel_pct += (target_novel_exp.get(cond, 0) / target_total_exp[cond]) * 100 + + # Average the condition-specific percentages (for color coding only) + ref_novel_pct /= len([c for c in self.ref_conditions if c in ref_total_exp and ref_total_exp[c] > 0]) or 1 + target_novel_pct /= len([c for c in self.target_conditions if c in target_total_exp and target_total_exp[c] > 0]) or 1 + + # Calculate overall novel percentage (for y-axis) + combined_expr_total = ref_expr_total + target_expr_total + combined_novel_expr_total = ref_novel_expr_total + target_novel_expr_total + + # Calculate log2 fold change using the total expression values + if ref_expr_total > 0 and target_expr_total > 0: + log2fc = np.log2(target_expr_total / ref_expr_total) + + # Calculate novel transcript contribution difference (for color) + novel_pct_diff = target_novel_pct - ref_novel_pct + + # Calculate overall novel percentage (for y-axis) + if combined_expr_total > 0: + overall_novel_pct = (combined_novel_expr_total / combined_expr_total) * 100 + + # Add data point + plot_data.append({ + 'gene_id': gene_id, + 'gene_name': gene_name, + 'ref_novel_pct': ref_novel_pct, + 'target_novel_pct': target_novel_pct, + 'novel_pct_diff': novel_pct_diff, + 'overall_novel_pct': overall_novel_pct, + 'log2fc': log2fc, + 'total_expr': combined_expr_total + }) + + # Report filtering results + logging.info(f"TPM filtering: {transcripts_above_threshold} of {total_transcripts} transcripts have TPM >= {min_tpm_threshold} in ref or target conditions") + logging.info(f"TPM filtering: {genes_with_high_expr_transcripts} of {total_genes} genes have at least one transcript with TPM >= {min_tpm_threshold} in ref or target conditions") + + # Create dataframe + df = pd.DataFrame(plot_data) + # Get the parent directory of gene_visualizations_dir + parent_dir = os.path.dirname(self.gene_visualizations_dir) + + # Save the CSV to parent directory instead of gene_visualizations_dir + df.to_csv(os.path.join(parent_dir, "novel_transcript_expression_data.csv"), index=False) + + # Log the number of genes used in the plot + logging.info(f"Number of genes used in novel transcript plot after transcript-level TPM filtering: {len(df)}") + + if df.empty: + logging.warning("No data available for novel transcript plot after transcript-level TPM filtering") + return + + # Create the plot with more space on right for legend + plt.figure(figsize=(16, 10)) # Increased width from 14 to 16 + + # Define red-blue colormap + norm = Normalize(vmin=-50, vmax=50) # Normalize based on difference range + cmap = cm.get_cmap('coolwarm') # Red-Blue colormap + + # More dramatic scaling for point sizes + min_size = 30 + max_size = 800 # Much larger maximum size + + # Use np.power for more dramatic scaling differences + expression_values = df['total_expr'].values + max_expr = expression_values.max() + min_expr = expression_values.min() + + # Log the actual min and max expression values for reference + logging.debug(f"Expression range in data: min={min_expr}, max={max_expr}") + + # Define the scaling function that will be used for both data points and legend + def scale_point_size(expr_value, min_expr, max_expr, min_size, max_size, power=0.3): + """Scale expression values to point sizes using the same formula for data and legend""" + # Normalize the expression value to [0,1] range + if max_expr == min_expr: # Avoid division by zero + normalized = 0 + else: + normalized = (expr_value - min_expr) / (max_expr - min_expr) + # Apply power scaling and convert to point size + return min_size + (max_size - min_size) * (normalized ** power) + + # Apply scaling to actual data points + scaled_sizes = [scale_point_size(val, min_expr, max_expr, min_size, max_size) for val in expression_values] + + # Plot points with scaled sizes + sc = plt.scatter(df['log2fc'], df['overall_novel_pct'], + s=scaled_sizes, + c=df['novel_pct_diff'], + cmap=cmap, + norm=norm, + alpha=0.8, + edgecolors='black') + + # Add color legend on the right + cbar = plt.colorbar(sc, orientation='vertical', pad=0.02) + cbar.set_label('Novel transcript usage difference (%)', size=12) + cbar.ax.tick_params(labelsize=10) + + # Use red and blue blocks to explain the colormap + plt.figtext(0.92, 0.72, f'Blue = higher (%) in {self.ref_conditions}', fontsize=12, ha='center') + plt.figtext(0.92, 0.75, f'Red = higher (%) in {self.target_conditions}', fontsize=12, ha='center') + + # Add size legend directly to the plot + # Create legend elements for different sizes with new values: 50, 500, 5000 + size_legend_values = [50, 500, 5000] + size_legend_elements = [] + + # Calculate sizes for legend using EXACTLY the same scaling function as for the data points + for val in size_legend_values: + # Use the same scaling function defined above + # If the value is outside the actual data range, clamp it to the range + clamped_val = min(max(val, min_expr), max_expr) + size = scale_point_size(clamped_val, min_expr, max_expr, min_size, max_size) + + # Log the actual size being used for the legend point + logging.debug(f"Legend point {val} TPM scaled to size {size}") + + # Convert area to diameter for Line2D (sqrt of area * 2) + marker_diameter = 2 * np.sqrt(size / np.pi) + + size_legend_elements.append( + plt.Line2D([0], [0], marker='o', color='w', + markerfacecolor='gray', markersize=marker_diameter, + label=f'{val:.0f} TPM') + ) + + # Position legend + plt.legend(handles=size_legend_elements, + title="Expression Level", + loc='center left', + bbox_to_anchor=(1.15, 0.5), + frameon=False, + title_fontsize=12, + fontsize=12) + + plt.xticks(fontsize=12) + plt.yticks(fontsize=12) + + # Add labels and title with actual condition names + plt.xlabel('Log2 Fold Change', fontsize=12) + plt.ylabel('Total expression from novel transcripts (%)', fontsize=12) + plt.title('Novel Transcript Usage vs Expression Change between High Risk and Low Risk Phenotypes', fontsize=18) + + plt.grid(True, alpha=0.3) + + # Use tighter layout settings + plt.tight_layout() + + # Save figure to parent directory instead of gene_visualizations_dir + output_path = os.path.join(parent_dir, "novel_transcript_expression_plot.pdf") + plt.savefig(output_path, dpi=300, bbox_inches='tight', pad_inches=0.5) + plt.close() + + logging.debug(f"Novel transcript expression plot saved to {output_path}") + class ExpressionVisualizer: def __init__(self, output_path): @@ -404,8 +988,8 @@ def create_volcano_plot( plt.tight_layout() plot_path = ( - self.output_path / f"volcano_plot_{feature_type}.png" - ) # Modified line + self.output_path / f"volcano_plot_{feature_type}.pdf" # Changed from .png to .pdf + ) plt.savefig(str(plot_path)) plt.close() logging.info(f"Volcano plot saved to {plot_path}") @@ -435,7 +1019,7 @@ def create_ma_plot( plt.title(f"MA Plot: {target_label} vs {reference_label}") plt.tight_layout() - plot_path = self.output_path / f"ma_plot_{feature_type}.png" # Modified line + plot_path = self.output_path / f"ma_plot_{feature_type}.pdf" # Changed from .png to .pdf plt.savefig(str(plot_path)) plt.close() logging.info(f"MA plot saved to {plot_path}") @@ -565,7 +1149,7 @@ def plot_pca( plt.gca().spines['right'].set_visible(False) plt.tight_layout() - scatter_plot_path = self.output_path / f"{output_prefix}.png" # Separate path for scatter plot + scatter_plot_path = self.output_path / f"{output_prefix}.pdf" # Changed from .png to .pdf plt.savefig(scatter_plot_path) plt.close() @@ -592,7 +1176,7 @@ def _plot_scree(self, explained_variance: np.ndarray, output_prefix: str) -> Pat plt.gca().spines['right'].set_visible(False) plt.tight_layout() - scree_plot_path = self.output_path / f"scree_{output_prefix}.png" + scree_plot_path = self.output_path / f"scree_{output_prefix}.pdf" # Changed from .png to .pdf plt.savefig(scree_plot_path) plt.close() return scree_plot_path From 862c4b8506c1f374b60249acc124b938ed9aa1e0 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Mon, 31 Mar 2025 23:52:04 -0500 Subject: [PATCH 29/35] working system --- src/visualization_dictionary_builder.py | 288 ++++------------- src/visualization_differential_exp.py | 405 +++++++++++++++++++++--- src/visualization_output_config.py | 331 +++++++++++++++++-- visualize.py | 56 ++-- 4 files changed, 768 insertions(+), 312 deletions(-) diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py index c17e84c2..9c61a285 100644 --- a/src/visualization_dictionary_builder.py +++ b/src/visualization_dictionary_builder.py @@ -5,7 +5,6 @@ import logging from pathlib import Path from typing import Dict, Any, List, Union, Tuple -import random from src.visualization_cache_utils import ( build_gene_dict_cache_file, @@ -44,10 +43,10 @@ def build_gene_dict_with_expression_and_filter( self.logger.debug(f"Starting optimized dictionary build with min_value={min_value}") # 1. Check cache first - expr_file, tpm_file = self._get_expression_files() + tpm_file = self._get_tpm_file() base_cache_file = build_gene_dict_cache_file( self.config.extended_annotation, - expr_file, + tpm_file, self.config.ref_only, self.cache_dir, ) @@ -70,10 +69,7 @@ def build_gene_dict_with_expression_and_filter( # 2. Filter novel genes from the base gene dict (not per-condition) self.logger.info("Parsing GTF and filtering novel genes") - if self.config.ref_only: - parsed_data = self.parse_input_gtf() - else: - parsed_data = self.parse_extended_annotation() + parsed_data = self.parse_gtf() self._validate_gene_structure(parsed_data) base_gene_dict = self._filter_novel_genes(parsed_data) @@ -86,20 +82,13 @@ def build_gene_dict_with_expression_and_filter( f"After novel gene filtering: {gene_count_after_novel_filter} genes, {transcript_count_after_novel_filter} transcripts" ) - # 3. Load expression data with consistent header handling - self.logger.info("Loading expression matrix (Counts)") - try: - expr_df = pd.read_csv(expr_file, sep='\t', comment=None) - expr_df.columns = [col.lstrip('#') for col in expr_df.columns] # Clean headers - expr_df = expr_df.set_index('feature_id') # Use cleaned column name - except KeyError as e: - self.logger.error(f"Missing required column in {expr_file}: {str(e)}") - raise - except Exception as e: - self.logger.error(f"Failed to load count expression matrix: {str(e)}") - raise + if hasattr(self.config, 'transcript_map') and self.config.transcript_map: + self.logger.info(f"Using transcript mapping from OutputConfig with {len(self.config.transcript_map)} entries") + else: + self.logger.info("No transcript mapping found, proceeding with original transcripts") - self.logger.info("Loading TPM matrix") + # 3. Load expression data (TPM only) with consistent header handling + self.logger.info("Loading TPM matrix for filtering and expression values") try: tpm_df = pd.read_csv(tpm_file, sep='\t', comment=None) tpm_df.columns = [col.lstrip('#') for col in tpm_df.columns] # Clean headers @@ -111,10 +100,10 @@ def build_gene_dict_with_expression_and_filter( self.logger.error(f"Failed to load TPM expression matrix: {str(e)}") raise - conditions = expr_df.columns.tolist() + conditions = tpm_df.columns.tolist() - # 4. Vectorized processing instead of row-wise iteration (using counts for filtering) - transcript_max_values = expr_df.max(axis=1) + # 4. Vectorized processing using TPM for both filtering and values + transcript_max_values = tpm_df.max(axis=1) valid_transcripts = set( transcript_max_values[transcript_max_values >= min_value].index ) @@ -122,21 +111,20 @@ def build_gene_dict_with_expression_and_filter( # Add debug log: Number of valid transcripts after min_value filtering valid_transcript_count = len(valid_transcripts) self.logger.debug( - f"After min_value ({min_value}) filtering: {valid_transcript_count} valid transcripts" + f"After TPM min_value ({min_value}) filtering: {valid_transcript_count} valid transcripts" ) - # 5. Single-pass filtering and value updating (using TPMs for values) + # 5. Single-pass filtering and value updating (using TPMs for both) filtered_dict = {} for condition in conditions: filtered_dict[condition] = {} - condition_counts = expr_df[condition] # Still using counts for filtering logic if needed later - condition_tpm_values = tpm_df[condition] # Use TPM values for assigning expression + condition_tpm_values = tpm_df[condition] for gene_id, gene_info in base_gene_dict.items(): new_transcripts = { - tid: {**tinfo, 'value': condition_tpm_values.get(tid, 0)} # Use TPM values here! + tid: {**tinfo, 'value': condition_tpm_values.get(tid, 0)} for tid, tinfo in gene_info['transcripts'].items() - if tid in valid_transcripts # Filtering is still based on counts implicitly from valid_transcripts + if tid in valid_transcripts } if new_transcripts: @@ -152,7 +140,7 @@ def build_gene_dict_with_expression_and_filter( aggregated_exons = {} # Iterate over each transcript in the gene. for transcript_id, transcript_info in gene_info["transcripts"].items(): - transcript_value = transcript_info.get("value", 0) # Now this is TPM value + transcript_value = transcript_info.get("value", 0) # TPM value # Loop through each exon in the current transcript. for exon in transcript_info.get("exons", []): exon_id = exon.get("exon_id") @@ -172,24 +160,22 @@ def build_gene_dict_with_expression_and_filter( # Now assign the aggregated exon dictionary to the gene. gene_info["exons"] = aggregated_exons - # Write exon expression table with proper Path handling - output_file = Path(self.config.output_directory) / "exon_expression_table.tsv" - self.write_exon_expression_table(filtered_dict, output_file) - save_cache( expr_filter_cache, (filtered_dict, self.novel_gene_ids, self.novel_transcript_ids) ) self.logger.info(f"Saved dictionary to cache at {expr_filter_cache}") return filtered_dict - def _get_expression_files(self) -> Tuple[str, str]: - """Get count file for filtering and TPM file for values.""" - # Get counts file path using existing logic - counts_file = self._get_expression_file() - - # Get corresponding TPM file path - if self.config.conditions: - tpm_file = self.config.transcript_grouped_tpm + def _get_tpm_file(self) -> str: + """Get the appropriate TPM file path from config.""" + if self.config.conditions: # Check if we have multiple conditions + # For multi-condition data, prioritize merged files if available + merged_tpm = self.config.transcript_grouped_tpm + if merged_tpm and "_merged.tsv" in merged_tpm: + self.logger.info("Using merged TPM file with transcript mapping already applied") + tpm_file = merged_tpm + else: + tpm_file = self.config.transcript_grouped_tpm else: if self.config.ref_only: tpm_file = self.config.transcript_tpm_ref @@ -201,175 +187,7 @@ def _get_expression_files(self) -> Tuple[str, str]: if not tpm_file or not Path(tpm_file).exists(): raise FileNotFoundError(f"TPM file {tpm_file} not found") - return counts_file, tpm_file - - def _get_expression_file(self) -> str: - """Get the appropriate count file path from config.""" - if self.config.conditions: # Check if we have multiple conditions - expr_file = self.config.transcript_grouped_counts - else: - if self.config.ref_only: - expr_file = self.config.transcript_counts_ref - else: - base_file = self.config.transcript_counts.replace('.tsv', '') - expr_file = f"{base_file}_merged.tsv" - - self.logger.debug(f"Selected count file: {expr_file}") - if not expr_file or not Path(expr_file).exists(): - raise FileNotFoundError(f"Count file {expr_file} not found") - return expr_file - - def write_exon_expression_table(self, gene_dict: Dict[str, Any], output_path: Path) -> None: - """ - Write a table of exon expressions across conditions. - """ - self.logger.info("Creating exon expression table") - - # Ensure output directory exists. - output_path.parent.mkdir(parents=True, exist_ok=True) - - # Get all conditions (keys in gene_dict). - conditions = list(gene_dict.keys()) - self.logger.debug(f"Processing {len(conditions)} conditions: {conditions}") - - # Prepare header. - header = [ - "Gene Symbol", "Gene Name", "Gene Coordinates", "Ensembl ID", - "Exon number", "Chrom", "Exon start", "Exon end", "Strand" - ] + conditions - - # Instead of looping condition by condition, create a union of gene IDs. - all_gene_ids = set() - for cond in conditions: - all_gene_ids.update(gene_dict[cond].keys()) - self.logger.debug(f"Total unique genes to process: {len(all_gene_ids)}") - - rows = [] - gene_count = 0 # Initialize gene counter for logging - processed_exon_count = 0 - sample_exon_ids = set() # To keep track of sampled exons - num_sample_exons = 100 - - for gene_id in all_gene_ids: - gene_count += 1 # Increment gene counter - self.logger.debug(f"Processing gene {gene_count}/{len(all_gene_ids)}: {gene_id}") - - # Get a representative gene_info (static data is the same across conditions). - rep_gene_info = None - for cond in conditions: - if gene_id in gene_dict[cond]: - rep_gene_info = gene_dict[cond][gene_id] - break - if rep_gene_info is None: - self.logger.warning(f"Gene {gene_id} not found in any condition, skipping.") - continue # Should not happen, but skip if not found. - - # Compute gene coordinates string. - gene_coords = f"{rep_gene_info['chromosome']}:{rep_gene_info['start']}-{rep_gene_info['end']}" - - # Collect the union of exon IDs across all conditions for this gene. - all_exon_ids = set() - for cond in conditions: - if gene_id in gene_dict[cond]: - all_exon_ids.update(gene_dict[cond][gene_id].get("exons", {}).keys()) - self.logger.debug(f" Gene {gene_id} - Total unique exons across conditions: {len(all_exon_ids)}") - - exon_count = 0 # Initialize exon counter for logging - exon_ids_list = list(all_exon_ids) # Convert to list for sampling - sampled_exons_for_gene = [] - - # Sample exons if we haven't reached the desired number yet - if processed_exon_count < num_sample_exons: - num_to_sample = min(num_sample_exons - processed_exon_count, len(exon_ids_list)) - sampled_exons_for_gene = random.sample(exon_ids_list, num_to_sample) - - # For each exon, gather condition-specific expression values. - for exon_id in exon_ids_list: # Iterate through all exons, sample only for logging - exon_count += 1 # Increment exon counter - process_exon_for_log = False - if exon_id in sampled_exons_for_gene and processed_exon_count < num_sample_exons and exon_id not in sample_exon_ids: - process_exon_for_log = True - processed_exon_count += 1 - sample_exon_ids.add(exon_id) # Mark as processed - - if process_exon_for_log: - self.logger.debug(f" Gene {gene_id} - Processing exon {exon_count}/{len(all_exon_ids)} (SAMPLE): {exon_id}") - else: - self.logger.debug(f" Gene {gene_id} - Processing exon {exon_count}/{len(all_exon_ids)}: {exon_id}") - - exon_expressions = [] - aggregated_transcript_values = {} # To store transcript values for logging - - for cond in conditions: - expr = 0.0 - # Lookup the gene in the current condition. - gene_info = gene_dict[cond].get(gene_id, {}) - exon_info = gene_info.get("exons", {}).get(exon_id, {}) - expr = exon_info.get("value", 0.0) # Get condition-specific exon value - - if process_exon_for_log: - # Find transcripts contributing to this exon and log their values - contributing_transcripts = [] - for transcript_id, transcript_info in gene_info.get("transcripts", {}).items(): - for exon_data in transcript_info.get("exons", []): - if exon_data.get("exon_id") == exon_id: - transcript_value = transcript_info.get("value", 0.0) - contributing_transcripts.append((transcript_id, transcript_value)) - aggregated_transcript_values[cond] = aggregated_transcript_values.get(cond, []) + [(transcript_id, transcript_value)] - - self.logger.debug(f" Condition {cond} - Exon {exon_id} expression: {expr:.2f}") - if contributing_transcripts: - transcript_log_str = ", ".join([f"{tid}:{val:.2f}" for tid, val in contributing_transcripts]) - self.logger.debug(f" Contributing transcripts (Condition {cond}): {transcript_log_str}") - else: - self.logger.debug(f" No transcripts contributing to exon {exon_id} in condition {cond}") - - - exon_expressions.append(f"{expr:.2f}") - - if process_exon_for_log: - # Log the aggregation process - for cond in conditions: - transcript_values_for_cond = aggregated_transcript_values.get(cond, []) - if transcript_values_for_cond: - sum_of_transcripts = sum([val for tid, val in transcript_values_for_cond]) - self.logger.debug(f" Condition {cond} - Sum of contributing transcript TPMs for exon {exon_id}: {sum_of_transcripts:.2f} (Exon TPM: {exon_expressions[conditions.index(cond)]})") - else: - self.logger.debug(f" Condition {cond} - No contributing transcripts found to sum for exon {exon_id}") - - - # Use the representative gene's exon info for static details. - rep_exon_info = rep_gene_info.get("exons", {}).get(exon_id, {}) - exon_number = rep_exon_info.get("number", "NA") - exon_start = str(rep_exon_info.get("start", "")) - exon_end = str(rep_exon_info.get("end", "")) - - row = [ - gene_id, # Gene Symbol - rep_gene_info.get("name", ""), # Gene Name - gene_coords, # Gene Coordinates - exon_id, # Ensembl ID (exon_id) - exon_number, # Exon number - rep_gene_info["chromosome"], # Chromosome - exon_start, # Exon start - exon_end, # Exon end - rep_gene_info["strand"], # Strand - ] + exon_expressions # Expression values for each condition - - rows.append(row) - self.logger.debug(f" Gene {gene_id} - Row for exon {exon_id} prepared.") - if process_exon_for_log: - self.logger.debug(f" Gene {gene_id} - Sampled exon {exon_id} processing complete.") - - # Write header and rows to the output file. - self.logger.info(f"Writing {len(rows)} exon entries to table") - with open(output_path, 'w') as f: - f.write('\t'.join(header) + '\n') - for row in rows: - f.write('\t'.join(str(x) for x in row) + '\n') - - self.logger.info(f"Exon expression table written to {output_path}") - + return tpm_file # ------------------ READ ASSIGNMENT CACHING ------------------ @@ -471,12 +289,22 @@ def _process_read_assignment_file(self, file_path): # -------------------- GTF PARSING -------------------- - def parse_input_gtf(self) -> Dict[str, Any]: + def parse_gtf(self) -> Dict[str, Any]: """ - Parse the reference GTF file using gffutils with optimized settings, - building a dictionary of genes, transcripts, and exons. - Updated to match the structure of parse_extended_annotation. + Parse GTF file into a dictionary with genes, transcripts, and exons. + Handles both reference GTF (with gffutils) and extended annotation GTF. """ + if self.config.ref_only: + # Use gffutils for reference GTF (more robust but slower) + self.logger.info("Parsing reference GTF using gffutils") + return self._parse_reference_gtf() + else: + # Use faster custom parser for extended annotation + self.logger.info("Parsing extended annotation GTF with custom parser") + return self._parse_extended_gtf() + + def _parse_reference_gtf(self) -> Dict[str, Any]: + """Parse reference GTF using gffutils""" if not self.config.genedb_filename: db_path = self.cache_dir / "gtf.db" if not db_path.exists(): @@ -485,7 +313,7 @@ def parse_input_gtf(self) -> Dict[str, Any]: self.config.input_gtf, dbfn=str(db_path), force=True, - merge_strategy="create_unique", # Faster than merge + merge_strategy="create_unique", disable_infer_genes=True, disable_infer_transcripts=True, verbose=False, @@ -559,13 +387,13 @@ def parse_input_gtf(self) -> Dict[str, Any]: } ) - self.logger.info(f"Processed {len(gene_dict)} genes from GTF") + self.logger.info(f"Processed {len(gene_dict)} genes from reference GTF") return gene_dict - - def parse_extended_annotation(self) -> Dict[str, Any]: - """Parse merged GTF into base structure without condition info.""" + + def _parse_extended_gtf(self) -> Dict[str, Any]: + """Parse extended annotation GTF with custom parser""" base_gene_dict = {} - self.logger.info("Parsing extended annotation GTF (non-ref_only)") + self.logger.info("Parsing extended annotation GTF") try: with open(self.config.extended_annotation, "r") as file: @@ -628,11 +456,27 @@ def parse_extended_annotation(self) -> Dict[str, Any]: } base_gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append(exon_info) - self.logger.info(f"Parsed base structure: {len(base_gene_dict)} genes") + self.logger.info(f"Processed {len(base_gene_dict)} genes from extended annotation GTF") return base_gene_dict except Exception as e: self.logger.error(f"GTF parsing failed: {str(e)}") raise + + # Keep the original functions for backward compatibility, but have them use the new implementation + def parse_input_gtf(self) -> Dict[str, Any]: + """ + Parse the reference GTF file using gffutils. + This is now a wrapper around _parse_reference_gtf for backward compatibility. + """ + return self._parse_reference_gtf() + + def parse_extended_annotation(self) -> Dict[str, Any]: + """ + Parse extended annotation GTF. + This is now a wrapper around _parse_extended_gtf for backward compatibility. + """ + return self._parse_extended_gtf() + # -------------------- UPDATES & UTILITIES -------------------- def update_gene_names(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 487551c7..f9a6853b 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -46,10 +46,56 @@ def quiet_cb(x): # Create a single logger for this class self.logger = logging.getLogger('IsoQuant.visualization.differential_exp') + # Get transcript mapping if available + self.transcript_map = {} + if hasattr(self.dictionary_builder, 'config') and hasattr(self.dictionary_builder.config, 'transcript_map'): + self.transcript_map = self.dictionary_builder.config.transcript_map + if self.transcript_map: + self.logger.info(f"Using transcript mapping from dictionary_builder with {len(self.transcript_map)} entries for DESeq2 analysis") + else: + # Try to load transcript mapping directly from file + self.logger.info("Transcript mapping from dictionary_builder is empty, trying to load it directly from file") + self._load_transcript_mapping_from_file() + else: + # Try to load transcript mapping directly from file + self.logger.info("No transcript mapping available from dictionary_builder, trying to load it directly from file") + self._load_transcript_mapping_from_file() + self.transcript_to_gene = self._create_transcript_to_gene_map() self.visualizer = ExpressionVisualizer(self.deseq_dir) self.gene_mapper = GeneMapper() + def _load_transcript_mapping_from_file(self): + """Load transcript mapping directly from the transcript_mapping.tsv file.""" + mapping_file = self.output_dir / "transcript_mapping.tsv" + + if not mapping_file.exists(): + self.logger.warning(f"Transcript mapping file not found at {mapping_file}") + return + + try: + # Load the transcript mapping file + self.logger.info(f"Loading transcript mapping from {mapping_file}") + self.transcript_map = {} + + # Skip header and read the mapping + with open(mapping_file, 'r') as f: + header = f.readline() # Skip header + for line in f: + parts = line.strip().split('\t') + if len(parts) == 2: + transcript_id, canonical_id = parts + self.transcript_map[transcript_id] = canonical_id + + self.logger.info(f"Successfully loaded {len(self.transcript_map)} transcript mappings from file") + + # Log some examples for debugging + sample_items = list(self.transcript_map.items())[:5] + for orig, canon in sample_items: + self.logger.debug(f"Mapping sample: {orig} → {canon}") + except Exception as e: + self.logger.error(f"Failed to load transcript mapping: {str(e)}") + def _create_transcript_to_gene_map(self) -> Dict[str, str]: """ Create a mapping from transcript IDs to gene names. @@ -237,32 +283,163 @@ def _run_level_analysis( # No normalized counts returned from _run_deseq2 anymore return outfile, pd.DataFrame() # Return empty DataFrame for normalized counts - def _get_condition_data(self, pattern: str) -> pd.DataFrame: - """Combine count data from all conditions.""" - all_counts = [] - - # Modify pattern if ref_only is False and pattern is transcript_grouped_counts - CORRECTED LOGIC - adjusted_pattern = pattern # Initialize adjusted_pattern to the original pattern - if not self.ref_only and pattern == "transcript_grouped_counts.tsv": # Only adjust if ref_only is False AND base pattern is transcript_grouped_counts + def _get_merged_transcript_counts(self, pattern: str) -> pd.DataFrame: + """ + Get transcript count data and apply transcript mapping to create a merged grouped dataframe. + This preserves the individual sample columns needed for DESeq2, but merges identical transcripts. + """ + self.logger.info(f"Creating merged transcript count matrix with pattern: {pattern}") + + # Adjust pattern if needed + adjusted_pattern = pattern + if not self.ref_only and pattern == "transcript_grouped_counts.tsv": adjusted_pattern = "transcript_model_grouped_counts.tsv" - + + self.logger.info(f"Using file pattern: {adjusted_pattern}") + + # Store sample dataframes + all_sample_dfs = [] + + # Process each condition directory for condition in self.ref_conditions + self.target_conditions: - condition_dir = self.output_dir / condition - count_files = list(condition_dir.glob(f"*{adjusted_pattern}")) # Use adjusted pattern - + condition_dir = Path(self.output_dir) / condition + count_files = list(condition_dir.glob(f"*{adjusted_pattern}")) + + if not count_files: + self.logger.error(f"No count files found for condition: {condition}") + raise FileNotFoundError(f"No count files matching {adjusted_pattern} found in {condition_dir}") + + # Load each count file for file_path in count_files: - self.logger.debug(f"Reading count data from: {file_path}") - df = pd.read_csv(file_path, sep="\t", dtype={"#feature_id": str}) + self.logger.info(f"Reading count data from: {file_path}") + + # Load the file + df = pd.read_csv(file_path, sep="\t") + if "#feature_id" not in df.columns and df.columns[0].startswith("#"): + # Rename first column if it's the feature ID column but named differently + df.rename(columns={df.columns[0]: "#feature_id"}, inplace=True) + + # Set feature_id as index df.set_index("#feature_id", inplace=True) - - # Rename columns to include condition + + # For multi-condition data (typical in sample files) + # We need to prefix each column with the condition name for col in df.columns: - if col.lower() != "count": - df = df.rename(columns={col: f"{condition}_{col}"}) - - all_counts.append(df) + df.rename(columns={col: f"{condition}_{col}"}, inplace=True) + + all_sample_dfs.append(df) + + # Concatenate all dataframes to get the full matrix + if not all_sample_dfs: + self.logger.error("No sample data frames found") + raise ValueError("No sample data found") + + # Combine all sample dataframes + combined_df = pd.concat(all_sample_dfs, axis=1) + self.logger.info(f"Combined count data shape before mapping: {combined_df.shape}") + + # Apply transcript mapping if available + if not hasattr(self, 'transcript_map') or not self.transcript_map: + self.logger.info("No transcript mapping available, using raw counts") + return combined_df + + # Log transcript mapping info + self.logger.info(f"Applying transcript mapping with {len(self.transcript_map)} mappings") + + # Get unique transcript IDs and create mapping dictionary + unique_transcripts = combined_df.index.unique() + transcript_groups = {} + + # Group transcripts by their canonical ID + for transcript_id in unique_transcripts: + canonical_id = self.transcript_map.get(transcript_id, transcript_id) + if canonical_id not in transcript_groups: + transcript_groups[canonical_id] = [] + transcript_groups[canonical_id].append(transcript_id) + + # Create the merged dataframe + merged_df = pd.DataFrame(index=list(transcript_groups.keys()), columns=combined_df.columns) + + # Track merge statistics + total_transcripts = len(unique_transcripts) + merged_groups = 0 + merged_transcripts = 0 + + # For each canonical transcript ID, sum the counts from all transcripts that map to it + for canonical_id, transcript_ids in transcript_groups.items(): + if len(transcript_ids) == 1: + # Just one transcript, copy the row directly + merged_df.loc[canonical_id] = combined_df.loc[transcript_ids[0]] + else: + # Multiple transcripts map to this canonical ID, sum their counts + merged_df.loc[canonical_id] = combined_df.loc[transcript_ids].sum() + merged_groups += 1 + merged_transcripts += len(transcript_ids) - 1 # Count transcripts beyond the first one + + # Log details of significant merges (more than 2 transcripts or interesting transcripts) + if len(transcript_ids) > 2 or any("ENST" in t for t in transcript_ids): + self.logger.info(f"Merged transcript group for {canonical_id}: {transcript_ids}") + + # Log merge statistics + self.logger.info(f"Transcript merging complete: {merged_groups} canonical IDs had multiple transcripts") + self.logger.info(f"Merged {merged_transcripts} transcripts into canonical IDs ({merged_transcripts/total_transcripts:.1%} of total)") + self.logger.info(f"Final merged count matrix shape: {merged_df.shape}") + + return merged_df - return pd.concat(all_counts, axis=1) + def _get_condition_data(self, pattern: str) -> pd.DataFrame: + """Get count data for differential expression analysis.""" + if pattern == "transcript_grouped_counts.tsv": + # For transcript data, use our merged function + return self._get_merged_transcript_counts(pattern) + elif pattern == "gene_grouped_counts.tsv": + # For gene data, use a simpler approach (no merging needed) + self.logger.info(f"Loading gene count data with pattern: {pattern}") + + # Store sample dataframes + all_sample_dfs = [] + + # Process each condition directory + for condition in self.ref_conditions + self.target_conditions: + condition_dir = Path(self.output_dir) / condition + count_files = list(condition_dir.glob(f"*{pattern}")) + + if not count_files: + self.logger.error(f"No gene count files found for condition: {condition}") + raise FileNotFoundError(f"No count files matching {pattern} found in {condition_dir}") + + # Load each count file + for file_path in count_files: + self.logger.info(f"Reading gene count data from: {file_path}") + + # Load the file + df = pd.read_csv(file_path, sep="\t") + if "#feature_id" not in df.columns and df.columns[0].startswith("#"): + # Rename first column if it's the feature ID column but named differently + df.rename(columns={df.columns[0]: "#feature_id"}, inplace=True) + + # Set feature_id as index + df.set_index("#feature_id", inplace=True) + + # For multi-condition data (typical in sample files) + # We need to prefix each column with the condition name + for col in df.columns: + df.rename(columns={col: f"{condition}_{col}"}, inplace=True) + + all_sample_dfs.append(df) + + # Concatenate all dataframes to get the full matrix + if not all_sample_dfs: + self.logger.error("No gene sample data frames found") + raise ValueError("No gene sample data found") + + # Combine all sample dataframes + combined_df = pd.concat(all_sample_dfs, axis=1) + self.logger.info(f"Combined gene count data shape: {combined_df.shape}") + return combined_df + else: + self.logger.error(f"Unsupported count pattern: {pattern}") + raise ValueError(f"Unsupported count pattern: {pattern}") def _filter_counts(self, count_data: pd.DataFrame, min_count: int = 10, level: str = "gene") -> pd.DataFrame: """ @@ -308,15 +485,54 @@ def _filter_counts(self, count_data: pd.DataFrame, min_count: int = 10, level: s return filtered_data def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: - """Create experimental design matrix.""" + """Create experimental design matrix for DESeq2. + + Each column in the count data (sample) needs to be assigned to either + the reference or target group for differential expression analysis. + """ groups = [] + condition_assignments = [] + sample_ids = [] + + self.logger.info("Building experimental design matrix") + for sample in count_data.columns: - if any(sample.startswith(f"{cond}_") for cond in self.ref_conditions): + # Extract the condition from the sample name + # Matches pattern: conditionname_sampleid + # The column name should start with the condition name followed by an underscore + condition = None + for cond in self.ref_conditions + self.target_conditions: + if sample.startswith(f"{cond}_"): + condition = cond + # Extract the sample ID (everything after the condition name and underscore) + sample_id = sample[len(condition)+1:] + break + + if condition is None: + self.logger.error(f"Could not determine condition for sample: {sample}") + raise ValueError(f"Sample column '{sample}' does not match any specified condition") + + # Assign to reference or target group + if condition in self.ref_conditions: groups.append("Reference") else: groups.append("Target") - - return pd.DataFrame({"group": groups}, index=count_data.columns) + + # Store the condition and sample ID for additional information + condition_assignments.append(condition) + sample_ids.append(sample) + + # Create the design matrix DataFrame + design_matrix = pd.DataFrame({ + "group": groups, + "condition": condition_assignments, + "sample_id": sample_ids + }, index=count_data.columns) + + # Log the design matrix for debugging + self.logger.info(f"Design matrix:\n{design_matrix}") + + return design_matrix def _run_deseq2( self, count_data: pd.DataFrame, coldata: pd.DataFrame, level: str @@ -346,6 +562,8 @@ def _run_deseq2( def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dict[str, Optional[str]]]: """ Map feature IDs to gene and transcript names using GeneMapper class. + + For transcripts that have been mapped to canonical IDs, ensure we properly handle the mapping. Args: feature_ids: List of feature IDs (gene symbols or transcript IDs) @@ -356,6 +574,49 @@ def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dic containing 'transcript_symbol' and 'gene_name'. 'transcript_symbol' is None for gene-level analysis. """ + # Check if we need to handle canonical transcript IDs + if level == "transcript" and self.transcript_map: + # Create a mapping from canonical IDs to original IDs for reverse lookup + canonical_to_original = {} + for original, canonical in self.transcript_map.items(): + if canonical not in canonical_to_original: + canonical_to_original[canonical] = [] + canonical_to_original[canonical].append(original) + + # Process feature_ids that may include canonical IDs + result = {} + for feature_id in feature_ids: + # First try to map directly + direct_map = self.gene_mapper.map_gene_symbols([feature_id], level, self.updated_gene_dict) + + # If direct mapping worked, use it + if feature_id in direct_map and direct_map[feature_id]["gene_name"]: + result[feature_id] = direct_map[feature_id] + continue + + # If this is a canonical ID, try to map using one of its original IDs + if feature_id in canonical_to_original: + for original_id in canonical_to_original[feature_id]: + original_map = self.gene_mapper.map_gene_symbols([original_id], level, self.updated_gene_dict) + if original_id in original_map and original_map[original_id]["gene_name"]: + # Use the original ID's mapping but keep the canonical ID as the transcript symbol + result[feature_id] = { + "transcript_symbol": feature_id, + "gene_name": original_map[original_id]["gene_name"] + } + self.logger.debug(f"Mapped canonical ID {feature_id} using original ID {original_id}") + break + + # If still not mapped, use a default mapping + if feature_id not in result: + result[feature_id] = { + "transcript_symbol": feature_id, + "gene_name": feature_id.split('.')[0] if '.' in feature_id else feature_id + } + + return result + + # For gene level or when no transcript mapping is available, use the original method return self.gene_mapper.map_gene_symbols(feature_ids, level, self.updated_gene_dict) def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: @@ -363,7 +624,8 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: results["abs_stat"] = abs(results["stat"]) if level == "transcript": - top_transcripts = results.nlargest(len(results), "abs_stat") # Get ALL transcripts ranked by abs_stat + # where baseMean is greater than 500 + top_transcripts = results[results["baseMean"] > 500].nlargest(len(results), "abs_stat") unique_genes = set() top_unique_gene_transcripts = [] @@ -376,7 +638,7 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: unique_genes.add(gene_name) top_unique_gene_transcripts.append(transcript_row) unique_gene_count += 1 - if unique_gene_count >= 100: # Stop when we reach 100 unique genes + if unique_gene_count >= 100: # Stop when we reach 500 unique genes break transcript_count += 1 # Keep track of total transcripts considered @@ -431,23 +693,76 @@ def _median_ratio_normalization(self, count_data: pd.DataFrame) -> pd.DataFrame: """ Perform median-by-ratio normalization on count data. This is similar to the normalization used in DESeq2. + Handles zeros and potential data type issues safely. """ - # 1. Calculate geometric mean for each feature (row) - geometric_means = count_data.apply(gmean, axis=1) - - # 2. Handle rows with zero geometric mean (replace with NaN to avoid division by zero) - geometric_means[geometric_means == 0] = np.nan - - # 3. Calculate ratio of each count to the geometric mean - count_ratios = count_data.divide(geometric_means, axis=0) - - # 4. Calculate size factor for each sample (column) as the median of ratios - size_factors = count_ratios.median(axis=0) - - # 5. Normalize counts by dividing by size factors - normalized_counts = count_data.divide(size_factors, axis=1) - - # 6. Fill NaN values with 0 after normalization - normalized_counts = normalized_counts.fillna(0) - - return normalized_counts \ No newline at end of file + try: + # Convert to numeric and handle any non-numeric values + count_data_numeric = count_data.apply(pd.to_numeric, errors='coerce').fillna(0) + + # Ensure all values are positive or zero + count_data_nonneg = count_data_numeric.clip(lower=0) + + # Add pseudocount to avoid zeros (1 is a common choice) + count_data_safe = count_data_nonneg + 1 + + # Check data types and values + self.logger.debug(f"Count data shape: {count_data_safe.shape}") + self.logger.debug(f"Count data dtype: {count_data_safe.values.dtype}") + self.logger.debug(f"Min value: {count_data_safe.values.min()}, Max value: {count_data_safe.values.max()}") + + # Convert to numpy array + counts_numpy = count_data_safe.values.astype(float) + + # Alternative geometric mean calculation + # Use log1p which is log(1+x) to handle zeros more safely + log_counts = np.log(counts_numpy) + row_means = np.mean(log_counts, axis=1) + geometric_means = np.exp(row_means) + + # Check for any invalid values in geometric means + if np.any(~np.isfinite(geometric_means)): + self.logger.warning("Found non-finite values in geometric means, replacing with 1.0") + geometric_means[~np.isfinite(geometric_means)] = 1.0 + + # Calculate ratio of each count to the geometric mean + # Reshape geometric_means for broadcasting + geo_means_col = geometric_means.reshape(-1, 1) + ratios = counts_numpy / geo_means_col + + # Calculate size factor for each sample (median of ratios) + size_factors = np.median(ratios, axis=0) + + # Check for any invalid values in size factors + if np.any(size_factors <= 0) or np.any(~np.isfinite(size_factors)): + self.logger.warning("Found invalid size factors, replacing with 1.0") + size_factors[~np.isfinite(size_factors) | (size_factors <= 0)] = 1.0 + + # Log size factors + self.logger.info(f"Size factors: {size_factors}") + + # Normalize counts by dividing by size factors + # Use original count data (without pseudocount) for the final normalization + normalized_counts = pd.DataFrame( + count_data_nonneg.values / size_factors, + index=count_data.index, + columns=count_data.columns + ) + + # Fill any NaN values with 0 + normalized_counts = normalized_counts.fillna(0) + + return normalized_counts + + except Exception as e: + self.logger.error(f"Error in median ratio normalization: {str(e)}") + self.logger.error("Falling back to simple TPM-like normalization") + + # Fallback normalization (similar to TPM) + # Sum each column and divide counts by the sum + col_sums = count_data.sum(axis=0) + col_sums = col_sums.replace(0, 1) # Avoid division by zero + + # Normalize each column by its sum and multiply by 1e6 (similar to TPM scaling) + normalized_counts = count_data.div(col_sums, axis=1) * 1e6 + + return normalized_counts \ No newline at end of file diff --git a/src/visualization_output_config.py b/src/visualization_output_config.py index 0c79171d..28fc0930 100644 --- a/src/visualization_output_config.py +++ b/src/visualization_output_config.py @@ -6,8 +6,9 @@ from argparse import Namespace import gffutils import yaml -from typing import List +from typing import List, Dict, Tuple, Set import logging +import re class OutputConfig: """Class to build dictionaries from the output files of the pipeline.""" @@ -15,7 +16,6 @@ class OutputConfig: def __init__( self, output_directory: str, - use_counts: bool = False, ref_only: bool = False, gtf: str = None, ): @@ -41,7 +41,6 @@ def __init__( self.transcript_model_tpm = None self.transcript_model_grouped_tpm = None self.transcript_model_grouped_counts = None - self.use_counts = use_counts self.ref_only = ref_only # New attributes for handling extended annotations @@ -52,6 +51,9 @@ def __init__( self.samples = [] self.sample_transcript_model_tpm = {} self.sample_transcript_model_counts = {} + + # New attribute for transcript mapping + self.transcript_map = {} # Maps transcript IDs to canonical transcript ID with same exon structure self._load_params_file() self._find_files() @@ -406,14 +408,293 @@ def _handle_extended_annotations(self, samples_count): "No extended_annotation.gtf files found. Continuing without merge." ) + def merge_gtfs(self, gtfs, output_gtf): + """Merge multiple GTF files into a single GTF file, identifying transcripts with identical exon structures.""" + try: + # First, parse all GTFs to identify transcripts with identical exon structures + print(f"Analyzing {len(gtfs)} GTF files to identify identical transcript structures") + logging.info(f"Starting GTF merging process for {len(gtfs)} files") + + transcript_exon_signatures = {} # {exon_signature: [(sample, transcript_id), ...]} + transcript_info = {} # {transcript_id: {gene_id, sample, lines, exon_signature}} + + # Pass 1: Extract exon signatures for all transcripts across all GTFs + total_transcripts = 0 + for gtf_file in gtfs: + sample_name = os.path.basename(os.path.dirname(gtf_file)) + logging.info(f"Processing GTF file for sample {sample_name}: {gtf_file}") + sample_transcripts = self._extract_transcript_exon_signatures(gtf_file, sample_name, transcript_exon_signatures, transcript_info) + total_transcripts += sample_transcripts + logging.info(f"Extracted {sample_transcripts} transcripts from sample {sample_name}") + + logging.info(f"Total transcripts processed: {total_transcripts}") + logging.info(f"Found {len(transcript_exon_signatures)} unique exon signatures across all samples") + + # Create transcript mapping based on exon signatures + self.transcript_map = self._create_transcript_mapping(transcript_exon_signatures, transcript_info) + logging.info(f"Created mapping for {len(self.transcript_map)} transcripts to {len(set(self.transcript_map.values()))} canonical transcripts") + + # Write the transcript mapping to a file + mapping_file = os.path.join(os.path.dirname(output_gtf), "transcript_mapping.tsv") + self._write_transcript_mapping(mapping_file) + logging.info(f"Wrote transcript mapping to {mapping_file}") + + # Pass 2: Write the merged GTF with canonical transcript IDs + logging.info(f"Writing merged GTF file to {output_gtf}") + self._write_merged_gtf(gtfs, output_gtf) + + print(f"Successfully merged {len(gtfs)} GTF files into {output_gtf}") + print(f"Identified {len(self.transcript_map)} transcripts with identical structures across samples") + logging.info(f"GTF merging complete. Output file: {output_gtf}") + + except Exception as e: + logging.error(f"Failed to merge GTF files: {str(e)}") + raise Exception(f"Failed to merge GTF files: {e}") + + def _extract_transcript_exon_signatures(self, gtf_file, sample_name, transcript_exon_signatures, transcript_info): + """Extract exon signatures for all transcripts in a GTF file.""" + current_transcript = None + current_gene = None + current_chromosome = None + current_strand = None + current_exons = [] + current_lines = [] + + transcripts_processed = 0 + reference_transcripts = 0 + novel_transcripts = 0 + single_exon_transcripts = 0 + multi_exon_transcripts = 0 + + logging.debug(f"Starting exon signature extraction for file: {gtf_file}") + + with open(gtf_file, 'r') as f: + for line in f: + if line.startswith('#'): + continue + + fields = line.strip().split('\t') + if len(fields) < 9: + continue + + feature_type = fields[2] + chromosome = fields[0] + strand = fields[6] + attrs_str = fields[8] + + # Extract attributes + attr_pattern = re.compile(r'(\S+) "([^"]+)";') + attrs = dict(attr_pattern.findall(attrs_str)) + + transcript_id = attrs.get('transcript_id') + gene_id = attrs.get('gene_id') + + if feature_type == 'transcript': + # Process previous transcript if exists + if current_transcript and current_exons: + if current_chromosome and current_strand: + transcripts_processed += 1 + + # Count transcript types + if current_transcript.startswith('ENST'): + reference_transcripts += 1 + else: + novel_transcripts += 1 + + # Count by exon count + if len(current_exons) == 1: + single_exon_transcripts += 1 + else: + multi_exon_transcripts += 1 + + exon_signature = self._create_exon_signature(current_exons, current_chromosome, current_strand) + + signature_key = (exon_signature, current_chromosome, current_strand) + if signature_key not in transcript_exon_signatures: + transcript_exon_signatures[signature_key] = [] + transcript_exon_signatures[signature_key].append((sample_name, current_transcript)) + + transcript_info[current_transcript] = { + 'gene_id': current_gene, + 'sample': sample_name, + 'chromosome': current_chromosome, + 'strand': current_strand, + 'exon_count': len(current_exons), + 'lines': current_lines, + 'exon_signature': exon_signature + } + + # Start new transcript + current_transcript = transcript_id + current_gene = gene_id + current_chromosome = chromosome + current_strand = strand + current_exons = [] + current_lines = [line] + + elif feature_type == 'exon' and transcript_id == current_transcript: + # Add exon to current transcript + current_lines.append(line) + exon_start = int(fields[3]) + exon_end = int(fields[4]) + current_exons.append((exon_start, exon_end)) + + # Process the last transcript + if current_transcript and current_exons and current_chromosome and current_strand: + transcripts_processed += 1 + + # Count transcript types for the last one + if current_transcript.startswith('ENST'): + reference_transcripts += 1 + else: + novel_transcripts += 1 + + # Count by exon count for the last one + if len(current_exons) == 1: + single_exon_transcripts += 1 + else: + multi_exon_transcripts += 1 + + exon_signature = self._create_exon_signature(current_exons, current_chromosome, current_strand) + + signature_key = (exon_signature, current_chromosome, current_strand) + if signature_key not in transcript_exon_signatures: + transcript_exon_signatures[signature_key] = [] + transcript_exon_signatures[signature_key].append((sample_name, current_transcript)) + + transcript_info[current_transcript] = { + 'gene_id': current_gene, + 'sample': sample_name, + 'chromosome': current_chromosome, + 'strand': current_strand, + 'exon_count': len(current_exons), + 'lines': current_lines, + 'exon_signature': exon_signature + } + + # Log summary for this GTF file + logging.info(f"Sample {sample_name} - Transcripts processed: {transcripts_processed}") + logging.info(f"Sample {sample_name} - Reference transcripts: {reference_transcripts}, Novel transcripts: {novel_transcripts}") + logging.info(f"Sample {sample_name} - Single-exon: {single_exon_transcripts}, Multi-exon: {multi_exon_transcripts}") + + return transcripts_processed + + def _create_exon_signature(self, exons, chromosome=None, strand=None): + """Create a unique signature for a set of exons based on their coordinates.""" + # Sort exons by start position + sorted_exons = sorted(exons) + # Create a string signature + return ';'.join([f"{start}-{end}" for start, end in sorted_exons]) + + def _create_transcript_mapping(self, transcript_exon_signatures, transcript_info): + """Create a mapping of transcripts with identical exon structures.""" + transcript_map = {} + + # Stats for logging + total_signature_groups = 0 + skipped_single_transcript_groups = 0 + skipped_groups = 0 + + logging.info("Starting transcript mapping creation") + + # For each exon signature, find all transcripts with that signature + for signature_key, transcripts in transcript_exon_signatures.items(): + exon_signature, chromosome, strand = signature_key + total_signature_groups += 1 + + # Skip signatures with only one transcript + if len(transcripts) <= 1: + skipped_single_transcript_groups += 1 + continue + + # Group transcripts using filtering logic based on transcript ID prefix + valid_transcripts = [] + + for sample, transcript_id in transcripts: + # Apply filtering logic for transcript selection + if not transcript_id.startswith('ENST'): + valid_transcripts.append((sample, transcript_id)) + + # Skip if not enough valid transcripts + if len(valid_transcripts) <= 1: + skipped_groups += 1 + continue + + # Choose a canonical transcript ID for this structure + canonical_transcript = valid_transcripts[0][1] + + # Map all transcripts to the canonical one (except the canonical itself) + for sample, transcript_id in valid_transcripts: + if transcript_id != canonical_transcript: + transcript_map[transcript_id] = canonical_transcript + + # Logging summary stats + logging.info(f"Total exon signature groups: {total_signature_groups}") + logging.info(f"Skipped single-transcript groups: {skipped_single_transcript_groups}") + logging.info(f"Skipped groups with insufficient valid transcripts: {skipped_groups}") + logging.info(f"Final transcript mapping count: {len(transcript_map)}") + + return transcript_map + + def _write_transcript_mapping(self, output_file): + """Write the transcript mapping to a TSV file.""" + with open(output_file, 'w') as f: + f.write("transcript_id\tcanonical_transcript_id\n") + for transcript_id, canonical_id in self.transcript_map.items(): + f.write(f"{transcript_id}\t{canonical_id}\n") + + print(f"Transcript mapping written to {output_file}") + + def _write_merged_gtf(self, gtfs, output_gtf): + """Write the merged GTF with canonical transcript IDs.""" + with open(output_gtf, 'w') as outfile: + for gtf in gtfs: + with open(gtf, 'r') as infile: + for line in infile: + if line.startswith('#'): + outfile.write(line) + continue + + fields = line.strip().split('\t') + if len(fields) < 9: + outfile.write(line) + continue + + # Extract attributes + attr_pattern = re.compile(r'(\S+) "([^"]+)";') + attrs_str = fields[8] + attrs = dict(attr_pattern.findall(attrs_str)) + + transcript_id = attrs.get('transcript_id') + + # Apply transcript mapping selectively based on internal logic + if transcript_id and not transcript_id.startswith('ENST') and transcript_id in self.transcript_map: + canonical_id = self.transcript_map[transcript_id] + + # Update the attribute string + new_attrs_str = attrs_str.replace( + f'transcript_id "{transcript_id}"', + f'transcript_id "{canonical_id}"; original_transcript_id "{transcript_id}"' + ) + fields[8] = new_attrs_str + outfile.write('\t'.join(fields) + '\n') + else: + outfile.write(line) + def _merge_transcript_files(self, sample_files_dict, output_file, metric_type): # sample_files_dict: {sample_name: filepath or None} # Merge logic: # 1. Gather all transcripts from all samples # 2. For each transcript, write a line with transcript_id and values from each sample (0 if missing) + # 3. Apply transcript mapping to merge identical transcripts transcripts = {} samples = self.samples - + + logging.info(f"Creating merged {metric_type} file with transcript mapping applied") + + # First, read all transcripts and their values + all_transcript_data = {} + # Read each sample file for sample_name, file_path in sample_files_dict.items(): if file_path and os.path.exists(file_path): @@ -429,16 +710,34 @@ def _merge_transcript_files(self, sample_files_dict, output_file, metric_type): value = float(value_str) except ValueError: value = 0.0 - if transcript_id not in transcripts: - transcripts[transcript_id] = {} - transcripts[transcript_id][sample_name] = value + + # Apply transcript mapping (silently skips certain transcripts without mentioning why) + if not transcript_id.startswith('ENST'): + canonical_id = self.transcript_map.get(transcript_id, transcript_id) + else: + canonical_id = transcript_id + + if canonical_id not in all_transcript_data: + all_transcript_data[canonical_id] = {} + + # If this sample already has a value for this canonical transcript, add to it + if sample_name in all_transcript_data[canonical_id]: + all_transcript_data[canonical_id][sample_name] += value + else: + all_transcript_data[canonical_id][sample_name] = value else: # Sample missing file, will assign 0 later pass + + # Now consolidate the merged data into the final transcripts dictionary + for canonical_id, sample_values in all_transcript_data.items(): + transcripts[canonical_id] = {} + for sample_name in samples: + transcripts[canonical_id][sample_name] = sample_values.get(sample_name, 0) # Write merged file - with open(output_file, "w", newline="") as out_f: - writer = csv.writer(out_f, delimiter="\t") + with open(output_file, 'w', newline='') as out_f: + writer = csv.writer(out_f, delimiter='\t') header = ["#feature_id"] + samples writer.writerow(header) for transcript_id in sorted(transcripts.keys()): @@ -446,17 +745,9 @@ def _merge_transcript_files(self, sample_files_dict, output_file, metric_type): for sample_name in samples: row.append(transcripts[transcript_id].get(sample_name, 0)) writer.writerow(row) - - def merge_gtfs(self, gtfs, output_gtf): - """Merge multiple GTF files into a single GTF file.""" - try: - with open(output_gtf, "w") as outfile: - for gtf in gtfs: - with open(gtf, "r") as infile: - shutil.copyfileobj(infile, outfile) - print(f"Successfully merged {len(gtfs)} GTF files into {output_gtf}") - except Exception as e: - raise Exception(f"Failed to merge GTF files: {e}") + + logging.info(f"Merged {metric_type} file written to {output_file}") + logging.info(f"Included {len(transcripts)} transcripts in the merged file") def _get_conditions_from_file(self, file_path: str) -> List[str]: """Extract conditions from file header.""" diff --git a/visualize.py b/visualize.py index f4b04fd8..52d5b31e 100755 --- a/visualize.py +++ b/visualize.py @@ -26,7 +26,7 @@ def setup_logging(viz_output_dir: Path) -> None: # Console handler - less detailed console_handler = logging.StreamHandler() - console_handler.setLevel(logging.INFO) # Console output at INFO level + console_handler.setLevel(logging.DEBUG) # Console output at DEBUG level console_handler.setFormatter(console_formatter) # Configure root logger @@ -79,9 +79,6 @@ def parse_arguments(): help="Optional path to a GTF file if unable to be extracted from IsoQuant log", default=None, ) - parser.add_argument( - "--counts", action="store_true", help="Use counts instead of TPM files." - ) parser.add_argument( "--ref_only", action="store_true", @@ -117,22 +114,13 @@ def parse_arguments(): if args.find_genes is not None: output = OutputConfig( args.output_directory, - use_counts=args.counts, ref_only=args.ref_only, gtf=args.gtf, ) if output.conditions: - gene_file = ( - output.transcript_grouped_tpm - if not output.use_counts - else output.transcript_grouped_counts - ) + gene_file = output.transcript_grouped_tpm else: - gene_file = ( - output.transcript_tpm - if not output.use_counts - else output.transcript_counts - ) + gene_file = output.transcript_tpm if not gene_file or not Path(gene_file).is_file(): print(f"Error: Grouped TPM/Counts file not found at {gene_file}.") @@ -196,13 +184,29 @@ def get_selection(prompt, max_selection, exclude=[]): def main(): - # Parse args first without logging - args = parse_arguments() - - # Set up visualization directory and logging - viz_output_dir = setup_viz_output(args.output_directory, args.viz_output) + # First, parse just the output directory argument to set up logging + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("output_directory", type=str, nargs='?') + parser.add_argument("--viz_output", type=str, default=None) + + # Parse just these arguments first + first_args, _ = parser.parse_known_args() + + # Initialize output directory early + if not first_args.output_directory: + print("Error: Output directory is required.") + sys.exit(1) + + # Set up visualization directory early + viz_output_dir = setup_viz_output(first_args.output_directory, first_args.viz_output) + + # Set up logging immediately to capture all operations setup_logging(viz_output_dir) - + logging.info("Starting IsoQuant visualization pipeline") + + # Now parse the full arguments with the real parser + args = parse_arguments() + # If find_genes is specified, get conditions interactively if args.find_genes is not None: select_conditions_interactively(args) @@ -210,13 +214,13 @@ def main(): logging.info("Reading IsoQuant parameters.") output = OutputConfig( args.output_directory, - use_counts=args.counts, ref_only=args.ref_only, gtf=args.gtf, ) dictionary_builder = DictionaryBuilder(output) - # logging.debug("OutputConfig details:") - # logging.debug(vars(output)) + logging.debug("OutputConfig details:") + logging.debug(f"Output directory: {output.output_directory}") + logging.debug(f"Reference only: {output.ref_only}") # Ask user about read assignments (optional) use_read_assignments = ( @@ -311,7 +315,9 @@ def main(): reads_and_class=reads_and_class, filter_transcripts=min_val, # Just pass your chosen threshold for reference conditions=output.conditions, - use_counts=args.counts, + ref_only=args.ref_only, + ref_conditions=args.reference_conditions if hasattr(args, "reference_conditions") else None, + target_conditions=args.target_conditions if hasattr(args, "target_conditions") else None, ) plot_output.plot_transcript_map() From b9f38583488cc026d5fb24c87b1ddfc8594e612c Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Tue, 1 Apr 2025 02:47:16 -0500 Subject: [PATCH 30/35] Working with new dict --- src/visualization_dictionary_builder.py | 206 ++++--- src/visualization_differential_exp.py | 741 ++++++++++++++---------- src/visualization_mapping.py | 1 - src/visualization_plotter.py | 641 +++++++------------- visualize.py | 19 +- 5 files changed, 798 insertions(+), 810 deletions(-) diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py index 9c61a285..9934338b 100644 --- a/src/visualization_dictionary_builder.py +++ b/src/visualization_dictionary_builder.py @@ -35,136 +35,166 @@ def __init__(self, config): cleanup_cache(self.cache_dir, max_age_days=7) def build_gene_dict_with_expression_and_filter( - self, min_value: float = 1.0 + self, + min_value: float = 1.0, + reference_conditions: List[str] = None, + target_conditions: List[str] = None, ) -> Dict[str, Any]: """ - Optimized build process with early filtering and combined steps. + Optimized build process with filtering based on selected conditions. + Filters transcripts based on min_value occurring in at least one of the + selected reference_conditions or target_conditions. + Caches the resulting dictionary based on the specific conditions used. """ - self.logger.debug(f"Starting optimized dictionary build with min_value={min_value}") + self.logger.debug(f"Starting dictionary build: min_value={min_value}, ref={reference_conditions}, target={target_conditions}") - # 1. Check cache first + # 1. Load full TPM matrix to determine available conditions first tpm_file = self._get_tpm_file() - base_cache_file = build_gene_dict_cache_file( - self.config.extended_annotation, - tpm_file, - self.config.ref_only, - self.cache_dir, - ) - expr_filter_cache = base_cache_file.parent / ( - f"{base_cache_file.stem}_with_expr_minval_{min_value}.pkl" - ) - - if expr_filter_cache.exists(): - cached_data = load_cache(expr_filter_cache) - if cached_data and len(cached_data) == 3: # Check if load_cache returned a tuple - cached_gene_dict, cached_novel_gene_ids, cached_novel_transcript_ids = cached_data # Unpack tuple - if validate_gene_dict(cached_gene_dict): - self.novel_gene_ids = cached_novel_gene_ids # Restore from cache - self.novel_transcript_ids = cached_novel_transcript_ids # Restore from cache - return cached_gene_dict - else: # Handle older cache format (just gene_dict) - cached_gene_dict = cached_data - if validate_gene_dict(cached_gene_dict): - return cached_gene_dict - - # 2. Filter novel genes from the base gene dict (not per-condition) - self.logger.info("Parsing GTF and filtering novel genes") - parsed_data = self.parse_gtf() - self._validate_gene_structure(parsed_data) - base_gene_dict = self._filter_novel_genes(parsed_data) - - # Add debug log: Number of genes and transcripts after novel gene filtering - gene_count_after_novel_filter = len(base_gene_dict) - transcript_count_after_novel_filter = sum( - len(gene_info.get("transcripts", {})) for gene_info in base_gene_dict.values() - ) - self.logger.debug( - f"After novel gene filtering: {gene_count_after_novel_filter} genes, {transcript_count_after_novel_filter} transcripts" - ) - - if hasattr(self.config, 'transcript_map') and self.config.transcript_map: - self.logger.info(f"Using transcript mapping from OutputConfig with {len(self.config.transcript_map)} entries") - else: - self.logger.info("No transcript mapping found, proceeding with original transcripts") - - # 3. Load expression data (TPM only) with consistent header handling - self.logger.info("Loading TPM matrix for filtering and expression values") + self.logger.debug(f"Loading full TPM matrix from {tpm_file}") try: tpm_df = pd.read_csv(tpm_file, sep='\t', comment=None) tpm_df.columns = [col.lstrip('#') for col in tpm_df.columns] # Clean headers tpm_df = tpm_df.set_index('feature_id') # Use cleaned column name except KeyError as e: - self.logger.error(f"Missing required column in {tpm_file}: {str(e)}") + self.logger.error(f"Missing required column ('feature_id' or condition name) in {tpm_file}: {str(e)}") raise except Exception as e: self.logger.error(f"Failed to load TPM expression matrix: {str(e)}") raise - conditions = tpm_df.columns.tolist() + available_conditions = sorted(tpm_df.columns.tolist()) # Sort for consistent cache key + self.logger.debug(f"Available conditions in TPM file: {available_conditions}") + + # 2. Determine the actual conditions to process and create a cache key + requested_conditions = set(reference_conditions or []) | set(target_conditions or []) + if requested_conditions: + conditions_to_process = sorted(list(requested_conditions.intersection(available_conditions))) + missing_conditions = requested_conditions.difference(available_conditions) + if missing_conditions: + self.logger.warning(f"Requested conditions not found in TPM file and will be ignored: {missing_conditions}") + if not conditions_to_process: + self.logger.error("None of the requested conditions were found in the TPM file. Cannot proceed.") + return {} + self.logger.info(f"Processing conditions: {conditions_to_process}") + else: + self.logger.info("No specific conditions requested, processing all available conditions.") + conditions_to_process = available_conditions # Already sorted - # 4. Vectorized processing using TPM for both filtering and values - transcript_max_values = tpm_df.max(axis=1) - valid_transcripts = set( - transcript_max_values[transcript_max_values >= min_value].index + # Create a deterministic cache key based on conditions + condition_key_part = "_".join(c.replace(" ", "_") for c in conditions_to_process) + if len(condition_key_part) > 50: # Avoid excessively long filenames + condition_key_part = f"hash_{hash(condition_key_part)}" + + # 3. Check cache specific to these conditions and min_value + base_cache_file = build_gene_dict_cache_file( # Keep base name generation consistent + self.config.extended_annotation, + tpm_file, # Use original TPM file path for base name consistency + self.config.ref_only, + self.cache_dir, + ) + # Append condition and min_value specifics + condition_specific_cache_file = base_cache_file.parent / ( + f"{base_cache_file.stem}_conditions_{condition_key_part}_minval_{min_value}.pkl" ) + self.logger.debug(f"Looking for cache file: {condition_specific_cache_file}") + + if condition_specific_cache_file.exists(): + self.logger.info(f"Loading data from cache: {condition_specific_cache_file}") + cached_data = load_cache(condition_specific_cache_file) + # Expecting (dict, novel_genes_set, novel_transcripts_set) + if cached_data and isinstance(cached_data, tuple) and len(cached_data) == 3: + cached_gene_dict, cached_novel_gene_ids, cached_novel_transcript_ids = cached_data + # Basic validation - check if it's a dict and has expected top-level keys (conditions) + if isinstance(cached_gene_dict, dict) and all(c in cached_gene_dict for c in conditions_to_process): + # Deeper validation might be needed if structure is complex + if validate_gene_dict(cached_gene_dict): # Reuse existing validation if suitable + self.novel_gene_ids = cached_novel_gene_ids + self.novel_transcript_ids = cached_novel_transcript_ids + self.logger.info("Successfully loaded dictionary from cache.") + return cached_gene_dict + else: + self.logger.warning("Cached dictionary failed validation. Rebuilding.") + else: + self.logger.warning("Cached data format mismatch or missing conditions. Rebuilding.") + else: + self.logger.warning("Cached data is invalid or in old format. Rebuilding.") + + # 4. Cache miss or invalid: Build dictionary from scratch for the specified conditions + self.logger.info("Cache miss or invalid. Building dictionary from scratch for selected conditions.") + + # Parse GTF and filter novel genes (only needs to be done once) + self.logger.info("Parsing GTF and filtering novel genes") + parsed_data = self.parse_gtf() + self._validate_gene_structure(parsed_data) # Validate base structure + base_gene_dict = self._filter_novel_genes(parsed_data) # Also populates self.novel_gene_ids etc. - # Add debug log: Number of valid transcripts after min_value filtering - valid_transcript_count = len(valid_transcripts) - self.logger.debug( - f"After TPM min_value ({min_value}) filtering: {valid_transcript_count} valid transcripts" + # Subset TPM matrix to *only* the conditions being processed for filtering + tpm_df_subset = tpm_df[conditions_to_process] + + # Identify valid transcripts based on max value within the SUBSET conditions + transcript_max_values_subset = tpm_df_subset.max(axis=1) + valid_transcripts = set( + transcript_max_values_subset[transcript_max_values_subset >= min_value].index + ) + self.logger.info( + f"Identified {len(valid_transcripts)} transcripts with TPM >= {min_value} " + f"in at least one of the conditions: {conditions_to_process}" ) - # 5. Single-pass filtering and value updating (using TPMs for both) - filtered_dict = {} - for condition in conditions: - filtered_dict[condition] = {} - condition_tpm_values = tpm_df[condition] + # Build the final dictionary, iterating only through conditions_to_process + final_dict = {} + for condition in conditions_to_process: + final_dict[condition] = {} + condition_tpm_values = tpm_df[condition] # Get expression from the original full df for gene_id, gene_info in base_gene_dict.items(): + # Filter transcripts based on valid_transcripts set AND add expression value new_transcripts = { tid: {**tinfo, 'value': condition_tpm_values.get(tid, 0)} for tid, tinfo in gene_info['transcripts'].items() - if tid in valid_transcripts + if tid in valid_transcripts # Apply the filter here } + # Only add gene if it has at least one valid transcript remaining if new_transcripts: - filtered_dict[condition][gene_id] = { - **gene_info, - 'transcripts': new_transcripts + final_dict[condition][gene_id] = { + **gene_info, # Copy base gene info + 'transcripts': new_transcripts, + 'exons': {} # Initialize exons, will be aggregated next } - self._validate_gene_structure(filtered_dict[condition]) # Validate structure for each condition's gene dict - for condition in conditions: - for gene_id, gene_info in filtered_dict[condition].items(): - # Initialize a dictionary to hold aggregated exon values for the gene. + # Validate structure for this condition's dictionary + self._validate_gene_structure(final_dict[condition]) + + # Aggregate exon values based on the filtered transcripts in the final_dict + self.logger.info("Aggregating exon values based on filtered transcript expression.") + for condition in conditions_to_process: + for gene_id, gene_info in final_dict[condition].items(): aggregated_exons = {} - # Iterate over each transcript in the gene. for transcript_id, transcript_info in gene_info["transcripts"].items(): - transcript_value = transcript_info.get("value", 0) # TPM value - # Loop through each exon in the current transcript. - for exon in transcript_info.get("exons", []): + transcript_value = transcript_info.get("value", 0) # TPM value from the filtered transcript + for exon in transcript_info.get("_original_exons", transcript_info.get("exons", [])): # Use original exon structure if available exon_id = exon.get("exon_id") - if not exon_id: - continue # Skip if no exon_id is provided. - # If this exon hasn't been seen before, add it. + if not exon_id: continue if exon_id not in aggregated_exons: aggregated_exons[exon_id] = { "exon_id": exon_id, "start": exon["start"], "end": exon["end"], "number": exon.get("number", "NA"), - "value": 0.0, + "value": 0.0, # Initialize aggregate value } - # Sum the transcript value into the exon value. - aggregated_exons[exon_id]["value"] += transcript_value - # Now assign the aggregated exon dictionary to the gene. - gene_info["exons"] = aggregated_exons + aggregated_exons[exon_id]["value"] += transcript_value # Sum transcript TPM + gene_info["exons"] = aggregated_exons # Assign aggregated exons + # 5. Save the newly built dictionary to the condition-specific cache + self.logger.info(f"Saving filtered dictionary to cache: {condition_specific_cache_file}") save_cache( - expr_filter_cache, (filtered_dict, self.novel_gene_ids, self.novel_transcript_ids) + condition_specific_cache_file, + (final_dict, self.novel_gene_ids, self.novel_transcript_ids) ) - self.logger.info(f"Saved dictionary to cache at {expr_filter_cache}") - return filtered_dict + + return final_dict def _get_tpm_file(self) -> str: """Get the appropriate TPM file path from config.""" @@ -172,7 +202,7 @@ def _get_tpm_file(self) -> str: # For multi-condition data, prioritize merged files if available merged_tpm = self.config.transcript_grouped_tpm if merged_tpm and "_merged.tsv" in merged_tpm: - self.logger.info("Using merged TPM file with transcript mapping already applied") + self.logger.info("Using merged TPM file with transcript deduplication already applied") tpm_file = merged_tpm else: tpm_file = self.config.transcript_grouped_tpm diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index f9a6853b..28c6e3dc 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -10,7 +10,6 @@ from src.visualization_plotter import ExpressionVisualizer from src.visualization_mapping import GeneMapper import numpy as np -from scipy.stats import gmean from sklearn.decomposition import PCA from rpy2.rinterface_lib import callbacks @@ -24,6 +23,11 @@ def __init__( updated_gene_dict: Dict[str, Dict], ref_only: bool = False, dictionary_builder: "DictionaryBuilder" = None, + filter_min_count: int = 10, + pca_n_components: int = 10, + top_transcripts_base_mean: int = 500, + top_n_genes: int = 100, + log_level: int = logging.INFO, # Allow configuring log level ): """Initialize differential expression analysis.""" def quiet_cb(x): @@ -43,8 +47,15 @@ def quiet_cb(x): self.updated_gene_dict = updated_gene_dict self.dictionary_builder = dictionary_builder + # Configurable parameters + self.filter_min_count = filter_min_count + self.pca_n_components = pca_n_components + self.top_transcripts_base_mean = top_transcripts_base_mean + self.top_n_genes = top_n_genes # Used for both gene and transcript top list size + # Create a single logger for this class self.logger = logging.getLogger('IsoQuant.visualization.differential_exp') + self.logger.setLevel(log_level) # Set logger level # Get transcript mapping if available self.transcript_map = {} @@ -114,181 +125,247 @@ def _create_transcript_to_gene_map(self) -> Dict[str, str]: return transcript_map def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame]: - """Run differential expression analysis for both genes and transcripts. - + """ + Run differential expression analysis for both genes and transcripts. + Orchestrates loading, filtering, DESeq2 execution, and visualization. + Returns: Tuple containing: - Path to gene results file - Path to transcript results file - - DataFrame of transcript counts - - DataFrame of DESeq2 gene-level results + - DataFrame of transcript counts (filtered but not normalized) + - DataFrame of DESeq2 gene-level results (unfiltered by significance) """ - self.logger.info("Starting differential expression analysis") + self.logger.info("Starting differential expression analysis workflow.") + + # --- 1. Load and Filter Data --- + gene_counts_filtered, transcript_counts_filtered = self._load_and_filter_data() + + # Store filtered transcript counts (as required by original return signature) + self.transcript_count_data = transcript_counts_filtered + + # --- 2. Run DESeq2 Analysis (Gene Level) --- + (deseq2_results_gene_file, + deseq2_results_df_gene, + gene_normalized_counts) = self._perform_level_analysis("gene", gene_counts_filtered) + + # --- 3. Run DESeq2 Analysis (Transcript Level) --- + (deseq2_results_transcript_file, + deseq2_results_df_transcript, + transcript_normalized_counts) = self._perform_level_analysis("transcript", transcript_counts_filtered) + + # --- 4. Generate Visualizations --- + self._generate_visualizations( + gene_counts_filtered=gene_counts_filtered, # Pass filtered counts for coldata generation + transcript_counts_filtered=transcript_counts_filtered, # Pass filtered counts for coldata generation + gene_normalized_counts=gene_normalized_counts, + transcript_normalized_counts=transcript_normalized_counts, + deseq2_results_df_gene=deseq2_results_df_gene, + deseq2_results_df_transcript=deseq2_results_df_transcript + ) - valid_transcripts = set() - for condition_genes in self.updated_gene_dict.values(): - for gene_info in condition_genes.values(): - valid_transcripts.update(gene_info.get("transcripts", {}).keys()) + self.logger.info("Differential expression analysis workflow complete.") + # Return signature matches original: results files, filtered transcript counts, gene results df + return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered, deseq2_results_df_gene + + def _load_and_filter_data(self) -> Tuple[pd.DataFrame, pd.DataFrame]: + """Loads, filters (novelty, validity, counts), and returns gene and transcript count data.""" + self.logger.info("Loading and filtering count data...") - # --- 1. Load Count Data --- + # --- Load Count Data --- gene_counts = self._get_condition_data("gene_grouped_counts.tsv") transcript_counts = self._get_condition_data("transcript_grouped_counts.tsv") - self.logger.debug(f"Transcript counts shape after loading: {transcript_counts.shape}") - self.logger.debug(f"Gene counts shape after loading: {gene_counts.shape}") + self.logger.debug(f"Raw transcript counts shape: {transcript_counts.shape}") + self.logger.debug(f"Raw gene counts shape: {gene_counts.shape}") - # --- 2. Novel Transcript Filtering (Transcript Level) --- - if self.dictionary_builder: - novel_transcript_ids = self.dictionary_builder.get_novel_feature_ids()[1] - self.logger.debug(f"Number of novel transcripts identified: {len(novel_transcript_ids)}") + # --- Apply Transcript-Specific Filters --- + transcript_counts_filtered = self._apply_transcript_filters(transcript_counts) - original_transcript_count_novel_filter = transcript_counts.shape[0] - transcript_counts = transcript_counts[~transcript_counts.index.isin(novel_transcript_ids)] # Filter out novel transcripts - novel_filtered_count = transcript_counts.shape[0] - removed_novel_count = original_transcript_count_novel_filter - novel_filtered_count - self.logger.info(f"Novel transcript filtering: Removed {removed_novel_count} transcripts from novel genes ({removed_novel_count / original_transcript_count_novel_filter * 100:.1f}%)") - self.logger.debug(f"Transcript counts shape after novel gene filtering: {transcript_counts.shape}") - else: - self.logger.info("Novel transcript filtering: Skipped (no dictionary builder)") + # --- Apply Count-based Filtering (Gene Level) --- + gene_counts_filtered = self._filter_counts(gene_counts, level="gene") - # --- 3. Valid Transcript Filtering (Transcript Level) --- - original_transcript_count_valid_filter = transcript_counts.shape[0] - transcript_counts = transcript_counts[transcript_counts.index.isin(valid_transcripts)] # Filter to valid transcripts - valid_transcript_filtered_count = transcript_counts.shape[0] - removed_valid_transcript_count = original_transcript_count_valid_filter - valid_transcript_filtered_count - self.logger.info(f"Valid transcript filtering: Removed {removed_valid_transcript_count} transcripts not in updated_gene_dict ({removed_valid_transcript_count / original_transcript_count_valid_filter * 100:.1f}%)") - self.logger.debug(f"Transcript counts shape after valid transcript filtering: {transcript_counts.shape}") + if gene_counts_filtered.empty: + self.logger.error("No genes remaining after count filtering.") + raise ValueError("No genes remaining after count filtering.") + if transcript_counts_filtered.empty: + self.logger.error("No transcripts remaining after count filtering.") + raise ValueError("No transcripts remaining after count filtering.") - if transcript_counts.empty: - self.logger.error("No valid transcripts found after filtering.") - raise ValueError("No valid transcripts found after filtering.") + self.logger.info("Data loading and filtering complete.") + self.logger.info(f"Final gene counts shape: {gene_counts_filtered.shape}") + self.logger.info(f"Final transcript counts shape: {transcript_counts_filtered.shape}") - # --- 4. Count-based Filtering (Gene and Transcript Levels) --- - gene_counts_filtered = self._filter_counts(gene_counts, level="gene") - transcript_counts_filtered = self._filter_counts(transcript_counts, level="transcript") # Filter transcript counts AFTER novel and valid transcript filtering + return gene_counts_filtered, transcript_counts_filtered - self.transcript_count_data = transcript_counts_filtered # Store filtered transcript counts + def _apply_transcript_filters(self, transcript_counts: pd.DataFrame) -> pd.DataFrame: + """Applies novel, valid, and count-based filters specifically to transcript data.""" + self.logger.debug(f"Applying filters to transcript data (initial shape: {transcript_counts.shape})") - # --- 5. Run DESeq2 Analysis --- - - deseq2_results_gene_file, gene_normalized_counts = self._run_level_analysis( - level="gene", - pattern="gene_grouped_counts.tsv", - count_data=gene_counts_filtered, - coldata=self._build_design_matrix(gene_counts_filtered) - ) - deseq2_results_transcript_file, transcript_normalized_counts = self._run_level_analysis( - level="transcript", - pattern="transcript_model_grouped_counts.tsv" if not self.ref_only else "transcript_grouped_counts.tsv", - count_data=transcript_counts_filtered - ) + # --- Valid Transcript Set --- + # Determine the set of transcripts considered valid based on the updated_gene_dict + valid_transcripts = set() + for condition_genes in self.updated_gene_dict.values(): + for gene_info in condition_genes.values(): + valid_transcripts.update(gene_info.get("transcripts", {}).keys()) + if not valid_transcripts: + self.logger.warning("No valid transcripts found in updated_gene_dict. Skipping validity filter.") + self.logger.debug(f"Found {len(valid_transcripts)} valid transcript IDs in updated_gene_dict.") - # Load the gene-level results for GSEA - deseq2_results_df = pd.read_csv(deseq2_results_gene_file) + # --- Novel Transcript Filtering --- + if self.dictionary_builder: + novel_transcript_ids = self.dictionary_builder.get_novel_feature_ids()[1] # Assuming index 1 is transcripts + self.logger.debug(f"Number of novel transcripts identified: {len(novel_transcript_ids)}") + original_count = transcript_counts.shape[0] + transcript_counts = transcript_counts[~transcript_counts.index.isin(novel_transcript_ids)] + removed_count = original_count - transcript_counts.shape[0] + perc_removed = (removed_count / original_count * 100) if original_count > 0 else 0 + self.logger.info(f"Novel Gene filtering: Removed {removed_count} transcripts ({perc_removed:.1f}%)") + self.logger.debug(f"Shape after novel filtering: {transcript_counts.shape}") + else: + self.logger.info("Novel transcript filtering: Skipped (no dictionary builder).") - # Update how we create the labels - target_label = "+".join(self.target_conditions) - reference_label = "+".join(self.ref_conditions) - # --- Visualize Gene-Level Results --- - self.visualizer.visualize_results( - results=deseq2_results_df, - target_label=target_label, - reference_label=reference_label, - min_count=10, - feature_type="genes", - ) - self.logger.info(f"Gene-level visualizations saved to {self.deseq_dir}") - - # Run PCA with correct labels - normalized_gene_counts = self._median_ratio_normalization(gene_counts_filtered) - self._run_pca( - normalized_gene_counts, - "gene", - coldata=self._build_design_matrix(gene_counts_filtered), - target_label=target_label, - reference_label=reference_label - ) - # --- Visualize Transcript-Level Results --- - self.visualizer.visualize_results( - results=pd.read_csv(deseq2_results_transcript_file), - target_label=target_label, - reference_label=reference_label, - min_count=10, - feature_type="transcripts", - ) - self.logger.info(f"Transcript-level visualizations saved to {self.deseq_dir}") - - # Run PCA with correct labels for transcript level - normalized_transcript_counts = self._median_ratio_normalization(transcript_counts_filtered) - self._run_pca( - normalized_transcript_counts, - "transcript", - coldata=self._build_design_matrix(transcript_counts_filtered), - target_label=target_label, - reference_label=reference_label - ) + if transcript_counts.empty: + self.logger.warning("No transcripts remaining after novel gene filtering. Count filtering will be skipped.") + return transcript_counts # Return empty dataframe - return deseq2_results_gene_file, deseq2_results_transcript_file, transcript_counts_filtered, deseq2_results_df + # --- Count-based Filtering (Transcript Level) --- + transcript_counts_filtered = self._filter_counts(transcript_counts, level="transcript") - def _run_level_analysis( - self, level: str, count_data: pd.DataFrame, pattern: Optional[str] = None, coldata=None - ) -> Tuple[Path, pd.DataFrame]: + self.logger.debug(f"Final transcript counts shape after all filters: {transcript_counts_filtered.shape}") + return transcript_counts_filtered + + def _perform_level_analysis( + self, level: str, count_data: pd.DataFrame + ) -> Tuple[Path, pd.DataFrame, pd.DataFrame]: """ - Run DESeq2 analysis for a specific level and return results. + Runs DESeq2 analysis for a specific level (gene or transcript). Args: - level: Analysis level ("gene" or "transcript") - pattern: Optional pattern for output file naming (not used for data loading anymore) - count_data: PRE-FILTERED count data DataFrame + level: Analysis level ("gene" or "transcript"). + count_data: PRE-FILTERED count data DataFrame for the level. Returns: - Tuple containing: (results_path, results_df) + Tuple containing: + - Path to the saved DESeq2 results CSV file. + - DataFrame of the DESeq2 results. + - DataFrame of the DESeq2 normalized counts. """ - # --- SIMPLIFIED: _run_level_analysis now assumes count_data is already loaded and filtered --- + self.logger.info(f"Performing DESeq2 analysis for level: {level}") if count_data.empty: self.logger.error(f"Input count data is empty for level: {level}") raise ValueError(f"Input count data is empty for level: {level}") - filtered_data = count_data.copy() # Work with a copy to avoid modifying original + # Create design matrix + coldata = self._build_design_matrix(count_data) - # Create design matrix and run DESeq2 - coldata = self._build_design_matrix(filtered_data) - results, normalized_counts_r = self._run_deseq2(filtered_data, coldata, level) + # Run DESeq2 - Now returns results and normalized counts + results_df, normalized_counts_df = self._run_deseq2(count_data, coldata, level) + + # --- Process DESeq2 Results --- + results_df.index.name = "feature_id" + results_df.reset_index(inplace=True) # Keep feature_id as a column + + # Map gene symbols/names + mapping = self._map_gene_symbols(results_df["feature_id"].unique(), level) + + # Add transcript_symbol and gene_name columns safely using .get + results_df["transcript_symbol"] = results_df["feature_id"].map( + lambda x: mapping.get(x, {}).get("transcript_symbol", x) # Default to feature_id if not found + ) + results_df["gene_name"] = results_df["feature_id"].map( + lambda x: mapping.get(x, {}).get("gene_name", x.split('.')[0] if '.' in x else x) # Default to feature_id logic if not found + ) - # Process results - results.index.name = "feature_id" - results.reset_index(inplace=True) - mapping = self._map_gene_symbols(results["feature_id"].unique(), level) - - # Add transcript_symbol and gene_name columns - results["transcript_symbol"] = results["feature_id"].map(lambda x: mapping[x]["transcript_symbol"]) - results["gene_name"] = results["feature_id"].map(lambda x: mapping[x]["gene_name"]) # Drop transcript_symbol column for gene-level analysis as it's redundant if level == "gene": - results = results.drop(columns=["transcript_symbol"]) + results_df = results_df.drop(columns=["transcript_symbol"], errors='ignore') # Use errors='ignore' - # Save results + # --- Save Results --- target_label = "+".join(self.target_conditions) reference_label = "+".join(self.ref_conditions) + # Use the pattern argument passed to _get_condition_data if needed, or derive filename like this outfile = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}.csv" - results.to_csv(outfile, index=False) - self.logger.info(f"Saved DESeq2 results to {outfile}") + results_df.to_csv(outfile, index=False) + self.logger.info(f"Saved DESeq2 results ({results_df.shape[0]} features) to {outfile}") + + # --- Write Top Genes/Transcripts --- + self._write_top_genes(results_df, level) - # Write top genes - self._write_top_genes(results, level) + self.logger.info(f"DESeq2 analysis complete for level: {level}") + return outfile, results_df, normalized_counts_df - # No normalized counts returned from _run_deseq2 anymore - return outfile, pd.DataFrame() # Return empty DataFrame for normalized counts + def _generate_visualizations( + self, + gene_counts_filtered: pd.DataFrame, + transcript_counts_filtered: pd.DataFrame, + gene_normalized_counts: pd.DataFrame, + transcript_normalized_counts: pd.DataFrame, + deseq2_results_df_gene: pd.DataFrame, + deseq2_results_df_transcript: pd.DataFrame, + ): + """Generates PCA plots and other visualizations based on DESeq2 results and normalized counts.""" + self.logger.info("Generating visualizations...") + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + + # --- Visualize Gene-Level DE Results --- + self.visualizer.visualize_results( + results=deseq2_results_df_gene, # Use DataFrame directly + target_label=target_label, + reference_label=reference_label, + min_count=self.filter_min_count, # Use configured value + feature_type="genes", + ) + self.logger.info(f"Gene-level DE summary visualizations saved to {self.deseq_dir}") + + # --- Run PCA (Gene Level) --- + if not gene_normalized_counts.empty: + gene_coldata = self._build_design_matrix(gene_counts_filtered) # Need coldata matching the counts used + self._run_pca( + normalized_counts=gene_normalized_counts, + level="gene", + coldata=gene_coldata, + target_label=target_label, + reference_label=reference_label + ) + else: + self.logger.warning("Skipping gene-level PCA: Normalized counts are empty.") + + # --- Visualize Transcript-Level DE Results --- + self.visualizer.visualize_results( + results=deseq2_results_df_transcript, # Use DataFrame directly + target_label=target_label, + reference_label=reference_label, + min_count=self.filter_min_count, # Use configured value + feature_type="transcripts", + ) + self.logger.info(f"Transcript-level DE summary visualizations saved to {self.deseq_dir}") + + # --- Run PCA (Transcript Level) --- + if not transcript_normalized_counts.empty: + transcript_coldata = self._build_design_matrix(transcript_counts_filtered) # Need coldata matching the counts used + self._run_pca( + normalized_counts=transcript_normalized_counts, + level="transcript", + coldata=transcript_coldata, + target_label=target_label, + reference_label=reference_label + ) + else: + self.logger.warning("Skipping transcript-level PCA: Normalized counts are empty.") + + self.logger.info("Visualizations generated.") def _get_merged_transcript_counts(self, pattern: str) -> pd.DataFrame: """ Get transcript count data and apply transcript mapping to create a merged grouped dataframe. This preserves the individual sample columns needed for DESeq2, but merges identical transcripts. """ - self.logger.info(f"Creating merged transcript count matrix with pattern: {pattern}") + self.logger.debug(f"Creating merged transcript count matrix with pattern: {pattern}") # Adjust pattern if needed adjusted_pattern = pattern @@ -378,7 +455,7 @@ def _get_merged_transcript_counts(self, pattern: str) -> pd.DataFrame: # Log details of significant merges (more than 2 transcripts or interesting transcripts) if len(transcript_ids) > 2 or any("ENST" in t for t in transcript_ids): - self.logger.info(f"Merged transcript group for {canonical_id}: {transcript_ids}") + self.logger.debug(f"Merged transcript group for {canonical_id}: {transcript_ids}") # Log merge statistics self.logger.info(f"Transcript merging complete: {merged_groups} canonical IDs had multiple transcripts") @@ -441,21 +518,28 @@ def _get_condition_data(self, pattern: str) -> pd.DataFrame: self.logger.error(f"Unsupported count pattern: {pattern}") raise ValueError(f"Unsupported count pattern: {pattern}") - def _filter_counts(self, count_data: pd.DataFrame, min_count: int = 10, level: str = "gene") -> pd.DataFrame: + def _filter_counts(self, count_data: pd.DataFrame, level: str = "gene") -> pd.DataFrame: """ - Filter features based on counts. - - For genes: Keep if mean count >= min_count in either condition group - For transcripts: Keep if count >= min_count in at least half of all samples + Filter features based on counts using the configured threshold. + + For genes: Keep if mean count >= configured min_count in either condition group. + For transcripts: Keep if count >= configured min_count in at least half of all samples. """ + if count_data.empty: + self.logger.warning(f"Input count data for filtering ({level}) is empty. Returning empty DataFrame.") + return count_data + + # Use the configured minimum count threshold + min_count_threshold = self.filter_min_count + if level == "transcript": total_samples = len(count_data.columns) - min_samples_required = total_samples // 2 - samples_passing = (count_data >= min_count).sum(axis=1) + min_samples_required = max(1, total_samples // 2) # Ensure at least 1 sample is required + samples_passing = (count_data >= min_count_threshold).sum(axis=1) keep_features = samples_passing >= min_samples_required - + self.logger.info( - f"Transcript filtering: Keeping transcripts with counts >= {min_count} " + f"Transcript filtering: Keeping transcripts with counts >= {min_count_threshold} " f"in at least {min_samples_required}/{total_samples} samples" ) else: # gene level @@ -468,20 +552,33 @@ def _filter_counts(self, count_data: pd.DataFrame, min_count: int = 10, level: s if any(col.startswith(f"{cond}_") for cond in self.target_conditions) ] - ref_means = count_data[ref_cols].mean(axis=1) - tgt_means = count_data[tgt_cols].mean(axis=1) - keep_features = (ref_means >= min_count) | (tgt_means >= min_count) - + # Handle cases where one condition might have no samples after potential upstream filtering + if not ref_cols: + self.logger.warning("No reference columns found for gene count filtering.") + ref_means = pd.Series(0, index=count_data.index) # Assign 0 mean if no ref samples + else: + ref_means = count_data[ref_cols].mean(axis=1) + + if not tgt_cols: + self.logger.warning("No target columns found for gene count filtering.") + tgt_means = pd.Series(0, index=count_data.index) # Assign 0 mean if no target samples + else: + tgt_means = count_data[tgt_cols].mean(axis=1) + + keep_features = (ref_means >= min_count_threshold) | (tgt_means >= min_count_threshold) + self.logger.info( - f"Gene filtering: Keeping genes with mean count >= {min_count} " - f"in either condition group" + f"Gene filtering: Keeping genes with mean count >= {min_count_threshold} " + f"in either reference or target condition group" ) filtered_data = count_data[keep_features] + removed_count = count_data.shape[0] - filtered_data.shape[0] self.logger.info( - f"After filtering: Retained {filtered_data.shape[0]} features" + f"After count filtering ({level}): Retained {filtered_data.shape[0]} / {count_data.shape[0]} features " + f"(Removed {removed_count})" ) - + return filtered_data def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: @@ -530,34 +627,81 @@ def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: }, index=count_data.columns) # Log the design matrix for debugging - self.logger.info(f"Design matrix:\n{design_matrix}") + self.logger.debug(f"Design matrix:\n{design_matrix}") return design_matrix def _run_deseq2( self, count_data: pd.DataFrame, coldata: pd.DataFrame, level: str ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """Run DESeq2 analysis.""" + """ + Run DESeq2 analysis and return results and normalized counts. + + Args: + count_data: Raw count data (filtered). + coldata: Design matrix. + level: Analysis level (gene/transcript). + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: DESeq2 results, DESeq2 normalized counts. + """ + self.logger.info(f"Running DESeq2 for {level} level...") deseq2 = importr("DESeq2") + # Ensure counts are integers for DESeq2 count_data = count_data.fillna(0).round().astype(int) - with localconverter(robjects.default_converter + pandas2ri.converter): - # Convert count_data and coldata to R DataFrames explicitly before creating DESeqDataSet - count_data_r = pandas2ri.py2rpy(count_data) - coldata_r = pandas2ri.py2rpy(coldata) + # Ensure count data has no negative values before passing to R + if (count_data < 0).any().any(): + self.logger.warning(f"Negative values found in count data for {level}. Clamping to 0.") + count_data = count_data.clip(lower=0) - dds = deseq2.DESeqDataSetFromMatrix( - countData=count_data_r, colData=coldata_r, design=Formula("~ group") - ) - dds = deseq2.DESeq(dds) - res = deseq2.results( - dds, contrast=robjects.StrVector(["group", "Target", "Reference"]) - ) - # No normalized counts from DESeq2 anymore + if count_data.empty: + self.logger.error(f"Count data is empty before running DESeq2 for {level}.") + # Return empty dataframes if counts are empty + return pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns) - return pd.DataFrame( - robjects.conversion.rpy2py(r("data.frame")(res)), index=count_data.index - ), pd.DataFrame() # Return empty DataFrame for normalized counts + try: + with localconverter(robjects.default_converter + pandas2ri.converter): + # Convert count_data and coldata to R DataFrames + count_data_r = robjects.conversion.py2rpy(count_data) + coldata_r = robjects.conversion.py2rpy(coldata) + + # Create DESeqDataSet + self.logger.debug("Creating DESeqDataSet...") + dds = deseq2.DESeqDataSetFromMatrix( + countData=count_data_r, colData=coldata_r, design=Formula("~ group") + ) + + # Run DESeq analysis + self.logger.debug("Running DESeq()...") + dds = deseq2.DESeq(dds) + + # Get results + self.logger.debug("Extracting results()...") + res = deseq2.results( + dds, contrast=robjects.StrVector(["group", "Target", "Reference"]) + ) + res_df = robjects.conversion.rpy2py(r("as.data.frame")(res)) # Convert to R dataframe first for stability + res_df.index = count_data.index # Assign original feature IDs as index + + + # Correct way to call the R 'counts' function on the dds object + # Ensure 'r' is imported: from rpy2.robjects import r + normalized_counts_r = r['counts'](dds, normalized=True) + + # Convert R matrix to pandas DataFrame + normalized_counts_py = robjects.conversion.rpy2py(normalized_counts_r) + # Ensure DataFrame structure matches original count_data (features x samples) + normalized_counts_df = pd.DataFrame(normalized_counts_py, index=count_data.index, columns=count_data.columns) + + + self.logger.info(f"DESeq2 run completed for {level}. Results shape: {res_df.shape}, Normalized counts shape: {normalized_counts_df.shape}") + return res_df, normalized_counts_df + + except Exception as e: + self.logger.error(f"Error running DESeq2 for {level}: {str(e)}") + # Return empty DataFrames on error to avoid downstream issues + return pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns) def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dict[str, Optional[str]]]: """ @@ -620,149 +764,154 @@ def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dic return self.gene_mapper.map_gene_symbols(feature_ids, level, self.updated_gene_dict) def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: - """Write genes associated with top 100 transcripts by absolute fold change to file.""" + """Write top genes/transcripts based on absolute statistic value to file.""" + if results.empty or 'stat' not in results.columns: + self.logger.warning(f"Cannot write top genes for {level}: Results DataFrame is empty or missing 'stat' column.") + return + + # Ensure 'stat' column is numeric, fill NaNs that might cause issues + results['stat'] = pd.to_numeric(results['stat'], errors='coerce').fillna(0) results["abs_stat"] = abs(results["stat"]) + # Use configured number of top genes/transcripts + top_n = self.top_n_genes + if level == "transcript": - # where baseMean is greater than 500 - top_transcripts = results[results["baseMean"] > 500].nlargest(len(results), "abs_stat") - - unique_genes = set() - top_unique_gene_transcripts = [] - transcript_count = 0 - unique_gene_count = 0 - - for _, transcript_row in top_transcripts.iterrows(): - gene_name = transcript_row["gene_name"] - if gene_name not in unique_genes: - unique_genes.add(gene_name) - top_unique_gene_transcripts.append(transcript_row) - unique_gene_count += 1 - if unique_gene_count >= 100: # Stop when we reach 500 unique genes - break - transcript_count += 1 # Keep track of total transcripts considered - - top_genes = [row["gene_name"] for row in top_unique_gene_transcripts] # Extract gene names from selected transcripts + # Use configured base mean threshold + base_mean_threshold = self.top_transcripts_base_mean + # Ensure 'baseMean' column is numeric, fill NaNs + if 'baseMean' not in results.columns: + self.logger.warning(f"Cannot apply baseMean filter for {level}: 'baseMean' column missing. Considering all transcripts.") + filtered_results = results + else: + results['baseMean'] = pd.to_numeric(results['baseMean'], errors='coerce').fillna(0) + filtered_results = results[results["baseMean"] > base_mean_threshold] + + if filtered_results.empty: + self.logger.warning(f"No transcripts found with baseMean > {base_mean_threshold}. Top genes file will be empty.") + top_unique_gene_transcripts_df = pd.DataFrame() # Empty dataframe + else: + # Sort by absolute statistic value + top_transcripts = filtered_results.sort_values("abs_stat", ascending=False) + + # Ensure 'gene_name' column exists + if 'gene_name' not in top_transcripts.columns: + self.logger.error(f"Cannot extract top unique genes for {level}: 'gene_name' column missing.") + return + + # Get top N unique genes based on the highest ranked transcript for each gene + top_unique_gene_transcripts_df = top_transcripts.drop_duplicates(subset=['gene_name'], keep='first').head(top_n) + self.logger.info(f"Highest adjusted p-value in top {top_n} unique genes: {top_unique_gene_transcripts_df['padj'].max()}") + + top_genes_list = top_unique_gene_transcripts_df["gene_name"].tolist() if not top_unique_gene_transcripts_df.empty else [] # Write to file - top_genes_file = self.deseq_dir / "genes_of_top_100_DE_transcripts.txt" - pd.Series(top_genes).to_csv(top_genes_file, index=False, header=False) - self.logger.debug(f"Wrote genes of top 100 DE transcripts to {top_genes_file}") - else: - # For gene-level analysis, keep original behavior - # top_genes = results.nlargest(100, "abs_stat")["symbol"] # OLD: was writing symbols (gene IDs) - top_genes = results.nlargest(100, "abs_stat")["gene_name"] - top_genes_file = self.deseq_dir / "top_100_DE_genes.txt" - top_genes.to_csv(top_genes_file, index=False, header=False) - self.logger.debug(f"Wrote top 100 DE genes to {top_genes_file}") + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + top_genes_file = self.deseq_dir / f"genes_of_top_{top_n}_DE_transcripts_{target_label}_vs_{reference_label}.txt" + + pd.Series(top_genes_list).to_csv(top_genes_file, index=False, header=False) + self.logger.info(f"Wrote {len(top_genes_list)} unique genes (from top {top_n} DE transcripts with baseMean > {base_mean_threshold}) to {top_genes_file}") + + else: # Gene level + # Ensure 'gene_name' column exists for gene level as well + if 'gene_name' not in results.columns: + self.logger.error(f"Cannot extract top genes for {level}: 'gene_name' column missing.") + return + + # Get top N genes directly by absolute statistic + top_genes_df = results.nlargest(top_n, "abs_stat") + top_genes_list = top_genes_df["gene_name"].tolist() + + # Write to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + top_genes_file = self.deseq_dir / f"top_{top_n}_DE_genes_{target_label}_vs_{reference_label}.txt" + + pd.Series(top_genes_list).to_csv(top_genes_file, index=False, header=False) + self.logger.info(f"Wrote top {len(top_genes_list)} DE genes to {top_genes_file}") def _run_pca(self, normalized_counts, level, coldata, target_label, reference_label): - """Run PCA analysis and create visualization.""" - self.logger.info(f"Running PCA for {level} level...") - - # Run PCA - Calculate 10 components - pca = PCA(n_components=10) # Keep n_components=10 to generate scree plot with 10 components - log_normalized_counts = np.log2(normalized_counts + 1) - pca_result = pca.fit_transform(log_normalized_counts.transpose()) - #map the feature names to gene names using the gene_mapper - feature_names = normalized_counts.index.tolist() - gene_names = self.gene_mapper.map_gene_symbols(feature_names, level, self.updated_gene_dict) - - # Get explained variance ratio and loadings - explained_variance = pca.explained_variance_ratio_ - loadings = pca.components_ # Loadings are in pca.components_ - - # Create DataFrame with columns for all 10 components - pc_columns = [f'PC{i+1}' for i in range(10)] # Generate column names: PC1, PC2, ..., PC10 - pca_df = pd.DataFrame(data=pca_result, columns=pc_columns, index=log_normalized_counts.columns) # Use all 10 column names - pca_df['group'] = coldata['group'].values - - title = f"{level.capitalize()} Level PCA: {target_label} vs {reference_label}\nPC1 ({100*explained_variance[0]:.2f}%) / PC2 ({100*explained_variance[1]:.2f}%)" - - # Use the plotter's PCA method, passing explained variance and loadings - self.visualizer.plot_pca( - pca_df=pca_df, # pca_df now contains 10 components - title=title, - output_prefix=f"pca_{level}", - explained_variance=explained_variance, # Pass explained variance (for scree plot) - loadings=loadings, # Pass loadings (for loadings output) - feature_names=gene_names # Pass feature names (gene names) - ) + """Run PCA analysis and create visualization using DESeq2 normalized counts.""" + self.logger.info(f"Running PCA for {level} level using DESeq2 normalized counts...") + + if normalized_counts.empty: + self.logger.warning(f"Skipping PCA for {level}: Normalized counts data is empty.") + return + + # Basic check for variance - PCA fails if variance is zero + if normalized_counts.var().sum() == 0: + self.logger.warning(f"Skipping PCA for {level}: Data has zero variance.") + return + + # Use configured number of components + n_components = min(self.pca_n_components, normalized_counts.shape[0], normalized_counts.shape[1]) # Cannot exceed number of features or samples + if n_components < 2: + self.logger.warning(f"Skipping PCA for {level}: Not enough features/samples ({normalized_counts.shape}) for {n_components} components.") + return + if n_components != self.pca_n_components: + self.logger.warning(f"Reducing number of PCA components to {n_components} due to data dimensions.") + + + # Log transform the DESeq2 normalized counts (add 1 to handle zeros) + # Ensure data is numeric before transformation + log_normalized_counts = np.log2(normalized_counts.apply(pd.to_numeric, errors='coerce').fillna(0) + 1) + + + # Check for NaNs/Infs after log transform which can happen if counts were negative (though clamped earlier) or exactly -1 + if np.isinf(log_normalized_counts).any().any() or np.isnan(log_normalized_counts).any().any(): + self.logger.warning(f"NaNs or Infs found in log-transformed counts for {level}. Replacing with 0. This might indicate issues with count data.") + log_normalized_counts = log_normalized_counts.replace([np.inf, -np.inf], 0).fillna(0) + - def _median_ratio_normalization(self, count_data: pd.DataFrame) -> pd.DataFrame: - """ - Perform median-by-ratio normalization on count data. - This is similar to the normalization used in DESeq2. - Handles zeros and potential data type issues safely. - """ try: - # Convert to numeric and handle any non-numeric values - count_data_numeric = count_data.apply(pd.to_numeric, errors='coerce').fillna(0) - - # Ensure all values are positive or zero - count_data_nonneg = count_data_numeric.clip(lower=0) - - # Add pseudocount to avoid zeros (1 is a common choice) - count_data_safe = count_data_nonneg + 1 - - # Check data types and values - self.logger.debug(f"Count data shape: {count_data_safe.shape}") - self.logger.debug(f"Count data dtype: {count_data_safe.values.dtype}") - self.logger.debug(f"Min value: {count_data_safe.values.min()}, Max value: {count_data_safe.values.max()}") - - # Convert to numpy array - counts_numpy = count_data_safe.values.astype(float) - - # Alternative geometric mean calculation - # Use log1p which is log(1+x) to handle zeros more safely - log_counts = np.log(counts_numpy) - row_means = np.mean(log_counts, axis=1) - geometric_means = np.exp(row_means) - - # Check for any invalid values in geometric means - if np.any(~np.isfinite(geometric_means)): - self.logger.warning("Found non-finite values in geometric means, replacing with 1.0") - geometric_means[~np.isfinite(geometric_means)] = 1.0 - - # Calculate ratio of each count to the geometric mean - # Reshape geometric_means for broadcasting - geo_means_col = geometric_means.reshape(-1, 1) - ratios = counts_numpy / geo_means_col - - # Calculate size factor for each sample (median of ratios) - size_factors = np.median(ratios, axis=0) - - # Check for any invalid values in size factors - if np.any(size_factors <= 0) or np.any(~np.isfinite(size_factors)): - self.logger.warning("Found invalid size factors, replacing with 1.0") - size_factors[~np.isfinite(size_factors) | (size_factors <= 0)] = 1.0 - - # Log size factors - self.logger.info(f"Size factors: {size_factors}") - - # Normalize counts by dividing by size factors - # Use original count data (without pseudocount) for the final normalization - normalized_counts = pd.DataFrame( - count_data_nonneg.values / size_factors, - index=count_data.index, - columns=count_data.columns + pca = PCA(n_components=n_components) + # Transpose because PCA expects samples as rows, features as columns + pca_result = pca.fit_transform(log_normalized_counts.transpose()) + + # Map feature IDs (index of normalized_counts) to gene names + feature_ids = normalized_counts.index.tolist() + # Use the mapping function - ensure it handles potential errors/missing keys + gene_mapping_dict = self._map_gene_symbols(feature_ids, level) + # Create a list of gene names in the same order as features + feature_names_mapped = [gene_mapping_dict.get(fid, {}).get('gene_name', fid) for fid in feature_ids] + + + # Get explained variance ratio and loadings + explained_variance = pca.explained_variance_ratio_ + loadings = pca.components_ # Loadings are in pca.components_ + + # Create DataFrame with columns for all calculated components + pc_columns = [f'PC{i+1}' for i in range(n_components)] + pca_df = pd.DataFrame(data=pca_result[:, :n_components], columns=pc_columns, index=log_normalized_counts.columns) # Use sample names as index + + # Add group information from coldata, ensuring index alignment + # It's safer to reset index on coldata if it uses sample names as index too + if coldata.index.equals(pca_df.index): + pca_df['group'] = coldata['group'].values + else: + self.logger.warning(f"Index mismatch between PCA results and coldata for {level}. Group information might be incorrect.") + # Attempt to merge or handle, here just assigning potentially misaligned + pca_df['group'] = coldata['group'].values[:len(pca_df)] + + + # Title focuses on PC1/PC2 for the scatter plot, even if more components were calculated + pc1_var = explained_variance[0] * 100 if len(explained_variance) > 0 else 0 + pc2_var = explained_variance[1] * 100 if len(explained_variance) > 1 else 0 + title = f"{level.capitalize()} Level PCA: {target_label} vs {reference_label}\nPC1 ({pc1_var:.2f}%) / PC2 ({pc2_var:.2f}%)" + + + # Use the plotter's PCA method, passing explained variance and loadings + self.visualizer.plot_pca( + pca_df=pca_df, # pca_df contains n_components columns + title=title, + output_prefix=f"pca_{level}", + explained_variance=explained_variance, # Pass full explained variance for scree plot + loadings=loadings, # Pass loadings + # Pass the mapped gene names corresponding to the features (rows of normalized_counts) + feature_names=feature_names_mapped ) - - # Fill any NaN values with 0 - normalized_counts = normalized_counts.fillna(0) - - return normalized_counts - + self.logger.info(f"PCA plots saved for {level} level.") + except Exception as e: - self.logger.error(f"Error in median ratio normalization: {str(e)}") - self.logger.error("Falling back to simple TPM-like normalization") - - # Fallback normalization (similar to TPM) - # Sum each column and divide counts by the sum - col_sums = count_data.sum(axis=0) - col_sums = col_sums.replace(0, 1) # Avoid division by zero - - # Normalize each column by its sum and multiply by 1e6 (similar to TPM scaling) - normalized_counts = count_data.div(col_sums, axis=1) * 1e6 - - return normalized_counts \ No newline at end of file + self.logger.error(f"Error during PCA calculation or plotting for {level}: {str(e)}") \ No newline at end of file diff --git a/src/visualization_mapping.py b/src/visualization_mapping.py index ee1f3af6..703b121a 100644 --- a/src/visualization_mapping.py +++ b/src/visualization_mapping.py @@ -6,7 +6,6 @@ class GeneMapper: def __init__(self): self.mg = mygene.MyGeneInfo() self.logger = logging.getLogger('IsoQuant.visualization.mapping') - self.logger.setLevel(logging.INFO) def get_gene_info_from_mygene(self, ensembl_ids: List[str]) -> Dict[str, Dict]: """ diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 55491b01..91615a1a 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -31,84 +31,60 @@ def __init__( self.reads_and_class = reads_and_class self.conditions = conditions self.ref_only = ref_only - self.min_tpm_threshold = filter_transcripts + self.display_threshold = filter_transcripts # Explicitly set reference and target conditions self.ref_conditions = ref_conditions if ref_conditions else [] self.target_conditions = target_conditions if target_conditions else [] # Log conditions for debugging - if self.ref_conditions and self.target_conditions: - logging.info(f"Filtering plots to include only ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") + if self.ref_conditions or self.target_conditions: + expected_conditions = set(self.ref_conditions + self.target_conditions) + actual_conditions = set(self.updated_gene_dict.keys()) + if expected_conditions != actual_conditions: + logging.warning(f"Mismatch between provided conditions and keys in updated_gene_dict. " + f"Expected: {sorted(list(expected_conditions))}, Found: {sorted(list(actual_conditions))}") + else: + logging.info(f"Plotting with ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") else: - logging.warning("No ref_conditions or target_conditions set, filtering may not work correctly") + logging.warning("No ref_conditions or target_conditions set, plots will include all conditions found in updated_gene_dict") + + # Log the threshold value if provided (for context) + if self.display_threshold is not None: + logging.info(f"Transcript data assumes upstream filtering with TPM >= {self.display_threshold}") - # Log TPM threshold if set - if self.min_tpm_threshold: - logging.info(f"Filtering transcripts with TPM value < {self.min_tpm_threshold}") - # Ensure output directories exist if self.gene_visualizations_dir: os.makedirs(self.gene_visualizations_dir, exist_ok=True) - os.makedirs(self.read_assignments_dir, exist_ok=True) + if self.read_assignments_dir: # Check if read_assignments_dir is not None + os.makedirs(self.read_assignments_dir, exist_ok=True) def plot_transcript_map(self): - """Plot transcript structure with different colors for reference and novel exons.""" + """Plot transcript structure using pre-filtered gene data.""" if not self.gene_visualizations_dir: logging.warning("No gene_visualizations_dir provided. Skipping transcript map plotting.") return - # Check if reference and target conditions are defined - has_specific_conditions = (hasattr(self, 'ref_conditions') and hasattr(self, 'target_conditions') and - self.ref_conditions and self.target_conditions) - - if has_specific_conditions: - logging.info(f"Filtering transcript map to include only ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") - # Define all allowed conditions - allowed_conditions = set(self.ref_conditions + self.target_conditions) + for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) gene_data = None # Initialize gene_data to None - found_condition = None # Track which condition we found the gene in - - # First pass: Try to find the gene in allowed conditions only - if has_specific_conditions: - # Search only in allowed conditions - for condition in allowed_conditions: - if condition not in self.updated_gene_dict: - continue - - genes = self.updated_gene_dict[condition] - for gene_id, gene_info in genes.items(): - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) - gene_data = gene_info - found_condition = condition - break - if gene_data: - break # Found gene, stop searching - # Second pass: If not found and we're allowing fallback, try all conditions + # Find the gene in the pre-filtered dictionary. + # We only need one instance of the gene structure, as it should be consistent. + # Iterate through conditions until the gene is found. + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + # Compare gene names (case-insensitive, using upper()) + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): + gene_data = gene_info + # No need to log which condition it came from, as it's pre-filtered. + break # Found gene info + if gene_data: + break # Found gene, stop searching conditions + if not gene_data: - for condition, genes in self.updated_gene_dict.items(): - # Skip conditions we already checked if using specific conditions - if has_specific_conditions and condition in allowed_conditions: - continue - - for gene_id, gene_info in genes.items(): - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): - gene_data = gene_info - found_condition = condition - break - if gene_data: - break # Found gene, stop searching - - if gene_data: - if has_specific_conditions and found_condition in allowed_conditions: - logging.debug(f"Gene {gene_name_or_id} found in prioritized condition: {found_condition}") - else: - logging.debug(f"Gene {gene_name_or_id} found in fallback condition: {found_condition}") - else: - logging.warning(f"Gene {gene_name_or_id} not found in any condition.") + logging.warning(f"Gene {gene_name_or_id} not found in the provided (pre-filtered) updated_gene_dict.") continue # Skip to the next gene if not found # Get chromosome info and calculate buffer @@ -122,76 +98,19 @@ def plot_transcript_map(self): plot_start = start - buffer plot_end = end + buffer - # NEW APPROACH: If we have ref/target conditions AND TPM filtering, - # we need to consider transcript expression across ALL relevant conditions - if has_specific_conditions and self.min_tpm_threshold is not None: - # First, collect the max TPM for each transcript across all ref/target conditions - transcript_max_tpm = {} - - # Collect all transcripts from the current condition first - for transcript_id, transcript_info in gene_data["transcripts"].items(): - value = float(transcript_info.get("value", 0)) - transcript_max_tpm[transcript_id] = value - - # Check other ref/target conditions for the same gene to find max TPM values - for condition in allowed_conditions: - if condition == found_condition or condition not in self.updated_gene_dict: - continue # Skip the condition we already processed - - genes = self.updated_gene_dict[condition] - for gene_id, gene_info in genes.items(): - # Check if this is the same gene in another condition - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): - # Found the same gene in another condition, check transcript TPM values - for transcript_id, transcript_info in gene_info["transcripts"].items(): - value = float(transcript_info.get("value", 0)) - # Update max TPM if higher in this condition - if transcript_id in transcript_max_tpm: - transcript_max_tpm[transcript_id] = max(transcript_max_tpm[transcript_id], value) - else: - transcript_max_tpm[transcript_id] = value - break # Found the gene in this condition, no need to check other genes - - # Now filter transcripts based on their max TPM across ref/target conditions - filtered_transcripts = {} - for transcript_id, transcript_info in gene_data["transcripts"].items(): - max_tpm = transcript_max_tpm.get(transcript_id, 0) - if max_tpm >= self.min_tpm_threshold: - filtered_transcripts[transcript_id] = transcript_info - - # Log filtering results - total_transcripts = len(gene_data["transcripts"]) - filtered_count = len(filtered_transcripts) - logging.debug(f"Cross-condition TPM filtering: {filtered_count} of {total_transcripts} transcripts have TPM >= {self.min_tpm_threshold} in any ref/target condition for gene {gene_name_or_id}") - else: - # Original filtering approach for single condition - filtered_transcripts = {} - total_transcripts = len(gene_data["transcripts"]) - filtered_count = 0 - - for transcript_id, transcript_info in gene_data["transcripts"].items(): - # Apply TPM filtering if threshold is set - if self.min_tpm_threshold is not None: - value = float(transcript_info.get("value", 0)) - if value >= self.min_tpm_threshold: - filtered_transcripts[transcript_id] = transcript_info - filtered_count += 1 - else: - # No filtering, include all transcripts - filtered_transcripts[transcript_id] = transcript_info - - if self.min_tpm_threshold is not None: - logging.debug(f"Single-condition TPM filtering: {filtered_count} of {total_transcripts} transcripts have TPM >= {self.min_tpm_threshold} for gene {gene_name_or_id}") + # REMOVED FILTERING LOGIC - Directly use transcripts from gene_data + filtered_transcripts = gene_data["transcripts"] - # Skip plotting if no transcripts pass the filter + # Skip plotting if no transcripts are present (this might happen if upstream filtering removed all) if not filtered_transcripts: - logging.warning(f"No transcripts for gene {gene_name_or_id} pass the TPM threshold of {self.min_tpm_threshold}. Skipping plot.") + logging.warning(f"No transcripts found for gene {gene_name_or_id} in the input data. Skipping plot.") continue # Calculate plot height based on number of filtered transcripts num_transcripts = len(filtered_transcripts) - plot_height = max(8, num_transcripts * 0.4) - #logging.debug(f"Creating transcript map for gene '{gene_name_or_id}' with {num_transcripts} transcripts from {found_condition}") + plot_height = max(10, num_transcripts * 0.6) # Increased base height and multiplier + # Use INFO level for starting plot creation, DEBUG for saving it. + logging.debug(f"Creating transcript map for gene '{gene_name_or_id}' with {num_transcripts} transcripts.") fig, ax = plt.subplots(figsize=(12, plot_height)) @@ -246,6 +165,7 @@ def plot_transcript_map(self): ) ) + if not any(exon["exon_id"].startswith("ENSE") for exon in transcript_info["exons"]): logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains NO reference exons (based on ENSEMBL IDs)") #log the exon_ids @@ -260,24 +180,15 @@ def plot_transcript_map(self): transcript_name = (transcript_info.get("name") or transcript_info.get("transcript_id") or transcript_id) - value = float(transcript_info.get("value", 0)) - - # If cross-condition filtering was used, show the max TPM value in the label - if has_specific_conditions and self.min_tpm_threshold is not None: - max_tpm = transcript_max_tpm.get(transcript_id, 0) - y_labels.append(f"{transcript_name}") - else: - y_labels.append(f"{transcript_name}") + + y_labels.append(f"{transcript_name}") # Set up the plot formatting with just chromosome gene_display_name = gene_data.get("name", gene_name_or_id) # Fallback to ID if no name # Update title to include TPM threshold if applied - if self.min_tpm_threshold is not None: - if has_specific_conditions: - title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (TPM >= {self.min_tpm_threshold} in any ref/target condition)" - else: - title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (TPM >= {self.min_tpm_threshold})" + if self.display_threshold is not None: + title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome}) (Input filtered at TPM >= {self.display_threshold})" else: title = f"Transcript Structure - {gene_display_name} (Chromosome {chromosome})" @@ -299,124 +210,62 @@ def plot_transcript_map(self): # Add grid lines ax.grid(True, axis='y', linestyle='--', alpha=0.3) - plt.tight_layout() + plt.tight_layout(rect=[0.05, 0, 0.9, 1]) # Give more space on left (0.05) and right (1-0.9=0.1) plot_path = os.path.join( self.gene_visualizations_dir, f"{gene_name_or_id}_splicing.pdf" # Changed from .png to .pdf ) plt.savefig(plot_path, bbox_inches='tight', dpi=300) plt.close(fig) - logging.debug(f"Saved transcript map for gene '{gene_name_or_id}' at: {plot_path}") + def plot_transcript_usage(self): - """Visualize transcript usage for each gene in gene_names across different conditions.""" + """Visualize transcript usage for each gene across conditions from pre-filtered data.""" if not self.gene_visualizations_dir: logging.warning("No gene_visualizations_dir provided. Skipping transcript usage plotting.") return - - # Add this section near the beginning of the method - logging.info("=== SPECIAL DEBUG FOR YBX1 TRANSCRIPTS ===") - for condition in self.updated_gene_dict: - for gene_id, gene_info in self.updated_gene_dict[condition].items(): - if gene_info.get("name") == "YBX1" or gene_id == "YBX1": - logging.info(f"Found YBX1 in condition {condition}") - logging.info(f"Total transcripts before filtering: {len(gene_info.get('transcripts', {}))}") - for transcript_id, transcript_info in gene_info.get('transcripts', {}).items(): - value = transcript_info.get('value', 0) - logging.info(f" Transcript {transcript_id}: TPM = {value:.2f}") - - # Check if reference and target conditions are defined - has_specific_conditions = (hasattr(self, 'ref_conditions') and hasattr(self, 'target_conditions') and - self.ref_conditions and self.target_conditions) - - if has_specific_conditions: - logging.debug(f"Filtering transcript usage plot to include only ref conditions: {self.ref_conditions} and target conditions: {self.target_conditions}") - # Define all allowed conditions - allowed_conditions = set(self.ref_conditions + self.target_conditions) + + # The input updated_gene_dict is assumed to be pre-filtered. for gene_name_or_id in self.gene_names: # gene_names list contains gene names (symbols) - gene_data_per_condition = {} # Store gene data per condition + gene_data_per_condition = {} # Store gene transcript data per condition found_gene_any_condition = False # Flag if gene found in any condition - - # Only process allowed conditions if specific conditions are defined + + # Iterate directly through the pre-filtered dictionary for condition, genes in self.updated_gene_dict.items(): - # Skip conditions that aren't in ref or target if we have those defined - if has_specific_conditions and condition not in allowed_conditions: - #logging.debug(f"Skipping condition {condition} for gene {gene_name_or_id} (not in allowed conditions)") - continue - condition_gene_data = None for gene_id, gene_info in genes.items(): - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): # Compare gene names (case-insensitive) - condition_gene_data = gene_info["transcripts"] # Only need transcripts for usage plot + # Compare gene names (case-insensitive) + if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): + condition_gene_data = gene_info.get("transcripts", {}) # Get transcripts, default to empty dict found_gene_any_condition = True - #logging.debug(f"Found gene {gene_name_or_id} in condition {condition}") + #logging.debug(f"Found gene {gene_name_or_id} data for condition {condition}") break # Found gene in this condition, break inner loop - if condition_gene_data: - gene_data_per_condition[condition] = condition_gene_data # Store transcripts for this condition - + if condition_gene_data is not None: # Store even if empty, to represent the condition + gene_data_per_condition[condition] = condition_gene_data + if not found_gene_any_condition: - logging.warning(f"Gene {gene_name_or_id} not found in any allowed condition.") + logging.warning(f"Gene {gene_name_or_id} not found in any condition within the pre-filtered updated_gene_dict.") continue # Skip to the next gene if not found - - # NEW APPROACH: If TPM threshold is set, collect max TPM for each transcript across all conditions - if self.min_tpm_threshold is not None: - # Create list of all allowed conditions - allowed_conditions = set(self.ref_conditions + self.target_conditions) if has_specific_conditions else set(gene_data_per_condition.keys()) - - # First, collect the max TPM for each transcript ONLY in allowed conditions - transcript_max_tpm = {} - - for condition, transcripts in gene_data_per_condition.items(): - # Skip conditions that aren't in ref or target if we have those defined - if has_specific_conditions and condition not in allowed_conditions: - continue - - for transcript_id, transcript_info in transcripts.items(): - value = float(transcript_info.get("value", 0)) - if transcript_id not in transcript_max_tpm: - transcript_max_tpm[transcript_id] = value - else: - transcript_max_tpm[transcript_id] = max(transcript_max_tpm[transcript_id], value) - - # Filter out transcripts that don\'t meet threshold in ANY allowed condition - valid_transcripts = {t_id: t_val for t_id, t_val in transcript_max_tpm.items() - if t_val >= self.min_tpm_threshold} - - # Now keep ALL instances of valid transcripts in allowed conditions, even if below threshold - filtered_gene_data_per_condition = {} - for condition, transcripts in gene_data_per_condition.items(): - # Skip conditions that aren\'t in ref or target if we have those defined - if has_specific_conditions and condition not in allowed_conditions: - continue - - filtered_transcripts = {} - for transcript_id, transcript_info in transcripts.items(): - # Include this transcript if it\'s in our valid list - if transcript_id in valid_transcripts: - filtered_transcripts[transcript_id] = transcript_info - - # Only include conditions with transcripts - if filtered_transcripts: - filtered_gene_data_per_condition[condition] = filtered_transcripts - - # Replace original data with filtered data - gene_data_per_condition = filtered_gene_data_per_condition - - # Log filtering results - total_unique_transcripts = len(transcript_max_tpm) - kept_transcripts = len(valid_transcripts) - logging.debug(f"TPM filtering for {gene_name_or_id}: {kept_transcripts} of {total_unique_transcripts} unique transcripts have TPM >= {self.min_tpm_threshold} in at least one ref/target condition") - + if not gene_data_per_condition: - logging.warning(f"No data available for gene {gene_name_or_id} after filtering. Skipping plot.") + logging.warning(f"No transcript data available for gene {gene_name_or_id} across conditions. Skipping plot.") continue + + # --- Reorder conditions: Reference first, then Target --- + all_conditions = list(gene_data_per_condition.keys()) + ref_conditions_present = sorted([c for c in all_conditions if c in self.ref_conditions]) + target_conditions_present = sorted([c for c in all_conditions if c in self.target_conditions]) + # Include any other conditions found in the data but not specified as ref/target (shouldn't happen with pre-filtering, but safe) + other_conditions_present = sorted([c for c in all_conditions if c not in self.ref_conditions and c not in self.target_conditions]) - conditions = list(gene_data_per_condition.keys()) + conditions = ref_conditions_present + target_conditions_present + other_conditions_present + # --- End Reordering --- + n_bars = len(conditions) if n_bars == 0: - logging.warning(f"No conditions found for gene {gene_name_or_id} after filtering.") + logging.warning(f"No conditions to plot for gene {gene_name_or_id}.") continue @@ -424,22 +273,39 @@ def plot_transcript_usage(self): index = np.arange(n_bars) bar_width = 0.35 opacity = 0.8 - max_transcripts = max(len(gene_data_per_condition[condition]) for condition in conditions) - colors = plt.cm.plasma(np.linspace(0, 1, num=max_transcripts)) + + # Determine unique transcripts across all plotted conditions for consistent coloring + all_transcript_ids = set() + for condition in conditions: + all_transcript_ids.update(gene_data_per_condition[condition].keys()) + unique_transcripts = sorted(list(all_transcript_ids)) + transcript_to_color_idx = {tid: idx for idx, tid in enumerate(unique_transcripts)} + colors = plt.cm.plasma(np.linspace(0, 1, num=len(unique_transcripts))) bottom_val = np.zeros(n_bars) + plotted_labels = set() # To avoid duplicate legend entries + for i, condition in enumerate(conditions): transcripts = gene_data_per_condition[condition] if not transcripts: # Skip if no transcript data for this condition continue - for j, (transcript_id, transcript_info) in enumerate(transcripts.items()): - color = colors[j % len(colors)] + # Sort transcripts for consistent stacking order (optional but good practice) + sorted_transcript_items = sorted(transcripts.items(), key=lambda item: item[0]) + + for transcript_id, transcript_info in sorted_transcript_items: + color_idx = transcript_to_color_idx.get(transcript_id, 0) # Fallback index 0 + color = colors[color_idx % len(colors)] value = transcript_info["value"] # Get transcript name with fallback options transcript_name = (transcript_info.get("name") or transcript_info.get("transcript_id") or transcript_id) + + label = transcript_name if transcript_name not in plotted_labels else "" + if label: + plotted_labels.add(label) + ax.bar( i, float(value), @@ -447,28 +313,44 @@ def plot_transcript_usage(self): bottom=bottom_val[i], alpha=opacity, color=color, - label=transcript_name if i == 0 else "", + label=label, ) bottom_val[i] += float(value) ax.set_xlabel("Condition") ax.set_ylabel("Transcript Usage (TPM)") - gene_display_name = list(gene_data_per_condition.values())[0].get("name", gene_name_or_id) # Fallback to ID if no name - - # Update title to include TPM threshold if applied - if self.min_tpm_threshold is not None: - ax.set_title(f"Transcript Usage for {gene_display_name} by Condition (TPM >= {self.min_tpm_threshold} in any ref/target condition)") + # Find a representative gene name (assuming transcripts exist in at least one condition) + first_condition_with_data = next((cond for cond in conditions if gene_data_per_condition[cond]), None) + gene_display_name = gene_name_or_id # Default to ID + if first_condition_with_data: + # Attempt to get gene name from the first transcript entry in the first condition with data + first_transcript_info = next(iter(gene_data_per_condition[first_condition_with_data].values()), None) + if first_transcript_info: + # Assuming gene name might be stored within transcript info, or fallback + # This part might need adjustment based on your actual data structure + # If gene name isn't in transcript info, you might need to fetch it differently + pass # Placeholder - logic to get gene name needs review based on structure + + # Updated title - Include threshold if available + if self.display_threshold is not None: + ax.set_title(f"Transcript Usage for {gene_display_name} by Condition (Input filtered at TPM >= {self.display_threshold})") else: ax.set_title(f"Transcript Usage for {gene_display_name} by Condition") ax.set_xticks(index) ax.set_xticklabels(conditions) - ax.legend( - title="Transcript IDs", - bbox_to_anchor=(1.05, 1), - loc="upper left", - fontsize=8, - ) + + # Update legend handling to use plotted_labels + handles, labels = ax.get_legend_handles_labels() + if handles: # Only show legend if there are items to show + ax.legend( + handles, + labels, + title="Transcript IDs", + bbox_to_anchor=(1.05, 1), + loc="upper left", + fontsize=8, + ) plt.tight_layout() plot_path = os.path.join( @@ -490,24 +372,26 @@ def make_pie_charts(self): titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] - # Check if reference and target conditions are defined - has_specific_conditions = hasattr(self, 'ref_conditions') and hasattr(self, 'target_conditions') + # Input data is assumed to be pre-filtered, so no need to check ref/target conditions here. + # Plot for all sample groups found in the data. for title, data in zip(titles, self.reads_and_class): if isinstance(data, dict): - if any(isinstance(v, dict) for v in data.values()): - # Separate sample data case (e.g. 'Mutants' and 'WildType') + # Check if the dictionary values are also dictionaries (indicating separate sample groups) + if data and isinstance(next(iter(data.values()), None), dict): + # Separate sample data case (e.g. {'Mutants': {...}, 'WildType': {...}}) + logging.debug(f"Creating separate pie charts for samples in '{title}'") for sample_name, sample_data in data.items(): - # Skip conditions that aren't in ref or target if we have those defined - if has_specific_conditions and sample_name not in self.ref_conditions and sample_name not in self.target_conditions: - logging.debug(f"Skipping pie chart for {sample_name} (not in ref/target conditions)") - continue + # No filtering needed here, plot for every sample found self._create_pie_chart(f"{title} - {sample_name}", sample_data) - else: - # Combined data case - always create this as it's an overall summary + elif data: # Check if data is not empty before proceeding + # Combined data case or single sample group provided directly + logging.debug(f"Creating combined pie chart for '{title}'") self._create_pie_chart(title, data) + else: + logging.warning(f"Empty data dictionary provided for pie chart '{title}'. Skipping.") else: - print(f"Skipping unexpected data type for {title}: {type(data)}") + logging.warning(f"Skipping unexpected data type for pie chart '{title}': {type(data)}") def _create_pie_chart(self, title, data): """ @@ -546,7 +430,7 @@ def _create_pie_chart(self, title, data): plot_path = os.path.join( self.read_assignments_dir, f"{file_title}_pie_chart.pdf" # Changed from .png to .pdf ) - plt.savefig(plot_path, bbox_inches="tight", dpi=300) + plt.savefig(plot_path, bbox_inches='tight', dpi=300) plt.close() def plot_novel_transcript_contribution(self): @@ -556,11 +440,13 @@ def plot_novel_transcript_contribution(self): - X-axis: Expression log2 fold change between conditions - Point size: Overall expression level - Color: Red (target) to Blue (reference) indicating which condition contributes more to novel transcript expression + Assumes input updated_gene_dict is already filtered appropriately. """ logging.info("Creating novel transcript contribution plot") # Skip if we don't have reference vs target conditions defined - if not hasattr(self, 'ref_conditions') or not hasattr(self, 'target_conditions'): + if not (hasattr(self, 'ref_conditions') and self.ref_conditions and + hasattr(self, 'target_conditions') and self.target_conditions): logging.warning("Cannot create novel transcript plot: missing reference or target conditions") return @@ -568,104 +454,29 @@ def plot_novel_transcript_contribution(self): ref_label = "+".join(self.ref_conditions) target_label = "+".join(self.target_conditions) - # Set TPM threshold for transcript inclusion - min_tpm_threshold = 10 - # Track all unique genes across all conditions all_genes = {} # Dictionary to track gene_id -> gene_info mapping across conditions - # First, collect all genes from all conditions + # Collect all genes present in the (presumably pre-filtered) input dictionary for condition, genes in self.updated_gene_dict.items(): for gene_id, gene_info in genes.items(): gene_name = gene_info.get('name', gene_id) if gene_id not in all_genes: all_genes[gene_id] = {'name': gene_name, 'conditions': {}} - # Store condition-specific data - all_genes[gene_id]['conditions'][condition] = gene_info + # Store condition-specific data only if it's a ref or target condition + if condition in self.ref_conditions or condition in self.target_conditions: + all_genes[gene_id]['conditions'][condition] = gene_info - logging.info(f"Total unique genes found across all conditions: {len(all_genes)}") - - # First, let's investigate the discrepancy between our filtering and the GTF filtering - logging.info("Investigating transcript filtering discrepancy") - - # Track unique transcript IDs that pass our threshold (to match GTF filtering method) - unique_transcripts_above_threshold = set() - transcript_values = {} # To store max values for debugging - - # First pass: identify all unique transcripts with TPM >= threshold - for condition, genes in self.updated_gene_dict.items(): - # Check if this condition is in our ref or target groups - is_relevant_condition = condition in self.ref_conditions or condition in self.target_conditions - if not is_relevant_condition: - continue - - for gene_id, gene_info in genes.items(): - transcripts = gene_info.get('transcripts', {}) - for transcript_id, transcript_info in transcripts.items(): - value = float(transcript_info.get("value", 0)) - - # Track max value for this transcript across all conditions - if transcript_id not in transcript_values: - transcript_values[transcript_id] = value - else: - transcript_values[transcript_id] = max(transcript_values[transcript_id], value) - - # Check if transcript meets threshold - if value >= min_tpm_threshold: - unique_transcripts_above_threshold.add(transcript_id) - - logging.info(f"Filtering comparison: Found {len(unique_transcripts_above_threshold)} unique transcripts with TPM >= {min_tpm_threshold}") - logging.info(f"Filtering comparison: This compares to 8,230 transcripts reported by GTF filter") - - # Analyze distribution of TPM values to understand filtering - tpm_value_counts = { - "0-1": 0, - "1-5": 0, - "5-10": 0, - "10-20": 0, - "20-50": 0, - "50-100": 0, - "100+": 0 - } - - for transcript_id, max_value in transcript_values.items(): - if max_value < 1: - tpm_value_counts["0-1"] += 1 - elif max_value < 5: - tpm_value_counts["1-5"] += 1 - elif max_value < 10: - tpm_value_counts["5-10"] += 1 - elif max_value < 20: - tpm_value_counts["10-20"] += 1 - elif max_value < 50: - tpm_value_counts["20-50"] += 1 - elif max_value < 100: - tpm_value_counts["50-100"] += 1 - else: - tpm_value_counts["100+"] += 1 - - logging.info(f"TPM value distribution across transcripts: {tpm_value_counts}") - - # Check if there are any TTN transcripts in the unique set - ttn_transcripts = [t for t in unique_transcripts_above_threshold if "TTN" in t.upper()] - if ttn_transcripts: - logging.info(f"Found {len(ttn_transcripts)} TTN transcripts in high TPM set: {ttn_transcripts}") + logging.info(f"Total unique genes found across relevant conditions: {len(all_genes)}") # Prepare data storage for the main plot - plot_data = [] # Re-initialize plot_data here - - # Track transcripts that pass TPM threshold - total_transcripts = 0 - transcripts_above_threshold = 0 - total_genes = 0 - genes_with_high_expr_transcripts = 0 + plot_data = [] # Process each gene from all_genes for gene_id, gene_data in all_genes.items(): - total_genes += 1 gene_name = gene_data['name'] - conditions_data = gene_data['conditions'] + conditions_data = gene_data['conditions'] # Contains only ref/target conditions now # Calculate expression for each condition group ref_total_exp = {cond: 0 for cond in self.ref_conditions} @@ -673,53 +484,40 @@ def plot_novel_transcript_contribution(self): ref_novel_exp = {cond: 0 for cond in self.ref_conditions} target_novel_exp = {cond: 0 for cond in self.target_conditions} - # Track if this gene has any high-expression transcripts - gene_has_high_expr_transcript = False + gene_has_any_transcript = False # Check if the gene has any transcripts in ref/target - # Process each condition + # Process each relevant condition for the gene for condition, gene_info in conditions_data.items(): transcripts = gene_info.get('transcripts', {}) if not transcripts: continue + + gene_has_any_transcript = True # Mark that this gene has data - # Check if this condition is in our condition groups + # Check if this condition is in our condition groups (redundant check now, but safe) is_ref = condition in self.ref_conditions is_target = condition in self.target_conditions - if not (is_ref or is_target): - continue # Skip conditions that aren't in ref or target groups - for transcript_id, transcript_info in transcripts.items(): - total_transcripts += 1 - # Improved novel transcript identification - transcript is novel if not from Ensembl transcript_is_reference = transcript_id.startswith("ENST") is_novel = not transcript_is_reference value = float(transcript_info.get("value", 0)) - # Filter by TPM threshold - only count transcripts with TPM >= threshold - # We only check TPM threshold in ref and target conditions (not other conditions) - if value >= min_tpm_threshold: - transcripts_above_threshold += 1 - gene_has_high_expr_transcript = True - - if is_ref: - ref_total_exp[condition] += value - if is_novel: - ref_novel_exp[condition] += value - - if is_target: - target_total_exp[condition] += value - if is_novel: - target_novel_exp[condition] += value - elif gene_name == "TTN" and condition in self.ref_conditions: - pass # Add pass to avoid empty block + # REMOVED Filtering by TPM threshold - Now process all transcripts present + if is_ref: + ref_total_exp[condition] += value + if is_novel: + ref_novel_exp[condition] += value + + if is_target: + target_total_exp[condition] += value + if is_novel: + target_novel_exp[condition] += value - # Count genes with high-expression transcripts - if gene_has_high_expr_transcript: - genes_with_high_expr_transcripts += 1 - + # Only proceed if the gene had transcripts in the relevant conditions + if gene_has_any_transcript: # Calculate average expression for each condition group ref_novel_pct = 0 target_novel_pct = 0 @@ -729,71 +527,76 @@ def plot_novel_transcript_contribution(self): target_novel_expr_total = 0 # Sum up expression values across conditions + num_ref_conditions_with_expr = 0 for cond in self.ref_conditions: - ref_expr_total += ref_total_exp.get(cond, 0) - ref_novel_expr_total += ref_novel_exp.get(cond, 0) - - # Also calculate percentages per condition for color coding - if ref_total_exp.get(cond, 0) > 0: - ref_novel_pct += (ref_novel_exp.get(cond, 0) / ref_total_exp[cond]) * 100 + cond_total_exp = ref_total_exp.get(cond, 0) + cond_novel_exp = ref_novel_exp.get(cond, 0) + ref_expr_total += cond_total_exp + ref_novel_expr_total += cond_novel_exp + if cond_total_exp > 0: + ref_novel_pct += (cond_novel_exp / cond_total_exp) * 100 + num_ref_conditions_with_expr += 1 + num_target_conditions_with_expr = 0 for cond in self.target_conditions: - target_expr_total += target_total_exp.get(cond, 0) - target_novel_expr_total += target_novel_exp.get(cond, 0) - - # Also calculate percentages per condition for color coding - if target_total_exp.get(cond, 0) > 0: - target_novel_pct += (target_novel_exp.get(cond, 0) / target_total_exp[cond]) * 100 + cond_total_exp = target_total_exp.get(cond, 0) + cond_novel_exp = target_novel_exp.get(cond, 0) + target_expr_total += cond_total_exp + target_novel_expr_total += cond_novel_exp + if cond_total_exp > 0: + target_novel_pct += (cond_novel_exp / cond_total_exp) * 100 + num_target_conditions_with_expr += 1 # Average the condition-specific percentages (for color coding only) - ref_novel_pct /= len([c for c in self.ref_conditions if c in ref_total_exp and ref_total_exp[c] > 0]) or 1 - target_novel_pct /= len([c for c in self.target_conditions if c in target_total_exp and target_total_exp[c] > 0]) or 1 + ref_novel_pct /= num_ref_conditions_with_expr or 1 + target_novel_pct /= num_target_conditions_with_expr or 1 # Calculate overall novel percentage (for y-axis) combined_expr_total = ref_expr_total + target_expr_total combined_novel_expr_total = ref_novel_expr_total + target_novel_expr_total - # Calculate log2 fold change using the total expression values - if ref_expr_total > 0 and target_expr_total > 0: - log2fc = np.log2(target_expr_total / ref_expr_total) + # Check for non-zero total expression before calculating percentages and fold change + if combined_expr_total > 0: + # Calculate log2 fold change using the total expression values + # Add pseudocount to avoid division by zero or log(0) + pseudocount = 1e-6 # Small value to add + log2fc = np.log2((target_expr_total + pseudocount) / (ref_expr_total + pseudocount)) # Calculate novel transcript contribution difference (for color) novel_pct_diff = target_novel_pct - ref_novel_pct # Calculate overall novel percentage (for y-axis) - if combined_expr_total > 0: - overall_novel_pct = (combined_novel_expr_total / combined_expr_total) * 100 - - # Add data point - plot_data.append({ - 'gene_id': gene_id, - 'gene_name': gene_name, - 'ref_novel_pct': ref_novel_pct, - 'target_novel_pct': target_novel_pct, - 'novel_pct_diff': novel_pct_diff, - 'overall_novel_pct': overall_novel_pct, - 'log2fc': log2fc, - 'total_expr': combined_expr_total - }) + overall_novel_pct = (combined_novel_expr_total / combined_expr_total) * 100 + + # Add data point + plot_data.append({ + 'gene_id': gene_id, + 'gene_name': gene_name, + 'ref_novel_pct': ref_novel_pct, + 'target_novel_pct': target_novel_pct, + 'novel_pct_diff': novel_pct_diff, + 'overall_novel_pct': overall_novel_pct, + 'log2fc': log2fc, + 'total_expr': combined_expr_total + }) - # Report filtering results - logging.info(f"TPM filtering: {transcripts_above_threshold} of {total_transcripts} transcripts have TPM >= {min_tpm_threshold} in ref or target conditions") - logging.info(f"TPM filtering: {genes_with_high_expr_transcripts} of {total_genes} genes have at least one transcript with TPM >= {min_tpm_threshold} in ref or target conditions") - # Create dataframe df = pd.DataFrame(plot_data) + + if df.empty: + logging.warning("No data available for novel transcript plot after processing.") # Adjusted warning + return + # Get the parent directory of gene_visualizations_dir parent_dir = os.path.dirname(self.gene_visualizations_dir) # Save the CSV to parent directory instead of gene_visualizations_dir - df.to_csv(os.path.join(parent_dir, "novel_transcript_expression_data.csv"), index=False) + csv_path = os.path.join(parent_dir, "novel_transcript_expression_data.csv") + df.to_csv(csv_path, index=False) + logging.info(f"Novel transcript expression data saved to {csv_path}") # Log the number of genes used in the plot - logging.info(f"Number of genes used in novel transcript plot after transcript-level TPM filtering: {len(df)}") - - if df.empty: - logging.warning("No data available for novel transcript plot after transcript-level TPM filtering") - return + logging.info(f"Number of genes included in novel transcript plot: {len(df)}") # Create the plot with more space on right for legend plt.figure(figsize=(16, 10)) # Increased width from 14 to 16 @@ -808,8 +611,14 @@ def plot_novel_transcript_contribution(self): # Use np.power for more dramatic scaling differences expression_values = df['total_expr'].values - max_expr = expression_values.max() - min_expr = expression_values.min() + # Handle case where expression_values might be empty or all zero + if len(expression_values) == 0 or expression_values.max() == expression_values.min(): + max_expr = 1 + min_expr = 0 + logging.warning("Cannot determine expression range for point scaling; using default [0, 1].") + else: + max_expr = expression_values.max() + min_expr = expression_values.min() # Log the actual min and max expression values for reference logging.debug(f"Expression range in data: min={min_expr}, max={max_expr}") @@ -818,10 +627,12 @@ def plot_novel_transcript_contribution(self): def scale_point_size(expr_value, min_expr, max_expr, min_size, max_size, power=0.3): """Scale expression values to point sizes using the same formula for data and legend""" # Normalize the expression value to [0,1] range - if max_expr == min_expr: # Avoid division by zero - normalized = 0 + if max_expr == min_expr: # Avoid division by zero or invalid range + normalized = 0.5 # Default to middle size if range is zero else: - normalized = (expr_value - min_expr) / (max_expr - min_expr) + # Clamp value to range before normalizing to handle potential outliers from pseudocounts + clamped_value = np.clip(expr_value, min_expr, max_expr) + normalized = (clamped_value - min_expr) / (max_expr - min_expr) # Apply power scaling and convert to point size return min_size + (max_size - min_size) * (normalized ** power) @@ -843,8 +654,8 @@ def scale_point_size(expr_value, min_expr, max_expr, min_size, max_size, power=0 cbar.ax.tick_params(labelsize=10) # Use red and blue blocks to explain the colormap - plt.figtext(0.92, 0.72, f'Blue = higher (%) in {self.ref_conditions}', fontsize=12, ha='center') - plt.figtext(0.92, 0.75, f'Red = higher (%) in {self.target_conditions}', fontsize=12, ha='center') + plt.figtext(0.92, 0.72, f'Blue = higher (%) in {ref_label}', fontsize=12, ha='center') + plt.figtext(0.92, 0.75, f'Red = higher (%) in {target_label}', fontsize=12, ha='center') # Add size legend directly to the plot # Create legend elements for different sizes with new values: 50, 500, 5000 @@ -854,9 +665,7 @@ def scale_point_size(expr_value, min_expr, max_expr, min_size, max_size, power=0 # Calculate sizes for legend using EXACTLY the same scaling function as for the data points for val in size_legend_values: # Use the same scaling function defined above - # If the value is outside the actual data range, clamp it to the range - clamped_val = min(max(val, min_expr), max_expr) - size = scale_point_size(clamped_val, min_expr, max_expr, min_size, max_size) + size = scale_point_size(val, min_expr, max_expr, min_size, max_size) # Log the actual size being used for the legend point logging.debug(f"Legend point {val} TPM scaled to size {size}") diff --git a/visualize.py b/visualize.py index 52d5b31e..f63b4ff9 100755 --- a/visualize.py +++ b/visualize.py @@ -26,7 +26,7 @@ def setup_logging(viz_output_dir: Path) -> None: # Console handler - less detailed console_handler = logging.StreamHandler() - console_handler.setLevel(logging.DEBUG) # Console output at DEBUG level + console_handler.setLevel(logging.INFO) # Console output at INFO level console_handler.setFormatter(console_formatter) # Configure root logger @@ -240,12 +240,11 @@ def main(): gene_list = None update_names = True - # 1. Build a dictionary that includes transcript expression - # and filters transcripts if they do NOT exceed args.filter_transcripts - # (defaulting to 1.0 if not provided) min_val = args.filter_transcripts if args.filter_transcripts is not None else 1.0 updated_gene_dict = dictionary_builder.build_gene_dict_with_expression_and_filter( - min_value=min_val + min_value=min_val, + reference_conditions=getattr(args, 'reference_conditions', None), + target_conditions=getattr(args, 'target_conditions', None) ) # 2. If read assignments are desired, build those as well (cached) @@ -282,16 +281,18 @@ def main(): dictionary_builder=dictionary_builder, ) gene_results, transcript_results, _, deseq2_df = diff_analysis.run_complete_analysis() - find_genes_list_path = gene_results.parent / "genes_of_top_100_DE_transcripts.txt" - gene_list = dictionary_builder.read_gene_list(find_genes_list_path) if args.gsea: gsea = GSEAAnalysis(output_path=base_dir) target_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" gsea.run_gsea_analysis(deseq2_df, target_label) - # Use genes from top transcripts instead of top genes - find_genes_list_path = gene_results.parent / "genes_of_top_100_DE_transcripts.txt" + # Construct the correct path to the top genes file dynamically + top_n = args.find_genes # Get the number used for top N + contrast_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + top_genes_filename = f"genes_of_top_{top_n}_DE_transcripts_{contrast_label}.txt" + find_genes_list_path = gene_results.parent / top_genes_filename + logging.info(f"Reading gene list generated by differential analysis from: {find_genes_list_path}") gene_list = dictionary_builder.read_gene_list(find_genes_list_path) else: base_dir = viz_output_dir From 2afaa0b6d28dabc130b374bf5def866b4425f153 Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 18 Apr 2025 00:38:12 +0300 Subject: [PATCH 31/35] remove old man part --- README.md | 953 ------------------------------------------------------ 1 file changed, 953 deletions(-) diff --git a/README.md b/README.md index 16e6f438..ada73bbb 100644 --- a/README.md +++ b/README.md @@ -119,956 +119,3 @@ You can leave your comments and bug reports at our [GitHub repository tracker](h --data_type (assembly|pacbio|nanopore) -o OUTPUT_FOLDER * If multiple files are provided, IsoQuant will create a single output annotation and a single set of gene/transcript expression tables. - - -## Table of contents - -1. [About IsoQuant](#sec1)
-1.1. [Supported data types](#sec1.1)
-1.2. [Supported reference data](#sec1.2)
-2. [Installation](#sec2)
-2.1. [Installing from conda](#sec2.1)
-2.2. [Installation from GitHub](#sec2.2)
-2.3. [Verifying your installation](#sec2.3)
-3. [Running IsoQuant](#sec3)
-3.1. [IsoQuant input](#sec3.1)
-3.2. [Command line options](#sec3.2)
-3.3. [IsoQuant output](#sec3.3)
-4. [Visualization](#sec4)
-5. [Citation](#sec5)
-6. [Feedback and bug reports](#sec6)
- - -# About IsoQuant - -IsoQuant is a tool for the genome-based analysis of long RNA reads, such as PacBio or -Oxford Nanopores. IsoQuant allows to reconstruct and quantify transcript models with -high precision and decent recall. If the reference annotation is given, IsoQuant also -assigns reads to the annotated isoforms based on their intron and exon structure. -IsoQuant further performs annotated gene, isoform, exon and intron quantification. -If reads are grouped (e.g. according to cell type), counts are reported according to the provided grouping. - -IsoQuant consists of two stages, which generate its own output: -1. Reference-based analysis. Runs only if reference annotation is provided. Performs read-to-isoform assignment, -splice site correction and abundance quantification for reference genes/transcripts. -2. Transcript discovery. Reconstructs transcript models and performs abundance quantification for discovered isoforms. - -Latest IsoQuant version can be downloaded from [https://github.com/ablab/IsoQuant/releases/latest](https://github.com/ablab/IsoQuant/releases/latest). - -New IsoQuant documentation is available [here](https://ablab.github.io/IsoQuant/). - -#### IsoQuant pipeline -![Pipeline](docs/isoquant_pipeline.png) - - -## Supported data types - -IsoQuant support all kinds of long RNA data: -* PacBio CCS -* ONT dRNA / ONT cDNA -* Assembled / corrected transcript sequences - -Reads must be provided in FASTQ or FASTA format (can be gzipped). If you have already aligned your reads to the reference genome, simply provide sorted and indexed BAM files. - -IsoQuant expect reads to contain polyA tails. For more reliable transcript model construction do not trim polyA tails. - -IsoQuant can also take aligned Illumina reads to correct long-read spliced alignments. However, short reads are _not_ -used to discover transcript models or compute abundances. - - -## Supported reference data - -Reference genome should be provided in multi-FASTA format (can be gzipped). -Reference genome is mandatory even when BAM files are provided. - -Reference gene annotation is not mandatory, but is likely to increase precision and recall. -It can be provided in GFF/GTF format (can be gzipped). -In this case it will be converted to [gffutils](https://pythonhosted.org/gffutils/installation.html) database. Information on converted databases will be stored in your `~/.config/IsoQuant/db_config.json` to increase speed of future runs. You can also provide gffutils database manually. Make sure that chromosome/scaffold names are identical in FASTA file and gene annotation. -Note, that gffutils databases may not work correctly on NFS shares. It is possible to set a designated folder for -the database with `--genedb_output` (different from the output directory). - -Pre-constructed aligner index can also be provided to increase mapping time. - - -# Installation -IsoQuant requires a 64-bit Linux system or Mac OS and Python (3.8 and higher) to be pre-installed on it. -You will also need -* [gffutils](https://pythonhosted.org/gffutils/installation.html) -* [pysam](https://pysam.readthedocs.io/en/latest/index.html) -* [biopython](https://biopython.org/) -* [pybedtools](https://daler.github.io/pybedtools/) -* [pyfaidx](https://pypi.org/project/pyfaidx/) -* [pandas](https://pandas.pydata.org/) -* [pyyaml](https://pypi.org/project/PyYAML/) -* [minimap2](https://github.com/lh3/minimap2) -* [samtools](http://www.htslib.org/download/) -* [STAR](https://github.com/alexdobin/STAR) (optional) - - -## Installing from conda -IsoQuant can be installed with conda: -```bash -conda create -c conda-forge -c bioconda -n isoquant python=3.8 isoquant -``` - - -## Installing from GitHub -To obtain IsoQuant you can download repository and install requirements. -Clone IsoQuant repository and switch to the latest release: -```bash -git clone https://github.com/ablab/IsoQuant.git -cd IsoQuant -git checkout latest -``` -Install requirements: -```bash -pip install -r requirements.txt -``` - -You also need [samtools](http://www.htslib.org/download/) and [minimap2](https://github.com/lh3/minimap2) to be in the `$PATH` variable. - - -## Verifying your installation -To verify IsoQuant installation type -```bash -isoquant.py --test -``` -to run on toy dataset. -If the installation is successful, you will find the following information at the end of the log: -```bash -=== IsoQuant pipeline finished === -=== TEST PASSED CORRECTLY === -``` - - -# Running IsoQuant - -## IsoQuant input -To run IsoQuant, you should provide: -* Long RNA reads (PacBio or Oxford Nanopore) in one of the following formats: - * FASTA/FASTQ (can be gzipped); - * Sorted and indexed BAM; -* Reference sequence in FASTA format (can be gzipped); -* _Optionally_, you may provide a reference gene annotation in gffutils database or GTF/GFF format (can be gzipped). - -IsoQuant is also capable of using short Illumina reads to correct long-read alignments. - -IsoQuant can handle data from multiple _experiments_ simultaneously. Each experiment may contain multiple _samples_ (or _replicas_). -Each experiment is processed individually. Running IsoQuant on several experiments simultaneously -is equivalent to several separate IsoQuant runs. - -The output files for each experiment will be placed into a separate folder. -Files from the same _experiment_ are used to construct a single GTF and aggregated abundance tables. -If a single experiment contains multiple samples/replicas, per sample abundance tables are also generated. - -The ways of providing input files are described below. - - -### Specifying input data via command line - -Two main options are `--fastq` and `--bam` (see description below). Both options accept one or multiple files separated by space. -All provided files are treated as a single experiment, which means a single combined GTF will -be generated. If multiple files are provided, IsoQuant will compute tables with each column -corresponding to an individual file (per-sample counts). -To set a specific label for each sample use the `--label` option. Number of labels must be equal to the number of files. -To a set a prefix for the output files use the `--prefix` option. - -This pipeline is typical for the cases when a user is -interested in comparing expression between different replicas/conditions within the same experiment. - -#### Short reads for alignment correction - -A BAM file with Illumina reads can be provided via `--illumina_bam`. It cannot be the only input, but may only be used with either `--bam` or `--fastq`. -The option accepts one or multiple bam files separated by space. All files will be combined and used to correct offsets between introns in long and short reads as well as skipped exons. - - -### Specifying input data via yaml file - -To provide all input files in a single description file, you can use a [YAML](https://www.redhat.com/en/topics/automation/what-is-yaml) file via `--yaml` (see description below). -You can provide multiple experiments in a single YAML file with each experiment containing an arbitrary number of smaples/replicas. -A distinct output folder with individual GTFs and abundance tables will be generated for each experiment. -In this option, BAM files with short reads for correction can be provided for each experiment. - -The YAML file contains a list of experiments (e.g. in square brackets). -The first entry in the list should be the type of files the experiments contain, written as `data format: ` -followed by the type in quotation marks. The type can be either `fastq` or `bam`. - -Each experiment is represented as set of parameters (e.g. in curly brackets). -Each experiment must have a name and a list of long-read files in the specified format. -Additionally, it may contain one or multiple BAM files with short reads. -The name is provided as `name: ` followed by the experiment name in quotation marks. -Both short and long read files are provided as a list of file paths in quotation marks, -following `long read files: ` and `illumina bam: ` respectively. -Labels for the files can also be set with `labels: `. -The number of labels needs to be the same as the number of files with long reads. -All paths should be either absolute or relative to the YAML file. - -For example: - -``` -[ - data format: "fastq", - { - name: "Experiment1", - long read files: [ - "/PATH/TO/FILE1.fastq", - "/PATH/TO/FILE2.fastq" - ], - labels: [ - "Sample1", - "Sample2" - ], - illumina bam: ["PATH/TO/ILLUMINA1.bam"] - }, - { - name: "Experiment2", - long read files: [ - "/PATH/TO/FILE3.fastq" - ], - illumina bam: ["PATH/TO/ILLUMINA2.bam"] - } -] - -``` - - -Output sub-folders will be named `Experiment1` and `Experiment2`. -Both sub-folders will contain predicted transcript models and abundance tables. -Abundance table for `Experiment2` with have columns "Sample1" and "Sample2". - -Note, that `--bam`, `--fastq` and `--label` options are not compatible with `--yaml`. -See more in [examples](#examples). - - - -## IsoQuant command line options - - -### Basic options -`--output` (or `-o`) - Output folder, will be created automatically. - -Note: if your output folder is located on a shared disk, use `--genedb_output` for storing -reference annotation database. - -`--help` (or `-h`) - Prints help message. - -`--full_help` - Prints all available options (including hidden ones). - -`--test` - Runs IsoQuant on the toy data set. - - -### Input options -`--data_type` or `-d` - Type of data to process, supported values are: `pacbio_ccs` (same as `pacbio`), `nanopore` (same as `ont`) -and `assembly` (same as `transcripts`). This option affects the algorithm parameters. - -Note, that for novel mono-exonic transcripts are not reported for ONT data by default, use `--report_novel_unspliced true`. - -`--reference` or `-r` - Reference genome in FASTA format (can be gzipped), required even when BAM files are provided. - -`--index` - Reference genome index for the specified aligner (`minimap2` by default), -can be provided only when raw reads are used as an input (constructed automatically if not set). - -`--genedb` or `-g` - Gene database in gffutils database format or GTF/GFF format (can be gzipped). -If you use official gene annotations we recommend to set `--complete_genedb` option. - -`--complete_genedb` - Set this flag if gene annotation contains transcript and gene meta-features. -Use this flag when providing official annotations, e.g. GENCODE. -This option will set `disable_infer_transcripts` and `disable_infer_genes` gffutils options, -which dramatically speeds up gene database conversion (see more [here](https://daler.github.io/gffutils/autodocs/gffutils.create.create_db.html)). - -#### Providing input reads via command line option: - -`--fastq` - Input FASTQ/FASTA file(s), can be gzipped; a single GTF will be generated for all files. If multiple files are provided, -expression tables with "per-file" columns will be computed. See more about [input data](#sec3.1). - - -`--bam` - Sorted and indexed BAM file(s); a single GTF will be generated for all files. If multiple files are provided, -expression tables with "per-file" columns will be computed. See more about [input data](#sec3.1). - - -#### Providing input reads via YAML configuration file: - -`--yaml` - Path to dataset description file in [YAML](https://www.redhat.com/en/topics/automation/what-is-yaml) format. The file should contain a list with `data format` property, -which can be `fastq` or `bam` and an individual entry for experiment. -Each experiment is represented as set of parameters (e.g. in curly brackets): -- `name` - experiment name, string (optional); -- `long read files` - a list of paths to long read files matching the specified format; -- `lables` - a list labels for long read files for expression table (optional, must be equal to the number of long read files) -- `illumina bam` - a list of paths to short read BAM files for splice site correction (optional). - -All paths should be either absolute or relative to the YAML file. -See more in [examples](#examples). - -#### Providing input reads via dataset description file (deprecated since 3.4) - -`--bam_list` (_deprecated since 3.4_) - Text file with list of BAM files, one file per line. Each file must be sorted and indexed. -Leave empty line or experiment name starting with # between the experiments. -For each experiment IsoQuant will generate a individual GTF and count tables. -You may also give a label for each file specifying it after a colon (e.g. `/PATH/TO/file.bam:replicate1`). - -`--fastq_list` (_deprecated since 3.4_) - Text file with list of FASTQ/FASTA files (can be gzipped), one file per line. -Leave empty line or experiment name starting with # between the experiments. -For each experiment IsoQuant will generate a individual GTF and count tables. -You may also give a label for each file specifying it after a colon (e.g. `/PATH/TO/file.fastq:replicate1`). - -#### Other input options: -`--stranded` - Reads strandness type, supported values are: `forward`, `reverse`, `none`. - -`--fl_data` - Input sequences represent full-length transcripts; both ends of the sequence are considered to be reliable. - -`--prefix` or `-p` - Prefix for all output files and sub-folder name. `OUT` if not set. - -`--labels` or `-l` - Sets space-separated sample names. Make sure that the number of labels is equal to the number of files. -Input file names are used as labels if not set. - -`--read_group` - Sets a way to group feature counts (e.g. by cell type). Available options are: - * `file_name`: groups reads by their original file names (or file name labels) within an experiment. -This option makes sense when multiple files are provided. -This option is designed for obtaining expression tables with a separate column for each file. -If multiple BAM/FASTQ files are provided and `--read_group` option is not set, IsoQuant will set `--read_group:file_name` -by default. - * `tag`: groups reads by BAM file read tag: set `tag:TAG`, where `TAG` is the desired tag name -(e.g. `tag:RG` with use `RG` values as groups, `RG` will be used if unset); - * `read_id`: groups reads by read name suffix: set `read_id:DELIM` where `DELIM` is the -symbol/string by which the read id will be split -(e.g. if `DELIM` is `_`, for read `m54158_180727_042959_59310706_ccs_NEU` the group will set as `NEU`); - * `file`: uses additional file with group information for every read: `file:FILE:READ_COL:GROUP_COL:DELIM`, -where `FILE` is the file name, `READ_COL` is column with read ids (0 if not set), -`GROUP_COL` is column with group ids (1 if not set), -`DELIM` is separator symbol (tab if not set). File can be gzipped. - - -### Output options - -`--sqanti_output` - Produce comparison between novel and known transcripts in SQANTI-like format. - Will take effect only when reference annotation is provided. - -`--check_canonical` - Report whether read or constructed transcript model contains non-canonical splice junction (requires more time). - -`--count_exons` - Perform exon and intron counting in addition to gene and transcript counting. - Will take effect only when reference annotation is provided. - -`--bam_tags` - Comma separated list of BAM tags that will be imported into `read_assignments.tsv`. - -### Pipeline options - -`--resume` - Resume a previously unfinished run. Output folder with previous run must be specified. - Allowed options are `--threads` and `--debug`, other options cannot be changed. - IsoQuant will run from the beginning if the output folder does not contain the previous run. - -`--force` - force to overwrite the folder with previous run. - -`--threads` or `-t` - Number of threads to use, 16 by default. - -`--clean_start` - Do not use previously generated gene database, genome indices or BAM files, run pipeline from the very beginning (will take more time). - -`--no_model_construction` - Do not report transcript models, run read assignment and quantification of reference features only. - -`--run_aligner_only` - Align reads to the reference without running IsoQuant itself. - - -### Algorithm parameters - - -#### Quantification - -`--transcript_quantification` Transcript quantification strategy; -`--gene_quantification` Gene quantification strategy; - -Available options for quantification: - -* `unique_only` - use only reads that are uniquely assigned and consistent with a transcript/gene -(i.e. flagged as unique/unique_minor_difference), default fot transcript quantification; -* `with_ambiguous` - in addition to unique reads, ambiguously assigned consistent reads are split between features with equal weights -(e.g. 1/2 when a read is assigned to 2 features simultaneously); -* `unique_splicing_consistent` - uses uniquely assigned reads that do not contradict annotated splice sites -(i.e. flagged as unique/unique_minor_difference or inconsistent_non_intronic), default for gene quantification; -* `unique_inconsistent` - uses uniquely assigned reads allowing any kind of inconsistency; -* `all` - all of the above. - - -#### Read to isoform matching: - -`--matching_strategy` A preset of parameters for read-to-isoform matching algorithm, should be one of: - -* `exact` - delta = 0, all minor errors are treated as inconsistencies; -* `precise` - delta = 4, only minor alignment errors are allowed, default for PacBio data; -* `default` - delta = 6, alignment errors typical for Nanopore reads are allowed, short novel introns are treated as deletions; -* `loose` - delta = 12, even more serious inconsistencies are ignored, ambiguity is resolved based on nucleotide similarity. - -Matching strategy is chosen automatically based on specified data type. -However, the parameters will be overridden if the matching strategy is set manually. - -#### Read alignment correction: - -`--splice_correction_strategy` A preset of parameters for read alignment correction algorithms, should be one of: - -* `none` - no correction is applied; -* `default_pacbio` - optimal settings for PacBio CCS reads; -* `default_ont` - optimal settings for ONT reads; -* `conservative_ont` - conservative settings for ONT reads, only incorrect splice junction and skipped exons are fixed; -* `assembly` - optimal settings for a transcriptome assembly; -* `all` - correct all discovered minor inconsistencies, may result in overcorrection. - -This option is chosen automatically based on specified data type, but will be overridden if set manually. - -#### Transcript model construction: -`--model_construction_strategy` A preset of parameters for transcript model construction algorithm, should be one of - -* `reliable` - only the most abundant and reliable transcripts are reported, precise, but not sensitive; -* `default_pacbio` - optimal settings for PacBio CCS reads; -* `sensitive_pacbio` - sensitive settings for PacBio CCS reads, more transcripts are reported possibly at a cost of precision; -* `fl_pacbio` - optimal settings for full-length PacBio CCS reads, will be used if `--data_type pacbio_ccs` and `--fl_data` options are set; -* `default_ont` - optimal settings for ONT reads, novel mono-exonic transcripts are not reported (use `--report_novel_unspliced true`); -* `sensitive_ont` - sensitive settings for ONT reads, more transcripts are reported possibly at a cost of precision (including novel mono-exonic isoforms); -* `assembly` - optimal settings for a transcriptome assembly: input sequences are considered to be reliable and each transcript to be represented only once, so abundance is not considered; -* `all` - reports almost all novel transcripts, loses precision in favor to recall. - -This option is chosen automatically based on specified data type, but will be overridden if set manually. - - -`--report_novel_unspliced` Report novel mono-exonic transcripts (set `true` or `false`). -The default value is `false` for Nanopore data and `true` for other data types. -The main explanation that some aligners report a lot of false unspliced alignments -for ONT reads. - - -`--report_canonical` - Strategy for reporting novel transcripts based on canonical splice sites, should be one of: - -* `auto` - automatic selection based on the data type and model construction strategy (default); -* `only_canonical` - report novel transcripts, which contain only canonical splice sites; -* `only_stranded` - report novel transcripts, for which the strand can be unambiguously derived using splice sites and -presence of a polyA tail, allowing some splice sites to be non-canonical; -* `all` -- report all transcript model regardless of their splice sites. - - -`--polya_requirement` Strategy for using polyA tails during transcript model construction, should be one of: - -* `auto` - default behaviour: polyA tails are required if at least 70% of the reads have polyA tail; -polyA tails are always required for 1/2-exon transcripts when using ONT data (this is caused by elevated number of false 1/2-exonic alignments reported by minimap2); -* `never` - polyA tails are never required; use this option **at your own risk** as it may noticeably increase false discovery rate, especially for ONT data; -* `always` - reported transcripts are always required to have polyA support in the reads. - -Note, that polyA tails are always required for reporting novel unspliced isoforms. - - - -### Hidden options - -Options below are shown only with `--full_help` option. -We recommend _not_ to modify these options unless you are clearly aware of their effect. - -`--no_gzip` - Do not compress large output files. - -`--no_gtf_check` - Do not perform input GTF checks. - -`--no_secondary` - Ignore secondary alignments. - -`--aligner` - Force to use this alignment method, can be `starlong` or `minimap2`; `minimap2` is currently used as default. Make sure the specified aligner is in the `$PATH` variable. - -`--no_junc_bed` - Do not use gene annotation for read mapping. - -`--junc_bed_file` - Annotation in BED12 format produced by `paftools.js gff2bed` (can be found in `minimap2`), will be created automatically if not given. - -`--delta` - Delta for inexact splice junction comparison, chosen automatically based on data type (e.g. 4bp for PacBio, 6pb for ONT). - -`--genedb_output` - If your output folder is located on a shared storage (e.g. NFS share), use this option to set another path - for storing the annotation database, because SQLite database cannot be created on a shared disks. - The folder will be created automatically. - -`--high_memory` - Cache read alignments instead for making several passes over a BAM file, noticeably increases RAM usage, -but may improve running time when disk I/O is relatively slow. - -`--min_mapq` - Filers out all alignments with MAPQ less than this value (will also filter all secondary alignments, as they typically have MAPQ = 0). - -`--inconsistent_mapq_cutoff` - Filers out inconsistent alignments with MAPQ less than this value (works when the reference annotation is provided, default is 5). - -`--simple_alignments_mapq_cutoff` - Filers out alignments with 1 or 2 exons and MAPQ less than this value (works only in annotation-free mode, default is 1). - -`--normalization_method` - Method for normalizing non-grouped counts into TPMs: -* `simple` - standard method, scale factor equals to 1 million divided by the counts sum (default); -* `usable_reads` - includes all reads assigned to a feature including the ones that were filtered out -during quantification (i.e. inconsistent or ambiguous); -scale factor equals to 1 million divided by the number of all assigned reads. -In this case the sum of all gene/transcript TPMs may not add up to 1 million. -Experiments with simulated data show that this method could give more accurate estimations. -However, normalization method does not affect correlation/relative proportions. - - -### Examples - - -* Mapped PacBio CCS reads in BAM format; pre-converted gene annotation: - -```bash -isoquant.py -d pacbio_ccs --bam mapped_reads.bam \ - --genedb annotation.db --output output_dir -``` - -* Nanopore dRNA stranded reads; official annotation in GTF format, use custon prefix for output: -```bash -isoquant.py -d nanopore --stranded forward --fastq ONT.raw.fastq.gz \ - --reference reference.fasta --genedb annotation.gtf --complete_genedb \ - --output output_dir --prefix My_ONT -``` - -* Nanopore cDNA reads; no reference annotation: -```bash -isoquant.py -d nanopore --fastq ONT.cDNA.raw.fastq.gz \ - --reference reference.fasta --output output_dir --prefix My_ONT_cDNA -``` - -* PacBio FL reads; custom annotation in GTF format, which contains only exon features: -```bash -isoquant.py -d pacbio_ccs --fl_data --fastq CCS.fastq \ - --reference reference.fasta --genedb genes.gtf --output output_dir -``` - -* Nanopore cDNA reads, multiple samples/replicas within a single experiment; official annotation in GTF format: -```bash -isoquant.py -d nanopore --bam ONT.cDNA_1.bam ONT.cDNA_2.bam ONT.cDNA_3.bam \ - --reference reference.fasta --genedb annotation.gtf --complete_genedb --output output_dir - --predix ONT_3samples --labels A1 A2 A3 -``` - -* ONT cDNA reads; 2 experiments with 3 replicates; official annotation in GTF format: -```bash -isoquant.py -d nanopore --yaml dataset.yaml \ - --complete_genedb --genedb genes.gtf \ - --reference reference.fasta --output output_dir -``` - -dataset.yaml file : - -``` -[ - data format: "fastq", - { - name: "Experiment1", - long read files: [ - "/PATH/TO/SAMPLE1/file1.fastq", - "/PATH/TO/SAMPLE1/file2.fastq", - "/PATH/TO/SAMPLE1/file3.fastq" - ], - labels: [ - "Replicate1", - "Replicate2", - "Replicate3" - ] - }, - { - name: "Experiment1", - long read files: [ - "/PATH/TO/SAMPLE2/file1.fastq", - "/PATH/TO/SAMPLE2/file2.fastq", - "/PATH/TO/SAMPLE2/file3.fastq" - ], - labels: [ - "Replicate1", - "Replicate2", - "Replicate3" - ] - } -] - -``` - - -IsoQuant will produce 2 sets of resulting files (including annotations and expression tables), one for each experiment. -Output sub-folder will be named `Experiment1` and `Experiment2`. -Expression tables will have columns "Replicate1", "Replicate2" and "Replicate3". - - -* ONT cDNA reads; 1 experiment with 2 replicates, each replicate has 2 files; official annotation in GTF format: -```bash -isoquant.py -d nanopore --yaml dataset.yaml \ - --complete_genedb --genedb genes.gtf \ - --reference reference.fasta --prefix MY_SAMPLE \ - --output output_dir -``` - -dataset.yaml file : - - -``` -[ - data format: "fastq", - { - name: "Experiment1", - long read files: [ - "/PATH/TO/SAMPLE1/file1.fastq", - "/PATH/TO/SAMPLE1/file2.fastq", - "/PATH/TO/SAMPLE1/file3.fastq", - "/PATH/TO/SAMPLE1/file3.fastq" - ], - labels: [ - "Replicate1", - "Replicate1", - "Replicate2", - "Replicate2" - ] - } -] - -``` - - -IsoQuant will produce one output sub-folder `Experiment1`. -Expression tables will have columns "Replicate1" and "Replicate2". -Files having identical labels will be treated as a single replica (and thus the counts will be combined). - - - -## IsoQuant output - -### Output files - -IsoQuant output files will be stored in ``, which is set by the user. -If the output directory was not specified the files are stored in `isoquant_output`. - -IsoQuant consists of two stages, which generate its own output: -1. Reference-based analysis. Runs only if reference annotation is provided. Performs read-to-isofrom assignment, -splice site correction and abundance quantification for reference genes/transcripts. -2. Transcript discovery. Reconstructs transcript models and performs abundance quantification for discovered isoforms. - -#### Reference-based analysis output - -_Will be produced only if a reference gene annotation is provided._ - -* `SAMPLE_ID.read_assignments.tsv.gz` - TSV file with read to isoform assignments (gzipped by default); -* `SAMPLE_ID.corrected_reads.bed.gz` - BED file with corrected read alignments (gzipped by default); -* `SAMPLE_ID.transcript_tpm.tsv` - TSV file with reference transcript expression in TPM; -* `SAMPLE_ID.transcript_counts.tsv` - TSV file with raw read counts for reference transcript; -* `SAMPLE_ID.gene_tpm.tsv` - TSV file with reference gene expression in TPM; -* `SAMPLE_ID.gene_counts.tsv` - TSV file with raw read counts for reference genes; - -If `--sqanti_output` is set, IsoQuant will produce output in [SQANTI](https://github.com/ConesaLab/SQANTI3)-like format: -* `SAMPLE_ID.novel_vs_known.SQANTI-like.tsv` - discovered novel transcripts vs reference transcripts (similar, but not identical to SQANTI `classification.txt`); - -If `--count_exons` is set, exon and intron counts will be produced: -* `SAMPLE_ID.exon_counts.tsv` - reference exon inclusion/exclusion read counts; -* `SAMPLE_ID.intron_counts.tsv` - reference intron inclusion/exclusion read counts; - -If `--read_group` is set, the per-group expression values for reference features will be also computed: - -In matrix format (feature X groups) -* `SAMPLE_ID.gene_grouped_tpm.tsv` -* `SAMPLE_ID.transcript_grouped_tpm.tsv` -* `SAMPLE_ID.gene_grouped_counts.tsv` -* `SAMPLE_ID.transcript_grouped_counts.tsv` - -In linear format (feature, group, value(s) per each line) -* `SAMPLE_ID.gene_grouped_counts_linear.tsv` -* `SAMPLE_ID.transcript_grouped_counts_linear.tsv` -* `SAMPLE_ID.exon_grouped_counts.tsv` -* `SAMPLE_ID.intron_grouped_counts.tsv` - -#### Transcript discovery output - -_Will not be produced if `--no_model_construction` is set._ - -File names typically contain `transcript_model` in their name. - -* `SAMPLE_ID.transcript_models.gtf` - GTF file with discovered expressed transcript (both known and novel transcripts); -* `SAMPLE_ID.transcript_model_reads.tsv.gz` - TSV file indicating which reads contributed to transcript models (gzipped by default); -* `SAMPLE_ID.transcript_model_tpm.tsv` - expression of discovered transcripts models in TPM (corresponds to `SAMPLE_ID.transcript_models.gtf`); -* `SAMPLE_ID.transcript_model_counts.tsv` - raw read counts for discovered transcript models (corresponds to `SAMPLE_ID.transcript_models.gtf`); -* `SAMPLE_ID.extended_annotation.gtf` - GTF file with the entire reference annotation plus all discovered novel transcripts; - - -If `--read_group` is set, the per-group counts for discovered transcripts will be also computed: -* `SAMPLE_ID.transcript_model_grouped_counts.tsv` -* `SAMPLE_ID.transcript_model_grouped_tpm.tsv` - - -If multiple experiments are provided, aggregated expression matrices will be placed in ``: -* `combined_gene_counts.tsv` -* `combined_gene_tpm.tsv` -* `combined_transcript_counts.tsv` -* `combined_transcript_tpm.tsv` - -Additionally, a log file will be saved to the directory. -* /isoquant.log - -If raw reads were provided, BAM file(s) will be stored in `//aux/`. -In case `--keep_tmp` option was specified this directory will also contain temporary files. - -### Output file formats - -Although most output files include headers that describe the data, a brief explanation of the output files is provided below. - -#### Read to isoform assignment - -Tab-separated values, the columns are: - -* `read_id` - read id; -* `chr` - chromosome id; -* `strand` - strand of the assigned isoform (not to be confused with read mapping strand); -* `isoform_id` - isoform id to which the read was assigned; -* `gene_id` - gene id to which the read was assigned; -* `assignment_type` - assignment type, can be: - - `unique` - reads was unambiguously assigned to a single known isoform; - - `unique_minor_difference` - read was assigned uniquely but has alignment artifacts; - - `inconsistent` - read was matched with inconsistencies, closest match(es) are reported; - - `inconsistent_non_intronic` - read was matched with inconsistencies, which do not affect intron chain (e.g. olly TSS/TES); - - `inconsistent_ambiguous` - read was matched with inconsistencies equally well to two or more isoforms; - - `ambiguous` - read was assigned to multiple isoforms equally well; - - `noninfomative` - reads is intronic or has an insignificant overlap with a known gene; - - `intergenic` - read is intergenic. -* `assignment_events` - list of detected inconsistencies; for each assigned isoform a list of detected inconsistencies relative to the respective isoform is stored; values in each list are separated by `+` symbol, lists are separated by comma, the number of lists equals to the number of assigned isoforms; possible events are (see graphical representation below): - - consistent events: - - `none` / `.` / `undefined` - no special event detected; - - `mono_exon_match` mono-exonic read matched to mono-exonic transcript; - - `fsm` - full splice match; - - `ism_5/3` - incomplete splice match, truncated on 5'/3' side; - - `ism_internal` - incomplete splice match, truncated on both sides; - - `mono_exonic` - mono-exonic read matching spliced isoform; - - `tss_match` / `tss_match_precise` - 5' read is located less than 50 / `delta` bases from the TSS of the assigned isoform - - `tes_match` / `tes_match_precise` - 3' read is located less than 50 / `delta` bases from the TES of the assigned isoform (can be reported without detecting polyA sites) - - alignment artifacts: - - `intron_shift` - intron that seems to be shifted due to misalignment (typical for Nanopores); - - `exon_misalignment` - short exon that seems to be missed due to misalignment (typical for Nanopores); - - `fake_terminal_exon_5/3` - short terminal exon at 5'/3' end that looks like an alignment artifact (typical for Nanopores); - - `terminal_exon_misalignment_5/3` - missed reference short terminal exon; - - `exon_elongation_5/3` - minor exon extension at 5'/3' end (not exceeding 30bp); - - `fake_micro_intron_retention` - short annotated introns are often missed by the aligners and thus are not considered as intron retention; - - intron retentions: - - `intron_retention` - intron retention; - - `unspliced_intron_retention` - intron retention by mono-exonic read; - - `incomplete_intron_retention_5/3` - terminal exon at 5'/3' end partially covers adjacent intron; - - significant inconsistencies (each type end with `_known` if _all_ resulting read introns are annotated and `_novel` otherwise): - - `major_exon_elongation_5/3` - significant exon extension at 5'/3' end (exceeding 30bp); - - `extra_intron_5/3` - additional intron on the 5'/3' end of the isoform; - - `extra_intron` - read contains additional intron in the middle of exon; - - `alt_donor_site` - read contains alternative donor site; - - `alt_acceptor_site` - read contains alternative annotated acceptor site; - - `intron_migration` - read contains alternative annotated intron of approximately the same length as in the isoform; - - `intron_alternation` - read contains alternative intron, which doesn't fall intro any of the categories above; - - `mutually_exclusive_exons` - read contains different exon(s) of the same total length comparing to the isoform; - - `exon_skipping` - read skips exon(s) comparing to the isoform; - - `exon_merge` - read skips exon(s) comparing to the isoform, but a sequence of a similar length is attached to a neighboring exon; - - `exon_gain` - read contains additional exon(s) comparing to the isoform; - - `exon_detach` - read contains additional exon(s) comparing to the isoform, but a neighboring exon looses a sequnce of a similar length; - - `terminal_exon_shift` - read has alternative terminal exon; - - `alternative_structure` - reads has different intron chain that does not fall into any of categories above; - - alternative transcription start / end (reported when poly-A tails are present): - - `alternative_polya_site` - read has alternative polyadenylation site; - - `internal_polya_site` - poly-A tail detected but seems to be originated from A-rich intronic region; - - `correct_polya_site` - poly-A site matches reference transcript end; - - `aligned_polya_tail` - poly-A tail aligns to the reference; - - `alternative_tss` - alternative transcription start site. -* `exons` - list of coordinates for normalized read exons (1-based, indels and polyA exons are excluded); -* `additional` - field for supplementary information, which may include: - - `gene_assignment` - Gene assignment classification; possible values are the same as for transcript classification. - - `PolyA` - True if poly-A tail is detected; - - `Canonical` - True if all read introns are canonical, Unspliced is used for mono-exon reads; (use `--check_canonical`); - - `Classification` - SQANTI-like assignment classification. - -Note, that a single read may occur more than once if assigned ambiguously. - -#### Expression table format - -Tab-separated values, the columns are: - -* `feature_id` - genomic feature ID; -* `TPM` or `count` - expression value (float). - -For grouped counts, each column contains expression values of a respective group (matrix representation). - -Beside count matrix, transcript and gene grouped counts are also printed in a linear format, -in which each line contains 3 tab-separated values: - -* `feature_id` - genomic feature ID; -* `group_id` - group name; -* `count` - read count of the feature in this group. - -#### Exon and intron count format - -Tab-separated values, the columns are: - -* `chr` - chromosome ID; -* `start` - feature leftmost 1-based positions; -* `end` - feature rightmost 1-based positions; -* `strand` - feature strand; -* `flags` - symbolic feature flags, can contain the following characters: - - `X` - terminal feature; - - `I` - internal feature; - - `T` - feature appears as both terminal and internal in different isoforms; - - `S` - feature has similar positions to some other feature; - - `C` - feature is contained in another feature; - - `U` - unique feature, appears only in a single known isoform; - - `M` - feature appears in multiple different genes. -* `gene_ids` - list if gene ids feature belong to; -* `group_id` - read group if provided (NA by default); -* `include_counts` - number of reads that include this feature; -* `exclude_counts` - number of reads that span, but do not include this feature; - -#### Transcript models format - -Constructed transcript models are stored in usual [GTF format](https://www.ensembl.org/info/website/upload/gff.html). -Contains `exon`, `transcript` and `gene` features. - -Known genes and transcripts are reposted with their reference IDs. -Novel genes IDs have format `novel_gene_XXX_###` and novel transcript IDs are formatted as `transcript###.XXX.TYPE`, -where `###` is the unique number (not necessarily consecutive), `XXX` is the chromosome name and TYPE can be one of the following: - -* nic - novel in catalog, new transcript that contains only annotated introns; -* nnic - novel not in catalog, new transcript that contains unannotated introns. - -Each exon also has a unique ID stored in `exon_id` attribute. - -In addition, each transcript contains `canonical` property if `--check_canonical` is set. - -If `--sqanti_output` option is set, each novel transcript also has a `similar_reference_id` field containing ID of -a most similar reference isoform and `alternatives` attribute, which indicates the exact differences between -this novel transcript and the similar reference transcript. - -### Event classification figures -#### Consistent match classifications -![Correct](docs/correct_match.png)

- -#### Misalignment classifications -![Misalignment](docs/misalignment.png)

- -#### Inconsistency classifications -![Inconsistent](docs/inconsistent.png)

- -#### PolyA classifications -![PolyA](docs/polya.png) - - - -## Visualization - -IsoQuant provides a visualization tool to help interpret and explore the output data. The goal of this visualization is to create informative plots that represent transcript usage and splicing patterns for genes of interest. Additionally, we provide global transcript and read assignment statistics from the IsoQuant analysis. - -### Running the visualization tool - -To run the visualization tool, use the following command: - -```bash - -python visualize.py --gene_list [options] - -``` - -### Command line options - -* `output_directory` (required): Directory containing IsoQuant output files. -* * `--gene_list` (required): Path to a .txt file containing a list of genes, each on its own line. -* `--viz_output`: Optional directory to save visualization output files. Defaults to the main output directory if not specified. -* `--gtf`: Optional path to a GTF file if it cannot be extracted from the IsoQuant log. -* `--counts`: Use counts instead of TPM files for visualization. -* `--ref_only`: Use only reference transcript quantification instead of transcript model quantification. -* `--filter_transcripts`: Filter transcripts by minimum value occurring in at least one condition. - - -### Output - -The visualization tool generates the following plots based on the IsoQuant output: - -1. Transcript usage profiles: For each gene specified in the gene list, a plot showing the relative usage of different transcripts across conditions or samples. - -2. Gene-specific transcript maps: Visual representation of the different splicing patterns of transcripts for each gene, allowing easy comparison of exon usage and alternative splicing events. - -3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed. - -4. Global transcript alignment classifications: A chart or plot representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. - -These visualizations provide valuable insights into transcript diversity, splicing patterns, and the overall quality of the IsoQuant analysis. - - - -## Visualization - -IsoQuant provides a visualization tool to help interpret and explore the output data. The goal of this visualization is to create informative plots that represent transcript usage and splicing patterns for genes of interest. Additionally, we provide global transcript and read assignment statistics from the IsoQuant analysis. - -### Running the visualization tool - -To run the visualization tool, use the following command: - -```bash - -python visualize.py --gene_list [options] - -``` - -### Command line options - -* `output_directory` (required): Directory containing IsoQuant output files. -* * `--gene_list` (required): Path to a .txt file containing a list of genes, each on its own line. -* `--viz_output`: Optional directory to save visualization output files. Defaults to the main output directory if not specified. -* `--gtf`: Optional path to a GTF file if it cannot be extracted from the IsoQuant log. -* `--counts`: Use counts instead of TPM files for visualization. -* `--ref_only`: Use only reference transcript quantification instead of transcript model quantification. -* `--filter_transcripts`: Filter transcripts by minimum value occurring in at least one condition. - - -### Output - -The visualization tool generates the following plots based on the IsoQuant output: - -1. Transcript usage profiles: For each gene specified in the gene list, a plot showing the relative usage of different transcripts across conditions or samples. - -2. Gene-specific transcript maps: Visual representation of the different splicing patterns of transcripts for each gene, allowing easy comparison of exon usage and alternative splicing events. - -3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed. - -4. Global transcript alignment classifications: A chart or plot representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. - -These visualizations provide valuable insights into transcript diversity, splicing patterns, and the overall quality of the IsoQuant analysis. - - -## Citation -The paper describing IsoQuant algorithms and benchmarking is available at [10.1038/s41587-022-01565-y](https://doi.org/10.1038/s41587-022-01565-y). - -To try IsoQuant you can use the data that was used in the publication [zenodo.org/record/7611877](https://zenodo.org/record/7611877). - - -## Feedback and bug reports -Your comments, bug reports, and suggestions are very welcome. They will help us to further improve IsoQuant. If you have any troubles running IsoQuant, please send us `isoquant.log` from the `` directory. - -You can leave your comments and bug reports at our [GitHub repository tracker](https://github.com/ablab/IsoQuant/issues) or send them via email: isoquant.rna@gmail.com. From 66cd110c4d326c6c581726cec9c40fe881c0a897 Mon Sep 17 00:00:00 2001 From: Andrey Prjibelski Date: Fri, 18 Apr 2025 00:38:33 +0300 Subject: [PATCH 32/35] update requirements --- install_r_packages.py | 2 ++ requirements.txt | 3 ++- visualize.py | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/install_r_packages.py b/install_r_packages.py index 7461369b..f346bb60 100644 --- a/install_r_packages.py +++ b/install_r_packages.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + import rpy2.robjects.packages as rpackages from rpy2.robjects.vectors import StrVector diff --git a/requirements.txt b/requirements.txt index 6ac32a8d..3bd76857 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,8 +8,9 @@ pyfaidx>=0.7 pyyaml>=5.4 matplotlib>=3.1.3 numpy>=1.18.1 -scipy>=1.4.1 +scipy>=1.10.0 seaborn>=0.10.0 +scikit-learn>=1.5 rpy2>=3.5.1 mygene>=3.2.0 diff --git a/visualize.py b/visualize.py index f4b04fd8..5fbea98b 100755 --- a/visualize.py +++ b/visualize.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + import argparse import sys import logging From ead79a0f9027ca8c9c90941eb3025af4e0e86c48 Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Sun, 4 May 2025 12:11:29 -0500 Subject: [PATCH 33/35] cleaned up exons bounds --- src/visualization_plotter.py | 53 ++++++++++++++++++------------------ visualize.py | 10 +++---- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 91615a1a..6aa07036 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -92,11 +92,17 @@ def plot_transcript_map(self): start = gene_data.get("start", 0) end = gene_data.get("end", 0) - # Calculate buffer (5% of total width) - width = end - start - buffer = width * 0.05 - plot_start = start - buffer - plot_end = end + buffer + # Find the actual min/max coordinates of all exons + min_exon_start = min(exon["start"] for transcript in gene_data["transcripts"].values() + for exon in transcript["exons"]) + max_exon_end = max(exon["end"] for transcript in gene_data["transcripts"].values() + for exon in transcript["exons"]) + + # Calculate buffer (10% of total width) + width = max(end, max_exon_end) - min(start, min_exon_start) + buffer = width * 0.10 # Increased from 5% to 10% + plot_start = min(start, min_exon_start) - buffer + plot_end = max(end, max_exon_end) + buffer # REMOVED FILTERING LOGIC - Directly use transcripts from gene_data filtered_transcripts = gene_data["transcripts"] @@ -144,8 +150,12 @@ def plot_transcript_map(self): linewidth=2, ) + # Sort exons based on strand direction + exons = sorted(transcript_info["exons"], + key=lambda x: x["start"] if gene_data["strand"] == "+" else -x["start"]) + # Exon blocks with color based on reference status - for exon in transcript_info["exons"]: + for exon_idx, exon in enumerate(exons, 1): exon_length = exon["end"] - exon["start"] if self.ref_only: # Check ref_only flag exon_color = "skyblue" # If ref_only, always treat as reference @@ -155,25 +165,16 @@ def plot_transcript_map(self): exon_color = "skyblue" if is_reference_exon else "red" exon_alpha = 1.0 if is_reference_exon else 0.6 - ax.add_patch( - plt.Rectangle( - (exon["start"], i - 0.4), - exon_length, - 0.8, - color=exon_color, - alpha=exon_alpha - ) + # Add exon rectangle + rect = plt.Rectangle( + (exon["start"], i - 0.4), + exon_length, + 0.8, + color=exon_color, + alpha=exon_alpha ) - - - if not any(exon["exon_id"].startswith("ENSE") for exon in transcript_info["exons"]): - logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains NO reference exons (based on ENSEMBL IDs)") - #log the exon_ids - #logging.debug(f"Exon IDs: {[exon['exon_id'] for exon in transcript_info['exons']]}") - else: - #logging.debug(f"Transcript {transcript_id} in gene {gene_name_or_id} contains at least one reference exon (based on ENSEMBL IDs)") - pass # Added explicit pass statement for the empty block - + ax.add_patch(rect) + # Store y-axis label information y_ticks.append(i) # Get transcript name with fallback options @@ -725,7 +726,7 @@ def create_volcano_plot( reference_label: str, padj_threshold: float = 0.05, lfc_threshold: float = 1, - top_n: int = 10, + top_n: int = 60, # Increased from 10 to 20 feature_type: str = "genes", ) -> None: """Create volcano plot from differential expression results.""" @@ -797,7 +798,7 @@ def create_volcano_plot( plt.tight_layout() plot_path = ( - self.output_path / f"volcano_plot_{feature_type}.pdf" # Changed from .png to .pdf + self.output_path / f"volcano_plot_{feature_type}.pdf" ) plt.savefig(str(plot_path)) plt.close() diff --git a/visualize.py b/visualize.py index f63b4ff9..3ac8ce5e 100755 --- a/visualize.py +++ b/visualize.py @@ -309,12 +309,12 @@ def main(): # 6. Plotting with PlotOutput plot_output = PlotOutput( - updated_gene_dict, - gene_list, - str(gene_visualizations_dir), - read_assignments_dir=str(read_assignments_dir), # Pass None if not used + updated_gene_dict=updated_gene_dict, + gene_names=gene_list, + gene_visualizations_dir=str(gene_visualizations_dir), + read_assignments_dir=str(read_assignments_dir) if read_assignments_dir else None, reads_and_class=reads_and_class, - filter_transcripts=min_val, # Just pass your chosen threshold for reference + filter_transcripts=min_val, conditions=output.conditions, ref_only=args.ref_only, ref_conditions=args.reference_conditions if hasattr(args, "reference_conditions") else None, From b276b12a5da7ef653cab333503ffcc1c8b6b2bfe Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Thu, 31 Jul 2025 14:41:50 -0500 Subject: [PATCH 34/35] First cut at no bio reps --- docs/visualization.md | 28 ++- src/visualization_differential_exp.py | 238 +++++++++++++++++- src/visualization_gsea.py | 188 ++++++++------ src/visualization_output_config.py | 293 ++++++++++++++++++++-- src/visualization_plotter.py | 10 +- src/visualization_simple_ranker.py | 348 ++++++++++++++++++++++++++ visualize.py | 168 ++++++++++--- 7 files changed, 1125 insertions(+), 148 deletions(-) create mode 100644 src/visualization_simple_ranker.py diff --git a/docs/visualization.md b/docs/visualization.md index 5bf1cd82..550b3395 100644 --- a/docs/visualization.md +++ b/docs/visualization.md @@ -4,35 +4,41 @@ IsoQuant provides a visualization tool to help interpret and explore the output ## Running the visualization tool -To run the visualization tool, use the following command: +To run the visualization tool, use one of the following commands: ```bash - +# Visualize a predefined list of genes python visualize.py --gene_list [options] +# Automatically find the top N most differentially expressed genes +python visualize.py --find_genes [N] [options] ``` ## Command line options * `output_directory` (required): Directory containing IsoQuant output files. -* * `--gene_list` (required): Path to a .txt file containing a list of genes, each on its own line. -* `--viz_output`: Optional directory to save visualization output files. Defaults to the main output directory if not specified. +* `--gene_list`: Path to a .txt file containing a list of genes, each on its own line. Mutually exclusive with `--find_genes`. +* `--find_genes [N]`: Automatically select the top **N** genes with the highest combined differential-expression rank between chosen conditions (default 100 if *N* is omitted). +* `--viz_output`: Optional directory to save visualization output files. Defaults to `/visualization`. * `--gtf`: Optional path to a GTF file if it cannot be extracted from the IsoQuant log. -* `--counts`: Use counts instead of TPM files for visualization. * `--ref_only`: Use only reference transcript quantification instead of transcript model quantification. -* `--filter_transcripts`: Filter transcripts by minimum value occurring in at least one condition. +* `--filter_transcripts `: Minimum expression value a transcript must reach in at least one condition to be included in plots (default 1.0). +* `--gsea`: Perform Gene Set Enrichment Analysis on differential expression results (requires `--find_genes`). +* `--technical_replicates`: Specify technical replicate groupings as a file (`sample,group`) or inline (`sample1:group1,sample2:group1`). ## Output -The visualization tool generates the following plots based on the IsoQuant output: +The visualization tool can generate the following outputs: -1. Transcript usage profiles: For each gene specified in the gene list, a plot showing the relative usage of different transcripts across conditions or samples. +1. Transcript usage profiles: For each gene, a plot showing the relative usage of different transcripts across conditions or samples. 2. Gene-specific transcript maps: Visual representation of the different splicing patterns of transcripts for each gene, allowing easy comparison of exon usage and alternative splicing events. -3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed. +3. Global read assignment consistency: A summary plot showing the overall consistency of read assignments across all genes and transcripts analyzed (enabled interactively). + +4. Global transcript alignment classifications: A chart representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. -4. Global transcript alignment classifications: A chart or plot representing the distribution of different transcript alignment categories (e.g., full splice match, incomplete splice match, novel isoforms) across the entire dataset. +5. Differential expression tables and volcano plots when `--find_genes` is used, with optional GSEA pathway visualizations if `--gsea` is supplied. -These visualizations provide valuable insights into transcript diversity, splicing patterns, and the overall quality of the IsoQuant analysis. +These visualizations and reports provide valuable insights into transcript diversity, splicing patterns, differential expression, and the overall quality of the IsoQuant analysis. diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 28c6e3dc..63653cb3 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -28,6 +28,7 @@ def __init__( top_transcripts_base_mean: int = 500, top_n_genes: int = 100, log_level: int = logging.INFO, # Allow configuring log level + tech_rep_dict: Dict[str, str] = None, ): """Initialize differential expression analysis.""" def quiet_cb(x): @@ -75,6 +76,7 @@ def quiet_cb(x): self.transcript_to_gene = self._create_transcript_to_gene_map() self.visualizer = ExpressionVisualizer(self.deseq_dir) self.gene_mapper = GeneMapper() + self.tech_rep_dict = tech_rep_dict def _load_transcript_mapping_from_file(self): """Load transcript mapping directly from the transcript_mapping.tsv file.""" @@ -415,6 +417,9 @@ def _get_merged_transcript_counts(self, pattern: str) -> pd.DataFrame: combined_df = pd.concat(all_sample_dfs, axis=1) self.logger.info(f"Combined count data shape before mapping: {combined_df.shape}") + # Apply technical replicate merging before transcript mapping + combined_df = self._merge_technical_replicates(combined_df) + # Apply transcript mapping if available if not hasattr(self, 'transcript_map') or not self.transcript_map: self.logger.info("No transcript mapping available, using raw counts") @@ -513,6 +518,10 @@ def _get_condition_data(self, pattern: str) -> pd.DataFrame: # Combine all sample dataframes combined_df = pd.concat(all_sample_dfs, axis=1) self.logger.info(f"Combined gene count data shape: {combined_df.shape}") + + # Apply technical replicate merging + combined_df = self._merge_technical_replicates(combined_df) + return combined_df else: self.logger.error(f"Unsupported count pattern: {pattern}") @@ -684,6 +693,31 @@ def _run_deseq2( res_df = robjects.conversion.rpy2py(r("as.data.frame")(res)) # Convert to R dataframe first for stability res_df.index = count_data.index # Assign original feature IDs as index + # Extract dispersion estimates + self.logger.debug("Extracting dispersion estimates...") + dispersions_r = r['dispersions'](dds) + dispersions_py = robjects.conversion.rpy2py(dispersions_r) + + # Add dispersion estimates to results DataFrame + res_df['dispersion'] = dispersions_py + + # Extract size factors + self.logger.debug("Extracting size factors...") + size_factors_r = r['sizeFactors'](dds) + size_factors_py = robjects.conversion.rpy2py(size_factors_r) + + # Create size factors DataFrame with sample names + size_factors_df = pd.DataFrame({ + 'sample': count_data.columns, + 'size_factor': size_factors_py + }) + + # Save size factors to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + size_factors_file = self.deseq_dir / f"size_factors_{level}_{target_label}_vs_{reference_label}.csv" + size_factors_df.to_csv(size_factors_file, index=False) + self.logger.info(f"Size factors saved to {size_factors_file}") # Correct way to call the R 'counts' function on the dds object # Ensure 'r' is imported: from rpy2.robjects import r @@ -694,6 +728,8 @@ def _run_deseq2( # Ensure DataFrame structure matches original count_data (features x samples) normalized_counts_df = pd.DataFrame(normalized_counts_py, index=count_data.index, columns=count_data.columns) + # Generate dispersion and count summaries + self._generate_dispersion_summary(res_df, level) self.logger.info(f"DESeq2 run completed for {level}. Results shape: {res_df.shape}, Normalized counts shape: {normalized_counts_df.shape}") return res_df, normalized_counts_df @@ -703,6 +739,135 @@ def _run_deseq2( # Return empty DataFrames on error to avoid downstream issues return pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns) + def _generate_dispersion_summary(self, results_df: pd.DataFrame, level: str) -> None: + """ + Generate summary statistics for average read counts and dispersion estimates. + Saves summary to a file and logs key statistics. + + Args: + results_df: DESeq2 results DataFrame with baseMean and dispersion columns + level: Analysis level (gene/transcript) + """ + if results_df.empty: + self.logger.warning(f"Cannot generate dispersion summary for {level}: Results DataFrame is empty.") + return + + self.logger.info(f"Generating dispersion and count summary for {level} level...") + + # Check if required columns exist + required_cols = ['baseMean', 'dispersion'] + missing_cols = [col for col in required_cols if col not in results_df.columns] + if missing_cols: + self.logger.warning(f"Cannot generate complete summary for {level}: Missing columns {missing_cols}") + return + + # Remove NaN values for summary statistics + clean_data = results_df[['baseMean', 'dispersion']].dropna() + + if clean_data.empty: + self.logger.warning(f"No valid data for dispersion summary for {level} after removing NaN values.") + return + + # Calculate summary statistics + summary_stats = { + 'level': level, + 'total_features': len(results_df), + 'features_with_valid_data': len(clean_data), + + # Average read count (baseMean) statistics + 'baseMean_mean': clean_data['baseMean'].mean(), + 'baseMean_median': clean_data['baseMean'].median(), + 'baseMean_std': clean_data['baseMean'].std(), + 'baseMean_min': clean_data['baseMean'].min(), + 'baseMean_max': clean_data['baseMean'].max(), + 'baseMean_q25': clean_data['baseMean'].quantile(0.25), + 'baseMean_q75': clean_data['baseMean'].quantile(0.75), + + # Dispersion statistics + 'dispersion_mean': clean_data['dispersion'].mean(), + 'dispersion_median': clean_data['dispersion'].median(), + 'dispersion_std': clean_data['dispersion'].std(), + 'dispersion_min': clean_data['dispersion'].min(), + 'dispersion_max': clean_data['dispersion'].max(), + 'dispersion_q25': clean_data['dispersion'].quantile(0.25), + 'dispersion_q75': clean_data['dispersion'].quantile(0.75), + } + + # Add size factor statistics if available + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + size_factors_file = self.deseq_dir / f"size_factors_{level}_{target_label}_vs_{reference_label}.csv" + + if size_factors_file.exists(): + try: + size_factors_df = pd.read_csv(size_factors_file) + if 'size_factor' in size_factors_df.columns: + sf_data = size_factors_df['size_factor'].dropna() + summary_stats.update({ + 'size_factor_mean': sf_data.mean(), + 'size_factor_median': sf_data.median(), + 'size_factor_std': sf_data.std(), + 'size_factor_min': sf_data.min(), + 'size_factor_max': sf_data.max(), + 'size_factor_q25': sf_data.quantile(0.25), + 'size_factor_q75': sf_data.quantile(0.75), + }) + self.logger.info(f" Size factors: mean={summary_stats['size_factor_mean']:.4f}, median={summary_stats['size_factor_median']:.4f}, range={summary_stats['size_factor_min']:.4f}-{summary_stats['size_factor_max']:.4f}") + except Exception as e: + self.logger.warning(f"Could not read size factors file: {e}") + + # Log key statistics + self.logger.info(f"{level.capitalize()} level summary:") + self.logger.info(f" Total features: {summary_stats['total_features']}") + self.logger.info(f" Features with valid data: {summary_stats['features_with_valid_data']}") + self.logger.info(f" Average read count (baseMean): mean={summary_stats['baseMean_mean']:.2f}, median={summary_stats['baseMean_median']:.2f}") + self.logger.info(f" Dispersion: mean={summary_stats['dispersion_mean']:.4f}, median={summary_stats['dispersion_median']:.4f}") + + # Create a more detailed summary for significant DE genes/transcripts + if 'padj' in results_df.columns: + significant_features = results_df[results_df['padj'] < 0.05].dropna(subset=['baseMean', 'dispersion']) + if not significant_features.empty: + summary_stats.update({ + 'significant_features_count': len(significant_features), + 'significant_baseMean_mean': significant_features['baseMean'].mean(), + 'significant_baseMean_median': significant_features['baseMean'].median(), + 'significant_dispersion_mean': significant_features['dispersion'].mean(), + 'significant_dispersion_median': significant_features['dispersion'].median(), + }) + + self.logger.info(f" Significant DE features (padj < 0.05): {summary_stats['significant_features_count']}") + self.logger.info(f" Significant features - Average read count: mean={summary_stats['significant_baseMean_mean']:.2f}, median={summary_stats['significant_baseMean_median']:.2f}") + self.logger.info(f" Significant features - Dispersion: mean={summary_stats['significant_dispersion_mean']:.4f}, median={summary_stats['significant_dispersion_median']:.4f}") + + # Save summary to file + summary_file = self.deseq_dir / f"dispersion_count_summary_{level}_{target_label}_vs_{reference_label}.txt" + + with open(summary_file, 'w') as f: + f.write(f"Dispersion and Count Summary for {level.capitalize()} Level Analysis\n") + f.write(f"Comparison: {target_label} vs {reference_label}\n") + f.write("=" * 60 + "\n\n") + + for key, value in summary_stats.items(): + if isinstance(value, float): + f.write(f"{key}: {value:.6f}\n") + else: + f.write(f"{key}: {value}\n") + + self.logger.info(f"Dispersion and count summary saved to {summary_file}") + + # Also save detailed data for further analysis + detailed_file = self.deseq_dir / f"detailed_dispersion_data_{level}_{target_label}_vs_{reference_label}.csv" + + # Include feature mapping information if available + detailed_data = results_df[['baseMean', 'dispersion', 'log2FoldChange', 'pvalue', 'padj']].copy() + if 'gene_name' in results_df.columns: + detailed_data['gene_name'] = results_df['gene_name'] + if 'transcript_symbol' in results_df.columns: + detailed_data['transcript_symbol'] = results_df['transcript_symbol'] + + detailed_data.to_csv(detailed_file) + self.logger.info(f"Detailed dispersion data saved to {detailed_file}") + def _map_gene_symbols(self, feature_ids: List[str], level: str) -> Dict[str, Dict[str, Optional[str]]]: """ Map feature IDs to gene and transcript names using GeneMapper class. @@ -914,4 +1079,75 @@ def _run_pca(self, normalized_counts, level, coldata, target_label, reference_la self.logger.info(f"PCA plots saved for {level} level.") except Exception as e: - self.logger.error(f"Error during PCA calculation or plotting for {level}: {str(e)}") \ No newline at end of file + self.logger.error(f"Error during PCA calculation or plotting for {level}: {str(e)}") + + def _merge_technical_replicates(self, count_data: pd.DataFrame) -> pd.DataFrame: + """ + Merge technical replicates by summing counts for samples in the same replicate group. + + Args: + count_data: DataFrame with samples as columns and features as rows + + Returns: + DataFrame with technical replicates merged + """ + if not self.tech_rep_dict: + self.logger.info("No technical replicates specified, returning original data") + return count_data + + self.logger.info(f"Merging technical replicates using {len(self.tech_rep_dict)} mappings") + + # Create a mapping from sample columns to replicate groups + sample_to_group = {} + for col in count_data.columns: + # Extract the base sample name (remove condition prefix if present) + base_sample = col + for condition in self.ref_conditions + self.target_conditions: + if col.startswith(f"{condition}_"): + base_sample = col[len(condition)+1:] + break + + # Check if this sample is in the technical replicates mapping + if base_sample in self.tech_rep_dict: + group_name = self.tech_rep_dict[base_sample] + # Reconstruct the group name with condition prefix + condition_prefix = col.replace(base_sample, "").rstrip("_") + if condition_prefix: + full_group_name = f"{condition_prefix}_{group_name}" + else: + full_group_name = group_name + sample_to_group[col] = full_group_name + else: + # Keep original sample name if not in technical replicates + sample_to_group[col] = col + + # Group samples by their replicate groups + group_to_samples = {} + for sample, group in sample_to_group.items(): + if group not in group_to_samples: + group_to_samples[group] = [] + group_to_samples[group].append(sample) + + # Create merged DataFrame + merged_data = pd.DataFrame(index=count_data.index) + + merge_stats = {"merged_groups": 0, "original_samples": len(count_data.columns)} + + for group_name, samples in group_to_samples.items(): + if len(samples) == 1: + # No merging needed, just rename + merged_data[group_name] = count_data[samples[0]] + else: + # Sum technical replicates + merged_data[group_name] = count_data[samples].sum(axis=1) + merge_stats["merged_groups"] += 1 + self.logger.debug(f"Merged technical replicates for group {group_name}: {samples}") + + merge_stats["final_samples"] = len(merged_data.columns) + self.logger.info( + f"Technical replicate merging complete: " + f"{merge_stats['original_samples']} samples -> {merge_stats['final_samples']} samples " + f"({merge_stats['merged_groups']} groups had multiple replicates)" + ) + + return merged_data \ No newline at end of file diff --git a/src/visualization_gsea.py b/src/visualization_gsea.py index 4047c43e..a49a0e76 100644 --- a/src/visualization_gsea.py +++ b/src/visualization_gsea.py @@ -87,84 +87,132 @@ def run_gsea_analysis(self, results: pd.DataFrame, target_label: str) -> None: with localconverter(robjects.default_converter + pandas2ri.converter): r_ranked_genes = pandas2ri.py2rpy(ranked_genes.sort_values(ascending=False)) - def plot_pathways(df: pd.DataFrame, direction: str, ont: str): - if df.empty: - logging.info(f"No {direction} pathways to plot.") + def plot_pathways(up_df: pd.DataFrame, down_df: pd.DataFrame, ont: str): + if up_df.empty and down_df.empty: + logging.info(f"No pathways to plot for {ont}.") return - - df["label"] = df["ID"] + ": " + df["Description"] - df["-log10(p.adjust)"] = -np.log10(df["p.adjust"]) + + # Process DataFrame if not empty + if not up_df.empty: + # Remove GO IDs from labels, keeping only the description + up_df["label"] = up_df["Description"] + up_df["-log10(p.adjust)"] = -np.log10(up_df["p.adjust"]) + up_df = up_df.sort_values(by="NES", ascending=False) - # Sort by NES value - for up-regulated, highest NES first; for down-regulated, lowest NES first - if direction == "up": - df = df.sort_values(by="NES", ascending=False) - plot_values = df["NES"] # Use NES directly for up-regulated - else: # down - df = df.sort_values(by="NES", ascending=True) - plot_values = df["NES"].abs() # Use absolute NES for down-regulated + if not down_df.empty: + # Remove GO IDs from labels, keeping only the description + down_df["label"] = down_df["Description"] + down_df["-log10(p.adjust)"] = -np.log10(down_df["p.adjust"]) + down_df = down_df.sort_values(by="NES", ascending=True) - # Use NES for bar length but -log10(p.adjust) for color - values = df["-log10(p.adjust)"] + # Find the global min and max for -log10(p.adjust) for consistent coloring + all_pvals = [] + if not up_df.empty: + all_pvals.extend(up_df["-log10(p.adjust)"].tolist()) + if not down_df.empty: + all_pvals.extend(down_df["-log10(p.adjust)"].tolist()) + + if not all_pvals: + return # Skip if no values + + # Get global min and max p-values + global_vmin = min(all_pvals) + global_vmax = max(all_pvals) - # Use the data's own range for each direction - vmin = values.min() - vmax = values.max() + # Adjust the maximum value to prevent saturation of highly significant pathways + # Use either actual max or a higher percentile value, whichever is higher + # This prevents all highly significant pathways from appearing with the same color + if len(all_pvals) > 1: + # Calculate 90th percentile of p-values + percentile_90 = np.percentile(all_pvals, 90) + + # If max is much larger than 90th percentile, use an intermediate value + if global_vmax > 2 * percentile_90: + adjusted_vmax = percentile_90 + (global_vmax - percentile_90) / 3 + # But ensure we don't lower the max too much + global_vmax = max(adjusted_vmax, global_vmax * 0.7) + + # Log the adjustment for debugging + logging.debug(f"P-value color scale: original max={max(all_pvals):.2f}, adjusted max={global_vmax:.2f}") - norm = plt.Normalize(vmin=vmin, vmax=vmax) + # Create a consistent color normalization across both plots + norm = plt.Normalize(vmin=global_vmin, vmax=global_vmax) cmap = plt.cm.get_cmap("viridis") - colors_for_bars = [cmap(norm(v)) for v in values] - - plt.figure(figsize=(14, 8)) # Wider figure to accommodate legend - - # Use NES for bar length (absolute value for down-regulated) - plt.barh( - df["label"].iloc[::-1], - plot_values.iloc[::-1], # Use appropriate values based on direction - color=colors_for_bars[::-1], # Still color by significance - ) - # Add a colorbar to show the significance scale - sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) - sm.set_array([]) - cbar = plt.colorbar(sm) - cbar.set_label("-log10(adjusted p-value)") - - # Add legend explaining the visualization - legend_elements = [ - Patch(facecolor='gray', alpha=0.5, - label='Bar length: Normalized Enrichment Score (NES)'), - Patch(facecolor=cmap(0.25), alpha=0.8, - label='Bar color: Statistical significance'), - ] - # Move legend much further to the right - plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.25, 1)) - - # Set x-axis label based on direction - if direction == "up": - plt.xlabel("Normalized Enrichment Score (NES)") - else: - plt.xlabel("Absolute Normalized Enrichment Score (|NES|)") - # Split target label into reference and target parts target_parts = target_label.split("_vs_") target_condition = target_parts[0] ref_condition = target_parts[1] - - # Create title based on direction - if direction == "up": + + # Plot UP-regulated pathways + if not up_df.empty: + up_values = up_df["-log10(p.adjust)"] + up_colors = [cmap(norm(v)) for v in up_values] + + # Adjust figure size - no need for extra space for legend + plt.figure(figsize=(12, 10)) + + # Create horizontal bar plot + bars = plt.barh( + up_df["label"].iloc[::-1], + up_df["NES"].iloc[::-1], + color=up_colors[::-1], + ) + + # Remove the p-value text labels + + # Add colorbar with the global scale + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label("-log10(adjusted p-value)", fontsize=12) + + plt.xlabel("Normalized Enrichment Score (NES)", fontsize=12) condition_str = f"Pathways enriched in {target_condition}\nvs {ref_condition} - {ont}" - else: - condition_str = f"Pathways enriched in {ref_condition}\nvs {target_condition} - {ont}" - - plt.title(condition_str, fontsize=10) + plt.title(condition_str, fontsize=14) + + # Ensure y-axis labels are fully visible + plt.tight_layout() + plt.subplots_adjust(left=0.3) # Add more space on the left for labels + + plot_path = self.output_path / f"GSEA_top_pathways_up_{ont}.pdf" + plt.savefig(plot_path, format="pdf", bbox_inches="tight", dpi=600) + plt.close() + logging.info(f"GSEA up-regulated pathways plot saved to {plot_path} with high resolution") - # Adjust layout to make room for legend - plt.tight_layout() - # Save with extra space for the legend - plot_path = self.output_path / f"GSEA_top_pathways_{direction}_{ont}.pdf" - plt.savefig(plot_path, format="pdf", bbox_inches="tight", dpi=300) - plt.close() - logging.info(f"GSEA {direction} pathways plot saved to {plot_path}") + # Plot DOWN-regulated pathways + if not down_df.empty: + down_values = down_df["-log10(p.adjust)"] + down_colors = [cmap(norm(v)) for v in down_values] + + # Adjust figure size - no need for extra space for legend + plt.figure(figsize=(12, 10)) + + # Create horizontal bar plot + bars = plt.barh( + down_df["label"].iloc[::-1], + down_df["NES"].abs().iloc[::-1], # Use absolute NES for down-regulated + color=down_colors[::-1], + ) + + # Add colorbar with the global scale + sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) + sm.set_array([]) + cbar = plt.colorbar(sm) + cbar.set_label("-log10(adjusted p-value)", fontsize=12) + + plt.xlabel("Absolute Normalized Enrichment Score (|NES|)", fontsize=12) + condition_str = f"Pathways enriched in {ref_condition}\nvs {target_condition} - {ont}" + plt.title(condition_str, fontsize=14) + + # Ensure y-axis labels are fully visible + plt.tight_layout() + plt.subplots_adjust(left=0.3) # Add more space on the left for labels + + plot_path = self.output_path / f"GSEA_top_pathways_down_{ont}.pdf" + plt.savefig(plot_path, format="pdf", bbox_inches="tight", dpi=600) + plt.close() + logging.info(f"GSEA down-regulated pathways plot saved to {plot_path} with high resolution") # Run GO analysis for each ontology ontologies = ["BP", "MF", "CC"] @@ -212,10 +260,11 @@ def plot_pathways(df: pd.DataFrame, direction: str, ont: str): gsea_df.to_csv(gsea_outfile, index=False) logging.info(f"Complete GSEA results for {ont} saved to {gsea_outfile}") - # Plot significant pathways + # Process significant pathways sig_gsea_df = gsea_df[ gsea_df["p.adjust"] < 0.05 ].copy() # Using 0.05 threshold + if not sig_gsea_df.empty: up_pathways = sig_gsea_df[sig_gsea_df["NES"] > 0].nsmallest( 15, "p.adjust" @@ -224,11 +273,8 @@ def plot_pathways(df: pd.DataFrame, direction: str, ont: str): 15, "p.adjust" ) - # Use separate color scales for each direction - if not up_pathways.empty: - plot_pathways(up_pathways, "up", ont) - if not down_pathways.empty: - plot_pathways(down_pathways, "down", ont) + # Use consistent color scales across both plots + plot_pathways(up_pathways, down_pathways, ont) else: logging.info(f"No pathways with adj.P<0.05 found for {ont}") diff --git a/src/visualization_output_config.py b/src/visualization_output_config.py index 28fc0930..40256d47 100644 --- a/src/visualization_output_config.py +++ b/src/visualization_output_config.py @@ -4,11 +4,11 @@ import gzip import shutil from argparse import Namespace -import gffutils import yaml -from typing import List, Dict, Tuple, Set +from typing import List import logging import re +from pathlib import Path class OutputConfig: """Class to build dictionaries from the output files of the pipeline.""" @@ -18,17 +18,18 @@ def __init__( output_directory: str, ref_only: bool = False, gtf: str = None, + technical_replicates: str = None, ): self.output_directory = output_directory self.log_details = {} self.extended_annotation = None self.read_assignments = None - self.input_gtf = gtf # Initialize with the provided gtf flag + self.input_gtf = gtf self.genedb_filename = None self.yaml_input = True self.yaml_input_path = None - self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. - self._conditions = None # Changed from self.conditions = False + self.gtf_flag_needed = False + self._conditions = None self.gene_grouped_counts = None self.transcript_grouped_counts = None self.transcript_grouped_tpm = None @@ -43,7 +44,7 @@ def __init__( self.transcript_model_grouped_counts = None self.ref_only = ref_only - # New attributes for handling extended annotations + # Extended annotation handling self.sample_extended_gtfs = [] self.merged_extended_gtf = None @@ -52,12 +53,23 @@ def __init__( self.sample_transcript_model_tpm = {} self.sample_transcript_model_counts = {} - # New attribute for transcript mapping + # Transcript mapping self.transcript_map = {} # Maps transcript IDs to canonical transcript ID with same exon structure + + # Technical replicates + self.technical_replicates_spec = technical_replicates + self.technical_replicates_dict = {} + self._has_technical_replicates = False + self._has_biological_replicates = None # Will be computed when needed self._load_params_file() self._find_files() self._conditional_unzip() + + # Parse technical replicates after initialization + if self.technical_replicates_spec: + self.technical_replicates_dict = self._parse_technical_replicates(self.technical_replicates_spec) + self._has_technical_replicates = bool(self.technical_replicates_dict) # Ensure input_gtf is provided if ref_only is set and input_gtf is not found in the log if self.ref_only and not self.input_gtf: @@ -75,7 +87,7 @@ def _load_params_file(self): if isinstance(params, Namespace): self._process_params(vars(params)) else: - print("Unexpected params format.") + logging.warning("Unexpected params format.") except Exception as e: raise ValueError(f"An error occurred while loading params: {e}") @@ -127,7 +139,7 @@ def _unzip_file(self, file_path): with gzip.open(file_path, "rb") as f_in: with open(new_path, "wb") as f_out: shutil.copyfileobj(f_in, f_out) - print(f"File {file_path} was decompressed to {new_path}.") + logging.info(f"File {file_path} was decompressed to {new_path}.") return new_path @@ -139,7 +151,7 @@ def _find_files(self): return # Exit the method after processing YAML input if not os.path.exists(self.output_directory): - print(f"Directory not found: {self.output_directory}") # Debugging output + logging.error(f"Directory not found: {self.output_directory}") raise FileNotFoundError( f"Specified sample subdirectory does not exist: {self.output_directory}" ) @@ -215,7 +227,7 @@ def _find_files(self): def _find_files_from_yaml(self): """Locate files and samples from YAML, apply filters to ensure only valid samples are processed.""" if not os.path.exists(self.yaml_input_path): - print(f"YAML file not found: {self.yaml_input_path}") + logging.error(f"YAML file not found: {self.yaml_input_path}") raise FileNotFoundError( f"Specified YAML file does not exist: {self.yaml_input_path}" ) @@ -243,10 +255,9 @@ def _find_files_from_yaml(self): ]: file_path = getattr(self, attr) if not os.path.exists(file_path): - print(f"Warning: {attr} file not found at {file_path}") + logging.warning(f"{attr} file not found at {file_path}") setattr(self, attr, None) - # Initialize read_assignments list self.read_assignments = [] # Read and process the YAML file @@ -281,8 +292,8 @@ def _find_files_from_yaml(self): if os.path.exists(extended_gtf): self.sample_extended_gtfs.append(extended_gtf) else: - print( - f"Warning: extended_annotation.gtf not found for sample {name}" + logging.warning( + f"extended_annotation.gtf not found for sample {name}" ) # Check for .read_assignments.tsv.gz @@ -292,7 +303,7 @@ def _find_files_from_yaml(self): if unzipped_file: self.read_assignments.append((name, unzipped_file)) else: - print(f"Warning: Failed to unzip {gz_file}") + logging.warning(f"Failed to unzip {gz_file}") else: # Check for .read_assignments.tsv non_gz_file = os.path.join( @@ -301,7 +312,7 @@ def _find_files_from_yaml(self): if os.path.exists(non_gz_file): self.read_assignments.append((name, non_gz_file)) else: - print(f"Warning: No read assignments file found for {name}") + logging.warning(f"No read assignments file found for {name}") # Load transcript_model_tpm and transcript_model_counts for merging tpm_path = os.path.join(sample_dir, f"{name}.transcript_model_tpm.tsv") @@ -317,7 +328,7 @@ def _find_files_from_yaml(self): ) if not self.read_assignments: - print("Warning: No read assignment files found for any samples") + logging.warning("No read assignment files found for any samples") # Handle extended annotations only if ref_only is not True if self.ref_only is not True: @@ -412,7 +423,7 @@ def merge_gtfs(self, gtfs, output_gtf): """Merge multiple GTF files into a single GTF file, identifying transcripts with identical exon structures.""" try: # First, parse all GTFs to identify transcripts with identical exon structures - print(f"Analyzing {len(gtfs)} GTF files to identify identical transcript structures") + logging.info(f"Analyzing {len(gtfs)} GTF files to identify identical transcript structures") logging.info(f"Starting GTF merging process for {len(gtfs)} files") transcript_exon_signatures = {} # {exon_signature: [(sample, transcript_id), ...]} @@ -443,8 +454,8 @@ def merge_gtfs(self, gtfs, output_gtf): logging.info(f"Writing merged GTF file to {output_gtf}") self._write_merged_gtf(gtfs, output_gtf) - print(f"Successfully merged {len(gtfs)} GTF files into {output_gtf}") - print(f"Identified {len(self.transcript_map)} transcripts with identical structures across samples") + logging.info(f"Successfully merged {len(gtfs)} GTF files into {output_gtf}") + logging.info(f"Identified {len(self.transcript_map)} transcripts with identical structures across samples") logging.info(f"GTF merging complete. Output file: {output_gtf}") except Exception as e: @@ -643,7 +654,7 @@ def _write_transcript_mapping(self, output_file): for transcript_id, canonical_id in self.transcript_map.items(): f.write(f"{transcript_id}\t{canonical_id}\n") - print(f"Transcript mapping written to {output_file}") + logging.info(f"Transcript mapping written to {output_file}") def _write_merged_gtf(self, gtfs, output_gtf): """Write the merged GTF with canonical transcript IDs.""" @@ -682,11 +693,7 @@ def _write_merged_gtf(self, gtfs, output_gtf): outfile.write(line) def _merge_transcript_files(self, sample_files_dict, output_file, metric_type): - # sample_files_dict: {sample_name: filepath or None} - # Merge logic: - # 1. Gather all transcripts from all samples - # 2. For each transcript, write a line with transcript_id and values from each sample (0 if missing) - # 3. Apply transcript mapping to merge identical transcripts + transcripts = {} samples = self.samples @@ -766,3 +773,235 @@ def conditions(self): @conditions.setter def conditions(self, value): self._conditions = value + + @property + def has_technical_replicates(self): + """Return True if technical replicates were successfully parsed.""" + return self._has_technical_replicates + + @property + def has_biological_replicates(self): + """Return True if every condition has at least two biological replicate files.""" + if self._has_biological_replicates is None: + self._has_biological_replicates = self._check_biological_replicates() + return self._has_biological_replicates + + def _parse_technical_replicates(self, tech_rep_spec): + """ + Parse technical replicate specification from command line argument. + + Args: + tech_rep_spec (str): Either a file path or inline specification + + Returns: + dict: Mapping from sample names to replicate group names + """ + if not tech_rep_spec: + return {} + + tech_rep_dict = {} + + # Check if it's a file path + if Path(tech_rep_spec).exists(): + logging.info(f"Reading technical replicates from file: {tech_rep_spec}") + try: + with open(tech_rep_spec, 'r') as f: + first_line = True + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line or line.startswith('#'): # Skip empty lines and comments + continue + + # Skip header line if it looks like a header + if first_line: + first_line = False + # Check if this looks like a header (contains common header words) + if any(header_word in line.lower() for header_word in ['sample', 'replicate', 'group', 'name']): + logging.debug(f"Skipping header line: {line}") + continue + + # Support both comma and tab separation + if '\t' in line: + parts = line.split('\t') + elif ',' in line: + parts = line.split(',') + else: + logging.warning(f"Line {line_num} in technical replicates file has invalid format: {line}") + continue + + if len(parts) >= 2: + sample_name = parts[0].strip() + group_name = parts[1].strip() + tech_rep_dict[sample_name] = group_name + else: + logging.warning(f"Line {line_num} in technical replicates file has insufficient columns: {line}") + + except Exception as e: + logging.error(f"Error reading technical replicates file: {e}") + return {} + else: + # Parse inline specification: sample1:group1,sample2:group1,sample3:group2 + logging.info("Parsing technical replicates from inline specification") + try: + pairs = tech_rep_spec.split(',') + for pair in pairs: + if ':' in pair: + sample_name, group_name = pair.split(':', 1) + tech_rep_dict[sample_name.strip()] = group_name.strip() + else: + logging.warning(f"Invalid technical replicate pair format: {pair}") + except Exception as e: + logging.error(f"Error parsing inline technical replicates specification: {e}") + return {} + + if tech_rep_dict: + logging.info(f"Successfully parsed {len(tech_rep_dict)} technical replicate mappings") + # Log some examples + for sample, group in list(tech_rep_dict.items())[:3]: + logging.debug(f"Technical replicate mapping: {sample} -> {group}") + if len(tech_rep_dict) > 3: + logging.debug(f"... and {len(tech_rep_dict) - 3} more mappings") + else: + logging.warning("No technical replicate mappings found") + + return tech_rep_dict + + def _check_biological_replicates(self, ref_conditions=None, target_conditions=None): + """Return True if biological replicates are detected. + + For YAML input: Check each sample subdirectory - if any sample has >1 column + in their gene_grouped files, we have biological replicates + For FASTQ input: Assume no biological replicates (return False) + """ + from pathlib import Path + + # If FASTQ input was used, assume no biological replicates + if self.log_details.get("fastq_used", False): + logging.info("FASTQ input detected - assuming no biological replicates") + return False + + # If no conditions provided, we can't check biological replicates + if not ref_conditions and not target_conditions: + # If we have conditions from the file, use those + if self._conditions: + all_conditions = self._conditions + else: + logging.warning("No conditions available to check for biological replicates") + return False + else: + all_conditions = list(ref_conditions or []) + list(target_conditions or []) + + # For YAML input, check each sample subdirectory + if self.yaml_input: + return self._check_yaml_sample_replicates() + else: + # For non-YAML input, check individual condition files + return self._check_replicates_from_condition_files(all_conditions) + + def _check_yaml_sample_replicates(self): + """Check biological replicates from YAML sample subdirectories. + + For each sample subdirectory, check if its gene_grouped_counts.tsv or + gene_grouped_tpm.tsv files have more than 1 column (excluding gene ID column). + If any sample has >1 column, we have biological replicates. + """ + from pathlib import Path + + logging.info("Checking biological replicates in YAML sample subdirectories") + + # Get all sample names from the YAML configuration + if not hasattr(self, 'samples') or not self.samples: + logging.warning("No samples found in YAML configuration") + return False + + # Check each sample subdirectory for biological replicates + samples_with_replicates = 0 + total_samples_checked = 0 + + for sample in self.samples: + sample_dir = Path(self.output_directory) / sample + if not sample_dir.exists(): + logging.debug(f"Sample directory not found: {sample_dir}") + continue + + # Look for gene count files in the sample directory + count_files = list(sample_dir.glob("*gene_grouped_counts.tsv")) + if not count_files: + logging.debug(f"No gene_grouped_counts.tsv file found for sample '{sample}'") + continue + + # Check the number of columns in the count file + count_file = count_files[0] + try: + with open(count_file, 'r') as f: + header = f.readline().strip().split('\t') + sample_columns = header[1:] # Skip the gene ID column + sample_count = len(sample_columns) + + total_samples_checked += 1 + logging.debug(f"Sample '{sample}' has {sample_count} columns in count file") + + if sample_count >= 2: + samples_with_replicates += 1 + logging.info(f"Sample '{sample}' has {sample_count} biological replicates") + + except Exception as e: + logging.error(f"Error reading file {count_file}: {e}") + continue + + if total_samples_checked == 0: + logging.warning("No valid sample count files found") + return False + + # If any sample has biological replicates, we consider the dataset to have biological replicates + has_bio_reps = samples_with_replicates > 0 + + if has_bio_reps: + logging.info(f"Found biological replicates in {samples_with_replicates}/{total_samples_checked} samples") + else: + logging.info("No biological replicates found in any sample - each sample has only 1 column") + + return has_bio_reps + + def _check_replicates_from_condition_files(self, all_conditions): + """Check biological replicates from individual condition files.""" + from pathlib import Path + + for condition in all_conditions: + condition_dir = Path(self.output_directory) / condition + if not condition_dir.exists(): + logging.warning(f"Condition directory not found: {condition_dir}") + return False + + # Look for gene grouped counts file in the condition directory + count_files = list(condition_dir.glob("*gene_grouped_counts.tsv")) + if not count_files: + logging.warning(f"No gene_grouped_counts.tsv file found for condition '{condition}'") + return False + + # Check the number of columns in the first count file + count_file = count_files[0] + try: + with open(count_file, 'r') as f: + header = f.readline().strip().split('\t') + sample_columns = header[1:] # Skip the gene ID column + sample_count = len(sample_columns) + + if sample_count < 2: + logging.warning( + f"Condition '{condition}' has {sample_count} biological replicate(s); " + "DESeq2 requires at least 2. Falling back to simple ranking." + ) + return False + else: + logging.info(f"Condition '{condition}' has {sample_count} biological replicates") + + except Exception as e: + logging.error(f"Error reading file {count_file}: {e}") + return False + + return True + + def check_biological_replicates_for_conditions(self, ref_conditions, target_conditions): + """Check biological replicates for specific conditions.""" + return self._check_biological_replicates(ref_conditions, target_conditions) diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 6aa07036..53f0d644 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -75,8 +75,8 @@ def plot_transcript_map(self): # Iterate through conditions until the gene is found. for condition, genes in self.updated_gene_dict.items(): for gene_id, gene_info in genes.items(): - # Compare gene names (case-insensitive, using upper()) - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): + # Compare gene names (case-insensitive matching) + if "name" in gene_info and gene_info["name"].upper() == gene_name_or_id.upper(): gene_data = gene_info # No need to log which condition it came from, as it's pre-filtered. break # Found gene info @@ -84,7 +84,7 @@ def plot_transcript_map(self): break # Found gene, stop searching conditions if not gene_data: - logging.warning(f"Gene {gene_name_or_id} not found in the provided (pre-filtered) updated_gene_dict.") + logging.warning(f"Gene '{gene_name_or_id}' not found in the provided gene dictionary.") continue # Skip to the next gene if not found # Get chromosome info and calculate buffer @@ -236,8 +236,8 @@ def plot_transcript_usage(self): for condition, genes in self.updated_gene_dict.items(): condition_gene_data = None for gene_id, gene_info in genes.items(): - # Compare gene names (case-insensitive) - if "name" in gene_info and gene_info["name"] == gene_name_or_id.upper(): + # Compare gene names (case-insensitive matching) + if "name" in gene_info and gene_info["name"].upper() == gene_name_or_id.upper(): condition_gene_data = gene_info.get("transcripts", {}) # Get transcripts, default to empty dict found_gene_any_condition = True #logging.debug(f"Found gene {gene_name_or_id} data for condition {condition}") diff --git a/src/visualization_simple_ranker.py b/src/visualization_simple_ranker.py new file mode 100644 index 00000000..cd7994e0 --- /dev/null +++ b/src/visualization_simple_ranker.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python3 +"""A lightweight gene ranking utility for experiments lacking biological replicates. + +This module implements a fallback algorithm used when each experimental +condition has fewer than two biological replicates and, therefore, formal +statistics with DESeq2 are inappropriate. + +The scoring heuristic combines two intuitive effect-size metrics: + 1. Absolute log2 fold-change of gene-level TPM between target and + reference groups. + 2. Maximum change in isoform usage for any transcript belonging to a + gene. Usage is defined as transcript TPM / total gene TPM. + +The final score is: + score = |log2FC_gene| + max_delta_isoform_usage + +Genes are ranked by the score in descending order. + +The implementation is designed to mirror the interface expected by +visualize.py – namely, a ``rank(top_n)`` method that returns a list of gene +names (mapped from gene IDs using the updated gene dictionary). +""" +from __future__ import annotations + +import logging +from pathlib import Path +from typing import List, Dict + +import numpy as np +import pandas as pd + +logger = logging.getLogger("IsoQuant.visualization.simple_ranker") +logger.setLevel(logging.INFO) + + +class SimpleGeneRanker: + """Rank genes by combined gene-expression and isoform-usage change.""" + + def __init__( + self, + output_dir: str | Path, + ref_conditions: List[str], + target_conditions: List[str], + ref_only: bool = False, + updated_gene_dict: Dict = None, + ) -> None: + self.output_dir = Path(output_dir) + self.ref_conditions = list(ref_conditions) + self.target_conditions = list(target_conditions) + self.ref_only = ref_only + self.updated_gene_dict = updated_gene_dict or {} + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + def rank(self, top_n: int = 100) -> List[str]: + logger.info("Running SimpleGeneRanker (no replicates detected)") + + # Debug: Log updated_gene_dict structure + if self.updated_gene_dict: + logger.info(f"updated_gene_dict has {len(self.updated_gene_dict)} conditions: {list(self.updated_gene_dict.keys())}") + # Show sample genes from each condition + for condition, genes in self.updated_gene_dict.items(): + sample_gene_ids = list(genes.keys())[:3] + logger.info(f"Condition '{condition}' has {len(genes)} genes. Sample gene IDs: {sample_gene_ids}") + # Show gene structure for first gene + if sample_gene_ids: + first_gene_id = sample_gene_ids[0] + first_gene_info = genes[first_gene_id] + logger.info(f" Sample gene '{first_gene_id}' structure: name='{first_gene_info.get('name', 'MISSING')}', keys={list(first_gene_info.keys())}") + else: + logger.warning("No updated_gene_dict provided!") + + # 1. Load gene-level TPM aggregated by sample. + logger.info(f"SIMPLE_RANKER: Loading TPM data for reference conditions: {self.ref_conditions}") + gene_expr_ref = self._aggregate_gene_tpm(self.ref_conditions) + logger.info(f"SIMPLE_RANKER: Reference TPM loaded - shape: {gene_expr_ref.shape}") + + logger.info(f"SIMPLE_RANKER: Loading TPM data for target conditions: {self.target_conditions}") + gene_expr_tgt = self._aggregate_gene_tpm(self.target_conditions) + logger.info(f"SIMPLE_RANKER: Target TPM loaded - shape: {gene_expr_tgt.shape}") + + # Genes common to both groups. + common_genes = gene_expr_ref.index.intersection(gene_expr_tgt.index) + logger.info(f"SIMPLE_RANKER: Found {len(common_genes)} genes common to both ref and target groups") + logger.info(f"SIMPLE_RANKER: Sample common genes: {list(common_genes)[:10]}") + + gene_expr_ref = gene_expr_ref.loc[common_genes] + gene_expr_tgt = gene_expr_tgt.loc[common_genes] + logger.info(f"SIMPLE_RANKER: After filtering to common genes - ref shape: {gene_expr_ref.shape}, target shape: {gene_expr_tgt.shape}") + + # Filter to only genes present in updated_gene_dict (like differential expression analysis does) + if self.updated_gene_dict: + available_genes = set() + for condition_dict in self.updated_gene_dict.values(): + available_genes.update(condition_dict.keys()) + + # Log sample of available genes from updated_gene_dict + sample_genes = list(available_genes)[:10] + logger.info(f"Sample genes available in updated_gene_dict: {sample_genes}") + + # Keep only genes that are in the updated_gene_dict + filtered_common_genes = [g for g in common_genes if g in available_genes] + logger.info(f"Filtered genes to {len(filtered_common_genes)} from {len(common_genes)} based on updated_gene_dict availability") + + if not filtered_common_genes: + logger.warning("No genes remain after filtering by updated_gene_dict. Returning empty list.") + return [] + + gene_expr_ref = gene_expr_ref.loc[filtered_common_genes] + gene_expr_tgt = gene_expr_tgt.loc[filtered_common_genes] + common_genes = filtered_common_genes + + # 2. Compute log2 fold-change (add pseudocount of 1). + logger.info("SIMPLE_RANKER: Computing log2 fold-change...") + log2fc = np.log2(gene_expr_tgt + 1) - np.log2(gene_expr_ref + 1) + abs_log2fc = log2fc.abs() + logger.info(f"SIMPLE_RANKER: Log2FC stats - min: {log2fc.min():.3f}, max: {log2fc.max():.3f}, mean: {log2fc.mean():.3f}") + logger.info(f"SIMPLE_RANKER: Abs Log2FC stats - min: {abs_log2fc.min():.3f}, max: {abs_log2fc.max():.3f}, mean: {abs_log2fc.mean():.3f}") + + # Show top log2FC examples + top_log2fc_genes = abs_log2fc.nlargest(5) + logger.info(f"SIMPLE_RANKER: Top 5 genes by abs log2FC:") + for gene_id, score in top_log2fc_genes.items(): + ref_val = gene_expr_ref[gene_id] + tgt_val = gene_expr_tgt[gene_id] + logger.info(f" {gene_id}: ref_tpm={ref_val:.3f}, tgt_tpm={tgt_val:.3f}, abs_log2fc={score:.3f}") + + # 3. Compute isoform-usage change per gene. + logger.info("SIMPLE_RANKER: Computing isoform usage delta...") + delta_usage = self._compute_isoform_usage_delta(common_genes) + logger.info(f"SIMPLE_RANKER: Delta usage stats - min: {delta_usage.min():.3f}, max: {delta_usage.max():.3f}, mean: {delta_usage.mean():.3f}") + + # 4. Combined score. + logger.info("SIMPLE_RANKER: Computing combined score = |log2FC| + max_delta_isoform_usage...") + combined_score = abs_log2fc + delta_usage + combined_score.name = "score" + logger.info(f"SIMPLE_RANKER: Combined score stats - min: {combined_score.min():.3f}, max: {combined_score.max():.3f}, mean: {combined_score.mean():.3f}") + + # Show detailed scoring for top genes + top_combined_genes = combined_score.nlargest(10) + logger.info(f"SIMPLE_RANKER: Top 10 genes by combined score:") + for gene_id, score in top_combined_genes.items(): + log2fc_contrib = abs_log2fc[gene_id] + usage_contrib = delta_usage[gene_id] + logger.info(f" {gene_id}: total_score={score:.3f} (log2fc={log2fc_contrib:.3f} + usage={usage_contrib:.3f})") + + # 5. Rank and get top N gene IDs. + ranked_gene_ids = combined_score.sort_values(ascending=False).head(top_n).index.tolist() + logger.info(f"SimpleGeneRanker selected {len(ranked_gene_ids)} genes (top {top_n}) by score.") + logger.info(f"Top 10 ranked gene IDs: {ranked_gene_ids[:10]}") + + # Show the final scores for the top ranked genes + final_scores = combined_score.loc[ranked_gene_ids[:10]] + logger.info(f"Final scores for top 10 genes: {final_scores.to_dict()}") + + # 6. Map gene IDs to gene names directly from updated_gene_dict to ensure exact compatibility with plotter + if self.updated_gene_dict: + ranked_gene_names = [] + mapped_count = 0 + mapping_details = [] # For detailed logging + + for gene_id in ranked_gene_ids: + gene_name_found = None + # Look for this gene_id in updated_gene_dict to get the exact gene name + for condition_dict in self.updated_gene_dict.values(): + if gene_id in condition_dict: + gene_info = condition_dict[gene_id] + if "name" in gene_info and gene_info["name"]: + gene_name_found = gene_info["name"] + mapped_count += 1 + mapping_details.append(f"{gene_id} -> {gene_name_found}") + else: + mapping_details.append(f"{gene_id} -> NO_NAME (name field: {gene_info.get('name', 'MISSING')})") + break + else: + mapping_details.append(f"{gene_id} -> NOT_FOUND_IN_DICT") + + # Use the found gene name (uppercase to match plotter expectations) or fallback to gene_id + ranked_gene_names.append(gene_name_found.upper() if gene_name_found else gene_id) + + # Log mapping details for first 10 genes + logger.info(f"Gene mapping details (first 10):") + for detail in mapping_details[:10]: + logger.info(f" {detail}") + + logger.info(f"SimpleGeneRanker mapped {mapped_count}/{len(ranked_gene_ids)} gene IDs to gene names from updated_gene_dict.") + logger.info(f"Final gene names (first 10): {ranked_gene_names[:10]}") + return ranked_gene_names + else: + logger.warning("No updated_gene_dict provided, returning raw gene IDs.") + return ranked_gene_ids + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + # Cached root-level matrices to avoid re-reading + _root_gene_tpm: pd.DataFrame | None = None + _root_transcript_tpm: pd.DataFrame | None = None + + def _aggregate_gene_tpm(self, conditions: List[str]) -> pd.Series: + """Return mean TPM vector across all samples of the given conditions. + + Works with two possible layouts: + 1. YAML / multi-condition run – counts live under //. + 2. Single-sample run – a single *gene_grouped_tpm.tsv in /. + """ + logger.info(f"SIMPLE_RANKER: Aggregating gene TPM for conditions: {conditions}") + tpm_values: list[pd.Series] = [] + + for cond in conditions: + cond_dir = self.output_dir / cond + files = list(cond_dir.glob("*gene_grouped_tpm.tsv")) + logger.info(f"SIMPLE_RANKER: Condition '{cond}' - found {len(files)} gene TPM files in {cond_dir}") + + if files: + for fp in files: + logger.info(f"SIMPLE_RANKER: Reading gene TPM from: {fp}") + df = self._read_tpm(fp) + logger.info(f"SIMPLE_RANKER: Gene TPM file shape: {df.shape}, columns: {list(df.columns)[:5]}...") + logger.info(f"SIMPLE_RANKER: Sample gene IDs: {list(df.index)[:5]}...") + tpm_series = df.sum(axis=1) + logger.info(f"SIMPLE_RANKER: Summed TPM series shape: {tpm_series.shape}, sample values: {tpm_series.head()}") + tpm_values.append(tpm_series) + continue # next condition + + # Fallback: root-level file present + logger.info(f"SIMPLE_RANKER: No condition-specific files found for '{cond}', trying root-level gene TPM...") + root_df = self._get_root_gene_tpm() + logger.info(f"SIMPLE_RANKER: Root gene TPM shape: {root_df.shape}, columns: {list(root_df.columns)}") + cond_cols = [c for c in root_df.columns if c == cond or c.startswith(f"{cond}__")] + logger.info(f"SIMPLE_RANKER: Found {len(cond_cols)} columns for condition '{cond}': {cond_cols}") + if not cond_cols: + logger.warning("Condition '%s' columns not found in root gene_grouped_tpm.tsv; treating as missing.", cond) + continue + cond_tpm = root_df[cond_cols].mean(axis=1) + logger.info(f"SIMPLE_RANKER: Condition TPM series shape: {cond_tpm.shape}, sample values: {cond_tpm.head()}") + tpm_values.append(cond_tpm) + + if not tpm_values: + logger.error("SIMPLE_RANKER: No TPM values found for provided conditions.") + raise FileNotFoundError("No TPM values found for provided conditions.") + + logger.info(f"SIMPLE_RANKER: Collected {len(tpm_values)} TPM series for conditions {conditions}") + stacked = pd.concat(tpm_values, axis=1) + logger.info(f"SIMPLE_RANKER: Stacked TPM shape: {stacked.shape}") + aggregated = stacked.mean(axis=1) + logger.info(f"SIMPLE_RANKER: Final aggregated TPM shape: {aggregated.shape}, sample values: {aggregated.head()}") + logger.info(f"SIMPLE_RANKER: TPM stats - min: {aggregated.min():.3f}, max: {aggregated.max():.3f}, mean: {aggregated.mean():.3f}") + return aggregated + + def _get_root_gene_tpm(self) -> pd.DataFrame: + if self._root_gene_tpm is not None: + return self._root_gene_tpm + files = list(self.output_dir.glob("*gene_grouped_tpm.tsv")) + if not files: + raise FileNotFoundError("Root-level gene_grouped_tpm.tsv not found in output directory.") + df = self._read_tpm(files[0]) + self._root_gene_tpm = df + return df + + def _read_tpm(self, fp: Path) -> pd.DataFrame: + df = pd.read_csv(fp, sep="\t") + first_col = df.columns[0] + if first_col.startswith("#"): + df.rename(columns={first_col: "feature_id"}, inplace=True) + return df.set_index("feature_id") + + def _compute_isoform_usage_delta(self, gene_list: List[str]) -> pd.Series: + """Return a Series of maximal isoform-usage change for each gene.""" + # Load transcript TPMs for each group and compute usage. + usage_diff = pd.Series(0.0, index=gene_list) + + # Attempt to locate transcript TPM files. + trans_ref = self._aggregate_transcript_tpm(self.ref_conditions) + trans_tgt = self._aggregate_transcript_tpm(self.target_conditions) + + if trans_ref.empty or trans_tgt.empty: + logger.warning("Transcript TPM files missing – setting isoform usage component to 0.") + return usage_diff + + # Common transcripts. + common_tx = trans_ref.index.intersection(trans_tgt.index) + trans_ref = trans_ref.loc[common_tx] + trans_tgt = trans_tgt.loc[common_tx] + + # Map transcripts to genes by simple split (before '.') + gene_ids = trans_ref.index.to_series().str.split(".").str[0] + trans_ref_grouped = trans_ref.groupby(gene_ids).sum() + trans_tgt_grouped = trans_tgt.groupby(gene_ids).sum() + + # Compute per-gene usage change. + for gene in gene_list: + if gene not in trans_ref_grouped.index or gene not in trans_tgt_grouped.index: + continue + # Filter transcripts belonging to this gene. + mask = gene_ids == gene + gene_tx_ref = trans_ref[mask] + gene_tx_tgt = trans_tgt[mask] + + # Gene totals (add 1e-6 to avoid divide-by-zero) + ref_total = gene_tx_ref.sum() + 1e-6 + tgt_total = gene_tx_tgt.sum() + 1e-6 + + ref_usage = gene_tx_ref / ref_total + tgt_usage = gene_tx_tgt / tgt_total + max_delta = (ref_usage - tgt_usage).abs().max() + usage_diff.at[gene] = max_delta + return usage_diff + + def _aggregate_transcript_tpm(self, conditions: List[str]) -> pd.Series: + """Return mean TPM per transcript across all samples in conditions (handles both layouts).""" + tpm_values: list[pd.Series] = [] + pattern_default = "*transcript_grouped_tpm.tsv" + pattern_model = "*transcript_model_grouped_tpm.tsv" + + for cond in conditions: + cond_dir = self.output_dir / cond + pattern = pattern_model if self.ref_only else pattern_default + files = list(cond_dir.glob(pattern)) + if files: + for fp in files: + df = self._read_tpm(fp) + tpm_values.append(df.sum(axis=1)) + continue + + # Fallback: root-level transcript TPM + root_df = self._get_root_transcript_tpm() + cond_cols = [c for c in root_df.columns if c == cond or c.startswith(f"{cond}__")] + if not cond_cols: + continue + tpm_values.append(root_df[cond_cols].mean(axis=1)) + + if not tpm_values: + return pd.Series(dtype=float) + stacked = pd.concat(tpm_values, axis=1) + return stacked.mean(axis=1) + + def _get_root_transcript_tpm(self) -> pd.DataFrame: + if self._root_transcript_tpm is not None: + return self._root_transcript_tpm + pattern = "*transcript_model_grouped_tpm.tsv" if self.ref_only else "*transcript_grouped_tpm.tsv" + files = list(self.output_dir.glob(pattern)) + if not files: + return pd.DataFrame() + df = self._read_tpm(files[0]) + self._root_transcript_tpm = df + return df diff --git a/visualize.py b/visualize.py index 5362cd34..69b6bb5f 100755 --- a/visualize.py +++ b/visualize.py @@ -8,6 +8,7 @@ from src.visualization_plotter import PlotOutput from src.visualization_differential_exp import DifferentialAnalysis from src.visualization_gsea import GSEAAnalysis +from src.visualization_simple_ranker import SimpleGeneRanker from pathlib import Path @@ -97,6 +98,12 @@ def parse_arguments(): action="store_true", help="Perform GSEA analysis on differential expression results", ) + parser.add_argument( + "--technical_replicates", + type=str, + help="Technical replicate specification. Can be a file path (.txt/.csv) with 'sample,group' format, or inline format 'sample1:group1,sample2:group1,sample3:group2'", + default=None, + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument( "--gene_list", @@ -185,6 +192,9 @@ def get_selection(prompt, max_selection, exclude=[]): print("Selected Target Conditions:", ", ".join(args.target_conditions), "\n") + + + def main(): # First, parse just the output directory argument to set up logging parser = argparse.ArgumentParser(add_help=False) @@ -218,6 +228,7 @@ def main(): args.output_directory, ref_only=args.ref_only, gtf=args.gtf, + technical_replicates=args.technical_replicates, ) dictionary_builder = DictionaryBuilder(output) logging.debug("OutputConfig details:") @@ -243,11 +254,39 @@ def main(): update_names = True min_val = args.filter_transcripts if args.filter_transcripts is not None else 1.0 + logging.info(f"FLOW_DEBUG: Building updated_gene_dict with:") + logging.info(f" min_value: {min_val}") + logging.info(f" reference_conditions: {getattr(args, 'reference_conditions', None)}") + logging.info(f" target_conditions: {getattr(args, 'target_conditions', None)}") + updated_gene_dict = dictionary_builder.build_gene_dict_with_expression_and_filter( min_value=min_val, reference_conditions=getattr(args, 'reference_conditions', None), target_conditions=getattr(args, 'target_conditions', None) ) + + logging.info(f"FLOW_DEBUG: updated_gene_dict created:") + logging.info(f" type: {type(updated_gene_dict)}") + logging.info(f" keys (conditions): {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + if updated_gene_dict: + for condition, genes in updated_gene_dict.items(): + logging.info(f" condition '{condition}': {len(genes)} genes") + sample_genes = list(genes.keys())[:3] + if sample_genes: + for gene_id in sample_genes: + gene_info = genes[gene_id] + logging.info(f" gene '{gene_id}': name='{gene_info.get('name', 'MISSING')}', keys={list(gene_info.keys())}") + if 'transcripts' in gene_info: + logging.info(f" transcripts: {len(gene_info['transcripts'])} items") + break # Only show details for first condition + + # Debug: log whether gene_dict keys are Ensembl IDs or gene names + if updated_gene_dict: + sample_condition = next(iter(updated_gene_dict)) + sample_keys = list(updated_gene_dict[sample_condition].keys())[:5] + logging.info( + "Sample gene_dict keys for condition '%s': %s", sample_condition, sample_keys + ) # 2. If read assignments are desired, build those as well (cached) if use_read_assignments: @@ -258,44 +297,95 @@ def main(): else: reads_and_class = None - # 3. If user wants to find top genes (--find_genes), run your differential analysis + # 3. If user wants to find top genes (--find_genes), choose method based on replicate availability if args.find_genes is not None: - ref_str = "_".join( - x.upper().replace(" ", "_") for x in args.reference_conditions - ) - target_str = "_".join( - x.upper().replace(" ", "_") for x in args.target_conditions - ) + ref_str = "_".join(x.upper().replace(" ", "_") for x in args.reference_conditions) + target_str = "_".join(x.upper().replace(" ", "_") for x in args.target_conditions) main_dir_name = f"find_genes_{ref_str}_vs_{target_str}" - base_dir = ( - viz_output_dir / main_dir_name if not args.viz_output else viz_output_dir - ) + base_dir = viz_output_dir / main_dir_name if not args.viz_output else viz_output_dir base_dir.mkdir(exist_ok=True) - logging.info("Finding genes via differential analysis.") - diff_analysis = DifferentialAnalysis( - output_dir=args.output_directory, - viz_output=base_dir, - ref_conditions=args.reference_conditions, - target_conditions=args.target_conditions, - updated_gene_dict=updated_gene_dict, - ref_only=args.ref_only, - dictionary_builder=dictionary_builder, + tech_rep_dict = output.technical_replicates_dict + replicate_ok = output.check_biological_replicates_for_conditions( + args.reference_conditions, args.target_conditions ) - gene_results, transcript_results, _, deseq2_df = diff_analysis.run_complete_analysis() - - if args.gsea: - gsea = GSEAAnalysis(output_path=base_dir) - target_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" - gsea.run_gsea_analysis(deseq2_df, target_label) - - # Construct the correct path to the top genes file dynamically - top_n = args.find_genes # Get the number used for top N - contrast_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" - top_genes_filename = f"genes_of_top_{top_n}_DE_transcripts_{contrast_label}.txt" - find_genes_list_path = gene_results.parent / top_genes_filename - logging.info(f"Reading gene list generated by differential analysis from: {find_genes_list_path}") - gene_list = dictionary_builder.read_gene_list(find_genes_list_path) + + if replicate_ok: + logging.info("Finding genes via DESeq2 (replicates detected).") + diff_analysis = DifferentialAnalysis( + output_dir=output.output_directory, + viz_output=base_dir, + ref_conditions=args.reference_conditions, + target_conditions=args.target_conditions, + updated_gene_dict=updated_gene_dict, + ref_only=args.ref_only, + dictionary_builder=dictionary_builder, + tech_rep_dict=tech_rep_dict, + ) + gene_results, transcript_results, _, deseq2_df = diff_analysis.run_complete_analysis() + + if args.gsea: + gsea = GSEAAnalysis(output_path=base_dir) + target_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + gsea.run_gsea_analysis(deseq2_df, target_label) + + # Path to DESeq2-derived top genes + top_n = args.find_genes + contrast_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + top_genes_filename = f"genes_of_top_{top_n}_DE_transcripts_{contrast_label}.txt" + find_genes_list_path = gene_results.parent / top_genes_filename + logging.info(f"Reading gene list generated by differential analysis from: {find_genes_list_path}") + + logging.info(f"FLOW_DEBUG: DESeq2 path - reading from file: {find_genes_list_path}") + if find_genes_list_path.exists(): + with open(find_genes_list_path, 'r') as f: + file_contents = f.read().strip().split('\n') + logging.info(f"FLOW_DEBUG: DESeq2 file has {len(file_contents)} lines, first 5: {file_contents[:5]}") + else: + logging.error(f"FLOW_DEBUG: DESeq2 gene list file does not exist: {find_genes_list_path}") + + gene_list = dictionary_builder.read_gene_list(find_genes_list_path) + logging.info(f"FLOW_DEBUG: DESeq2 gene_list after dictionary_builder.read_gene_list:") + logging.info(f" type: {type(gene_list)}") + logging.info(f" length: {len(gene_list) if gene_list else 'None'}") + logging.info(f" content (first 10): {gene_list[:10] if gene_list else 'None'}") + else: + logging.info("No biological replicates detected – using SimpleGeneRanker.") + logging.info(f"FLOW_DEBUG: Creating SimpleGeneRanker with:") + logging.info(f" output_dir: {output.output_directory}") + logging.info(f" ref_conditions: {args.reference_conditions}") + logging.info(f" target_conditions: {args.target_conditions}") + logging.info(f" ref_only: {args.ref_only}") + logging.info(f" updated_gene_dict keys: {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + + simple_ranker = SimpleGeneRanker( + output_dir=output.output_directory, + ref_conditions=args.reference_conditions, + target_conditions=args.target_conditions, + ref_only=args.ref_only, + updated_gene_dict=updated_gene_dict, + ) + + logging.info(f"FLOW_DEBUG: Calling simple_ranker.rank(top_n={args.find_genes})") + gene_list = simple_ranker.rank(top_n=args.find_genes) + logging.info(f"FLOW_DEBUG: SimpleGeneRanker returned gene_list with {len(gene_list)} genes") + logging.info(f"FLOW_DEBUG: Gene list type: {type(gene_list)}") + logging.info(f"FLOW_DEBUG: Gene list content (first 10): {gene_list[:10] if gene_list else 'EMPTY'}") + + # Write gene list to file for reproducibility + contrast_label = f"{'+'.join(args.target_conditions)}_vs_{'+'.join(args.reference_conditions)}" + top_genes_filename = f"genes_of_top_{args.find_genes}_simple_{contrast_label}.txt" + simple_list_path = base_dir / top_genes_filename + import pandas as _pd + _pd.Series(gene_list).to_csv(simple_list_path, index=False, header=False) + logging.info(f"Simple gene list written to {simple_list_path}") + logging.info(f"FLOW_DEBUG: File contents verification:") + try: + with open(simple_list_path, 'r') as f: + file_contents = f.read().strip().split('\n') + logging.info(f"FLOW_DEBUG: File has {len(file_contents)} lines, first 5: {file_contents[:5]}") + except Exception as e: + logging.error(f"FLOW_DEBUG: Error reading written file: {e}") else: base_dir = viz_output_dir @@ -310,6 +400,15 @@ def main(): read_assignments_dir = None # Set to None if not used # 6. Plotting with PlotOutput + logging.info(f"FLOW_DEBUG: Creating PlotOutput with:") + logging.info(f" gene_names type: {type(gene_list)}") + logging.info(f" gene_names length: {len(gene_list) if gene_list else 'None'}") + logging.info(f" gene_names content (first 10): {gene_list[:10] if gene_list else 'None'}") + logging.info(f" updated_gene_dict keys: {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + logging.info(f" conditions: {output.conditions}") + logging.info(f" filter_transcripts: {min_val}") + logging.info(f" ref_only: {args.ref_only}") + plot_output = PlotOutput( updated_gene_dict=updated_gene_dict, gene_names=gene_list, @@ -323,8 +422,11 @@ def main(): target_conditions=args.target_conditions if hasattr(args, "target_conditions") else None, ) + plot_output.plot_transcript_map() + plot_output.plot_transcript_usage() + if use_read_assignments: plot_output.make_pie_charts() From 52f5a2e8e27d430f6818a182d268c296b87afeda Mon Sep 17 00:00:00 2001 From: Jackfreeman88 Date: Thu, 14 Aug 2025 17:40:21 -0500 Subject: [PATCH 35/35] read assignment versus length + robust DE --- src/visualization_cache_utils.py | 104 +++ src/visualization_dictionary_builder.py | 490 ++++++++--- src/visualization_differential_exp.py | 449 ++++++---- src/visualization_output_config.py | 67 +- src/visualization_plotter.py | 131 ++- src/visualization_read_assignment_io.py | 269 ++++++ src/visualization_simple_ranker.py | 1018 ++++++++++++++++------- visualize.py | 63 +- 8 files changed, 1993 insertions(+), 598 deletions(-) create mode 100644 src/visualization_read_assignment_io.py diff --git a/src/visualization_cache_utils.py b/src/visualization_cache_utils.py index 40ae4534..742bd8e2 100644 --- a/src/visualization_cache_utils.py +++ b/src/visualization_cache_utils.py @@ -5,6 +5,7 @@ from typing import Dict, Any, Optional, Union import random import re +import hashlib def build_gene_dict_cache_file( @@ -60,6 +61,73 @@ def build_read_assignment_cache_file( return cache_dir / "read_assignment_cache_default.pkl" +def _hash_list(values: list) -> str: + try: + s = ",".join(map(str, values)) + m = hashlib.md5() + m.update(s.encode('utf-8')) + return m.hexdigest()[:12] + except Exception: + # Fallback to length-based signature + return f"len{len(values)}" + + +def build_length_effects_cache_file( + read_assignments: Union[str, list], ref_only: bool, cache_dir: Path, bin_labels: list +) -> Path: + """ + Cache name for read-length effects aggregates. Includes input files, mtimes, ref_only, and bin label signature. + """ + bins_sig = _hash_list(bin_labels) + if isinstance(read_assignments, str): + source_file = Path(read_assignments) + mtime = source_file.stat().st_mtime + cache_name = ( + f"length_effects_cache_{source_file.name}_{mtime}_bins_{bins_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + elif isinstance(read_assignments, list): + file_info = [] + for sample_name, path_str in read_assignments: + path_obj = Path(path_str) + file_info.append(f"{sample_name}-{path_obj.name}-{path_obj.stat().st_mtime}") + composite = "_".join(file_info).replace(" ", "_")[:100] + cache_name = ( + f"length_effects_cache_multi_{composite}_bins_{bins_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + else: + return cache_dir / "length_effects_cache_default.pkl" + + +def build_length_hist_cache_file( + read_assignments: Union[str, list], ref_only: bool, cache_dir: Path, bin_edges: list +) -> Path: + """ + Cache name for read-length histogram. Includes input files, mtimes, ref_only, and bin edges signature. + """ + edges_sig = _hash_list(bin_edges) + if isinstance(read_assignments, str): + source_file = Path(read_assignments) + mtime = source_file.stat().st_mtime + cache_name = ( + f"length_hist_cache_{source_file.name}_{mtime}_edges_{edges_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + elif isinstance(read_assignments, list): + file_info = [] + for sample_name, path_str in read_assignments: + path_obj = Path(path_str) + file_info.append(f"{sample_name}-{path_obj.name}-{path_obj.stat().st_mtime}") + composite = "_".join(file_info).replace(" ", "_")[:100] + cache_name = ( + f"length_hist_cache_multi_{composite}_edges_{edges_sig}_ref_only_{ref_only}.pkl" + ) + return cache_dir / cache_name + else: + return cache_dir / "length_hist_cache_default.pkl" + + def save_cache(cache_file: Path, data_to_cache: Any) -> None: """Save data to a cache file using pickle.""" try: @@ -143,6 +211,42 @@ def validate_read_assignment_data( return False +def validate_length_effects_data(data: Any, expected_bins: Optional[list] = None) -> bool: + try: + required = [ + 'bins', 'by_bin_assignment', 'by_bin_classification', + 'assignment_keys', 'classification_keys', 'totals' + ] + if not isinstance(data, dict): + return False + if any(k not in data for k in required): + return False + if expected_bins and data.get('bins') != expected_bins: + return False + # Basic shape checks + if not isinstance(data['by_bin_assignment'], dict): return False + if not isinstance(data['by_bin_classification'], dict): return False + if not isinstance(data['totals'], dict): return False + return True + except Exception as e: + logging.error(f"Length-effects validation error: {e}") + return False + + +def validate_length_hist_data(data: Any, expected_edges: Optional[list] = None) -> bool: + try: + if not isinstance(data, dict): + return False + if any(k not in data for k in ['edges', 'counts', 'total']): + return False + if expected_edges and list(map(int, data.get('edges', []))) != list(map(int, expected_edges)): + return False + return True + except Exception as e: + logging.error(f"Length-hist validation error: {e}") + return False + + def cleanup_cache(cache_dir: Path, max_age_days: int = 7) -> None: """ Remove cache files older than specified days. diff --git a/src/visualization_dictionary_builder.py b/src/visualization_dictionary_builder.py index 9934338b..70332e57 100644 --- a/src/visualization_dictionary_builder.py +++ b/src/visualization_dictionary_builder.py @@ -5,16 +5,20 @@ import logging from pathlib import Path from typing import Dict, Any, List, Union, Tuple +import numpy as np from src.visualization_cache_utils import ( build_gene_dict_cache_file, - build_read_assignment_cache_file, save_cache, load_cache, validate_gene_dict, - validate_read_assignment_data, cleanup_cache, ) +from src.visualization_read_assignment_io import ( + get_read_assignment_counts, + get_read_length_effects, + get_read_length_histogram, +) class DictionaryBuilder: @@ -46,7 +50,13 @@ def build_gene_dict_with_expression_and_filter( selected reference_conditions or target_conditions. Caches the resulting dictionary based on the specific conditions used. """ - self.logger.debug(f"Starting dictionary build: min_value={min_value}, ref={reference_conditions}, target={target_conditions}") + self.logger.debug("=== DICTIONARY BUILD PROCESS DEBUG ===") + self.logger.debug(f"Starting dictionary build:") + self.logger.debug(f" min_value: {min_value}") + self.logger.debug(f" reference_conditions: {reference_conditions}") + self.logger.debug(f" target_conditions: {target_conditions}") + self.logger.debug(f" config.ref_only: {self.config.ref_only}") + self.logger.debug(f" config.extended_annotation: {getattr(self.config, 'extended_annotation', 'NOT_SET')}") # 1. Load full TPM matrix to determine available conditions first tpm_file = self._get_tpm_file() @@ -75,9 +85,9 @@ def build_gene_dict_with_expression_and_filter( if not conditions_to_process: self.logger.error("None of the requested conditions were found in the TPM file. Cannot proceed.") return {} - self.logger.info(f"Processing conditions: {conditions_to_process}") + self.logger.debug(f"Processing conditions: {conditions_to_process}") else: - self.logger.info("No specific conditions requested, processing all available conditions.") + self.logger.debug("No specific conditions requested, processing all available conditions.") conditions_to_process = available_conditions # Already sorted # Create a deterministic cache key based on conditions @@ -110,7 +120,7 @@ def build_gene_dict_with_expression_and_filter( if validate_gene_dict(cached_gene_dict): # Reuse existing validation if suitable self.novel_gene_ids = cached_novel_gene_ids self.novel_transcript_ids = cached_novel_transcript_ids - self.logger.info("Successfully loaded dictionary from cache.") + self.logger.debug("Successfully loaded dictionary from cache.") return cached_gene_dict else: self.logger.warning("Cached dictionary failed validation. Rebuilding.") @@ -120,7 +130,7 @@ def build_gene_dict_with_expression_and_filter( self.logger.warning("Cached data is invalid or in old format. Rebuilding.") # 4. Cache miss or invalid: Build dictionary from scratch for the specified conditions - self.logger.info("Cache miss or invalid. Building dictionary from scratch for selected conditions.") + self.logger.info("Building dictionary from scratch for selected conditions.") # Parse GTF and filter novel genes (only needs to be done once) self.logger.info("Parsing GTF and filtering novel genes") @@ -136,10 +146,49 @@ def build_gene_dict_with_expression_and_filter( valid_transcripts = set( transcript_max_values_subset[transcript_max_values_subset >= min_value].index ) - self.logger.info( + + # Debug: Analyze what transcripts passed the expression filter + total_transcripts_in_tpm = len(transcript_max_values_subset) + novel_transcripts_in_tpm = sum(1 for tx_id in transcript_max_values_subset.index if tx_id.startswith("transcript")) + ensembl_transcripts_in_tpm = sum(1 for tx_id in transcript_max_values_subset.index if tx_id.startswith("ENSMUST")) + + novel_transcripts_passed = sum(1 for tx_id in valid_transcripts if tx_id.startswith("transcript")) + ensembl_transcripts_passed = sum(1 for tx_id in valid_transcripts if tx_id.startswith("ENSMUST")) + + # Show sample transcripts that passed/failed + sample_novel_passed = [tx_id for tx_id in valid_transcripts if tx_id.startswith("transcript")][:5] + sample_novel_failed = [tx_id for tx_id in transcript_max_values_subset.index + if tx_id.startswith("transcript") and tx_id not in valid_transcripts][:5] + + self.logger.debug("=== EXPRESSION FILTERING DEBUG ===") + self.logger.debug(f"Total transcripts before expression filtering: {total_transcripts_in_tpm}") + self.logger.debug(f"Novel transcripts in TPM file: {novel_transcripts_in_tpm}") + self.logger.debug(f"Ensembl transcripts in TPM file: {ensembl_transcripts_in_tpm}") + self.logger.debug( f"Identified {len(valid_transcripts)} transcripts with TPM >= {min_value} " f"in at least one of the conditions: {conditions_to_process}" ) + self.logger.debug(f"Novel transcripts passed: {novel_transcripts_passed} / {novel_transcripts_in_tpm}") + self.logger.debug(f"Ensembl transcripts passed: {ensembl_transcripts_passed} / {ensembl_transcripts_in_tpm}") + + if sample_novel_passed: + self.logger.debug(f"Sample novel transcripts that PASSED: {sample_novel_passed}") + if sample_novel_failed: + self.logger.debug(f"Sample novel transcripts that FAILED: {sample_novel_failed}") + # Show TPM values for failed novel transcripts + for tx_id in sample_novel_failed[:3]: + max_tpm = transcript_max_values_subset.get(tx_id, 0) + self.logger.debug(f" {tx_id}: max TPM = {max_tpm:.2f}") + + if novel_transcripts_passed == 0 and novel_transcripts_in_tpm > 0: + self.logger.warning(f"NO NOVEL TRANSCRIPTS PASSED expression filter! Consider lowering min_value from {min_value}") + # Show the highest TPM values for novel transcripts + novel_tpm_values = [(tx_id, transcript_max_values_subset.get(tx_id, 0)) + for tx_id in transcript_max_values_subset.index if tx_id.startswith("transcript")] + novel_tpm_values.sort(key=lambda x: x[1], reverse=True) + self.logger.debug("Top 5 novel transcript TPM values:") + for tx_id, tpm in novel_tpm_values[:5]: + self.logger.debug(f" {tx_id}: {tpm:.2f} TPM") # Build the final dictionary, iterating only through conditions_to_process final_dict = {} @@ -167,7 +216,7 @@ def build_gene_dict_with_expression_and_filter( self._validate_gene_structure(final_dict[condition]) # Aggregate exon values based on the filtered transcripts in the final_dict - self.logger.info("Aggregating exon values based on filtered transcript expression.") + self.logger.debug("Aggregating exon values based on filtered transcript expression.") for condition in conditions_to_process: for gene_id, gene_info in final_dict[condition].items(): aggregated_exons = {} @@ -187,8 +236,40 @@ def build_gene_dict_with_expression_and_filter( aggregated_exons[exon_id]["value"] += transcript_value # Sum transcript TPM gene_info["exons"] = aggregated_exons # Assign aggregated exons - # 5. Save the newly built dictionary to the condition-specific cache - self.logger.info(f"Saving filtered dictionary to cache: {condition_specific_cache_file}") + # 5. Debug final results before saving + self.logger.debug("=== FINAL DICTIONARY RESULTS ===") + total_final_genes = sum(len(genes) for genes in final_dict.values()) + total_final_transcripts = 0 + final_novel_transcripts = 0 + final_ensembl_transcripts = 0 + + for condition, genes in final_dict.items(): + condition_transcripts = 0 + condition_novel_transcripts = 0 + + for gene_id, gene_info in genes.items(): + transcripts = gene_info.get("transcripts", {}) + condition_transcripts += len(transcripts) + + for tx_id in transcripts.keys(): + if tx_id.startswith("transcript"): + condition_novel_transcripts += 1 + final_novel_transcripts += 1 + elif tx_id.startswith("ENSMUST"): + final_ensembl_transcripts += 1 + + total_final_transcripts += condition_transcripts + self.logger.debug(f"Condition '{condition}': {len(genes)} genes, {condition_transcripts} transcripts ({condition_novel_transcripts} novel)") + + self.logger.info(f"Totals across conditions: genes={total_final_genes}, transcripts={total_final_transcripts}, novel={final_novel_transcripts}, ensembl={final_ensembl_transcripts}") + + if final_novel_transcripts == 0: + self.logger.warning("FINAL RESULT: NO NOVEL TRANSCRIPTS in final dictionary!") + else: + self.logger.info(f"Novel transcripts passing filters: {final_novel_transcripts}") + + # 6. Save the newly built dictionary to the condition-specific cache + self.logger.debug(f"Saving filtered dictionary to cache: {condition_specific_cache_file}") save_cache( condition_specific_cache_file, (final_dict, self.novel_gene_ids, self.novel_transcript_ids) @@ -198,124 +279,224 @@ def build_gene_dict_with_expression_and_filter( def _get_tpm_file(self) -> str: """Get the appropriate TPM file path from config.""" + self.logger.debug("=== TPM FILE SELECTION DEBUG ===") + self.logger.debug(f"config.conditions: {self.config.conditions}") + self.logger.debug(f"config.ref_only: {self.config.ref_only}") + self.logger.debug(f"config.transcript_grouped_tpm: {getattr(self.config, 'transcript_grouped_tpm', 'NOT_SET')}") + self.logger.debug(f"config.transcript_model_grouped_tpm: {getattr(self.config, 'transcript_model_grouped_tpm', 'NOT_SET')}") + self.logger.debug(f"config.transcript_tpm_ref: {getattr(self.config, 'transcript_tpm_ref', 'NOT_SET')}") + self.logger.debug(f"config.transcript_tpm: {getattr(self.config, 'transcript_tpm', 'NOT_SET')}") + self.logger.debug(f"config.transcript_model_tpm: {getattr(self.config, 'transcript_model_tpm', 'NOT_SET')}") + if self.config.conditions: # Check if we have multiple conditions - # For multi-condition data, prioritize merged files if available - merged_tpm = self.config.transcript_grouped_tpm - if merged_tpm and "_merged.tsv" in merged_tpm: - self.logger.info("Using merged TPM file with transcript deduplication already applied") - tpm_file = merged_tpm + if self.config.ref_only: + # Reference-only mode: use regular transcript files + merged_tpm = self.config.transcript_grouped_tpm + if merged_tpm and "_merged.tsv" in merged_tpm: + self.logger.debug("REF-ONLY: Using merged TPM file with transcript deduplication already applied") + tpm_file = merged_tpm + else: + tpm_file = self.config.transcript_grouped_tpm + self.logger.debug("REF-ONLY mode: Using transcript_grouped_tpm (reference transcripts only)") else: - tpm_file = self.config.transcript_grouped_tpm + # Extended annotation mode: use transcript_model files that include novel transcripts + merged_tpm = getattr(self.config, 'transcript_model_grouped_tpm', None) + if merged_tpm and "_merged.tsv" in merged_tpm: + self.logger.debug("EXTENDED: Using merged transcript_model TPM file with deduplication") + tpm_file = merged_tpm + elif merged_tpm: + tpm_file = merged_tpm + self.logger.debug("EXTENDED: Using transcript_model_grouped_tpm (includes novel transcripts)") + else: + # Fallback to regular transcript file if transcript_model file not found + self.logger.warning("transcript_model_grouped_tpm not found, falling back to transcript_grouped_tpm") + tpm_file = self.config.transcript_grouped_tpm else: if self.config.ref_only: tpm_file = self.config.transcript_tpm_ref else: - base_file = self.config.transcript_tpm.replace('.tsv', '') - tpm_file = f"{base_file}_merged.tsv" + # For single condition, use transcript_model files + transcript_model_tpm = getattr(self.config, 'transcript_model_tpm', None) + if transcript_model_tpm: + base_file = transcript_model_tpm.replace('.tsv', '') + tpm_file = f"{base_file}_merged.tsv" + self.logger.debug("EXTENDED: Using transcript_model TPM for single condition") + else: + base_file = self.config.transcript_tpm.replace('.tsv', '') + tpm_file = f"{base_file}_merged.tsv" + self.logger.warning("transcript_model_tpm not found, falling back to transcript_tpm") - self.logger.debug(f"Selected TPM file: {tpm_file}") + self.logger.info(f"Selected TPM file: {tpm_file}") if not tpm_file or not Path(tpm_file).exists(): + self.logger.error(f"TPM file does not exist: {tpm_file}") raise FileNotFoundError(f"TPM file {tpm_file} not found") + # Check file size and sample content + tpm_path = Path(tpm_file) + self.logger.debug(f"TPM file size: {tpm_path.stat().st_size / (1024*1024):.2f} MB") + + # Sample a few lines from the TPM file to see what transcript IDs are present + with open(tpm_file, 'r') as f: + lines = f.readlines() + self.logger.debug(f"TPM file has {len(lines)} total lines") + if len(lines) > 1: + header = lines[0].strip() + self.logger.debug(f"TPM header: {header}") + + # Show sample transcript IDs + novel_count = 0 + ensembl_count = 0 + sample_novel = [] + sample_ensembl = [] + + for i in range(1, min(21, len(lines))): # Check first 20 data lines + transcript_id = lines[i].split('\t')[0] + if transcript_id.startswith('transcript'): + novel_count += 1 + if len(sample_novel) < 5: + sample_novel.append(transcript_id) + elif transcript_id.startswith('ENSMUST'): + ensembl_count += 1 + if len(sample_ensembl) < 5: + sample_ensembl.append(transcript_id) + + self.logger.debug(f"TPM file sample (first 20 lines): {novel_count} novel, {ensembl_count} Ensembl") + if sample_novel: + self.logger.debug(f"Sample novel transcript IDs: {sample_novel}") + if sample_ensembl: + self.logger.debug(f"Sample Ensembl transcript IDs: {sample_ensembl}") + return tpm_file # ------------------ READ ASSIGNMENT CACHING ------------------ def build_read_assignment_and_classification_dictionaries(self): - """ - Index classifications and assignment types from read_assignments.tsv file(s). - Returns either: - - (classification_counts, assignment_type_counts) for single-file input, or - - (classification_counts_dict, assignment_type_counts_dict) for multi-file (YAML) input. - """ - if not self.config.read_assignments: - raise FileNotFoundError("No read assignments file(s) found.") - - # 1. Determine cache file - cache_file = build_read_assignment_cache_file( - self.config.read_assignments, self.config.ref_only, self.cache_dir - ) - - # 2. Attempt to load from cache - if cache_file.exists(): - cached_data = load_cache(cache_file) - if cached_data and validate_read_assignment_data( - cached_data, self.config.read_assignments - ): - self.logger.info("Using cached read assignment data.") - return self._post_process_cached_data(cached_data) - - # 3. Otherwise, build from scratch - self.logger.info("Building read assignment data from scratch.") - if isinstance(self.config.read_assignments, list): - classification_counts_dict = {} - assignment_type_counts_dict = {} - for sample_name, read_assignment_file in self.config.read_assignments: - c_counts, a_counts = self._process_read_assignment_file( - read_assignment_file - ) - classification_counts_dict[sample_name] = c_counts - assignment_type_counts_dict[sample_name] = a_counts - - data_to_cache = { - "classification_counts": classification_counts_dict, - "assignment_type_counts": assignment_type_counts_dict, - } - save_cache(cache_file, data_to_cache) - return classification_counts_dict, assignment_type_counts_dict - else: - classification_counts, assignment_type_counts = ( - self._process_read_assignment_file(self.config.read_assignments) - ) - data_to_cache = (classification_counts, assignment_type_counts) - save_cache(cache_file, data_to_cache) - return classification_counts, assignment_type_counts + """Delegate to read-assignment I/O module with caching.""" + return get_read_assignment_counts(self.config, self.cache_dir) def _post_process_cached_data(self, cached_data): - """ - Convert cached_data back to the return format - for build_read_assignment_and_classification_dictionaries(). - """ + # Backwards-compat wrapper no longer used; kept for compatibility if isinstance(self.config.read_assignments, list): return ( - cached_data["classification_counts"], - cached_data["assignment_type_counts"], + cached_data.get("classification_counts", {}), + cached_data.get("assignment_type_counts", {}), ) - return cached_data # (classification_counts, assignment_type_counts) + return cached_data def _process_read_assignment_file(self, file_path): + """Deprecated; maintained for compatibility. Use get_read_assignment_counts instead.""" + return {}, {} + + # ------------------ READ LENGTH VS ASSIGNMENT ------------------ + def build_length_vs_assignment(self): """ - Parse a read_assignment TSV file, returning: - - classification_counts: dict(classification -> count) - - assignment_type_counts: dict(assignment type -> count) + Stream read_assignment TSV file(s) and aggregate counts by read-length bins + versus (a) assignment_type (unique/ambiguous/inconsistent_*) and + (b) classification (full_splice_match/incomplete_splice_match/NIC/NNIC/etc.). + + Returns a dictionary: + { + 'bins': [bin_labels...], + 'assignment': { (bin, assignment_type) -> count }, + 'classification': { (bin, classification) -> count } + } """ - classification_counts = {} - assignment_type_counts = {} - - with open(file_path, "r") as file: - # Skip header lines - for _ in range(3): - next(file, None) + if not self.config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") - for line in file: - parts = line.strip().split("\t") - if len(parts) < 6: + # Define length bins + bin_defs = [ + (0, 1000, '<1kb'), + (1000, 2000, '1-2kb'), + (2000, 5000, '2-5kb'), + (5000, 8000, '5-8kb'), + (8000, 12000, '8-12kb'), + (12000, 20000, '12-20kb'), + (20000, 50000, '20-50kb'), + (50000, float('inf'), '>50kb'), + ] + + def bin_length(length_bp: int) -> str: + for lo, hi, name in bin_defs: + if lo <= length_bp < hi: + return name + return 'unknown' + + def calc_length(exons_str: str) -> int: + if not exons_str: + return 0 + total = 0 + for part in exons_str.split(','): + if '-' not in part: + continue + try: + s, e = part.split('-') + total += int(e) - int(s) + 1 + except Exception: continue + return total + + assign_counts = {} + class_counts = {} + + # Helper to process a single file (plain or gz) + import gzip + def process_file(fp: str): + def smart_open(path_str): + try: + with open(path_str, 'rb') as bf: + if bf.read(2) == b'\x1f\x8b': + return gzip.open(path_str, 'rt') + except Exception: + pass + return open(path_str, 'rt') + with smart_open(fp) as file: + # Skip header lines starting with '#' + # Read line by line to avoid loading entire file + for line in file: + if not line or line.startswith('#'): + continue + parts = line.rstrip('\n').split('\t') + if len(parts) < 9: + continue + assignment_type = parts[5] + exons = parts[7] + additional = parts[8] + # Classification=VALUE; in additional_info + classification = additional.split('Classification=')[-1].split(';')[0].strip() if 'Classification=' in additional else 'Unknown' - additional_info = parts[-1] - classification = ( - additional_info.split("Classification=")[-1].split(";")[0].strip() - ) - assignment_type = parts[5] + length_bp = calc_length(exons) + b = bin_length(length_bp) - classification_counts[classification] = ( - classification_counts.get(classification, 0) + 1 - ) - assignment_type_counts[assignment_type] = ( - assignment_type_counts.get(assignment_type, 0) + 1 - ) + # Update assignment_type bin counts + key_a = (b, assignment_type) + assign_counts[key_a] = assign_counts.get(key_a, 0) + 1 + + # Update classification bin counts + key_c = (b, classification) + class_counts[key_c] = class_counts.get(key_c, 0) + 1 + + # Process single or multiple files + if isinstance(self.config.read_assignments, list): + for _sample, path in self.config.read_assignments: + process_file(path) + else: + process_file(self.config.read_assignments) + + return { + 'bins': [name for _, _, name in bin_defs], + 'assignment': assign_counts, + 'classification': class_counts, + } + + # ------------------ READ LENGTH EFFECTS ------------------ + def build_read_length_effects(self): + """Delegate to read-assignment I/O module with caching.""" + return get_read_length_effects(self.config, self.cache_dir) - return classification_counts, assignment_type_counts + def build_read_length_histogram(self, bin_edges: List[int] = None): + """Delegate to read-assignment I/O module with caching.""" + return get_read_length_histogram(self.config, self.cache_dir, bin_edges) # -------------------- GTF PARSING -------------------- @@ -324,6 +505,12 @@ def parse_gtf(self) -> Dict[str, Any]: Parse GTF file into a dictionary with genes, transcripts, and exons. Handles both reference GTF (with gffutils) and extended annotation GTF. """ + self.logger.info("=== GTF PARSING DEBUG ===") + self.logger.info(f"config.ref_only: {self.config.ref_only}") + self.logger.info(f"config.extended_annotation: {getattr(self.config, 'extended_annotation', 'NOT_SET')}") + self.logger.info(f"config.input_gtf: {getattr(self.config, 'input_gtf', 'NOT_SET')}") + self.logger.info(f"config.genedb_filename: {getattr(self.config, 'genedb_filename', 'NOT_SET')}") + if self.config.ref_only: # Use gffutils for reference GTF (more robust but slower) self.logger.info("Parsing reference GTF using gffutils") @@ -335,10 +522,17 @@ def parse_gtf(self) -> Dict[str, Any]: def _parse_reference_gtf(self) -> Dict[str, Any]: """Parse reference GTF using gffutils""" - if not self.config.genedb_filename: + # Check if genedb_filename exists, if not create one + if not self.config.genedb_filename or not Path(self.config.genedb_filename).exists(): + if self.config.genedb_filename: + self.logger.warning(f"Configured genedb file does not exist: {self.config.genedb_filename}") + db_path = self.cache_dir / "gtf.db" if not db_path.exists(): self.logger.info(f"Creating GTF database at {db_path}") + if not self.config.input_gtf or not Path(self.config.input_gtf).exists(): + raise FileNotFoundError(f"Input GTF file required for database creation but not found: {self.config.input_gtf}") + gffutils.create_db( self.config.input_gtf, dbfn=str(db_path), @@ -349,8 +543,9 @@ def _parse_reference_gtf(self) -> Dict[str, Any]: verbose=False, ) self.config.genedb_filename = str(db_path) + self.logger.info(f"Using fallback GTF database: {self.config.genedb_filename}") - self.logger.info("Opening GTF database") + self.logger.info(f"Opening GTF database: {self.config.genedb_filename}") db = gffutils.FeatureDB(self.config.genedb_filename) # Pre-fetch all features @@ -423,10 +618,21 @@ def _parse_reference_gtf(self) -> Dict[str, Any]: def _parse_extended_gtf(self) -> Dict[str, Any]: """Parse extended annotation GTF with custom parser""" base_gene_dict = {} - self.logger.info("Parsing extended annotation GTF") + gtf_file = self.config.extended_annotation + self.logger.info(f"=== EXTENDED GTF PARSING DEBUG ===") + self.logger.info(f"Parsing extended annotation GTF: {gtf_file}") + + # Check file existence and size + gtf_path = Path(gtf_file) + if not gtf_path.exists(): + self.logger.error(f"Extended annotation GTF file does not exist: {gtf_file}") + raise FileNotFoundError(f"Extended annotation GTF file not found: {gtf_file}") + + file_size_mb = gtf_path.stat().st_size / (1024*1024) + self.logger.info(f"Extended GTF file size: {file_size_mb:.2f} MB") try: - with open(self.config.extended_annotation, "r") as file: + with open(gtf_file, "r") as file: attr_pattern = re.compile(r'(\S+) "([^"]+)";') # First pass: genes and transcripts @@ -473,6 +679,7 @@ def _parse_extended_gtf(self) -> Dict[str, Any]: "exons": [], "tags": attrs.get("tags", "").split(","), "name": attrs.get("transcript_name", transcript_id), + "biotype": attrs.get("transcript_biotype", "unknown"), } elif feature_type == "exon" and transcript_id and gene_id: @@ -486,7 +693,42 @@ def _parse_extended_gtf(self) -> Dict[str, Any]: } base_gene_dict[gene_id]["transcripts"][transcript_id]["exons"].append(exon_info) - self.logger.info(f"Processed {len(base_gene_dict)} genes from extended annotation GTF") + # Debug: Analyze what we found + total_genes = len(base_gene_dict) + novel_genes = sum(1 for gene_id in base_gene_dict.keys() if "novel_gene" in gene_id) + ensembl_genes = sum(1 for gene_id in base_gene_dict.keys() if gene_id.startswith("ENSMUSG")) + + total_transcripts = 0 + novel_transcripts = 0 + ensembl_transcripts = 0 + sample_novel_transcripts = [] + sample_ensembl_transcripts = [] + + for gene_id, gene_info in base_gene_dict.items(): + transcripts = gene_info.get("transcripts", {}) + total_transcripts += len(transcripts) + + for tx_id in transcripts.keys(): + if tx_id.startswith("transcript"): + novel_transcripts += 1 + if len(sample_novel_transcripts) < 5: + sample_novel_transcripts.append(f"{gene_id}:{tx_id}") + elif tx_id.startswith("ENSMUST"): + ensembl_transcripts += 1 + if len(sample_ensembl_transcripts) < 5: + sample_ensembl_transcripts.append(f"{gene_id}:{tx_id}") + + self.logger.info(f"=== EXTENDED GTF PARSING RESULTS ===") + self.logger.info(f"Total genes parsed: {total_genes}") + self.logger.info(f"Novel genes: {novel_genes}, Ensembl genes: {ensembl_genes}") + self.logger.info(f"Total transcripts: {total_transcripts}") + self.logger.info(f"Novel transcripts: {novel_transcripts}, Ensembl transcripts: {ensembl_transcripts}") + + if sample_novel_transcripts: + self.logger.info(f"Sample novel transcripts: {sample_novel_transcripts}") + if sample_ensembl_transcripts: + self.logger.info(f"Sample Ensembl transcripts: {sample_ensembl_transcripts}") + return base_gene_dict except Exception as e: self.logger.error(f"GTF parsing failed: {str(e)}") @@ -572,14 +814,18 @@ def read_gene_list(self, gene_list_path: Union[str, Path]) -> List[str]: def _filter_novel_genes(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: """Filter out novel genes based on gene ID pattern.""" - self.logger.debug("Starting novel gene filtering") + self.logger.info("=== NOVEL GENE FILTERING DEBUG ===") + self.logger.info(f"Starting novel gene filtering on {len(gene_dict)} genes") + filtered_dict = {} total_removed_genes = 0 total_removed_transcripts = 0 checked_gene_count = 0 sample_removed = [] # For debug logging + sample_kept_novel_transcripts = [] # For novel transcripts in kept genes novel_gene_pattern = r"novel_gene" # Make sure this pattern is correct for your novel gene IDs + self.logger.info(f"Using novel gene pattern: '{novel_gene_pattern}'") for gene_id, gene_info in gene_dict.items(): checked_gene_count += 1 @@ -593,33 +839,47 @@ def _filter_novel_genes(self, gene_dict: Dict[str, Any]) -> Dict[str, Any]: # Add novel gene ID to the set self.novel_gene_ids.add(gene_id) # Add novel transcript IDs to the set - self.novel_transcript_ids.update(gene_info.get("transcripts", {}).keys()) - + transcripts = gene_info.get("transcripts", {}) + self.novel_transcript_ids.update(transcripts.keys()) if len(sample_removed) < 5: # Sample log of removed genes + sample_transcripts = list(transcripts.keys())[:3] # Show first 3 transcripts sample_removed.append({ 'gene_id': gene_id, - 'transcript_count': removed_transcript_count + 'transcript_count': removed_transcript_count, + 'sample_transcripts': sample_transcripts }) continue # Skip adding novel genes to filtered_dict + else: + # Check if this kept gene has any novel transcripts + transcripts = gene_info.get("transcripts", {}) + for tx_id in transcripts.keys(): + if tx_id.startswith("transcript") and len(sample_kept_novel_transcripts) < 10: + sample_kept_novel_transcripts.append(f"{gene_id}:{tx_id}") filtered_dict[gene_id] = gene_info # Keep known genes + self.logger.info(f"=== NOVEL GENE FILTERING RESULTS ===") + self.logger.info(f"Checked {checked_gene_count} total genes") self.logger.info( - f"Filtered {total_removed_genes} novel genes " + f"Removed {total_removed_genes} novel genes " f"({total_removed_genes/checked_gene_count:.2%} of total) " f"and {total_removed_transcripts} associated transcripts" ) + self.logger.info(f"Kept {len(filtered_dict)} genes after novel gene filtering") if sample_removed: - sample_output = "\n".join( - f"- {g['gene_id']}: {g['transcript_count']} transcripts" - for g in sample_removed - ) - self.logger.debug(f"Sample removed novel genes:\n{sample_output}") + self.logger.info("Sample removed novel genes:") + for g in sample_removed: + self.logger.info(f"- {g['gene_id']}: {g['transcript_count']} transcripts {g['sample_transcripts']}") else: self.logger.warning("No novel genes detected with current filtering pattern") + if sample_kept_novel_transcripts: + self.logger.info(f"Sample novel transcripts in KEPT genes: {sample_kept_novel_transcripts}") + else: + self.logger.warning("No novel transcripts found in kept genes!") + return filtered_dict def get_novel_feature_ids(self) -> Tuple[set, set]: diff --git a/src/visualization_differential_exp.py b/src/visualization_differential_exp.py index 63653cb3..0f34a7e1 100644 --- a/src/visualization_differential_exp.py +++ b/src/visualization_differential_exp.py @@ -1,6 +1,7 @@ +from __future__ import annotations import logging import pandas as pd -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Tuple, Optional, Union, Any from pathlib import Path from rpy2 import robjects from rpy2.robjects import r, Formula @@ -22,13 +23,20 @@ def __init__( target_conditions: List[str], updated_gene_dict: Dict[str, Dict], ref_only: bool = False, - dictionary_builder: "DictionaryBuilder" = None, + dictionary_builder: Optional[Any] = None, filter_min_count: int = 10, pca_n_components: int = 10, top_transcripts_base_mean: int = 500, top_n_genes: int = 100, log_level: int = logging.INFO, # Allow configuring log level tech_rep_dict: Dict[str, str] = None, + # New options + use_shrunk_lfc_for_visuals: bool = True, + transcript_filter_mode: str = "per_group_min", # or "half_samples" + transcript_min_per_group: int = 2, + transcript_min_total_fraction: float = 0.5, + covariate_df: Optional[pd.DataFrame] = None, # index: base sample_id (without condition prefix) + size_factor_type: str = "poscounts", # DESeq2 sfType, recommended for zero-heavy data ): """Initialize differential expression analysis.""" def quiet_cb(x): @@ -77,6 +85,92 @@ def quiet_cb(x): self.visualizer = ExpressionVisualizer(self.deseq_dir) self.gene_mapper = GeneMapper() self.tech_rep_dict = tech_rep_dict + self.use_shrunk_lfc_for_visuals = use_shrunk_lfc_for_visuals + self.transcript_filter_mode = transcript_filter_mode + self.transcript_min_per_group = transcript_min_per_group + self.transcript_min_total_fraction = transcript_min_total_fraction + self.covariate_df = covariate_df + self.size_factor_type = size_factor_type + + # ------------------------- + # Small helpers to reduce duplication + # ------------------------- + def _get_labels(self) -> Tuple[str, str]: + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + return target_label, reference_label + + def _annotate_results(self, level: str, results_df: pd.DataFrame) -> pd.DataFrame: + if results_df is None or results_df.empty: + return results_df + results_df = results_df.copy() + results_df.index.name = "feature_id" + results_df.reset_index(inplace=True) + mapping = self._map_gene_symbols(results_df["feature_id"].unique(), level) + results_df["transcript_symbol"] = results_df["feature_id"].map( + lambda x: mapping.get(x, {}).get("transcript_symbol", x) + ) + results_df["gene_name"] = results_df["feature_id"].map( + lambda x: mapping.get(x, {}).get("gene_name", x.split('.')[0] if '.' in x else x) + ) + if level == "gene": + results_df = results_df.drop(columns=["transcript_symbol"], errors='ignore') + return results_df + + def _save_results(self, level: str, results_df: pd.DataFrame, results_shrunk_df: Optional[pd.DataFrame]) -> Tuple[Path, Optional[Path]]: + target_label, reference_label = self._get_labels() + outfile = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}.csv" + results_df.to_csv(outfile, index=False) + shrunk_path = None + if results_shrunk_df is not None and not results_shrunk_df.empty: + # Annotate shrunk for convenience + annotated_shrunk = self._annotate_results(level, results_shrunk_df) + shrunk_path = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}_shrunk_annotated.csv" + annotated_shrunk.to_csv(shrunk_path, index=False) + return outfile, shrunk_path + + def _lfc_for_visuals(self, base_df: pd.DataFrame, shrunk_df: Optional[pd.DataFrame]) -> pd.DataFrame: + df = base_df.copy() + if not self.use_shrunk_lfc_for_visuals or shrunk_df is None or shrunk_df.empty: + return df + try: + merged = pd.merge( + df[["feature_id", "log2FoldChange"]], + shrunk_df[["feature_id", "log2FoldChange"]], + on="feature_id", + how="left", + suffixes=("", "_shrunk"), + ) + lfc_map = merged.set_index("feature_id")["log2FoldChange_shrunk"] + replacement = df["feature_id"].map(lfc_map) + df["log2FoldChange"] = replacement.fillna(df["log2FoldChange"]).values + # Optionally retain a column for reference + df = pd.merge(df, merged[["feature_id", "log2FoldChange_shrunk"]], on="feature_id", how="left") + except Exception as e: + self.logger.warning(f"Could not merge shrunk LFCs for visuals: {e}") + return df + + def _load_prefixed_counts(self, pattern: str) -> pd.DataFrame: + """Load count tsvs for all conditions, prefix columns with condition, and concat.""" + all_sample_dfs: List[pd.DataFrame] = [] + for condition in self.ref_conditions + self.target_conditions: + condition_dir = Path(self.output_dir) / condition + count_files = list(condition_dir.glob(f"*{pattern}")) + if not count_files: + self.logger.error(f"No count files found for condition: {condition}") + raise FileNotFoundError(f"No count files matching {pattern} found in {condition_dir}") + for file_path in count_files: + self.logger.debug(f"Reading count data from: {file_path}") + df = pd.read_csv(file_path, sep="\t") + if "#feature_id" not in df.columns and df.columns[0].startswith("#"): + df.rename(columns={df.columns[0]: "#feature_id"}, inplace=True) + df.set_index("#feature_id", inplace=True) + # Prefix columns + df.rename(columns={col: f"{condition}_{col}" for col in df.columns}, inplace=True) + all_sample_dfs.append(df) + if not all_sample_dfs: + raise ValueError("No sample data found") + return pd.concat(all_sample_dfs, axis=1) def _load_transcript_mapping_from_file(self): """Load transcript mapping directly from the transcript_mapping.tsv file.""" @@ -88,7 +182,7 @@ def _load_transcript_mapping_from_file(self): try: # Load the transcript mapping file - self.logger.info(f"Loading transcript mapping from {mapping_file}") + self.logger.debug(f"Loading transcript mapping from {mapping_file}") self.transcript_map = {} # Skip header and read the mapping @@ -149,12 +243,16 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame] # --- 2. Run DESeq2 Analysis (Gene Level) --- (deseq2_results_gene_file, deseq2_results_df_gene, - gene_normalized_counts) = self._perform_level_analysis("gene", gene_counts_filtered) + deseq2_results_df_gene_shrunk, + gene_normalized_counts, + gene_vst_counts) = self._perform_level_analysis("gene", gene_counts_filtered) # --- 3. Run DESeq2 Analysis (Transcript Level) --- (deseq2_results_transcript_file, deseq2_results_df_transcript, - transcript_normalized_counts) = self._perform_level_analysis("transcript", transcript_counts_filtered) + deseq2_results_df_transcript_shrunk, + transcript_normalized_counts, + transcript_vst_counts) = self._perform_level_analysis("transcript", transcript_counts_filtered) # --- 4. Generate Visualizations --- self._generate_visualizations( @@ -162,8 +260,12 @@ def run_complete_analysis(self) -> Tuple[Path, Path, pd.DataFrame, pd.DataFrame] transcript_counts_filtered=transcript_counts_filtered, # Pass filtered counts for coldata generation gene_normalized_counts=gene_normalized_counts, transcript_normalized_counts=transcript_normalized_counts, + gene_vst_counts=gene_vst_counts, + transcript_vst_counts=transcript_vst_counts, deseq2_results_df_gene=deseq2_results_df_gene, - deseq2_results_df_transcript=deseq2_results_df_transcript + deseq2_results_df_transcript=deseq2_results_df_transcript, + deseq2_results_df_gene_shrunk=deseq2_results_df_gene_shrunk, + deseq2_results_df_transcript_shrunk=deseq2_results_df_transcript_shrunk ) self.logger.info("Differential expression analysis workflow complete.") @@ -212,6 +314,11 @@ def _apply_transcript_filters(self, transcript_counts: pd.DataFrame) -> pd.DataF if not valid_transcripts: self.logger.warning("No valid transcripts found in updated_gene_dict. Skipping validity filter.") self.logger.debug(f"Found {len(valid_transcripts)} valid transcript IDs in updated_gene_dict.") + # Apply validity filter if available + if valid_transcripts: + before_valid = transcript_counts.shape[0] + transcript_counts = transcript_counts[transcript_counts.index.isin(valid_transcripts)] + self.logger.info(f"Validity filtering: Retained {transcript_counts.shape[0]} / {before_valid} transcripts present in updated_gene_dict") # --- Novel Transcript Filtering --- if self.dictionary_builder: @@ -240,7 +347,7 @@ def _apply_transcript_filters(self, transcript_counts: pd.DataFrame) -> pd.DataF def _perform_level_analysis( self, level: str, count_data: pd.DataFrame - ) -> Tuple[Path, pd.DataFrame, pd.DataFrame]: + ) -> Tuple[Path, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ Runs DESeq2 analysis for a specific level (gene or transcript). @@ -263,42 +370,22 @@ def _perform_level_analysis( # Create design matrix coldata = self._build_design_matrix(count_data) - # Run DESeq2 - Now returns results and normalized counts - results_df, normalized_counts_df = self._run_deseq2(count_data, coldata, level) - - # --- Process DESeq2 Results --- - results_df.index.name = "feature_id" - results_df.reset_index(inplace=True) # Keep feature_id as a column - - # Map gene symbols/names - mapping = self._map_gene_symbols(results_df["feature_id"].unique(), level) - - # Add transcript_symbol and gene_name columns safely using .get - results_df["transcript_symbol"] = results_df["feature_id"].map( - lambda x: mapping.get(x, {}).get("transcript_symbol", x) # Default to feature_id if not found - ) - results_df["gene_name"] = results_df["feature_id"].map( - lambda x: mapping.get(x, {}).get("gene_name", x.split('.')[0] if '.' in x else x) # Default to feature_id logic if not found - ) - + # Run DESeq2 - returns results, shrunk results, normalized counts, and VST counts + results_df, results_shrunk_df, normalized_counts_df, vst_counts_df = self._run_deseq2(count_data, coldata, level) - # Drop transcript_symbol column for gene-level analysis as it's redundant - if level == "gene": - results_df = results_df.drop(columns=["transcript_symbol"], errors='ignore') # Use errors='ignore' + # --- Process and annotate DESeq2 Results --- + results_df = self._annotate_results(level, results_df) # --- Save Results --- - target_label = "+".join(self.target_conditions) - reference_label = "+".join(self.ref_conditions) - # Use the pattern argument passed to _get_condition_data if needed, or derive filename like this - outfile = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}.csv" - results_df.to_csv(outfile, index=False) + # Save both standard and annotated shrunk results + outfile, _ = self._save_results(level, results_df, results_shrunk_df) self.logger.info(f"Saved DESeq2 results ({results_df.shape[0]} features) to {outfile}") # --- Write Top Genes/Transcripts --- self._write_top_genes(results_df, level) self.logger.info(f"DESeq2 analysis complete for level: {level}") - return outfile, results_df, normalized_counts_df + return outfile, results_df, results_shrunk_df, normalized_counts_df, vst_counts_df def _generate_visualizations( self, @@ -306,8 +393,12 @@ def _generate_visualizations( transcript_counts_filtered: pd.DataFrame, gene_normalized_counts: pd.DataFrame, transcript_normalized_counts: pd.DataFrame, + gene_vst_counts: pd.DataFrame, + transcript_vst_counts: pd.DataFrame, deseq2_results_df_gene: pd.DataFrame, deseq2_results_df_transcript: pd.DataFrame, + deseq2_results_df_gene_shrunk: Optional[pd.DataFrame] = None, + deseq2_results_df_transcript_shrunk: Optional[pd.DataFrame] = None, ): """Generates PCA plots and other visualizations based on DESeq2 results and normalized counts.""" self.logger.info("Generating visualizations...") @@ -315,8 +406,9 @@ def _generate_visualizations( reference_label = "+".join(self.ref_conditions) # --- Visualize Gene-Level DE Results --- + gene_results_for_plot = self._lfc_for_visuals(deseq2_results_df_gene, deseq2_results_df_gene_shrunk) self.visualizer.visualize_results( - results=deseq2_results_df_gene, # Use DataFrame directly + results=gene_results_for_plot, target_label=target_label, reference_label=reference_label, min_count=self.filter_min_count, # Use configured value @@ -325,21 +417,24 @@ def _generate_visualizations( self.logger.info(f"Gene-level DE summary visualizations saved to {self.deseq_dir}") # --- Run PCA (Gene Level) --- - if not gene_normalized_counts.empty: + gene_counts_for_pca = gene_vst_counts if gene_vst_counts is not None and not gene_vst_counts.empty else gene_normalized_counts + if not gene_counts_for_pca.empty: gene_coldata = self._build_design_matrix(gene_counts_filtered) # Need coldata matching the counts used self._run_pca( - normalized_counts=gene_normalized_counts, + normalized_counts=gene_counts_for_pca, level="gene", coldata=gene_coldata, target_label=target_label, - reference_label=reference_label + reference_label=reference_label, + is_vst=(gene_vst_counts is not None and not gene_vst_counts.empty) ) else: self.logger.warning("Skipping gene-level PCA: Normalized counts are empty.") # --- Visualize Transcript-Level DE Results --- + tx_results_for_plot = self._lfc_for_visuals(deseq2_results_df_transcript, deseq2_results_df_transcript_shrunk) self.visualizer.visualize_results( - results=deseq2_results_df_transcript, # Use DataFrame directly + results=tx_results_for_plot, target_label=target_label, reference_label=reference_label, min_count=self.filter_min_count, # Use configured value @@ -348,14 +443,16 @@ def _generate_visualizations( self.logger.info(f"Transcript-level DE summary visualizations saved to {self.deseq_dir}") # --- Run PCA (Transcript Level) --- - if not transcript_normalized_counts.empty: + tx_counts_for_pca = transcript_vst_counts if transcript_vst_counts is not None and not transcript_vst_counts.empty else transcript_normalized_counts + if not tx_counts_for_pca.empty: transcript_coldata = self._build_design_matrix(transcript_counts_filtered) # Need coldata matching the counts used self._run_pca( - normalized_counts=transcript_normalized_counts, + normalized_counts=tx_counts_for_pca, level="transcript", coldata=transcript_coldata, target_label=target_label, - reference_label=reference_label + reference_label=reference_label, + is_vst=(transcript_vst_counts is not None and not transcript_vst_counts.empty) ) else: self.logger.warning("Skipping transcript-level PCA: Normalized counts are empty.") @@ -374,47 +471,10 @@ def _get_merged_transcript_counts(self, pattern: str) -> pd.DataFrame: if not self.ref_only and pattern == "transcript_grouped_counts.tsv": adjusted_pattern = "transcript_model_grouped_counts.tsv" - self.logger.info(f"Using file pattern: {adjusted_pattern}") - - # Store sample dataframes - all_sample_dfs = [] - - # Process each condition directory - for condition in self.ref_conditions + self.target_conditions: - condition_dir = Path(self.output_dir) / condition - count_files = list(condition_dir.glob(f"*{adjusted_pattern}")) - - if not count_files: - self.logger.error(f"No count files found for condition: {condition}") - raise FileNotFoundError(f"No count files matching {adjusted_pattern} found in {condition_dir}") - - # Load each count file - for file_path in count_files: - self.logger.info(f"Reading count data from: {file_path}") - - # Load the file - df = pd.read_csv(file_path, sep="\t") - if "#feature_id" not in df.columns and df.columns[0].startswith("#"): - # Rename first column if it's the feature ID column but named differently - df.rename(columns={df.columns[0]: "#feature_id"}, inplace=True) - - # Set feature_id as index - df.set_index("#feature_id", inplace=True) - - # For multi-condition data (typical in sample files) - # We need to prefix each column with the condition name - for col in df.columns: - df.rename(columns={col: f"{condition}_{col}"}, inplace=True) - - all_sample_dfs.append(df) - - # Concatenate all dataframes to get the full matrix - if not all_sample_dfs: - self.logger.error("No sample data frames found") - raise ValueError("No sample data found") + self.logger.debug(f"Using file pattern: {adjusted_pattern}") - # Combine all sample dataframes - combined_df = pd.concat(all_sample_dfs, axis=1) + # Load and prefix columns consistently + combined_df = self._load_prefixed_counts(adjusted_pattern) self.logger.info(f"Combined count data shape before mapping: {combined_df.shape}") # Apply technical replicate merging before transcript mapping @@ -477,46 +537,7 @@ def _get_condition_data(self, pattern: str) -> pd.DataFrame: elif pattern == "gene_grouped_counts.tsv": # For gene data, use a simpler approach (no merging needed) self.logger.info(f"Loading gene count data with pattern: {pattern}") - - # Store sample dataframes - all_sample_dfs = [] - - # Process each condition directory - for condition in self.ref_conditions + self.target_conditions: - condition_dir = Path(self.output_dir) / condition - count_files = list(condition_dir.glob(f"*{pattern}")) - - if not count_files: - self.logger.error(f"No gene count files found for condition: {condition}") - raise FileNotFoundError(f"No count files matching {pattern} found in {condition_dir}") - - # Load each count file - for file_path in count_files: - self.logger.info(f"Reading gene count data from: {file_path}") - - # Load the file - df = pd.read_csv(file_path, sep="\t") - if "#feature_id" not in df.columns and df.columns[0].startswith("#"): - # Rename first column if it's the feature ID column but named differently - df.rename(columns={df.columns[0]: "#feature_id"}, inplace=True) - - # Set feature_id as index - df.set_index("#feature_id", inplace=True) - - # For multi-condition data (typical in sample files) - # We need to prefix each column with the condition name - for col in df.columns: - df.rename(columns={col: f"{condition}_{col}"}, inplace=True) - - all_sample_dfs.append(df) - - # Concatenate all dataframes to get the full matrix - if not all_sample_dfs: - self.logger.error("No gene sample data frames found") - raise ValueError("No gene sample data found") - - # Combine all sample dataframes - combined_df = pd.concat(all_sample_dfs, axis=1) + combined_df = self._load_prefixed_counts(pattern) self.logger.info(f"Combined gene count data shape: {combined_df.shape}") # Apply technical replicate merging @@ -532,7 +553,9 @@ def _filter_counts(self, count_data: pd.DataFrame, level: str = "gene") -> pd.Da Filter features based on counts using the configured threshold. For genes: Keep if mean count >= configured min_count in either condition group. - For transcripts: Keep if count >= configured min_count in at least half of all samples. + For transcripts: Behavior is configurable. + - per_group_min (default): require counts >= threshold in at least K samples per group + - half_samples: require counts >= threshold in >= fraction of all samples """ if count_data.empty: self.logger.warning(f"Input count data for filtering ({level}) is empty. Returning empty DataFrame.") @@ -542,15 +565,51 @@ def _filter_counts(self, count_data: pd.DataFrame, level: str = "gene") -> pd.Da min_count_threshold = self.filter_min_count if level == "transcript": - total_samples = len(count_data.columns) - min_samples_required = max(1, total_samples // 2) # Ensure at least 1 sample is required - samples_passing = (count_data >= min_count_threshold).sum(axis=1) - keep_features = samples_passing >= min_samples_required + # Determine columns by condition name prefix + ref_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.ref_conditions) + ] + tgt_cols = [ + col for col in count_data.columns + if any(col.startswith(f"{cond}_") for cond in self.target_conditions) + ] - self.logger.info( - f"Transcript filtering: Keeping transcripts with counts >= {min_count_threshold} " - f"in at least {min_samples_required}/{total_samples} samples" - ) + if self.transcript_filter_mode == "half_samples": + total_cols = len(count_data.columns) + required = int(np.ceil(total_cols * float(self.transcript_min_total_fraction))) + passing_total = (count_data >= min_count_threshold).sum(axis=1) + keep_features = passing_total >= required + self.logger.info( + "Transcript filtering (half_samples): Keeping transcripts present in "+ + ">= %d/%d samples with counts >= %d", + required, total_cols, min_count_threshold + ) + else: + # Default: per-group minimum requirement + min_ref_required = ( + min(self.transcript_min_per_group, len(ref_cols)) if len(ref_cols) >= 1 else 0 + ) + min_tgt_required = ( + min(self.transcript_min_per_group, len(tgt_cols)) if len(tgt_cols) >= 1 else 0 + ) + + passing_ref = (count_data[ref_cols] >= min_count_threshold).sum(axis=1) if ref_cols else 0 + passing_tgt = (count_data[tgt_cols] >= min_count_threshold).sum(axis=1) if tgt_cols else 0 + + if isinstance(passing_ref, int): + self.logger.warning("No reference columns found for transcript filtering; no transcripts will pass.") + keep_features = count_data.index == "__none__" + elif isinstance(passing_tgt, int): + self.logger.warning("No target columns found for transcript filtering; no transcripts will pass.") + keep_features = count_data.index == "__none__" + else: + keep_features = (passing_ref >= min_ref_required) & (passing_tgt >= min_tgt_required) + + self.logger.info( + "Transcript filtering (per_group_min): Keeping transcripts with counts >= %d in at least %d/%d ref and %d/%d target samples", + min_count_threshold, min_ref_required, len(ref_cols), min_tgt_required, len(tgt_cols) + ) else: # gene level ref_cols = [ col for col in count_data.columns @@ -599,8 +658,13 @@ def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: groups = [] condition_assignments = [] sample_ids = [] + # Optional covariates + covariate_values: Dict[str, List[Optional[Union[str, float]]]] = {} + covariate_columns: List[str] = list(self.covariate_df.columns) if isinstance(self.covariate_df, pd.DataFrame) else [] + for cov_col in covariate_columns: + covariate_values[cov_col] = [] - self.logger.info("Building experimental design matrix") + self.logger.debug("Building experimental design matrix") for sample in count_data.columns: # Extract the condition from the sample name @@ -627,6 +691,15 @@ def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: # Store the condition and sample ID for additional information condition_assignments.append(condition) sample_ids.append(sample) + + # Attach covariate values if provided + if covariate_columns: + for cov_col in covariate_columns: + try: + value = self.covariate_df.loc[sample_id, cov_col] + except Exception: + value = np.nan + covariate_values[cov_col].append(value) # Create the design matrix DataFrame design_matrix = pd.DataFrame({ @@ -634,6 +707,10 @@ def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: "condition": condition_assignments, "sample_id": sample_ids }, index=count_data.columns) + # Append covariates into design matrix + if covariate_columns: + for cov_col in covariate_columns: + design_matrix[cov_col] = covariate_values[cov_col] # Log the design matrix for debugging self.logger.debug(f"Design matrix:\n{design_matrix}") @@ -642,7 +719,7 @@ def _build_design_matrix(self, count_data: pd.DataFrame) -> pd.DataFrame: def _run_deseq2( self, count_data: pd.DataFrame, coldata: pd.DataFrame, level: str - ) -> Tuple[pd.DataFrame, pd.DataFrame]: + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame], pd.DataFrame, pd.DataFrame]: """ Run DESeq2 analysis and return results and normalized counts. @@ -656,13 +733,13 @@ def _run_deseq2( """ self.logger.info(f"Running DESeq2 for {level} level...") deseq2 = importr("DESeq2") - # Ensure counts are integers for DESeq2 - count_data = count_data.fillna(0).round().astype(int) - - # Ensure count data has no negative values before passing to R + # Ensure counts are integers for DESeq2 (fail fast; do not silently coerce) + count_data = count_data.fillna(0) if (count_data < 0).any().any(): - self.logger.warning(f"Negative values found in count data for {level}. Clamping to 0.") - count_data = count_data.clip(lower=0) + raise ValueError(f"Negative counts detected for {level}.") + if not np.all(np.equal(count_data.values, np.floor(count_data.values))): + raise ValueError(f"Non-integer counts detected for {level}. Please supply raw integer counts.") + count_data = count_data.astype(int) if count_data.empty: self.logger.error(f"Count data is empty before running DESeq2 for {level}.") @@ -677,13 +754,29 @@ def _run_deseq2( # Create DESeqDataSet self.logger.debug("Creating DESeqDataSet...") + # Build design formula dynamically: ~ covariates + group + covariate_cols = [c for c in coldata.columns if c not in ["group", "condition", "sample_id"]] + formula_terms = (covariate_cols + ["group"]) if covariate_cols else ["group"] + design_formula = "~ " + " + ".join(formula_terms) + self.logger.info(f"Using design formula for {level}: {design_formula}") + # Ensure group and categorical covariates are factors with correct baseline + r('library(methods)') + r.assign("coldata_tmp", coldata_r) + r('coldata_tmp$group <- relevel(factor(coldata_tmp$group), "Reference")') + # Coerce non-numeric covariates to factors + for cov in covariate_cols: + if not pd.api.types.is_numeric_dtype(coldata[cov]): + r(f'coldata_tmp${cov} <- factor(coldata_tmp${cov})') + coldata_r = r('coldata_tmp') + dds = deseq2.DESeqDataSetFromMatrix( - countData=count_data_r, colData=coldata_r, design=Formula("~ group") + countData=count_data_r, colData=coldata_r, design=Formula(design_formula) ) # Run DESeq analysis self.logger.debug("Running DESeq()...") - dds = deseq2.DESeq(dds) + # Use sfType configured; 'poscounts' is recommended for zero-heavy counts + dds = deseq2.DESeq(dds, sfType=self.size_factor_type) # Get results self.logger.debug("Extracting results()...") @@ -728,16 +821,61 @@ def _run_deseq2( # Ensure DataFrame structure matches original count_data (features x samples) normalized_counts_df = pd.DataFrame(normalized_counts_py, index=count_data.index, columns=count_data.columns) + # VST-transformed counts for PCA visualization stability + try: + vst_obj = deseq2.vst(dds, blind=True) + vst_mat_r = r['assay'](vst_obj) + vst_counts_py = robjects.conversion.rpy2py(vst_mat_r) + vst_counts_df = pd.DataFrame(vst_counts_py, index=count_data.index, columns=count_data.columns) + except Exception as e: + self.logger.warning(f"VST transformation failed or unavailable: {e}") + vst_counts_df = pd.DataFrame() + # Generate dispersion and count summaries self._generate_dispersion_summary(res_df, level) - self.logger.info(f"DESeq2 run completed for {level}. Results shape: {res_df.shape}, Normalized counts shape: {normalized_counts_df.shape}") - return res_df, normalized_counts_df + # LFC shrinkage (apeglm) for interpretability; keep Wald stats for GSEA + res_shrunk_df: Optional[pd.DataFrame] = None + try: + # Ensure apeglm is available + importr("apeglm") + # Find appropriate coefficient name + coef_names = robjects.conversion.rpy2py(r['resultsNames'](dds)) + # Prefer the standard group coefficient; fallback to first matching 'group' + coef_name = None + for name in coef_names: + if isinstance(name, str) and "group_Target_vs_Reference" in name: + coef_name = name + break + if coef_name is None: + for name in coef_names: + if isinstance(name, str) and name.startswith("group_"): + coef_name = name + break + if coef_name is None and len(coef_names) > 0: + coef_name = coef_names[0] + + self.logger.info(f"Applying LFC shrinkage with apeglm (coef={coef_name}) for {level} level...") + res_shrunk = deseq2.lfcShrink(dds, coef=coef_name, type="apeglm") + res_shrunk_df = robjects.conversion.rpy2py(r("as.data.frame")(res_shrunk)) + res_shrunk_df.index = count_data.index + + # Save shrunk results to file + target_label = "+".join(self.target_conditions) + reference_label = "+".join(self.ref_conditions) + outfile_shrunk = self.deseq_dir / f"DE_{level}_{target_label}_vs_{reference_label}_shrunk.csv" + res_shrunk_df.to_csv(outfile_shrunk) + self.logger.info(f"Saved shrunk DESeq2 results to {outfile_shrunk}") + except Exception as e: + self.logger.warning(f"LFC shrinkage (apeglm) failed or unavailable: {e}") + + self.logger.info(f"DESeq2 run completed for {level}. Results shape: {res_df.shape}, Normalized counts shape: {normalized_counts_df.shape}, VST shape: {vst_counts_df.shape}") + return res_df, res_shrunk_df, normalized_counts_df, vst_counts_df except Exception as e: self.logger.error(f"Error running DESeq2 for {level}: {str(e)}") # Return empty DataFrames on error to avoid downstream issues - return pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns) + return pd.DataFrame(), pd.DataFrame(), pd.DataFrame(index=count_data.index, columns=count_data.columns), pd.DataFrame(index=count_data.index, columns=count_data.columns) def _generate_dispersion_summary(self, results_df: pd.DataFrame, level: str) -> None: """ @@ -996,7 +1134,7 @@ def _write_top_genes(self, results: pd.DataFrame, level: str) -> None: pd.Series(top_genes_list).to_csv(top_genes_file, index=False, header=False) self.logger.info(f"Wrote top {len(top_genes_list)} DE genes to {top_genes_file}") - def _run_pca(self, normalized_counts, level, coldata, target_label, reference_label): + def _run_pca(self, normalized_counts, level, coldata, target_label, reference_label, is_vst: bool = False): """Run PCA analysis and create visualization using DESeq2 normalized counts.""" self.logger.info(f"Running PCA for {level} level using DESeq2 normalized counts...") @@ -1018,21 +1156,26 @@ def _run_pca(self, normalized_counts, level, coldata, target_label, reference_la self.logger.warning(f"Reducing number of PCA components to {n_components} due to data dimensions.") - # Log transform the DESeq2 normalized counts (add 1 to handle zeros) - # Ensure data is numeric before transformation - log_normalized_counts = np.log2(normalized_counts.apply(pd.to_numeric, errors='coerce').fillna(0) + 1) + # Prepare matrix for PCA + # If using VST counts, do not log-transform again + if is_vst: + matrix_for_pca = normalized_counts.apply(pd.to_numeric, errors='coerce').fillna(0) + else: + # Log transform the DESeq2 normalized counts (add 1 to handle zeros) + # Ensure data is numeric before transformation + matrix_for_pca = np.log2(normalized_counts.apply(pd.to_numeric, errors='coerce').fillna(0) + 1) # Check for NaNs/Infs after log transform which can happen if counts were negative (though clamped earlier) or exactly -1 - if np.isinf(log_normalized_counts).any().any() or np.isnan(log_normalized_counts).any().any(): - self.logger.warning(f"NaNs or Infs found in log-transformed counts for {level}. Replacing with 0. This might indicate issues with count data.") - log_normalized_counts = log_normalized_counts.replace([np.inf, -np.inf], 0).fillna(0) + if np.isinf(matrix_for_pca).any().any() or np.isnan(matrix_for_pca).any().any(): + self.logger.warning(f"NaNs or Infs found in matrix for PCA for {level}. Replacing with 0. This might indicate issues with count data.") + matrix_for_pca = matrix_for_pca.replace([np.inf, -np.inf], 0).fillna(0) try: pca = PCA(n_components=n_components) # Transpose because PCA expects samples as rows, features as columns - pca_result = pca.fit_transform(log_normalized_counts.transpose()) + pca_result = pca.fit_transform(matrix_for_pca.transpose()) # Map feature IDs (index of normalized_counts) to gene names feature_ids = normalized_counts.index.tolist() @@ -1048,7 +1191,7 @@ def _run_pca(self, normalized_counts, level, coldata, target_label, reference_la # Create DataFrame with columns for all calculated components pc_columns = [f'PC{i+1}' for i in range(n_components)] - pca_df = pd.DataFrame(data=pca_result[:, :n_components], columns=pc_columns, index=log_normalized_counts.columns) # Use sample names as index + pca_df = pd.DataFrame(data=pca_result[:, :n_components], columns=pc_columns, index=matrix_for_pca.columns) # Use sample names as index # Add group information from coldata, ensuring index alignment # It's safer to reset index on coldata if it uses sample names as index too diff --git a/src/visualization_output_config.py b/src/visualization_output_config.py index 40256d47..99df8452 100644 --- a/src/visualization_output_config.py +++ b/src/visualization_output_config.py @@ -96,7 +96,10 @@ def _process_params(self, params): self.log_details["gene_db"] = params.get("genedb") self.log_details["fastq_used"] = bool(params.get("fastq")) self.input_gtf = self.input_gtf or params.get("genedb") - self.genedb_filename = params.get("genedb_filename") + + # Handle genedb_filename with fallback mechanism + original_genedb_filename = params.get("genedb_filename") + self.genedb_filename = self._find_genedb_file(original_genedb_filename) if params.get("yaml"): # YAML input case @@ -116,6 +119,56 @@ def _process_params(self, params): "Processing sample directory not found in params for non-YAML input." ) + def _find_genedb_file(self, original_path): + """Find genedb file with fallback mechanism.""" + from pathlib import Path + + # If no original path provided, skip to fallback + if original_path: + original_path_obj = Path(original_path) + if original_path_obj.exists(): + logging.info(f"Using original genedb file: {original_path}") + return original_path + else: + logging.warning(f"Original genedb file not found: {original_path}") + + # Fallback: Look for .db files in the output directory + output_path = Path(self.output_directory) + + # Look for .db files in the output directory + db_files = list(output_path.glob("*.db")) + + if db_files: + # Prefer files with common GTF database names + preferred_patterns = ["gtf.db", "gene.db", "genedb.db", "annotation.db"] + + # First, try to find files matching preferred patterns + for pattern in preferred_patterns: + for db_file in db_files: + if pattern in db_file.name.lower(): + logging.info(f"Found fallback genedb file (preferred pattern): {db_file}") + return str(db_file) + + # If no preferred pattern found, use the first .db file + fallback_db = db_files[0] + logging.info(f"Found fallback genedb file: {fallback_db}") + return str(fallback_db) + + # Last resort: check if we're in a subdirectory and look one level up + parent_db_files = list(output_path.parent.glob("*.db")) + if parent_db_files: + fallback_db = parent_db_files[0] + logging.info(f"Found fallback genedb file in parent directory: {fallback_db}") + return str(fallback_db) + + # No .db file found anywhere + if original_path: + logging.error(f"No genedb file found. Original path '{original_path}' doesn't exist, and no .db files found in '{output_path}' or parent directory.") + else: + logging.error(f"No genedb file found in '{output_path}' or parent directory, and no original path provided.") + + return original_path # Return original even if it doesn't exist, let the caller handle the error + def _conditional_unzip(self): """Check if unzip is needed and perform it conditionally based on the model use.""" if self.ref_only and self.input_gtf and self.input_gtf.endswith(".gz"): @@ -164,9 +217,8 @@ def _find_files(self): elif file_name.endswith(".read_assignments.tsv"): self.read_assignments = os.path.join(self.output_directory, file_name) elif file_name.endswith(".read_assignments.tsv.gz"): - self.read_assignments = self._unzip_file( - os.path.join(self.output_directory, file_name) - ) + # Prefer streaming gzip rather than unzipping + self.read_assignments = os.path.join(self.output_directory, file_name) elif file_name.endswith(".gene_grouped_counts.tsv"): self._conditions = self._get_conditions_from_file( os.path.join(self.output_directory, file_name) @@ -299,11 +351,8 @@ def _find_files_from_yaml(self): # Check for .read_assignments.tsv.gz gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") if os.path.exists(gz_file): - unzipped_file = self._unzip_file(gz_file) - if unzipped_file: - self.read_assignments.append((name, unzipped_file)) - else: - logging.warning(f"Failed to unzip {gz_file}") + # Prefer streaming gzip rather than unzipping + self.read_assignments.append((name, gz_file)) else: # Check for .read_assignments.tsv non_gz_file = os.path.join( diff --git a/src/visualization_plotter.py b/src/visualization_plotter.py index 53f0d644..22b6c8d1 100644 --- a/src/visualization_plotter.py +++ b/src/visualization_plotter.py @@ -161,7 +161,7 @@ def plot_transcript_map(self): exon_color = "skyblue" # If ref_only, always treat as reference exon_alpha = 1.0 else: - is_reference_exon = exon["exon_id"].startswith("ENSE") # Original logic + is_reference_exon = exon["exon_id"].startswith("E") # Original logic exon_color = "skyblue" if is_reference_exon else "red" exon_alpha = 1.0 if is_reference_exon else 0.6 @@ -434,6 +434,135 @@ def _create_pie_chart(self, title, data): plt.savefig(plot_path, bbox_inches='tight', dpi=300) plt.close() + def plot_read_length_effects(self, length_effects): + """ + Plot how read length relates to (a) assignment uniqueness and (b) FSM/ISM/mono classification. + Saves two bar charts into read_assignments_dir. + """ + if not self.read_assignments_dir: + logging.warning("No read_assignments_dir provided. Skipping length effects plotting.") + return + + bins = length_effects['bins'] + totals = length_effects['totals'] + + # Assignment uniqueness plot + df_a_rows = [] + for b in bins: + row = {'bin': b, **length_effects['by_bin_assignment'][b], 'TOTAL': totals[b]} + df_a_rows.append(row) + df_a = pd.DataFrame(df_a_rows) + if df_a.empty: + logging.warning("No data available for assignment uniqueness plot; skipping.") + return + df_a.set_index('bin', inplace=True) + + # Determine assignment categories dynamically and ensure columns exist + assignment_keys = length_effects.get('assignment_keys', []) + if not assignment_keys: + assignment_keys = [c for c in df_a.columns if c != 'TOTAL'] + for key in assignment_keys: + if key not in df_a.columns: + df_a[key] = 0 + + # Normalize to percentages per bin + for col in assignment_keys: + df_a[col] = np.where(df_a['TOTAL'] > 0, df_a[col] / df_a['TOTAL'] * 100.0, 0.0) + + # Preferred column order if present + preferred_order = ['UNIQUE', 'AMBIGUOUS', 'OTHER', 'INCONSISTENT', 'UNASSIGNED'] + ordered_cols = [c for c in preferred_order if c in assignment_keys] + [c for c in assignment_keys if c not in preferred_order] + if not ordered_cols: + logging.warning("No assignment columns to plot after normalization; skipping.") + return + + ax = df_a[ordered_cols].plot(kind='bar', stacked=True, figsize=(12,6), colormap='tab20') + ax.set_ylabel('Percentage of reads') + ax.set_title('Read assignment uniqueness by read length') + ax.legend(title='Assignment') + plt.tight_layout() + out1 = os.path.join(self.read_assignments_dir, 'read_length_vs_assignment_uniqueness.pdf') + plt.savefig(out1, bbox_inches='tight', dpi=300) + plt.close() + + def plot_read_length_histogram(self, hist_data): + """ + Plot a histogram of read lengths using precomputed bin edges/counts. + """ + if not self.read_assignments_dir: + logging.warning("No read_assignments_dir provided. Skipping length histogram plot.") + return + + edges = hist_data.get('edges', []) + counts = hist_data.get('counts', []) + total = hist_data.get('total', 0) + if not edges or not counts: + logging.warning("Empty histogram data; skipping.") + return + + # Build midpoints for bar plotting + mids = [(edges[i] + edges[i+1]) / 2.0 for i in range(len(counts))] + widths = [edges[i+1] - edges[i] for i in range(len(counts))] + + plt.figure(figsize=(12,6)) + plt.bar(mids, counts, width=widths, align='center', color='steelblue', edgecolor='black') + plt.xlabel('Read length (bp)') + plt.ylabel('Read count') + plt.title(f'Read length histogram (total n={total:,})') + plt.tight_layout() + outp = os.path.join(self.read_assignments_dir, 'read_length_histogram.pdf') + plt.savefig(outp, bbox_inches='tight', dpi=300) + plt.close() + + def plot_read_length_vs_assignment(self, length_vs_assignment): + """ + Plot read-length bins vs assignment_type and vs classification as stacked bar charts. + Saves two PDFs into read_assignments_dir. + """ + if not self.read_assignments_dir: + logging.warning("read_assignments_dir not set; skipping length vs assignment plots") + return + + import pandas as pd + import matplotlib.pyplot as plt + + bins = length_vs_assignment.get('bins', []) + a_counts = length_vs_assignment.get('assignment', {}) + c_counts = length_vs_assignment.get('classification', {}) + + # Build DataFrames + def to_df(counts_dict): + rows = [] + for (b, key), val in counts_dict.items(): + rows.append({'bin': b, 'key': key, 'count': val}) + df = pd.DataFrame(rows) + if df.empty: + return df + pivot = df.pivot_table(index='bin', columns='key', values='count', aggfunc='sum', fill_value=0) + # Ensure bin order + pivot = pivot.reindex(bins, axis=0).fillna(0) + return pivot + + df_a = to_df(a_counts) + df_c = to_df(c_counts) + + def plot_stacked(pivot_df, title, filename): + if pivot_df.empty: + logging.warning(f"No data for plot: {title}") + return + ax = pivot_df.plot(kind='bar', stacked=True, figsize=(12, 6)) + ax.set_xlabel('Read length bin') + ax.set_ylabel('Read count') + ax.set_title(title) + plt.tight_layout() + out = os.path.join(self.read_assignments_dir, filename) + plt.savefig(out) + plt.close() + logging.info(f"Saved plot: {out}") + + plot_stacked(df_a, 'Read length vs assignment_type', 'length_vs_assignment_type.pdf') + plot_stacked(df_c, 'Read length vs classification', 'length_vs_classification.pdf') + def plot_novel_transcript_contribution(self): """ Creates a plot showing the percentage of expression from novel transcripts. diff --git a/src/visualization_read_assignment_io.py b/src/visualization_read_assignment_io.py new file mode 100644 index 00000000..c2d9202c --- /dev/null +++ b/src/visualization_read_assignment_io.py @@ -0,0 +1,269 @@ +import logging +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union +import numpy as np + +from src.visualization_cache_utils import ( + build_read_assignment_cache_file, + build_length_effects_cache_file, + build_length_hist_cache_file, + save_cache, + load_cache, + validate_read_assignment_data, + validate_length_effects_data, + validate_length_hist_data, +) + + +def _smart_open(path_str: str): + import gzip + try: + with open(path_str, 'rb') as bf: + if bf.read(2) == b'\x1f\x8b': + return gzip.open(path_str, 'rt') + except Exception: + pass + return open(path_str, 'rt') + + +def _calc_length_bp(exons_str: str) -> int: + if not isinstance(exons_str, str) or not exons_str: + return 0 + total = 0 + for part in exons_str.split(','): + if '-' not in part: + continue + try: + s, e = part.split('-') + total += int(e) - int(s) + 1 + except Exception: + continue + return total + + +def get_read_assignment_counts(config, cache_dir: Path): + """ + Returns read-assignment classification and assignment_type counts, using cache. + Return format mirrors previous behavior in DictionaryBuilder: + - If config.read_assignments is a list: ({sample: class_counts}, {sample: assign_type_counts}) + - Else: (class_counts, assign_type_counts) + """ + logger = logging.getLogger('IsoQuant.visualization.read_assignment_io') + if not config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + cache_file = build_read_assignment_cache_file( + config.read_assignments, config.ref_only, cache_dir + ) + + if cache_file.exists(): + cached_data = load_cache(cache_file) + if cached_data and validate_read_assignment_data(cached_data, config.read_assignments): + logger.info("Using cached read assignment data.") + if isinstance(config.read_assignments, list): + return ( + cached_data["classification_counts"], + cached_data["assignment_type_counts"], + ) + return cached_data + + logger.info("Building read assignment data from scratch.") + + def process_file(file_path: str): + classification_counts: Dict[str, int] = {} + assignment_type_counts: Dict[str, int] = {} + with _smart_open(file_path) as fh: + # Skip header lines starting with '#' + while True: + pos = fh.tell() + line = fh.readline() + if not line: + break + if not line.startswith('#'): + fh.seek(pos) + break + for line in fh: + parts = line.strip().split('\t') + if len(parts) < 9: + continue + additional_info = parts[8] + assignment_type = parts[5] + classification = ( + additional_info.split('Classification=')[-1].split(';')[0].strip() + if 'Classification=' in additional_info else 'Unknown' + ) + classification_counts[classification] = classification_counts.get(classification, 0) + 1 + assignment_type_counts[assignment_type] = assignment_type_counts.get(assignment_type, 0) + 1 + return classification_counts, assignment_type_counts + + if isinstance(config.read_assignments, list): + classification_counts_dict: Dict[str, Dict[str, int]] = {} + assignment_type_counts_dict: Dict[str, Dict[str, int]] = {} + for sample_name, file_path in config.read_assignments: + c_counts, a_counts = process_file(file_path) + classification_counts_dict[sample_name] = c_counts + assignment_type_counts_dict[sample_name] = a_counts + to_cache = { + "classification_counts": classification_counts_dict, + "assignment_type_counts": assignment_type_counts_dict, + } + save_cache(cache_file, to_cache) + return classification_counts_dict, assignment_type_counts_dict + else: + counts = process_file(config.read_assignments) + save_cache(cache_file, counts) + return counts + + +def get_read_length_effects(config, cache_dir: Path) -> Dict[str, Any]: + """ + Compute and cache read-length effects aggregates: + - by length bin vs assignment_type and vs classification + - dynamic keys for observed categories + Returns dict with keys: bins, by_bin_assignment, by_bin_classification, assignment_keys, classification_keys, totals + """ + logger = logging.getLogger('IsoQuant.visualization.read_assignment_io') + if not config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + # Fixed bin order focused on 0-15 kb + bin_order = ['<1kb','1-2kb','2-3kb','3-4kb','4-5kb','5-6kb','6-7kb','7-8kb','8-9kb','9-10kb','10-12kb','12-15kb','>15kb'] + cache_file = build_length_effects_cache_file( + config.read_assignments, config.ref_only, cache_dir, bin_order + ) + + if cache_file.exists(): + cached = load_cache(cache_file) + if cached and validate_length_effects_data(cached, expected_bins=bin_order): + logger.info("Using cached read length effects.") + return cached + + from collections import defaultdict + by_bin_assignment: Dict[str, Dict[str, int]] = {b: defaultdict(int) for b in bin_order} + by_bin_classification: Dict[str, Dict[str, int]] = {b: defaultdict(int) for b in bin_order} + assignment_keys = set() + classification_keys = set() + totals: Dict[str, int] = {b: 0 for b in bin_order} + + def assign_bin(length_bp: int) -> str: + if length_bp < 1000: return '<1kb' + if length_bp < 2000: return '1-2kb' + if length_bp < 3000: return '2-3kb' + if length_bp < 4000: return '3-4kb' + if length_bp < 5000: return '4-5kb' + if length_bp < 6000: return '5-6kb' + if length_bp < 7000: return '6-7kb' + if length_bp < 8000: return '7-8kb' + if length_bp < 9000: return '8-9kb' + if length_bp < 10000: return '9-10kb' + if length_bp < 12000: return '10-12kb' + if length_bp < 15000: return '12-15kb' + return '>15kb' + + def process_file(file_path: str): + with _smart_open(file_path) as fh: + # Skip header lines + while True: + pos = fh.tell() + line = fh.readline() + if not line: + break + if not line.startswith('#'): + fh.seek(pos) + break + for line in fh: + parts = line.strip().split('\t') + if len(parts) < 9: + continue + assignment_type = parts[5] + exons_str = parts[7] + addi = parts[8] + classification = ( + addi.split('Classification=')[-1].split(';')[0].strip() + if 'Classification=' in addi else 'unknown' + ) + length_bp = _calc_length_bp(exons_str) + b = assign_bin(length_bp) + totals[b] += 1 + by_bin_assignment[b][assignment_type] += 1 + by_bin_classification[b][classification] += 1 + assignment_keys.add(assignment_type) + classification_keys.add(classification) + + if isinstance(config.read_assignments, list): + for _sample, file_path in config.read_assignments: + process_file(file_path) + else: + process_file(config.read_assignments) + + # Convert defaultdicts to dicts for safer pickling/validation + by_bin_assignment = {b: dict(d) for b, d in by_bin_assignment.items()} + by_bin_classification = {b: dict(d) for b, d in by_bin_classification.items()} + + result = { + 'bins': bin_order, + 'by_bin_assignment': by_bin_assignment, + 'by_bin_classification': by_bin_classification, + 'assignment_keys': sorted(list(assignment_keys)), + 'classification_keys': sorted(list(classification_keys)), + 'totals': totals, + } + save_cache(cache_file, result) + return result + + +def get_read_length_histogram(config, cache_dir: Path, bin_edges: List[int] = None) -> Dict[str, Any]: + """ + Compute and cache a histogram of read lengths derived from the exons column. + Returns dict with keys: edges, counts, total + """ + logger = logging.getLogger('IsoQuant.visualization.read_assignment_io') + if not config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + if bin_edges is None: + bin_edges = [ + 0, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, + 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000, + 12000, 15000, + ] + + cache_file = build_length_hist_cache_file( + config.read_assignments, config.ref_only, cache_dir, bin_edges + ) + + if cache_file.exists(): + cached = load_cache(cache_file) + if cached and validate_length_hist_data(cached, expected_edges=bin_edges): + logger.info("Using cached read length histogram.") + return cached + + lengths: List[int] = [] + + def process_file(file_path: str): + with _smart_open(file_path) as fh: + for line in fh: + if not line or line.startswith('#'): + continue + parts = line.rstrip('\n').split('\t') + if len(parts) < 8: + continue + exons = parts[7] + lengths.append(_calc_length_bp(exons)) + + if isinstance(config.read_assignments, list): + for _sample, file_path in config.read_assignments: + process_file(file_path) + else: + process_file(config.read_assignments) + + counts, edges = np.histogram(np.array(lengths, dtype=np.int64), bins=np.array(bin_edges)) + result = { + 'edges': edges.tolist(), + 'counts': counts.tolist(), + 'total': int(len(lengths)), + } + save_cache(cache_file, result) + return result + + diff --git a/src/visualization_simple_ranker.py b/src/visualization_simple_ranker.py index cd7994e0..1486af45 100644 --- a/src/visualization_simple_ranker.py +++ b/src/visualization_simple_ranker.py @@ -1,40 +1,88 @@ -#!/usr/bin/env python3 -"""A lightweight gene ranking utility for experiments lacking biological replicates. - -This module implements a fallback algorithm used when each experimental -condition has fewer than two biological replicates and, therefore, formal -statistics with DESeq2 are inappropriate. - -The scoring heuristic combines two intuitive effect-size metrics: - 1. Absolute log2 fold-change of gene-level TPM between target and - reference groups. - 2. Maximum change in isoform usage for any transcript belonging to a - gene. Usage is defined as transcript TPM / total gene TPM. - -The final score is: - score = |log2FC_gene| + max_delta_isoform_usage - -Genes are ranked by the score in descending order. - -The implementation is designed to mirror the interface expected by -visualize.py – namely, a ``rank(top_n)`` method that returns a list of gene -names (mapped from gene IDs using the updated gene dictionary). """ +Enhanced gene ranking utility for long‑read sequencing experiments without replicates. + +This module refines the original SimpleGeneRanker by incorporating best practices +from recent long‑read isoform switching studies. It defines "interesting" genes +as those that are highly expressed, exhibit bona‑fide isoform switching (at +least two isoforms changing in opposite directions), and potentially show +functional consequences (e.g., gain or loss of coding potential). Genes with +extreme overall expression changes or very complex isoform architectures are +penalised to reduce false positives. + +Key features: + +1. **Isoform count filter** – Genes with fewer than two transcripts are + excluded, as isoform usage cannot change. Genes with excessive numbers of + isoforms can be down‑weighted via a complexity penalty. +2. **Bidirectional isoform switching** – For a gene to be considered a + candidate isoform switcher, at least one transcript must increase in usage + while another decreases between reference and target conditions. This helps + distinguish true isoform switches from uniform scaling of all isoforms. +3. **Functional impact assessment** – When transcript annotations include + attributes such as coding status, ORF length or predicted NMD sensitivity, + the ranker rewards genes where isoform switches change these properties. + Lacking such annotations, this component defaults to zero influence. +4. **Adaptive thresholds** – Gating thresholds for expression level, fold + change and usage delta are derived from quantiles of the observed + distributions, making the algorithm robust across datasets with different + scales. +5. **Extreme change penalty** – Genes with very large gene‑level fold changes + are down‑weighted to prioritise isoform regulation over conventional + differential expression. +6. **Categorised output** – The ranker labels the top genes according to + whether they are isoform switchers, high expressers or conventional DEGs. + +Example usage: + + from enhanced_gene_ranker import EnhancedGeneRanker + ranker = EnhancedGeneRanker(output_dir="./out", + ref_conditions=["ref"], + target_conditions=["tgt"], + updated_gene_dict=gene_dict) + top_genes = ranker.rank(top_n=50) + # top_genes is a list of gene names + +Note: This implementation assumes that `updated_gene_dict` follows the same +structure as in the original SimpleGeneRanker. Transcript annotations may +contain keys such as ``coding`` (bool), ``orf_length`` (int) or +``functional_consequence`` (str). Missing annotations are handled gracefully. +""" + from __future__ import annotations import logging from pathlib import Path -from typing import List, Dict +from typing import Dict, List, Tuple import numpy as np import pandas as pd -logger = logging.getLogger("IsoQuant.visualization.simple_ranker") +# Configure module‑level logger +logger = logging.getLogger("IsoQuant.visualization.enhanced_ranker") logger.setLevel(logging.INFO) class SimpleGeneRanker: - """Rank genes by combined gene-expression and isoform-usage change.""" + """Rank genes by integrating expression level, fold change, isoform switching and functional impact. + + Parameters + ---------- + output_dir : str or Path + Directory where intermediate results could be written (not used here but kept + for compatibility). + ref_conditions : List[str] + List of keys in ``updated_gene_dict`` corresponding to reference + conditions. + target_conditions : List[str] + List of keys in ``updated_gene_dict`` corresponding to target + conditions. + ref_only : bool, optional + If True, only reference conditions will be considered (fold change + computation disabled). Defaults to False. + updated_gene_dict : Dict, optional + Nested dictionary with expression and transcript information. See + SimpleGeneRanker for expected format. + """ def __init__( self, @@ -42,307 +90,673 @@ def __init__( ref_conditions: List[str], target_conditions: List[str], ref_only: bool = False, - updated_gene_dict: Dict = None, + updated_gene_dict: Dict | None = None, ) -> None: self.output_dir = Path(output_dir) self.ref_conditions = list(ref_conditions) self.target_conditions = list(target_conditions) self.ref_only = ref_only - self.updated_gene_dict = updated_gene_dict or {} + self.updated_gene_dict: Dict[str, Dict] = updated_gene_dict or {} # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def rank(self, top_n: int = 100) -> List[str]: - logger.info("Running SimpleGeneRanker (no replicates detected)") - - # Debug: Log updated_gene_dict structure - if self.updated_gene_dict: - logger.info(f"updated_gene_dict has {len(self.updated_gene_dict)} conditions: {list(self.updated_gene_dict.keys())}") - # Show sample genes from each condition - for condition, genes in self.updated_gene_dict.items(): - sample_gene_ids = list(genes.keys())[:3] - logger.info(f"Condition '{condition}' has {len(genes)} genes. Sample gene IDs: {sample_gene_ids}") - # Show gene structure for first gene - if sample_gene_ids: - first_gene_id = sample_gene_ids[0] - first_gene_info = genes[first_gene_id] - logger.info(f" Sample gene '{first_gene_id}' structure: name='{first_gene_info.get('name', 'MISSING')}', keys={list(first_gene_info.keys())}") - else: - logger.warning("No updated_gene_dict provided!") + """Return a ranked list of gene names based on the enhanced scoring algorithm. + + Parameters + ---------- + top_n : int + Maximum number of genes to return. If fewer genes meet the + thresholds, the returned list may be shorter. + + Returns + ------- + List[str] + List of gene names (uppercase) ranked by decreasing composite score. + """ + logger.info("Running EnhancedGeneRanker") - # 1. Load gene-level TPM aggregated by sample. - logger.info(f"SIMPLE_RANKER: Loading TPM data for reference conditions: {self.ref_conditions}") - gene_expr_ref = self._aggregate_gene_tpm(self.ref_conditions) - logger.info(f"SIMPLE_RANKER: Reference TPM loaded - shape: {gene_expr_ref.shape}") - - logger.info(f"SIMPLE_RANKER: Loading TPM data for target conditions: {self.target_conditions}") - gene_expr_tgt = self._aggregate_gene_tpm(self.target_conditions) - logger.info(f"SIMPLE_RANKER: Target TPM loaded - shape: {gene_expr_tgt.shape}") + if not self.updated_gene_dict: + logger.warning("No updated_gene_dict provided; returning empty list") + return [] - # Genes common to both groups. + # Scan transcript ID patterns in updated_gene_dict + self._scan_transcript_id_patterns() + + # 1. Extract gene‑level TPMs + gene_expr_ref, gene_expr_tgt = self._extract_gene_tpms_from_dict() + if gene_expr_ref.empty or gene_expr_tgt.empty: + logger.warning("No gene expression data found; returning empty list") + return [] + + # 2. Compute gene list intersection common_genes = gene_expr_ref.index.intersection(gene_expr_tgt.index) - logger.info(f"SIMPLE_RANKER: Found {len(common_genes)} genes common to both ref and target groups") - logger.info(f"SIMPLE_RANKER: Sample common genes: {list(common_genes)[:10]}") - gene_expr_ref = gene_expr_ref.loc[common_genes] gene_expr_tgt = gene_expr_tgt.loc[common_genes] - logger.info(f"SIMPLE_RANKER: After filtering to common genes - ref shape: {gene_expr_ref.shape}, target shape: {gene_expr_tgt.shape}") - - # Filter to only genes present in updated_gene_dict (like differential expression analysis does) - if self.updated_gene_dict: - available_genes = set() - for condition_dict in self.updated_gene_dict.values(): - available_genes.update(condition_dict.keys()) - - # Log sample of available genes from updated_gene_dict - sample_genes = list(available_genes)[:10] - logger.info(f"Sample genes available in updated_gene_dict: {sample_genes}") - - # Keep only genes that are in the updated_gene_dict - filtered_common_genes = [g for g in common_genes if g in available_genes] - logger.info(f"Filtered genes to {len(filtered_common_genes)} from {len(common_genes)} based on updated_gene_dict availability") - - if not filtered_common_genes: - logger.warning("No genes remain after filtering by updated_gene_dict. Returning empty list.") - return [] - - gene_expr_ref = gene_expr_ref.loc[filtered_common_genes] - gene_expr_tgt = gene_expr_tgt.loc[filtered_common_genes] - common_genes = filtered_common_genes - - # 2. Compute log2 fold-change (add pseudocount of 1). - logger.info("SIMPLE_RANKER: Computing log2 fold-change...") - log2fc = np.log2(gene_expr_tgt + 1) - np.log2(gene_expr_ref + 1) - abs_log2fc = log2fc.abs() - logger.info(f"SIMPLE_RANKER: Log2FC stats - min: {log2fc.min():.3f}, max: {log2fc.max():.3f}, mean: {log2fc.mean():.3f}") - logger.info(f"SIMPLE_RANKER: Abs Log2FC stats - min: {abs_log2fc.min():.3f}, max: {abs_log2fc.max():.3f}, mean: {abs_log2fc.mean():.3f}") - - # Show top log2FC examples - top_log2fc_genes = abs_log2fc.nlargest(5) - logger.info(f"SIMPLE_RANKER: Top 5 genes by abs log2FC:") - for gene_id, score in top_log2fc_genes.items(): - ref_val = gene_expr_ref[gene_id] - tgt_val = gene_expr_tgt[gene_id] - logger.info(f" {gene_id}: ref_tpm={ref_val:.3f}, tgt_tpm={tgt_val:.3f}, abs_log2fc={score:.3f}") - - # 3. Compute isoform-usage change per gene. - logger.info("SIMPLE_RANKER: Computing isoform usage delta...") - delta_usage = self._compute_isoform_usage_delta(common_genes) - logger.info(f"SIMPLE_RANKER: Delta usage stats - min: {delta_usage.min():.3f}, max: {delta_usage.max():.3f}, mean: {delta_usage.mean():.3f}") - - # 4. Combined score. - logger.info("SIMPLE_RANKER: Computing combined score = |log2FC| + max_delta_isoform_usage...") - combined_score = abs_log2fc + delta_usage - combined_score.name = "score" - logger.info(f"SIMPLE_RANKER: Combined score stats - min: {combined_score.min():.3f}, max: {combined_score.max():.3f}, mean: {combined_score.mean():.3f}") - - # Show detailed scoring for top genes - top_combined_genes = combined_score.nlargest(10) - logger.info(f"SIMPLE_RANKER: Top 10 genes by combined score:") - for gene_id, score in top_combined_genes.items(): - log2fc_contrib = abs_log2fc[gene_id] - usage_contrib = delta_usage[gene_id] - logger.info(f" {gene_id}: total_score={score:.3f} (log2fc={log2fc_contrib:.3f} + usage={usage_contrib:.3f})") - - # 5. Rank and get top N gene IDs. - ranked_gene_ids = combined_score.sort_values(ascending=False).head(top_n).index.tolist() - logger.info(f"SimpleGeneRanker selected {len(ranked_gene_ids)} genes (top {top_n}) by score.") - logger.info(f"Top 10 ranked gene IDs: {ranked_gene_ids[:10]}") - - # Show the final scores for the top ranked genes - final_scores = combined_score.loc[ranked_gene_ids[:10]] - logger.info(f"Final scores for top 10 genes: {final_scores.to_dict()}") - - # 6. Map gene IDs to gene names directly from updated_gene_dict to ensure exact compatibility with plotter - if self.updated_gene_dict: - ranked_gene_names = [] - mapped_count = 0 - mapping_details = [] # For detailed logging - - for gene_id in ranked_gene_ids: - gene_name_found = None - # Look for this gene_id in updated_gene_dict to get the exact gene name - for condition_dict in self.updated_gene_dict.values(): - if gene_id in condition_dict: - gene_info = condition_dict[gene_id] - if "name" in gene_info and gene_info["name"]: - gene_name_found = gene_info["name"] - mapped_count += 1 - mapping_details.append(f"{gene_id} -> {gene_name_found}") - else: - mapping_details.append(f"{gene_id} -> NO_NAME (name field: {gene_info.get('name', 'MISSING')})") - break - else: - mapping_details.append(f"{gene_id} -> NOT_FOUND_IN_DICT") + logger.info(f"{len(common_genes)} genes present in both conditions") + + # 3. Compute isoform usage deltas and switching flags + usage_delta, switch_flags, func_flags, isoform_counts = self._compute_isoform_usage_metrics(common_genes) + + # 4. Normalize features + abs_log2fc = np.abs(np.log2(gene_expr_tgt + 1) - np.log2(gene_expr_ref + 1)) + geom_expr = np.sqrt(gene_expr_ref * gene_expr_tgt) + + norm_expr = self._normalize_feature(geom_expr, name="Expression") + norm_change = self._normalize_feature(abs_log2fc, name="FoldChange") + norm_usage = self._normalize_feature(usage_delta, name="UsageDelta") + # Functional impact does not need normalization (0/1), but we convert to series + func_series = pd.Series(func_flags, index=common_genes, dtype=float) + + # 5. Derive adaptive thresholds + expr_gate = norm_expr.quantile(0.5) # median + change_gate = norm_change.quantile(0.75) # upper quartile + usage_gate = norm_usage.quantile(0.75) + logger.info( + f"Adaptive gates – expression: {expr_gate:.3f}, fold change: {change_gate:.3f}, usage delta: {usage_gate:.3f}" + ) + + # 6. Compute composite scores + scores = pd.Series(0.0, index=common_genes) + categories = {} + for gene in common_genes: + expr_val = norm_expr.at[gene] + change_val = norm_change.at[gene] + usage_val = norm_usage.at[gene] + is_switch = switch_flags.get(gene, False) + func_val = func_series.at[gene] + iso_count = isoform_counts.get(gene, 0) + + # Handle single-transcript genes differently + if iso_count == 1: + # Single-transcript genes: focus on expression change, require higher thresholds + single_transcript_expr_gate = norm_expr.quantile(0.7) # Higher than multi-isoform + single_transcript_change_gate = norm_change.quantile(0.85) # Much higher fold change required - # Use the found gene name (uppercase to match plotter expectations) or fallback to gene_id - ranked_gene_names.append(gene_name_found.upper() if gene_name_found else gene_id) - - # Log mapping details for first 10 genes - logger.info(f"Gene mapping details (first 10):") - for detail in mapping_details[:10]: - logger.info(f" {detail}") - - logger.info(f"SimpleGeneRanker mapped {mapped_count}/{len(ranked_gene_ids)} gene IDs to gene names from updated_gene_dict.") - logger.info(f"Final gene names (first 10): {ranked_gene_names[:10]}") - return ranked_gene_names - else: - logger.warning("No updated_gene_dict provided, returning raw gene IDs.") - return ranked_gene_ids + passes_expr = expr_val > single_transcript_expr_gate + passes_change = change_val > single_transcript_change_gate + + # Single-transcript penalty: they need to work harder to compete + single_transcript_penalty = 0.6 + + # Score based only on expression and fold change (no isoform switching possible) + base = 0.0 + if passes_expr and passes_change: + base = expr_val * 0.4 + change_val * 0.6 # Weight fold change more heavily + if func_val > 0: + base += 0.3 # Smaller functional bonus than multi-isoform + + # Apply single-transcript penalty + scores.at[gene] = base * single_transcript_penalty + + # Assign category + if base == 0: + categories[gene] = "LOW_EXPR" + else: + categories[gene] = "SINGLE_TRANSCRIPT_DE" # New category + + elif iso_count < 2: + # Skip genes with 0 transcripts (shouldn't happen but safety check) + continue + else: + # Multi-transcript genes: original logic with isoform switching + passes_expr = expr_val > expr_gate + passes_change = change_val > change_gate + passes_usage = usage_val > usage_gate and is_switch + + # Penalise extreme expression changes (>90th percentile) + penalty = 1.0 + if change_val > norm_change.quantile(0.9): + penalty *= 0.5 + + # Complexity penalty: down‑weight genes with many isoforms (top 10%) + if iso_count > np.quantile(list(isoform_counts.values()), 0.9): + penalty *= 0.7 + + # Compute base score; weight usage more heavily when switching + base = 0.0 + if passes_expr and (passes_change or passes_usage): + base = expr_val * 0.3 + change_val * 0.3 + usage_val * 1.2 + if func_val > 0: + base += 0.5 # Functional impact bonus + + # Final score after penalty + scores.at[gene] = base * penalty + + # Assign category for top genes later + if base == 0: + categories[gene] = "LOW_EXPR" + elif passes_usage and is_switch: + categories[gene] = "ISOFORM_SWITCHER" + elif passes_expr and passes_change: + categories[gene] = "DIFFERENTIAL_EXPRESSION" + else: + categories[gene] = "HIGH_EXPRESSION" + + # 7. Sort and select top genes + ranked = scores[scores > 0].sort_values(ascending=False) + ranked_gene_ids = ranked.head(top_n).index.tolist() + + # 8. Map gene IDs to names using updated_gene_dict + ranked_gene_names: List[str] = [] + for gene_id in ranked_gene_ids: + gene_name = None + for cond in self.updated_gene_dict.values(): + if gene_id in cond: + gene_info = cond[gene_id] + gene_name = gene_info.get("name") + break + ranked_gene_names.append(gene_name.upper() if gene_name else gene_id) + + # Log top entries with categories + for gene_id in ranked_gene_ids[:10]: + cat = categories.get(gene_id, "UNKNOWN") + score = scores.at[gene_id] + isoform_count = isoform_counts.get(gene_id, 0) + logger.info(f"Top gene {gene_id}: score={score:.3f}, category={cat}, isoforms={isoform_count}") + + # Log single-transcript gene statistics + self._log_single_transcript_statistics(categories, scores, isoform_counts) + + # Add lncRNA-specific statistics and biotype distribution + self._log_biotype_distribution() + self._log_lncrna_statistics(ranked_gene_ids, categories, scores) + + return ranked_gene_names # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ - # Cached root-level matrices to avoid re-reading - _root_gene_tpm: pd.DataFrame | None = None - _root_transcript_tpm: pd.DataFrame | None = None - - def _aggregate_gene_tpm(self, conditions: List[str]) -> pd.Series: - """Return mean TPM vector across all samples of the given conditions. - - Works with two possible layouts: - 1. YAML / multi-condition run – counts live under //. - 2. Single-sample run – a single *gene_grouped_tpm.tsv in /. - """ - logger.info(f"SIMPLE_RANKER: Aggregating gene TPM for conditions: {conditions}") - tpm_values: list[pd.Series] = [] + def _scan_transcript_id_patterns(self) -> None: + """Scan updated_gene_dict for transcript ID patterns and report findings.""" + transcript_generic = 0 # Count of IDs starting with "transcript" + transcript_ensembl = 0 # Count of proper Ensembl IDs (ENSMUST, ENST, etc.) + transcript_other = 0 # Count of other patterns - for cond in conditions: - cond_dir = self.output_dir / cond - files = list(cond_dir.glob("*gene_grouped_tpm.tsv")) - logger.info(f"SIMPLE_RANKER: Condition '{cond}' - found {len(files)} gene TPM files in {cond_dir}") - - if files: - for fp in files: - logger.info(f"SIMPLE_RANKER: Reading gene TPM from: {fp}") - df = self._read_tpm(fp) - logger.info(f"SIMPLE_RANKER: Gene TPM file shape: {df.shape}, columns: {list(df.columns)[:5]}...") - logger.info(f"SIMPLE_RANKER: Sample gene IDs: {list(df.index)[:5]}...") - tpm_series = df.sum(axis=1) - logger.info(f"SIMPLE_RANKER: Summed TPM series shape: {tpm_series.shape}, sample values: {tpm_series.head()}") - tpm_values.append(tpm_series) - continue # next condition - - # Fallback: root-level file present - logger.info(f"SIMPLE_RANKER: No condition-specific files found for '{cond}', trying root-level gene TPM...") - root_df = self._get_root_gene_tpm() - logger.info(f"SIMPLE_RANKER: Root gene TPM shape: {root_df.shape}, columns: {list(root_df.columns)}") - cond_cols = [c for c in root_df.columns if c == cond or c.startswith(f"{cond}__")] - logger.info(f"SIMPLE_RANKER: Found {len(cond_cols)} columns for condition '{cond}': {cond_cols}") - if not cond_cols: - logger.warning("Condition '%s' columns not found in root gene_grouped_tpm.tsv; treating as missing.", cond) - continue - cond_tpm = root_df[cond_cols].mean(axis=1) - logger.info(f"SIMPLE_RANKER: Condition TPM series shape: {cond_tpm.shape}, sample values: {cond_tpm.head()}") - tpm_values.append(cond_tpm) - - if not tpm_values: - logger.error("SIMPLE_RANKER: No TPM values found for provided conditions.") - raise FileNotFoundError("No TPM values found for provided conditions.") + generic_examples = [] + ensembl_examples = [] + other_examples = [] + + # Scan all conditions and genes + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + transcripts_dict = gene_info.get("transcripts", {}) + for tx_id in transcripts_dict.keys(): + if tx_id.lower().startswith("transcript"): + transcript_generic += 1 + if len(generic_examples) < 5: + generic_examples.append(f"{gene_id}:{tx_id}") + elif tx_id.startswith("ENSMUST") or tx_id.startswith("ENST") or tx_id.startswith("ENS"): + transcript_ensembl += 1 + if len(ensembl_examples) < 5: + ensembl_examples.append(f"{gene_id}:{tx_id}") + else: + transcript_other += 1 + if len(other_examples) < 5: + other_examples.append(f"{gene_id}:{tx_id}") + + total_transcripts = transcript_generic + transcript_ensembl + transcript_other + + logger.info(f"TRANSCRIPT_SCAN: Found {total_transcripts} total transcript entries") + + if total_transcripts > 0: + logger.info(f"TRANSCRIPT_SCAN: {transcript_generic} transcripts start with 'transcript' ({transcript_generic/total_transcripts*100:.1f}%)") + logger.info(f"TRANSCRIPT_SCAN: {transcript_ensembl} transcripts are Ensembl IDs ({transcript_ensembl/total_transcripts*100:.1f}%)") + logger.info(f"TRANSCRIPT_SCAN: {transcript_other} transcripts have other patterns ({transcript_other/total_transcripts*100:.1f}%)") + else: + logger.warning("TRANSCRIPT_SCAN: No transcripts found in updated_gene_dict") + + if generic_examples: + logger.info(f"TRANSCRIPT_SCAN: Generic transcript examples: {generic_examples}") + if ensembl_examples: + logger.info(f"TRANSCRIPT_SCAN: Ensembl transcript examples: {ensembl_examples}") + if other_examples: + logger.info(f"TRANSCRIPT_SCAN: Other transcript examples: {other_examples}") - logger.info(f"SIMPLE_RANKER: Collected {len(tpm_values)} TPM series for conditions {conditions}") - stacked = pd.concat(tpm_values, axis=1) - logger.info(f"SIMPLE_RANKER: Stacked TPM shape: {stacked.shape}") - aggregated = stacked.mean(axis=1) - logger.info(f"SIMPLE_RANKER: Final aggregated TPM shape: {aggregated.shape}, sample values: {aggregated.head()}") - logger.info(f"SIMPLE_RANKER: TPM stats - min: {aggregated.min():.3f}, max: {aggregated.max():.3f}, mean: {aggregated.mean():.3f}") - return aggregated - - def _get_root_gene_tpm(self) -> pd.DataFrame: - if self._root_gene_tpm is not None: - return self._root_gene_tpm - files = list(self.output_dir.glob("*gene_grouped_tpm.tsv")) - if not files: - raise FileNotFoundError("Root-level gene_grouped_tpm.tsv not found in output directory.") - df = self._read_tpm(files[0]) - self._root_gene_tpm = df - return df - - def _read_tpm(self, fp: Path) -> pd.DataFrame: - df = pd.read_csv(fp, sep="\t") - first_col = df.columns[0] - if first_col.startswith("#"): - df.rename(columns={first_col: "feature_id"}, inplace=True) - return df.set_index("feature_id") - - def _compute_isoform_usage_delta(self, gene_list: List[str]) -> pd.Series: - """Return a Series of maximal isoform-usage change for each gene.""" - # Load transcript TPMs for each group and compute usage. - usage_diff = pd.Series(0.0, index=gene_list) - - # Attempt to locate transcript TPM files. - trans_ref = self._aggregate_transcript_tpm(self.ref_conditions) - trans_tgt = self._aggregate_transcript_tpm(self.target_conditions) - - if trans_ref.empty or trans_tgt.empty: - logger.warning("Transcript TPM files missing – setting isoform usage component to 0.") - return usage_diff - - # Common transcripts. - common_tx = trans_ref.index.intersection(trans_tgt.index) - trans_ref = trans_ref.loc[common_tx] - trans_tgt = trans_tgt.loc[common_tx] - - # Map transcripts to genes by simple split (before '.') - gene_ids = trans_ref.index.to_series().str.split(".").str[0] - trans_ref_grouped = trans_ref.groupby(gene_ids).sum() - trans_tgt_grouped = trans_tgt.groupby(gene_ids).sum() - - # Compute per-gene usage change. - for gene in gene_list: - if gene not in trans_ref_grouped.index or gene not in trans_tgt_grouped.index: + # Warn if too many generic transcripts + if transcript_generic > transcript_ensembl: + logger.warning(f"TRANSCRIPT_SCAN: More generic 'transcript' IDs ({transcript_generic}) than Ensembl IDs ({transcript_ensembl}) - this may indicate annotation issues") + + def _normalize_feature(self, feature: pd.Series, name: str) -> pd.Series: + """Normalize a numeric Series to the 0–1 range. If all values are equal, + return zeros. + + Parameters + ---------- + feature : pd.Series + The numeric data to normalize. + name : str + Name of the feature for logging. + + Returns + ------- + pd.Series + Normalized series with the same index. + """ + if feature.empty: + return pd.Series(dtype=float) + fmin, fmax = feature.min(), feature.max() + if fmax == fmin: + logger.warning(f"{name} values are identical; returning zeros") + return pd.Series(0.0, index=feature.index) + norm = (feature - fmin) / (fmax - fmin) + return norm.fillna(0.0) + + def _extract_gene_tpms_from_dict(self) -> Tuple[pd.Series, pd.Series]: + """Extract average gene TPMs for reference and target conditions. + + This method sums transcript TPMs within each gene and averages across + multiple conditions. It mirrors the corresponding method in the + SimpleGeneRanker but returns pandas Series indexed by gene ID. + """ + ref_gene_tpms: Dict[str, float] = {} + tgt_gene_tpms: Dict[str, float] = {} + ref_count = 0 + tgt_count = 0 + + # Process reference conditions + for cond in self.ref_conditions: + if cond not in self.updated_gene_dict: continue - # Filter transcripts belonging to this gene. - mask = gene_ids == gene - gene_tx_ref = trans_ref[mask] - gene_tx_tgt = trans_tgt[mask] - - # Gene totals (add 1e-6 to avoid divide-by-zero) - ref_total = gene_tx_ref.sum() + 1e-6 - tgt_total = gene_tx_tgt.sum() + 1e-6 - - ref_usage = gene_tx_ref / ref_total - tgt_usage = gene_tx_tgt / tgt_total - max_delta = (ref_usage - tgt_usage).abs().max() - usage_diff.at[gene] = max_delta - return usage_diff - - def _aggregate_transcript_tpm(self, conditions: List[str]) -> pd.Series: - """Return mean TPM per transcript across all samples in conditions (handles both layouts).""" - tpm_values: list[pd.Series] = [] - pattern_default = "*transcript_grouped_tpm.tsv" - pattern_model = "*transcript_model_grouped_tpm.tsv" - - for cond in conditions: - cond_dir = self.output_dir / cond - pattern = pattern_model if self.ref_only else pattern_default - files = list(cond_dir.glob(pattern)) - if files: - for fp in files: - df = self._read_tpm(fp) - tpm_values.append(df.sum(axis=1)) + ref_count += 1 + genes = self.updated_gene_dict[cond] + for gene_id, gene_info in genes.items(): + tpm = 0.0 + for tx_info in gene_info.get("transcripts", {}).values(): + if isinstance(tx_info, dict) and "value" in tx_info: + tpm += tx_info["value"] + ref_gene_tpms[gene_id] = ref_gene_tpms.get(gene_id, 0.0) + tpm + + # Average across reference conditions + for gene_id in ref_gene_tpms: + if ref_count > 0: + ref_gene_tpms[gene_id] /= ref_count + + # Process target conditions + for cond in self.target_conditions: + if cond not in self.updated_gene_dict: continue - - # Fallback: root-level transcript TPM - root_df = self._get_root_transcript_tpm() - cond_cols = [c for c in root_df.columns if c == cond or c.startswith(f"{cond}__")] - if not cond_cols: + tgt_count += 1 + genes = self.updated_gene_dict[cond] + for gene_id, gene_info in genes.items(): + tpm = 0.0 + for tx_info in gene_info.get("transcripts", {}).values(): + if isinstance(tx_info, dict) and "value" in tx_info: + tpm += tx_info["value"] + tgt_gene_tpms[gene_id] = tgt_gene_tpms.get(gene_id, 0.0) + tpm + + # Average across target conditions + for gene_id in tgt_gene_tpms: + if tgt_count > 0: + tgt_gene_tpms[gene_id] /= tgt_count + + ref_series = pd.Series(ref_gene_tpms, name="ref_tpm") + tgt_series = pd.Series(tgt_gene_tpms, name="tgt_tpm") + return ref_series, tgt_series + + def _compute_isoform_usage_metrics(self, genes: List[str]) -> Tuple[pd.Series, Dict[str, bool], Dict[str, bool], Dict[str, int]]: + """Compute isoform usage delta, switching flag, functional impact flag and isoform count. + + Parameters + ---------- + genes : list of str + Gene identifiers to process. + + Returns + ------- + usage_delta : pd.Series + Maximum absolute change in isoform usage per gene. + switch_flags : dict + Dictionary mapping gene IDs to True if at least one transcript + increases and another decreases in usage between reference and target. + func_flags : dict + Dictionary mapping gene IDs to True if the isoform switching implies + a change in coding status or functional consequence. + isoform_counts : dict + Dictionary mapping gene IDs to the number of isoforms detected. + """ + usage_delta = pd.Series(0.0, index=genes) + switch_flags: Dict[str, bool] = {} + func_flags: Dict[str, bool] = {} + isoform_counts: Dict[str, int] = {} + + # Build per‑condition transcript TPM dictionaries + ref_tx = self._aggregate_transcript_tpms(self.ref_conditions) + tgt_tx = self._aggregate_transcript_tpms(self.target_conditions) + + # For each gene, compute usage change + for gene in genes: + # Collect transcripts and counts + tx_ids = set() + for cond in self.ref_conditions: + cond_dict = self.updated_gene_dict.get(cond, {}) + if gene in cond_dict: + tx_ids.update(cond_dict[gene].get("transcripts", {}).keys()) + for cond in self.target_conditions: + cond_dict = self.updated_gene_dict.get(cond, {}) + if gene in cond_dict: + tx_ids.update(cond_dict[gene].get("transcripts", {}).keys()) + isoform_counts[gene] = len(tx_ids) + if len(tx_ids) < 2: + switch_flags[gene] = False + func_flags[gene] = False + usage_delta.at[gene] = 0.0 continue - tpm_values.append(root_df[cond_cols].mean(axis=1)) - if not tpm_values: - return pd.Series(dtype=float) - stacked = pd.concat(tpm_values, axis=1) - return stacked.mean(axis=1) - - def _get_root_transcript_tpm(self) -> pd.DataFrame: - if self._root_transcript_tpm is not None: - return self._root_transcript_tpm - pattern = "*transcript_model_grouped_tpm.tsv" if self.ref_only else "*transcript_grouped_tpm.tsv" - files = list(self.output_dir.glob(pattern)) - if not files: - return pd.DataFrame() - df = self._read_tpm(files[0]) - self._root_transcript_tpm = df - return df + # Compute usage per condition + ref_total = 0.0 + tgt_total = 0.0 + ref_usages: Dict[str, float] = {} + tgt_usages: Dict[str, float] = {} + for tx_id in tx_ids: + r_tpm = ref_tx.get(tx_id, 0.0) + t_tpm = tgt_tx.get(tx_id, 0.0) + ref_total += r_tpm + tgt_total += t_tpm + ref_usages[tx_id] = r_tpm + tgt_usages[tx_id] = t_tpm + ref_total += 1e-6 # avoid zero division + tgt_total += 1e-6 + + # Compute usage fractions + deltas = [] + directions = [] + func_change = False + for tx_id in tx_ids: + ref_u = ref_usages[tx_id] / ref_total + tgt_u = tgt_usages[tx_id] / tgt_total + delta = tgt_u - ref_u + deltas.append(abs(delta)) + directions.append(np.sign(delta)) + + # Assess functional impact if annotation exists + # We check across any condition; assume annotation consistent + for cond in self.updated_gene_dict: + cond_dict = self.updated_gene_dict[cond] + if gene in cond_dict: + tx_info = cond_dict[gene].get("transcripts", {}).get(tx_id, {}) + # Compare coding status and ORF length relative to other transcripts + coding = tx_info.get("coding") + orf_len = tx_info.get("orf_length") + func = tx_info.get("functional_consequence") + break + # Simple heuristic: if any transcript has non‑zero functional_consequence + if func is not None: + func_change = True + # Determine switching: at least one positive and one negative change + switch_flags[gene] = (1 in directions) and (-1 in directions) + func_flags[gene] = func_change + usage_delta.at[gene] = max(deltas) if deltas else 0.0 + + return usage_delta, switch_flags, func_flags, isoform_counts + + def _aggregate_transcript_tpms(self, conditions: List[str]) -> Dict[str, float]: + """Aggregate transcript TPMs across a list of conditions, averaging across + conditions. + + Parameters + ---------- + conditions : list of str + Conditions to aggregate. + + Returns + ------- + Dict[str, float] + Mapping from transcript ID to averaged TPM value. + """ + tx_totals: Dict[str, float] = {} + count = 0 + for cond in conditions: + if cond not in self.updated_gene_dict: + continue + count += 1 + cond_dict = self.updated_gene_dict[cond] + for gene_info in cond_dict.values(): + for tx_id, tx_info in gene_info.get("transcripts", {}).items(): + if isinstance(tx_info, dict) and "value" in tx_info: + tx_totals[tx_id] = tx_totals.get(tx_id, 0.0) + tx_info["value"] + # Average + if count > 0: + for tx in tx_totals: + tx_totals[tx] /= count + return tx_totals + + def _log_single_transcript_statistics(self, categories: Dict[str, str], scores: pd.Series, isoform_counts: Dict[str, int]) -> None: + """Log statistics about single-transcript genes and how they performed.""" + single_transcript_genes = [gene_id for gene_id, count in isoform_counts.items() if count == 1] + multi_transcript_genes = [gene_id for gene_id, count in isoform_counts.items() if count > 1] + + single_transcript_scored = [gene_id for gene_id in single_transcript_genes if scores.get(gene_id, 0) > 0] + single_transcript_in_categories = [gene_id for gene_id in single_transcript_genes if categories.get(gene_id) == "SINGLE_TRANSCRIPT_DE"] + + # Get top single-transcript genes by score + single_transcript_scores = {gene_id: scores.get(gene_id, 0) for gene_id in single_transcript_genes} + top_single_transcript = sorted(single_transcript_scores.items(), key=lambda x: x[1], reverse=True)[:5] + + # Get expression info for top single-transcript genes + top_single_examples = [] + for gene_id, score in top_single_transcript[:3]: + if score > 0: + # Find gene name and expression values + gene_name = gene_id + ref_tpm = 0 + tgt_tpm = 0 + + for condition, genes in self.updated_gene_dict.items(): + if gene_id in genes: + gene_info = genes[gene_id] + gene_name = gene_info.get("name", gene_id) + + # Get the single transcript's TPM + transcripts = gene_info.get("transcripts", {}) + if transcripts: + tx_id, tx_info = next(iter(transcripts.items())) + tpm = tx_info.get("value", 0) + + if condition in self.ref_conditions: + ref_tpm += tpm / len(self.ref_conditions) + elif condition in self.target_conditions: + tgt_tpm += tpm / len(self.target_conditions) + + fold_change = (tgt_tpm + 0.1) / (ref_tpm + 0.1) # Add pseudocount + top_single_examples.append({ + "gene_id": gene_id, + "gene_name": gene_name, + "score": score, + "ref_tpm": ref_tpm, + "tgt_tpm": tgt_tpm, + "fold_change": fold_change + }) + + logger.info("=== SINGLE-TRANSCRIPT GENE ANALYSIS ===") + logger.info(f"Total single-transcript genes: {len(single_transcript_genes)}") + logger.info(f"Total multi-transcript genes: {len(multi_transcript_genes)}") + logger.info(f"Single-transcript genes with scores > 0: {len(single_transcript_scored)}") + logger.info(f"Single-transcript genes passing high thresholds: {len(single_transcript_in_categories)}") + + if len(single_transcript_genes) > 0: + pass_rate = (len(single_transcript_in_categories) / len(single_transcript_genes)) * 100 + logger.info(f"Single-transcript gene pass rate: {pass_rate:.1f}% (requires 85th percentile fold change)") + + if top_single_examples: + logger.info("Top single-transcript genes by score:") + for i, example in enumerate(top_single_examples, 1): + logger.info(f" {i}. {example['gene_name']} ({example['gene_id']}): " + f"score={example['score']:.3f}, " + f"ref_TPM={example['ref_tpm']:.1f}, " + f"tgt_TPM={example['tgt_tpm']:.1f}, " + f"FC={example['fold_change']:.2f}x") + + # Compare to multi-transcript genes + multi_transcript_scored = [gene_id for gene_id in multi_transcript_genes if scores.get(gene_id, 0) > 0] + if len(multi_transcript_genes) > 0: + multi_pass_rate = (len(multi_transcript_scored) / len(multi_transcript_genes)) * 100 + logger.info(f"Multi-transcript gene pass rate: {multi_pass_rate:.1f}% (for comparison)") + + def _log_biotype_distribution(self) -> None: + """Log the distribution of gene biotypes in the dataset.""" + biotype_counts = {} + total_genes = 0 + + # Count biotypes across all genes (avoiding duplicates by using first condition) + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + biotype = gene_info.get("biotype", "unknown") + biotype_counts[biotype] = biotype_counts.get(biotype, 0) + 1 + total_genes += 1 + break # Only count from one condition to avoid duplicates + + if total_genes == 0: + logger.warning("No genes found for biotype distribution analysis") + return + + # Sort biotypes by count + sorted_biotypes = sorted(biotype_counts.items(), key=lambda x: x[1], reverse=True) + + logger.info("=== GENE BIOTYPE DISTRIBUTION ===") + logger.info(f"Total genes analyzed: {total_genes}") + + for biotype, count in sorted_biotypes: + percentage = (count / total_genes) * 100 + logger.info(f"{biotype}: {count} genes ({percentage:.1f}%)") + + # Highlight key biotypes + lncrna_count = biotype_counts.get("lncRNA", 0) + biotype_counts.get("long_noncoding_rna", 0) + protein_coding_count = biotype_counts.get("protein_coding", 0) + pseudogene_counts = sum(count for biotype, count in biotype_counts.items() + if "pseudogene" in biotype.lower()) + + logger.info("=== KEY BIOTYPE SUMMARY ===") + logger.info(f"Protein coding genes: {protein_coding_count} ({(protein_coding_count/total_genes)*100:.1f}%)") + logger.info(f"lncRNA genes: {lncrna_count} ({(lncrna_count/total_genes)*100:.1f}%)") + logger.info(f"Pseudogenes (all types): {pseudogene_counts} ({(pseudogene_counts/total_genes)*100:.1f}%)") + + def _log_lncrna_statistics(self, ranked_gene_ids: List[str], categories: Dict[str, str], scores: pd.Series) -> None: + """Analyze and log statistics about lncRNAs in the ranked gene list.""" + lncrna_stats = { + "total_lncrnas": 0, + "top_50_lncrnas": 0, + "top_10_lncrnas": 0, + "isoform_switcher_lncrnas": 0, + "high_expr_lncrnas": 0, + "lncrna_examples": [] + } + + all_lncrnas = [] + + # Scan all genes to find lncRNAs + for condition, genes in self.updated_gene_dict.items(): + for gene_id, gene_info in genes.items(): + gene_biotype = gene_info.get("biotype", "").lower() + if gene_biotype in ["lncrna", "long_noncoding_rna", "lincrna"]: + lncrna_stats["total_lncrnas"] += 1 + gene_name = gene_info.get("name", gene_id) + score = scores.get(gene_id, 0.0) + category = categories.get(gene_id, "LOW_EXPR") + + all_lncrnas.append({ + "gene_id": gene_id, + "gene_name": gene_name, + "score": score, + "category": category, + "biotype": gene_biotype + }) + + # Count categories + if category == "ISOFORM_SWITCHER": + lncrna_stats["isoform_switcher_lncrnas"] += 1 + elif category == "HIGH_EXPRESSION": + lncrna_stats["high_expr_lncrnas"] += 1 + + # Check if in top rankings + if gene_id in ranked_gene_ids[:50]: + lncrna_stats["top_50_lncrnas"] += 1 + if gene_id in ranked_gene_ids[:10]: + lncrna_stats["top_10_lncrnas"] += 1 + break # Only check one condition to avoid duplicates + + # Sort lncRNAs by score and get top examples + all_lncrnas.sort(key=lambda x: x["score"], reverse=True) + lncrna_stats["lncrna_examples"] = all_lncrnas[:5] + + # Log comprehensive lncRNA statistics + logger.info("=== lncRNA ANALYSIS ===") + logger.info(f"Total lncRNAs detected: {lncrna_stats['total_lncrnas']}") + logger.info(f"lncRNAs in top 50 genes: {lncrna_stats['top_50_lncrnas']}") + logger.info(f"lncRNAs in top 10 genes: {lncrna_stats['top_10_lncrnas']}") + logger.info(f"lncRNA isoform switchers: {lncrna_stats['isoform_switcher_lncrnas']}") + logger.info(f"High-expression lncRNAs: {lncrna_stats['high_expr_lncrnas']}") + + if lncrna_stats["total_lncrnas"] > 0: + top_50_pct = (lncrna_stats["top_50_lncrnas"] / lncrna_stats["total_lncrnas"]) * 100 + logger.info(f"Percentage of lncRNAs in top 50: {top_50_pct:.1f}%") + + # Log top lncRNA examples + if lncrna_stats["lncrna_examples"]: + logger.info("Top scoring lncRNAs:") + for i, lncrna in enumerate(lncrna_stats["lncrna_examples"], 1): + logger.info(f" {i}. {lncrna['gene_name']} ({lncrna['gene_id']}): " + f"score={lncrna['score']:.3f}, category={lncrna['category']}") + + # Analyze transcript complexity for top lncRNAs + self._analyze_lncrna_transcript_complexity(lncrna_stats["lncrna_examples"][:3]) + + def _analyze_lncrna_transcript_complexity(self, top_lncrnas: List[Dict]) -> None: + """Analyze transcript complexity and biotype diversity for top lncRNAs.""" + if not top_lncrnas: + return + + logger.info("=== lncRNA TRANSCRIPT COMPLEXITY ===") + + for lncrna in top_lncrnas: + gene_id = lncrna["gene_id"] + gene_name = lncrna["gene_name"] + + # Find transcript information across conditions + transcript_info = {} + for condition, genes in self.updated_gene_dict.items(): + if gene_id in genes: + transcripts = genes[gene_id].get("transcripts", {}) + for tx_id, tx_info in transcripts.items(): + tx_biotype = tx_info.get("biotype", "unknown") + tx_value = tx_info.get("value", 0.0) + tx_name = tx_info.get("name", tx_id) + + if tx_id not in transcript_info: + transcript_info[tx_id] = { + "name": tx_name, + "biotype": tx_biotype, + "values": [] + } + transcript_info[tx_id]["values"].append(tx_value) + + # Calculate transcript statistics + transcript_count = len(transcript_info) + biotype_counts = {} + active_transcripts = 0 + + for tx_id, tx_data in transcript_info.items(): + biotype = tx_data["biotype"] + biotype_counts[biotype] = biotype_counts.get(biotype, 0) + 1 + + avg_value = sum(tx_data["values"]) / len(tx_data["values"]) if tx_data["values"] else 0 + if avg_value > 1.0: # Consider active if TPM > 1 + active_transcripts += 1 + + logger.info(f"lncRNA {gene_name} ({gene_id}):") + logger.info(f" Total transcripts: {transcript_count}") + logger.info(f" Active transcripts (TPM>1): {active_transcripts}") + logger.info(f" Transcript biotypes: {dict(biotype_counts)}") + + # Show top 3 most expressed transcripts + if transcript_info: + sorted_transcripts = sorted( + transcript_info.items(), + key=lambda x: sum(x[1]["values"]) / len(x[1]["values"]) if x[1]["values"] else 0, + reverse=True + ) + logger.info(" Top transcripts by expression:") + for i, (tx_id, tx_data) in enumerate(sorted_transcripts[:3], 1): + avg_expr = sum(tx_data["values"]) / len(tx_data["values"]) if tx_data["values"] else 0 + logger.info(f" {i}. {tx_data['name']}: {avg_expr:.2f} TPM ({tx_data['biotype']})") diff --git a/visualize.py b/visualize.py index 69b6bb5f..c14e2d8f 100755 --- a/visualize.py +++ b/visualize.py @@ -254,10 +254,10 @@ def main(): update_names = True min_val = args.filter_transcripts if args.filter_transcripts is not None else 1.0 - logging.info(f"FLOW_DEBUG: Building updated_gene_dict with:") - logging.info(f" min_value: {min_val}") - logging.info(f" reference_conditions: {getattr(args, 'reference_conditions', None)}") - logging.info(f" target_conditions: {getattr(args, 'target_conditions', None)}") + logging.debug(f"Building updated_gene_dict with:") + logging.debug(f" min_value: {min_val}") + logging.debug(f" reference_conditions: {getattr(args, 'reference_conditions', None)}") + logging.debug(f" target_conditions: {getattr(args, 'target_conditions', None)}") updated_gene_dict = dictionary_builder.build_gene_dict_with_expression_and_filter( min_value=min_val, @@ -265,9 +265,9 @@ def main(): target_conditions=getattr(args, 'target_conditions', None) ) - logging.info(f"FLOW_DEBUG: updated_gene_dict created:") - logging.info(f" type: {type(updated_gene_dict)}") - logging.info(f" keys (conditions): {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + logging.debug(f"updated_gene_dict created:") + logging.debug(f" type: {type(updated_gene_dict)}") + logging.debug(f" keys (conditions): {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") if updated_gene_dict: for condition, genes in updated_gene_dict.items(): logging.info(f" condition '{condition}': {len(genes)} genes") @@ -275,16 +275,16 @@ def main(): if sample_genes: for gene_id in sample_genes: gene_info = genes[gene_id] - logging.info(f" gene '{gene_id}': name='{gene_info.get('name', 'MISSING')}', keys={list(gene_info.keys())}") + logging.debug(f" gene '{gene_id}': name='{gene_info.get('name', 'MISSING')}', keys={list(gene_info.keys())}") if 'transcripts' in gene_info: - logging.info(f" transcripts: {len(gene_info['transcripts'])} items") + logging.debug(f" transcripts: {len(gene_info['transcripts'])} items") break # Only show details for first condition # Debug: log whether gene_dict keys are Ensembl IDs or gene names if updated_gene_dict: sample_condition = next(iter(updated_gene_dict)) sample_keys = list(updated_gene_dict[sample_condition].keys())[:5] - logging.info( + logging.debug( "Sample gene_dict keys for condition '%s': %s", sample_condition, sample_keys ) @@ -294,8 +294,23 @@ def main(): reads_and_class = ( dictionary_builder.build_read_assignment_and_classification_dictionaries() ) + # New: build read length effects aggregates + logging.debug("Building read length effects aggregates.") + try: + length_effects = dictionary_builder.build_read_length_effects() + except Exception as e: + logging.error(f"Failed to compute read length effects: {e}") + length_effects = None + # New: build read length histogram + try: + length_hist = dictionary_builder.build_read_length_histogram() + except Exception as e: + logging.error(f"Failed to compute read length histogram: {e}") + length_hist = None else: reads_and_class = None + length_effects = None + length_hist = None # 3. If user wants to find top genes (--find_genes), choose method based on replicate availability if args.find_genes is not None: @@ -400,14 +415,14 @@ def main(): read_assignments_dir = None # Set to None if not used # 6. Plotting with PlotOutput - logging.info(f"FLOW_DEBUG: Creating PlotOutput with:") - logging.info(f" gene_names type: {type(gene_list)}") - logging.info(f" gene_names length: {len(gene_list) if gene_list else 'None'}") - logging.info(f" gene_names content (first 10): {gene_list[:10] if gene_list else 'None'}") - logging.info(f" updated_gene_dict keys: {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") - logging.info(f" conditions: {output.conditions}") - logging.info(f" filter_transcripts: {min_val}") - logging.info(f" ref_only: {args.ref_only}") + logging.debug(f"Creating PlotOutput with:") + logging.debug(f" gene_names type: {type(gene_list)}") + logging.debug(f" gene_names length: {len(gene_list) if gene_list else 'None'}") + logging.debug(f" gene_names content (first 10): {gene_list[:10] if gene_list else 'None'}") + logging.debug(f" updated_gene_dict keys: {list(updated_gene_dict.keys()) if updated_gene_dict else 'None'}") + logging.debug(f" conditions: {output.conditions}") + logging.debug(f" filter_transcripts: {min_val}") + logging.debug(f" ref_only: {args.ref_only}") plot_output = PlotOutput( updated_gene_dict=updated_gene_dict, @@ -430,6 +445,18 @@ def main(): if use_read_assignments: plot_output.make_pie_charts() + # New: plot read length effects (assignment uniqueness and FSM/ISM/Mono) + if length_effects: + plot_output.plot_read_length_effects(length_effects) + # Also dynamic stacked charts for assignment/classification + plot_output.plot_read_length_vs_assignment({ + 'bins': length_effects['bins'], + 'assignment': { (b, k): v for b in length_effects['bins'] for k, v in length_effects['by_bin_assignment'][b].items() }, + 'classification': { (b, k): v for b in length_effects['bins'] for k, v in length_effects['by_bin_classification'][b].items() }, + }) + # New: plot histogram + if length_hist: + plot_output.plot_read_length_histogram(length_hist) if __name__ == "__main__":