From b077cce2c9bf39cfb4f43b486701108159d616f8 Mon Sep 17 00:00:00 2001 From: Pablo Rodriguez Mier Date: Tue, 25 Feb 2025 13:07:17 +0100 Subject: [PATCH 1/5] Use scipy for stats in sc_dist, add option for batching --- src/generators/models/sc_dist.py | 82 +++++++++++++++++++------------- 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/src/generators/models/sc_dist.py b/src/generators/models/sc_dist.py index 18dd0cb..7210a96 100644 --- a/src/generators/models/sc_dist.py +++ b/src/generators/models/sc_dist.py @@ -21,12 +21,14 @@ def __init__(self, config: Dict[str, Any]): self.distribution = self.generator_config["distribution"] # Either 'NB' or 'Poisson' self.cell_type_col_name = self.dataset_config["cell_type_col_name"] self.cell_label_col_name = self.dataset_config["cell_label_col_name"] + self.batch_size = self.generator_config.get("batch_size", None) # Optional batch size # Parameters for the data generation self.gene_means = None self.num_samples = None self.X_train_features = None self.cell_type_params = {} + self.max_real_value = None self.initialize_random_seeds() @@ -36,38 +38,44 @@ def initialize_random_seeds(self): def train(self): """Compute gene expression parameters for each cell type from training data.""" X_train_adata = self.load_train_anndata() - - counts = X_train_adata.X.toarray() if isinstance(X_train_adata.X, np.ndarray) else X_train_adata.X.A + counts = X_train_adata.X cell_types = X_train_adata.obs[self.cell_type_col_name].values cell_labels = X_train_adata.obs[self.cell_label_col_name].values self.cell_type_to_label = dict(set(zip(cell_types, cell_labels))) - print("Cell Type to Label Mapping:", self.cell_type_to_label) - # Store the max gene expression value from training - self.max_real_value = counts.max() + # Determine max real expression value without converting sparse data to dense + if sp.issparse(counts): + self.max_real_value = counts.data.max() if counts.data.size > 0 else 0 + else: + self.max_real_value = counts.max() print(f"Max real expression value from training: {self.max_real_value}") - for cell_type in np.unique(cell_types): + unique_cell_types = np.unique(cell_types) + for cell_type in unique_cell_types: print(f"Training on Cell Type: {cell_type}") - cell_type_mask = cell_types == cell_type cell_type_counts = counts[cell_type_mask, :] - means = cell_type_counts.mean(axis=0) + if sp.issparse(cell_type_counts): + # Compute means and variances on sparse matrices: + means = np.array(cell_type_counts.mean(axis=0)).ravel() + # For variance: Var(X)=E[X^2] - (E[X])^2 + sq_means = np.array(cell_type_counts.power(2).mean(axis=0)).ravel() + variances = sq_means - means**2 + else: + means = cell_type_counts.mean(axis=0) + variances = cell_type_counts.var(axis=0) + means = np.clip(means, 1e-6, None) # Avoid zero means if self.distribution == 'NB': - variances = cell_type_counts.var(axis=0) - - # variance >= mean to prevent negative dispersions + # Ensure variance is at least the mean variances = np.maximum(variances, means) - dispersions = (variances - means) / (means ** 2) dispersions = np.clip(dispersions, 1e-3, 10) # Avoid extreme values - # debugging print(f"Dispersion values for {cell_type}: min={dispersions.min()}, max={dispersions.max()}") if np.any(np.isnan(dispersions)): @@ -85,21 +93,20 @@ def train(self): print("Training completed successfully!") - - def generate(self): if self.max_real_value is None: raise ValueError("Training must be completed before generating data!") X_test_adata = self.load_test_anndata() - counts = X_test_adata.X.toarray() if isinstance(X_test_adata.X, np.ndarray) else X_test_adata.X.A - print("Original counts shape:", counts.shape) + counts_shape = X_test_adata.X.shape + print("Original counts shape:", counts_shape) cell_types = X_test_adata.obs[self.cell_type_col_name].values - synthetic_counts = sp.lil_matrix(counts.shape, dtype=np.int64) + synthetic_counts = sp.lil_matrix(counts_shape, dtype=np.int64) synthetic_cell_types = [] - for cell_type in np.unique(cell_types): + unique_cell_types = np.unique(cell_types) + for cell_type in unique_cell_types: print(f"Generating for Cell Type: {cell_type}") if str(cell_type) not in self.cell_type_params: @@ -118,29 +125,38 @@ def generate(self): dispersions = np.clip(dispersions, 1e-3, 10) # Prevent extreme values # Compute Negative Binomial parameters - n_param = np.clip(1 / (dispersions + 1e-6), 1e-2, 10) - p_param = np.clip(means / (means + n_param), 0.01, 0.99) + n_param = np.clip(1 / (dispersions + 1e-6), 1e-2, 10) + p_param = np.clip(means / (means + n_param), 0.01, 0.99) - # Debugging prints print(f"n_param range for {cell_type}: min={n_param.min()}, max={n_param.max()}") print(f"p_param range for {cell_type}: min={p_param.min()}, max={p_param.max()}") expected_variance = means + (means ** 2) / n_param print(f"Expected variance for {cell_type}: min={expected_variance.min()}, max={expected_variance.max()}") - # Generate Negative Binomial samples - generated_data = st.nbinom.rvs(n=n_param, p=p_param, size=(num_cells, means.shape[0])).astype(np.int64) + # Use batch processing if batch_size is specified, otherwise process all cells at once + batch_size = self.batch_size if self.batch_size is not None else num_cells - elif self.distribution == 'Poisson': - generated_data = st.poisson.rvs(means, size=(num_cells, means.shape[0])).astype(np.int64) + for start in range(0, num_cells, batch_size): + end = min(start + batch_size, num_cells) + current_batch_size = end - start - # Limit extreme values to prevent memory explosion - upper_clip = np.percentile(generated_data, 99.5) - generated_data = np.clip(generated_data, 0, min(upper_clip, self.max_real_value * 2)) + if self.distribution == 'NB': + batch_generated = st.nbinom.rvs( + n=n_param, p=p_param, size=(current_batch_size, means.shape[0]) + ).astype(np.int64) + elif self.distribution == 'Poisson': + batch_generated = st.poisson.rvs( + means, size=(current_batch_size, means.shape[0]) + ).astype(np.int64) - # Store generated data - synthetic_counts[cell_indices, :] = generated_data - synthetic_cell_types.extend([cell_type] * num_cells) + # Limit extreme values to prevent memory explosion + upper_clip = np.percentile(batch_generated, 99.5) + batch_generated = np.clip(batch_generated, 0, min(upper_clip, self.max_real_value * 2)) + + indices = cell_indices[start:end] + synthetic_counts[indices, :] = batch_generated + synthetic_cell_types.extend([cell_type] * current_batch_size) synthetic_counts_csr = synthetic_counts.tocsr().astype(np.int64) synthetic_adata = ad.AnnData(X=synthetic_counts_csr) @@ -149,7 +165,5 @@ def generate(self): return synthetic_adata - def load_from_checkpoint(self): pass - From 08dfc847bcc324845af2e001c96cd104dbe07e45 Mon Sep 17 00:00:00 2001 From: Pablo Rodriguez Mier Date: Tue, 25 Feb 2025 13:22:45 +0100 Subject: [PATCH 2/5] Support the use of sparse matrices during evaluation --- src/evaluation/sc_evaluate.py | 132 +++++-------- src/evaluation/utils/sc_metrics.py | 302 +++++++++++++---------------- 2 files changed, 180 insertions(+), 254 deletions(-) diff --git a/src/evaluation/sc_evaluate.py b/src/evaluation/sc_evaluate.py index 2c88ca8..8dc24b2 100644 --- a/src/evaluation/sc_evaluate.py +++ b/src/evaluation/sc_evaluate.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import scanpy as sc +from scipy.sparse import issparse # Import for sparse checks src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(src_dir) @@ -13,9 +14,6 @@ from evaluation.utils.sc_metrics import (filter_low_quality_cells_and_genes, Statistics, VisualizeClassify) - - - def check_dirs(path): if not os.path.exists(path): os.makedirs(path) @@ -32,49 +30,41 @@ def __init__(self, config): self.save_dir = os.path.join(self.home_dir, "data_splits") self.random_seed = config["evaluator_config"]["random_seed"] - ## experiment name self.experiment_name = self.config['generator_config']['experiment_name'] self.generator_name = self.config['generator_config']['name'] - self.res_figures_dir = os.path.join(self.home_dir, - config["dir_list"]["figures"], - self.dataset_name, - self.generator_name, + self.res_figures_dir = os.path.join(self.home_dir, + config["dir_list"]["figures"], + self.dataset_name, + self.generator_name, self.experiment_name ) - self.res_files_dir = os.path.join(self.home_dir, - config["dir_list"]["res_files"], - self.dataset_name, - self.generator_name, + self.res_files_dir = os.path.join(self.home_dir, + config["dir_list"]["res_files"], + self.dataset_name, + self.generator_name, self.experiment_name) - check_dirs( self.res_figures_dir) - check_dirs( self.res_files_dir) + check_dirs(self.res_figures_dir) + check_dirs(self.res_files_dir) - self.synthetic_data_path = os.path.join(self.save_dir, - self.dataset_name, + self.synthetic_data_path = os.path.join(self.save_dir, + self.dataset_name, "synthetic", self.generator_name, - self.experiment_name, - ) + self.experiment_name) self.celltypist_model_path = os.path.join(self.home_dir, self.dataset_config["celltypist_model"]) self.results = {} - @staticmethod def save_split_results(results, output_file): df = pd.DataFrame([results]) df.to_csv(output_file, index=False) - def load_test_anndata(self): try: test_data_pth = os.path.join(self.home_dir, self.dataset_config["test_count_file"]) test_data = sc.read_h5ad(test_data_pth) - #cell_types = test_data.obs[self.cell_type_col].values - #cell_labels = test_data.obs[self.cell_label_col].values - - #self.cell_type_to_label = dict(set(zip(cell_types, cell_labels))) test_data.obs[self.cell_label_col] = ( test_data.obs[self.cell_label_col] @@ -82,43 +72,45 @@ def load_test_anndata(self): .str.replace(" ", "_", regex=True) ) - X_dense = test_data.X.toarray() if hasattr(test_data.X, "toarray") else test_data.X - # Count NaN and Inf values - nan_count = np.isnan(X_dense).sum() - inf_count = np.isinf(X_dense).sum() + # Instead of converting to dense, check for NaN and Inf directly + X = test_data.X + if issparse(X): + nan_count = np.isnan(X.data).sum() + inf_count = np.isinf(X.data).sum() + else: + nan_count = np.isnan(X).sum() + inf_count = np.isinf(X).sum() if nan_count > 0 or inf_count > 0: raise ValueError(f"Test data contains {nan_count} NaN values and {inf_count} Inf values.") print(test_data) return test_data - except: - raise Exception(f"Failed to load test anndata.") - + except Exception as e: + raise Exception(f"Failed to load test anndata: {e}") + def load_synthetic_anndata(self): - try: + try: syn_data_pth = os.path.join(self.synthetic_data_path, "onek1k_annotated_synthetic.h5ad") syn_data = sc.read_h5ad(syn_data_pth) - #syn_data.obs[self.cell_label_col] = ( - # syn_data.obs[self.cell_label_col] - # .astype(str) - # .str.replace(" ", "_", regex=True) - #) - X_dense = syn_data.X.toarray() if hasattr(syn_data.X, "toarray") else syn_data.X - # Count NaN and Inf values - nan_count = np.isnan(X_dense).sum() - inf_count = np.isinf(X_dense).sum() + # Check for NaN and Inf values without converting to dense + X = syn_data.X + if issparse(X): + nan_count = np.isnan(X.data).sum() + inf_count = np.isinf(X.data).sum() + else: + nan_count = np.isnan(X).sum() + inf_count = np.isinf(X).sum() if nan_count > 0 or inf_count > 0: - raise ValueError(f"Test data contains {nan_count} NaN values and {inf_count} Inf values.") + raise ValueError(f"Synthetic data contains {nan_count} NaN values and {inf_count} Inf values.") print(syn_data) return syn_data - except: - raise Exception(f"Failed to load synthetic anndata.") - + except Exception as e: + raise Exception(f"Failed to load synthetic anndata: {e}") def initialize_datasets(self): test_anndata = self.load_test_anndata() @@ -137,7 +129,6 @@ def initialize_datasets(self): print(f"After gene alignment - Real: {real_data.n_vars}, Synthetic: {synthetic_data.n_vars}") return real_data, synthetic_data - def get_statistical_evals(self): real_data, synthetic_data = self.initialize_datasets() @@ -154,110 +145,81 @@ def get_statistical_evals(self): 'ari_real_vs_syn': ari_real_syn, 'ari_gt_vs_comb': ari_gt_comb } - - def get_umap_evals(self, n_hvgs:int): + def get_umap_evals(self, n_hvgs: int): real_data, synthetic_data = self.initialize_datasets() visual = VisualizeClassify(self.res_figures_dir, self.random_seed) visual.plot_umap(real_data, synthetic_data, n_hvgs) - - def get_classification_evals(self): real_data, synthetic_data = self.initialize_datasets() classfier = VisualizeClassify(self.res_figures_dir, self.random_seed) - ari_score, jaccard = classfier.celltypist_classification(real_data, - synthetic_data, + ari_score, jaccard = classfier.celltypist_classification(real_data, + synthetic_data, self.celltypist_model_path) roc_score, _ = classfier.random_forest_eval(real_data, synthetic_data) return { - "celltypist_ari": ari_score, "celltypist_jaccard": jaccard, "randomforest_roc": roc_score, } - - @staticmethod def save_results_to_csv(results, output_file): df = pd.DataFrame([results]) df.to_csv(output_file, index=False) - @click.group() def cli(): pass - - @click.command() def run_statistical_eval(): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) results = evaluator.get_statistical_evals() - + output_file = os.path.join(evaluator.res_files_dir, f"statistics_evals.csv") - ## evaluator.save_results_to_csv(results, output_file) click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") - - @click.command() -#@click.argument("sc_figures_dir", type=str) @click.argument("n_hvgs", type=int, default=2000) def run_umap_eval(n_hvgs): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) evaluator.get_umap_evals(n_hvgs) - - @click.command() @click.argument("cell_label", type=str, default="CD4 ET") -def run_qq_eval(cell_label:str): +def run_qq_eval(cell_label: str): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) evaluator.save_qq_evals(cell_label=cell_label) - - - - -### function runs for an individual split -### results are saved under -### results/files/{dataset_name}/{model_name}/{experiment_name} @click.command() def run_classification_eval(): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) results = evaluator.get_classification_evals() - + output_file = os.path.join(evaluator.res_files_dir, f"classification_evals.csv") - ## evaluator.save_results_to_csv(results, output_file) click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") - - - cli.add_command(run_classification_eval) cli.add_command(run_umap_eval) cli.add_command(run_statistical_eval) cli.add_command(run_qq_eval) - if __name__ == '__main__': cli() - - diff --git a/src/evaluation/utils/sc_metrics.py b/src/evaluation/utils/sc_metrics.py index 131fe7a..c886383 100644 --- a/src/evaluation/utils/sc_metrics.py +++ b/src/evaluation/utils/sc_metrics.py @@ -1,182 +1,190 @@ import umap -import os +import os import numpy as np import scanpy as sc import scipy.stats as stats -import scipy.sparse +import scipy.sparse import matplotlib.pyplot as plt import seaborn as sns from scipy.spatial.distance import cdist from scipy.stats import spearmanr from sklearn.model_selection import train_test_split -from sklearn.decomposition import PCA +from sklearn.decomposition import PCA, TruncatedSVD from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import LabelBinarizer from sklearn.metrics.pairwise import rbf_kernel -from sklearn.metrics import (adjusted_rand_score, roc_auc_score, - jaccard_score) +from sklearn.metrics import adjusted_rand_score, roc_auc_score, jaccard_score from scib.metrics import ilisi_graph import celltypist -#from celltypist import models -#from celltypist.models import Model from scipy.sparse import issparse - - - def filter_low_quality_cells_and_genes(adata, min_counts=10, min_cells=3): - adata = adata.copy() + """ + Filters cells and genes based on minimum counts. + Uses Scanpy’s built-in filtering functions (which are sparse-aware). + """ + adata = adata.copy() sc.pp.filter_cells(adata, min_counts=min_counts) sc.pp.filter_genes(adata, min_cells=min_cells) - return adata - -def is_sparse(adata): - if scipy.sparse.issparse(adata.X): - return adata.X.toarray() # Convert sparse to dense - return adata.X # Already dense - -def to_dense(adata): - return adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X - +def get_dense_column(adata, i): + """ + Returns the i-th column of adata.X as a dense vector. + This avoids converting the entire matrix to dense at once. + """ + X = adata.X + if issparse(X): + return X[:, i].toarray().ravel() + else: + return np.array(X[:, i]).ravel() def check_for_inf_nan(adata, label): - X = adata.X.toarray() if issparse(adata.X) else np.array(adata.X) + """ + Checks for NaN/Inf values in adata.X without converting the whole matrix. + """ + X = adata.X + if issparse(X): + data = X.data + else: + data = np.array(X) print(f"Checking {label} dataset:") - print(f"NaNs? {np.isnan(X).any()}") - print(f"Infs? {np.isinf(X).any()}") - print(f"Min: {np.min(X)}, Max: {np.max(X)}\n") - - + print(f"NaNs? {np.isnan(data).any()}") + print(f"Infs? {np.isinf(data).any()}") + print(f"Min: {data.min()}, Max: {data.max()}\n") def check_missing_genes(real_data, synthetic_data): - # Convert to sets for easy comparison + """ + Compares gene names between real and synthetic datasets. + """ real_genes = set(real_data.var_names) synthetic_genes = set(synthetic_data.var_names) - - # Find missing genes missing_in_real = synthetic_genes - real_genes missing_in_synthetic = real_genes - synthetic_genes print(f"Genes in synthetic but not in real: {len(missing_in_real)}") print(f"Genes in real but not in synthetic: {len(missing_in_synthetic)}") - - # Print some missing genes print(f"Example missing in real: {list(missing_in_real)[:10]}") print(f"Example missing in synthetic: {list(missing_in_synthetic)[:10]}") - print(f"real_data.var_names dtype: {real_data.var_names.dtype}") print(f"synthetic_data.var_names dtype: {synthetic_data.var_names.dtype}") - class Statistics: def __init__(self, random_seed=42): self.random_seed = random_seed np.random.seed(self.random_seed) def compute_scc(self, real_data, synthetic_data, n_hvgs=5000): + """ + Computes the mean Spearman correlation across highly variable genes (HVGs) + between the real and synthetic datasets. Instead of converting the whole + expression matrix to dense, each gene column is converted on the fly. + """ np.random.seed(self.random_seed) - check_missing_genes(real_data, synthetic_data ) - real_data = real_data[:, synthetic_data.var_names] - synthetic_data = synthetic_data[:, real_data.var_names] + check_missing_genes(real_data, synthetic_data) + # Align genes using the gene names from synthetic_data + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] check_for_inf_nan(real_data, "Real") check_for_inf_nan(synthetic_data, "Synthetic") - # Normalize both datasets + # Normalize and log-transform both datasets sc.pp.normalize_total(real_data, target_sum=1e4) sc.pp.log1p(real_data) - sc.pp.normalize_total(synthetic_data, target_sum=1e4) sc.pp.log1p(synthetic_data) check_for_inf_nan(real_data, "Real") check_for_inf_nan(synthetic_data, "Synthetic") - # Identify HVGs + # Identify HVGs using the combined dataset combined_adata = real_data.concatenate(synthetic_data) sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) # Subset to HVGs - real_hvg = real_data[:, combined_adata.var["highly_variable"]] - synth_hvg = synthetic_data[:, combined_adata.var["highly_variable"]] - - # Convert to dense format - real_exp = is_sparse(real_hvg) - synth_exp = is_sparse(synth_hvg) - - # Compute Spearman correlation per gene (HVGs only) - scc_values = np.array([ - stats.spearmanr(real_exp[:, i], synth_exp[:, i], nan_policy='omit')[0] - for i in range(real_exp.shape[1]) - ]) - - # Handle NaNs + hvgs = combined_adata.var["highly_variable"] + real_hvg = real_data[:, hvgs] + synth_hvg = synthetic_data[:, hvgs] + + # Compute Spearman correlation gene-by-gene + scc_values = [] + for i in range(real_hvg.n_vars): + real_vec = get_dense_column(real_hvg, i) + synth_vec = get_dense_column(synth_hvg, i) + corr, _ = stats.spearmanr(real_vec, synth_vec, nan_policy='omit') + scc_values.append(corr) + scc_values = np.array(scc_values) return np.nanmean(scc_values) if not np.all(np.isnan(scc_values)) else np.nan - def compute_mmd_optimized(self, real_data, synthetic_data, sample_size=20000, n_pca=50, gamma=1.0, n_hvgs=5000): - # Ensure both datasets have the same gene order + """ + Computes the Maximum Mean Discrepancy (MMD) between the two datasets. + Uses a sparse-aware subsampling and PCA approach. If the HVG data remains sparse, + TruncatedSVD is used instead of the standard PCA. + """ np.random.seed(self.random_seed) - real_data, synthetic_data = real_data[:, synthetic_data.var_names], synthetic_data[:, real_data.var_names] + # Align genes using synthetic_data's ordering + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] - # Identify HVGs combined_adata = real_data.concatenate(synthetic_data) - sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) - # Subset to HVGs - real_hvg = real_data[:, combined_adata.var["highly_variable"]] - synth_hvg = synthetic_data[:, combined_adata.var["highly_variable"]] - - # Convert sparse to dense - real_dense = real_hvg.X.toarray() if scipy.sparse.issparse(real_hvg.X) else real_hvg.X - synth_dense = synth_hvg.X.toarray() if scipy.sparse.issparse(synth_hvg.X) else synth_hvg.X - - # Subsample - real_idx = np.random.choice(real_dense.shape[0], - min(sample_size, real_dense.shape[0]), - replace=False) - synth_idx = np.random.choice(synth_dense.shape[0], - min(sample_size, synth_dense.shape[0]), - replace=False) - - real_sample = real_dense[real_idx] - synth_sample = synth_dense[synth_idx] - - # Combine for PCA - combined_sample = np.vstack([real_sample, synth_sample]) - pca = PCA(n_components=n_pca, random_state=self.random_seed) - combined_pca = pca.fit_transform(combined_sample) - - # Split PCA results - real_pca = combined_pca[: len(real_sample)] - synth_pca = combined_pca[len(real_sample) :] - - # Compute MMD + hvgs = combined_adata.var["highly_variable"] + real_hvg = real_data[:, hvgs] + synth_hvg = synthetic_data[:, hvgs] + + n_real = real_hvg.n_obs + n_synth = synth_hvg.n_obs + + real_idx = np.random.choice(n_real, min(sample_size, n_real), replace=False) + synth_idx = np.random.choice(n_synth, min(sample_size, n_synth), replace=False) + + # Process sparse or dense data accordingly + if issparse(real_hvg.X): + real_sample = real_hvg.X[real_idx] + synth_sample = synth_hvg.X[synth_idx] + from scipy.sparse import vstack + combined_sample = vstack([real_sample, synth_sample]) + pca_model = TruncatedSVD(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca_model.fit_transform(combined_sample) + else: + real_sample = real_hvg.X[real_idx] + synth_sample = synth_hvg.X[synth_idx] + combined_sample = np.vstack([real_sample, synth_sample]) + pca_model = PCA(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca_model.fit_transform(combined_sample) + + real_pca = combined_pca[:len(real_sample)] + synth_pca = combined_pca[len(real_sample):] + K_xx = rbf_kernel(real_pca, real_pca, gamma=gamma).mean() K_yy = rbf_kernel(synth_pca, synth_pca, gamma=gamma).mean() K_xy = rbf_kernel(real_pca, synth_pca, gamma=gamma).mean() return K_xx + K_yy - 2 * K_xy - - # Goal: Measure the mixing of real and synthetic cells in a shared space. def compute_lisi(self, real_data, synthetic_data, n_hvgs=5000): - # Ensure both datasets have the same genes + """ + Computes the Local Inverse Simpson’s Index (LISI) to measure mixing + of real and synthetic cells in a shared low-dimensional space. + """ np.random.seed(self.random_seed) - real_data = real_data[:, synthetic_data.var_names] - synthetic_data = synthetic_data[:, real_data.var_names] + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) - # Assign batch labels (0 = real, 1 = synthetic) + # Create a numeric batch label (0 = real, 1 = synthetic) combined_adata.obs["batch"] = (combined_adata.obs["source"] == "synthetic").astype(int) sc.pp.normalize_total(combined_adata, target_sum=1e4) @@ -184,27 +192,20 @@ def compute_lisi(self, real_data, synthetic_data, n_hvgs=5000): sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - # ---- Downsampling (if enabled) ---- - #if use_downsample: - # sample_size = int(downsample_ratio * combined_adata.shape[0]) # Compute sample size - # sampled_idx = np.random.choice(combined_adata.shape[0], size=sample_size, replace=False) - # combined_adata = combined_adata[sampled_idx, :] - - # Perform PCA - sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) + sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') - return ilisi_graph(combined_adata, batch_key="batch", type_="knn") - - - # Goal: Measure how well real & synthetic cells cluster into the same types. def compute_ari(self, real_data, synthetic_data, cell_type_col, n_hvgs=5000): - # Ensure both datasets have the same genes + """ + Computes the Adjusted Rand Index (ARI) to measure clustering consistency + between real and synthetic data. Clusters are obtained via Scanpy's Louvain. + """ np.random.seed(self.random_seed) - real_data = real_data[:, synthetic_data.var_names] - synthetic_data = synthetic_data[:, real_data.var_names] + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) @@ -214,160 +215,123 @@ def compute_ari(self, real_data, synthetic_data, cell_type_col, n_hvgs=5000): sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - - # Perform PCA - sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) + sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') sc.tl.louvain(combined_adata) # Convert Louvain clusters to numerical labels combined_adata.obs["louvain"] = combined_adata.obs["louvain"].astype("category").cat.codes - real_clusters = combined_adata.obs.loc[ - combined_adata.obs["source"] == "real", "louvain"].values - synthetic_clusters = combined_adata.obs.loc[ - combined_adata.obs["source"] == "synthetic", "louvain"].values + real_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "real", "louvain"].values + synthetic_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "synthetic", "louvain"].values ari_real_vs_syn = adjusted_rand_score(real_clusters, synthetic_clusters) - ari_gt_vs_comb = adjusted_rand_score(combined_adata.obs[cell_type_col], - combined_adata.obs["louvain"]) + ari_gt_vs_comb = adjusted_rand_score(combined_adata.obs[cell_type_col], combined_adata.obs["louvain"]) return ari_real_vs_syn, ari_gt_vs_comb - - - class VisualizeClassify: - ### add figures_dir = def __init__(self, sc_figures_dir, random_seed=42): self.random_seed = random_seed self.sc_figures_dir = sc_figures_dir np.random.seed(self.random_seed) - # self.figures_dir = figures_dir - ## get example name instead of sc_figures dir def plot_umap(self, real_data, synthetic_data, n_hvgs=5000): + """ + Creates and saves a UMAP plot of the combined real and synthetic data. + """ sc.settings.figdir = self.sc_figures_dir np.random.seed(self.random_seed) - # Combine datasets with batch labels check_for_inf_nan(real_data, "Real") check_for_inf_nan(synthetic_data, "Synthetic") combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) - sc.pp.normalize_total(combined_adata, target_sum=1e4 ) + sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - # Perform PCA & UMAP sc.pp.pca(combined_adata, random_state=self.random_seed) sc.pp.neighbors(combined_adata) sc.tl.umap(combined_adata, random_state=self.random_seed) - # Plot UMAP - sc.pl.umap(combined_adata, - color=["source"], + sc.pl.umap(combined_adata, + color=["source"], title="UMAP of Real vs Synthetic Data", - save = f"syn_test_PCA_HVG={n_hvgs}.png") - + save=f"syn_test_PCA_HVG={n_hvgs}.png") def celltypist_classification(self, real_data_test, synthetic_data, celltypist_model, n_hvgs=5000): + """ + Uses a CellTypist model to annotate cells from both datasets and then compares + the predicted labels via ARI and Jaccard scores. + """ np.random.seed(self.random_seed) - - # Combine datasets for HVG selection combined_adata = real_data_test.concatenate(synthetic_data) - - # Normalize before selecting HVGs sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) - ### Normalize and logp + # Normalize and log-transform each dataset individually sc.pp.normalize_total(real_data_test, target_sum=1e4) sc.pp.log1p(real_data_test) - - ### Normalize and logp sc.pp.normalize_total(synthetic_data, target_sum=1e4) sc.pp.log1p(synthetic_data) - # Subset both datasets to HVGs real_data_test = real_data_test[:, combined_adata.var['highly_variable']] synthetic_data = synthetic_data[:, combined_adata.var['highly_variable']] - # Load CellTypist model model = celltypist.models.Model.load(celltypist_model) real_predictions = celltypist.annotate(real_data_test, model=model) synthetic_predictions = celltypist.annotate(synthetic_data, model=model) - # Extract predicted labels real_labels = real_predictions.predicted_labels.values.ravel() synthetic_labels = synthetic_predictions.predicted_labels.values.ravel() - # Compute ARI score ari_score = adjusted_rand_score(real_labels, synthetic_labels) - # Compute Jaccard score for multi-class labels lb = LabelBinarizer() real_onehot = lb.fit_transform(real_labels) synthetic_onehot = lb.transform(synthetic_labels) jaccard_scores = [ - jaccard_score(real_onehot[:, i], synthetic_onehot[:, i]) + jaccard_score(real_onehot[:, i], synthetic_onehot[:, i]) for i in range(real_onehot.shape[1]) ] jaccard = np.mean(jaccard_scores) return ari_score, jaccard - - ## whether it can separate synthetic vs real def random_forest_eval(self, real_data, synthetic_data, n_hvgs=5000): + """ + Evaluates how well a Random Forest can separate real vs. synthetic cells. + After batch correction, the expression matrix is converted to dense only once. + """ np.random.seed(self.random_seed) - - # Explicitly label real vs. synthetic real_data.obs["source"] = "real" synthetic_data.obs["source"] = "synthetic" - # Concatenate datasets combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) - # Normalize & log transform sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) - - # Select highly variable genes (HVGs) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - # **Batch correction using Combat** - sc.pp.combat(combined_adata, key="source") # Removes batch effect + sc.pp.combat(combined_adata, key="source") - # Convert sparse to dense if needed + # Convert to dense only at this final stage (necessary for RandomForest) X = combined_adata.X.A if hasattr(combined_adata.X, "A") else combined_adata.X - - # Assign labels: 0 = real, 1 = synthetic y = (combined_adata.obs["source"] == "synthetic").astype(int).values - # Train-test split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=self.random_seed) - # Train Random Forest rf = RandomForestClassifier(n_estimators=1000, max_depth=5, random_state=self.random_seed) rf.fit(X_train, y_train) - # Predict probabilities and compute AUC pred_probs = rf.predict_proba(X_test)[:, 1] auc = roc_auc_score(y_test, pred_probs) return auc, pred_probs - - - - - - - - From 4dad32bafc6221310d8ce61948183045210631f4 Mon Sep 17 00:00:00 2001 From: Pablo Rodriguez Mier Date: Wed, 26 Feb 2025 11:09:14 +0100 Subject: [PATCH 3/5] support experimental eval. with sparse data --- src/evaluation/sc_evaluate.py | 132 +++++---- src/evaluation/sc_evaluate_opt.py | 232 +++++++++++++++ src/evaluation/utils/sc_metrics.py | 302 +++++++++++--------- src/evaluation/utils/sc_metrics_opt.py | 376 +++++++++++++++++++++++++ 4 files changed, 862 insertions(+), 180 deletions(-) create mode 100644 src/evaluation/sc_evaluate_opt.py create mode 100644 src/evaluation/utils/sc_metrics_opt.py diff --git a/src/evaluation/sc_evaluate.py b/src/evaluation/sc_evaluate.py index 8dc24b2..2c88ca8 100644 --- a/src/evaluation/sc_evaluate.py +++ b/src/evaluation/sc_evaluate.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd import scanpy as sc -from scipy.sparse import issparse # Import for sparse checks src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) sys.path.append(src_dir) @@ -14,6 +13,9 @@ from evaluation.utils.sc_metrics import (filter_low_quality_cells_and_genes, Statistics, VisualizeClassify) + + + def check_dirs(path): if not os.path.exists(path): os.makedirs(path) @@ -30,41 +32,49 @@ def __init__(self, config): self.save_dir = os.path.join(self.home_dir, "data_splits") self.random_seed = config["evaluator_config"]["random_seed"] + ## experiment name self.experiment_name = self.config['generator_config']['experiment_name'] self.generator_name = self.config['generator_config']['name'] - self.res_figures_dir = os.path.join(self.home_dir, - config["dir_list"]["figures"], - self.dataset_name, - self.generator_name, + self.res_figures_dir = os.path.join(self.home_dir, + config["dir_list"]["figures"], + self.dataset_name, + self.generator_name, self.experiment_name ) - self.res_files_dir = os.path.join(self.home_dir, - config["dir_list"]["res_files"], - self.dataset_name, - self.generator_name, + self.res_files_dir = os.path.join(self.home_dir, + config["dir_list"]["res_files"], + self.dataset_name, + self.generator_name, self.experiment_name) - check_dirs(self.res_figures_dir) - check_dirs(self.res_files_dir) + check_dirs( self.res_figures_dir) + check_dirs( self.res_files_dir) - self.synthetic_data_path = os.path.join(self.save_dir, - self.dataset_name, + self.synthetic_data_path = os.path.join(self.save_dir, + self.dataset_name, "synthetic", self.generator_name, - self.experiment_name) + self.experiment_name, + ) self.celltypist_model_path = os.path.join(self.home_dir, self.dataset_config["celltypist_model"]) self.results = {} + @staticmethod def save_split_results(results, output_file): df = pd.DataFrame([results]) df.to_csv(output_file, index=False) + def load_test_anndata(self): try: test_data_pth = os.path.join(self.home_dir, self.dataset_config["test_count_file"]) test_data = sc.read_h5ad(test_data_pth) + #cell_types = test_data.obs[self.cell_type_col].values + #cell_labels = test_data.obs[self.cell_label_col].values + + #self.cell_type_to_label = dict(set(zip(cell_types, cell_labels))) test_data.obs[self.cell_label_col] = ( test_data.obs[self.cell_label_col] @@ -72,45 +82,43 @@ def load_test_anndata(self): .str.replace(" ", "_", regex=True) ) - # Instead of converting to dense, check for NaN and Inf directly - X = test_data.X - if issparse(X): - nan_count = np.isnan(X.data).sum() - inf_count = np.isinf(X.data).sum() - else: - nan_count = np.isnan(X).sum() - inf_count = np.isinf(X).sum() + X_dense = test_data.X.toarray() if hasattr(test_data.X, "toarray") else test_data.X + # Count NaN and Inf values + nan_count = np.isnan(X_dense).sum() + inf_count = np.isinf(X_dense).sum() if nan_count > 0 or inf_count > 0: raise ValueError(f"Test data contains {nan_count} NaN values and {inf_count} Inf values.") print(test_data) return test_data - except Exception as e: - raise Exception(f"Failed to load test anndata: {e}") - + except: + raise Exception(f"Failed to load test anndata.") + def load_synthetic_anndata(self): - try: + try: syn_data_pth = os.path.join(self.synthetic_data_path, "onek1k_annotated_synthetic.h5ad") syn_data = sc.read_h5ad(syn_data_pth) + #syn_data.obs[self.cell_label_col] = ( + # syn_data.obs[self.cell_label_col] + # .astype(str) + # .str.replace(" ", "_", regex=True) + #) - # Check for NaN and Inf values without converting to dense - X = syn_data.X - if issparse(X): - nan_count = np.isnan(X.data).sum() - inf_count = np.isinf(X.data).sum() - else: - nan_count = np.isnan(X).sum() - inf_count = np.isinf(X).sum() + X_dense = syn_data.X.toarray() if hasattr(syn_data.X, "toarray") else syn_data.X + # Count NaN and Inf values + nan_count = np.isnan(X_dense).sum() + inf_count = np.isinf(X_dense).sum() if nan_count > 0 or inf_count > 0: - raise ValueError(f"Synthetic data contains {nan_count} NaN values and {inf_count} Inf values.") + raise ValueError(f"Test data contains {nan_count} NaN values and {inf_count} Inf values.") print(syn_data) return syn_data - except Exception as e: - raise Exception(f"Failed to load synthetic anndata: {e}") + except: + raise Exception(f"Failed to load synthetic anndata.") + def initialize_datasets(self): test_anndata = self.load_test_anndata() @@ -129,6 +137,7 @@ def initialize_datasets(self): print(f"After gene alignment - Real: {real_data.n_vars}, Synthetic: {synthetic_data.n_vars}") return real_data, synthetic_data + def get_statistical_evals(self): real_data, synthetic_data = self.initialize_datasets() @@ -145,81 +154,110 @@ def get_statistical_evals(self): 'ari_real_vs_syn': ari_real_syn, 'ari_gt_vs_comb': ari_gt_comb } + - def get_umap_evals(self, n_hvgs: int): + def get_umap_evals(self, n_hvgs:int): real_data, synthetic_data = self.initialize_datasets() visual = VisualizeClassify(self.res_figures_dir, self.random_seed) visual.plot_umap(real_data, synthetic_data, n_hvgs) + + def get_classification_evals(self): real_data, synthetic_data = self.initialize_datasets() classfier = VisualizeClassify(self.res_figures_dir, self.random_seed) - ari_score, jaccard = classfier.celltypist_classification(real_data, - synthetic_data, + ari_score, jaccard = classfier.celltypist_classification(real_data, + synthetic_data, self.celltypist_model_path) roc_score, _ = classfier.random_forest_eval(real_data, synthetic_data) return { + "celltypist_ari": ari_score, "celltypist_jaccard": jaccard, "randomforest_roc": roc_score, } + + @staticmethod def save_results_to_csv(results, output_file): df = pd.DataFrame([results]) df.to_csv(output_file, index=False) + @click.group() def cli(): pass + + @click.command() def run_statistical_eval(): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) results = evaluator.get_statistical_evals() - + output_file = os.path.join(evaluator.res_files_dir, f"statistics_evals.csv") + ## evaluator.save_results_to_csv(results, output_file) click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") + + @click.command() +#@click.argument("sc_figures_dir", type=str) @click.argument("n_hvgs", type=int, default=2000) def run_umap_eval(n_hvgs): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) evaluator.get_umap_evals(n_hvgs) + + @click.command() @click.argument("cell_label", type=str, default="CD4 ET") -def run_qq_eval(cell_label: str): +def run_qq_eval(cell_label:str): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) evaluator.save_qq_evals(cell_label=cell_label) + + + + +### function runs for an individual split +### results are saved under +### results/files/{dataset_name}/{model_name}/{experiment_name} @click.command() def run_classification_eval(): with open("config.yaml", 'r') as file: config = yaml.safe_load(file) - + evaluator = SingleCellEvaluator(config=config) results = evaluator.get_classification_evals() - + output_file = os.path.join(evaluator.res_files_dir, f"classification_evals.csv") + ## evaluator.save_results_to_csv(results, output_file) click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") + + + cli.add_command(run_classification_eval) cli.add_command(run_umap_eval) cli.add_command(run_statistical_eval) cli.add_command(run_qq_eval) + if __name__ == '__main__': cli() + + diff --git a/src/evaluation/sc_evaluate_opt.py b/src/evaluation/sc_evaluate_opt.py new file mode 100644 index 0000000..4a7e86e --- /dev/null +++ b/src/evaluation/sc_evaluate_opt.py @@ -0,0 +1,232 @@ +import os +import click +import yaml +import sys +import fnmatch +import numpy as np +import pandas as pd +import scanpy as sc +from scipy.sparse import issparse # Import for sparse checks + +src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(src_dir) + +from evaluation.utils.sc_metrics_opt import ( + filter_low_quality_cells_and_genes, + Statistics, VisualizeClassify +) + +def check_dirs(path): + if not os.path.exists(path): + os.makedirs(path) + +class SingleCellEvaluator: + def __init__(self, config): + self.config = config + self.home_dir = config["dir_list"]["home"] + self.dataset_config = config["dataset_config"] + self.dataset_name = self.dataset_config["name"] + self.cell_type_col = self.dataset_config["cell_type_col_name"] + self.cell_label_col = self.dataset_config["cell_label_col_name"] + + self.save_dir = os.path.join(self.home_dir, "data_splits") + self.random_seed = config["evaluator_config"]["random_seed"] + + ## experiment name + self.experiment_name = self.config['generator_config']['experiment_name'] + self.generator_name = self.config['generator_config']['name'] + self.res_figures_dir = os.path.join(self.home_dir, + config["dir_list"]["figures"], + self.dataset_name, + self.generator_name, + self.experiment_name + ) + self.res_files_dir = os.path.join(self.home_dir, + config["dir_list"]["res_files"], + self.dataset_name, + self.generator_name, + self.experiment_name) + check_dirs(self.res_figures_dir) + check_dirs(self.res_files_dir) + + self.synthetic_data_path = os.path.join(self.save_dir, + self.dataset_name, + "synthetic", + self.generator_name, + self.experiment_name) + self.celltypist_model_path = os.path.join(self.home_dir, + self.dataset_config["celltypist_model"]) + self.results = {} + + @staticmethod + def save_split_results(results, output_file): + df = pd.DataFrame([results]) + df.to_csv(output_file, index=False) + + def load_test_anndata(self): + try: + test_data_pth = os.path.join(self.home_dir, self.dataset_config["test_count_file"]) + test_data = sc.read_h5ad(test_data_pth) + + test_data.obs[self.cell_label_col] = ( + test_data.obs[self.cell_label_col] + .astype(str) + .str.replace(" ", "_", regex=True) + ) + + # Instead of converting to dense, check for NaN and Inf directly + X = test_data.X + if issparse(X): + nan_count = np.isnan(X.data).sum() + inf_count = np.isinf(X.data).sum() + else: + nan_count = np.isnan(X).sum() + inf_count = np.isinf(X).sum() + + if nan_count > 0 or inf_count > 0: + raise ValueError(f"Test data contains {nan_count} NaN values and {inf_count} Inf values.") + + print(test_data) + return test_data + except Exception as e: + raise Exception(f"Failed to load test anndata: {e}") + + + def load_synthetic_anndata(self): + try: + syn_data_pth = os.path.join(self.synthetic_data_path, "onek1k_annotated_synthetic.h5ad") + syn_data = sc.read_h5ad(syn_data_pth) + + # Check for NaN and Inf values without converting to dense + X = syn_data.X + if issparse(X): + nan_count = np.isnan(X.data).sum() + inf_count = np.isinf(X.data).sum() + else: + nan_count = np.isnan(X).sum() + inf_count = np.isinf(X).sum() + + if nan_count > 0 or inf_count > 0: + raise ValueError(f"Synthetic data contains {nan_count} NaN values and {inf_count} Inf values.") + + print(syn_data) + return syn_data + except Exception as e: + raise Exception(f"Failed to load synthetic anndata: {e}") + + def initialize_datasets(self): + test_anndata = self.load_test_anndata() + synthetic_anndata = self.load_synthetic_anndata() + + print(f"Initial gene count - Real: {test_anndata.n_vars}, Synthetic: {synthetic_anndata.n_vars}") + real_data = filter_low_quality_cells_and_genes(test_anndata) + synthetic_data = filter_low_quality_cells_and_genes(synthetic_anndata) + print(f"After filtering - Real: {real_data.n_vars}, Synthetic: {synthetic_data.n_vars}") + + # make sure both datasets have the same genes after filter + common_genes = real_data.var_names.intersection(synthetic_data.var_names) + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + + print(f"After gene alignment - Real: {real_data.n_vars}, Synthetic: {synthetic_data.n_vars}") + + return real_data, synthetic_data + + def get_statistical_evals(self): + real_data, synthetic_data = self.initialize_datasets() + stats = Statistics(self.random_seed) + print("Computing SCC...") + scc = stats.compute_scc(real_data, synthetic_data) + print("Computing MMD...") + mmd = stats.compute_mmd_optimized(real_data, synthetic_data) + print("Computing LISI...") + lisi = stats.compute_lisi(real_data, synthetic_data) + print("Computing ARI...") + ari_real_syn, ari_gt_comb = stats.compute_ari(real_data, synthetic_data, self.cell_type_col) + print("Done.") + + return { + 'scc': scc, + 'mmd': mmd, + 'lisi': lisi, + 'ari_real_vs_syn': ari_real_syn, + 'ari_gt_vs_comb': ari_gt_comb + } + + def get_umap_evals(self, n_hvgs: int): + real_data, synthetic_data = self.initialize_datasets() + visual = VisualizeClassify(self.res_figures_dir, self.random_seed) + visual.plot_umap(real_data, synthetic_data, n_hvgs) + + def get_classification_evals(self): + real_data, synthetic_data = self.initialize_datasets() + classfier = VisualizeClassify(self.res_figures_dir, self.random_seed) + ari_score, jaccard = classfier.celltypist_classification(real_data, + synthetic_data, + self.celltypist_model_path) + roc_score, _ = classfier.random_forest_eval(real_data, synthetic_data) + + return { + "celltypist_ari": ari_score, + "celltypist_jaccard": jaccard, + "randomforest_roc": roc_score, + } + + @staticmethod + def save_results_to_csv(results, output_file): + df = pd.DataFrame([results]) + df.to_csv(output_file, index=False) + +@click.group() +def cli(): + pass + +@click.command() +def run_statistical_eval(): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + results = evaluator.get_statistical_evals() + + output_file = os.path.join(evaluator.res_files_dir, f"statistics_evals.csv") + evaluator.save_results_to_csv(results, output_file) + click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") + +@click.command() +@click.argument("n_hvgs", type=int, default=2000) +def run_umap_eval(n_hvgs): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + evaluator.get_umap_evals(n_hvgs) + +@click.command() +@click.argument("cell_label", type=str, default="CD4 ET") +def run_qq_eval(cell_label: str): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + evaluator.save_qq_evals(cell_label=cell_label) + +@click.command() +def run_classification_eval(): + with open("config.yaml", 'r') as file: + config = yaml.safe_load(file) + + evaluator = SingleCellEvaluator(config=config) + results = evaluator.get_classification_evals() + + output_file = os.path.join(evaluator.res_files_dir, f"classification_evals.csv") + evaluator.save_results_to_csv(results, output_file) + click.echo(f"Evaluation for classification is completed. Results saved to {output_file}") + +cli.add_command(run_classification_eval) +cli.add_command(run_umap_eval) +cli.add_command(run_statistical_eval) +cli.add_command(run_qq_eval) + +if __name__ == '__main__': + cli() diff --git a/src/evaluation/utils/sc_metrics.py b/src/evaluation/utils/sc_metrics.py index c886383..131fe7a 100644 --- a/src/evaluation/utils/sc_metrics.py +++ b/src/evaluation/utils/sc_metrics.py @@ -1,190 +1,182 @@ import umap -import os +import os import numpy as np import scanpy as sc import scipy.stats as stats -import scipy.sparse +import scipy.sparse import matplotlib.pyplot as plt import seaborn as sns from scipy.spatial.distance import cdist from scipy.stats import spearmanr from sklearn.model_selection import train_test_split -from sklearn.decomposition import PCA, TruncatedSVD +from sklearn.decomposition import PCA from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import LabelBinarizer from sklearn.metrics.pairwise import rbf_kernel -from sklearn.metrics import adjusted_rand_score, roc_auc_score, jaccard_score +from sklearn.metrics import (adjusted_rand_score, roc_auc_score, + jaccard_score) from scib.metrics import ilisi_graph import celltypist +#from celltypist import models +#from celltypist.models import Model from scipy.sparse import issparse + + + def filter_low_quality_cells_and_genes(adata, min_counts=10, min_cells=3): - """ - Filters cells and genes based on minimum counts. - Uses Scanpy’s built-in filtering functions (which are sparse-aware). - """ - adata = adata.copy() + adata = adata.copy() sc.pp.filter_cells(adata, min_counts=min_counts) sc.pp.filter_genes(adata, min_cells=min_cells) + return adata -def get_dense_column(adata, i): - """ - Returns the i-th column of adata.X as a dense vector. - This avoids converting the entire matrix to dense at once. - """ - X = adata.X - if issparse(X): - return X[:, i].toarray().ravel() - else: - return np.array(X[:, i]).ravel() + +def is_sparse(adata): + if scipy.sparse.issparse(adata.X): + return adata.X.toarray() # Convert sparse to dense + return adata.X # Already dense + +def to_dense(adata): + return adata.X.toarray() if scipy.sparse.issparse(adata.X) else adata.X + def check_for_inf_nan(adata, label): - """ - Checks for NaN/Inf values in adata.X without converting the whole matrix. - """ - X = adata.X - if issparse(X): - data = X.data - else: - data = np.array(X) + X = adata.X.toarray() if issparse(adata.X) else np.array(adata.X) print(f"Checking {label} dataset:") - print(f"NaNs? {np.isnan(data).any()}") - print(f"Infs? {np.isinf(data).any()}") - print(f"Min: {data.min()}, Max: {data.max()}\n") + print(f"NaNs? {np.isnan(X).any()}") + print(f"Infs? {np.isinf(X).any()}") + print(f"Min: {np.min(X)}, Max: {np.max(X)}\n") + + def check_missing_genes(real_data, synthetic_data): - """ - Compares gene names between real and synthetic datasets. - """ + # Convert to sets for easy comparison real_genes = set(real_data.var_names) synthetic_genes = set(synthetic_data.var_names) + + # Find missing genes missing_in_real = synthetic_genes - real_genes missing_in_synthetic = real_genes - synthetic_genes print(f"Genes in synthetic but not in real: {len(missing_in_real)}") print(f"Genes in real but not in synthetic: {len(missing_in_synthetic)}") + + # Print some missing genes print(f"Example missing in real: {list(missing_in_real)[:10]}") print(f"Example missing in synthetic: {list(missing_in_synthetic)[:10]}") + print(f"real_data.var_names dtype: {real_data.var_names.dtype}") print(f"synthetic_data.var_names dtype: {synthetic_data.var_names.dtype}") + class Statistics: def __init__(self, random_seed=42): self.random_seed = random_seed np.random.seed(self.random_seed) def compute_scc(self, real_data, synthetic_data, n_hvgs=5000): - """ - Computes the mean Spearman correlation across highly variable genes (HVGs) - between the real and synthetic datasets. Instead of converting the whole - expression matrix to dense, each gene column is converted on the fly. - """ np.random.seed(self.random_seed) - check_missing_genes(real_data, synthetic_data) - # Align genes using the gene names from synthetic_data - common_genes = synthetic_data.var_names - real_data = real_data[:, common_genes] - synthetic_data = synthetic_data[:, common_genes] + check_missing_genes(real_data, synthetic_data ) + real_data = real_data[:, synthetic_data.var_names] + synthetic_data = synthetic_data[:, real_data.var_names] check_for_inf_nan(real_data, "Real") check_for_inf_nan(synthetic_data, "Synthetic") - # Normalize and log-transform both datasets + # Normalize both datasets sc.pp.normalize_total(real_data, target_sum=1e4) sc.pp.log1p(real_data) + sc.pp.normalize_total(synthetic_data, target_sum=1e4) sc.pp.log1p(synthetic_data) check_for_inf_nan(real_data, "Real") check_for_inf_nan(synthetic_data, "Synthetic") - # Identify HVGs using the combined dataset + # Identify HVGs combined_adata = real_data.concatenate(synthetic_data) sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) # Subset to HVGs - hvgs = combined_adata.var["highly_variable"] - real_hvg = real_data[:, hvgs] - synth_hvg = synthetic_data[:, hvgs] - - # Compute Spearman correlation gene-by-gene - scc_values = [] - for i in range(real_hvg.n_vars): - real_vec = get_dense_column(real_hvg, i) - synth_vec = get_dense_column(synth_hvg, i) - corr, _ = stats.spearmanr(real_vec, synth_vec, nan_policy='omit') - scc_values.append(corr) - scc_values = np.array(scc_values) + real_hvg = real_data[:, combined_adata.var["highly_variable"]] + synth_hvg = synthetic_data[:, combined_adata.var["highly_variable"]] + + # Convert to dense format + real_exp = is_sparse(real_hvg) + synth_exp = is_sparse(synth_hvg) + + # Compute Spearman correlation per gene (HVGs only) + scc_values = np.array([ + stats.spearmanr(real_exp[:, i], synth_exp[:, i], nan_policy='omit')[0] + for i in range(real_exp.shape[1]) + ]) + + # Handle NaNs return np.nanmean(scc_values) if not np.all(np.isnan(scc_values)) else np.nan + def compute_mmd_optimized(self, real_data, synthetic_data, sample_size=20000, n_pca=50, gamma=1.0, n_hvgs=5000): - """ - Computes the Maximum Mean Discrepancy (MMD) between the two datasets. - Uses a sparse-aware subsampling and PCA approach. If the HVG data remains sparse, - TruncatedSVD is used instead of the standard PCA. - """ + # Ensure both datasets have the same gene order np.random.seed(self.random_seed) - # Align genes using synthetic_data's ordering - common_genes = synthetic_data.var_names - real_data = real_data[:, common_genes] - synthetic_data = synthetic_data[:, common_genes] + real_data, synthetic_data = real_data[:, synthetic_data.var_names], synthetic_data[:, real_data.var_names] + # Identify HVGs combined_adata = real_data.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) - hvgs = combined_adata.var["highly_variable"] - real_hvg = real_data[:, hvgs] - synth_hvg = synthetic_data[:, hvgs] - - n_real = real_hvg.n_obs - n_synth = synth_hvg.n_obs - - real_idx = np.random.choice(n_real, min(sample_size, n_real), replace=False) - synth_idx = np.random.choice(n_synth, min(sample_size, n_synth), replace=False) - - # Process sparse or dense data accordingly - if issparse(real_hvg.X): - real_sample = real_hvg.X[real_idx] - synth_sample = synth_hvg.X[synth_idx] - from scipy.sparse import vstack - combined_sample = vstack([real_sample, synth_sample]) - pca_model = TruncatedSVD(n_components=n_pca, random_state=self.random_seed) - combined_pca = pca_model.fit_transform(combined_sample) - else: - real_sample = real_hvg.X[real_idx] - synth_sample = synth_hvg.X[synth_idx] - combined_sample = np.vstack([real_sample, synth_sample]) - pca_model = PCA(n_components=n_pca, random_state=self.random_seed) - combined_pca = pca_model.fit_transform(combined_sample) - - real_pca = combined_pca[:len(real_sample)] - synth_pca = combined_pca[len(real_sample):] - + # Subset to HVGs + real_hvg = real_data[:, combined_adata.var["highly_variable"]] + synth_hvg = synthetic_data[:, combined_adata.var["highly_variable"]] + + # Convert sparse to dense + real_dense = real_hvg.X.toarray() if scipy.sparse.issparse(real_hvg.X) else real_hvg.X + synth_dense = synth_hvg.X.toarray() if scipy.sparse.issparse(synth_hvg.X) else synth_hvg.X + + # Subsample + real_idx = np.random.choice(real_dense.shape[0], + min(sample_size, real_dense.shape[0]), + replace=False) + synth_idx = np.random.choice(synth_dense.shape[0], + min(sample_size, synth_dense.shape[0]), + replace=False) + + real_sample = real_dense[real_idx] + synth_sample = synth_dense[synth_idx] + + # Combine for PCA + combined_sample = np.vstack([real_sample, synth_sample]) + pca = PCA(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca.fit_transform(combined_sample) + + # Split PCA results + real_pca = combined_pca[: len(real_sample)] + synth_pca = combined_pca[len(real_sample) :] + + # Compute MMD K_xx = rbf_kernel(real_pca, real_pca, gamma=gamma).mean() K_yy = rbf_kernel(synth_pca, synth_pca, gamma=gamma).mean() K_xy = rbf_kernel(real_pca, synth_pca, gamma=gamma).mean() return K_xx + K_yy - 2 * K_xy + + # Goal: Measure the mixing of real and synthetic cells in a shared space. def compute_lisi(self, real_data, synthetic_data, n_hvgs=5000): - """ - Computes the Local Inverse Simpson’s Index (LISI) to measure mixing - of real and synthetic cells in a shared low-dimensional space. - """ + # Ensure both datasets have the same genes np.random.seed(self.random_seed) - common_genes = synthetic_data.var_names - real_data = real_data[:, common_genes] - synthetic_data = synthetic_data[:, common_genes] + real_data = real_data[:, synthetic_data.var_names] + synthetic_data = synthetic_data[:, real_data.var_names] combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) - # Create a numeric batch label (0 = real, 1 = synthetic) + # Assign batch labels (0 = real, 1 = synthetic) combined_adata.obs["batch"] = (combined_adata.obs["source"] == "synthetic").astype(int) sc.pp.normalize_total(combined_adata, target_sum=1e4) @@ -192,20 +184,27 @@ def compute_lisi(self, real_data, synthetic_data, n_hvgs=5000): sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) + # ---- Downsampling (if enabled) ---- + #if use_downsample: + # sample_size = int(downsample_ratio * combined_adata.shape[0]) # Compute sample size + # sampled_idx = np.random.choice(combined_adata.shape[0], size=sample_size, replace=False) + # combined_adata = combined_adata[sampled_idx, :] + + # Perform PCA + sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') + return ilisi_graph(combined_adata, batch_key="batch", type_="knn") + + + # Goal: Measure how well real & synthetic cells cluster into the same types. def compute_ari(self, real_data, synthetic_data, cell_type_col, n_hvgs=5000): - """ - Computes the Adjusted Rand Index (ARI) to measure clustering consistency - between real and synthetic data. Clusters are obtained via Scanpy's Louvain. - """ + # Ensure both datasets have the same genes np.random.seed(self.random_seed) - common_genes = synthetic_data.var_names - real_data = real_data[:, common_genes] - synthetic_data = synthetic_data[:, common_genes] + real_data = real_data[:, synthetic_data.var_names] + synthetic_data = synthetic_data[:, real_data.var_names] combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) @@ -215,123 +214,160 @@ def compute_ari(self, real_data, synthetic_data, cell_type_col, n_hvgs=5000): sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) + + # Perform PCA + sc.pp.pca(combined_adata, n_comps=100, random_state=self.random_seed) sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') sc.tl.louvain(combined_adata) # Convert Louvain clusters to numerical labels combined_adata.obs["louvain"] = combined_adata.obs["louvain"].astype("category").cat.codes - real_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "real", "louvain"].values - synthetic_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "synthetic", "louvain"].values + real_clusters = combined_adata.obs.loc[ + combined_adata.obs["source"] == "real", "louvain"].values + synthetic_clusters = combined_adata.obs.loc[ + combined_adata.obs["source"] == "synthetic", "louvain"].values ari_real_vs_syn = adjusted_rand_score(real_clusters, synthetic_clusters) - ari_gt_vs_comb = adjusted_rand_score(combined_adata.obs[cell_type_col], combined_adata.obs["louvain"]) + ari_gt_vs_comb = adjusted_rand_score(combined_adata.obs[cell_type_col], + combined_adata.obs["louvain"]) return ari_real_vs_syn, ari_gt_vs_comb + + + class VisualizeClassify: + ### add figures_dir = def __init__(self, sc_figures_dir, random_seed=42): self.random_seed = random_seed self.sc_figures_dir = sc_figures_dir np.random.seed(self.random_seed) + # self.figures_dir = figures_dir + ## get example name instead of sc_figures dir def plot_umap(self, real_data, synthetic_data, n_hvgs=5000): - """ - Creates and saves a UMAP plot of the combined real and synthetic data. - """ sc.settings.figdir = self.sc_figures_dir np.random.seed(self.random_seed) + # Combine datasets with batch labels check_for_inf_nan(real_data, "Real") check_for_inf_nan(synthetic_data, "Synthetic") combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) - sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.normalize_total(combined_adata, target_sum=1e4 ) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + # Perform PCA & UMAP sc.pp.pca(combined_adata, random_state=self.random_seed) sc.pp.neighbors(combined_adata) sc.tl.umap(combined_adata, random_state=self.random_seed) - sc.pl.umap(combined_adata, - color=["source"], + # Plot UMAP + sc.pl.umap(combined_adata, + color=["source"], title="UMAP of Real vs Synthetic Data", - save=f"syn_test_PCA_HVG={n_hvgs}.png") + save = f"syn_test_PCA_HVG={n_hvgs}.png") + def celltypist_classification(self, real_data_test, synthetic_data, celltypist_model, n_hvgs=5000): - """ - Uses a CellTypist model to annotate cells from both datasets and then compares - the predicted labels via ARI and Jaccard scores. - """ np.random.seed(self.random_seed) + + # Combine datasets for HVG selection combined_adata = real_data_test.concatenate(synthetic_data) + + # Normalize before selecting HVGs sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) - # Normalize and log-transform each dataset individually + ### Normalize and logp sc.pp.normalize_total(real_data_test, target_sum=1e4) sc.pp.log1p(real_data_test) + + ### Normalize and logp sc.pp.normalize_total(synthetic_data, target_sum=1e4) sc.pp.log1p(synthetic_data) + # Subset both datasets to HVGs real_data_test = real_data_test[:, combined_adata.var['highly_variable']] synthetic_data = synthetic_data[:, combined_adata.var['highly_variable']] + # Load CellTypist model model = celltypist.models.Model.load(celltypist_model) real_predictions = celltypist.annotate(real_data_test, model=model) synthetic_predictions = celltypist.annotate(synthetic_data, model=model) + # Extract predicted labels real_labels = real_predictions.predicted_labels.values.ravel() synthetic_labels = synthetic_predictions.predicted_labels.values.ravel() + # Compute ARI score ari_score = adjusted_rand_score(real_labels, synthetic_labels) + # Compute Jaccard score for multi-class labels lb = LabelBinarizer() real_onehot = lb.fit_transform(real_labels) synthetic_onehot = lb.transform(synthetic_labels) jaccard_scores = [ - jaccard_score(real_onehot[:, i], synthetic_onehot[:, i]) + jaccard_score(real_onehot[:, i], synthetic_onehot[:, i]) for i in range(real_onehot.shape[1]) ] jaccard = np.mean(jaccard_scores) return ari_score, jaccard + + ## whether it can separate synthetic vs real def random_forest_eval(self, real_data, synthetic_data, n_hvgs=5000): - """ - Evaluates how well a Random Forest can separate real vs. synthetic cells. - After batch correction, the expression matrix is converted to dense only once. - """ np.random.seed(self.random_seed) + + # Explicitly label real vs. synthetic real_data.obs["source"] = "real" synthetic_data.obs["source"] = "synthetic" + # Concatenate datasets combined_adata = real_data.concatenate( synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] ) + # Normalize & log transform sc.pp.normalize_total(combined_adata, target_sum=1e4) sc.pp.log1p(combined_adata) + + # Select highly variable genes (HVGs) sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) combined_adata = combined_adata[:, combined_adata.var['highly_variable']] - sc.pp.combat(combined_adata, key="source") + # **Batch correction using Combat** + sc.pp.combat(combined_adata, key="source") # Removes batch effect - # Convert to dense only at this final stage (necessary for RandomForest) + # Convert sparse to dense if needed X = combined_adata.X.A if hasattr(combined_adata.X, "A") else combined_adata.X + + # Assign labels: 0 = real, 1 = synthetic y = (combined_adata.obs["source"] == "synthetic").astype(int).values + # Train-test split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=self.random_seed) + # Train Random Forest rf = RandomForestClassifier(n_estimators=1000, max_depth=5, random_state=self.random_seed) rf.fit(X_train, y_train) + # Predict probabilities and compute AUC pred_probs = rf.predict_proba(X_test)[:, 1] auc = roc_auc_score(y_test, pred_probs) return auc, pred_probs + + + + + + + + diff --git a/src/evaluation/utils/sc_metrics_opt.py b/src/evaluation/utils/sc_metrics_opt.py new file mode 100644 index 0000000..a122689 --- /dev/null +++ b/src/evaluation/utils/sc_metrics_opt.py @@ -0,0 +1,376 @@ +import umap +import os +import numpy as np +import scanpy as sc +import scipy.stats as stats +import scipy.sparse +import matplotlib.pyplot as plt +import seaborn as sns +from scipy.spatial.distance import cdist +from scipy.stats import spearmanr +from sklearn.model_selection import train_test_split +from sklearn.decomposition import PCA, TruncatedSVD +from sklearn.ensemble import RandomForestClassifier +from sklearn.preprocessing import LabelBinarizer +from sklearn.metrics.pairwise import rbf_kernel +from sklearn.metrics import adjusted_rand_score, roc_auc_score, jaccard_score +from scib.metrics import ilisi_graph +import celltypist +from scipy.sparse import issparse + +_DEF_N_HVGS = 120 + +def filter_low_quality_cells_and_genes(adata, min_counts=10, min_cells=3): + """ + Filters cells and genes based on minimum counts. + Uses Scanpy’s built-in filtering functions (which are sparse-aware). + """ + adata = adata.copy() + sc.pp.filter_cells(adata, min_counts=min_counts) + sc.pp.filter_genes(adata, min_cells=min_cells) + return adata + +def get_dense_column(adata, i): + """ + Returns the i-th column of adata.X as a dense vector. + This avoids converting the entire matrix to dense at once. + """ + X = adata.X + if issparse(X): + return X[:, i].toarray().ravel() + else: + return np.array(X[:, i]).ravel() + +def check_for_inf_nan(adata, label): + """ + Checks for NaN/Inf values in adata.X without converting the whole matrix. + """ + X = adata.X + if issparse(X): + data = X.data + else: + data = np.array(X) + print(f"==> Checking {label} dataset:") + print(f" NaNs? {np.isnan(data).any()}") + print(f" Infs? {np.isinf(data).any()}") + print(f" Min: {data.min()}, Max: {data.max()}\n") + +def check_missing_genes(real_data, synthetic_data): + """ + Compares gene names between real and synthetic datasets. + """ + real_genes = set(real_data.var_names) + synthetic_genes = set(synthetic_data.var_names) + missing_in_real = synthetic_genes - real_genes + missing_in_synthetic = real_genes - synthetic_genes + + print("==> Checking gene differences:") + print(f" Genes in synthetic but not in real: {len(missing_in_real)}") + print(f" Genes in real but not in synthetic: {len(missing_in_synthetic)}") + print(f" Example missing in real: {list(missing_in_real)[:10]}") + print(f" Example missing in synthetic: {list(missing_in_synthetic)[:10]}") + print(f" real_data.var_names dtype: {real_data.var_names.dtype}") + print(f" synthetic_data.var_names dtype: {synthetic_data.var_names.dtype}\n") + +class Statistics: + def __init__(self, random_seed=42): + self.random_seed = random_seed + np.random.seed(self.random_seed) + + def compute_scc(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + """ + Computes the mean Spearman correlation across highly variable genes (HVGs) + between the real and synthetic datasets. Instead of converting the whole + expression matrix to dense, each gene column is converted on the fly. + """ + np.random.seed(self.random_seed) + print("=== Starting compute_scc ===") + check_missing_genes(real_data, synthetic_data) + + # Align genes using the gene names from synthetic_data + common_genes = synthetic_data.var_names + print("Aligning real and synthetic data on common genes...") + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + + check_for_inf_nan(real_data, "Real") + check_for_inf_nan(synthetic_data, "Synthetic") + + # Normalize and log-transform both datasets + print("Normalizing and log-transforming real data...") + sc.pp.normalize_total(real_data, target_sum=1e4) + sc.pp.log1p(real_data) + print("Normalizing and log-transforming synthetic data...") + sc.pp.normalize_total(synthetic_data, target_sum=1e4) + sc.pp.log1p(synthetic_data) + + check_for_inf_nan(real_data, "Real") + check_for_inf_nan(synthetic_data, "Synthetic") + + # Identify HVGs using the combined dataset + print("Concatenating datasets to identify highly variable genes...") + combined_adata = real_data.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + print("Identifying highly variable genes...") + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + + # Subset to HVGs + hvgs = combined_adata.var["highly_variable"] + print(f"Subsetting to HVGs: {hvgs.sum()} genes selected.") + real_hvg = real_data[:, hvgs] + synth_hvg = synthetic_data[:, hvgs] + + # Compute Spearman correlation gene-by-gene + print("Computing Spearman correlation gene-by-gene...") + scc_values = [] + total_genes = real_hvg.n_vars + progress_interval = max(1, total_genes // 100) + for i in range(total_genes): + real_vec = get_dense_column(real_hvg, i) + synth_vec = get_dense_column(synth_hvg, i) + corr, _ = stats.spearmanr(real_vec, synth_vec, nan_policy='omit') + scc_values.append(corr) + # Print progress every 10% + if (i + 1) % progress_interval == 0 or (i + 1) == total_genes: + percent = ((i + 1) / total_genes) * 100 + print(f" Processed {i + 1} / {total_genes} genes ({percent:.0f}%)") + scc_values = np.array(scc_values) + mean_corr = np.nanmean(scc_values) if not np.all(np.isnan(scc_values)) else np.nan + print(f"Finished compute_scc: Mean Spearman correlation = {mean_corr:.4f}\n") + return mean_corr + + def compute_mmd_optimized(self, real_data, synthetic_data, sample_size=20000, + n_pca=50, gamma=1.0, n_hvgs=_DEF_N_HVGS): + np.random.seed(self.random_seed) + # Align genes using synthetic_data's ordering + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + + combined_adata = real_data.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + + hvgs = combined_adata.var["highly_variable"] + real_hvg = real_data[:, hvgs] + synth_hvg = synthetic_data[:, hvgs] + + n_real = real_hvg.n_obs + n_synth = synth_hvg.n_obs + + real_idx = np.random.choice(n_real, min(sample_size, n_real), replace=False) + synth_idx = np.random.choice(n_synth, min(sample_size, n_synth), replace=False) + + # Process sparse or dense data accordingly + if issparse(real_hvg.X): + real_sample = real_hvg.X[real_idx] + synth_sample = synth_hvg.X[synth_idx] + from scipy.sparse import vstack + combined_sample = vstack([real_sample, synth_sample]) + pca_model = TruncatedSVD(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca_model.fit_transform(combined_sample) + else: + real_sample = real_hvg.X[real_idx] + synth_sample = synth_hvg.X[synth_idx] + combined_sample = np.vstack([real_sample, synth_sample]) + pca_model = PCA(n_components=n_pca, random_state=self.random_seed) + combined_pca = pca_model.fit_transform(combined_sample) + + # Use shape[0] instead of len() to get the number of real samples + num_real = real_sample.shape[0] + real_pca = combined_pca[:num_real] + synth_pca = combined_pca[num_real:] + + K_xx = rbf_kernel(real_pca, real_pca, gamma=gamma).mean() + K_yy = rbf_kernel(synth_pca, synth_pca, gamma=gamma).mean() + K_xy = rbf_kernel(real_pca, synth_pca, gamma=gamma).mean() + + return K_xx + K_yy - 2 * K_xy + + def compute_lisi(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + np.random.seed(self.random_seed) + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + # Create a numeric batch label (0 = real, 1 = synthetic) + combined_adata.obs["batch"] = (combined_adata.obs["source"] == "synthetic").astype(int) + + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + # Dynamically determine the number of PCA components + n_obs, n_vars = combined_adata.shape + n_comps = min(n_hvgs, n_obs - 1, n_vars - 1) if n_obs > 1 and n_vars > 1 else 1 + print(f"Performing PCA with n_comps={n_comps} (n_obs={n_obs}, n_vars={n_vars})") + + sc.pp.pca(combined_adata, n_comps=n_comps, random_state=self.random_seed) + sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') + + return ilisi_graph(combined_adata, batch_key="batch", type_="knn") + + + def compute_ari(self, real_data, synthetic_data, cell_type_col, n_hvgs=_DEF_N_HVGS): + """ + Computes the Adjusted Rand Index (ARI) to measure clustering consistency + between real and synthetic data. Clusters are obtained via Scanpy's Louvain. + """ + np.random.seed(self.random_seed) + print("=== Starting compute_ari ===") + common_genes = synthetic_data.var_names + real_data = real_data[:, common_genes] + synthetic_data = synthetic_data[:, common_genes] + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + + print("Normalizing, log-transforming, and selecting HVGs for ARI computation...") + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + n_obs, n_vars = combined_adata.shape + n_comps = min(n_hvgs, n_obs - 1, n_vars - 1) if n_obs > 1 and n_vars > 1 else 1 + print(f"Performing PCA with n_comps={n_comps} (n_obs={n_obs}, n_vars={n_vars}) and computing neighbors") + sc.pp.pca(combined_adata, n_comps=n_comps, random_state=self.random_seed) + sc.pp.neighbors(combined_adata, n_neighbors=10, method='umap') + print("Clustering with Louvain...") + sc.tl.louvain(combined_adata) + + # Convert Louvain clusters to numerical labels + combined_adata.obs["louvain"] = combined_adata.obs["louvain"].astype("category").cat.codes + real_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "real", "louvain"].values + synthetic_clusters = combined_adata.obs.loc[combined_adata.obs["source"] == "synthetic", "louvain"].values + ari_real_vs_syn = adjusted_rand_score(real_clusters, synthetic_clusters) + ari_gt_vs_comb = adjusted_rand_score(combined_adata.obs[cell_type_col], combined_adata.obs["louvain"]) + + print(f"Finished compute_ari: ARI (real vs synthetic) = {ari_real_vs_syn:.4f}, ARI (ground truth vs clusters) = {ari_gt_vs_comb:.4f}\n") + return ari_real_vs_syn, ari_gt_vs_comb + +class VisualizeClassify: + def __init__(self, sc_figures_dir, random_seed=42): + self.random_seed = random_seed + self.sc_figures_dir = sc_figures_dir + np.random.seed(self.random_seed) + + def plot_umap(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + """ + Creates and saves a UMAP plot of the combined real and synthetic data. + """ + print("=== Starting UMAP plotting ===") + sc.settings.figdir = self.sc_figures_dir + np.random.seed(self.random_seed) + check_for_inf_nan(real_data, "Real") + check_for_inf_nan(synthetic_data, "Synthetic") + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + + print("Normalizing, log-transforming, and selecting HVGs for UMAP...") + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + n_obs, n_vars = combined_adata.shape + n_comps = min(n_hvgs, n_obs - 1, n_vars - 1) if n_obs > 1 and n_vars > 1 else 1 + print(f"Performing PCA with n_comps={n_comps} (n_obs={n_obs}, n_vars={n_vars}), computing neighbors, and generating UMAP...") + sc.pp.pca(combined_adata, n_comps=n_comps, random_state=self.random_seed) + sc.pp.neighbors(combined_adata) + sc.tl.umap(combined_adata, random_state=self.random_seed) + + sc.pl.umap(combined_adata, + color=["source"], + title="UMAP of Real vs Synthetic Data", + save=f"syn_test_PCA_HVG={n_hvgs}.png") + print("UMAP plot saved.\n") + + def celltypist_classification(self, real_data_test, synthetic_data, celltypist_model, n_hvgs=_DEF_N_HVGS): + """ + Uses a CellTypist model to annotate cells from both datasets and then compares + the predicted labels via ARI and Jaccard scores. + """ + np.random.seed(self.random_seed) + print("=== Starting celltypist classification ===") + combined_adata = real_data_test.concatenate(synthetic_data) + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + + # Normalize and log-transform each dataset individually + sc.pp.normalize_total(real_data_test, target_sum=1e4) + sc.pp.log1p(real_data_test) + sc.pp.normalize_total(synthetic_data, target_sum=1e4) + sc.pp.log1p(synthetic_data) + + # Subset both datasets to HVGs + real_data_test = real_data_test[:, combined_adata.var['highly_variable']] + synthetic_data = synthetic_data[:, combined_adata.var['highly_variable']] + + print("Loading CellTypist model and annotating cells...") + model = celltypist.models.Model.load(celltypist_model) + real_predictions = celltypist.annotate(real_data_test, model=model) + synthetic_predictions = celltypist.annotate(synthetic_data, model=model) + + real_labels = real_predictions.predicted_labels.values.ravel() + synthetic_labels = synthetic_predictions.predicted_labels.values.ravel() + + ari_score = adjusted_rand_score(real_labels, synthetic_labels) + + lb = LabelBinarizer() + real_onehot = lb.fit_transform(real_labels) + synthetic_onehot = lb.transform(synthetic_labels) + + jaccard_scores = [ + jaccard_score(real_onehot[:, i], synthetic_onehot[:, i]) + for i in range(real_onehot.shape[1]) + ] + jaccard = np.mean(jaccard_scores) + print(f"Finished celltypist classification: ARI = {ari_score:.4f}, Jaccard = {jaccard:.4f}\n") + return ari_score, jaccard + + def random_forest_eval(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): + """ + Evaluates how well a Random Forest can separate real vs. synthetic cells. + After batch correction, the expression matrix is converted to dense only once. + """ + np.random.seed(self.random_seed) + print("=== Starting Random Forest evaluation ===") + real_data.obs["source"] = "real" + synthetic_data.obs["source"] = "synthetic" + + combined_adata = real_data.concatenate( + synthetic_data, batch_key="source", batch_categories=["real", "synthetic"] + ) + + print("Normalizing, log-transforming, and selecting HVGs for Random Forest...") + sc.pp.normalize_total(combined_adata, target_sum=1e4) + sc.pp.log1p(combined_adata) + sc.pp.highly_variable_genes(combined_adata, flavor="seurat", n_top_genes=n_hvgs) + combined_adata = combined_adata[:, combined_adata.var['highly_variable']] + + print("Applying Combat batch correction...") + sc.pp.combat(combined_adata, key="source") + + print("Converting expression matrix to dense and splitting data...") + X = combined_adata.X.A if hasattr(combined_adata.X, "A") else combined_adata.X + y = (combined_adata.obs["source"] == "synthetic").astype(int).values + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=self.random_seed) + + print("Training Random Forest classifier...") + rf = RandomForestClassifier(n_estimators=1000, max_depth=5, random_state=self.random_seed) + rf.fit(X_train, y_train) + + pred_probs = rf.predict_proba(X_test)[:, 1] + auc = roc_auc_score(y_test, pred_probs) + print(f"Finished Random Forest evaluation: AUC = {auc:.4f}\n") + + return auc, pred_probs From 5bed2eae488255dc7f0ae92a7c504e6092ba48d8 Mon Sep 17 00:00:00 2001 From: Pablo Rodriguez Mier Date: Wed, 26 Feb 2025 11:14:08 +0100 Subject: [PATCH 4/5] add sc_dist_sparse method to blue_team --- src/evaluation/utils/sc_metrics_opt.py | 3 +- src/generators/blue_team.py | 13 ++-- src/generators/models/sc_dist.py | 82 +++++++++++--------------- 3 files changed, 42 insertions(+), 56 deletions(-) diff --git a/src/evaluation/utils/sc_metrics_opt.py b/src/evaluation/utils/sc_metrics_opt.py index a122689..081a7e6 100644 --- a/src/evaluation/utils/sc_metrics_opt.py +++ b/src/evaluation/utils/sc_metrics_opt.py @@ -18,7 +18,7 @@ import celltypist from scipy.sparse import issparse -_DEF_N_HVGS = 120 +_DEF_N_HVGS = 5000 def filter_low_quality_cells_and_genes(adata, min_counts=10, min_cells=3): """ @@ -131,7 +131,6 @@ def compute_scc(self, real_data, synthetic_data, n_hvgs=_DEF_N_HVGS): synth_vec = get_dense_column(synth_hvg, i) corr, _ = stats.spearmanr(real_vec, synth_vec, nan_policy='omit') scc_values.append(corr) - # Print progress every 10% if (i + 1) % progress_interval == 0 or (i + 1) == total_genes: percent = ((i + 1) / total_genes) * 100 print(f" Processed {i + 1} / {total_genes} genes ({percent:.0f}%)") diff --git a/src/generators/blue_team.py b/src/generators/blue_team.py index 10ff16e..d9c2944 100644 --- a/src/generators/blue_team.py +++ b/src/generators/blue_team.py @@ -17,10 +17,11 @@ 'dpcvae': ('models.cvae', 'CVAEDataGenerationPipeline'), "ctgan": ('models.sdv_ctgan', 'CTGANDataGenerationPipeline'), "dpctgan": ('models.dpctgan', 'DPCTGANDataGenerationPipeline'), - "sc_dist": ('models.sc_dist', 'ScDistributionDataGenerator') + "sc_dist": ('models.sc_dist', 'ScDistributionDataGenerator'), + "sc_dist_sparse": ('models.sc_dist_opt', 'ScDistributionDataGenerator') } -## dynamic import to avoid package versioning errors +## dynamic import to avoid package versioning errors def get_generator_class(generator_name): if generator_name in generator_classes: module_name, class_name = generator_classes[generator_name] @@ -42,17 +43,17 @@ def cli(): def generate_split_indices(): configfile = "config.yaml" config = yaml.safe_load(open(configfile)) - rdataloader = RealDataLoader(config) + rdataloader = RealDataLoader(config) rdataloader.save_split_indices() -## the real data will be split into 5 train/test pairs +## the real data will be split into 5 train/test pairs ## based on the above generated {dataset_name}_split.yaml ## the data will be saved under data_splits/{dataset_name}/real/ @click.command() def generate_data_splits(): configfile = "config.yaml" config = yaml.safe_load(open(configfile)) - rdataloader = RealDataLoader(config) + rdataloader = RealDataLoader(config) # Save dataset rdataloader.save_split_data() @@ -143,4 +144,4 @@ def run_singlecell_generator(experiment_name: str = None): # else: # print("CUDA is NOT available.") -#check_cuda_availability() \ No newline at end of file +#check_cuda_availability() diff --git a/src/generators/models/sc_dist.py b/src/generators/models/sc_dist.py index 7210a96..18dd0cb 100644 --- a/src/generators/models/sc_dist.py +++ b/src/generators/models/sc_dist.py @@ -21,14 +21,12 @@ def __init__(self, config: Dict[str, Any]): self.distribution = self.generator_config["distribution"] # Either 'NB' or 'Poisson' self.cell_type_col_name = self.dataset_config["cell_type_col_name"] self.cell_label_col_name = self.dataset_config["cell_label_col_name"] - self.batch_size = self.generator_config.get("batch_size", None) # Optional batch size # Parameters for the data generation self.gene_means = None self.num_samples = None self.X_train_features = None self.cell_type_params = {} - self.max_real_value = None self.initialize_random_seeds() @@ -38,44 +36,38 @@ def initialize_random_seeds(self): def train(self): """Compute gene expression parameters for each cell type from training data.""" X_train_adata = self.load_train_anndata() - counts = X_train_adata.X + + counts = X_train_adata.X.toarray() if isinstance(X_train_adata.X, np.ndarray) else X_train_adata.X.A cell_types = X_train_adata.obs[self.cell_type_col_name].values cell_labels = X_train_adata.obs[self.cell_label_col_name].values self.cell_type_to_label = dict(set(zip(cell_types, cell_labels))) + print("Cell Type to Label Mapping:", self.cell_type_to_label) - # Determine max real expression value without converting sparse data to dense - if sp.issparse(counts): - self.max_real_value = counts.data.max() if counts.data.size > 0 else 0 - else: - self.max_real_value = counts.max() + # Store the max gene expression value from training + self.max_real_value = counts.max() print(f"Max real expression value from training: {self.max_real_value}") - unique_cell_types = np.unique(cell_types) - for cell_type in unique_cell_types: + for cell_type in np.unique(cell_types): print(f"Training on Cell Type: {cell_type}") + cell_type_mask = cell_types == cell_type cell_type_counts = counts[cell_type_mask, :] - if sp.issparse(cell_type_counts): - # Compute means and variances on sparse matrices: - means = np.array(cell_type_counts.mean(axis=0)).ravel() - # For variance: Var(X)=E[X^2] - (E[X])^2 - sq_means = np.array(cell_type_counts.power(2).mean(axis=0)).ravel() - variances = sq_means - means**2 - else: - means = cell_type_counts.mean(axis=0) - variances = cell_type_counts.var(axis=0) - + means = cell_type_counts.mean(axis=0) means = np.clip(means, 1e-6, None) # Avoid zero means if self.distribution == 'NB': - # Ensure variance is at least the mean + variances = cell_type_counts.var(axis=0) + + # variance >= mean to prevent negative dispersions variances = np.maximum(variances, means) + dispersions = (variances - means) / (means ** 2) dispersions = np.clip(dispersions, 1e-3, 10) # Avoid extreme values + # debugging print(f"Dispersion values for {cell_type}: min={dispersions.min()}, max={dispersions.max()}") if np.any(np.isnan(dispersions)): @@ -93,20 +85,21 @@ def train(self): print("Training completed successfully!") + + def generate(self): if self.max_real_value is None: raise ValueError("Training must be completed before generating data!") X_test_adata = self.load_test_anndata() - counts_shape = X_test_adata.X.shape - print("Original counts shape:", counts_shape) + counts = X_test_adata.X.toarray() if isinstance(X_test_adata.X, np.ndarray) else X_test_adata.X.A + print("Original counts shape:", counts.shape) cell_types = X_test_adata.obs[self.cell_type_col_name].values - synthetic_counts = sp.lil_matrix(counts_shape, dtype=np.int64) + synthetic_counts = sp.lil_matrix(counts.shape, dtype=np.int64) synthetic_cell_types = [] - unique_cell_types = np.unique(cell_types) - for cell_type in unique_cell_types: + for cell_type in np.unique(cell_types): print(f"Generating for Cell Type: {cell_type}") if str(cell_type) not in self.cell_type_params: @@ -125,38 +118,29 @@ def generate(self): dispersions = np.clip(dispersions, 1e-3, 10) # Prevent extreme values # Compute Negative Binomial parameters - n_param = np.clip(1 / (dispersions + 1e-6), 1e-2, 10) - p_param = np.clip(means / (means + n_param), 0.01, 0.99) + n_param = np.clip(1 / (dispersions + 1e-6), 1e-2, 10) + p_param = np.clip(means / (means + n_param), 0.01, 0.99) + # Debugging prints print(f"n_param range for {cell_type}: min={n_param.min()}, max={n_param.max()}") print(f"p_param range for {cell_type}: min={p_param.min()}, max={p_param.max()}") expected_variance = means + (means ** 2) / n_param print(f"Expected variance for {cell_type}: min={expected_variance.min()}, max={expected_variance.max()}") - # Use batch processing if batch_size is specified, otherwise process all cells at once - batch_size = self.batch_size if self.batch_size is not None else num_cells - - for start in range(0, num_cells, batch_size): - end = min(start + batch_size, num_cells) - current_batch_size = end - start + # Generate Negative Binomial samples + generated_data = st.nbinom.rvs(n=n_param, p=p_param, size=(num_cells, means.shape[0])).astype(np.int64) - if self.distribution == 'NB': - batch_generated = st.nbinom.rvs( - n=n_param, p=p_param, size=(current_batch_size, means.shape[0]) - ).astype(np.int64) - elif self.distribution == 'Poisson': - batch_generated = st.poisson.rvs( - means, size=(current_batch_size, means.shape[0]) - ).astype(np.int64) + elif self.distribution == 'Poisson': + generated_data = st.poisson.rvs(means, size=(num_cells, means.shape[0])).astype(np.int64) - # Limit extreme values to prevent memory explosion - upper_clip = np.percentile(batch_generated, 99.5) - batch_generated = np.clip(batch_generated, 0, min(upper_clip, self.max_real_value * 2)) + # Limit extreme values to prevent memory explosion + upper_clip = np.percentile(generated_data, 99.5) + generated_data = np.clip(generated_data, 0, min(upper_clip, self.max_real_value * 2)) - indices = cell_indices[start:end] - synthetic_counts[indices, :] = batch_generated - synthetic_cell_types.extend([cell_type] * current_batch_size) + # Store generated data + synthetic_counts[cell_indices, :] = generated_data + synthetic_cell_types.extend([cell_type] * num_cells) synthetic_counts_csr = synthetic_counts.tocsr().astype(np.int64) synthetic_adata = ad.AnnData(X=synthetic_counts_csr) @@ -165,5 +149,7 @@ def generate(self): return synthetic_adata + def load_from_checkpoint(self): pass + From 57168d47eed010dceb1c308cebae0cb7e99a6fbe Mon Sep 17 00:00:00 2001 From: Pablo Rodriguez Mier Date: Wed, 26 Feb 2025 11:20:06 +0100 Subject: [PATCH 5/5] add missing sc_dist_opt method --- src/generators/models/sc_dist_opt.py | 169 +++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 src/generators/models/sc_dist_opt.py diff --git a/src/generators/models/sc_dist_opt.py b/src/generators/models/sc_dist_opt.py new file mode 100644 index 0000000..7210a96 --- /dev/null +++ b/src/generators/models/sc_dist_opt.py @@ -0,0 +1,169 @@ +import os +import sys +import pandas as pd +import numpy as np +import scipy.stats as st +import scipy.sparse as sp +import scanpy as sc +import anndata as ad +from typing import Dict, Any + +src_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) +sys.path.append(src_dir) + +from generators.models.sc_base import BaseSingleCellDataGenerator + +class ScDistributionDataGenerator(BaseSingleCellDataGenerator): + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + self.noise_level = self.generator_config["noise_level"] + self.random_seed = self.generator_config["random_seed"] + self.distribution = self.generator_config["distribution"] # Either 'NB' or 'Poisson' + self.cell_type_col_name = self.dataset_config["cell_type_col_name"] + self.cell_label_col_name = self.dataset_config["cell_label_col_name"] + self.batch_size = self.generator_config.get("batch_size", None) # Optional batch size + + # Parameters for the data generation + self.gene_means = None + self.num_samples = None + self.X_train_features = None + self.cell_type_params = {} + self.max_real_value = None + + self.initialize_random_seeds() + + def initialize_random_seeds(self): + np.random.seed(self.random_seed) + + def train(self): + """Compute gene expression parameters for each cell type from training data.""" + X_train_adata = self.load_train_anndata() + counts = X_train_adata.X + cell_types = X_train_adata.obs[self.cell_type_col_name].values + cell_labels = X_train_adata.obs[self.cell_label_col_name].values + + self.cell_type_to_label = dict(set(zip(cell_types, cell_labels))) + print("Cell Type to Label Mapping:", self.cell_type_to_label) + + # Determine max real expression value without converting sparse data to dense + if sp.issparse(counts): + self.max_real_value = counts.data.max() if counts.data.size > 0 else 0 + else: + self.max_real_value = counts.max() + print(f"Max real expression value from training: {self.max_real_value}") + + unique_cell_types = np.unique(cell_types) + for cell_type in unique_cell_types: + print(f"Training on Cell Type: {cell_type}") + cell_type_mask = cell_types == cell_type + cell_type_counts = counts[cell_type_mask, :] + + if sp.issparse(cell_type_counts): + # Compute means and variances on sparse matrices: + means = np.array(cell_type_counts.mean(axis=0)).ravel() + # For variance: Var(X)=E[X^2] - (E[X])^2 + sq_means = np.array(cell_type_counts.power(2).mean(axis=0)).ravel() + variances = sq_means - means**2 + else: + means = cell_type_counts.mean(axis=0) + variances = cell_type_counts.var(axis=0) + + means = np.clip(means, 1e-6, None) # Avoid zero means + + if self.distribution == 'NB': + # Ensure variance is at least the mean + variances = np.maximum(variances, means) + dispersions = (variances - means) / (means ** 2) + dispersions = np.clip(dispersions, 1e-3, 10) # Avoid extreme values + + print(f"Dispersion values for {cell_type}: min={dispersions.min()}, max={dispersions.max()}") + + if np.any(np.isnan(dispersions)): + raise ValueError(f"NaN detected in dispersions for {cell_type}!") + + self.cell_type_params[str(cell_type)] = { + 'means': means.astype(np.float32), + 'dispersions': dispersions.astype(np.float32) + } + + elif self.distribution == 'Poisson': + self.cell_type_params[str(cell_type)] = { + 'means': means.astype(np.float32) + } + + print("Training completed successfully!") + + def generate(self): + if self.max_real_value is None: + raise ValueError("Training must be completed before generating data!") + + X_test_adata = self.load_test_anndata() + counts_shape = X_test_adata.X.shape + print("Original counts shape:", counts_shape) + + cell_types = X_test_adata.obs[self.cell_type_col_name].values + synthetic_counts = sp.lil_matrix(counts_shape, dtype=np.int64) + synthetic_cell_types = [] + + unique_cell_types = np.unique(cell_types) + for cell_type in unique_cell_types: + print(f"Generating for Cell Type: {cell_type}") + + if str(cell_type) not in self.cell_type_params: + print(f"Cell type {cell_type} not found in training data! Skipping...") + continue + + cell_type_mask = cell_types == cell_type + cell_indices = np.where(cell_type_mask)[0] + num_cells = len(cell_indices) + + means = self.cell_type_params[str(cell_type)]['means'].astype(np.float64) + means = np.clip(means, 1e-6, None) # Avoid zeros + + if self.distribution == 'NB': + dispersions = self.cell_type_params[str(cell_type)]['dispersions'].astype(np.float64) + dispersions = np.clip(dispersions, 1e-3, 10) # Prevent extreme values + + # Compute Negative Binomial parameters + n_param = np.clip(1 / (dispersions + 1e-6), 1e-2, 10) + p_param = np.clip(means / (means + n_param), 0.01, 0.99) + + print(f"n_param range for {cell_type}: min={n_param.min()}, max={n_param.max()}") + print(f"p_param range for {cell_type}: min={p_param.min()}, max={p_param.max()}") + + expected_variance = means + (means ** 2) / n_param + print(f"Expected variance for {cell_type}: min={expected_variance.min()}, max={expected_variance.max()}") + + # Use batch processing if batch_size is specified, otherwise process all cells at once + batch_size = self.batch_size if self.batch_size is not None else num_cells + + for start in range(0, num_cells, batch_size): + end = min(start + batch_size, num_cells) + current_batch_size = end - start + + if self.distribution == 'NB': + batch_generated = st.nbinom.rvs( + n=n_param, p=p_param, size=(current_batch_size, means.shape[0]) + ).astype(np.int64) + elif self.distribution == 'Poisson': + batch_generated = st.poisson.rvs( + means, size=(current_batch_size, means.shape[0]) + ).astype(np.int64) + + # Limit extreme values to prevent memory explosion + upper_clip = np.percentile(batch_generated, 99.5) + batch_generated = np.clip(batch_generated, 0, min(upper_clip, self.max_real_value * 2)) + + indices = cell_indices[start:end] + synthetic_counts[indices, :] = batch_generated + synthetic_cell_types.extend([cell_type] * current_batch_size) + + synthetic_counts_csr = synthetic_counts.tocsr().astype(np.int64) + synthetic_adata = ad.AnnData(X=synthetic_counts_csr) + synthetic_adata.obs[self.cell_type_col_name] = synthetic_cell_types + synthetic_adata.var_names = X_test_adata.var_names + + return synthetic_adata + + def load_from_checkpoint(self): + pass