Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ authors:
- family-names: "Chakravarty"
given-names: "M. Mallar"
title: "RABIES: Rodent Automated Bold Improvement of EPI Sequences."
version: 0.5.5
date-released: 2025-12-19
version: 0.6.0
date-released: 2026-03-26
url: "https://github.com/CoBrALab/RABIES"


Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
author = 'CoBrALab'

# The full version, including alpha/beta/rc tags
release = '0.5.5'
release = '0.6.0'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ sphinxcontrib-bibtex
sphinxcontrib-programoutput
jinja2==3.1.1
pillow==10.1.0
rabies==0.5.5
rabies==0.6.0
traits<7.0
2 changes: 1 addition & 1 deletion rabies/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# 88YbdP88 8P 88""" dP__Yb Yb 88"Yb dP__Yb Yb "88 88""
# 88 YY 88 dP 88 dP""""Yb YboodP 88 Yb dP""""Yb YboodP 888888

VERSION = (0, 5, 5)
VERSION = (0, 6, 0)

__version__ = '.'.join(map(str, VERSION))
2 changes: 1 addition & 1 deletion rabies/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ def get_parser():
'--timeseries_interval', type=str, default='0-end',
help=
"Before confound correction, can crop the timeseries within a specific interval.\n"
"e.g. '0,80' for timepoint 0 to 80. 0 is the first time frame, and 'end' stands for \n"
"e.g. '0-80' for timepoint 0 to 80. 0 is the first time frame, and 'end' stands for \n"
"the last time frame. \n"
"(default: %(default)s)\n"
"\n"
Expand Down
36 changes: 23 additions & 13 deletions rabies/preprocess_pkg/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def init_bold_hmc_wf(opts, name='bold_hmc_wf'):
])

if opts.hmc_qc_report['apply']:
# multiply mem_gb 20X because there is massive RAM overload with this node
hmc_qc_node = pe.Node(HMC_QC(fps=opts.hmc_qc_report['fps'], figure_format=opts.figure_format),
name='hmc_qc_node', mem_gb=1.1*opts.scale_min_memory, n_procs=n_procs)
name='hmc_qc_node', mem_gb=1.1*opts.scale_min_memory*20, n_procs=n_procs)

workflow.connect([
(inputnode, hmc_qc_node, [('ref_image', 'ref_file'),
Expand Down Expand Up @@ -246,30 +247,35 @@ class HMC_QC(BaseInterface):

def _run_interface(self, runtime):
import pandas as pd
from simpleitk_timeseries_motion_correction.create_animation import main
filename = pathlib.Path(self.inputs.in_file).name.rsplit(".nii")[0]
figure_path = os.path.abspath(f'{filename}_HMC_QC.{self.inputs.figure_format}')
csv_path = os.path.abspath(f'{filename}_derivatives.csv')

n_procs = int(os.environ['RABIES_ITK_NUM_THREADS']) if "RABIES_ITK_NUM_THREADS" in os.environ else os.cpu_count() # default to number of CPUs

derivatives_dict = HMC_derivatives(self.inputs.in_file, self.inputs.ref_file, self.inputs.csv_params, get_R2=False, n_procs=n_procs)
print('Calculating HMC derivatives for QC visualization...')
img_preHMC, img_postHMC, derivatives_dict = HMC_derivatives(self.inputs.in_file, self.inputs.ref_file, self.inputs.csv_params, get_R2=False, n_procs=n_procs)

print('Creating video pre/post HMC...')
# write .webp file
video_file = os.path.abspath(f'{filename}_HMC.webp')
main(input_img=img_preHMC,
output_file=video_file,
additional_input_imgs=[img_postHMC],
labels=['Before correction', 'After correction'],
scale=2.0,
fps=self.inputs.fps)
del img_preHMC, img_postHMC

# save some outputs to .csv
key_l = ['D_Sc_preHMC', 'D_Sc_postHMC', 'mse_preHMC', 'mse_postHMC']
pd.DataFrame(np.array([derivatives_dict[key].flatten() for key in key_l]).T, columns=key_l).to_csv(csv_path)

print('Creating QC figure...')
fig = plot_motion_QC(derivatives_dict, self.inputs.ref_file, plot_R2=False)
fig.savefig(figure_path, bbox_inches='tight')

# write .webp file
video_file = os.path.abspath(f'{filename}_HMC.webp')
from simpleitk_timeseries_motion_correction.create_animation import main
main(input_img=derivatives_dict['img_preHMC'],
output_file=video_file,
additional_input_imgs=[derivatives_dict['img_postHMC']],
labels=['Before correction', 'After correction'],
scale=2.0,
fps=self.inputs.fps)
setattr(self, 'out_figure', figure_path)
setattr(self, 'out_csv', csv_path)
setattr(self, 'video_file', video_file)
Expand Down Expand Up @@ -339,7 +345,8 @@ def HMC_derivatives(in_img, in_ref, motcorr_params_file, n_procs=1, get_R2=False
ref_img = sitk.ReadImage(in_ref)
else:
raise ValueError(f"in_ref must be an SITK image or valid file path. Got: {in_ref}")

del in_img, in_ref

# prepare timeseries post-correction
transforms = read_transforms_from_csv(motcorr_params_file)
img_postHMC = framewise_resample_volume(
Expand All @@ -366,17 +373,20 @@ def HMC_derivatives(in_img, in_ref, motcorr_params_file, n_procs=1, get_R2=False

mse_preHMC = np.mean((timeseries_preHMC.T - ref_img_array)**2, axis=0) # taking mean square error
mse_postHMC = np.mean((timeseries_postHMC.T - ref_img_array)**2, axis=0) # taking mean square error
del ref_img_array, timeseries_preHMC, timeseries_postHMC

img_preHMC_SD = get_SD(img_preHMC)
img_postHMC_SD = get_SD(img_postHMC)

img_preHMC_R2 = get_motion_R2(img_preHMC, translations,rotations) if get_R2 else None
img_postHMC_R2 = get_motion_R2(img_postHMC, translations,rotations) if get_R2 else None

return {'img_preHMC':img_preHMC, 'img_postHMC':img_postHMC, 'translations':translations, 'rotations':rotations,
derivatives_dict = {'translations':translations, 'rotations':rotations,
'D_Sc_preHMC':D_Sc_preHMC, 'D_Sc_postHMC':D_Sc_postHMC,'mse_preHMC':mse_preHMC, 'mse_postHMC':mse_postHMC,
'img_preHMC_SD':img_preHMC_SD, 'img_postHMC_SD':img_postHMC_SD, 'img_preHMC_R2':img_preHMC_R2, 'img_postHMC_R2':img_postHMC_R2}

return img_preHMC, img_postHMC, derivatives_dict


def plot_motion_QC(derivatives_dict, ref_file, plot_R2=False):
import matplotlib.pyplot as plt
Expand Down
Loading