diff --git a/src/scilpy/segment/voting_scheme.py b/src/scilpy/segment/voting_scheme.py index 7990ec588..8bcc6d768 100644 --- a/src/scilpy/segment/voting_scheme.py +++ b/src/scilpy/segment/voting_scheme.py @@ -11,6 +11,7 @@ import warnings from dipy.io.streamline import save_tractogram, load_tractogram +from dipy.io.stateful_tractogram import StatefulTractogram from dipy.segment.clustering import qbx_and_merge from dipy.tracking.streamline import transform_streamlines import nibabel as nib @@ -21,7 +22,7 @@ from scilpy.segment.bundleseg import BundleSeg from scilpy.utils import get_duration -logger = logging.getLogger('BundleSeg') +logger = logging.getLogger("BundleSeg") # These parameters are leftovers from Recobundles. # Now with BundleSeg, they do not need to be modified. @@ -94,8 +95,8 @@ def _load_bundles_dictionary(self): if len(tmp_list) > 0: bundles_filepath.append(tmp_list) - logger.info(f'{len(self.atlas_dir)} sub-model directories were found. ' - f'with {len(bundle_names)} model bundles total') + logger.info(f"{len(self.atlas_dir)} sub-model directories were found. " + f"with {len(bundle_names)} model bundles total") model_bundles_dict = {} bundle_counts = [] @@ -193,23 +194,34 @@ def _save_recognized_bundles(self, input_tractograms_path, reference, if len(streamlines_id) == 0: streamlines_id = np.array([], dtype=np.uint32) - logger.info(f'{bundle_names[bundle_id]} final recognition got ' - f'{len(streamlines_id)} streamlines') + logger.info(f"{bundle_names[bundle_id]} final recognition got " + f"{len(streamlines_id)} streamlines") + + scores = np.array([], dtype=np.float16) + if len(streamlines_id) > 0: + scores = bundles_wise_score[bundle_id, + streamlines_id].flatten() + + results_dict[basename] = {"indices": streamlines_id.tolist(), + "scores": scores.tolist()} else: streamlines_id = np.array( - results_dict[basename]['indices']) + results_dict[basename]["indices"]) # Need to make sure the indices are valid for this sft + # Convert back to local indices (for this sft) if len(sft) and len(streamlines_id): - # Convert back to local indices (for this sft) streamlines_id = streamlines_id[streamlines_id >= tot_sft_len] - streamlines_id = streamlines_id[streamlines_id < tot_sft_len + curr_sft_len] + streamlines_id = streamlines_id[streamlines_id < + tot_sft_len + curr_sft_len] + else: + continue # If the user asked to ignore metadata, remove it (simpler) - new_sft = sft[streamlines_id - tot_sft_len] - if self.ignore_metadata: - new_sft.data_per_point = {} - new_sft.data_per_streamline = {} + if len(streamlines_id) > 0: + new_sft = sft[streamlines_id - tot_sft_len] + else: + new_sft = StatefulTractogram.from_sft([], sft) if basename in results_sft: try: @@ -221,18 +233,6 @@ def _save_recognized_bundles(self, input_tractograms_path, reference, f"try --ignore_metadata.") else: results_sft[basename] = new_sft - - # Populate the results dictionary (will be saved as json) - curr_results_dict = {} - curr_results_dict['indices'] = streamlines_id.tolist() - - if len(streamlines_id) > 0: - scores = bundles_wise_score[bundle_id, - streamlines_id].flatten() - else: - scores = np.array([], dtype=np.float16) - curr_results_dict['scores'] = scores.tolist() - results_dict[basename] = curr_results_dict tot_sft_len += len(sft) # Once everything is done, save all bundles, at the moment only @@ -242,10 +242,10 @@ def _save_recognized_bundles(self, input_tractograms_path, reference, if len(sft) > 0 or self.save_empty: sft.remove_invalid_streamlines() save_tractogram(sft, os.path.join(self.output_directory, - basename + extension)) + basename + extension)) - out_logfile = os.path.join(self.output_directory, 'results.json') - with open(out_logfile, 'w') as outfile: + out_logfile = os.path.join(self.output_directory, "results.json") + with open(out_logfile, "w") as outfile: json.dump(results_dict, outfile) def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, @@ -273,9 +273,9 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, nib.streamlines.load(in_tractogram).streamlines) len_wb_streamlines = len(wb_streamlines) - logger.debug(f'Tractogram {input_tractograms_path} with ' - f'{len_wb_streamlines} streamlines ' - f'is loaded in {get_duration(load_timer)} seconds') + logger.debug(f"Tractogram {input_tractograms_path} with " + f"{len_wb_streamlines} streamlines " + f"is loaded in {get_duration(load_timer)} seconds") total_timer = time() # Each type of bundle is processed separately @@ -298,12 +298,12 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, clusters_indices = ArraySequence(clusters_indices) clusters_indices._data = clusters_indices._data.astype(np.uint32) - logger.info(f'QBx with seed {seed} at {TCT}mm took ' - f'{get_duration(cluster_timer)}sec. gave ' - f'{len(cluster_map.centroids)} centroids') + logger.info(f"QBx with seed {seed} at {TCT}mm took " + f"{get_duration(cluster_timer)}sec. gave " + f"{len(cluster_map.centroids)} centroids") tmp_dir, tmp_memmap_filenames = streamlines_to_memmap(wb_streamlines, - 'float16') + "float16") # Memory cleanup (before multiprocessing) cluster_map.refdata = None @@ -331,8 +331,8 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, pool.close() pool.join() - logger.info(f'BundleSeg took {get_duration(total_timer)} sec. for ' - f'{len(bundle_names)} bundles from {len(self.atlas_dir)} atlas') + logger.info(f"BundleSeg took {get_duration(total_timer)} sec. for " + f"{len(bundle_names)} bundles from {len(self.atlas_dir)} atlas") bundles_wise_vote = np.zeros((len(bundle_names), len_wb_streamlines), @@ -369,10 +369,10 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None, minimum_vote, ext) tmp_dir.cleanup() saved_bundles = [f for f in os.listdir(self.output_directory) - if os.path.splitext(f)[1] in ['.trk', '.tck']] - logger.info(f'Saving of {len(saved_bundles)} files in ' - f'{self.output_directory} took ' - f'{get_duration(save_timer)} sec.') + if os.path.splitext(f)[1] in [".trk", ".tck"]] + logger.info(f"Saving of {len(saved_bundles)} files in " + f"{self.output_directory} took " + f"{get_duration(save_timer)} sec.") def single_recognize_parallel(args): @@ -427,7 +427,7 @@ def single_recognize(args): shorter_tag, ext = os.path.splitext(os.path.basename(model_filepath)) # Now hardcoded (not useful with FSS from Etienne) - slr_transform_type = 'similarity' + slr_transform_type = "similarity" recognize_timer = time() results = bsg.recognize(model_bundle, @@ -438,11 +438,11 @@ def single_recognize(args): recognized_indices, recognized_scores = results del model_bundle._data, model_bundle - logger.info(f'Model {shorter_tag} recognized {len(recognized_indices)} ' - 'streamlines') - logger.debug(f'Model {model_filepath} with parameters tct={TCT}, mct={MCT}, ' - f'bpt={bundle_pruning_thr} ' - f'took {get_duration(recognize_timer)} sec.') + logger.info(f"Model {shorter_tag} recognized {len(recognized_indices)} " + "streamlines") + logger.debug(f"Model {model_filepath} with parameters tct={TCT}, mct={MCT}, " + f"bpt={bundle_pruning_thr} " + f"took {get_duration(recognize_timer)} sec.") bundle_id = bundle_names.index(shorter_tag+ext) return bundle_id, recognized_indices, recognized_scores