Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 46 additions & 46 deletions src/scilpy/segment/voting_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import warnings

from dipy.io.streamline import save_tractogram, load_tractogram
from dipy.io.stateful_tractogram import StatefulTractogram
from dipy.segment.clustering import qbx_and_merge
from dipy.tracking.streamline import transform_streamlines
import nibabel as nib
Expand All @@ -21,7 +22,7 @@
from scilpy.segment.bundleseg import BundleSeg
from scilpy.utils import get_duration

logger = logging.getLogger('BundleSeg')
logger = logging.getLogger("BundleSeg")

# These parameters are leftovers from Recobundles.
# Now with BundleSeg, they do not need to be modified.
Expand Down Expand Up @@ -94,8 +95,8 @@ def _load_bundles_dictionary(self):
if len(tmp_list) > 0:
bundles_filepath.append(tmp_list)

logger.info(f'{len(self.atlas_dir)} sub-model directories were found. '
f'with {len(bundle_names)} model bundles total')
logger.info(f"{len(self.atlas_dir)} sub-model directories were found. "
f"with {len(bundle_names)} model bundles total")

model_bundles_dict = {}
bundle_counts = []
Expand Down Expand Up @@ -193,23 +194,34 @@ def _save_recognized_bundles(self, input_tractograms_path, reference,

if len(streamlines_id) == 0:
streamlines_id = np.array([], dtype=np.uint32)
logger.info(f'{bundle_names[bundle_id]} final recognition got '
f'{len(streamlines_id)} streamlines')
logger.info(f"{bundle_names[bundle_id]} final recognition got "
f"{len(streamlines_id)} streamlines")

scores = np.array([], dtype=np.float16)
if len(streamlines_id) > 0:
scores = bundles_wise_score[bundle_id,
streamlines_id].flatten()

results_dict[basename] = {"indices": streamlines_id.tolist(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where the important change is.

"scores": scores.tolist()}
else:
streamlines_id = np.array(
results_dict[basename]['indices'])
results_dict[basename]["indices"])

# Need to make sure the indices are valid for this sft
# Convert back to local indices (for this sft)
if len(sft) and len(streamlines_id):
# Convert back to local indices (for this sft)
streamlines_id = streamlines_id[streamlines_id >= tot_sft_len]
streamlines_id = streamlines_id[streamlines_id < tot_sft_len + curr_sft_len]
streamlines_id = streamlines_id[streamlines_id <
tot_sft_len + curr_sft_len]
else:
continue

# If the user asked to ignore metadata, remove it (simpler)
new_sft = sft[streamlines_id - tot_sft_len]
if self.ignore_metadata:
new_sft.data_per_point = {}
new_sft.data_per_streamline = {}
if len(streamlines_id) > 0:
new_sft = sft[streamlines_id - tot_sft_len]
else:
new_sft = StatefulTractogram.from_sft([], sft)

if basename in results_sft:
try:
Expand All @@ -221,18 +233,6 @@ def _save_recognized_bundles(self, input_tractograms_path, reference,
f"try --ignore_metadata.")
else:
results_sft[basename] = new_sft

# Populate the results dictionary (will be saved as json)
curr_results_dict = {}
curr_results_dict['indices'] = streamlines_id.tolist()

if len(streamlines_id) > 0:
scores = bundles_wise_score[bundle_id,
streamlines_id].flatten()
else:
scores = np.array([], dtype=np.float16)
curr_results_dict['scores'] = scores.tolist()
results_dict[basename] = curr_results_dict
tot_sft_len += len(sft)

# Once everything is done, save all bundles, at the moment only
Expand All @@ -242,10 +242,10 @@ def _save_recognized_bundles(self, input_tractograms_path, reference,
if len(sft) > 0 or self.save_empty:
sft.remove_invalid_streamlines()
save_tractogram(sft, os.path.join(self.output_directory,
basename + extension))
basename + extension))

out_logfile = os.path.join(self.output_directory, 'results.json')
with open(out_logfile, 'w') as outfile:
out_logfile = os.path.join(self.output_directory, "results.json")
with open(out_logfile, "w") as outfile:
json.dump(results_dict, outfile)

def __call__(self, input_tractograms_path, nbr_processes=1, seed=None,
Expand Down Expand Up @@ -273,9 +273,9 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None,
nib.streamlines.load(in_tractogram).streamlines)
len_wb_streamlines = len(wb_streamlines)

logger.debug(f'Tractogram {input_tractograms_path} with '
f'{len_wb_streamlines} streamlines '
f'is loaded in {get_duration(load_timer)} seconds')
logger.debug(f"Tractogram {input_tractograms_path} with "
f"{len_wb_streamlines} streamlines "
f"is loaded in {get_duration(load_timer)} seconds")

total_timer = time()
# Each type of bundle is processed separately
Expand All @@ -298,12 +298,12 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None,
clusters_indices = ArraySequence(clusters_indices)
clusters_indices._data = clusters_indices._data.astype(np.uint32)

logger.info(f'QBx with seed {seed} at {TCT}mm took '
f'{get_duration(cluster_timer)}sec. gave '
f'{len(cluster_map.centroids)} centroids')
logger.info(f"QBx with seed {seed} at {TCT}mm took "
f"{get_duration(cluster_timer)}sec. gave "
f"{len(cluster_map.centroids)} centroids")

tmp_dir, tmp_memmap_filenames = streamlines_to_memmap(wb_streamlines,
'float16')
"float16")

# Memory cleanup (before multiprocessing)
cluster_map.refdata = None
Expand Down Expand Up @@ -331,8 +331,8 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None,
pool.close()
pool.join()

logger.info(f'BundleSeg took {get_duration(total_timer)} sec. for '
f'{len(bundle_names)} bundles from {len(self.atlas_dir)} atlas')
logger.info(f"BundleSeg took {get_duration(total_timer)} sec. for "
f"{len(bundle_names)} bundles from {len(self.atlas_dir)} atlas")

bundles_wise_vote = np.zeros((len(bundle_names),
len_wb_streamlines),
Expand Down Expand Up @@ -369,10 +369,10 @@ def __call__(self, input_tractograms_path, nbr_processes=1, seed=None,
minimum_vote, ext)
tmp_dir.cleanup()
saved_bundles = [f for f in os.listdir(self.output_directory)
if os.path.splitext(f)[1] in ['.trk', '.tck']]
logger.info(f'Saving of {len(saved_bundles)} files in '
f'{self.output_directory} took '
f'{get_duration(save_timer)} sec.')
if os.path.splitext(f)[1] in [".trk", ".tck"]]
logger.info(f"Saving of {len(saved_bundles)} files in "
f"{self.output_directory} took "
f"{get_duration(save_timer)} sec.")


def single_recognize_parallel(args):
Expand Down Expand Up @@ -427,7 +427,7 @@ def single_recognize(args):
shorter_tag, ext = os.path.splitext(os.path.basename(model_filepath))

# Now hardcoded (not useful with FSS from Etienne)
slr_transform_type = 'similarity'
slr_transform_type = "similarity"

recognize_timer = time()
results = bsg.recognize(model_bundle,
Expand All @@ -438,11 +438,11 @@ def single_recognize(args):
recognized_indices, recognized_scores = results
del model_bundle._data, model_bundle

logger.info(f'Model {shorter_tag} recognized {len(recognized_indices)} '
'streamlines')
logger.debug(f'Model {model_filepath} with parameters tct={TCT}, mct={MCT}, '
f'bpt={bundle_pruning_thr} '
f'took {get_duration(recognize_timer)} sec.')
logger.info(f"Model {shorter_tag} recognized {len(recognized_indices)} "
"streamlines")
logger.debug(f"Model {model_filepath} with parameters tct={TCT}, mct={MCT}, "
f"bpt={bundle_pruning_thr} "
f"took {get_duration(recognize_timer)} sec.")

bundle_id = bundle_names.index(shorter_tag+ext)
return bundle_id, recognized_indices, recognized_scores