Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 75 additions & 28 deletions src/scilpy/cli/scil_tracking_local_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,11 @@
assert_inputs_exist, assert_outputs_exist,
parse_sh_basis_arg, verify_compression_th,
load_matrix_in_any_format)
from scilpy.io.tensor import (convert_tensor_to_dipy_format,
supported_tensor_formats,
tensor_format_description)
from scilpy.image.volume_space_management import DataVolume
from scilpy.tracking.propagator import ODFPropagator
from scilpy.tracking.propagator import ODFPropagator, TensorPropagator
from scilpy.tracking.rap import RAPContinue
from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser
from scilpy.tracking.tracker import Tracker
Expand All @@ -83,14 +86,37 @@ def _build_arg_parser():
formatter_class=argparse.RawTextHelpFormatter,
epilog=version_string)

# Input data options
data_g = p.add_argument_group('Input data options')
data_group = data_g.add_mutually_exclusive_group(required=True)
data_group.add_argument('--in_odf',
help='Path to the ODF SH coefficient file (.nii.gz).\n'
'Use this for ODF-based tracking.')
data_group.add_argument('--in_tensor',
help='Path to the DTI tensor file (.nii.gz).\n'
'Use this for tensor-based tracking.')
p.add_argument('in_seed',
help='Seeding mask (.nii.gz).')
p.add_argument('in_mask',
help='Tracking mask (.nii.gz).\n'
'Tracking will stop outside this mask. The last point '
'of each \nstreamline (triggering the stopping '
'criteria) IS added to the streamline.')
p.add_argument('out_tractogram',
help='Tractogram output file (must be .trk or .tck).')

# Options common to both scripts
add_mandatory_options_tracking(p)
track_g = add_tracking_options(p)
add_seeding_options(p)
track_g = add_tracking_options(p)

# Options only for here.
track_g.add_argument('--algo', default='prob', choices=['det', 'prob'],
help='Algorithm to use. [%(default)s]')
track_g.add_argument('--tensor_format', type=str, default='dipy',
choices=supported_tensor_formats,
help="Format of the input tensor file.\n"
"Only used with --in_tensor. [%(default)s]\n" +
tensor_format_description)
add_sphere_arg(track_g, symmetric_only=False)
track_g.add_argument('--sub_sphere',
type=int, default=0,
Expand Down Expand Up @@ -178,7 +204,11 @@ def main():
parser.error('Invalid output streamline file format (must be trk or ' +
'tck): {0}'.format(args.out_tractogram))

inputs = [args.in_odf, args.in_seed, args.in_mask]
inputs = [args.in_seed, args.in_mask]
if args.in_odf:
inputs.append(args.in_odf)
if args.in_tensor:
inputs.append(args.in_tensor)
assert_inputs_exist(parser, inputs)
assert_outputs_exist(parser, args, args.out_tractogram)

Expand All @@ -205,7 +235,8 @@ def main():
max_nbr_pts = int(args.max_length / args.step_size)
min_nbr_pts = max(int(args.min_length / args.step_size), 1)

assert_same_resolution([args.in_mask, args.in_odf, args.in_seed])
input_data_file = args.in_odf if args.in_odf else args.in_tensor
assert_same_resolution([args.in_mask, input_data_file, args.in_seed])

# Choosing our space and origin for this tracking
# If save_seeds, space and origin must be vox, center. Choosing those
Expand Down Expand Up @@ -254,29 +285,45 @@ def main():
mask = DataVolume(mask_data, mask_res, args.mask_interp)

# ------- INSTANTIATING PROPAGATOR -------
logging.info("Loading ODF SH data.")
odf_sh_img = nib.load(args.in_odf)
odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float)
odf_sh_res = odf_sh_img.header.get_zooms()[:3]
dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp)

logging.info("Instantiating propagator.")
# Converting step size to vox space
# We only support iso vox for now but allow slightly different vox 1e-3.
assert np.allclose(np.mean(odf_sh_res[:3]),
odf_sh_res, atol=1e-03)
voxel_size = odf_sh_img.header.get_zooms()[0]
vox_step_size = args.step_size / voxel_size

# Using space and origin in the propagator: vox and center, like
# in dipy.
sh_basis, is_legacy = parse_sh_basis_arg(args)

propagator = ODFPropagator(
dataset, vox_step_size, args.rk_order, args.algo, sh_basis,
args.sf_threshold, args.sf_threshold_init, theta, args.sphere,
sub_sphere=args.sub_sphere,
space=our_space, origin=our_origin, is_legacy=is_legacy)
if args.in_odf:
logging.info("Loading ODF SH data.")
odf_sh_img = nib.load(args.in_odf)
odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float)
odf_sh_res = odf_sh_img.header.get_zooms()[:3]
dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp)

logging.info("Instantiating ODF propagator.")
voxel_size = odf_sh_img.header.get_zooms()[0]
vox_step_size = args.step_size / voxel_size
sh_basis, is_legacy = parse_sh_basis_arg(args)

propagator = ODFPropagator(
dataset, vox_step_size, args.rk_order, args.algo, sh_basis,
args.sf_threshold, args.sf_threshold_init, theta, args.sphere,
sub_sphere=args.sub_sphere,
space=our_space, origin=our_origin, is_legacy=is_legacy)

else: # args.in_tensor
logging.info("Loading DTI tensor data.")
tensor_img = nib.load(args.in_tensor)
tensor_data = tensor_img.get_fdata(caching='unchanged', dtype=float)

# Convert tensor to dipy format if needed
if args.tensor_format != 'dipy':
logging.info(f"Converting tensor from {args.tensor_format} to dipy format.")
tensor_data = convert_tensor_to_dipy_format(tensor_data, args.tensor_format)

logging.info(f"Tensor data shape: {tensor_data.shape}")
tensor_res = tensor_img.header.get_zooms()[:3]
dataset = DataVolume(tensor_data, tensor_res, args.mask_interp)

logging.info("Instantiating tensor propagator.")
voxel_size = tensor_img.header.get_zooms()[0]
vox_step_size = args.step_size / voxel_size

propagator = TensorPropagator(
dataset, vox_step_size, args.rk_order, args.algo,
theta, space=our_space, origin=our_origin)

# ------- INSTANTIATING RAP OBJECT -------
if args.rap_mask:
Expand Down
6 changes: 3 additions & 3 deletions src/scilpy/cli/tests/test_tracking_local_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_execution_tracking_fodf(script_runner, monkeypatch):
'fodf.nii.gz')
in_mask = os.path.join(SCILPY_HOME, 'tracking',
'seeding_mask.nii.gz')
ret = script_runner.run(['scil_tracking_local_dev', in_fodf,
ret = script_runner.run(['scil_tracking_local_dev', '--in_odf', in_fodf,
in_mask, in_mask, 'local_prob.trk', '--nt', '10',
'--compress', '0.1', '--sh_basis', 'descoteaux07',
'--min_length', '20', '--max_length', '200',
Expand All @@ -44,7 +44,7 @@ def test_execution_tracking_rap(script_runner, monkeypatch):

in_rap_mask = os.path.join(SCILPY_HOME, 'tracking',
'seeding_mask.nii.gz')
ret = script_runner.run(['scil_tracking_local_dev', in_fodf,
ret = script_runner.run(['scil_tracking_local_dev', '--in_odf', in_fodf,
in_mask, in_mask, 'local_prob_rap.trk',
'--nt', '10',
'--compress', '0.1', '--sh_basis', 'descoteaux07',
Expand All @@ -68,7 +68,7 @@ def test_execution_tracking_fodf_custom_seeds(script_runner, monkeypatch):
custom_seeds = [[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]
np.save(in_custom_seeds, custom_seeds)

ret = script_runner.run(['scil_tracking_local_dev', in_fodf,
ret = script_runner.run(['scil_tracking_local_dev', '--in_odf', in_fodf,
in_mask, in_mask, 'local_prob2.trk',
'--in_custom_seeds', in_custom_seeds,
'--compress', '0.1', '--sh_basis', 'descoteaux07',
Expand Down
159 changes: 159 additions & 0 deletions src/scilpy/tracking/propagator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dipy.data import get_sphere
from dipy.io.stateful_tractogram import Space, Origin
from dipy.reconst.shm import sh_to_sf_matrix
from dipy.reconst.dti import eig_from_lo_tri

from scilpy.reconst.utils import (get_sphere_neighbours,
get_sh_order_and_fullness)
Expand Down Expand Up @@ -691,3 +692,161 @@ def _get_possible_next_dirs(self, pos, v_in):
valid_volumes = np.array(valid_volumes)

return valid_dirs, valid_volumes


class TensorPropagator(AbstractPropagator):
"""
Propagator for DTI tensor tracking. Tracks along the principal
eigenvector of the diffusion tensor.
"""
def __init__(self, datavolume, step_size, rk_order, algo, theta,
space=Space('vox'), origin=Origin('center')):
"""
Parameters
----------
datavolume: scilpy.image.volume_space_management.DataVolume
Trackable DataVolume object containing tensor data in lower
triangular format (6 coefficients: Dxx, Dxy, Dyy, Dxz, Dyz, Dzz).
step_size: float
The step size for tracking.
rk_order: int
Order for the Runge Kutta integration.
algo: string
Type of algorithm. Choices are 'det' or 'prob'
theta: float
Maximum angle (radians) between two steps.
space: dipy Space
Space of the streamlines during tracking. Default: VOX.
origin: dipy Origin
Origin of the streamlines during tracking. Default: center.
"""
super().__init__(datavolume, step_size, rk_order, space, origin)

if self.space == Space.RASMM:
raise NotImplementedError(
"This version of the propagator on tensors is not ready to work "
"in RASMM space.")

self.algo = algo
self.theta = theta
self.normalize_directions = True
self.line_rng_generator = None

def reset_data(self, new_data=None):
return super().reset_data(new_data)

def prepare_forward(self, seeding_pos, random_generator):
"""Get initial direction from tensor at seeding position."""
self.line_rng_generator = random_generator

# Get tensor at seeding position
tensor_data = self.datavolume.get_value_at_coordinate(
*seeding_pos, space=self.space, origin=self.origin)

if tensor_data is None:
logging.debug(f"Seed at {seeding_pos}: tensor_data is None")
return PropagationStatus.ERROR

if np.all(tensor_data == 0):
logging.debug(f"Seed at {seeding_pos}: tensor_data is all zeros")
return PropagationStatus.ERROR

# Get principal eigenvector
direction = self._get_direction_from_tensor(tensor_data)

if direction is None:
logging.debug(f"Seed at {seeding_pos}: failed to extract direction from tensor")
return PropagationStatus.ERROR

return TrackingDirection(direction)

def prepare_backward(self, line, forward_dir):
"""Flip direction for backward tracking."""
# forward_dir is a TrackingDirection (which is a list)
return TrackingDirection(-np.array(forward_dir))

def finalize_streamline(self, last_pos, v_in):
return super().finalize_streamline(last_pos, v_in)

def propagate(self, line, v_in):
"""Propagate using Runge-Kutta integration."""
return super().propagate(line, v_in)

def _sample_next_direction(self, pos, v_in):
"""Sample next direction from tensor."""
tensor_data = self.datavolume.get_value_at_coordinate(
*pos, space=self.space, origin=self.origin)

if tensor_data is None or np.all(tensor_data == 0):
return None

# Get principal eigenvector
direction = self._get_direction_from_tensor(tensor_data)

if direction is None:
return None

# Check angle constraint
cosine = np.dot(v_in, direction) / (np.linalg.norm(v_in) * np.linalg.norm(direction))
cosine = np.clip(cosine, -1, 1)

# Flip if needed to maintain direction continuity
if cosine < 0:
direction = -direction
cosine = abs(cosine)

if np.arccos(cosine) > self.theta:
return None

# For probabilistic tracking, add some noise
if self.algo == 'prob' and self.line_rng_generator is not None:
# Add gaussian noise to direction
noise = self.line_rng_generator.normal(0, 0.1, 3)
Copy link
Contributor

@gabknight gabknight Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the normal std a member of the TensorPropagator class, with these default values? Potentially adding it too to scil_tracking_local_dev.py if these is use cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be less intuitive, but instead of 'det' or 'prob' could be just set the noise STD, in the doc you could suggest 0.1 std provide good prob results. if STD is 0 (default), you skip this step.

direction = direction + noise
direction = direction / np.linalg.norm(direction)

return direction

def _get_direction_from_tensor(self, tensor_data):
"""
Extract principal eigenvector and FA from tensor data.

Parameters
----------
tensor_data : ndarray
Tensor coefficients in lower triangular format (6 values).

Returns
-------
direction : ndarray or None
Principal eigenvector (3D direction).
"""
if len(tensor_data) != 6:
logging.warning(f"Expected 6 tensor coefficients, got {len(tensor_data)}")
return None, 0.0

logging.debug(f"Tensor data: {tensor_data}")

# Compute eigenvalues and eigenvectors
# eig_from_lo_tri returns a flat array of 12 values:
# [eval1, eval2, eval3, evec1_x, evec1_y, evec1_z, evec2_x, evec2_y, evec2_z, evec3_x, evec3_y, evec3_z]
try:
result = eig_from_lo_tri(tensor_data)
evals = result[:3] # First 3 values are eigenvalues
evecs = result[3:].reshape(3, 3) # Next 9 values are eigenvectors (3x3 matrix)
logging.debug(f"Eigenvalues: {evals}, Eigenvectors shape: {evecs.shape}")
except Exception as e:
logging.debug(f"Exception in eig_from_lo_tri: {e}")
return None

# Sort by eigenvalue magnitude (largest first)
order = np.argsort(evals)[::-1]
evals = evals[order]
evecs = evecs[:, order]

# Principal eigenvector (associated with largest eigenvalue)
principal_evec = evecs[:, 0]

logging.debug(f"Principal eigenvector: {principal_evec}, Sorted eigenvalues: {evals}") # Calculate FA

return principal_evec