Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/dwi_ml/cli/l2t_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def init_from_args(args, sub_loggers_level):
dg_args = check_args_direction_getter(args)
# (Nb features)
input_group_idx = dataset.volume_groups.index(args.input_group_name)
args.nb_features = dataset.nb_features[input_group_idx]

nb_features = dataset.nb_features[input_group_idx]
# Final model
with Timer("\n\nPreparing model", newline=True, color='yellow'):
# INPUTS: verifying args
Expand All @@ -84,10 +85,13 @@ def init_from_args(args, sub_loggers_level):
nb_previous_dirs=args.nb_previous_dirs,
normalize_prev_dirs=args.normalize_prev_dirs,
# INPUTS
nb_features_per_point=nb_features,
input_embedding_key=args.input_embedding_key,
input_embedded_size=args.input_embedded_size,
nb_features=args.nb_features, kernel_size=args.kernel_size,
nb_cnn_filters=args.nb_cnn_filters,
input_embedding_nn_out_size=args.input_embedded_size,
input_embedding_cnn_kernel_size=args.kernel_size,
input_embedding_cnn_nb_filters=args.nb_cnn_filters,
add_raw_coords_to_input=args.add_raw_coords_to_input,
add_relative_coords_to_input=args.add_relative_coords_to_input,
# RNN
rnn_key=args.rnn_key, rnn_layer_sizes=args.rnn_layer_sizes,
dropout=args.dropout,
Expand Down
3 changes: 3 additions & 0 deletions src/dwi_ml/cli/tests/test_all_steps_l2t.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_training(script_runner, experiments_path):
ret = script_runner.run('l2t_train_model',
experiments_path, experiment_name, hdf5_file,
input_group_name, streamline_group_name,
'--add_relative_coords_to_input',
'--input_embedding_key', 'nn_embedding',
'--input_embedded_size', '8',
'--max_epochs', '1', '--batch_size_training', '5',
'--batch_size_validation', '5',
'--dg_key', 'gaussian',
Expand Down
12 changes: 8 additions & 4 deletions src/dwi_ml/cli/tt_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,23 @@ def init_from_args(args, sub_loggers_level):

# (nb features)
input_group_idx = dataset.volume_groups.index(args.input_group_name)
args.nb_features = dataset.nb_features[input_group_idx]
nb_features = dataset.nb_features[input_group_idx]

# Final model
with Timer("\n\nPreparing model", newline=True, color='yellow'):
model = cls(
experiment_name=args.experiment_name,
step_size=args.step_size, compress_lines=args.compress_th,
# Concerning inputs:
max_len=args.max_len, nb_features=args.nb_features,
max_len=args.max_len,
nb_features_per_point=nb_features,
add_raw_coords_to_input=args.add_raw_coords_to_input,
add_relative_coords_to_input=args.add_relative_coords_to_input,
positional_encoding_key=args.position_encoding,
input_embedding_key=args.input_embedding_key,
input_embedded_size=args.input_embedded_size,
nb_cnn_filters=args.nb_cnn_filters, kernel_size=args.kernel_size,
input_embedding_nn_out_size=args.input_embedded_size,
input_embedding_cnn_nb_filters=args.nb_cnn_filters,
input_embedding_cnn_kernel_size=args.kernel_size,
# Torch's transformer parameters
ffnn_hidden_size=args.ffnn_hidden_size,
nheads=args.nheads, dropout_rate=args.dropout_rate,
Expand Down
4 changes: 2 additions & 2 deletions src/dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from dwi_ml.data.dataset.checks_for_groups import prepare_groups_info
from dwi_ml.data.dataset.mri_data_containers import MRIDataAbstract
from dwi_ml.data.dataset.subjectdata_list_containers import (
LazySubjectsDataList, SubjectsDataList)
SubjectsDataListAbstract, LazySubjectsDataList, SubjectsDataList)
from dwi_ml.data.dataset.single_subject_containers import (LazySubjectData,
SubjectData)

Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(self, set_name: str, hdf5_file: str, lazy: bool,

# The subjects data list will be either a SubjectsDataList or a
# LazySubjectsDataList depending on MultisubjectDataset.is_lazy.
self.subjs_data_list = None
self.subjs_data_list = None # type:SubjectsDataListAbstract
self.subjects = [] # type:List[str]
self.nb_subjects = 0

Expand Down
2 changes: 1 addition & 1 deletion src/dwi_ml/data/dataset/subjectdata_list_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, hdf5_path: str, log):

# Do not access it directly. Use get_subj.
# Will be a list of SubjectData or LazySubjectData
self._subjects_data_list = []
self._subjects_data_list = [] #type:List[SubjectDataAbstract]

def add_subject(self, subject_data: SubjectDataAbstract):
"""
Expand Down
Loading