Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def __init__(self,

def initialize_from_trained_model_folder(self, model_training_output_dir: str,
use_folds: Union[Tuple[Union[int, str]], None],
checkpoint_name: str = 'checkpoint_final.pth'):
checkpoint_name: str = 'checkpoint_final.pth',
trainer_class: Optional["nnUNetPredictor"] = None,):
"""
This is used when making predictions with a trained model
"""
Expand Down Expand Up @@ -96,11 +97,12 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
configuration_manager = plans_manager.get_configuration(configuration_name)
# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
if trainer_class is None:
raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. '
f'Please place it there (in any .py file)!')
network = trainer_class.build_network_architecture(
configuration_manager.network_arch_class_name,
configuration_manager.network_arch_init_kwargs,
Expand Down Expand Up @@ -423,7 +425,8 @@ def predict_from_data_iterator(self,
def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict,
segmentation_previous_stage: np.ndarray = None,
output_file_truncated: str = None,
save_or_return_probabilities: bool = False):
save_or_return_probabilities: bool = False,
return_logits_per_fold: bool = False):
"""
WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON.

Expand All @@ -447,7 +450,11 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di

if self.verbose:
print('predicting')
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu()
# For getting logits per fold, cpu extraction has to be done for each list element
if return_logits_per_fold:
predicted_logits = [ elem.cpu() for elem in self.predict_logits_from_preprocessed_data(dct['data'], return_logits_per_fold=return_logits_per_fold)]
else:
predicted_logits = self.predict_logits_from_preprocessed_data(dct['data'], return_logits_per_fold=return_logits_per_fold).cpu()

if self.verbose:
print('resampling to original shape')
Expand All @@ -456,19 +463,34 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di
self.plans_manager, self.dataset_json, output_file_truncated,
save_or_return_probabilities)
else:
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
dct['data_properties'],
return_probabilities=
save_or_return_probabilities)
if return_logits_per_fold:
ret = []
for elem in predicted_logits:
ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(
elem, self.plans_manager,
self.configuration_manager,
self.label_manager,
dct['data_properties'],
return_probabilities=save_or_return_probabilities))


else:
ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager,
self.configuration_manager,
self.label_manager,
dct['data_properties'],
return_probabilities=
save_or_return_probabilities)
if save_or_return_probabilities:
if return_logits_per_fold:
segs, probs = zip(*ret)
ret = [list(segs), list(probs)]
return ret[0], ret[1]
else:
return ret

@torch.inference_mode()
def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
def predict_logits_from_preprocessed_data(self, data: torch.Tensor, return_logits_per_fold: bool = False) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
Expand All @@ -479,6 +501,8 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
prediction = None
if return_logits_per_fold:
prediction = []

for params in self.list_of_parameters:

Expand All @@ -493,10 +517,12 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten
# this actually saves computation time
if prediction is None:
prediction = self.predict_sliding_window_return_logits(data).to('cpu')
if return_logits_per_fold:
prediction.append(self.predict_sliding_window_return_logits(data).to('cpu'))
else:
prediction += self.predict_sliding_window_return_logits(data).to('cpu')

if len(self.list_of_parameters) > 1:
if len(self.list_of_parameters) > 1 and not return_logits_per_fold:
prediction /= len(self.list_of_parameters)

if self.verbose: print('Prediction done')
Expand Down
5 changes: 4 additions & 1 deletion nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
from torch import distributed as dist
from torch._dynamo import OptimizedModule
from torch.cuda import device_count
from torch import GradScaler
try:
from torch import GradScaler # torch >= 2.3
except ImportError:
from torch.cuda.amp import GradScaler # torch < 2.3
from torch.nn.parallel import DistributedDataParallel as DDP

from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[project]
name = "nnunetv2"
name = "nnunetv2-neuropoly"
version = "2.6.2"
requires-python = ">=3.10"
description = "nnU-Net is a framework for out-of-the box image segmentation."
description = "NeuroPoly fork of nnU-Net"
readme = "readme.md"
license = { file = "LICENSE" }
authors = [
Expand Down Expand Up @@ -55,8 +55,8 @@ dependencies = [
]

[project.urls]
homepage = "https://github.com/MIC-DKFZ/nnUNet"
repository = "https://github.com/MIC-DKFZ/nnUNet"
homepage = "https://github.com/spinalcordtoolbox/nnUNet-neuropoly"
repository = "https://github.com/spinalcordtoolbox/nnUNet-neuropoly"

[project.scripts]
nnUNetv2_plan_and_preprocess = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry"
Expand Down
8 changes: 8 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# NeuroPoly Fork of nnU-Net

This project is a minor fork of the original nnU-Net project, with small fixes to improve compatibility with various Python dependencies, OS versions, and trained model versions.

The existing readme from the original nnU-Net project is preserved below.

----

# Welcome to the new nnU-Net!

Click [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead.
Expand Down
Loading