diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index 8904efad2..d25e1839d 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -64,8 +64,11 @@ assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, verify_compression_th, load_matrix_in_any_format) +from scilpy.io.tensor import (convert_tensor_to_dipy_format, + supported_tensor_formats, + tensor_format_description) from scilpy.image.volume_space_management import DataVolume -from scilpy.tracking.propagator import ODFPropagator +from scilpy.tracking.propagator import ODFPropagator, TensorPropagator from scilpy.tracking.rap import RAPContinue from scilpy.tracking.seed import SeedGenerator, CustomSeedsDispenser from scilpy.tracking.tracker import Tracker @@ -83,14 +86,37 @@ def _build_arg_parser(): formatter_class=argparse.RawTextHelpFormatter, epilog=version_string) + # Input data options + data_g = p.add_argument_group('Input data options') + data_group = data_g.add_mutually_exclusive_group(required=True) + data_group.add_argument('--in_odf', + help='Path to the ODF SH coefficient file (.nii.gz).\n' + 'Use this for ODF-based tracking.') + data_group.add_argument('--in_tensor', + help='Path to the DTI tensor file (.nii.gz).\n' + 'Use this for tensor-based tracking.') + p.add_argument('in_seed', + help='Seeding mask (.nii.gz).') + p.add_argument('in_mask', + help='Tracking mask (.nii.gz).\n' + 'Tracking will stop outside this mask. The last point ' + 'of each \nstreamline (triggering the stopping ' + 'criteria) IS added to the streamline.') + p.add_argument('out_tractogram', + help='Tractogram output file (must be .trk or .tck).') + # Options common to both scripts - add_mandatory_options_tracking(p) - track_g = add_tracking_options(p) add_seeding_options(p) + track_g = add_tracking_options(p) # Options only for here. track_g.add_argument('--algo', default='prob', choices=['det', 'prob'], help='Algorithm to use. [%(default)s]') + track_g.add_argument('--tensor_format', type=str, default='dipy', + choices=supported_tensor_formats, + help="Format of the input tensor file.\n" + "Only used with --in_tensor. [%(default)s]\n" + + tensor_format_description) add_sphere_arg(track_g, symmetric_only=False) track_g.add_argument('--sub_sphere', type=int, default=0, @@ -178,7 +204,11 @@ def main(): parser.error('Invalid output streamline file format (must be trk or ' + 'tck): {0}'.format(args.out_tractogram)) - inputs = [args.in_odf, args.in_seed, args.in_mask] + inputs = [args.in_seed, args.in_mask] + if args.in_odf: + inputs.append(args.in_odf) + if args.in_tensor: + inputs.append(args.in_tensor) assert_inputs_exist(parser, inputs) assert_outputs_exist(parser, args, args.out_tractogram) @@ -205,7 +235,8 @@ def main(): max_nbr_pts = int(args.max_length / args.step_size) min_nbr_pts = max(int(args.min_length / args.step_size), 1) - assert_same_resolution([args.in_mask, args.in_odf, args.in_seed]) + input_data_file = args.in_odf if args.in_odf else args.in_tensor + assert_same_resolution([args.in_mask, input_data_file, args.in_seed]) # Choosing our space and origin for this tracking # If save_seeds, space and origin must be vox, center. Choosing those @@ -254,29 +285,45 @@ def main(): mask = DataVolume(mask_data, mask_res, args.mask_interp) # ------- INSTANTIATING PROPAGATOR ------- - logging.info("Loading ODF SH data.") - odf_sh_img = nib.load(args.in_odf) - odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float) - odf_sh_res = odf_sh_img.header.get_zooms()[:3] - dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) - - logging.info("Instantiating propagator.") - # Converting step size to vox space - # We only support iso vox for now but allow slightly different vox 1e-3. - assert np.allclose(np.mean(odf_sh_res[:3]), - odf_sh_res, atol=1e-03) - voxel_size = odf_sh_img.header.get_zooms()[0] - vox_step_size = args.step_size / voxel_size - - # Using space and origin in the propagator: vox and center, like - # in dipy. - sh_basis, is_legacy = parse_sh_basis_arg(args) - - propagator = ODFPropagator( - dataset, vox_step_size, args.rk_order, args.algo, sh_basis, - args.sf_threshold, args.sf_threshold_init, theta, args.sphere, - sub_sphere=args.sub_sphere, - space=our_space, origin=our_origin, is_legacy=is_legacy) + if args.in_odf: + logging.info("Loading ODF SH data.") + odf_sh_img = nib.load(args.in_odf) + odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float) + odf_sh_res = odf_sh_img.header.get_zooms()[:3] + dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) + + logging.info("Instantiating ODF propagator.") + voxel_size = odf_sh_img.header.get_zooms()[0] + vox_step_size = args.step_size / voxel_size + sh_basis, is_legacy = parse_sh_basis_arg(args) + + propagator = ODFPropagator( + dataset, vox_step_size, args.rk_order, args.algo, sh_basis, + args.sf_threshold, args.sf_threshold_init, theta, args.sphere, + sub_sphere=args.sub_sphere, + space=our_space, origin=our_origin, is_legacy=is_legacy) + + else: # args.in_tensor + logging.info("Loading DTI tensor data.") + tensor_img = nib.load(args.in_tensor) + tensor_data = tensor_img.get_fdata(caching='unchanged', dtype=float) + + # Convert tensor to dipy format if needed + if args.tensor_format != 'dipy': + logging.info(f"Converting tensor from {args.tensor_format} to dipy format.") + tensor_data = convert_tensor_to_dipy_format(tensor_data, args.tensor_format) + + logging.info(f"Tensor data shape: {tensor_data.shape}") + tensor_res = tensor_img.header.get_zooms()[:3] + dataset = DataVolume(tensor_data, tensor_res, args.mask_interp) + + logging.info("Instantiating tensor propagator.") + voxel_size = tensor_img.header.get_zooms()[0] + vox_step_size = args.step_size / voxel_size + + propagator = TensorPropagator( + dataset, vox_step_size, args.rk_order, args.algo, + theta, space=our_space, origin=our_origin) # ------- INSTANTIATING RAP OBJECT ------- if args.rap_mask: diff --git a/src/scilpy/cli/tests/test_tracking_local_dev.py b/src/scilpy/cli/tests/test_tracking_local_dev.py index b7d95bdee..9e8fdda54 100644 --- a/src/scilpy/cli/tests/test_tracking_local_dev.py +++ b/src/scilpy/cli/tests/test_tracking_local_dev.py @@ -25,7 +25,7 @@ def test_execution_tracking_fodf(script_runner, monkeypatch): 'fodf.nii.gz') in_mask = os.path.join(SCILPY_HOME, 'tracking', 'seeding_mask.nii.gz') - ret = script_runner.run(['scil_tracking_local_dev', in_fodf, + ret = script_runner.run(['scil_tracking_local_dev', '--in_odf', in_fodf, in_mask, in_mask, 'local_prob.trk', '--nt', '10', '--compress', '0.1', '--sh_basis', 'descoteaux07', '--min_length', '20', '--max_length', '200', @@ -44,7 +44,7 @@ def test_execution_tracking_rap(script_runner, monkeypatch): in_rap_mask = os.path.join(SCILPY_HOME, 'tracking', 'seeding_mask.nii.gz') - ret = script_runner.run(['scil_tracking_local_dev', in_fodf, + ret = script_runner.run(['scil_tracking_local_dev', '--in_odf', in_fodf, in_mask, in_mask, 'local_prob_rap.trk', '--nt', '10', '--compress', '0.1', '--sh_basis', 'descoteaux07', @@ -68,7 +68,7 @@ def test_execution_tracking_fodf_custom_seeds(script_runner, monkeypatch): custom_seeds = [[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]] np.save(in_custom_seeds, custom_seeds) - ret = script_runner.run(['scil_tracking_local_dev', in_fodf, + ret = script_runner.run(['scil_tracking_local_dev', '--in_odf', in_fodf, in_mask, in_mask, 'local_prob2.trk', '--in_custom_seeds', in_custom_seeds, '--compress', '0.1', '--sh_basis', 'descoteaux07', diff --git a/src/scilpy/tracking/propagator.py b/src/scilpy/tracking/propagator.py index 777f46c53..ececc569a 100644 --- a/src/scilpy/tracking/propagator.py +++ b/src/scilpy/tracking/propagator.py @@ -7,6 +7,7 @@ from dipy.data import get_sphere from dipy.io.stateful_tractogram import Space, Origin from dipy.reconst.shm import sh_to_sf_matrix +from dipy.reconst.dti import eig_from_lo_tri from scilpy.reconst.utils import (get_sphere_neighbours, get_sh_order_and_fullness) @@ -691,3 +692,161 @@ def _get_possible_next_dirs(self, pos, v_in): valid_volumes = np.array(valid_volumes) return valid_dirs, valid_volumes + + +class TensorPropagator(AbstractPropagator): + """ + Propagator for DTI tensor tracking. Tracks along the principal + eigenvector of the diffusion tensor. + """ + def __init__(self, datavolume, step_size, rk_order, algo, theta, + space=Space('vox'), origin=Origin('center')): + """ + Parameters + ---------- + datavolume: scilpy.image.volume_space_management.DataVolume + Trackable DataVolume object containing tensor data in lower + triangular format (6 coefficients: Dxx, Dxy, Dyy, Dxz, Dyz, Dzz). + step_size: float + The step size for tracking. + rk_order: int + Order for the Runge Kutta integration. + algo: string + Type of algorithm. Choices are 'det' or 'prob' + theta: float + Maximum angle (radians) between two steps. + space: dipy Space + Space of the streamlines during tracking. Default: VOX. + origin: dipy Origin + Origin of the streamlines during tracking. Default: center. + """ + super().__init__(datavolume, step_size, rk_order, space, origin) + + if self.space == Space.RASMM: + raise NotImplementedError( + "This version of the propagator on tensors is not ready to work " + "in RASMM space.") + + self.algo = algo + self.theta = theta + self.normalize_directions = True + self.line_rng_generator = None + + def reset_data(self, new_data=None): + return super().reset_data(new_data) + + def prepare_forward(self, seeding_pos, random_generator): + """Get initial direction from tensor at seeding position.""" + self.line_rng_generator = random_generator + + # Get tensor at seeding position + tensor_data = self.datavolume.get_value_at_coordinate( + *seeding_pos, space=self.space, origin=self.origin) + + if tensor_data is None: + logging.debug(f"Seed at {seeding_pos}: tensor_data is None") + return PropagationStatus.ERROR + + if np.all(tensor_data == 0): + logging.debug(f"Seed at {seeding_pos}: tensor_data is all zeros") + return PropagationStatus.ERROR + + # Get principal eigenvector + direction = self._get_direction_from_tensor(tensor_data) + + if direction is None: + logging.debug(f"Seed at {seeding_pos}: failed to extract direction from tensor") + return PropagationStatus.ERROR + + return TrackingDirection(direction) + + def prepare_backward(self, line, forward_dir): + """Flip direction for backward tracking.""" + # forward_dir is a TrackingDirection (which is a list) + return TrackingDirection(-np.array(forward_dir)) + + def finalize_streamline(self, last_pos, v_in): + return super().finalize_streamline(last_pos, v_in) + + def propagate(self, line, v_in): + """Propagate using Runge-Kutta integration.""" + return super().propagate(line, v_in) + + def _sample_next_direction(self, pos, v_in): + """Sample next direction from tensor.""" + tensor_data = self.datavolume.get_value_at_coordinate( + *pos, space=self.space, origin=self.origin) + + if tensor_data is None or np.all(tensor_data == 0): + return None + + # Get principal eigenvector + direction = self._get_direction_from_tensor(tensor_data) + + if direction is None: + return None + + # Check angle constraint + cosine = np.dot(v_in, direction) / (np.linalg.norm(v_in) * np.linalg.norm(direction)) + cosine = np.clip(cosine, -1, 1) + + # Flip if needed to maintain direction continuity + if cosine < 0: + direction = -direction + cosine = abs(cosine) + + if np.arccos(cosine) > self.theta: + return None + + # For probabilistic tracking, add some noise + if self.algo == 'prob' and self.line_rng_generator is not None: + # Add gaussian noise to direction + noise = self.line_rng_generator.normal(0, 0.1, 3) + direction = direction + noise + direction = direction / np.linalg.norm(direction) + + return direction + + def _get_direction_from_tensor(self, tensor_data): + """ + Extract principal eigenvector and FA from tensor data. + + Parameters + ---------- + tensor_data : ndarray + Tensor coefficients in lower triangular format (6 values). + + Returns + ------- + direction : ndarray or None + Principal eigenvector (3D direction). + """ + if len(tensor_data) != 6: + logging.warning(f"Expected 6 tensor coefficients, got {len(tensor_data)}") + return None, 0.0 + + logging.debug(f"Tensor data: {tensor_data}") + + # Compute eigenvalues and eigenvectors + # eig_from_lo_tri returns a flat array of 12 values: + # [eval1, eval2, eval3, evec1_x, evec1_y, evec1_z, evec2_x, evec2_y, evec2_z, evec3_x, evec3_y, evec3_z] + try: + result = eig_from_lo_tri(tensor_data) + evals = result[:3] # First 3 values are eigenvalues + evecs = result[3:].reshape(3, 3) # Next 9 values are eigenvectors (3x3 matrix) + logging.debug(f"Eigenvalues: {evals}, Eigenvectors shape: {evecs.shape}") + except Exception as e: + logging.debug(f"Exception in eig_from_lo_tri: {e}") + return None + + # Sort by eigenvalue magnitude (largest first) + order = np.argsort(evals)[::-1] + evals = evals[order] + evecs = evecs[:, order] + + # Principal eigenvector (associated with largest eigenvalue) + principal_evec = evecs[:, 0] + + logging.debug(f"Principal eigenvector: {principal_evec}, Sorted eigenvalues: {evals}") # Calculate FA + + return principal_evec