diff --git a/README.md b/README.md index bbc6c18..009ea5c 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ If you have a powerful GPU is available then a full-quality segmentation can be - Input volume: input CT image - Segmentation task: instead of the default "total" segmentation, a more specialized segmentation model can be chosen - Fast: performs segmentation faster, but at lower resolution + - Body crop: crops the images to the body region before processing them, saves GPU memory - Outputs - Segmentation: it will contain a brain segment, which specifies the brain region - Show 3D: show/hide segments in 3D views diff --git a/TotalSegmentator/Resources/UI/TotalSegmentator.ui b/TotalSegmentator/Resources/UI/TotalSegmentator.ui index d005c4e..dfd49b4 100644 --- a/TotalSegmentator/Resources/UI/TotalSegmentator.ui +++ b/TotalSegmentator/Resources/UI/TotalSegmentator.ui @@ -12,8 +12,8 @@ - - + + Inputs @@ -25,22 +25,22 @@ - + Input abdominal, chest, or whole body CT. - + vtkMRMLScalarVolumeNode - + false - + false - + false @@ -90,12 +90,30 @@ + + + + Body crop: + + + + + + + + + + + + + + - - + + Outputs @@ -107,37 +125,37 @@ - + This will store the segmentation result. - + vtkMRMLSegmentationNode - + false - + Brain - + true - + true - + true - + true - + true - + Create new segmentation on Apply @@ -149,23 +167,23 @@ - - + + Advanced - + true - + false - + - + diff --git a/TotalSegmentator/TotalSegmentator.py b/TotalSegmentator/TotalSegmentator.py index 8819552..b392c66 100644 --- a/TotalSegmentator/TotalSegmentator.py +++ b/TotalSegmentator/TotalSegmentator.py @@ -198,6 +198,7 @@ def updateGUIFromParameterNode(self, caller=None, event=None): task = self._parameterNode.GetParameter("Task") self.ui.taskComboBox.setCurrentIndex(self.ui.taskComboBox.findData(task)) self.ui.fastCheckBox.checked = self._parameterNode.GetParameter("Fast") == "true" + self.ui.bodycropCheckBox.checked = self._parameterNode.GetParameter("BodyCrop") == "true" self.ui.useStandardSegmentNamesCheckBox.checked = self._parameterNode.GetParameter("UseStandardSegmentNames") == "true" self.ui.outputSegmentationSelector.setCurrentNode(self._parameterNode.GetNodeReference("OutputSegmentation")) @@ -234,6 +235,7 @@ def updateParameterNodeFromGUI(self, caller=None, event=None): self._parameterNode.SetNodeReferenceID("InputVolume", self.ui.inputVolumeSelector.currentNodeID) self._parameterNode.SetParameter("Task", self.ui.taskComboBox.currentData) self._parameterNode.SetParameter("Fast", "true" if self.ui.fastCheckBox.checked else "false") + self._parameterNode.SetParameter("BodyCrop", "true" if self.ui.bodycropCheckBox.checked else "false") self._parameterNode.SetParameter("UseStandardSegmentNames", "true" if self.ui.useStandardSegmentNamesCheckBox.checked else "false") self._parameterNode.SetNodeReferenceID("OutputSegmentation", self.ui.outputSegmentationSelector.currentNodeID) @@ -262,7 +264,7 @@ def onApplyButton(self): # Compute output self.logic.process(self.ui.inputVolumeSelector.currentNode(), self.ui.outputSegmentationSelector.currentNode(), - self.ui.fastCheckBox.checked, self.ui.taskComboBox.currentData) + self.ui.fastCheckBox.checked, self.ui.bodycropCheckBox.checked, self.ui.taskComboBox.currentData) self.ui.statusLabel.appendPlainText("\nProcessing finished.") @@ -631,6 +633,8 @@ def setDefaultParameters(self, parameterNode): """ if not parameterNode.GetParameter("Fast"): parameterNode.SetParameter("Fast", "True") + if not parameterNode.GetParameter("BodyCrop"): + parameterNode.SetParameter("BodyCrop", "False") if not parameterNode.GetParameter("Task"): parameterNode.SetParameter("Task", "total") if not parameterNode.GetParameter("UseStandardSegmentNames"): @@ -655,7 +659,7 @@ def logProcessOutput(self, proc): if retcode != 0: raise CalledProcessError(retcode, proc.args, output=proc.stdout, stderr=proc.stderr) - def process(self, inputVolume, outputSegmentation, fast=True, task=None): + def process(self, inputVolume, outputSegmentation, fast=True, bodycrop=False, task=None): """ Run the processing algorithm. @@ -663,6 +667,7 @@ def process(self, inputVolume, outputSegmentation, fast=True, task=None): :param inputVolume: volume to be thresholded :param outputVolume: thresholding result :param fast: faster and less accurate output + :param bodycrop: crop the image to the body region before processing it, saves GPU memory :param task: one of self.tasks, default is "total" """ @@ -757,6 +762,8 @@ def process(self, inputVolume, outputSegmentation, fast=True, task=None): options.extend(["--task", task]) if fast: options.append("--fast") + if bodycrop: + options.append("--body_seg") self.log('Creating segmentations with TotalSegmentator AI...') self.log(f"Total Segmentator arguments: {options}") proc = slicer.util.launchConsoleProcess(totalSegmentatorCommand + options)