From c6141b6b7a9cc37f8e0e4ef6979349bbe2ef3f17 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 6 May 2025 14:25:02 -0400 Subject: [PATCH 1/2] Add option to concat coords. Needed to modify Transformer's d_model management --- src/dwi_ml/cli/l2t_train_model.py | 12 +- src/dwi_ml/cli/tests/test_all_steps_l2t.py | 3 + src/dwi_ml/cli/tt_train_model.py | 12 +- .../data/dataset/multi_subject_containers.py | 4 +- .../dataset/subjectdata_list_containers.py | 2 +- src/dwi_ml/models/main_models.py | 275 ++++++++++++------ .../models/projects/learn2track_model.py | 59 ++-- .../models/projects/learn2track_utils.py | 3 +- .../models/projects/transformer_models.py | 195 ++++++++----- .../models/projects/transformers_utils.py | 3 +- src/dwi_ml/training/batch_loaders.py | 2 +- src/dwi_ml/training/trainers.py | 5 +- .../test_model_prepare_batchOneInput.py | 11 +- .../unit_tests/test_models_transformers.py | 42 +-- .../unit_tests/test_train_batch_loader.py | 10 +- .../utils/data_and_models_for_tests.py | 14 + 16 files changed, 438 insertions(+), 214 deletions(-) diff --git a/src/dwi_ml/cli/l2t_train_model.py b/src/dwi_ml/cli/l2t_train_model.py index b368b700..6537daa5 100755 --- a/src/dwi_ml/cli/l2t_train_model.py +++ b/src/dwi_ml/cli/l2t_train_model.py @@ -71,7 +71,8 @@ def init_from_args(args, sub_loggers_level): dg_args = check_args_direction_getter(args) # (Nb features) input_group_idx = dataset.volume_groups.index(args.input_group_name) - args.nb_features = dataset.nb_features[input_group_idx] + + nb_features = dataset.nb_features[input_group_idx] # Final model with Timer("\n\nPreparing model", newline=True, color='yellow'): # INPUTS: verifying args @@ -84,10 +85,13 @@ def init_from_args(args, sub_loggers_level): nb_previous_dirs=args.nb_previous_dirs, normalize_prev_dirs=args.normalize_prev_dirs, # INPUTS + nb_features_per_point=nb_features, input_embedding_key=args.input_embedding_key, - input_embedded_size=args.input_embedded_size, - nb_features=args.nb_features, kernel_size=args.kernel_size, - nb_cnn_filters=args.nb_cnn_filters, + input_embedding_nn_out_size=args.input_embedded_size, + input_embedding_cnn_kernel_size=args.kernel_size, + input_embedding_cnn_nb_filters=args.nb_cnn_filters, + add_raw_coords_to_input=args.add_raw_coords_to_input, + add_relative_coords_to_input=args.add_relative_coords_to_input, # RNN rnn_key=args.rnn_key, rnn_layer_sizes=args.rnn_layer_sizes, dropout=args.dropout, diff --git a/src/dwi_ml/cli/tests/test_all_steps_l2t.py b/src/dwi_ml/cli/tests/test_all_steps_l2t.py index 71544030..9ec0ea21 100644 --- a/src/dwi_ml/cli/tests/test_all_steps_l2t.py +++ b/src/dwi_ml/cli/tests/test_all_steps_l2t.py @@ -48,6 +48,9 @@ def test_training(script_runner, experiments_path): ret = script_runner.run('l2t_train_model', experiments_path, experiment_name, hdf5_file, input_group_name, streamline_group_name, + '--add_relative_coords_to_input', + '--input_embedding_key', 'nn_embedding', + '--input_embedded_size', '8', '--max_epochs', '1', '--batch_size_training', '5', '--batch_size_validation', '5', '--dg_key', 'gaussian', diff --git a/src/dwi_ml/cli/tt_train_model.py b/src/dwi_ml/cli/tt_train_model.py index e4ed5ea2..e4b644b2 100755 --- a/src/dwi_ml/cli/tt_train_model.py +++ b/src/dwi_ml/cli/tt_train_model.py @@ -107,7 +107,7 @@ def init_from_args(args, sub_loggers_level): # (nb features) input_group_idx = dataset.volume_groups.index(args.input_group_name) - args.nb_features = dataset.nb_features[input_group_idx] + nb_features = dataset.nb_features[input_group_idx] # Final model with Timer("\n\nPreparing model", newline=True, color='yellow'): @@ -115,11 +115,15 @@ def init_from_args(args, sub_loggers_level): experiment_name=args.experiment_name, step_size=args.step_size, compress_lines=args.compress_th, # Concerning inputs: - max_len=args.max_len, nb_features=args.nb_features, + max_len=args.max_len, + nb_features_per_point=nb_features, + add_raw_coords_to_input=args.add_raw_coords_to_input, + add_relative_coords_to_input=args.add_relative_coords_to_input, positional_encoding_key=args.position_encoding, input_embedding_key=args.input_embedding_key, - input_embedded_size=args.input_embedded_size, - nb_cnn_filters=args.nb_cnn_filters, kernel_size=args.kernel_size, + input_embedding_nn_out_size=args.input_embedded_size, + input_embedding_cnn_nb_filters=args.nb_cnn_filters, + input_embedding_cnn_kernel_size=args.kernel_size, # Torch's transformer parameters ffnn_hidden_size=args.ffnn_hidden_size, nheads=args.nheads, dropout_rate=args.dropout_rate, diff --git a/src/dwi_ml/data/dataset/multi_subject_containers.py b/src/dwi_ml/data/dataset/multi_subject_containers.py index 5aa999c3..c3978de5 100644 --- a/src/dwi_ml/data/dataset/multi_subject_containers.py +++ b/src/dwi_ml/data/dataset/multi_subject_containers.py @@ -15,7 +15,7 @@ from dwi_ml.data.dataset.checks_for_groups import prepare_groups_info from dwi_ml.data.dataset.mri_data_containers import MRIDataAbstract from dwi_ml.data.dataset.subjectdata_list_containers import ( - LazySubjectsDataList, SubjectsDataList) + SubjectsDataListAbstract, LazySubjectsDataList, SubjectsDataList) from dwi_ml.data.dataset.single_subject_containers import (LazySubjectData, SubjectData) @@ -50,7 +50,7 @@ def __init__(self, set_name: str, hdf5_file: str, lazy: bool, # The subjects data list will be either a SubjectsDataList or a # LazySubjectsDataList depending on MultisubjectDataset.is_lazy. - self.subjs_data_list = None + self.subjs_data_list = None # type:SubjectsDataListAbstract self.subjects = [] # type:List[str] self.nb_subjects = 0 diff --git a/src/dwi_ml/data/dataset/subjectdata_list_containers.py b/src/dwi_ml/data/dataset/subjectdata_list_containers.py index 7d203ccd..d3632bbe 100644 --- a/src/dwi_ml/data/dataset/subjectdata_list_containers.py +++ b/src/dwi_ml/data/dataset/subjectdata_list_containers.py @@ -21,7 +21,7 @@ def __init__(self, hdf5_path: str, log): # Do not access it directly. Use get_subj. # Will be a list of SubjectData or LazySubjectData - self._subjects_data_list = [] + self._subjects_data_list = [] #type:List[SubjectDataAbstract] def add_subject(self, subject_data: SubjectDataAbstract): """ diff --git a/src/dwi_ml/models/main_models.py b/src/dwi_ml/models/main_models.py index 2d25cef6..6d0d049c 100644 --- a/src/dwi_ml/models/main_models.py +++ b/src/dwi_ml/models/main_models.py @@ -208,10 +208,11 @@ def _load_state(cls, model_dir): return model_state - def forward(self, inputs, streamlines): + def forward(self, streamlines): raise NotImplementedError - def compute_loss(self, model_outputs, target_streamlines): + def compute_loss(self, model_outputs, target_streamlines, + average_results: bool = False): raise NotImplementedError def merge_batches_outputs(self, all_outputs, new_batch): @@ -453,6 +454,70 @@ def forward(self, inputs, target_streamlines: List[torch.tensor]): class MainModelOneInput(MainModelAbstract): + def __init__(self, + nb_features_per_point: int, + add_raw_coords_to_input: bool, + add_relative_coords_to_input: bool, + **kwargs): + """ + Params + ------ + add_raw_coords_to_input: bool + If true, add raw coordinates to input vectors. + add_relative_coords_to_input: bool + If true, add relative coordinates to input vectors, ie + coords / volume_dims.""" + super().__init__(**kwargs) + + self.nb_features_per_point = int(nb_features_per_point) + self.add_raw_coords_to_input = add_raw_coords_to_input + self.add_relative_coords_to_input = add_relative_coords_to_input + + if self.add_raw_coords_to_input and self.add_relative_coords_to_input: + raise ValueError( + "add_raw_coords_to_input and add_relative_coords_toinput " + "cannot be used at the same time.") + + @staticmethod + def add_args_model_one_input(p): + g = p.add_mutually_exclusive_group() + g.add_argument("--add_raw_coords_to_input", action="store_true", + help="If true, add raw coordinates to input vectors.") + g.add_argument("--add_relative_coords_to_input", action="store_true", + help="If true, add relative coordinates to input " + "vectors, i.e. \ncoords / volume_dims.") + + @classmethod + def _load_params(cls, model_dir): + params = super()._load_params(model_dir) + if 'nb_features' in params: + logging.warning("Deprecated model usage, with 'nb_features' in " + "its params. Modifying automatically. Support " + "will eventually be removed.") + if params['nb_neighbors'] > 0: + params['nb_features_per_point'] = \ + float(params['nb_features']) / params['nb_neighbors'] + assert params['nb_features_per_points'].is_integer(), \ + ("Unexpected error. Deprecated nb_features should have " + "been real_nb_features * nb_neighbors. " + "Ask Emmanuelle to fix.") + else: + params['nb_features_per_point'] = params['nb_features'] + return params + + @property + def params_for_checkpoint(self): + # Every parameter necessary to build the different layers again. + # during checkpoint state saving. + params = super().params_for_checkpoint + params.update({ + 'nb_features_per_point': self.nb_features_per_point, + 'add_raw_coords_to_input': self.add_raw_coords_to_input, + 'add_relative_coords_to_input': self.add_relative_coords_to_input, + }) + + return params + def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset, subj_idx, input_group_idx, prepare_mask=False, clear_cache=True): @@ -495,16 +560,25 @@ def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset, # to volume bounds. if isinstance(self, ModelWithNeighborhood): # Adding neighborhood. - subj_x_data, coords_torch = interpolate_volume_in_neighborhood( + subj_x, coords_torch = interpolate_volume_in_neighborhood( data_tensor, flat_subj_x_coords, self.neighborhood_vectors, clear_cache=clear_cache) else: - subj_x_data, coords_torch = interpolate_volume_in_neighborhood( + subj_x, coords_torch = interpolate_volume_in_neighborhood( data_tensor, flat_subj_x_coords, None, clear_cache=clear_cache) # Split the flattened signal back to streamlines lengths = [len(s) for s in streamlines] - subj_x_data = list(subj_x_data.split(lengths)) + subj_x = list(subj_x.split(lengths)) + + if self.add_raw_coords_to_input: + subj_x = [torch.cat((_x, _s), dim=1) for _x, _s + in zip(subj_x, streamlines)] + elif self.add_relative_coords_to_input: + volume_dim = torch.as_tensor(data_tensor.shape[0:3], + device=self.device) + subj_x = [torch.cat((_x, torch.divide(_s, volume_dim)), dim=1) + for _x, _s in zip(subj_x, streamlines)] if prepare_mask: logging.warning("Model OneInput: DEBUGGING MODE. Returning " @@ -521,134 +595,171 @@ def prepare_batch_one_input(self, streamlines, subset: MultisubjectSubset, for s in range(len(coords_torch)): input_mask.data[tuple(coords_to_idx_clipped[s, :])] = 1 - return subj_x_data, input_mask + return subj_x, input_mask - return subj_x_data + return subj_x + + def forward(self, inputs, streamlines): + raise NotImplementedError class ModelOneInputWithEmbedding(MainModelOneInput): - def __init__(self, nb_features: int, + def __init__(self, input_embedding_key: str, - input_embedded_size: int = None, - nb_cnn_filters: List[int] = None, - kernel_size: List[int] = None, **kw): + input_embedding_nn_out_size: int = None, + input_embedding_cnn_nb_filters: List[int] = None, + input_embedding_cnn_kernel_size: List[int] = None, **kw): """ + Note that input_size should be given later, when instantiating the + layers. + Parameters ---------- - nb_features: int - This value should be known from the actual data. Number of features - in the data (last dimension). input_embedding_key: str Key to an embedding class (one of dwi_ml.models.embeddings_on_tensors.keys_to_embeddings). Default: 'no_embedding'. - input_embedded_size: int - Output embedding size for the input. Not required for no_embedding. - nb_cnn_filters: int + input_embedding_nn_out_size: int + Output embedding size for the input for the NN embedding. + input_embedding_cnn_nb_filters: int Number of filters in the CNN. Output size at each voxel. - kernel_size: int + input_embedding_cnn_kernel_size: int Used only with CNN embedding. Size of the 3D filter matrix. Will be of shape [k, k, k]. """ - super().__init__(**kw) self.input_embedding_key = input_embedding_key - self.input_embedded_size = input_embedded_size - self.nb_cnn_filters = nb_cnn_filters - self.kernel_size = kernel_size - self.nb_features = nb_features + self.input_embedding_nn_out_size = input_embedding_nn_out_size + self.input_embedding_cnn_nb_filters = input_embedding_cnn_nb_filters + self.input_embedding_cnn_kernel_size = input_embedding_cnn_kernel_size # Preparing layer variables now but not instantiated. User must provide # input size. self.input_embedding_layer = None - # ----------- Instantiation + checks + # ----------- Checking options if self.input_embedding_key not in keys_to_embeddings.keys(): raise ValueError("Embedding choice for x data not understood: {}" .format(self.input_embedding_key)) + elif self.input_embedding_key in ['nn_embedding', 'no_embedding']: + # No CNN option + if self.input_embedding_cnn_nb_filters is not None: + raise ValueError("Nb CNN filters should not be used when " + "embedding is not CNN.") + if self.input_embedding_cnn_kernel_size is not None: + raise ValueError("CNN kernel_size should not be used when " + "embedding is not CNN.") + + # NN options + if (self.input_embedding_key == 'nn_embedding' and + input_embedding_nn_out_size is None): + raise ValueError("Input embedded size should be defined.") + + elif self.input_embedding_key == 'cnn_embedding': + # For CNN: make sure that neighborhood is included. + if not isinstance(self, ModelWithNeighborhood): + raise ValueError("CNN embedding cannot be used without a " + "neighborhood. Add ModelWithNeighborhood as " + "parent to your model class.") + if self.neighborhood_type != 'grid': + raise ValueError("CNN embedding should be used with a " + "grid-like neighborhood.") + + # No NN options + if input_embedding_nn_out_size is not None: + raise ValueError( + "CNN embedded size will be computed automatically based " + "on kernel size and number of filters. Do not use " + "nn_embedded_size") + + # Computing expected input size + # Right now input is always flattened (interpolation is implemented + # that way). For CNN, we will rearrange it ourselves. + # For options add_raw/relative_coords_to_input, we will add 3 + # additional inputs. + if input_embedding_key == 'cnn_embedding': + if (self.add_raw_coords_to_input or + self.add_relative_coords_to_input): + raise NotImplementedError( + "Not ready to concatenate coordinates to input with " + "CNN embedding.") + self.input_size = self.nb_features_per_point + else: + self.input_size = self.nb_features_per_point * self.nb_neighbors + if (self.add_raw_coords_to_input or + self.add_relative_coords_to_input): + self.input_size += 3 - # This variable will contain final computed size. + # This variable will eventually contain the final computed size. self.computed_input_embedded_size = None + + def instantiate_input_embedding_layer(self): + """ + Params + ------ + input_size: int + This value should be known from the actual data. Number of features + in the data (last dimension) that will be passed to the input + embedding layer. + - Using CNN: data should be a neighborhood of points with last + dimension "input_size" + - Using NN: data should be a single vector of dimension + "input_size". Ex, with a neighborhood, this value is probably + input_size = real_nb_features * nb_points_in_neighborhood + """ if self.input_embedding_key == 'cnn_embedding': - self.instantiate_cnn_embedding() - else: - self.instantiate_nn_embedding() + self._instantiate_cnn_embedding() + elif self.input_embedding_key == 'nn_embedding': + self._instantiate_nn_embedding() + else: # self.input_embedding_key == 'no_embedding' + self.computed_input_embedded_size = self.input_size + self.input_embedding_layer = torch.nn.Identity() - def instantiate_cnn_embedding(self): - input_embedding_cls = keys_to_embeddings[self.input_embedding_key] + def _instantiate_cnn_embedding(self): + cnn_embedding_cls = keys_to_embeddings[self.input_embedding_key] - # For CNN: make sure that neighborhood is included. - if not isinstance(self, ModelWithNeighborhood): - raise ValueError("CNN embedding cannot be used without a " - "neighborhood. Add ModelWithNeighborhood as " - "parent to your model class.") - if self.neighborhood_type != 'grid': - raise ValueError( - "CNN embedding should be used with a grid-like neighborhood.") + # Size, if neighborhood_radius is n, nb of voxels is 2*n + 1. + neighb_size = 2 * self.neighborhood_radius + 1 + logging.debug("Computed that CNN (input embedding) will reveive data " + "coming from a neighborhood of size {}" + .format(neighb_size)) - if self.input_embedded_size is not None: + if self.input_embedding_nn_out_size is not None: raise ValueError( "You should not use input_embedded_size with CNN embedding." "Rather, use the nb_filters and kernel_size.") - if self.kernel_size is None: + if self.input_embedding_cnn_kernel_size is None: raise ValueError("Kernel size must be defined to use CNN " "embedding") - if len(self.kernel_size) != len(self.nb_cnn_filters): + if (len(self.input_embedding_cnn_kernel_size) != + len(self.input_embedding_cnn_nb_filters)): raise ValueError("kernel_size and nb_cnn_filters must contain " "the same number of values.") - if self.kernel_size[0] > 2 * self.neighborhood_radius + 1: + if self.input_embedding_cnn_kernel_size[0] > neighb_size: # Kernel size cannot be bigger than the number of points. - # Per size, if neighborhood_radius in n, nb of voxels is 2*n + 1. # Kernel size of other layers would be longer to check. # We will wait for error in forward, if any. raise ValueError( - "CNN kernel size is bigger than the neighborhood size." - "Not expected, as we are not padding the data.") + "CNN kernel size (layer 1) is bigger than the neighborhood " + "size. Not expected, as we are not padding the data.") - neighb_size = 2 * self.neighborhood_radius + 1 - self.input_embedding_layer = input_embedding_cls( - nb_features_in=self.nb_features, - nb_filters=self.nb_cnn_filters, - kernel_sizes=self.kernel_size, + self.input_embedding_layer = cnn_embedding_cls( + nb_features_in=self.input_size, + nb_filters=self.input_embedding_cnn_nb_filters, + kernel_sizes=self.input_embedding_cnn_kernel_size, image_shape=[neighb_size] * 3) self.computed_input_embedded_size = \ self.input_embedding_layer.out_flattened_size - def instantiate_nn_embedding(self): - # NN embedding or identity embedding: - - if self.nb_cnn_filters is not None: - raise ValueError("Nb CNN filters should not be used when " - "embedding is not CNN.") - if self.kernel_size is not None: - raise ValueError("CNN kernel_size should not be used when " - "embedding is not CNN.") - - input_size = self.nb_features - if isinstance(self, ModelWithNeighborhood): - input_size *= self.nb_neighbors - - if self.input_embedding_key == 'no_embedding': - if self.computed_input_embedded_size is None: - self.computed_input_embedded_size = input_size - else: - assert self.computed_input_embedded_size == input_size, \ - "Got input size {} ({} features x {} neighbors) but " \ - "expecting {}".format(input_size, self.nb_features, - self.nb_neighbors, - self.computed_input_embedded_size) - else: - if self.input_embedded_size is None: - raise ValueError("Input embedded size should be defined.") - self.computed_input_embedded_size = self.input_embedded_size - + def _instantiate_nn_embedding(self): + self.computed_input_embedded_size = self.input_embedding_nn_out_size input_embedding_cls = keys_to_embeddings[self.input_embedding_key] self.input_embedding_layer = input_embedding_cls( - nb_features_in=input_size, + nb_features_in=self.input_size, nb_features_out=self.computed_input_embedded_size) @property @@ -658,9 +769,9 @@ def params_for_checkpoint(self): params = super().params_for_checkpoint params.update({ 'input_embedding_key': self.input_embedding_key, - 'input_embedded_size': self.input_embedded_size, - 'kernel_size': self.kernel_size, - 'nb_cnn_filters': self.nb_cnn_filters, + 'input_embedding_nn_out_size': self.input_embedding_nn_out_size, + 'input_embedding_cnn_kernel_size': self.input_embedding_cnn_kernel_size, + 'input_embedding_cnn_nb_filters': self.input_embedding_cnn_nb_filters, }) return params diff --git a/src/dwi_ml/models/projects/learn2track_model.py b/src/dwi_ml/models/projects/learn2track_model.py index 9ba8074c..39f0480c 100644 --- a/src/dwi_ml/models/projects/learn2track_model.py +++ b/src/dwi_ml/models/projects/learn2track_model.py @@ -81,7 +81,7 @@ class Learn2TrackModel(ModelWithPreviousDirections, ModelWithDirectionGetter, def __init__(self, experiment_name, step_size: Union[float, None], compress_lines: Union[float, None], - nb_features: int, + nb_features_per_point: int, # PREVIOUS DIRS nb_previous_dirs: Union[int, None], prev_dirs_embedded_size: Union[int, None], @@ -89,9 +89,11 @@ def __init__(self, experiment_name, normalize_prev_dirs: bool, # INPUTS input_embedding_key: str, - input_embedded_size: Union[int, None], - nb_cnn_filters: Optional[List[int]], - kernel_size: Optional[List[int]], + input_embedding_nn_out_size: Union[int, None], + input_embedding_cnn_nb_filters: Optional[List[int]], + input_embedding_cnn_kernel_size: Optional[List[int]], + add_raw_coords_to_input: bool, + add_relative_coords_to_input: bool, # RNN rnn_key: str, rnn_layer_sizes: List[int], use_skip_connection: bool, use_layer_normalization: bool, @@ -110,6 +112,9 @@ def __init__(self, experiment_name, nb_previous_dirs: int Number of previous direction (i.e. [x,y,z] information) to be received. + nb_features_per_point: int + Number of features per point. Real input size received from batch + loader will be nb_features_per_point * neigbhorhood_size. rnn_key: str Either 'LSTM' or 'GRU'. rnn_layer_sizes: List[int] @@ -140,11 +145,15 @@ def __init__(self, experiment_name, neighborhood_type=neighborhood_type, neighborhood_radius=neighborhood_radius, neighborhood_resolution=neighborhood_resolution, + # For super ModelWithOneInput: + nb_features_per_point=nb_features_per_point, + add_raw_coords_to_input=add_raw_coords_to_input, + add_relative_coords_to_input=add_relative_coords_to_input, # For super ModelWithInputEmbedding: - nb_features=nb_features, input_embedding_key=input_embedding_key, - input_embedded_size=input_embedded_size, - nb_cnn_filters=nb_cnn_filters, kernel_size=kernel_size, + input_embedding_nn_out_size=input_embedding_nn_out_size, + input_embedding_cnn_nb_filters=input_embedding_cnn_nb_filters, + input_embedding_cnn_kernel_size=input_embedding_cnn_kernel_size, # For super MainModelWithPD: nb_previous_dirs=nb_previous_dirs, prev_dirs_embedded_size=prev_dirs_embedded_size, @@ -155,12 +164,6 @@ def __init__(self, experiment_name, self.dropout = dropout self.start_from_copy_prev = start_from_copy_prev - self.nb_cnn_filters = nb_cnn_filters - self.kernel_size = kernel_size - - # Right now input is always flattened (interpolation is implemented - # that way). For CNN, we will rearrange it ourselves. - self.input_size = nb_features * self.nb_neighbors # ----------- Checks if dropout < 0 or dropout > 1: @@ -176,7 +179,8 @@ def __init__(self, experiment_name, # ---------- Instantiations # 1. Previous dirs embedding: prepared by super. - # 2. Input embedding + # 2. Input embedding: prepared by super. Adding dropout + self.instantiate_input_embedding_layer() self.embedding_dropout = torch.nn.Dropout(self.dropout) # 3. Stacked RNN @@ -189,14 +193,14 @@ def __init__(self, experiment_name, use_skip_connection=use_skip_connection, use_layer_normalization=use_layer_normalization, dropout=dropout) - # 4. Direction getter: + # 4. Direction getter: (calls super's method) self.instantiate_direction_getter(self.rnn_model.output_size) def set_context(self, context): # Training, validation: Used by trainer. Nothing special. # Tracking: Used by tracker. Returns only the last point. - # Preparing_backward: Used by tracker. Nothing special, but does - # not return only the last point. + # Preparing_backward: Used by tracker. Nothing special, but does + # not return only the last point. # Visu: Nothing special. Used by tester. assert context in ['training', 'validation', 'tracking', 'visu', 'preparing_backward'] @@ -208,7 +212,9 @@ def params_for_checkpoint(self): # during checkpoint state saving. params = super().params_for_checkpoint params.update({ - 'nb_features': int(self.nb_features), + 'nb_features_per_point': self.nb_features_per_point, + 'add_raw_coords_to_input': self.add_raw_coords_to_input, + 'add_relative_coords_to_input': self.add_relative_coords_to_input, 'rnn_key': self.rnn_model.rnn_torch_key, 'rnn_layer_sizes': self.rnn_model.layer_sizes, 'use_skip_connection': self.rnn_model.use_skip_connection, @@ -226,8 +232,9 @@ def computed_params_for_display(self): return p def forward(self, x: List[torch.tensor], - input_streamlines: List[torch.tensor] = None, - hidden_recurrent_states: List = None, return_hidden=False, + input_streamlines: List[torch.tensor], + hidden_recurrent_states: List = None, + return_hidden=False, point_idx: int = None): """Run the model on a batch of sequences. @@ -268,11 +275,14 @@ def forward(self, x: List[torch.tensor], # Right now input is always flattened (interpolation is implemented # that way). For CNN, we will rearrange it ourselves. # Verifying the first input - assert x[0].shape[-1] == self.input_size, \ + expected_input_size = self.nb_features_per_point * self.nb_neighbors + if self.add_raw_coords_to_input or self.add_relative_coords_to_input: + expected_input_size += 3 + assert x[0].shape[-1] == expected_input_size, \ "Not the expected input size! Should be {} (i.e. {} features for " \ "each of the {} neighbors), but got {} (input shape {})." \ - .format(self.input_size, self.nb_features, self.nb_neighbors, - x[0].shape[-1], x[0].shape) + .format(self.input_size, self.nb_features_per_point, + self.nb_neighbors, x[0].shape[-1], x[0].shape) # Making sure we can use default 'enforce_sorted=True' with packed # sequences. @@ -284,7 +294,8 @@ def forward(self, x: List[torch.tensor], unsorted_indices = invert_permutation(sorted_indices) x = [x[i] for i in sorted_indices] if input_streamlines is not None: - input_streamlines = [input_streamlines[i] for i in sorted_indices] + input_streamlines = [input_streamlines[i] + for i in sorted_indices] # ==== 0. Previous dirs. n_prev_dirs = None diff --git a/src/dwi_ml/models/projects/learn2track_utils.py b/src/dwi_ml/models/projects/learn2track_utils.py index 0a805c12..b17c09bf 100644 --- a/src/dwi_ml/models/projects/learn2track_utils.py +++ b/src/dwi_ml/models/projects/learn2track_utils.py @@ -23,7 +23,8 @@ def add_model_args(p: argparse.ArgumentParser): Learn2TrackModel.add_args_model_with_pd(prev_dirs_g) inputs_g = p.add_argument_group( - "Learn2track model: Main inputs embedding layer") + "Learn2track model: Main inputs and inputs embedding layer") + Learn2TrackModel.add_args_model_one_input(inputs_g) Learn2TrackModel.add_neighborhood_args_to_parser(inputs_g) Learn2TrackModel.add_args_input_embedding(inputs_g) diff --git a/src/dwi_ml/models/projects/transformer_models.py b/src/dwi_ml/models/projects/transformer_models.py index 10c259d8..a1520c88 100644 --- a/src/dwi_ml/models/projects/transformer_models.py +++ b/src/dwi_ml/models/projects/transformer_models.py @@ -101,7 +101,8 @@ def merge_one_weight_type(weights, new_weights, device): return weights -class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter, +class AbstractTransformerModel(ModelWithNeighborhood, + ModelWithDirectionGetter, ModelOneInputWithEmbedding): """ Prepares the parts common to our two transformer versions: embeddings, @@ -123,9 +124,13 @@ def __init__(self, step_size: Union[float, None], compress_lines: Union[float, None], # INPUTS IN ENCODER - nb_features: int, input_embedding_key: str, - input_embedded_size: int, - nb_cnn_filters: Optional[int], kernel_size: Optional[int], + nb_features_per_point: int, + add_raw_coords_to_input: bool, + add_relative_coords_to_input: bool, + input_embedding_key: str, + input_embedding_nn_out_size: int, + input_embedding_cnn_nb_filters: Optional[int], + input_embedding_cnn_kernel_size: Optional[int], # GENERAL TRANSFORMER PARAMS max_len: int, positional_encoding_key: str, ffnn_hidden_size: Union[int, None], @@ -180,9 +185,6 @@ def __init__(self, n_layers_e: int All ours models have at least an encoder. """ - - # Important. Super must be called first to verify input embedded size - # through the ModelOneInputWithEmbedding. super().__init__( # MainAbstract experiment_name=experiment_name, step_size=step_size, @@ -192,11 +194,15 @@ def __init__(self, neighborhood_type=neighborhood_type, neighborhood_radius=neighborhood_radius, neighborhood_resolution=neighborhood_resolution, + # For super WithInput + nb_features_per_point=nb_features_per_point, + add_raw_coords_to_input=add_raw_coords_to_input, + add_relative_coords_to_input=add_relative_coords_to_input, # For super ModelWithInputEmbedding: - nb_features=nb_features, input_embedding_key=input_embedding_key, - input_embedded_size=input_embedded_size, - nb_cnn_filters=nb_cnn_filters, kernel_size=kernel_size, + input_embedding_nn_out_size=input_embedding_nn_out_size, + input_embedding_cnn_nb_filters=input_embedding_cnn_nb_filters, + input_embedding_cnn_kernel_size=input_embedding_cnn_kernel_size, # Tracking dg_key=dg_key, dg_args=dg_args) @@ -207,14 +213,15 @@ def __init__(self, self.dropout_rate = dropout_rate self.activation = activation self.norm_first = norm_first + + # Note: self.d_model is now a property function, getting the right info + # based on the Transformer Class. It is based on the input embedding + # layer, so let's instantiate it now. + self.instantiate_input_embedding_layer() self.ffnn_hidden_size = ffnn_hidden_size if ffnn_hidden_size is not \ None else self.d_model // 2 # ----------- Checks - if self.d_model // self.nheads != float(self.d_model) / self.nheads: - raise ValueError("d_model ({}) must be divisible by nheads ({})" - .format(self.d_model, self.nheads)) - if dropout_rate < 0 or dropout_rate > 1: raise ValueError('The dropout rate must be between 0 and 1.') @@ -223,20 +230,47 @@ def __init__(self, raise ValueError("Positional encoding choice not understood: {}" .format(self.positional_encoding_key)) + # To be prepared in _finish_checks_and_instantiations: + self.dropout_layer = None + self.position_encoding_layer = None + + def _finish_checks_and_instantiations(self): + """ + Should be overwritten by child classes and used in model's final step + of the init(). Then d_model is final. + """ + # This should have been computed in MainModelWithEmbedding + assert self.computed_input_embedded_size is not None + + # This should now be available through the property method + assert self.d_model is not None + # ----------- Instantiations # This dropout is only used in the embedding; torch's transformer # prepares its own dropout elsewhere, and direction getter too. - self.dropout = Dropout(self.dropout_rate) + self.dropout_layer = Dropout(self.dropout_rate) # 1. x embedding layer + if self.d_model // self.nheads != float(self.d_model) / self.nheads: + raise ValueError("d_model ({}) must be divisible by nheads ({})" + .format(self.d_model, self.nheads)) assert self.computed_input_embedded_size > 3, \ - "Current computation of the positional encoding required data " \ - "of size > 3, but got {}".format(self.computed_input_embedded_size) + ("Current computation of the positional encoding required data " + "of size > 3, but got {}\n" + "(nb features in : {}. Embedding choice: {}.\n" + " -> If CNN embedding: embedded size based on filter size and " + "kernel size.\n" + " -> If NN embedding: input is {} * nb_neighbors ({}), +3 if " + "coords are added to inputs. Output is based on options: {}." + .format(self.computed_input_embedded_size, + self.nb_features_per_point, self.input_embedding_key, + self.nb_features_per_point, self.nb_neighbors, + self.input_embedding_nn_out_size)) # 2. positional encoding layer cls_p = keys_to_positional_encodings[self.positional_encoding_key] - self.position_encoding_layer = cls_p(self.d_model, dropout_rate, - max_len) + self.position_encoding_layer = cls_p(self.d_model, self.dropout_rate, + self.max_len) # 3. target embedding layer: See child class with Target @@ -262,9 +296,6 @@ def params_for_checkpoint(self): """ p = super().params_for_checkpoint p.update({ - 'nb_features': int(self.nb_features), - 'input_embedding_key': self.input_embedding_key, - 'input_embedded_size': self.input_embedded_size, 'max_len': self.max_len, 'n_layers_e': self.n_layers_e, 'positional_encoding_key': self.positional_encoding_key, @@ -283,8 +314,11 @@ def _load_params(cls, model_dir): # d_model now a property method. if 'd_model' in params: + logging.warning("Deprecated model. d_model is now called " + "differently. Trying to load from old " + "way. Support for this will end soon.") if isinstance(cls, TransformerSrcOnlyModel): - params['input_embedded_size'] = params['d_model'] + params['input_embedding_nn_out_size'] = params['d_model'] del params['d_model'] @@ -566,11 +600,15 @@ def merge_batches_weights(self, weights, new_weights, device): class TransformerSrcOnlyModel(AbstractTransformerModel): def __init__(self, **kw): """ - No additional params. d_model = input_embedded_size. + No additional params. d_model = computed_input_embedded_size. """ super().__init__(**kw) - # ----------- Additional instantiations + self._finish_checks_and_instantiations() + + def _finish_checks_and_instantiations(self): + super()._finish_checks_and_instantiations() + logger.debug("Instantiating Transformer...") main_layer_encoder = ModifiedTransformerEncoderLayer( self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size, @@ -615,7 +653,7 @@ def _run_embeddings(self, inputs, use_padding, batch_max_len): def _run_position_encoding(self, inputs): inputs = self.position_encoding_layer(inputs) - inputs = self.dropout(inputs) + inputs = self.dropout_layer(inputs) return inputs def _run_main_layer_forward(self, inputs, masks, return_weights): @@ -636,10 +674,12 @@ def merge_batches_weights(self, weights, new_weights, device): class AbstractTransformerModelWithTarget(AbstractTransformerModel): + # Original model: target embedding is the same size as input (d_model) + # Source and target model: we concatenate them so they can be of different + # embedding sizes. def __init__(self, # TARGETS IN DECODER sos_token_type: str, target_embedding_key: str, - target_embedded_size: int, start_from_copy_prev: bool, **kwargs): """ token_type: str @@ -649,10 +689,14 @@ def __init__(self, Target embedding, with the same choices as above. Default: 'no_embedding'. """ - # Some checks before super init, in case d_model depends on target - # embedded size. + super().__init__(**kwargs) + self.target_embedding_key = target_embedding_key - self.target_embedded_size = target_embedded_size + assert target_embedding_key in ['no_embedding', 'nn_embedding'], \ + "{} not supported for target embedding".format(target_embedding_key) + self.target_embedded_size = None # Will be set by child + + # Computing the number of features in target after formatting. if sos_token_type == 'as_label': self.token_sphere = None self.target_features = 4 @@ -662,28 +706,18 @@ def __init__(self, # nb classes = nb_vertices + SOS self.target_features = len(self.token_sphere.vertices) + 1 + if self.target_embedding_key not in keys_to_embeddings.keys(): + raise ValueError("Embedding choice for targets not understood: {}" + .format(self.target_embedding_key)) if self.target_embedding_key == 'no_embedding': - if self.target_embedded_size is None: - self.target_embedded_size = self.target_features - assert self.target_embedded_size == self.target_features, \ - "With no_embedding for the target, input size must be equal " \ - "to the output embedded size. Expecting {}"\ - .format(self.target_features) - else: - assert self.target_embedding_key == 'nn_embedding', \ - "Unrecognized embedding key for the targets." - - super().__init__(**kwargs) + self.target_embedded_size = self.target_features self.sos_token_type = sos_token_type self.start_from_copy_prev = start_from_copy_prev - # Checks. - if self.target_embedding_key not in keys_to_embeddings.keys(): - raise ValueError("Embedding choice for targets not understood: {}" - .format(self.target_embedding_key)) + def _finish_checks_and_instantiations(self): + super()._finish_checks_and_instantiations() - # 3. Target embedding. cls_t = keys_to_embeddings[self.target_embedding_key] self.embedding_layer_t = cls_t(self.target_features, self.target_embedded_size) @@ -704,6 +738,13 @@ def params_for_checkpoint(self): return p + def _instantiate_target_embedding(self): + assert self.target_embedded_size is not None, \ + "Code error. Target embedded size must be set by child class." + cls_t = keys_to_embeddings[self.target_embedding_key] + self.embedding_layer_t = cls_t(self.target_features, + self.target_embedded_size) + def move_to(self, device): super().move_to(device) if self.token_sphere is not None: @@ -823,30 +864,37 @@ class OriginalTransformerModel(AbstractTransformerModelWithTarget): emb_choice_x """ - def __init__(self, input_embedded_size, n_layers_d: int, **kw): + def __init__(self, n_layers_d: int, **kw): """ d_model = input_embedded_size = target_embedded_size. + We must wait to know the embedded size. Will be computed based on + CNN or NN options. Args ---- n_layers_d: int Number of encoding layers in the decoder. [6] """ - super().__init__(input_embedded_size=input_embedded_size, - target_embedded_size=input_embedded_size, **kw) - - # Veryfing that final computed values are still ok - if self.computed_input_embedded_size != self.target_embedded_size: - raise ValueError("For the original model, the input size and " - "target size after embedding must be equal " - "(value d_model) but got {} and {}" - .format(self.computed_input_embedded_size, - self.target_embedded_size)) - - # ----------- Additional params + super().__init__(**kw) + self.n_layers_d = n_layers_d - # ----------- Additional instantiations + # Input embedded size should be computed by now. + self.target_embedded_size = self.computed_input_embedded_size + self._finish_checks_and_instantiations() + + def _finish_checks_and_instantiations(self): + super()._finish_checks_and_instantiations() + + if (self.target_embedding_key == 'no_embedding' and + self.target_features != self.computed_input_embedded_size): + raise ValueError("Cannot use no_embedding for the target; they " + "must be embedded to the same dimension as the " + "input (d_model). Size of Y: {}." + "d_model: {}" + .format(self.target_features, + self.computed_input_embedded_size)) + logger.info("Instantiating torch transformer, may take a few " "seconds...") # Encoder: @@ -873,14 +921,13 @@ def __init__(self, input_embedded_size, n_layers_d: int, **kw): dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) - decoder = ModifiedTransformerDecoder(decoder_layer, n_layers_d, + decoder = ModifiedTransformerDecoder(decoder_layer, self.n_layers_d, norm=None) self.modified_torch_transformer = ModifiedTransformer( - self.d_model, self.nheads, self.n_layers_e, n_layers_d, + self.d_model, self.nheads, self.n_layers_e, self.n_layers_d, self.ffnn_hidden_size, self.dropout_rate, self.activation, - encoder, decoder, batch_first=True, - norm_first=self.norm_first) + encoder, decoder, batch_first=True, norm_first=self.norm_first) @property def d_model(self): @@ -904,10 +951,10 @@ def _run_embeddings(self, data, use_padding, batch_max_len): def _run_position_encoding(self, data): # inputs, targets = data inputs = self.position_encoding_layer(data[0]) - inputs = self.dropout(inputs) + inputs = self.dropout_layer(inputs) targets = self.position_encoding_layer(data[1]) - targets = self.dropout(targets) + targets = self.dropout_layer(targets) return inputs, targets @@ -964,16 +1011,22 @@ class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget): [ emb_choice_x ; emb_choice_y ] """ - def __init__(self, **kw): + def __init__(self, target_embedded_size, **kw): """ No additional params. d_model = input size + target size. """ super().__init__(**kw) + self.target_embedded_size = target_embedded_size + + self._finish_checks_and_instantiations() + + def _finish_checks_and_instantiations(self): + super()._finish_checks_and_instantiations() - # ----------- Additional instantiations logger.debug("Instantiating Transformer...") + d_model = self.computed_input_embedded_size + self.target_embedded_size main_layer_encoder = ModifiedTransformerEncoderLayer( - self.d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size, + d_model, self.nheads, dim_feedforward=self.ffnn_hidden_size, dropout=self.dropout_rate, activation=self.activation, batch_first=True, norm_first=self.norm_first) self.modified_torch_transformer = ModifiedTransformerEncoder( @@ -981,7 +1034,7 @@ def __init__(self, **kw): @property def d_model(self): - # d_model = input size = target size + # d_model = input size + target size # target embedded size must be verified before super init. return self.computed_input_embedded_size + self.target_embedded_size @@ -1001,7 +1054,7 @@ def _run_embeddings(self, data, use_padding, batch_max_len): def _run_position_encoding(self, data): data = self.position_encoding_layer(data) - data = self.dropout(data) + data = self.dropout_layer(data) return data def _run_main_layer_forward(self, concat_s_t, masks, return_weights): diff --git a/src/dwi_ml/models/projects/transformers_utils.py b/src/dwi_ml/models/projects/transformers_utils.py index 5191e678..c35faad9 100644 --- a/src/dwi_ml/models/projects/transformers_utils.py +++ b/src/dwi_ml/models/projects/transformers_utils.py @@ -23,11 +23,12 @@ def add_transformers_model_args(p): AbstractTransformerModel.add_args_main_model(p) gx = p.add_argument_group( - "Embedding of the input (X)", + "Input management and embedding of the input (X)", "Input embedding size defines the d_model. The d_model must be " "divisible by the number of heads.\n" "Note that for TTST, total d_model will rather be " "input_embedded_size + target_embedded_size.\n") + AbstractTransformerModel.add_args_model_one_input(gx) AbstractTransformerModel.add_args_input_embedding( gx, default_embedding='nn_embedding') gx.add_argument( diff --git a/src/dwi_ml/training/batch_loaders.py b/src/dwi_ml/training/batch_loaders.py index 70641502..7ff10c7a 100644 --- a/src/dwi_ml/training/batch_loaders.py +++ b/src/dwi_ml/training/batch_loaders.py @@ -429,7 +429,7 @@ def load_batch_inputs(self, batch_streamlines: List[torch.tensor], Debugging purposes. Saves the input coordinates as a mask. The inputs will be modified to a tuple containing the batch_streamlines, to compare the streamlines with masks. - device: torch device + device: torch.device Torch device. Returns diff --git a/src/dwi_ml/training/trainers.py b/src/dwi_ml/training/trainers.py index c9062186..04f24352 100644 --- a/src/dwi_ml/training/trainers.py +++ b/src/dwi_ml/training/trainers.py @@ -1208,6 +1208,9 @@ def run_one_batch(self, targets, ids_per_subj): n: int Total number of points for this batch. """ + # Data interpolation has not been done yet. GPU computations are done + # here in the main thread. + # Dataloader always works on CPU. Sending to right device. # (model is already moved). targets = [s.to(self.device, non_blocking=True, dtype=torch.float) @@ -1235,7 +1238,7 @@ def run_one_batch(self, targets, ids_per_subj): streamlines_f = self.batch_loader.add_noise_streamlines_forward( streamlines_f, self.device) model_outputs = self.model(batch_inputs, streamlines_f) - del streamlines_f + del streamlines_f, batch_inputs logger.debug('*** Computing loss') # Add noise to targets. diff --git a/src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py b/src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py index c4de67d2..dfa0bb0d 100644 --- a/src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py +++ b/src/dwi_ml/unit_tests/test_model_prepare_batchOneInput.py @@ -20,7 +20,13 @@ def test_model_batch(): lazy=False, log_level=logging.WARNING) dataset.load_data() - model = MainModelOneInput('test', step_size=0.5, compress_lines=False) + nb_features = TEST_EXPECTED_MRI_SHAPE[0][-1] + model = MainModelOneInput(experiment_name='test', + step_size=0.5, + compress_lines=False, + nb_features_per_point=nb_features, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False) one_line = torch.rand(6, 3) batch_streamlines = [one_line, one_line] @@ -28,8 +34,7 @@ def test_model_batch(): streamlines=batch_streamlines, subset=dataset.training_set, subj_idx=0, input_group_idx=0) assert len(batch_inputs) == 2 - assert np.array_equal(batch_inputs[0].shape, - [6, TEST_EXPECTED_MRI_SHAPE[0][-1]]) + assert np.array_equal(batch_inputs[0].shape, [6, nb_features]) if __name__ == '__main__': diff --git a/src/dwi_ml/unit_tests/test_models_transformers.py b/src/dwi_ml/unit_tests/test_models_transformers.py index 71b234e5..9d23e276 100644 --- a/src/dwi_ml/unit_tests/test_models_transformers.py +++ b/src/dwi_ml/unit_tests/test_models_transformers.py @@ -17,36 +17,40 @@ def _prepare_original_model(): # Using defaults from script model = OriginalTransformerModel( experiment_name='test', step_size=0.5, compress_lines=None, - nb_features=4, input_embedded_size=4, max_len=5, + nb_features_per_point=4, input_embedding_nn_out_size=4, max_len=5, log_level='DEBUG', positional_encoding_key='sinusoidal', sos_token_type='as_label', input_embedding_key='nn_embedding', target_embedding_key='nn_embedding', ffnn_hidden_size=None, nheads=1, dropout_rate=0., activation='relu', norm_first=False, n_layers_e=1, n_layers_d=1, dg_key='cosine-regression', dg_args=None, neighborhood_type=None, neighborhood_radius=None, - nb_cnn_filters=None, kernel_size=None, + input_embedding_cnn_nb_filters=None, + input_embedding_cnn_kernel_size=None, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False, start_from_copy_prev=False) return model def _prepare_ttst_model(): + # neighborhood --> No. The number of features is fixed in our fake data + model = TransformerSrcAndTgtModel( - # Varying sos_token_type. - # ffnn_hidden_size, - # norm_first - # dg_key - # eos - # neighborhood --> No. The number of features is fixed in our fake data - # start from copy prev experiment_name='test', step_size=0.5, compress_lines=None, - nb_features=4, max_len=5, input_embedded_size=4, target_embedded_size=2, + nb_features_per_point=4, max_len=5, + target_embedded_size=2, log_level='DEBUG', sos_token_type='repulsion100', - positional_encoding_key='sinusoidal', input_embedding_key='nn_embedding', + positional_encoding_key='sinusoidal', + input_embedding_key='nn_embedding', target_embedding_key='nn_embedding', ffnn_hidden_size=6, nheads=1, dropout_rate=0., activation='relu', norm_first=True, n_layers_e=1, dg_key='sphere-classification', dg_args={'add_eos': True}, neighborhood_type=None, neighborhood_radius=None, - nb_cnn_filters=None, kernel_size=None, + input_embedding_nn_out_size=4, + input_embedding_cnn_nb_filters=None, + input_embedding_cnn_kernel_size=None, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False, start_from_copy_prev=True) return model @@ -54,13 +58,17 @@ def _prepare_ttst_model(): def _prepare_tts_model(): model = TransformerSrcOnlyModel( experiment_name='test', step_size=0.5, compress_lines=None, - nb_features=4, max_len=5, log_level='DEBUG', - input_embedded_size=4, - positional_encoding_key='sinusoidal', input_embedding_key='nn_embedding', + nb_features_per_point=4, max_len=5, log_level='DEBUG', + input_embedding_nn_out_size=4, + positional_encoding_key='sinusoidal', + input_embedding_key='nn_embedding', ffnn_hidden_size=None, nheads=1, dropout_rate=0., activation='relu', norm_first=False, n_layers_e=1, dg_key='cosine-regression', dg_args=None, neighborhood_type=None, neighborhood_radius=None, - nb_cnn_filters=None, kernel_size=None) + input_embedding_cnn_nb_filters=None, + input_embedding_cnn_kernel_size=None, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False) return model @@ -90,7 +98,7 @@ def _run_original_model(model): assert output.shape[1] == 3 # Here, regression, should output x, y, z assert not isnan(output[0, 0]) - # Note. output[0].shape[0] ==> Depends if we unpad sequences. + # Note. output[0].shape[0] ==> Depends on if we unpad sequences. def _run_ttst_model(model): diff --git a/src/dwi_ml/unit_tests/test_train_batch_loader.py b/src/dwi_ml/unit_tests/test_train_batch_loader.py index 3173e7c3..a0d09e2c 100755 --- a/src/dwi_ml/unit_tests/test_train_batch_loader.py +++ b/src/dwi_ml/unit_tests/test_train_batch_loader.py @@ -61,7 +61,10 @@ def test_batch_loader(): # 1) With resampling logging.info('*** Test with batch size {} + loading with ' 'resample, noise, split, reverse.'.format(batch_size)) - model = MainModelOneInput(experiment_name='test', step_size=0.5) + model = MainModelOneInput(experiment_name='test', step_size=0.5, + nb_features_per_point=1, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False) batch_loader = create_batch_loader( dataset, model, noise_size=0.2, split_ratio=SPLIT_RATIO, reverse_ratio=0.5) @@ -77,7 +80,10 @@ def test_batch_loader(): # 2) With compressing logging.info('*** Test with batch size {} + loading with compress' .format(batch_size)) - model = MainModelOneInput(experiment_name='test', compress_lines=True) + model = MainModelOneInput(experiment_name='test', compress_lines=True, + nb_features_per_point=1, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False) batch_loader = create_batch_loader(dataset, model) batch_loader.set_context('training') _load_directly_and_verify(batch_loader, batch_idx_tuples, diff --git a/src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index eb15b5b6..c8552fcc 100644 --- a/src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/src/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -67,9 +67,15 @@ def create_test_batch_2lines_4features(): class ModelForTest(MainModelOneInput, ModelWithNeighborhood): def __init__(self, experiment_name: str = 'test', + nb_features: int = 1, + add_raw_coords_to_input: bool = False, + add_relative_coords_to_input: bool = False, neighborhood_type: str = None, neighborhood_radius=None, log_level=logging.root.level): super().__init__(experiment_name=experiment_name, + nb_features_per_point=nb_features, + add_raw_coords_to_input=add_raw_coords_to_input, + add_relative_coords_to_input=add_relative_coords_to_input, neighborhood_type=neighborhood_type, neighborhood_radius=neighborhood_radius, log_level=log_level) @@ -99,6 +105,10 @@ def __init__(self, experiment_name: str = 'test', log_level=logging.root.level, # NEIGHBORHOOD neighborhood_type: str = None, neighborhood_radius=None, + # Input + nb_features_per_point: int = 1, + add_raw_coords_to_input: bool = False, + add_relative_coords_to_input: bool = False, # PREVIOUS DIRS nb_previous_dirs=0, prev_dirs_embedded_size=None, prev_dirs_embedding_key=None, normalize_prev_dirs=True, @@ -111,6 +121,10 @@ def __init__(self, experiment_name: str = 'test', # For super ModelWithNeighborhood neighborhood_type=neighborhood_type, neighborhood_radius=neighborhood_radius, + # For model with One Input + nb_features_per_point=nb_features_per_point, + add_raw_coords_to_input=add_raw_coords_to_input, + add_relative_coords_to_input=add_relative_coords_to_input, # For super MainModelWithPD: nb_previous_dirs=nb_previous_dirs, prev_dirs_embedded_size=prev_dirs_embedded_size, From 0ed36d10c0f105b501dfe330a8f423656de7035f Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 13 Aug 2025 15:00:49 -0400 Subject: [PATCH 2/2] Fix test --- src/dwi_ml/models/main_models.py | 31 ++++++++++--------- .../models/projects/learn2track_model.py | 19 +++++------- .../unit_tests/test_models_learn2track.py | 25 ++++++++++----- .../unit_tests/test_models_transformers.py | 7 +++-- 4 files changed, 47 insertions(+), 35 deletions(-) diff --git a/src/dwi_ml/models/main_models.py b/src/dwi_ml/models/main_models.py index 6d0d049c..a3c4ef89 100644 --- a/src/dwi_ml/models/main_models.py +++ b/src/dwi_ml/models/main_models.py @@ -673,23 +673,26 @@ def __init__(self, "on kernel size and number of filters. Do not use " "nn_embedded_size") - # Computing expected input size - # Right now input is always flattened (interpolation is implemented - # that way). For CNN, we will rearrange it ourselves. + # When data is received from the trainer, input is flattened + # (interpolation is implemented that way). # For options add_raw/relative_coords_to_input, we will add 3 # additional inputs. + self.input_size_nn = None + self.input_size_cnn = None + self.expected_input_size = self.nb_features_per_point * self.nb_neighbors + add_coords = (self.add_raw_coords_to_input or + self.add_relative_coords_to_input) if input_embedding_key == 'cnn_embedding': - if (self.add_raw_coords_to_input or - self.add_relative_coords_to_input): + if add_coords: raise NotImplementedError( "Not ready to concatenate coordinates to input with " - "CNN embedding.") - self.input_size = self.nb_features_per_point + "CNN embedding. Would need to add it to data at each " + "point in the neighborhood.") + self.input_size_cnn = self.nb_features_per_point else: - self.input_size = self.nb_features_per_point * self.nb_neighbors - if (self.add_raw_coords_to_input or - self.add_relative_coords_to_input): - self.input_size += 3 + if add_coords: + self.expected_input_size += 3 + self.input_size_nn = self.expected_input_size # This variable will eventually contain the final computed size. self.computed_input_embedded_size = None @@ -713,7 +716,7 @@ def instantiate_input_embedding_layer(self): elif self.input_embedding_key == 'nn_embedding': self._instantiate_nn_embedding() else: # self.input_embedding_key == 'no_embedding' - self.computed_input_embedded_size = self.input_size + self.computed_input_embedded_size = self.input_size_nn self.input_embedding_layer = torch.nn.Identity() def _instantiate_cnn_embedding(self): @@ -748,7 +751,7 @@ def _instantiate_cnn_embedding(self): "size. Not expected, as we are not padding the data.") self.input_embedding_layer = cnn_embedding_cls( - nb_features_in=self.input_size, + nb_features_in=self.input_size_cnn, nb_filters=self.input_embedding_cnn_nb_filters, kernel_sizes=self.input_embedding_cnn_kernel_size, image_shape=[neighb_size] * 3) @@ -759,7 +762,7 @@ def _instantiate_nn_embedding(self): self.computed_input_embedded_size = self.input_embedding_nn_out_size input_embedding_cls = keys_to_embeddings[self.input_embedding_key] self.input_embedding_layer = input_embedding_cls( - nb_features_in=self.input_size, + nb_features_in=self.input_size_nn, nb_features_out=self.computed_input_embedded_size) @property diff --git a/src/dwi_ml/models/projects/learn2track_model.py b/src/dwi_ml/models/projects/learn2track_model.py index 39f0480c..2504a49c 100644 --- a/src/dwi_ml/models/projects/learn2track_model.py +++ b/src/dwi_ml/models/projects/learn2track_model.py @@ -269,21 +269,16 @@ def forward(self, x: List[torch.tensor], """ # Reminder. # Correct interpolation and management of points should be done before. + assert x[0].shape[-1] == self.expected_input_size, \ + ("Not the expected input size! Should be {} (i.e. {} features for " + "each of the {} neighbors, maybe with 3 additional values for " + "the current coordinates), but got {} (input shape {})." + .format(self.expected_input_size, self.nb_features_per_point, + self.nb_neighbors, x[0].shape[-1], x[0].shape)) + if self.context is None: raise ValueError("Please set context before usage.") - # Right now input is always flattened (interpolation is implemented - # that way). For CNN, we will rearrange it ourselves. - # Verifying the first input - expected_input_size = self.nb_features_per_point * self.nb_neighbors - if self.add_raw_coords_to_input or self.add_relative_coords_to_input: - expected_input_size += 3 - assert x[0].shape[-1] == expected_input_size, \ - "Not the expected input size! Should be {} (i.e. {} features for " \ - "each of the {} neighbors), but got {} (input shape {})." \ - .format(self.input_size, self.nb_features_per_point, - self.nb_neighbors, x[0].shape[-1], x[0].shape) - # Making sure we can use default 'enforce_sorted=True' with packed # sequences. unsorted_indices = None diff --git a/src/dwi_ml/unit_tests/test_models_learn2track.py b/src/dwi_ml/unit_tests/test_models_learn2track.py index 471b6cb5..7998b5fe 100644 --- a/src/dwi_ml/unit_tests/test_models_learn2track.py +++ b/src/dwi_ml/unit_tests/test_models_learn2track.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import logging +import torch from torch.nn.utils.rnn import pack_sequence from dwi_ml.experiment_utils.prints import format_dict_to_str @@ -34,21 +35,28 @@ def test_stacked_rnn(): def test_learn2track(): + + # Nb features per point = 4. + # Pretending that this is 1 + 3 coordinates. + model = Learn2TrackModel('test', step_size=0.5, compress_lines=False, - nb_features=4, rnn_layer_sizes=[3, 3], + nb_features_per_point=1, rnn_layer_sizes=[3, 3], log_level='DEBUG', # Using default from script: nb_previous_dirs=0, prev_dirs_embedded_size=None, prev_dirs_embedding_key=None, normalize_prev_dirs=True, input_embedding_key='nn_embedding', - input_embedded_size=5, kernel_size=None, - nb_cnn_filters=None, + input_embedding_nn_out_size=5, + input_embedding_cnn_kernel_size=None, + input_embedding_cnn_nb_filters=None, + add_raw_coords_to_input=False, + add_relative_coords_to_input=True, rnn_key='lstm', use_skip_connection=True, use_layer_normalization=True, dropout=0., start_from_copy_prev=False, dg_key='cosine-regression', dg_args=None, - neighborhood_type=None, neighborhood_radius=None) + neighborhood_type=None, neighborhood_radius=None,) logging.info("Learn2track model final parameters:" + format_dict_to_str(model.params_for_checkpoint)) @@ -60,15 +68,18 @@ def test_learn2track(): def test_learn2track_cnn(): model = Learn2TrackModel('test', step_size=0.5, compress_lines=False, - nb_features=4, rnn_layer_sizes=[3, 3], + nb_features_per_point=4, rnn_layer_sizes=[3, 3], log_level='DEBUG', # Using default from script: nb_previous_dirs=0, prev_dirs_embedded_size=None, prev_dirs_embedding_key=None, normalize_prev_dirs=True, input_embedding_key='cnn_embedding', - nb_cnn_filters=[4], kernel_size=[3], - input_embedded_size=None, + input_embedding_cnn_nb_filters=[4], + input_embedding_cnn_kernel_size=[3], + input_embedding_nn_out_size=None, + add_raw_coords_to_input=False, + add_relative_coords_to_input=False, rnn_key='lstm', use_skip_connection=True, use_layer_normalization=True, dropout=0., start_from_copy_prev=False, diff --git a/src/dwi_ml/unit_tests/test_models_transformers.py b/src/dwi_ml/unit_tests/test_models_transformers.py index 9d23e276..6d024d37 100644 --- a/src/dwi_ml/unit_tests/test_models_transformers.py +++ b/src/dwi_ml/unit_tests/test_models_transformers.py @@ -14,10 +14,13 @@ def _prepare_original_model(): + # Nb features per point = 4. + # Pretending that this is 1 + 3 coordinates. + # Using defaults from script model = OriginalTransformerModel( experiment_name='test', step_size=0.5, compress_lines=None, - nb_features_per_point=4, input_embedding_nn_out_size=4, max_len=5, + nb_features_per_point=1, input_embedding_nn_out_size=4, max_len=5, log_level='DEBUG', positional_encoding_key='sinusoidal', sos_token_type='as_label', input_embedding_key='nn_embedding', target_embedding_key='nn_embedding', ffnn_hidden_size=None, nheads=1, @@ -27,7 +30,7 @@ def _prepare_original_model(): input_embedding_cnn_nb_filters=None, input_embedding_cnn_kernel_size=None, add_raw_coords_to_input=False, - add_relative_coords_to_input=False, + add_relative_coords_to_input=True, start_from_copy_prev=False) return model