diff --git a/src/dwi_ml/general/data/dataset/multi_subject_containers.py b/src/dwi_ml/general/data/dataset/multi_subject_containers.py index 64b55b92..27e2344f 100644 --- a/src/dwi_ml/general/data/dataset/multi_subject_containers.py +++ b/src/dwi_ml/general/data/dataset/multi_subject_containers.py @@ -448,12 +448,20 @@ def load_data(self, load_training=True, load_validation=True, if compress == 'Not defined by user': compress = None - # Loading the first training subject's group information. + # Loading the first subject's group information. # Others should fit. - one_subj = hdf_handle.attrs['training_subjs'][0] + if len(hdf_handle.attrs['training_subjs']) > 0: + first_subj = hdf_handle.attrs['training_subjs'][0] + elif len(hdf_handle.attrs['validation_subjs']) > 0: + first_subj = hdf_handle.attrs['validation_subjs'][0] + elif len(hdf_handle.attrs['testing_subjs']) > 0: + first_subj = hdf_handle.attrs['testing_subjs'][0] + else: + raise ValueError("No subject found in the hdf5 file") + (poss_volume_groups, nb_features, poss_strea_groups, contains_connectivity) = prepare_groups_info( - one_subj, hdf_handle, ref_group_info=None) + first_subj, hdf_handle, ref_group_info=None) logger.debug("Possible volume groups are: {}" .format(poss_volume_groups)) logger.debug("Number of features in each of these groups: {}"