Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,5 @@ venv/

# Hidden folder
.hidden/

.DS_Store
69 changes: 55 additions & 14 deletions src/scilpy/cli/scil_tracking_local_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
load_matrix_in_any_format)
from scilpy.image.volume_space_management import DataVolume
from scilpy.tracking.propagator import ODFPropagator
from scilpy.tracking.rap import RAPContinue
from scilpy.tracking.rap import RAPContinue, RAPSwitch
from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser
from scilpy.tracking.tracker import Tracker
from scilpy.tracking.utils import (add_mandatory_options_tracking,
Expand All @@ -76,6 +76,8 @@
verify_streamline_length_options,
verify_seed_options)
from scilpy.version import version_string
from scilpy.image.labels import get_data_as_labels
from scilpy.io.image import get_data_as_mask


def _build_arg_parser():
Expand Down Expand Up @@ -149,15 +151,28 @@ def _build_arg_parser():
"fixed --rng_seed.\nEx: If tractogram_1 was created "
"with -nt 1,000,000, \nyou can create tractogram_2 "
"with \n--skip 1,000,000.")

track_g.add_argument('--rap_mask', default=None,
rap_mode = track_g.add_mutually_exclusive_group()
rap_mode.add_argument('--rap_mask', default=None,
help='Region-Adaptive Propagation mask (.nii.gz).\n'
'Region-Adaptive Propagation tractography will start within '
'this mask.')
rap_mode.add_argument('--rap_labels', default=None,
help='Region-Adaptive Propagation label volume (.nii.gz) .\n'
'Voxel values are integer labels (0=background, 1..N=regions) .\n'
'Used with --rap_method switch to select policies per label.')
track_g.add_argument('--rap_method', default='None',
choices=['None', 'continue'],
help="Region-Adaptive Propagation tractography method "
choices=['None', 'continue', 'switch'],
help="Region-Adaptive Propagation tractography method.\n"
"'continue': continues tracking with same params,\n"
"'switch': switches tracking params inside RAP mask.\n"
" [%(default)s]")
track_g.add_argument('--rap_params', default=None,
help='JSON file containing RAP parameters.\n'
'Required for rap_method=switch. Format:\n'
'{"step_size": float, "theta": float (degrees)}')
track_g.add_argument('--rap_save_entry_exit', default=None,
help='Save RAP entry/exit coordinates as a binary mask.\n'
'Provide output filename (.nii.gz).')

m_g = p.add_argument_group('Memory options')
add_processes_arg(m_g)
Expand Down Expand Up @@ -186,11 +201,16 @@ def main():
verify_compression_th(args.compress_th)
verify_seed_options(parser, args)

if args.rap_mask is not None and args.rap_method == "None":
if (args.rap_mask is not None or args.rap_labels is not None) and args.rap_method == "None":
parser.error('No RAP method selected.')
if not args.rap_method == "None" and args.rap_mask is None:
parser.error('No RAP mask selected.')

if args.rap_method == 'continue' and args.rap_mask is None:
parser.error('RAP method "continue" requires --rap_mask.')
if args.rap_method == 'switch' and (args.rap_mask is None and args.rap_labels is None):
parser.error('RAP method "switch" requires --rap_mask or --rap_labels.')
if args.rap_method == 'switch' and args.rap_params is None:
parser.error('RAP method "switch" requires --rap_params to be specified.')
if args.rap_params is not None and args.rap_method != 'switch':
parser.error('--rap_params can only be used with --rap_method switch.')
tracts_format = detect_format(args.out_tractogram)
if tracts_format is not TrkFile:
logging.warning("You have selected option --save_seeds but you are "
Expand Down Expand Up @@ -279,18 +299,35 @@ def main():
space=our_space, origin=our_origin, is_legacy=is_legacy)

# ------- INSTANTIATING RAP OBJECT -------
rap_mask = None
rap_labels = None
if args.rap_mask:
logging.info("Loading RAP mask.")
rap_img = nib.load(args.rap_mask)
rap_data = rap_img.get_fdata(caching='unchanged', dtype=float)
rap_res = rap_img.header.get_zooms()[:3]
rap_mask = DataVolume(rap_data, rap_res, args.mask_interp)
else:
rap_mask = None
rap_mask_data = get_data_as_mask(rap_img)
rap_mask_res = rap_img.header.get_zooms()[:3]
rap_mask = DataVolume(rap_mask_data, rap_mask_res, args.mask_interp)

if args.rap_labels:
logging.info("Loading RAP labels.")
rap_label_img = nib.load(args.rap_labels)

# Convert the rap_labels image to int if float
if np.issubdtype(rap_label_img.get_data_dtype(), np.floating):
int_data = np.round(rap_label_img.get_fdata()).astype(np.int16)
rap_label_img = nib.Nifti1Image(int_data, rap_label_img.affine)

rap_label_data = get_data_as_labels(rap_label_img)
rap_label_res = rap_label_img.header.get_zooms()[:3]
rap_labels = DataVolume(rap_label_data, rap_label_res, 'nearest')
rap_mask = rap_labels

if args.rap_method == "continue":
rap = RAPContinue(rap_mask, propagator, max_nbr_pts,
step_size=vox_step_size)
elif args.rap_method == "switch":
rap = RAPSwitch(rap_mask, propagator, max_nbr_pts,
rap_params_file=args.rap_params)
else:
rap = None

Expand Down Expand Up @@ -323,6 +360,10 @@ def main():
else:
data_per_streamline = {}

# Save RAP entry/exit mask if requested
if args.rap_save_entry_exit:
tracker.save_rap_entry_exit_mask(args.rap_save_entry_exit, mask_img)

# Compared with scil_tracking_local, using sft rather than
# LazyTractogram to deal with space.
# Contrary to scilpy or dipy, where space after tracking is vox, here
Expand Down
167 changes: 164 additions & 3 deletions src/scilpy/tracking/rap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# -*- coding: utf-8 -*-

import json
import logging
import numpy as np
from copy import deepcopy
from scilpy.tracking.propagator import get_sphere_neighbours


class RAP:
Expand All @@ -12,6 +16,9 @@ def __init__(self, mask_rap, propagator, max_nbr_pts):
self.rap_mask = mask_rap
self.propagator = propagator
self.max_nbr_pts = max_nbr_pts
self._current_label = None
self._total_steps = 0
self._current_cfg = {}

def is_in_rap_region(self, curr_pos, space, origin):
return self.rap_mask.get_value_at_coordinate(
Expand Down Expand Up @@ -53,18 +60,172 @@ def __init__(self, mask_rap, propagator, max_nbr_pts, step_size):

def rap_multistep_propagate(self, line, prev_direction):
is_line_valid = True
if len(line)>3:
if len(line) > 3:
pos = line[-2] + self.step_size * np.array(prev_direction)
line[-1] = pos
return line, prev_direction, is_line_valid
return line, prev_direction, is_line_valid


class RAPSwitch(RAP):
"""RAP class that switches tracking parameters when inside the RAP mask or RAP label."""
def __init__(self, mask_rap, propagator, max_nbr_pts, rap_params_file):
"""
Parameters
----------
mask_rap : DataVolume
Region-Adaptive Propagation mask.
propagator : Propagator
The propagator used for tracking.
max_nbr_pts : int
Maximum number of points per streamline.
rap_params_file : str
Path to JSON file containing RAP parameters.
"methods" is optionnal, if not provided, "default" will be applied
Expected format:
{
"methods": {
"1": {"algo": str, "theta": float, "step_size": float},
"2": {"algo": str, "theta": float, "step_size": float},
...
}
}
"""
super().__init__(mask_rap, propagator, max_nbr_pts)

# Load parameters from JSON file
with open(rap_params_file, 'r') as f:
rap_params = json.load(f)

self._base = {
'step_size': propagator.step_size,
'theta': propagator.theta,
'algo' : getattr(propagator, 'algo', None),
'tracking_neighbours' : getattr(propagator, 'tracking_neighbours', None)
}
self.methods_cfg = rap_params.get('methods', {})
logging.info("RAP parameters loaded:")

def rap_multistep_propagate(self, line, prev_direction):
"""
Propagate within the RAP region using modified parameters.

Parameters
----------
line : list
The current streamline.
prev_direction : np.ndarray
The previous tracking direction.

Returns
-------
line : list
The extended streamline.
prev_direction : np.ndarray
The last direction.
is_line_valid : bool
Whether the line is valid.
"""
# Switch to RAP parameters
label = self._get_label(line[-1], self.propagator.space, self.propagator.origin)
if label <= 0:
return line, prev_direction, False
# Apply the parameters of the RAP labels
cfg = self._merge_cfg(label)

# Perform propagation with new parameters
self._apply_cfg(cfg)
new_pos, new_dir, is_direction_valid = self.propagator.propagate(line, prev_direction)

# Add the new point to the line
if is_direction_valid:
line.append(new_pos)
if label != self._current_label:
if self._current_label is not None:
logging.debug(f"STEP[{self._total_steps}] label={self._current_label} algo={self._current_cfg.get('algo')} theta={self._current_cfg.get('theta')} step={self._current_cfg.get('step_size')}")
self._current_label = label
self._current_cfg = cfg
self._total_steps += 1
return line, new_dir, True
return line, prev_direction, False


def _get_label(self, curr_pos, space, origin):
"""
Receive label (int) at current position in RAP label volume.

Parameters
----------
curr_pos: np.ndarray
This is the current 3D position of the streamline.

space: Space
Coordinate space (here Space.VOX.).

origin: Origin
Origin convention ('center').

Returns
-------
int
The integer label at current position.
"""
v = self.rap_mask.get_value_at_coordinate(*curr_pos, space=space, origin=origin)
try:
return int(v)
except Exception:
return int(np.round(v))


def _merge_cfg(self, label):
"""
Merge the default configuration with the label-specific cfg override from the JSON policy.

Parameters
----------
label: int
Integer of label at current position.

Returns
-------
dict
Configuration dict with keys 'algo', 'theta', 'step_size'.
"""
override = self.methods_cfg.get(str(label))
if override is None:
if label != 1 and label != self._current_label:
logging.warning(f"Label {label} not found in methods, base params used.")
return {
'step_size': self._base['step_size'],
'algo': self._base['algo'],
'theta': float(np.degrees(self._base['theta']))
}
return deepcopy(override)


def _apply_cfg(self, cfg):
"""
Temporarily apply a label configuration to the propagator.

Parameters
----------
cfg: dict
Configuration dict with keys 'algo', 'theta', 'step_size'.
"""
if 'step_size' in cfg and cfg['step_size'] is not None:
self.propagator.step_size = float(cfg['step_size'])
if 'algo' in cfg and cfg['algo'] is not None:
self.propagator.algo = str(cfg['algo'])
if 'theta' in cfg and cfg['theta'] is not None:
theta_rad = np.deg2rad(float(cfg['theta']))
self.propagator.theta = theta_rad
# theta change => neighbours change
self.propagator.tracking_neighbours = get_sphere_neighbours(self.propagator.sphere, self.propagator.theta)

class RAPGraph(RAP):
def __init__(self, mask_rap, propagator, max_nbr_pts, neighboorhood_size):
super().__init__(mask_rap, propagator, max_nbr_pts)
self.neighboorhood_size = neighboorhood_size


def rap_multistep_propagate(self, line, prev_direction):
raise NotImplementedError
raise NotImplementedError
Loading
Loading