|
| 1 | +import pandas as pd |
| 2 | +import json |
| 3 | +import gzip |
| 4 | +import os |
| 5 | +import glob |
| 6 | +import re |
| 7 | +import multiprocessing |
| 8 | +import sys |
| 9 | +import datetime |
| 10 | +import string |
| 11 | +import random |
| 12 | +from dateutil.relativedelta import relativedelta |
| 13 | +from functools import partial |
| 14 | +from bson.objectid import ObjectId |
| 15 | + |
| 16 | +from annotations import Annotations |
| 17 | +from expression_writer import ExpressionWriter |
| 18 | +from writer_functions import get_cluster_cells |
| 19 | +from monitor import setup_logger, bypass_mongo_writes |
| 20 | +from mongo_connection import MongoConnection, graceful_auto_reconnect |
| 21 | +import config |
| 22 | + |
| 23 | + |
| 24 | +class DotPlotGenes: |
| 25 | + COLLECTION_NAME = "dot_plot_genes" |
| 26 | + BATCH_SIZE = 100 |
| 27 | + ALLOWED_FILE_TYPES = ["text/csv", "text/plain", "text/tab-separated-values"] |
| 28 | + EXP_WRITER_SETTINGS = {"output_format": "dict", "sparse": True, "delocalize": False} |
| 29 | + denominator = 2 if re.match('darwin', sys.platform) else 1 |
| 30 | + num_cores = int(multiprocessing.cpu_count() / denominator) - 1 |
| 31 | + dev_logger = setup_logger(__name__, "log.txt", format="support_configs") |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + study_id, |
| 36 | + study_file_id, # expression matrix file |
| 37 | + cluster_group_id, |
| 38 | + cluster_file, |
| 39 | + cell_metadata_file, |
| 40 | + matrix_file_path, |
| 41 | + matrix_file_type, |
| 42 | + **kwargs, |
| 43 | + ): |
| 44 | + self.study_id = study_id |
| 45 | + self.study_file_id = study_file_id |
| 46 | + self.cluster_group_id = cluster_group_id |
| 47 | + self.cluster_file = cluster_file |
| 48 | + self.cell_metadata_file = cell_metadata_file |
| 49 | + self.matrix_file_path = matrix_file_path |
| 50 | + self.matrix_file_type = matrix_file_type |
| 51 | + self.kwargs = kwargs |
| 52 | + |
| 53 | + if matrix_file_type == "mtx": |
| 54 | + self.genes_path = self.kwargs["gene_file"] |
| 55 | + self.barcodes_path = self.kwargs["barcode_file"] |
| 56 | + else: |
| 57 | + self.genes_path = None |
| 58 | + self.barcodes_path = None |
| 59 | + |
| 60 | + self.mongo_connection = MongoConnection() |
| 61 | + |
| 62 | + # the cluster name here is not important, it is only used for directory names |
| 63 | + # a random 6-letter slug is appended to the end to avoid directory collisions when running in CI |
| 64 | + random_slug = ''.join(random.sample(string.ascii_letters, 6)) |
| 65 | + self.cluster_name = f"cluster_entry_{self.cluster_group_id}_{random_slug}" |
| 66 | + self.output_path = f"{self.cluster_name}/dot_plot_genes" |
| 67 | + self.exp_writer = ExpressionWriter( |
| 68 | + self.matrix_file_path, self.matrix_file_type, self.cluster_file, self.cluster_name, self.genes_path, |
| 69 | + self.barcodes_path |
| 70 | + ) |
| 71 | + |
| 72 | + self.annotation_map = {} |
| 73 | + self.cluster_cells = [] |
| 74 | + |
| 75 | + def set_annotation_map(self): |
| 76 | + """ |
| 77 | + Preprocess all associated annotation data to generate the following: |
| 78 | + - list of cluster cells |
| 79 | + - map of all qualifying annotations with list of clusters cells in each annotation label |
| 80 | + """ |
| 81 | + self.dev_logger.info(f"getting cluster cells from {self.cluster_file}") |
| 82 | + self.cluster_cells = get_cluster_cells(self.cluster_file) |
| 83 | + self.dev_logger.info(f"preprocessing annotation data from {self.cell_metadata_file}") |
| 84 | + cell_metadata = Annotations(self.cell_metadata_file, self.ALLOWED_FILE_TYPES) |
| 85 | + cell_metadata.preprocess(False) |
| 86 | + valid_metadata = [ |
| 87 | + [column[0], 'study', cell_metadata] for column in cell_metadata.file.columns if |
| 88 | + self.annotation_is_valid(column, cell_metadata) |
| 89 | + ] |
| 90 | + # check for annotations in cluster file |
| 91 | + cluster_annots = Annotations(self.cluster_file, self.ALLOWED_FILE_TYPES) |
| 92 | + cluster_annots.preprocess(False) |
| 93 | + valid_cluster_annots = [ |
| 94 | + [column[0], 'cluster', cluster_annots] for column in cluster_annots.file.columns if |
| 95 | + self.annotation_is_valid(column, cluster_annots) |
| 96 | + ] |
| 97 | + all_annotations = valid_metadata + valid_cluster_annots |
| 98 | + for annotation_name, annotation_scope, source_data in all_annotations: |
| 99 | + self.add_annotation_to_map(annotation_name, annotation_scope, source_data) |
| 100 | + self.dev_logger.info(f"Annotation data preprocessing for {self.cell_metadata_file} complete") |
| 101 | + |
| 102 | + def add_annotation_to_map(self, annotation_name, annotation_scope, source_data): |
| 103 | + """ |
| 104 | + Take an individual annotation, filter cells and add it to the annotation_map dictionary |
| 105 | + :param annotation_name: (str) name of annotation |
| 106 | + :param annotation_scope: (str) scope of annotation, either study or cluster |
| 107 | + :param source_data: (DataFrame) pandas dataframe of source data |
| 108 | + """ |
| 109 | + annotation_id = f"{annotation_name}--group--{annotation_scope}" |
| 110 | + groups = dict(source_data.file['NAME']['TYPE'].groupby(source_data.file[annotation_name]['group']).unique()) |
| 111 | + self.dev_logger.info(f"reading {annotation_name}, found {len(groups)} labels") |
| 112 | + self.annotation_map[annotation_id] = {} |
| 113 | + for group in groups: |
| 114 | + filtered_cells = set(self.cluster_cells).intersection(set(groups[group])) |
| 115 | + self.annotation_map[annotation_id][group] = filtered_cells |
| 116 | + |
| 117 | + def render_gene_expression(self): |
| 118 | + """ |
| 119 | + Render gene-level expression for all cells in the given cluster, filtering for non-zero expression |
| 120 | + """ |
| 121 | + self.dev_logger.info(f"Rendering cluster-filtered gene expression from {self.matrix_file_path}") |
| 122 | + self.exp_writer.render_artifacts(**self.EXP_WRITER_SETTINGS) |
| 123 | + self.dev_logger.info(f"Gene expression rendering for {self.matrix_file_path} complete") |
| 124 | + |
| 125 | + def preprocess(self): |
| 126 | + """ |
| 127 | + Preprocess all data in preparation of creating DotPlotGene entries |
| 128 | + """ |
| 129 | + self.set_annotation_map() |
| 130 | + self.render_gene_expression() |
| 131 | + self.dev_logger.info("All data preprocessing complete") |
| 132 | + |
| 133 | + @staticmethod |
| 134 | + def annotation_is_valid(column, source_data): |
| 135 | + """ |
| 136 | + Determine if a given column in an annotation file is valid |
| 137 | + must be group-based and have between 2 and 250 values |
| 138 | + :param column: (str) name of column |
| 139 | + :param source_data: (DataFrame) pandas dataframe of source data |
| 140 | + :return: (bool) |
| 141 | + """ |
| 142 | + viz_range = range(2, 200, 1) |
| 143 | + column_name, annotation_type = column |
| 144 | + return annotation_type == 'group' and len(source_data.file[column_name][annotation_type].unique()) in viz_range |
| 145 | + |
| 146 | + @staticmethod |
| 147 | + def process_gene(gene_file, output_path, dot_plot_gene, annotation_map): |
| 148 | + """ |
| 149 | + Read gene-level document and compute both mean and percent cells expressing for all applicable annotations |
| 150 | + :param gene_file: (str) name of gene-level JSON file |
| 151 | + :param output_path (str) path to write output files to |
| 152 | + :param dot_plot_gene (dict) empty DotPlotGene with IDs already populated |
| 153 | + :param annotation_map (dict) class-level map of all annotations/labels and cells in each |
| 154 | + :param cluster_cells (list) list of all cells from cluster |
| 155 | + :return: (dict) fully processed DotPlotGene |
| 156 | + """ |
| 157 | + gene_name = DotPlotGenes.get_gene_name(gene_file) |
| 158 | + dot_plot_gene["gene_symbol"] = gene_name |
| 159 | + dot_plot_gene["searchable_gene"] = gene_name.lower() |
| 160 | + gene_dict = DotPlotGenes.get_gene_dict(gene_file) |
| 161 | + exp_scores = DotPlotGenes.get_expression_metrics(gene_dict, annotation_map) |
| 162 | + dot_plot_gene['exp_scores'] = exp_scores |
| 163 | + with gzip.open(f"{output_path}/{gene_name}.json.gz", "wt") as file: |
| 164 | + json.dump(dot_plot_gene, file, separators=(',', ':')) |
| 165 | + |
| 166 | + @staticmethod |
| 167 | + def get_expression_metrics(gene_doc, annotation_map): |
| 168 | + """ |
| 169 | + Set the mean expression and percent cells expressing for all available annotations/labels |
| 170 | + :param gene_doc: (dict) gene-level expression dict |
| 171 | + :param annotation_map (dict) class-level map of all annotations/labels and cells in each |
| 172 | + :return: (dict) |
| 173 | + """ |
| 174 | + expression_metrics = {} |
| 175 | + for annotation in annotation_map: |
| 176 | + expression_metrics[annotation] = {} |
| 177 | + for label in annotation_map[annotation]: |
| 178 | + label_cells = annotation_map[annotation][label] |
| 179 | + filtered_expression = DotPlotGenes.filter_expression_for_label(gene_doc, label_cells) |
| 180 | + pct_exp = DotPlotGenes.pct_expression(filtered_expression, label_cells) |
| 181 | + scaled_mean = DotPlotGenes.scaled_mean_expression(filtered_expression, pct_exp) |
| 182 | + expression_metrics[annotation][label] = [scaled_mean, pct_exp] |
| 183 | + return expression_metrics |
| 184 | + |
| 185 | + @staticmethod |
| 186 | + def filter_expression_for_label(gene_doc, filter_cells): |
| 187 | + """ |
| 188 | + Filter gene expression for cells present in a given label |
| 189 | + :param gene_doc: (dict) gene-level expression dict |
| 190 | + :param filter_cells: (list) list of cells to filter on |
| 191 | + :return: (dict) original gene doc filtered by cells from annotation label |
| 192 | + """ |
| 193 | + return {cell: exp for cell, exp in gene_doc.items() if cell in filter_cells} |
| 194 | + |
| 195 | + @staticmethod |
| 196 | + def get_gene_name(gene_file_path): |
| 197 | + """ |
| 198 | + Extract gene symbol from filepath |
| 199 | + :param gene_file_path: (str) path to gene JSON file |
| 200 | + :return: (str) |
| 201 | + """ |
| 202 | + return re.sub(r'\.json\.gz', '', gene_file_path.split('/')[1]) |
| 203 | + |
| 204 | + @staticmethod |
| 205 | + def get_gene_dict(gene_path): |
| 206 | + """ |
| 207 | + Read a gene document and process as a dict |
| 208 | + :param gene_path: (str) path to gzipped gene doc |
| 209 | + :return: (dict) |
| 210 | + """ |
| 211 | + return json.load(gzip.open(gene_path, 'rt')) |
| 212 | + |
| 213 | + @staticmethod |
| 214 | + def to_model(gene_dict): |
| 215 | + """ |
| 216 | + Convert a raw dict into a document that can be inserted into MongoDB |
| 217 | + :param gene_dict: (dict) raw processed dot plot gene entry |
| 218 | + :return: (dict) transformed dict with ObjectId entries |
| 219 | + """ |
| 220 | + model_dict = gene_dict.copy() |
| 221 | + model_dict['study_id'] = ObjectId(gene_dict['study_id']) |
| 222 | + model_dict['study_file_id'] = ObjectId(gene_dict['study_file_id']) |
| 223 | + model_dict['cluster_group_id'] = ObjectId(gene_dict['cluster_group_id']) |
| 224 | + return model_dict |
| 225 | + |
| 226 | + @staticmethod |
| 227 | + def scaled_mean_expression(gene_doc, pct_exp): |
| 228 | + """ |
| 229 | + Get the scaled mean expression of cells for a given gene |
| 230 | + :param gene_doc: (dict) gene-level significant expression values |
| 231 | + :param pct_exp: (float) percentage of cells expressing for gene to scale mean by |
| 232 | + :return: (float) |
| 233 | + """ |
| 234 | + exp_values = pd.DataFrame(gene_doc.values()) |
| 235 | + if exp_values.empty: |
| 236 | + return 0.0 |
| 237 | + else: |
| 238 | + raw_exp = exp_values.mean()[0] |
| 239 | + return round(raw_exp * pct_exp, 3) |
| 240 | + |
| 241 | + @staticmethod |
| 242 | + def pct_expression(gene_doc, cells): |
| 243 | + """ |
| 244 | + Get the percentage of cells expressing for a given gene relative to the cells in the cluster |
| 245 | + :param gene_doc: (dict) gene-level significant expression values |
| 246 | + :param cells: (list) list of cells for given annotation label |
| 247 | + :return: (float) |
| 248 | + """ |
| 249 | + if len(cells) == 0: |
| 250 | + return 0.0 |
| 251 | + observed_cells = gene_doc.keys() |
| 252 | + return round(len(observed_cells) / len(cells), 4) |
| 253 | + |
| 254 | + def process_all_genes(self): |
| 255 | + """ |
| 256 | + Parallel function to process all files and render out DotPlotGene dicts |
| 257 | + """ |
| 258 | + os.mkdir(self.output_path) |
| 259 | + gene_files = glob.glob(f"{self.cluster_name}/*.json.gz") |
| 260 | + blank_dot_plot_gene = { |
| 261 | + "study_id": self.study_id, |
| 262 | + "study_file_id": self.study_file_id, |
| 263 | + "cluster_group_id": self.cluster_group_id, |
| 264 | + "exp_scores": {} |
| 265 | + } |
| 266 | + self.dev_logger.info(f"beginning parallel rendering of {len(gene_files)} DotPlotGene entries") |
| 267 | + pool = multiprocessing.Pool(self.num_cores) |
| 268 | + processor = partial( |
| 269 | + DotPlotGenes.process_gene, |
| 270 | + dot_plot_gene=blank_dot_plot_gene, |
| 271 | + output_path=self.output_path, |
| 272 | + annotation_map=self.annotation_map |
| 273 | + ) |
| 274 | + pool.map(processor, gene_files) |
| 275 | + |
| 276 | + @graceful_auto_reconnect |
| 277 | + def load(self, documents): |
| 278 | + """ |
| 279 | + Insert batch of documents into MongoDB |
| 280 | + :param documents: (list) list of rendered documents to insert |
| 281 | + """ |
| 282 | + if not bypass_mongo_writes(): |
| 283 | + self.mongo_connection._client[self.COLLECTION_NAME].insert_many(documents, ordered=False) |
| 284 | + else: |
| 285 | + dev_msg = f"Extracted {len(documents)} DotPlotGenes for {self.matrix_file_path}" |
| 286 | + self.dev_logger.info(dev_msg) |
| 287 | + |
| 288 | + def transform(self): |
| 289 | + """ |
| 290 | + Main handler to process all data and render/insert DotPlotGenes |
| 291 | + """ |
| 292 | + start_time = datetime.datetime.now() |
| 293 | + self.dev_logger.info(f"beginning rendering of {self.matrix_file_path} into DotPlotGene entries") |
| 294 | + self.preprocess() |
| 295 | + self.process_all_genes() |
| 296 | + self.dev_logger.info(f"rendering of {self.matrix_file_path} complete, beginning load") |
| 297 | + gene_docs = [] |
| 298 | + for gene_path in glob.glob(f"{self.output_path}/*.json.gz"): |
| 299 | + rendered_gene = DotPlotGenes.get_gene_dict(gene_path) |
| 300 | + model_dict = DotPlotGenes.to_model(rendered_gene) |
| 301 | + gene_docs.append(model_dict) |
| 302 | + if len(gene_docs) == self.BATCH_SIZE: |
| 303 | + self.load(gene_docs) |
| 304 | + gene_docs.clear() |
| 305 | + if len(gene_docs) > 0: |
| 306 | + self.load(gene_docs) |
| 307 | + gene_docs.clear() |
| 308 | + end_time = datetime.datetime.now() |
| 309 | + time_diff = relativedelta(end_time, start_time) |
| 310 | + self.dev_logger.info( |
| 311 | + f" completed, total runtime: {time_diff.hours}h, {time_diff.minutes}m, {time_diff.seconds}s" |
| 312 | + ) |
0 commit comments