diff --git a/.gitignore b/.gitignore index ad1a91c78..3ad5a6417 100644 --- a/.gitignore +++ b/.gitignore @@ -75,3 +75,5 @@ venv/ # Hidden folder .hidden/ + +.DS_Store \ No newline at end of file diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index 8904efad2..e46930016 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -66,7 +66,7 @@ load_matrix_in_any_format) from scilpy.image.volume_space_management import DataVolume from scilpy.tracking.propagator import ODFPropagator -from scilpy.tracking.rap import RAPContinue +from scilpy.tracking.rap import RAPContinue, RAPSwitch from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser from scilpy.tracking.tracker import Tracker from scilpy.tracking.utils import (add_mandatory_options_tracking, @@ -76,6 +76,8 @@ verify_streamline_length_options, verify_seed_options) from scilpy.version import version_string +from scilpy.image.labels import get_data_as_labels +from scilpy.io.image import get_data_as_mask def _build_arg_parser(): @@ -149,15 +151,28 @@ def _build_arg_parser(): "fixed --rng_seed.\nEx: If tractogram_1 was created " "with -nt 1,000,000, \nyou can create tractogram_2 " "with \n--skip 1,000,000.") - - track_g.add_argument('--rap_mask', default=None, + rap_mode = track_g.add_mutually_exclusive_group() + rap_mode.add_argument('--rap_mask', default=None, help='Region-Adaptive Propagation mask (.nii.gz).\n' 'Region-Adaptive Propagation tractography will start within ' 'this mask.') + rap_mode.add_argument('--rap_labels', default=None, + help='Region-Adaptive Propagation label volume (.nii.gz) .\n' + 'Voxel values are integer labels (0=background, 1..N=regions) .\n' + 'Used with --rap_method switch to select policies per label.') track_g.add_argument('--rap_method', default='None', - choices=['None', 'continue'], - help="Region-Adaptive Propagation tractography method " + choices=['None', 'continue', 'switch'], + help="Region-Adaptive Propagation tractography method.\n" + "'continue': continues tracking with same params,\n" + "'switch': switches tracking params inside RAP mask.\n" " [%(default)s]") + track_g.add_argument('--rap_params', default=None, + help='JSON file containing RAP parameters.\n' + 'Required for rap_method=switch. Format:\n' + '{"step_size": float, "theta": float (degrees)}') + track_g.add_argument('--rap_save_entry_exit', default=None, + help='Save RAP entry/exit coordinates as a binary mask.\n' + 'Provide output filename (.nii.gz).') m_g = p.add_argument_group('Memory options') add_processes_arg(m_g) @@ -186,11 +201,16 @@ def main(): verify_compression_th(args.compress_th) verify_seed_options(parser, args) - if args.rap_mask is not None and args.rap_method == "None": + if (args.rap_mask is not None or args.rap_labels is not None) and args.rap_method == "None": parser.error('No RAP method selected.') - if not args.rap_method == "None" and args.rap_mask is None: - parser.error('No RAP mask selected.') - + if args.rap_method == 'continue' and args.rap_mask is None: + parser.error('RAP method "continue" requires --rap_mask.') + if args.rap_method == 'switch' and (args.rap_mask is None and args.rap_labels is None): + parser.error('RAP method "switch" requires --rap_mask or --rap_labels.') + if args.rap_method == 'switch' and args.rap_params is None: + parser.error('RAP method "switch" requires --rap_params to be specified.') + if args.rap_params is not None and args.rap_method != 'switch': + parser.error('--rap_params can only be used with --rap_method switch.') tracts_format = detect_format(args.out_tractogram) if tracts_format is not TrkFile: logging.warning("You have selected option --save_seeds but you are " @@ -282,15 +302,28 @@ def main(): if args.rap_mask: logging.info("Loading RAP mask.") rap_img = nib.load(args.rap_mask) - rap_data = rap_img.get_fdata(caching='unchanged', dtype=float) - rap_res = rap_img.header.get_zooms()[:3] - rap_mask = DataVolume(rap_data, rap_res, args.mask_interp) - else: - rap_mask = None + rap_mask_data = get_data_as_mask(rap_img) + rap_mask_res = rap_img.header.get_zooms()[:3] + rap_volume = DataVolume(rap_mask_data, rap_mask_res, args.mask_interp) + elif args.rap_labels: + logging.info("Loading RAP labels.") + rap_label_img = nib.load(args.rap_labels) + + # Convert the rap_labels image to int if float + if np.issubdtype(rap_label_img.get_data_dtype(), np.floating): + int_data = np.round(rap_label_img.get_fdata()).astype(np.int16) + rap_label_img = nib.Nifti1Image(int_data, rap_label_img.affine) + + rap_label_data = get_data_as_labels(rap_label_img) + rap_label_res = rap_label_img.header.get_zooms()[:3] + rap_volume = DataVolume(rap_label_data, rap_label_res, 'nearest') if args.rap_method == "continue": - rap = RAPContinue(rap_mask, propagator, max_nbr_pts, + rap = RAPContinue(rap_volume, propagator, max_nbr_pts, step_size=vox_step_size) + elif args.rap_method == "switch": + rap = RAPSwitch(rap_volume, propagator, max_nbr_pts, + rap_params_file=args.rap_params) else: rap = None @@ -323,6 +356,10 @@ def main(): else: data_per_streamline = {} + # Save RAP entry/exit mask if requested + if args.rap_save_entry_exit: + tracker.save_rap_entry_exit_mask(args.rap_save_entry_exit, mask_img) + # Compared with scil_tracking_local, using sft rather than # LazyTractogram to deal with space. # Contrary to scilpy or dipy, where space after tracking is vox, here diff --git a/src/scilpy/tracking/rap.py b/src/scilpy/tracking/rap.py index 552c15e58..0a48bf521 100644 --- a/src/scilpy/tracking/rap.py +++ b/src/scilpy/tracking/rap.py @@ -1,20 +1,27 @@ # -*- coding: utf-8 -*- +import json +import logging import numpy as np +from copy import deepcopy +from scilpy.tracking.propagator import get_sphere_neighbours class RAP: - def __init__(self, mask_rap, propagator, max_nbr_pts): + def __init__(self, rap_volume, propagator, max_nbr_pts): """ - RAP_mask: DataVolume + rap_volume: DataVolume HRegion-Adaptive Propagation tractography volume. """ - self.rap_mask = mask_rap + self.rap_volume = rap_volume self.propagator = propagator self.max_nbr_pts = max_nbr_pts + self._current_label = None + self._total_steps = 0 + self._current_cfg = {} def is_in_rap_region(self, curr_pos, space, origin): - return self.rap_mask.get_value_at_coordinate( + return self.rap_volume.get_value_at_coordinate( *curr_pos, space=space, origin=origin) > 0 def rap_multistep_propagate(self, line, prev_direction): @@ -42,29 +49,193 @@ def rap_multistep_propagate(self, line, prev_direction): class RAPContinue(RAP): """Dummy RAP class for tests. Goes straight""" - def __init__(self, mask_rap, propagator, max_nbr_pts, step_size): + def __init__(self, rap_volume, propagator, max_nbr_pts, step_size): """ Step size: float The step size inside the RAP mask. Could be different from the step size elsewhere. In voxel world. """ - super().__init__(mask_rap, propagator, max_nbr_pts) + super().__init__(rap_volume, propagator, max_nbr_pts) self.step_size = step_size def rap_multistep_propagate(self, line, prev_direction): is_line_valid = True - if len(line)>3: + if len(line) > 3: pos = line[-2] + self.step_size * np.array(prev_direction) line[-1] = pos return line, prev_direction, is_line_valid return line, prev_direction, is_line_valid +class RAPSwitch(RAP): + """RAP class that switches tracking parameters when inside the RAP mask or RAP label.""" + def __init__(self, rap_volume, propagator, max_nbr_pts, rap_params_file): + """ + Parameters + ---------- + rap_volume : DataVolume + Region-Adaptive Propagation mask. + propagator : Propagator + The propagator used for tracking. + max_nbr_pts : int + Maximum number of points per streamline. + rap_params_file : str + Path to JSON file containing RAP parameters. + "methods" is optionnal, if not provided, "default" will be applied + Expected format: + { + "methods": { + "1": {"algo": str, "theta": float, "step_size": float}, + "2": {"algo": str, "theta": float, "step_size": float}, + ... + } + } + """ + super().__init__(rap_volume, propagator, max_nbr_pts) + + # Load parameters from JSON file + with open(rap_params_file, 'r') as f: + rap_params = json.load(f) + + self._base = { + 'step_size': propagator.step_size, + 'theta': propagator.theta, + 'algo': getattr(propagator, 'algo', None), + 'tracking_neighbours': getattr(propagator, 'tracking_neighbours', None) + } + self.methods_cfg = rap_params.get('methods', {}) + logging.info("RAP parameters loaded:") + + # Check if all labels in the volume are covered by the configuration + unique_labels = np.unique(rap_volume.data) + # Remove 0 (background) and convert to int + unique_labels = [int(label) for label in unique_labels if label > 0] + + if unique_labels: + missing_labels = [label for label in unique_labels + if str(label) not in self.methods_cfg] + if missing_labels: + logging.warning( + f"Labels {missing_labels} found in RAP volume but not in " + f"methods config. Base parameters will be used for these labels." + ) + + def rap_multistep_propagate(self, line, prev_direction): + """ + Propagate within the RAP region using modified parameters. + + Parameters + ---------- + line : list + The current streamline. + prev_direction : np.ndarray + The previous tracking direction. + + Returns + ------- + line : list + The extended streamline. + prev_direction : np.ndarray + The last direction. + is_line_valid : bool + Whether the line is valid. + """ + # Switch to RAP parameters + label = self._get_label(line[-1], self.propagator.space, self.propagator.origin) + if label <= 0: + return line, prev_direction, False + # Apply the parameters of the RAP labels + cfg = self._merge_cfg(label) + + # Perform propagation with new parameters + self._apply_cfg(cfg) + new_pos, new_dir, is_direction_valid = self.propagator.propagate(line, prev_direction) + + # Add the new point to the line + if is_direction_valid: + line.append(new_pos) + if label != self._current_label: + if self._current_label is not None: + logging.debug(f"STEP[{self._total_steps}] label={self._current_label} algo={self._current_cfg.get('algo')} theta={self._current_cfg.get('theta')} step={self._current_cfg.get('step_size')}") + self._current_label = label + self._current_cfg = cfg + self._total_steps += 1 + return line, new_dir, True + return line, prev_direction, False + + def _get_label(self, curr_pos, space, origin): + """ + Receive label (int) at current position in RAP label volume. + + Parameters + ---------- + curr_pos: np.ndarray + This is the current 3D position of the streamline. + + space: Space + Coordinate space (here Space.VOX.). + + origin: Origin + Origin convention ('center'). + + Returns + ------- + int + The integer label at current position. + """ + v = self.rap_volume.get_value_at_coordinate(*curr_pos, space=space, origin=origin) + try: + return int(v) + except Exception: + return int(np.round(v)) + + def _merge_cfg(self, label): + """ + Merge the default configuration with the label-specific cfg override from the JSON policy. + + Parameters + ---------- + label: int + Integer of label at current position. + + Returns + ------- + dict + Configuration dict with keys 'algo', 'theta', 'step_size'. + """ + override = self.methods_cfg.get(str(label)) + if override is None: + return { + 'step_size': self._base['step_size'], + 'algo': self._base['algo'], + 'theta': float(np.degrees(self._base['theta'])) + } + return deepcopy(override) + + def _apply_cfg(self, cfg): + """ + Temporarily apply a label configuration to the propagator. + + Parameters + ---------- + cfg: dict + Configuration dict with keys 'algo', 'theta', 'step_size'. + """ + if 'step_size' in cfg and cfg['step_size'] is not None: + self.propagator.step_size = float(cfg['step_size']) + if 'algo' in cfg and cfg['algo'] is not None: + self.propagator.algo = str(cfg['algo']) + if 'theta' in cfg and cfg['theta'] is not None: + theta_rad = np.deg2rad(float(cfg['theta'])) + self.propagator.theta = theta_rad + # theta change => neighbours change + self.propagator.tracking_neighbours = get_sphere_neighbours(self.propagator.sphere, self.propagator.theta) + + class RAPGraph(RAP): def __init__(self, mask_rap, propagator, max_nbr_pts, neighboorhood_size): super().__init__(mask_rap, propagator, max_nbr_pts) self.neighboorhood_size = neighboorhood_size - def rap_multistep_propagate(self, line, prev_direction): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/src/scilpy/tracking/tracker.py b/src/scilpy/tracking/tracker.py index 9ff8a92c6..aee5f2a6a 100644 --- a/src/scilpy/tracking/tracker.py +++ b/src/scilpy/tracking/tracker.py @@ -105,6 +105,10 @@ def __init__(self, propagator: AbstractPropagator, mask: DataVolume, self.track_forward_only = track_forward_only self.append_last_point = append_last_point self.skip = skip + + # List to store RAP entry/exit coordinates as tuples (coord, type) + # where type is 1 for entry and 2 for exit + self.rap_entry_exit_coords = [] self.origin = self.propagator.origin self.space = self.propagator.space @@ -133,6 +137,55 @@ def __init__(self, propagator: AbstractPropagator, mask: DataVolume, self.verbose = verbose self.min_iter = min_iter + def save_rap_entry_exit_mask(self, output_path, reference_img): + """ + Save RAP entry/exit coordinates as a nifti mask. + Entry points have value 1, exit points have value 2. + + Parameters + ---------- + output_path : str + Path to save the nifti mask file. + reference_img : nibabel.Nifti1Image + Reference image to get affine and shape for the output mask. + """ + import nibabel as nib + + if not self.rap_entry_exit_coords: + logging.warning("No RAP entry/exit coordinates to save.") + return + + # Create empty mask with same shape as reference + mask_data = np.zeros(reference_img.shape[:3], dtype=np.uint8) + + # Convert coordinates to voxel space and set mask values + # Each element is a tuple (coord, coord_type) where coord_type is 1 (entry) or 2 (exit) + for coord, coord_type in self.rap_entry_exit_coords: + # Coordinates are already in voxel space (VOX, center) + # Round to nearest integer voxel + vox_coord = np.round(coord).astype(int) + + # Check bounds + if (0 <= vox_coord[0] < mask_data.shape[0] and + 0 <= vox_coord[1] < mask_data.shape[1] and + 0 <= vox_coord[2] < mask_data.shape[2]): + # Use max to handle overlapping entry/exit points + # If both entry and exit occur at same voxel, exit (2) will prevail + mask_data[vox_coord[0], vox_coord[1], vox_coord[2]] = max( + mask_data[vox_coord[0], vox_coord[1], vox_coord[2]], coord_type) + + # Create nifti image and save + mask_img = nib.Nifti1Image(mask_data, reference_img.affine, + reference_img.header) + nib.save(mask_img, output_path) + + entry_count = sum(1 for _, t in self.rap_entry_exit_coords if t == 1) + exit_count = sum(1 for _, t in self.rap_entry_exit_coords if t == 2) + logging.info(f"Saved RAP entry/exit mask to {output_path}") + logging.info(f"Entry coordinates: {entry_count}, Exit coordinates: {exit_count}") + logging.info(f"Unique voxels with entry (1): {np.sum(mask_data == 1)}, " + f"exit (2): {np.sum(mask_data == 2)}") + def track(self): """ Generate a set of streamline from seed, mask and odf files. @@ -443,24 +496,44 @@ def _propagate_line(self, line, previous_dir): """ invalid_direction_count = 0 propagation_can_continue = True + in_rap_region = False # Track whether we're currently in RAP region + step_count = 0 while len(line) < self.max_nbr_pts and propagation_can_continue: # Call the RAP function if needed. Can advance of as many points # as they want. - if (propagation_can_continue and self.rap and - self.rap.is_in_rap_region( - line[-1], space=self.space, origin=self.origin)): + is_currently_in_rap = (propagation_can_continue and self.rap and + self.rap.is_in_rap_region( + line[-1], space=self.space, origin=self.origin)) + + # Detect entering RAP region + if is_currently_in_rap and not in_rap_region: + self.rap_entry_exit_coords.append((line[-1].copy(), 1)) # 1 for entry + in_rap_region = True + logging.debug(f"TRACKER ENTERING pos={np.round(line[-1], 2)}") + + if is_currently_in_rap: + prev_len = len(line) line, new_dir, is_line_valid = ( self.rap.rap_multistep_propagate(line, previous_dir)) + if not is_line_valid: + logging.debug(f"TRACKER invalid, stop") + break + if len(line) == prev_len: + logging.debug(f"TRACKER no progress, stop") + propagation_can_continue = False + break + new_pos = line[-1] - if is_line_valid: - invalid_direction_count = 0 - else: + # Verify that our RAP propagated point stays within the tracking mask + propagation_can_continue = self._verify_stopping_criteria(new_pos) + if not propagation_can_continue: + logging.debug(f"TRACKER out of mask, stop.") + line.pop() break - new_pos = line[-1] - # Else, "normal" one-step propagation + step_count += 1 else: new_pos, new_dir, is_direction_valid = \ self.propagator.propagate(line, previous_dir) @@ -474,12 +547,13 @@ def _propagate_line(self, line, previous_dir): if invalid_direction_count > self.max_invalid_dirs: break - propagation_can_continue = self._verify_stopping_criteria(new_pos) - if propagation_can_continue or self.append_last_point: - line.append(new_pos) + propagation_can_continue = self._verify_stopping_criteria(new_pos) + if propagation_can_continue or self.append_last_point: + line.append(new_pos) previous_dir = new_dir + logging.debug(f"TRACKER end of propagation: {len(line)} total points, last pos={np.round(line[-1], 2)}") return line def _verify_stopping_criteria(self, last_pos):