Skip to content

Commit b7100a7

Browse files
authored
Merge pull request #7 from jqmcginnis/main
Update Release v1.1.0 - Reimplementation of tensorflow addons, changes to BIDS naming convention and lesion stats
2 parents 785fc47 + 5011d07 commit b7100a7

File tree

13 files changed

+450
-281
lines changed

13 files changed

+450
-281
lines changed

LST_AI/annotate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def annotate_lesions(atlas_t1, atlas_mask, t1w_native, seg_native, out_atlas_war
114114

115115
if __name__ == "__main__":
116116

117+
# Only for testing purposes
117118
lst_dir = os.getcwd()
118119
parent_directory = os.path.dirname(lst_dir)
119120
atlas_t1w_path = os.path.join(parent_directory, "atlas", "sub-mni152_space-mni_t1.nii.gz")

LST_AI/custom_tf.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
4+
def load_custom_model(model_path, compile=False):
5+
"""
6+
Loads a custom TensorFlow Keras model from the specified path.
7+
8+
This function is specifically designed to handle models that originally used the
9+
`tfa.InstanceNormalization` layer from TensorFlow Addons (tfa). Since tfa is no
10+
longer maintained, this function replaces the `InstanceNormalization` layer with a
11+
custom layer, `CustomGroupNormalization`, to ensure compatibility and avoid the need
12+
for installing tfa.
13+
14+
Args:
15+
model_path (str): The file path to the saved Keras model.
16+
compile (bool): If True, compiles the model after loading. Defaults to False.
17+
18+
Returns:
19+
tf.keras.Model: The loaded Keras model with `InstanceNormalization` layers replaced
20+
by `CustomGroupNormalization`.
21+
22+
Example:
23+
>>> model = load_custom_model('path/to/model.h5', compile=True)
24+
"""
25+
custom_objects = {
26+
'Addons>InstanceNormalization': CustomGroupNormalization,
27+
}
28+
return tf.keras.models.load_model(model_path, custom_objects=custom_objects, compile=compile)
29+
30+
31+
32+
class CustomGroupNormalization(tf.keras.layers.Layer):
33+
"""
34+
Custom Group Normalization layer for TensorFlow Keras models.
35+
36+
This class provides an alternative to the `tfa.InstanceNormalization` layer found in
37+
TensorFlow Addons (tfa), which is no longer maintained and not available for MAC ARM platforms.
38+
It facilitates the use of group normalization in models without the dependency on tfa, ensuring
39+
compatibility and broader platform support.
40+
41+
Args:
42+
groups (int): Number of groups for Group Normalization. Default is -1.
43+
**kwargs: Additional keyword arguments for layer configuration.
44+
"""
45+
def __init__(self, groups=-1, **kwargs):
46+
# Extract necessary arguments from kwargs
47+
self.groups = kwargs.pop('groups', -1)
48+
self.epsilon = kwargs.pop('epsilon', 0.001)
49+
self.center = kwargs.pop('center', True)
50+
self.scale = kwargs.pop('scale', True)
51+
self.beta_initializer = kwargs.pop('beta_initializer', 'zeros')
52+
self.gamma_initializer = kwargs.pop('gamma_initializer', 'ones')
53+
self.beta_regularizer = kwargs.pop('beta_regularizer', None)
54+
self.gamma_regularizer = kwargs.pop('gamma_regularizer', None)
55+
self.beta_constraint = kwargs.pop('beta_constraint', None)
56+
self.gamma_constraint = kwargs.pop('gamma_constraint', None)
57+
58+
# 'axis' argument is not used in GroupNormalization, so we remove it
59+
kwargs.pop('axis', None)
60+
61+
super(CustomGroupNormalization, self).__init__(**kwargs)
62+
self.group_norm = tf.keras.layers.GroupNormalization(
63+
groups=self.groups,
64+
epsilon=self.epsilon,
65+
center=self.center,
66+
scale=self.scale,
67+
beta_initializer=self.beta_initializer,
68+
gamma_initializer=self.gamma_initializer,
69+
beta_regularizer=self.beta_regularizer,
70+
gamma_regularizer=self.gamma_regularizer,
71+
beta_constraint=self.beta_constraint,
72+
gamma_constraint=self.gamma_constraint,
73+
**kwargs
74+
)
75+
76+
def call(self, inputs, training=None):
77+
return self.group_norm(inputs, training=training)
78+
79+
def get_config(self):
80+
config = super(CustomGroupNormalization, self).get_config()
81+
config.update({
82+
'groups': self.groups,
83+
'epsilon': self.epsilon,
84+
'center': self.center,
85+
'scale': self.scale,
86+
'beta_initializer': self.beta_initializer,
87+
'gamma_initializer': self.gamma_initializer,
88+
'beta_regularizer': self.beta_regularizer,
89+
'gamma_regularizer': self.gamma_regularizer,
90+
'beta_constraint': self.beta_constraint,
91+
'gamma_constraint': self.gamma_constraint
92+
})
93+
return config

LST_AI/lst

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,11 @@ import tempfile
1515
import shutil
1616
import argparse
1717

18-
# to filter the warning:
19-
# WARNING:root:The given value for groups will be overwritten.
20-
import logging
21-
class Filter(logging.Filter):
22-
def filter(self, record):
23-
return 'The given value for groups will be overwritten.' not in record.getMessage()
24-
25-
logging.getLogger().addFilter(Filter())
26-
2718
from LST_AI.strip import run_hdbet, apply_mask
2819
from LST_AI.register import mni_registration, apply_warp, rigid_reg
2920
from LST_AI.segment import unet_segmentation
3021
from LST_AI.annotate import annotate_lesions
22+
from LST_AI.stats import compute_stats
3123
from LST_AI.utils import download_data
3224

3325
if __name__ == "__main__":
@@ -135,10 +127,10 @@ if __name__ == "__main__":
135127
os.makedirs(work_dir)
136128

137129
# Define Image Paths (original space)
138-
path_org_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_T1w.nii.gz')
139-
path_org_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_FLAIR.nii.gz')
140-
path_org_stripped_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_desc-stripped_T1w.nii.gz')
141-
path_org_stripped_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_desc-stripped_FLAIR.nii.gz')
130+
path_org_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-t1w_T1w.nii.gz')
131+
path_org_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_FLAIR.nii.gz')
132+
path_org_stripped_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-t1w_desc-stripped_T1w.nii.gz')
133+
path_org_stripped_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_desc-stripped_FLAIR.nii.gz')
142134

143135
# Define Image Paths (MNI space)
144136
path_mni_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_T1w.nii.gz')
@@ -147,15 +139,23 @@ if __name__ == "__main__":
147139
path_mni_stripped_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_desc-stripped_FLAIR.nii.gz')
148140

149141
# Masks
150-
path_orig_brainmask_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-org_T1w_mask.nii.gz')
151-
path_orig_brainmask_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-org_FLAIR_mask.nii.gz')
142+
path_orig_brainmask_t1w = os.path.join(work_dir, 'sub-X_ses-Y_space-t1w_brainmask.nii.gz')
143+
path_orig_brainmask_flair = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_brainmask.nii.gz')
152144
path_mni_brainmask = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_brainmask.nii.gz')
153145

154-
# Segmentation results
155-
path_orig_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_seg.nii.gz')
156-
path_mni_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_seg.nii.gz')
157-
path_orig_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-orig_seg-annotated.nii.gz')
158-
path_mni_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_seg-annotated.nii.gz')
146+
# Temp Segmentation results
147+
path_orig_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_seg-lst.nii.gz')
148+
path_mni_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_seg-lst.nii.gz')
149+
path_orig_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-flair_desc-annotated_seg-lst.nii.gz')
150+
path_mni_annotated_segmentation = os.path.join(work_dir, 'sub-X_ses-Y_space-mni_desc-annotated_seg-lst.nii.gz')
151+
152+
# Output paths (in original space)
153+
filename_output_segmentation = "space-flair_seg-lst.nii.gz"
154+
filename_output_annotated_segmentation = "space-flair_desc-annotated_seg-lst.nii.gz"
155+
156+
# Stats
157+
filename_output_stats_segmentation = "lesion_stats.csv"
158+
filename_output_stats_annotated_segmentation = "annotated_lesion_stats.csv"
159159

160160
# affines
161161
path_affine_mni_t1w = os.path.join(work_dir, 'affine_t1w_to_mni.mat')
@@ -187,6 +187,7 @@ if __name__ == "__main__":
187187

188188
# Annotation only
189189
if args.annotate_only:
190+
print("LST-AI assumes existing segmentation to be in FLAIR space.")
190191
if os.path.isfile(args.existing_seg):
191192
shutil.copy(args.existing_seg, path_orig_segmentation)
192193
else:
@@ -240,7 +241,7 @@ if __name__ == "__main__":
240241
out_annotated_native=path_orig_annotated_segmentation)
241242

242243
shutil.copy(path_orig_annotated_segmentation,
243-
os.path.join(args.output, "space-orig_desc-annotated_seg-lst.nii.gz"))
244+
os.path.join(args.output, filename_output_annotated_segmentation))
244245

245246

246247
# Segmentation only + (opt. Annotation)
@@ -283,8 +284,7 @@ if __name__ == "__main__":
283284

284285
# move processed mask to correct naming convention
285286
hdbet_mask = path_mni_stripped_t1w.replace(".nii.gz", "_mask.nii.gz")
286-
print(hdbet_mask)
287-
shutil.copy(hdbet_mask, path_mni_brainmask)
287+
shutil.move(hdbet_mask, path_mni_brainmask)
288288

289289
# then apply brain mask to FLAIR
290290
apply_mask(input_image=path_mni_flair,
@@ -333,7 +333,7 @@ if __name__ == "__main__":
333333
n_threads=args.threads)
334334

335335
# store the segmentations
336-
shutil.copy(path_orig_segmentation, os.path.join(args.output, "space-orig_seg-lst.nii.gz"))
336+
shutil.copy(path_orig_segmentation, os.path.join(args.output, filename_output_segmentation))
337337

338338
# Annotation
339339
if not args.segment_only:
@@ -354,8 +354,18 @@ if __name__ == "__main__":
354354
n_threads=args.threads)
355355

356356
# store the segmentations
357-
shutil.copy(path_orig_annotated_segmentation, os.path.join(args.output, "space-orig_desc-annotated_seg-lst.nii.gz"))
358-
357+
shutil.copy(path_orig_annotated_segmentation, os.path.join(args.output, filename_output_annotated_segmentation))
358+
359+
# Compute Stats of (annotated) segmentation if they exist
360+
if os.path.exists(path_orig_segmentation):
361+
compute_stats(mask_file=path_orig_segmentation,
362+
output_file=os.path.join(args.output, filename_output_stats_segmentation),
363+
multi_class=False)
364+
365+
if os.path.exists(path_orig_annotated_segmentation):
366+
compute_stats(mask_file=path_orig_annotated_segmentation,
367+
output_file=os.path.join(args.output, filename_output_stats_annotated_segmentation),
368+
multi_class=True)
359369

360370
print(f"Results in {work_dir}")
361371
if not args.temp:

LST_AI/register.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def apply_warp(image_org_space, affine, origin, target, reverse=False, n_threads
106106

107107
subprocess.run(shlex.split(warp_call), check=True)
108108

109-
110-
111109
if __name__ == "__main__":
112110

111+
# Testing only
112+
113113
# Working directory
114114
script_dir = os.getcwd()
115115
parent_directory = os.path.dirname(script_dir)

LST_AI/segment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import numpy as np
66
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
77
import tensorflow as tf
8-
import tensorflow_addons as tfa
9-
#logging.getLogger("tensorflow").setLevel(logging.CRITICAL)
10-
#logging.getLogger("tensorflow_addons").setLevel(logging.CRITICAL)
8+
9+
from LST_AI.custom_tf import load_custom_model
10+
1111

1212
def unet_segmentation(model_path, mni_t1, mni_flair, output_segmentation_path, device='cpu', input_shape=(192,192,192), threshold=0.5):
1313
"""
@@ -99,7 +99,7 @@ def preprocess_intensities(img_arr):
9999
for i, model in enumerate(unet_mdls):
100100
with tf.device(tf_device):
101101
print(f"Running model {i}. ")
102-
mdl = tf.keras.models.load_model(model, compile=False)
102+
mdl = load_custom_model(model, compile=False)
103103

104104
img_image = np.stack([flair, t1], axis=-1)
105105
img_image = np.expand_dims(img_image, axis=0)
@@ -129,7 +129,7 @@ def preprocess_intensities(img_arr):
129129

130130

131131
if __name__ == "__main__":
132-
132+
# Testing only
133133
# Working directory
134134
script_dir = os.getcwd()
135135
parent_dir = os.path.dirname(script_dir)

LST_AI/stats.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import nibabel as nib
2+
import numpy as np
3+
import csv
4+
import argparse
5+
from scipy.ndimage import label
6+
7+
def compute_stats(mask_file, output_file, multi_class):
8+
"""
9+
Compute statistics from a lesion mask and save the results to a CSV file.
10+
11+
Parameters:
12+
mask_file (str): Path to the input mask file in NIfTI format.
13+
output_file (str): Path to the output CSV file where results will be saved.
14+
multi_class (bool): Flag indicating whether the mask contains multiple classes (True) or is binary (False).
15+
16+
This function calculates the number of lesions, the number of voxels in lesions, and the total lesion volume.
17+
If `multi_class` is True, these statistics are calculated for each lesion class separately.
18+
"""
19+
# Load the mask file
20+
mask = nib.load(mask_file)
21+
mask_data = mask.get_fdata()
22+
23+
# Voxel dimensions to calculate volume
24+
voxel_dims = mask.header.get_zooms()
25+
26+
results = []
27+
28+
if multi_class:
29+
# Multi-class processing
30+
lesion_labels = [1, 2, 3, 4]
31+
label_names = {
32+
1: 'Periventricular',
33+
2: 'Juxtacortical',
34+
3: 'Subcortical',
35+
4: 'Infratentorial'
36+
}
37+
38+
for lesion_label in lesion_labels:
39+
class_mask = mask_data == lesion_label
40+
41+
# Count lesions (connected components) for each class
42+
_ , num_lesions = label(class_mask)
43+
44+
voxel_count = np.count_nonzero(class_mask)
45+
volume = voxel_count * np.prod(voxel_dims)
46+
47+
results.append({
48+
'Region': label_names[lesion_label],
49+
'Num_Lesions': num_lesions,
50+
'Num_Vox': voxel_count,
51+
'Lesion_Volume': volume
52+
})
53+
54+
else:
55+
# Binary mask processing
56+
# Assert that only two unique values are present (0 and 1)
57+
unique_values = np.unique(mask_data)
58+
assert len(unique_values) <= 2, "Binary mask must contain no more than two unique values."
59+
60+
# Count lesions (connected components) in binary mask
61+
_, num_lesions = label(mask_data > 0)
62+
63+
voxel_count = np.count_nonzero(mask_data)
64+
volume = voxel_count * np.prod(voxel_dims)
65+
66+
results.append({
67+
'Num_Lesions': num_lesions,
68+
'Num_Vox': voxel_count,
69+
'Lesion_Volume': volume
70+
})
71+
72+
# Save results to CSV
73+
with open(output_file, 'w', newline='') as file:
74+
writer = csv.writer(file)
75+
if multi_class:
76+
writer.writerow(['Region', 'Num_Lesions', 'Num_Vox', 'Lesion_Volume'])
77+
for result in results:
78+
writer.writerow([result['Region'], result['Num_Lesions'], result['Num_Vox'], result['Lesion_Volume']])
79+
else:
80+
writer.writerow(['Num_Lesions', 'Num_Vox', 'Lesion_Volume'])
81+
for result in results:
82+
writer.writerow([result['Num_Lesions'], result['Num_Vox'], result['Lesion_Volume']])
83+
84+
if __name__ == "__main__":
85+
"""
86+
Main entry point of the script. Parses command-line arguments and calls the compute_stats function.
87+
"""
88+
parser = argparse.ArgumentParser(description='Process a lesion mask file.')
89+
parser.add_argument('--in', dest='input_file', required=True, help='Input mask file path')
90+
parser.add_argument('--out', dest='output_file', required=True, help='Output CSV file path')
91+
parser.add_argument('--multi-class', dest='multi_class', action='store_true', help='Flag for multi-class processing')
92+
93+
args = parser.parse_args()
94+
95+
compute_stats(args.input_file, args.output_file, args.multi_class)

0 commit comments

Comments
 (0)