diff --git a/README.md b/README.md
index bbc6c18..009ea5c 100644
--- a/README.md
+++ b/README.md
@@ -48,6 +48,7 @@ If you have a powerful GPU is available then a full-quality segmentation can be
- Input volume: input CT image
- Segmentation task: instead of the default "total" segmentation, a more specialized segmentation model can be chosen
- Fast: performs segmentation faster, but at lower resolution
+ - Body crop: crops the images to the body region before processing them, saves GPU memory
- Outputs
- Segmentation: it will contain a brain segment, which specifies the brain region
- Show 3D: show/hide segments in 3D views
diff --git a/TotalSegmentator/Resources/UI/TotalSegmentator.ui b/TotalSegmentator/Resources/UI/TotalSegmentator.ui
index d005c4e..dfd49b4 100644
--- a/TotalSegmentator/Resources/UI/TotalSegmentator.ui
+++ b/TotalSegmentator/Resources/UI/TotalSegmentator.ui
@@ -12,8 +12,8 @@
-
-
-
+
+
Inputs
@@ -25,22 +25,22 @@
-
-
+
Input abdominal, chest, or whole body CT.
-
+
vtkMRMLScalarVolumeNode
-
+
false
-
+
false
-
+
false
@@ -90,12 +90,30 @@
+ -
+
+
+ Body crop:
+
+
+
+ -
+
+
-
+
+
+
+
+
+
+
+
-
-
-
+
+
Outputs
@@ -107,37 +125,37 @@
-
-
+
This will store the segmentation result.
-
+
vtkMRMLSegmentationNode
-
+
false
-
+
Brain
-
+
true
-
+
true
-
+
true
-
+
true
-
+
true
-
+
Create new segmentation on Apply
@@ -149,23 +167,23 @@
-
-
-
+
+
Advanced
-
+
true
-
-
+
false
-
+
-
+
diff --git a/TotalSegmentator/TotalSegmentator.py b/TotalSegmentator/TotalSegmentator.py
index 8819552..b392c66 100644
--- a/TotalSegmentator/TotalSegmentator.py
+++ b/TotalSegmentator/TotalSegmentator.py
@@ -198,6 +198,7 @@ def updateGUIFromParameterNode(self, caller=None, event=None):
task = self._parameterNode.GetParameter("Task")
self.ui.taskComboBox.setCurrentIndex(self.ui.taskComboBox.findData(task))
self.ui.fastCheckBox.checked = self._parameterNode.GetParameter("Fast") == "true"
+ self.ui.bodycropCheckBox.checked = self._parameterNode.GetParameter("BodyCrop") == "true"
self.ui.useStandardSegmentNamesCheckBox.checked = self._parameterNode.GetParameter("UseStandardSegmentNames") == "true"
self.ui.outputSegmentationSelector.setCurrentNode(self._parameterNode.GetNodeReference("OutputSegmentation"))
@@ -234,6 +235,7 @@ def updateParameterNodeFromGUI(self, caller=None, event=None):
self._parameterNode.SetNodeReferenceID("InputVolume", self.ui.inputVolumeSelector.currentNodeID)
self._parameterNode.SetParameter("Task", self.ui.taskComboBox.currentData)
self._parameterNode.SetParameter("Fast", "true" if self.ui.fastCheckBox.checked else "false")
+ self._parameterNode.SetParameter("BodyCrop", "true" if self.ui.bodycropCheckBox.checked else "false")
self._parameterNode.SetParameter("UseStandardSegmentNames", "true" if self.ui.useStandardSegmentNamesCheckBox.checked else "false")
self._parameterNode.SetNodeReferenceID("OutputSegmentation", self.ui.outputSegmentationSelector.currentNodeID)
@@ -262,7 +264,7 @@ def onApplyButton(self):
# Compute output
self.logic.process(self.ui.inputVolumeSelector.currentNode(), self.ui.outputSegmentationSelector.currentNode(),
- self.ui.fastCheckBox.checked, self.ui.taskComboBox.currentData)
+ self.ui.fastCheckBox.checked, self.ui.bodycropCheckBox.checked, self.ui.taskComboBox.currentData)
self.ui.statusLabel.appendPlainText("\nProcessing finished.")
@@ -631,6 +633,8 @@ def setDefaultParameters(self, parameterNode):
"""
if not parameterNode.GetParameter("Fast"):
parameterNode.SetParameter("Fast", "True")
+ if not parameterNode.GetParameter("BodyCrop"):
+ parameterNode.SetParameter("BodyCrop", "False")
if not parameterNode.GetParameter("Task"):
parameterNode.SetParameter("Task", "total")
if not parameterNode.GetParameter("UseStandardSegmentNames"):
@@ -655,7 +659,7 @@ def logProcessOutput(self, proc):
if retcode != 0:
raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr)
- def process(self, inputVolume, outputSegmentation, fast=True, task=None):
+ def process(self, inputVolume, outputSegmentation, fast=True, bodycrop=False, task=None):
"""
Run the processing algorithm.
@@ -663,6 +667,7 @@ def process(self, inputVolume, outputSegmentation, fast=True, task=None):
:param inputVolume: volume to be thresholded
:param outputVolume: thresholding result
:param fast: faster and less accurate output
+ :param bodycrop: crop the image to the body region before processing it, saves GPU memory
:param task: one of self.tasks, default is "total"
"""
@@ -757,6 +762,8 @@ def process(self, inputVolume, outputSegmentation, fast=True, task=None):
options.extend(["--task", task])
if fast:
options.append("--fast")
+ if bodycrop:
+ options.append("--body_seg")
self.log('Creating segmentations with TotalSegmentator AI...')
self.log(f"Total Segmentator arguments: {options}")
proc = slicer.util.launchConsoleProcess(totalSegmentatorCommand + options)