diff --git a/immuneML/IO/dataset_export/AIRRExporter.py b/immuneML/IO/dataset_export/AIRRExporter.py index a0e7594fa..6f4704343 100644 --- a/immuneML/IO/dataset_export/AIRRExporter.py +++ b/immuneML/IO/dataset_export/AIRRExporter.py @@ -59,7 +59,7 @@ def export_repertoire_dataset(dataset: RepertoireDataset, path: Path, omit_colum else: new_metadata_df = pd.read_csv(path / f"{dataset.name}.csv") - labels = {label: list(new_metadata_df[label].unique()) + labels = {label: new_metadata_df[label].unique().tolist() for label in dataset.get_label_names(refresh=True)} dataset_yaml = RepertoireDataset.create_metadata_dict(metadata_file=path / f"{dataset.name}.csv", diff --git a/immuneML/config/default_params/ml_methods/progen_params.yaml b/immuneML/config/default_params/ml_methods/progen_params.yaml new file mode 100644 index 000000000..57f5744b7 --- /dev/null +++ b/immuneML/config/default_params/ml_methods/progen_params.yaml @@ -0,0 +1,14 @@ +locus: beta +device: cpu +num_frozen_layers: 22 +num_epochs: 2 +learning_rate: 0.00004 +fp16: False +prefix_text: '<|bos|>1' +suffix_text: '2<|eos|>' +max_new_tokens: 1024 +temperature: 1 +top_p: 0.9 +prompt: '1' +num_gen_batches: 1 +per_device_train_batch_size: 2 diff --git a/immuneML/config/default_params/reports/vj_gene_distribution_params.yaml b/immuneML/config/default_params/reports/vj_gene_distribution_params.yaml index 5b1e8e5ff..f7ccbf376 100644 --- a/immuneML/config/default_params/reports/vj_gene_distribution_params.yaml +++ b/immuneML/config/default_params/reports/vj_gene_distribution_params.yaml @@ -1,3 +1,4 @@ -split_by_label: False +split_by_label: false label: null -is_sequence_label: False \ No newline at end of file +is_sequence_label: false +show_joint_dist: true \ No newline at end of file diff --git a/immuneML/data_model/EncodedData.py b/immuneML/data_model/EncodedData.py index 0a777a18a..d6f5bdef7 100644 --- a/immuneML/data_model/EncodedData.py +++ b/immuneML/data_model/EncodedData.py @@ -27,7 +27,7 @@ class EncodedData: def __init__(self, examples, labels: dict = None, example_ids: list = None, feature_names: list = None, feature_annotations: pd.DataFrame = None, encoding: str = None, example_weights: list = None, info: dict = None, - dimensionality_reduced_data: np.ndarray = None): + dimensionality_reduced_data: np.ndarray = None, dim_names: list = None): assert feature_names is None or examples.shape[1] == len(feature_names), \ (f"EncodedData: the length of feature_names ({len(feature_names)}) must match the feature dimension of the " @@ -59,6 +59,7 @@ def __init__(self, examples, labels: dict = None, example_ids: list = None, feat self.encoding = encoding self.example_weights = example_weights self.info = info + self.dim_names = dim_names self.dimensionality_reduced_data = dimensionality_reduced_data def __getstate__(self): diff --git a/immuneML/dsl/instruction_parsers/ClusteringParser.py b/immuneML/dsl/instruction_parsers/ClusteringParser.py index 1ab134d41..ad5e1e1fb 100644 --- a/immuneML/dsl/instruction_parsers/ClusteringParser.py +++ b/immuneML/dsl/instruction_parsers/ClusteringParser.py @@ -1,5 +1,6 @@ import copy import inspect +import logging from pathlib import Path from typing import List @@ -152,7 +153,13 @@ def parse_clustering_settings(key: str, instruction: dict, symbol_table: SymbolT instruction) settings_objs.append(setting_obj) - return settings_objs + unique_objs = list(set(settings_objs)) + + if len(unique_objs) < len(settings_objs): + logging.warning(f"Clustering parser: clustering_settings contains {len(settings_objs) - len(unique_objs)} " + f"duplicate settings, keep the following: {[obj.get_key() for obj in unique_objs]}") + + return unique_objs def make_setting_obj(setting, valid_encodings, valid_clusterings, valid_dim_red, symbol_table, instruction): diff --git a/immuneML/environment/Constants.py b/immuneML/environment/Constants.py index 5bdb62fa0..5f110744f 100644 --- a/immuneML/environment/Constants.py +++ b/immuneML/environment/Constants.py @@ -1,6 +1,6 @@ class Constants: - VERSION = "3.0.16" + VERSION = "3.0.17" # encoding constants FEATURE_DELIMITER = "_" diff --git a/immuneML/ml_methods/dim_reduction/DimRedMethod.py b/immuneML/ml_methods/dim_reduction/DimRedMethod.py index 99950146f..f0dc899c2 100644 --- a/immuneML/ml_methods/dim_reduction/DimRedMethod.py +++ b/immuneML/ml_methods/dim_reduction/DimRedMethod.py @@ -1,5 +1,6 @@ import abc import logging +from abc import ABC from typing import List import numpy as np @@ -7,7 +8,7 @@ from immuneML.data_model.datasets.Dataset import Dataset -class DimRedMethod: +class DimRedMethod(ABC): """ Dimensionality reduction methods are algorithms which can be used to reduce the dimensionality of encoded datasets, in order to uncover and analyze patterns present in the data. @@ -21,14 +22,12 @@ def __init__(self, name: str = None): self.method = None self.name = name - @abc.abstractmethod def fit(self, dataset: Dataset = None, design_matrix: np.ndarray = None): if dataset is None: self.method.fit(design_matrix) else: self.method.fit(dataset.encoded_data.get_examples_as_np_matrix()) - @abc.abstractmethod def transform(self, dataset: Dataset = None, design_matrix: np.ndarray = None): if dataset is None: return self.method.transform(design_matrix) diff --git a/immuneML/ml_methods/generative_models/ProGen.py b/immuneML/ml_methods/generative_models/ProGen.py new file mode 100644 index 000000000..a78cc38e4 --- /dev/null +++ b/immuneML/ml_methods/generative_models/ProGen.py @@ -0,0 +1,314 @@ +import shutil +from pathlib import Path +from zipfile import ZipFile, ZIP_STORED + +import numpy as np +import pandas as pd + +from immuneML.data_model.SequenceParams import RegionType +from immuneML.data_model.bnp_util import get_sequence_field_name, write_yaml, read_yaml +from immuneML.data_model.datasets.Dataset import Dataset +from immuneML.data_model.datasets.ElementDataset import SequenceDataset +from immuneML.environment.SequenceType import SequenceType +from immuneML.ml_methods.generative_models.GenerativeModel import GenerativeModel +from immuneML.ml_methods.generative_models.progen.ProGenConfig import ProGenConfig +from immuneML.ml_methods.generative_models.progen.ProGenForCausalLM import ProGenForCausalLM +from immuneML.util.Logger import print_log +from immuneML.util.PathBuilder import PathBuilder + + +class ProGen(GenerativeModel): + """ + ProGen is a transformer-based language model for protein sequences. This class allows fine-tuning of a pre-trained + ProGen model on immune receptor sequences and generating new sequences. It is based on the ProGen2 implementation + available at https://github.com/salesforce/progen. It uses the sequences as given in "junction_aa" field in the + input dataset. + + References: + + Nijkamp, E., Ruffolo, J. A., Weinstein, E. N., Naik, N., & Madani, A. (2023). + Exploring the boundaries of protein language models. Cell Systems, 14(11), 968–978.e3. + https://doi.org/10.1016/j.cels.2023.10.002 + + **Specification arguments:** + + - locus (str): which locus the sequence come from, e.g., TRB + + - tokenizer_path (Path): path to the ProGen tokenizer file (tokenizer.json) + + - trained_model_path (Path): path to the pre-trained ProGen model directory + + - num_frozen_layers (int): number of transformer layers to freeze during fine-tuning + + - num_epochs (int): number of epochs for fine-tuning + + - learning_rate (float): learning rate for fine-tuning + + - device (str): device to use for training and inference ("cpu" or "cuda") + + - fp16 (bool): whether to use mixed precision training + + - prefix_text (str): text to prepend to each sequence during fine-tuning + + - suffix_text (str): text to append to each sequence during fine-tuning + + - max_new_tokens (int): maximum number of new tokens to generate + + - temperature (float): sampling temperature for sequence generation + + - top_p (float): nucleus sampling parameter for sequence generation + + - prompt (str): prompt text to start the generation + + - num_gen_batches (int): number of batches to split generation into + + - per_device_train_batch_size (int): batch size per device during fine-tuning + + - remove_affixes (bool): whether to remove prefix and suffix from generated sequences + + - seed (int): random seed for reproducibility + + + **YAML specification:** + + .. indent with spaces + .. code-block:: yaml + + definitions: + ml_methods: + progen_model: + ProGen: + locus: 'beta' + tokenizer_path: '/path/to/tokenizer.json' + trained_model_path: '/path/to/pretrained/progen/model' + num_frozen_layers: 27 + num_epochs: 3 + learning_rate: 0.00004 + device: 'cuda' + fp16: False + prefix_text: '<|bos|>1' + suffix_text: '2<|eos|>' + max_new_tokens: 1024 + temperature: 1.0 + top_p: 0.9 + prompt: '1' + num_gen_batches: 1 + per_device_train_batch_size: 2 + remove_affixes: True + name: 'progen_finetuned_model' + region_type: 'IMGT_JUNCTION' + seed: 42 + + """ + @classmethod + def load_model(cls, path: Path): + import torch + assert path.exists(), f"{cls.__name__}: {path} does not exist." + + model_overview_file = path / 'model_overview.yaml' + assert model_overview_file.exists(), f"{cls.__name__}: {model_overview_file} is not a file." + + # Uses ProGen weights/tokenizer paths from training. Override in model_overview.yaml if needed. + model_overview = read_yaml(model_overview_file) + progen = ProGen(**{k: v for k, v in model_overview.items() if k != 'type'}) + + config = ProGenConfig.from_pretrained(path) + model = ProGenForCausalLM.from_pretrained(path, + config=config, + dtype=torch.float32 if not progen.fp16 else torch.float16) + + progen.model = model.to(progen.device).eval() + + return progen + + def __init__(self, locus, tokenizer_path: Path, trained_model_path: Path, num_frozen_layers: int, num_epochs: int, + learning_rate: float, device: str, fp16: bool = False, prefix_text: str = "", suffix_text: str = "", + max_new_tokens: int = 1024, temperature: float = 1.0, top_p: float = 0.9, prompt: str = "1", + num_gen_batches: int = 1, per_device_train_batch_size: int = 2, remove_affixes: bool = True, + name: str = None, region_type: str = RegionType.IMGT_JUNCTION.name, seed: int = None, ): + super().__init__(locus, seed=seed, name=name, region_type=RegionType.get_object(region_type)) + self.sequence_type = SequenceType.AMINO_ACID + self.tokenizer_path = tokenizer_path + self.trained_model_path = trained_model_path + self.num_frozen_layers = num_frozen_layers + self.num_epochs = num_epochs + self.learning_rate = learning_rate + self.device = device # "cpu" or "cuda" + self.fp16 = fp16 + self.prefix_text = prefix_text + self.suffix_text = suffix_text + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.top_p = top_p + self.prompt = prompt + self.num_gen_batches = num_gen_batches + self.per_device_train_batch_size = per_device_train_batch_size + self.remove_affixes = remove_affixes + self.model = None + + from tokenizers import Tokenizer + from transformers import PreTrainedTokenizerFast + + tokenizer = Tokenizer.from_file(str(self.tokenizer_path)) + self.hf_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer) + self.hf_tokenizer.pad_token = "<|pad|>" + self.hf_tokenizer.eos_token = "<|eos|>" + self.hf_tokenizer.bos_token = "<|bos|>" + + def fit(self, data, path: Path = None): + assert path is not None, "ProGen.fit requires a target directory path for training outputs." + + logs_dir, output_dir = self._prepare_training_paths(path) + tokenized_dataset = self._preprocess_dataset(data) + + from transformers import DataCollatorForLanguageModeling + data_collator = DataCollatorForLanguageModeling(tokenizer=self.hf_tokenizer, mlm=False) + config = ProGenConfig.from_pretrained(self.trained_model_path) + model = ProGenForCausalLM.from_pretrained(self.trained_model_path, config=config) + self._freeze_model_layers(model) + + from transformers import TrainingArguments + training_args = TrainingArguments( + output_dir=str(output_dir), + per_device_train_batch_size=self.per_device_train_batch_size, + num_train_epochs=self.num_epochs, + learning_rate=self.learning_rate, + fp16=self.fp16, + use_cpu=True if self.device == "cpu" else False, + save_safetensors=False, + logging_dir=str(logs_dir), + save_total_limit=1, + save_strategy="no" + ) + + from transformers import Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=tokenized_dataset, + tokenizer=self.hf_tokenizer, + data_collator=data_collator + ) + + print_log(f"{self.name or ProGen.__name__}: starting ProGen fine-tuning.", True) + trainer.train() + print_log(f"{self.name or ProGen.__name__}: finished ProGen fine-tuning.", True) + self.model = trainer.model.to(self.device).eval() + + def _freeze_model_layers(self, model): + for layer in model.transformer.h[:self.num_frozen_layers]: + for param in layer.parameters(): + param.requires_grad = False + for param in model.transformer.ln_f.parameters(): + param.requires_grad = True + for param in model.lm_head.parameters(): + param.requires_grad = True + + def _preprocess_dataset(self, data): + from datasets import Dataset as HFDataset + data_df = data.data.topandas() + data_df["junction_aa"] = self.prefix_text + data_df["junction_aa"].astype(str).fillna("") + self.suffix_text + hf_dataset = HFDataset.from_pandas(data_df[["junction_aa"]], preserve_index=False) + tokenized_dataset = hf_dataset.map( + self.hf_tokenizer, + batched=True, + input_columns="junction_aa", + fn_kwargs={"truncation": True}, + remove_columns=["junction_aa"], + ) + return tokenized_dataset + + def _prepare_training_paths(self, path): + base_path = Path(path) + base_path.mkdir(parents=True, exist_ok=True) + output_dir = base_path / "model" + output_dir.mkdir(parents=True, exist_ok=True) + logs_dir = base_path / "logs" + logs_dir.mkdir(parents=True, exist_ok=True) + return logs_dir, output_dir + + def save_model(self, path: Path) -> Path: + model_path = PathBuilder.build(path / 'model') + self.model.save_pretrained(model_path, safe_serialization=False) + shutil.copy(self.tokenizer_path, model_path / Path(self.tokenizer_path).name) + + skip_export_keys = {"model", "tokenizer", "hf_tokenizer", 'region_type', 'sequence_type'} + write_yaml(filename=model_path / 'model_overview.yaml', + yaml_dict={**{k: v for k, v in vars(self).items() if k not in skip_export_keys}, + **{'type': self.__class__.__name__, + 'locus': self.locus.name}}) + + archive_path = path / f"trained_model_{self.name}.zip" + with ZipFile(archive_path, "w", compression=ZIP_STORED) as archive: + for file_path in (fp for fp in model_path.rglob("*") if fp.is_file()): + archive.write(file_path, file_path.relative_to(model_path)) + + return archive_path.resolve() + + def generate_sequences(self, count: int, seed: int, path: Path, sequence_type: SequenceType, + compute_p_gen: bool) -> Dataset: + import torch + prompt_encoding = self.hf_tokenizer(self.prompt, return_tensors="pt") + prompt_input_ids = prompt_encoding.input_ids.to(self.device) + prompt_attention_mask = prompt_encoding.attention_mask.to(self.device) + gen_sequences = [] + + num_sequences_per_batch = count // self.num_gen_batches + for i in range(self.num_gen_batches): + num_current_sequences = num_sequences_per_batch if i < self.num_gen_batches - 1 else count - ( + num_sequences_per_batch * (self.num_gen_batches - 1)) + with torch.inference_mode(): + output = self.model.generate( + input_ids=prompt_input_ids, + attention_mask=prompt_attention_mask, + max_new_tokens=self.max_new_tokens, + do_sample=True, + top_p=self.top_p, + temperature=self.temperature, + num_return_sequences=num_current_sequences, + pad_token_id=self.hf_tokenizer.pad_token_id, + return_dict_in_generate=False + ) + + gen_sequences.extend(self.hf_tokenizer.batch_decode(output, skip_special_tokens=True)) + + print_log(f"{self.name or ProGen.__name__}: {(i + 1) * num_current_sequences} sequences generated.", True) + + if self.remove_affixes: + gen_sequences = self._remove_affixes(gen_sequences) + + gen_sequences_df = pd.DataFrame({get_sequence_field_name(self.region_type, self.sequence_type): gen_sequences, + 'locus': [self.locus.to_string() for _ in range(count)], + 'gen_model_name': [self.name for _ in range(count)]}) + + return SequenceDataset.build_from_partial_df(gen_sequences_df, PathBuilder.build(path), + 'synthetic_dataset', {'gen_model_name': [self.name]}, + {'gen_model_name': str}) + + def _remove_affixes(self, gen_sequences): + prefix_text = self.hf_tokenizer.decode(self.hf_tokenizer(self.prefix_text).input_ids, + skip_special_tokens=True) + suffix_text = self.hf_tokenizer.decode(self.hf_tokenizer(self.suffix_text).input_ids, + skip_special_tokens=True) + gen_sequences = [seq.replace(prefix_text, '').replace(suffix_text, '') for seq in + gen_sequences] + return gen_sequences + + def is_same(self, model) -> bool: + raise NotImplementedError + + def compute_p_gens(self, sequences, sequence_type: SequenceType) -> np.ndarray: + raise RuntimeError + + def compute_p_gen(self, sequence: dict, sequence_type: SequenceType) -> float: + raise RuntimeError + + def can_compute_p_gens(self) -> bool: + return False + + def can_generate_from_skewed_gene_models(self) -> bool: + return False + + def generate_from_skewed_gene_models(self, v_genes: list, j_genes: list, seed: int, path: Path, + sequence_type: SequenceType, batch_size: int, compute_p_gen: bool): + return RuntimeError diff --git a/immuneML/ml_methods/generative_models/progen/ProGenConfig.py b/immuneML/ml_methods/generative_models/progen/ProGenConfig.py new file mode 100644 index 000000000..b2a1be1e5 --- /dev/null +++ b/immuneML/ml_methods/generative_models/progen/ProGenConfig.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified configuration implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/configuration_gptj.py +# Modified for ImmuneML. Original code from https://github.com/salesforce/progen + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class ProGenConfig(PretrainedConfig): + model_type = "progen" + + def __init__( + self, + vocab_size=50400, + n_positions=2048, + n_ctx=2048, + n_embd=4096, + n_layer=28, + n_head=16, + rotary_dim=64, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + gradient_checkpointing=False, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + **kwargs + ): + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer diff --git a/immuneML/ml_methods/generative_models/progen/ProGenForCausalLM.py b/immuneML/ml_methods/generative_models/progen/ProGenForCausalLM.py new file mode 100644 index 000000000..72bf1c831 --- /dev/null +++ b/immuneML/ml_methods/generative_models/progen/ProGenForCausalLM.py @@ -0,0 +1,693 @@ +# coding=utf-8 +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Modified forward-pass implementation based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/gptj/modeling_gptj.py +# Modified for ImmuneML. Original code from https://github.com/salesforce/progen + +from typing import Tuple + +import numpy as np + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import GenerationMixin + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from .ProGenConfig import ProGenConfig + + +logger = logging.get_logger(__name__) + + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim)) + sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float() + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), axis=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class ProGenAttention(nn.Module): + # Shared attention mask buffers are deterministic and replicated across layers; skip them when exporting to avoid shared-tensor errors. + _keys_to_ignore_on_save = ["bias", "masked_bias"] + _keys_to_ignore_on_load_missing = ["bias", "masked_bias"] + + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False) + + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + def _split_heads(self, x, n_head, dim_head, mp_num): + reshaped = x.reshape(x.shape[:-1] + (n_head//mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) + return reshaped + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into n_ctx + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = query.to(torch.float32) + key = key.to(torch.float32) + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + attn_weights = attn_weights / self.scale_attn + attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype)) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(dim=-1)(attn_weights) + attn_weights = attn_weights.to(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + + qkv = self.qkv_proj(hidden_states) + # TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic + # mp_num = 4 + mp_num = 8 + qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) + + local_dim = self.head_dim * self.num_attention_heads // mp_num + query, value, key = torch.split(qkv_split, local_dim, dim=-1) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + + value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = value.permute(0, 2, 1, 3) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +class ProGenMLP(nn.Module): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd + + self.fc_in = nn.Linear(embed_dim, intermediate_size) + self.fc_out = nn.Linear(intermediate_size, embed_dim) + + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ProGenBlock(nn.Module): + def __init__(self, config): + super().__init__() + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = ProGenAttention(config) + self.mlp = ProGenMLP(inner_dim, config) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class ProGenPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ProGenConfig + base_model_prefix = "transformer" + is_parallelizable = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear,)): + # Slightly different from Mesh Transformer JAX which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class ProGenModel(ProGenPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + assert batch_size > 0, "batch_size has to be defined and > 0" + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.size(-1),) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " + "`use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + None, + attention_mask, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ProGenForCausalLM(ProGenPreTrainedModel, GenerationMixin): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = ProGenModel(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + def deparallelize(self): + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + return + + def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to + ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]`` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + # make sure sampling in fp16 works correctly and + # compute loss in fp32 to match with mesh-tf version + # https://github.com/EleutherAI/gpt-neo/blob/89ce74164da2fb16179106f54e2269b5da8db333/models/gpt2/gpt2.py#L179 + lm_logits = self.lm_head(hidden_states).to(torch.float32) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) diff --git a/immuneML/ml_methods/generative_models/progen/ProGen_LICENSE.txt b/immuneML/ml_methods/generative_models/progen/ProGen_LICENSE.txt new file mode 100644 index 000000000..8184ccecc --- /dev/null +++ b/immuneML/ml_methods/generative_models/progen/ProGen_LICENSE.txt @@ -0,0 +1,14 @@ +BSD 3-Clause License + +Copyright (c) 2022, Salesforce.com, Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/immuneML/ml_methods/generative_models/progen/__init__.py b/immuneML/ml_methods/generative_models/progen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/immuneML/preprocessing/filters/ChainRepertoireFilter.py b/immuneML/preprocessing/filters/ChainRepertoireFilter.py index de091087c..b8c5d406c 100644 --- a/immuneML/preprocessing/filters/ChainRepertoireFilter.py +++ b/immuneML/preprocessing/filters/ChainRepertoireFilter.py @@ -4,6 +4,7 @@ from immuneML.data_model.SequenceSet import Repertoire from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset from immuneML.preprocessing.filters.Filter import Filter +from immuneML.util.ParameterValidator import ParameterValidator from immuneML.util.PathBuilder import PathBuilder @@ -23,12 +24,11 @@ class ChainRepertoireFilter(Filter): **Specification arguments:** - - keep_chain (str): Which chain should be kept, valid values are "TRA", "TRB", "IGH", "IGL", "IGK" + - keep_chains (list): Which chains should be kept, valid values are "TRA", "TRB", "IGH", "IGL", "IGK" - remove_only_sequences (bool): Whether to remove only sequences with different chain than "keep_chain" (true) in case of repertoire datasets; default is false - **YAML specification:** .. indent with spaces @@ -38,14 +38,15 @@ class ChainRepertoireFilter(Filter): my_preprocessing: - my_filter: ChainRepertoireFilter: - keep_chain: TRB + keep_chains: [TRB] remove_only_sequences: true """ - def __init__(self, keep_chain, remove_only_sequences: bool = False, result_path: Path = None): + def __init__(self, keep_chains: list, remove_only_sequences: bool = False, result_path: Path = None): super().__init__(result_path) - self.keep_chain = Chain.get_chain(keep_chain) + ParameterValidator.assert_type_and_value(keep_chains, list, "ChainRepertoireFilter", "keep_chains") + self.keep_chains = [Chain.get_chain(keep_chain) for keep_chain in keep_chains] self.remove_only_sequences = remove_only_sequences def process_dataset(self, dataset: RepertoireDataset, result_path: Path, number_of_processes=1): @@ -58,20 +59,21 @@ def _filter_repertoire_dataset(self, processed_dataset: RepertoireDataset): repertoires = [] indices = [] + valid_chains = [chain.value for chain in self.keep_chains] for index, repertoire in enumerate(processed_dataset.get_data()): if self.remove_only_sequences: data = repertoire.data - data = data[[Chain.get_chain(l).value == self.keep_chain.value for l in data.locus.tolist()]] + data = data[[Chain.get_chain(l).value in valid_chains for l in data.locus.tolist()]] if len(data) == 0: continue new_repertoire = Repertoire.build_from_dc_object(self.result_path, repertoire.metadata, data=data, - filename_base=repertoire.data_filename.stem + "filtered") + filename_base=repertoire.data_filename.stem) repertoires.append(new_repertoire) indices.append(index) else: - if all(Chain.get_chain(l).value == self.keep_chain.value for l in repertoire.data.locus.tolist()): + if all(Chain.get_chain(l).value in valid_chains for l in repertoire.data.locus.tolist()): repertoires.append(repertoire) indices.append(index) diff --git a/immuneML/presentation/html/Util.py b/immuneML/presentation/html/Util.py index c18278c5e..fe6d06e39 100644 --- a/immuneML/presentation/html/Util.py +++ b/immuneML/presentation/html/Util.py @@ -4,6 +4,7 @@ from enum import Enum from pathlib import Path +import numpy as np import pandas as pd from immuneML.reports.ReportOutput import ReportOutput @@ -132,7 +133,22 @@ def make_dataset_html_map(dataset, dataset_key="dataset"): f"{dataset_key}_locus": ", ".join(dataset.get_locus()), f"{dataset_key}_size": f"{dataset.get_example_count()} {type(dataset).__name__.replace('Dataset', 's').lower()}", f"{dataset_key}_labels": [{f"{dataset_key}_label_name": label_name, - f"{dataset_key}_label_classes": ", ".join(str(class_name) - for class_name in dataset.labels.get(label_name, []))} + f"{dataset_key}_label_classes": get_label_values_to_show(dataset.labels.get(label_name, []), dataset.get_example_count())} for label_name in dataset.get_label_names()], - f"show_{dataset_key}_labels": len(dataset.get_label_names()) > 0} \ No newline at end of file + f"show_{dataset_key}_labels": len(dataset.get_label_names()) > 0} + +def get_label_values_to_show(label_values: list, dataset_size: int): + if len(label_values) == dataset_size: + return "unique per example" + elif any(isinstance(label, float) and not label != np.nan for label in label_values): + try: + return f"{min(label_values):.2f} to {max(label_values):.2f}" + except Exception as e: + return 'mixed value types' + elif any(isinstance(label, int) for label in label_values): + try: + return f"{min(label_values)} to {max(label_values)}" + except Exception as e: + return 'mixed value types' + else: + return ", ".join(str(label) for label in label_values) diff --git a/immuneML/presentation/html/templates/css/custom.css b/immuneML/presentation/html/templates/css/custom.css index 2cd0c3229..aeee50104 100644 --- a/immuneML/presentation/html/templates/css/custom.css +++ b/immuneML/presentation/html/templates/css/custom.css @@ -17,7 +17,7 @@ body { :root { --primary-color: #2c3e50; --secondary-color: #34495e; - --accent-color: #009688; + --accent-color: #708090; --background-color: #f8f9fa; --text-color: #2c3e50; } @@ -299,7 +299,7 @@ code { } .nested-list-container ul li ul li { - border-left: 1px solid #68b5b3; /* Accent for nested items */ + border-left: 1px solid #708090; /* Accent for nested items */ margin: 4px 0; } diff --git a/immuneML/reports/clustering_method_reports/ClusteringVisualization.py b/immuneML/reports/clustering_method_reports/ClusteringVisualization.py index 674f5999f..21b6d08af 100644 --- a/immuneML/reports/clustering_method_reports/ClusteringVisualization.py +++ b/immuneML/reports/clustering_method_reports/ClusteringVisualization.py @@ -1,9 +1,11 @@ +import logging from pathlib import Path import pandas as pd import plotly import plotly.express as px +from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset from immuneML.dsl.definition_parsers.MLParser import MLParser from immuneML.ml_methods.dim_reduction.DimRedMethod import DimRedMethod from immuneML.reports.PlotlyUtil import PlotlyUtil @@ -23,7 +25,9 @@ class ClusteringVisualization(ClusteringMethodReport): - dim_red_method (dict): specification of which dimensionality reduction to perform; valid options are presented under :ref:`**Dimensionality reduction methods**` and should be specified with the name of the method and its - parameters, see the example below + parameters, see the example below; if not specified, the report will use any dimensionality reduced data + present in the dataset's encoded data; if the dataset does not contain dimensionality reduced data, and the + encoded data has more than 2 dimensions, the report will be skipped. YAML specification: @@ -42,6 +46,10 @@ class ClusteringVisualization(ClusteringMethodReport): TSNE: n_components: 2 init: pca + my_report_existing_dim_red: + ClusteringVisualization: + dim_red_method: null + """ def __init__(self, dim_red_method: DimRedMethod = None, name: str = None, @@ -50,7 +58,7 @@ def __init__(self, dim_red_method: DimRedMethod = None, name: str = None, self.dim_red_method = dim_red_method self.result_name = None self.desc = "Clustering Visualization" - self._dimension_names = self.dim_red_method.get_dimension_names() + self._dimension_names = self.dim_red_method.get_dimension_names() if self.dim_red_method else None @classmethod def build_object(cls, **kwargs): @@ -62,7 +70,9 @@ def build_object(cls, **kwargs): method_name = list(kwargs["dim_red_method"].keys())[0] dim_red_method = MLParser.parse_any_model("dim_red_method", kwargs["dim_red_method"], method_name)[0] else: - raise ValueError(f"{location}: dim_red_method must be specified.") + logging.warning(f"{location}: No dimensionality reduction method specified. " + "If the encoded dataset includes dimensionality reduction, it will be used.") + dim_red_method = None return cls(dim_red_method=dim_red_method, name=name, result_path=result_path, clustering_item=kwargs['clustering_item'] if 'clustering_item' in kwargs else None,) @@ -77,11 +87,22 @@ def _generate(self) -> ReportResult: f"Clustering visualization for {self.item.cl_setting.get_key()}") return ReportResult(f"{self.desc} ({self.name})", - info=f"{self.dim_red_method.__class__.__name__} visualizations of clustering results", + info=f"Visualizations of clustering results", output_figures=[report_output]) def _make_plot(self, result_path: Path) -> Path: - transformed_data = self.dim_red_method.fit_transform(dataset=self.item.dataset) + if self.dim_red_method is not None: + transformed_data = self.dim_red_method.fit_transform(dataset=self.item.dataset) + elif self.item.dataset.encoded_data.dimensionality_reduced_data is not None: + transformed_data = self.item.dataset.encoded_data.dimensionality_reduced_data + self._dimension_names = self.item.dataset.encoded_data.dim_names if self.item.dataset.encoded_data.dim_names else ['dim1', 'dim2'] + elif self.item.dataset.encoded_data.examples.shape[1] <= 2: + transformed_data = self.item.dataset.encoded_data.get_examples_as_np_matrix() + self._dimension_names = self.item.dataset.encoded_data.feature_names + else: + raise ValueError("ClusteringVisualization: No dimensionality reduction method specified, and the dataset " + "does not contain dimensionality reduced data. Please specify a dimensionality reduction " + "method.") df = pd.DataFrame(transformed_data, columns=self._dimension_names) df['cluster'] = pd.Series(self.item.predictions).astype(str) @@ -94,10 +115,20 @@ def _make_plot(self, result_path: Path) -> Path: fig.update_layout(template="plotly_white") - df.to_csv(result_path / f"clustering_visualization_{self.dim_red_method.name}.csv", index=False) + df.to_csv(result_path / f"clustering_visualization_{self.dim_red_method.name if self.dim_red_method else ''}.csv", index=False) plot_path = PlotlyUtil.write_image_to_file(fig, - result_path / f"clustering_visualization_{self.dim_red_method.name}.html", + result_path / f"clustering_visualization_{self.dim_red_method.name if self.dim_red_method else ''}.html", df.shape[0]) return plot_path + + def get_ids(self): + if isinstance(self.item.dataset, RepertoireDataset): + metadata = self.item.dataset.get_metadata(['subject_id'], return_df=True) + if 'subject_id' in metadata.columns: + return metadata['subject_id'].tolist() + else: + return self.item.dataset.get_example_ids() + else: + return self.item.dataset.get_example_ids() diff --git a/immuneML/reports/clustering_reports/ClusteringStabilityReport.py b/immuneML/reports/clustering_reports/ClusteringStabilityReport.py index ad57fa57c..24b3d3913 100644 --- a/immuneML/reports/clustering_reports/ClusteringStabilityReport.py +++ b/immuneML/reports/clustering_reports/ClusteringStabilityReport.py @@ -115,9 +115,9 @@ def _generate(self) -> ReportResult: def make_figure(self, df: pd.DataFrame) -> ReportOutput: import plotly.express as px - fig = px.box(df, x='clustering_setting', y=self.metric, points='all', - color_discrete_sequence=px.colors.qualitative.Set2) - fig.update_layout(xaxis_title="clustering setting", + fig = px.box(df, x='clustering_setting', y=self.metric, points='all', color='clustering_setting', + color_discrete_sequence=px.colors.qualitative.Vivid) + fig.update_layout(xaxis_title="clustering setting", showlegend=False, yaxis_title=self.metric, template="plotly_white") fig.update_traces(marker=dict(opacity=0.5), jitter=0.3) diff --git a/immuneML/reports/data_reports/AminoAcidFrequencyDistribution.py b/immuneML/reports/data_reports/AminoAcidFrequencyDistribution.py index 2ea2de394..23b0f995a 100644 --- a/immuneML/reports/data_reports/AminoAcidFrequencyDistribution.py +++ b/immuneML/reports/data_reports/AminoAcidFrequencyDistribution.py @@ -287,7 +287,7 @@ def _plot_distribution(self, freq_dist): figure.update_yaxes(tickformat=",.0%", range=[0, 1]) file_path = self.result_path / "amino_acid_frequency_distribution.html" - file_path = PlotlyUtil.write_image_to_file(figure, str(file_path), 20) + file_path = PlotlyUtil.write_image_to_file(figure, file_path, 20) return ReportOutput(path=file_path, name="Amino acid frequency distribution") diff --git a/immuneML/reports/data_reports/LabelDist.py b/immuneML/reports/data_reports/LabelDist.py index 08d803cfd..7c3ce3634 100644 --- a/immuneML/reports/data_reports/LabelDist.py +++ b/immuneML/reports/data_reports/LabelDist.py @@ -55,7 +55,7 @@ def _generate(self) -> ReportResult: fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=df.columns) - colors = px.colors.qualitative.Plotly * math.ceil(len(df.columns) / len(px.colors.qualitative.Plotly)) + colors = px.colors.qualitative.Vivid * math.ceil(len(df.columns) / len(px.colors.qualitative.Vivid)) for i, col in enumerate(df.columns): row = i // n_cols + 1 diff --git a/immuneML/reports/data_reports/LabelOverlap.py b/immuneML/reports/data_reports/LabelOverlap.py index 75676bac4..128d3f36f 100644 --- a/immuneML/reports/data_reports/LabelOverlap.py +++ b/immuneML/reports/data_reports/LabelOverlap.py @@ -3,6 +3,7 @@ import pandas as pd import plotly.graph_objects as go +import plotly.express as px from immuneML.data_model.datasets.Dataset import Dataset from immuneML.reports.PlotlyUtil import PlotlyUtil @@ -73,7 +74,7 @@ def _generate(self) -> ReportResult: z=overlap_matrix.values, x=overlap_matrix.columns, y=overlap_matrix.index, - colorscale=[[0, '#e5f6f6'], [0.5, '#66b2b2'], [1, '#006666']], # Custom teal colorscale + colorscale=px.colors.sequential.Viridis, # Custom teal colorscale text=overlap_matrix.values, texttemplate="%{text}", hovertemplate=f"{self.row_label}: " + "%{y}
" + f"{self.column_label}: " @@ -89,7 +90,8 @@ def _generate(self) -> ReportResult: xaxis_title=self.column_label, yaxis_title=self.row_label, template="plotly_white", - font=dict(size=12) + font=dict(size=12), + height=max(600, 40 * len(overlap_matrix.index) + 200), ) # Save plot diff --git a/immuneML/reports/data_reports/RepertoireClonotypeSummary.py b/immuneML/reports/data_reports/RepertoireClonotypeSummary.py index 6d0fd1475..9f8783eb5 100644 --- a/immuneML/reports/data_reports/RepertoireClonotypeSummary.py +++ b/immuneML/reports/data_reports/RepertoireClonotypeSummary.py @@ -77,7 +77,7 @@ def _plot(self) -> ReportResult: fig = px.bar(clonotypes, x='repertoire_index', y='clonotype_count', facet_row=self.facet_label, color=self.color_label, title='Clonotype count per repertoire', - color_discrete_sequence=px.colors.diverging.Tealrose) + color_discrete_sequence=px.colors.qualitative.Vivid) fig.update_layout(template="plotly_white", yaxis_title='clonotype count', xaxis_title='repertoires') diff --git a/immuneML/reports/data_reports/SequenceLengthDistribution.py b/immuneML/reports/data_reports/SequenceLengthDistribution.py index 57e1b2a4d..dfcb90826 100644 --- a/immuneML/reports/data_reports/SequenceLengthDistribution.py +++ b/immuneML/reports/data_reports/SequenceLengthDistribution.py @@ -212,7 +212,7 @@ def _plot(self, df: pd.DataFrame) -> ReportOutput: facet_col=self.label_name if self.label_name in df.columns else None, facet_row="chain" if isinstance(self.dataset, ReceptorDataset) else None) figure.update_layout(template="plotly_white") - figure.update_traces(marker_color=px.colors.diverging.Tealrose[0]) + figure.update_traces(marker_color=px.colors.qualitative.Vivid[1]) for annotation in figure.layout.annotations: annotation['font'] = {'size': 16} diff --git a/immuneML/reports/data_reports/ShannonDiversityOverview.py b/immuneML/reports/data_reports/ShannonDiversityOverview.py index ffdfd333c..619d0cf40 100644 --- a/immuneML/reports/data_reports/ShannonDiversityOverview.py +++ b/immuneML/reports/data_reports/ShannonDiversityOverview.py @@ -115,7 +115,7 @@ def _plot(self, encoded_df) -> ReportOutput: fig = px.bar(encoded_df, x='repertoire_index', y='shannon_diversity', facet_row=self.facet_row_label, color=self.color_label, title='Shannon diversity per repertoire', facet_col=self.facet_col_label, - color_discrete_sequence=px.colors.diverging.Tealrose, hover_data=hover_data_cols) + color_discrete_sequence=px.colors.qualitative.Vivid, hover_data=hover_data_cols) fig.update_layout(template="plotly_white", yaxis_title='Shannon diversity', xaxis_title='Repertoires sorted by Shannon diversity') diff --git a/immuneML/reports/data_reports/SimpleDatasetOverview.py b/immuneML/reports/data_reports/SimpleDatasetOverview.py deleted file mode 100644 index 4d491ab34..000000000 --- a/immuneML/reports/data_reports/SimpleDatasetOverview.py +++ /dev/null @@ -1,108 +0,0 @@ -from pathlib import Path - -from immuneML.data_model.datasets.Dataset import Dataset -from immuneML.data_model.datasets.ElementDataset import ReceptorDataset, SequenceDataset -from immuneML.data_model.datasets.RepertoireDataset import RepertoireDataset -from immuneML.reports.ReportOutput import ReportOutput -from immuneML.reports.ReportResult import ReportResult -from immuneML.reports.data_reports.DataReport import DataReport -from immuneML.util.PathBuilder import PathBuilder - - -class SimpleDatasetOverview(DataReport): - """ - Generates a simple text-based overview of the properties of any dataset, including the dataset name, size, and metadata labels. - - **YAML specification:** - - .. indent with spaces - .. code-block:: yaml - - definitions: - reports: - my_overview: SimpleDatasetOverview - - """ - UNKNOWN_CHAIN = "unknown" - - def __init__(self, dataset: Dataset = None, result_path: Path = None, number_of_processes: int = 1, name: str = None): - super().__init__(dataset=dataset, result_path=result_path, number_of_processes=number_of_processes, name=name) - - @classmethod - def build_object(cls, **kwargs): - return SimpleDatasetOverview(**kwargs) - - def _generate(self) -> ReportResult: - PathBuilder.build(self.result_path) - - text_path = self.result_path / "dataset_description.txt" - - dataset_name = self.dataset.name if self.dataset.name is not None else self.dataset.identifier - - output_text = self._get_generic_dataset_text() - - if isinstance(self.dataset, RepertoireDataset): - output_text += self._get_repertoire_dataset_text() - elif isinstance(self.dataset, ReceptorDataset): - output_text += self._get_receptor_dataset_text() - elif isinstance(self.dataset, SequenceDataset): - output_text += self._get_sequence_dataset_text() - - text_path.write_text(output_text) - - return ReportResult(name=self.name, - info=f"A simple overview of the properties of dataset {self.dataset.name}", - output_text=[ReportOutput(text_path, f"Description of dataset {dataset_name}")]) - - def _get_generic_dataset_text(self): - element_name = type(self.dataset).__name__.replace("Dataset", "s").lower() - - output_text = f"Dataset name: {self.dataset.name}\n" \ - f"Dataset identifier: {self.dataset.identifier}\n" \ - f"Dataset type: {type(self.dataset).__name__}\n" \ - f"Dataset size: {self.dataset.get_example_count()} {element_name}\n" \ - f"Labels available for classification:" - - if len(self.dataset.get_label_names()) == 0: - output_text += " None" - else: - for label in self.dataset.get_label_names(): - output_text += "\n - " + label - - return output_text - - def _get_repertoire_dataset_text(self): - output_text = f"\nmetadata file location: {self.dataset.metadata_file}\n" - - output_text += "\n\nProperties per repertoire:\n" - for repertoire in self.dataset.repertoires: - output_text += f"- Name: {repertoire.data_filename.name}\n" - output_text += f" Number of sequences: {repertoire.get_element_count()}\n" - - chains = list(set(repertoire.data.locus.tolist())) - if len(chains) == 1: - output_text += f" Chain type: {chains[0]}\n" - else: - output_text += f" Chain types: {','.join(chains)}\n" - - return output_text - - def _get_receptor_dataset_text(self): - receptor_types = list(set([receptor.chain_pair.name for receptor in self.dataset.get_data()])) - - if len(receptor_types) > 1: - output_text = "\nReceptor types: " + ",".join(receptor_types) - else: - output_text = "\nReceptor type: " + receptor_types[0] - - return output_text - - def _get_sequence_dataset_text(self): - chains = list(set(self.dataset.data.locus.tolist())) - - if len(chains) > 1: - output_text = "\nChain types: " + ",".join(chains) - else: - output_text = "\nChain type: " + chains[0] - - return output_text diff --git a/immuneML/reports/data_reports/TrueMotifsSummaryBarplot.py b/immuneML/reports/data_reports/TrueMotifsSummaryBarplot.py index 9eb7f1677..5d9328252 100644 --- a/immuneML/reports/data_reports/TrueMotifsSummaryBarplot.py +++ b/immuneML/reports/data_reports/TrueMotifsSummaryBarplot.py @@ -191,18 +191,32 @@ def _plot(self, plotting_data, output_name): df_melted = self._prepare_and_sort_plotting_data(plotting_data, signal_names, sorted_data_origins, seq_counts_df) + df_melted['data_origin_novelty'] = df_melted['data_origin'] + ' - ' + df_melted['novelty_memorization'] figure = px.bar( df_melted, x='signal', y='frequency', - color='novelty_memorization', + color='data_origin_novelty', facet_col='label', - color_discrete_sequence=px.colors.diverging.Tealrose, + color_discrete_sequence=px.colors.qualitative.Vivid, barmode='stack', title='Percentage of sequences containing signals across different generated datasets', ) + # make bars grouped by data_origin and stacked by novelty + # figure.update_traces(offsetgroup=df_melted["data_origin"], # group by data_origin + # opacity=1.0) + + # now reduce opacity for each novelty_memorization level + color_map = {data_origin: px.colors.qualitative.Vivid[i] + for i, data_origin in enumerate(plotting_data['data_origin'].unique())} + opacity_map = {"original": 1.0, "memorized": 1., "novel": 0.5} + for trace in figure.data: + data_origin, novelty = trace.name.split(' - ') + trace.opacity = opacity_map[novelty] + trace.marker.color = color_map[data_origin] + figure.for_each_annotation(lambda a: a.update(text=a.text.replace("label=", ""))) figure.update_layout( diff --git a/immuneML/reports/data_reports/VJGeneDistribution.py b/immuneML/reports/data_reports/VJGeneDistribution.py index 503dd7342..9a21685b6 100644 --- a/immuneML/reports/data_reports/VJGeneDistribution.py +++ b/immuneML/reports/data_reports/VJGeneDistribution.py @@ -43,6 +43,8 @@ class VJGeneDistribution(DataReport): (e.g., antigen binding versus non-binding across repertoires) or repertoire level (e.g., diseased repertoires versus healthy repertoires). By default, is_sequence_label is False. For Sequence- and ReceptorDatasets, this parameter is ignored. + - show_joint_dist (bool): whether to show the combined V and J gene distribution. Default is True. + **YAML specification:** @@ -54,6 +56,7 @@ class VJGeneDistribution(DataReport): my_vj_gene_report: VJGeneDistribution: label: ag_binding + show_joint_dist: false """ @@ -66,11 +69,13 @@ def build_object(cls, **kwargs): return VJGeneDistribution(**kwargs) def __init__(self, dataset: Dataset = None, result_path: Path = None, number_of_processes: int = 1, - name: str = None, split_by_label: bool = None, label: str = None, is_sequence_label: bool = None): + name: str = None, split_by_label: bool = None, label: str = None, is_sequence_label: bool = None, + show_joint_dist: bool = True): super().__init__(dataset=dataset, result_path=result_path, number_of_processes=number_of_processes, name=name) self.split_by_label = split_by_label self.label_name = label self.is_sequence_label = is_sequence_label + self.show_joint_dist = show_joint_dist def _generate(self) -> ReportResult: PathBuilder.build(self.result_path) @@ -99,7 +104,10 @@ def _get_sequence_receptor_results(self): v_tables, v_plots = self._get_single_gene_results_from_attributes(dataset_attributes, "v_call") j_tables, j_plots = self._get_single_gene_results_from_attributes(dataset_attributes, "j_call") - vj_tables, vj_plots = self._get_combo_gene_results_from_attributes(dataset_attributes) + if self.show_joint_dist: + vj_tables, vj_plots = self._get_combo_gene_results_from_attributes(dataset_attributes) + else: + vj_tables, vj_plots = [], [] return ReportResult(name=self.name, info="V and J gene distributions", @@ -166,7 +174,7 @@ def _plot_gene_distribution(self, df, title, filename): figure = px.bar(df, x="genes", y="counts", color=self.label_name, labels={"genes": "Gene names", "counts": "Observed frequency"}, - color_discrete_sequence=px.colors.diverging.Tealrose) + color_discrete_sequence=px.colors.qualitative.Vivid) figure.update_layout(xaxis=dict(tickmode='array', tickvals=df["genes"]), yaxis=dict(tickmode='array', tickvals=df["counts"]), template="plotly_white", @@ -182,7 +190,6 @@ def _get_combo_gene_results_from_attributes(self, dataset_attributes): vj_combo_count_df = self._get_vj_combo_count_df(dataset_attributes) - for chain in set(dataset_attributes["locus"]): chain_df = vj_combo_count_df[vj_combo_count_df["locus"] == chain] @@ -233,7 +240,7 @@ def _get_repertoire_results(self): plots = [] - for chain in set(vj_df["locus"]): + for chain in set(v_df["locus"]): plots.append(self._safe_plot(plot_callable="_plot_gene_distribution_across_repertoires", chain_df=v_df[v_df["locus"] == chain], title=f"{chain} V gene distribution per repertoire", @@ -244,13 +251,14 @@ def _get_repertoire_results(self): title=f"{chain} J gene distribution per repertoire", filename=f"{chain}J_gene_distribution.html")) - mean_chain_vj_df = self._average_norm_counts_per_repertoire(chain_vj_df=vj_df[vj_df["locus"] == chain]) + if self.show_joint_dist: + mean_chain_vj_df = self._average_norm_counts_per_repertoire(chain_vj_df=vj_df[vj_df["locus"] == chain]) - tables.append(self._write_output_table(mean_chain_vj_df, - file_path=self.result_path / f"{chain}VJ_gene_distribution_averaged_across_repertoires.tsv", - name=f"Combined {chain} V+J gene distribution averaged across repertoires")) + tables.append(self._write_output_table(mean_chain_vj_df, + file_path=self.result_path / f"{chain}VJ_gene_distribution_averaged_across_repertoires.tsv", + name=f"Combined {chain} V+J gene distribution averaged across repertoires")) - plots.extend(self._get_repertoire_heatmaps(mean_chain_vj_df, chain)) + plots.extend(self._get_repertoire_heatmaps(mean_chain_vj_df, chain)) return ReportResult(name=self.name, info="V and J gene distributions per repertoire", @@ -295,10 +303,6 @@ def _get_repertoire_count_dfs(self): for repertoire in self.dataset.repertoires: data = repertoire.data - if hasattr(data, "locus"): - assert len(set(data.locus.tolist())) == 1, (f"{VJGeneDistribution.__name__}: Repertoire {repertoire.identifier} of dataset {self.dataset.name} contained multiple loci: {set(data.locus.tolist())}. " - f"This report can only be created for 1 locus per repertoire.") - repertoire_attributes = {"v_call": data.v_call.tolist(), "j_call": data.j_call.tolist(), @@ -309,19 +313,20 @@ def _get_repertoire_count_dfs(self): v_rep_df = self._get_gene_count_df(repertoire_attributes, "v_call", include_label=self.is_sequence_label) j_rep_df = self._get_gene_count_df(repertoire_attributes, "j_call", include_label=self.is_sequence_label) - vj_rep_df = self._get_vj_combo_count_df(repertoire_attributes, include_label=self.is_sequence_label) self._supplement_repertoire_df(v_rep_df, repertoire) self._supplement_repertoire_df(j_rep_df, repertoire) - self._supplement_repertoire_df(vj_rep_df, repertoire) v_dfs.append(v_rep_df) j_dfs.append(j_rep_df) + + vj_rep_df = self._get_vj_combo_count_df(repertoire_attributes, include_label=self.is_sequence_label) + self._supplement_repertoire_df(vj_rep_df, repertoire) vj_dfs.append(vj_rep_df) return pd.concat(v_dfs, ignore_index=True), \ pd.concat(j_dfs, ignore_index=True), \ - pd.concat(vj_dfs, ignore_index=True), + pd.concat(vj_dfs, ignore_index=True) if self.show_joint_dist else None def _supplement_repertoire_df(self, rep_df, repertoire): rep_df["repertoire_id"] = repertoire.identifier @@ -344,9 +349,10 @@ def _write_repertoire_tables(self, v_df, j_df, vj_df): file_path=self.result_path / f"{chain}J_gene_distribution.tsv", name=f"{chain}J gene distribution per repertoire")) - tables.append(self._write_output_table(vj_df[vj_df['locus'] == chain], - file_path=self.result_path / f"{chain}VJ_gene_distribution.tsv", - name=f"Combined {chain}V+J gene distribution")) + if self.show_joint_dist: + tables.append(self._write_output_table(vj_df[vj_df['locus'] == chain], + file_path=self.result_path / f"{chain}VJ_gene_distribution.tsv", + name=f"Combined {chain}V+J gene distribution")) return tables @@ -355,7 +361,7 @@ def _plot_gene_distribution_across_repertoires(self, chain_df, title, filename): hover_data=["repertoire_id", "subject_id"], labels={"genes": "Gene names", "norm_counts": "Fraction of the repertoire"}, - color_discrete_sequence=px.colors.diverging.Tealrose) + color_discrete_sequence=px.colors.qualitative.Vivid) figure.update_layout(template="plotly_white") file_path = self.result_path / filename diff --git a/immuneML/reports/encoding_reports/DimensionalityReduction.py b/immuneML/reports/encoding_reports/DimensionalityReduction.py index 1490fd0cd..3aef35e9f 100644 --- a/immuneML/reports/encoding_reports/DimensionalityReduction.py +++ b/immuneML/reports/encoding_reports/DimensionalityReduction.py @@ -144,10 +144,10 @@ def _plot(self, df: pd.DataFrame) -> List[ReportOutput]: elif 'example_id' in df_copy.columns: hover_data += ['example_id'] - if len(unique_values) <= 3: + if len(unique_values) <= 15: df_copy[label] = df_copy[label].astype('category') figure = px.scatter(df_copy, x=self._dimension_names[0], y=self._dimension_names[1], color=label, - color_discrete_sequence=px.colors.qualitative.Set1, + color_discrete_sequence=px.colors.qualitative.Vivid, hover_data=hover_data, category_orders={label: sorted(unique_values)}) else: diff --git a/immuneML/reports/ml_reports/ConfusionMatrix.py b/immuneML/reports/ml_reports/ConfusionMatrix.py index 4622523a8..1b86a2259 100644 --- a/immuneML/reports/ml_reports/ConfusionMatrix.py +++ b/immuneML/reports/ml_reports/ConfusionMatrix.py @@ -28,12 +28,6 @@ class ConfusionMatrix(MLReport): This may be useful to compare performance across different data subsets (e.g., batches, sources). If specified, separate confusion matrices will be generated for each value of the alternative label. Default is None. - Example output: - - .. image:: ../../_static/images/reports/confusion_matrix_example.png - :alt: Confusion matrix report - :width: 650 - **YAML specification:** .. code-block:: yaml diff --git a/immuneML/reports/multi_dataset_reports/PerformanceOverview.py b/immuneML/reports/multi_dataset_reports/PerformanceOverview.py index d55783c83..e123985ad 100644 --- a/immuneML/reports/multi_dataset_reports/PerformanceOverview.py +++ b/immuneML/reports/multi_dataset_reports/PerformanceOverview.py @@ -4,8 +4,7 @@ import pandas as pd import plotly.express as px import plotly.graph_objects as go -from sklearn import metrics -from sklearn.metrics import precision_recall_curve +from sklearn.metrics import precision_recall_curve, roc_curve, roc_auc_score from immuneML.environment.Constants import Constants from immuneML.environment.Label import Label @@ -93,8 +92,8 @@ def plot_roc(self, optimal_hp_items, label: Label, colors) -> Tuple[ReportOutput df = pd.read_csv(item.test_predictions_path) true_class = df[f"{label.name}_true_class"].values predicted_class = df[f"{label.name}_{label.positive_class}_proba"].values - fpr, tpr, _ = metrics.roc_curve(y_true=true_class, y_score=predicted_class) - auc = metrics.roc_auc_score(true_class, predicted_class) + fpr, tpr, _ = roc_curve(y_true=true_class, y_score=predicted_class) + auc = roc_auc_score(y_true=true_class, y_score=predicted_class) name = self.instruction_states[index].dataset.name + f' (AUC = {round(auc, 2)})' figure.add_trace(go.Scatter(x=fpr, y=tpr, mode='lines', name=name, marker=dict(color=colors[index], line=dict(width=3)), hoverinfo="skip")) diff --git a/immuneML/workflows/instructions/clustering/ClusteringRunner.py b/immuneML/workflows/instructions/clustering/ClusteringRunner.py index 7ab064e63..ffad40ffb 100644 --- a/immuneML/workflows/instructions/clustering/ClusteringRunner.py +++ b/immuneML/workflows/instructions/clustering/ClusteringRunner.py @@ -81,8 +81,8 @@ def run_setting(self, dataset: Dataset, cl_setting: ClusteringSetting, analysis_ predictions=predictions, encoder=encoder, method=method, - external_performance=DataFrameWrapper(path=performance_paths['external']), - internal_performance=DataFrameWrapper(path=performance_paths['internal']) + external_performance=DataFrameWrapper(path=performance_paths['external']) if performance_paths['external'] else None, + internal_performance=DataFrameWrapper(path=performance_paths['internal']) if performance_paths['internal'] else None ) report_results = self.report_handler.run_item_reports(cl_item, analysis_desc, run_id, cl_setting.path, state) @@ -181,5 +181,6 @@ def encode_dataset_internal(dataset: Dataset, cl_setting: ClusteringSetting, num if cl_setting.dim_reduction_method: enc_dataset.encoded_data.dimensionality_reduced_data = cl_setting.dim_reduction_method.fit_transform( enc_dataset) + enc_dataset.encoded_data.dim_names = cl_setting.dim_reduction_method.get_dimension_names() return enc_dataset diff --git a/immuneML/workflows/instructions/clustering/clustering_run_model.py b/immuneML/workflows/instructions/clustering/clustering_run_model.py index 71514c0a2..0f3c99226 100644 --- a/immuneML/workflows/instructions/clustering/clustering_run_model.py +++ b/immuneML/workflows/instructions/clustering/clustering_run_model.py @@ -21,7 +21,7 @@ def __init__(self, path: Path, df: pd.DataFrame = None): df.to_csv(str(path), index=False) def get_df(self): - if self.df is None and self.path.exists(): + if self.df is None and isinstance(self.path, Path) and self.path.exists(): self.df = pd.read_csv(str(self.path)) return self.df @@ -49,6 +49,12 @@ def get_key(self) -> str: def __str__(self): return self.get_key() + def __hash__(self): + return hash(self.get_key()) + + def __eq__(self, other): + return self.get_key() == other.get_key() + @dataclass class ClusteringItem: diff --git a/immuneML/workflows/instructions/exploratory_analysis/ExploratoryAnalysisInstruction.py b/immuneML/workflows/instructions/exploratory_analysis/ExploratoryAnalysisInstruction.py index a641bcfae..9ae16de14 100644 --- a/immuneML/workflows/instructions/exploratory_analysis/ExploratoryAnalysisInstruction.py +++ b/immuneML/workflows/instructions/exploratory_analysis/ExploratoryAnalysisInstruction.py @@ -126,6 +126,7 @@ def run_unit(self, unit: ExploratoryAnalysisUnit, result_path: Path) -> List[Rep def _run_dimensionality_reduction(self, unit: ExploratoryAnalysisUnit): result = unit.dim_reduction.fit_transform(unit.dataset) unit.dataset.encoded_data.dimensionality_reduced_data = result + unit.dataset.encoded_data.dim_names = unit.dim_reduction.get_dimension_names() def preprocess_dataset(self, unit: ExploratoryAnalysisUnit, result_path: Path) -> Dataset: if unit.preprocessing_sequence is not None and len(unit.preprocessing_sequence) > 0: diff --git a/immuneML/workflows/instructions/train_gen_model/TrainGenModelInstruction.py b/immuneML/workflows/instructions/train_gen_model/TrainGenModelInstruction.py index 196fdc925..39a13d9ab 100644 --- a/immuneML/workflows/instructions/train_gen_model/TrainGenModelInstruction.py +++ b/immuneML/workflows/instructions/train_gen_model/TrainGenModelInstruction.py @@ -17,6 +17,7 @@ from immuneML.reports.data_reports.DataReport import DataReport from immuneML.reports.ml_reports.MLReport import MLReport from immuneML.reports.train_gen_model_reports.TrainGenModelReport import TrainGenModelReport +from immuneML.reports.gen_model_reports.GenModelReport import GenModelReport from immuneML.util.Logger import print_log from immuneML.util.PathBuilder import PathBuilder from immuneML.workflows.instructions.Instruction import Instruction @@ -33,7 +34,7 @@ class TrainGenModelState: generated_dataset: Dataset = None exported_datasets: Dict[str, Path] = field(default_factory=dict) report_results: Dict[str, List[ReportResult]] = field( - default_factory=lambda: {'data_reports': [], 'ml_reports': [], 'instruction_reports': []}) + default_factory=lambda: {'data_reports': [], 'ml_reports': [], 'gen_ml_reports': [], 'instruction_reports': []}) combined_dataset: Dataset = None train_dataset: Dataset = None test_dataset: Dataset = None @@ -252,6 +253,14 @@ def _run_reports(self): rep.dataset = original_dataset rep.name = rep.name + " (original dataset)" self.state.report_results['data_reports'].append(rep.generate_report()) + elif isinstance(report, GenModelReport): + for method in self.methods: + rep = copy.deepcopy(report) + rep.result_path = PathBuilder.build(rep.result_path.parent / f"{rep.result_path.name}_original_dataset") + rep.dataset = original_dataset + rep.model = method + rep.name = rep.name + " (original dataset)" + self.state.report_results['gen_ml_reports'].append(rep.generate_report()) self._print_report_summary_log() @@ -293,6 +302,8 @@ def _run_reports_main(self): for method in self.methods: rep = copy.deepcopy(report) rep.dataset = self.generated_datasets[method.name] + rep.result_path = PathBuilder.build( + rep.result_path.parent / f"{rep.result_path.name}_{method.name}") rep.name = rep.name + " (generated dataset from " + method.name + ")" self.state.report_results['data_reports'].append(rep.generate_report()) elif isinstance(report, MLReport): @@ -305,7 +316,7 @@ def _run_reports_main(self): def _print_report_summary_log(self): if len(self.reports) > 0: gen_rep_count = len(self.state.report_results['ml_reports']) + len( - self.state.report_results['data_reports']) + self.state.report_results['data_reports']) + len(self.state.report_results['gen_ml_reports']) print_log(f"{self.state.name}: generated {gen_rep_count} reports.", True) def _get_reports_path(self) -> Path: diff --git a/pyproject.toml b/pyproject.toml index 8c3eacfdf..2e5558751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ word2vec = ["gensim>=4"] fisher = ["fisher>=0.1.9", "fishersapi"] TCRdist = ["tcrdist3>=0.1.6"] -gen_models = ["sonnia", "torch"] +gen_models = ["sonnia", "torch", 'transformers', 'datasets', 'tokenizers', 'tensorflow<=2.15.0', 'accelerate>=0.26.0'] embeddings = ["transformers", "torch", "sentencepiece", "esm", 'httpx'] ligo = ["stitchr", "IMGTgeneDL"] DL = ["torch", "keras", "tensorflow", "logomaker", "gensim"] @@ -64,7 +64,11 @@ all = [ "gensim>=4", "transformers", "sentencepiece", - 'httpx' + 'httpx', + 'datasets', + 'tokenizers', + 'tensorflow<=2.15.0', + 'accelerate>=0.26.0' ] [project.urls] diff --git a/requirements_generative_models.txt b/requirements_generative_models.txt index ef296207e..a368a13fe 100644 --- a/requirements_generative_models.txt +++ b/requirements_generative_models.txt @@ -3,7 +3,10 @@ sonia<=0.1.3 sonnia<=0.1.5 sentencepiece keras -tensorflow +tensorflow<=2.15.0 torch transformers esm +datasets +tokenizers +accelerate>=0.26.0 diff --git a/test/dsl/test_preprocessingParser.py b/test/dsl/test_preprocessingParser.py index 9752a889c..cfc13b703 100644 --- a/test/dsl/test_preprocessingParser.py +++ b/test/dsl/test_preprocessingParser.py @@ -10,14 +10,14 @@ def test_parse(self): "seq1": [ {"filter_chain_B": { "ChainRepertoireFilter": { - "keep_chain": "A" + "keep_chains": ["A"] } }} ], "seq2": [ {"filter_chain_A": { "ChainRepertoireFilter": { - "keep_chain": "B" + "keep_chains": ["B"] } }} ] diff --git a/test/ml_methods/test_progen.py b/test/ml_methods/test_progen.py new file mode 100644 index 000000000..6ec0305d6 --- /dev/null +++ b/test/ml_methods/test_progen.py @@ -0,0 +1,85 @@ +import os +import shutil +import urllib +from pathlib import Path + +import pandas as pd +from transformers import PreTrainedTokenizerFast + +from immuneML.data_model.SequenceParams import Chain +from immuneML.environment.EnvironmentSettings import EnvironmentSettings +from immuneML.environment.SequenceType import SequenceType +from immuneML.ml_methods.generative_models.ProGen import ProGen +from immuneML.ml_methods.generative_models.progen.ProGenConfig import ProGenConfig +from immuneML.ml_methods.generative_models.progen.ProGenForCausalLM import ProGenForCausalLM +from immuneML.simulation.dataset_generation.RandomDatasetGenerator import RandomDatasetGenerator +from immuneML.util.PathBuilder import PathBuilder + + +def download_file(url: str, dest: Path): + dest.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading {url} ...") + with urllib.request.urlopen(url) as r, open(dest, "wb") as f: + f.write(r.read()) + print(f"Saved to {dest}") + return dest + + +def make_dummy_progen_model(model_path: Path) -> Path: + model_path.mkdir(parents=True, exist_ok=True) + + tokenizer_json = model_path / "tokenizer.json" + download_file( + "https://raw.githubusercontent.com/enijkamp/progen2/main/tokenizer.json", + tokenizer_json + ) + tokenizer = PreTrainedTokenizerFast(tokenizer_file=str(tokenizer_json)) + tokenizer.pad_token = tokenizer.eos_token + + config = ProGenConfig( + vocab_size=len(tokenizer), + n_positions=256, + n_ctx=256, + n_embd=16, + n_layer=1, + n_head=8, + ) + + model = ProGenForCausalLM(config) + tokenizer.save_pretrained(model_path) + model.save_pretrained(model_path) + + return model_path + + +def test_progen(): + path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / 'progen') + os.makedirs(path, exist_ok=True) + + model_dir = make_dummy_progen_model(path / "dummy_model") + tokenizer_json = model_dir / "tokenizer.json" + + dataset = RandomDatasetGenerator.generate_sequence_dataset(20, {10: 1.}, + {}, path / 'dataset', region_type="IMGT_JUNCTION") + + progen = ProGen(Chain.get_chain('beta'), + tokenizer_json, + model_dir, + 1, + 1, + 5e-5, + 'cpu', + prefix_text='', + suffix_text='', + name='progen_test', + region_type="IMGT_JUNCTION", + seed=42) + + progen.fit(dataset, path / 'model') + progen.generate_sequences(7, 42, path / 'generated_dataset', SequenceType.AMINO_ACID, False) + + assert (path / 'generated_dataset').exists() + assert (path / 'generated_dataset/synthetic_dataset.tsv').exists() + assert pd.read_csv(str(path / 'generated_dataset/synthetic_dataset.tsv'), sep='\t').shape[0] == 7 + + shutil.rmtree(path) diff --git a/test/preprocessing/filters/test_chainRepertoireFilter.py b/test/preprocessing/filters/test_chainRepertoireFilter.py index aa0863616..2b07624b8 100644 --- a/test/preprocessing/filters/test_chainRepertoireFilter.py +++ b/test/preprocessing/filters/test_chainRepertoireFilter.py @@ -32,7 +32,7 @@ def test_process(self): dataset = RepertoireDataset(repertoires=[rep1, rep2], metadata_file=path / "metadata.csv") - dataset2 = ChainRepertoireFilter(**{"keep_chain": "ALPHA"}).process_dataset(dataset, path / "results") + dataset2 = ChainRepertoireFilter(**{"keep_chains": ["ALPHA"]}).process_dataset(dataset, path / "results") self.assertEqual(1, len(dataset2.get_data())) self.assertEqual(2, len(dataset.get_data())) @@ -44,7 +44,7 @@ def test_process(self): for rep in dataset2.get_data(): self.assertEqual("AAA", rep.sequences()[0].get_sequence()) - self.assertRaises(AssertionError, ChainRepertoireFilter(**{"keep_chain": "GAMMA"}).process_dataset, dataset, + self.assertRaises(AssertionError, ChainRepertoireFilter(**{"keep_chains": ["GAMMA"]}).process_dataset, dataset, path / "results") rep1 = Repertoire.build_from_sequences([ReceptorSequence(sequence_aa="AAA", locus="ALPHA", @@ -57,7 +57,7 @@ def test_process(self): new_input_dataset = RepertoireDataset(repertoires=[rep1, rep2], metadata_file=path / "metadata.csv") - dataset3 = ChainRepertoireFilter(**{"keep_chain": "BETA", "remove_only_sequences": True}).process_dataset(new_input_dataset, path / "results2") + dataset3 = ChainRepertoireFilter(**{"keep_chains": ["BETA"], "remove_only_sequences": True}).process_dataset(new_input_dataset, path / "results2") self.assertEqual(2, len(dataset3.get_data())) shutil.rmtree(path) diff --git a/test/reports/data_reports/test_SimpleDatasetOverview.py b/test/reports/data_reports/test_SimpleDatasetOverview.py deleted file mode 100644 index e513fa197..000000000 --- a/test/reports/data_reports/test_SimpleDatasetOverview.py +++ /dev/null @@ -1,74 +0,0 @@ -import os -import shutil -from unittest import TestCase - -from immuneML.environment.EnvironmentSettings import EnvironmentSettings -from immuneML.reports.ReportResult import ReportResult -from immuneML.reports.data_reports.SimpleDatasetOverview import SimpleDatasetOverview -from immuneML.simulation.dataset_generation.RandomDatasetGenerator import RandomDatasetGenerator -from immuneML.util.PathBuilder import PathBuilder - - -class TestSimpleDatasetOverview(TestCase): - def test_generate_sequence_dataset(self): - path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / "overview_sequence_dataset/") - - dataset = RandomDatasetGenerator.generate_sequence_dataset(100, {10: 0.5, 11: 0.25, 20: 0.25}, - {"l1": {"a": 0.5, "b": 0.5}}, path / "dataset") - - params = {"dataset": dataset, "result_path": path / "result"} - - report = SimpleDatasetOverview.build_object(**params) - self.assertTrue(report.check_prerequisites()) - - result = report._generate() - - self.assertIsInstance(result, ReportResult) - - self.assertTrue(os.path.isfile(path / "result/dataset_description.txt")) - - shutil.rmtree(path) - - def test_generate_receptor_dataset(self): - path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / "overview_receptor_dataset/") - - dataset = RandomDatasetGenerator.generate_receptor_dataset(100, chain_1_length_probabilities={10: 0.5, 11: 0.25, - 20: 0.25}, - chain_2_length_probabilities={10: 0.5, 11: 0.25, - 15: 0.25}, - labels={"l1": {"a": 0.5, "b": 0.5}}, - path=path / "dataset") - - params = {"dataset": dataset, "result_path": path / "result"} - - report = SimpleDatasetOverview.build_object(**params) - self.assertTrue(report.check_prerequisites()) - - result = report._generate() - - self.assertIsInstance(result, ReportResult) - - self.assertTrue(os.path.isfile(path / "result/dataset_description.txt")) - - shutil.rmtree(path) - - def test_generate_repertoire_dataset(self): - path = PathBuilder.remove_old_and_build(EnvironmentSettings.tmp_test_path / "overview_repertoire_dataset/") - - dataset = RandomDatasetGenerator.generate_repertoire_dataset(repertoire_count=5, - sequence_count_probabilities={20: 1}, - sequence_length_probabilities={10: 1}, - labels={"l1": {"a": 0.5, "b": 0.5}}, - path=path / "dataset") - - params = {"dataset": dataset, "result_path": path / "result"} - - report = SimpleDatasetOverview.build_object(**params) - self.assertTrue(report.check_prerequisites()) - - result = report._generate() - - self.assertIsInstance(result, ReportResult) - self.assertTrue(os.path.isfile(path / "result/dataset_description.txt")) - - shutil.rmtree(path) diff --git a/test/reports/data_reports/test_VJGeneDistribution.py b/test/reports/data_reports/test_VJGeneDistribution.py index 81b65ad9d..661f89ef7 100644 --- a/test/reports/data_reports/test_VJGeneDistribution.py +++ b/test/reports/data_reports/test_VJGeneDistribution.py @@ -106,4 +106,4 @@ def test_generate_repertoire_dataset(self): self.assertTrue(os.path.isfile(path / "result/TRBVJ_gene_distribution_l1=a_averaged_across_repertoires.html")) self.assertTrue(os.path.isfile(path / "result/TRBVJ_gene_distribution_l1=b_averaged_across_repertoires.html")) - shutil.rmtree(path) + # shutil.rmtree(path) diff --git a/test/reports/tool_reports/test_PerformanceOverview.py b/test/reports/tool_reports/test_PerformanceOverview.py index fbea8667f..de42465f9 100644 --- a/test/reports/tool_reports/test_PerformanceOverview.py +++ b/test/reports/tool_reports/test_PerformanceOverview.py @@ -32,8 +32,8 @@ def _prepare_specs(self, path) -> Path: "d1": { "format": "RandomRepertoireDataset", "params": { - "repertoire_count": 50, - "sequence_count_probabilities": {50: 1}, + "repertoire_count": 30, + "sequence_count_probabilities": {5: 1}, "sequence_length_probabilities": {2: 1}, "result_path": str(Path(path / "d1")), "labels": { @@ -47,8 +47,8 @@ def _prepare_specs(self, path) -> Path: "d2": { "format": "RandomRepertoireDataset", "params": { - "repertoire_count": 50, - "sequence_count_probabilities": {50: 1}, + "repertoire_count": 30, + "sequence_count_probabilities": {5: 1}, "sequence_length_probabilities": {2: 1}, "result_path": str(Path(path / "d2")), "labels": {