diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index b14d86ec5..6d7193eb3 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -66,7 +66,8 @@ def __init__(self, def initialize_from_trained_model_folder(self, model_training_output_dir: str, use_folds: Union[Tuple[Union[int, str]], None], - checkpoint_name: str = 'checkpoint_final.pth'): + checkpoint_name: str = 'checkpoint_final.pth', + trainer_class: Optional["nnUNetPredictor"] = None,): """ This is used when making predictions with a trained model """ @@ -96,11 +97,12 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, configuration_manager = plans_manager.get_configuration(configuration_name) # restore network num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) - trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), - trainer_name, 'nnunetv2.training.nnUNetTrainer') if trainer_class is None: - raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' - f'Please place it there (in any .py file)!') + trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + if trainer_class is None: + raise RuntimeError(f'Unable to locate trainer class {trainer_name} in nnunetv2.training.nnUNetTrainer. ' + f'Please place it there (in any .py file)!') network = trainer_class.build_network_architecture( configuration_manager.network_arch_class_name, configuration_manager.network_arch_init_kwargs, @@ -423,7 +425,8 @@ def predict_from_data_iterator(self, def predict_single_npy_array(self, input_image: np.ndarray, image_properties: dict, segmentation_previous_stage: np.ndarray = None, output_file_truncated: str = None, - save_or_return_probabilities: bool = False): + save_or_return_probabilities: bool = False, + return_logits_per_fold: bool = False): """ WARNING: SLOW. ONLY USE THIS IF YOU CANNOT GIVE NNUNET MULTIPLE IMAGES AT ONCE FOR SOME REASON. @@ -447,7 +450,11 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di if self.verbose: print('predicting') - predicted_logits = self.predict_logits_from_preprocessed_data(dct['data']).cpu() + # For getting logits per fold, cpu extraction has to be done for each list element + if return_logits_per_fold: + predicted_logits = [ elem.cpu() for elem in self.predict_logits_from_preprocessed_data(dct['data'], return_logits_per_fold=return_logits_per_fold)] + else: + predicted_logits = self.predict_logits_from_preprocessed_data(dct['data'], return_logits_per_fold=return_logits_per_fold).cpu() if self.verbose: print('resampling to original shape') @@ -456,19 +463,34 @@ def predict_single_npy_array(self, input_image: np.ndarray, image_properties: di self.plans_manager, self.dataset_json, output_file_truncated, save_or_return_probabilities) else: - ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, - self.configuration_manager, - self.label_manager, - dct['data_properties'], - return_probabilities= - save_or_return_probabilities) + if return_logits_per_fold: + ret = [] + for elem in predicted_logits: + ret.append(convert_predicted_logits_to_segmentation_with_correct_shape( + elem, self.plans_manager, + self.configuration_manager, + self.label_manager, + dct['data_properties'], + return_probabilities=save_or_return_probabilities)) + + + else: + ret = convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits, self.plans_manager, + self.configuration_manager, + self.label_manager, + dct['data_properties'], + return_probabilities= + save_or_return_probabilities) if save_or_return_probabilities: + if return_logits_per_fold: + segs, probs = zip(*ret) + ret = [list(segs), list(probs)] return ret[0], ret[1] else: return ret @torch.inference_mode() - def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor: + def predict_logits_from_preprocessed_data(self, data: torch.Tensor, return_logits_per_fold: bool = False) -> torch.Tensor: """ IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE! @@ -479,6 +501,8 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten n_threads = torch.get_num_threads() torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads) prediction = None + if return_logits_per_fold: + prediction = [] for params in self.list_of_parameters: @@ -493,10 +517,12 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Ten # this actually saves computation time if prediction is None: prediction = self.predict_sliding_window_return_logits(data).to('cpu') + if return_logits_per_fold: + prediction.append(self.predict_sliding_window_return_logits(data).to('cpu')) else: prediction += self.predict_sliding_window_return_logits(data).to('cpu') - if len(self.list_of_parameters) > 1: + if len(self.list_of_parameters) > 1 and not return_logits_per_fold: prediction /= len(self.list_of_parameters) if self.verbose: print('Prediction done') diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index fb3aa5d28..799c3c64a 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -40,7 +40,10 @@ from torch import distributed as dist from torch._dynamo import OptimizedModule from torch.cuda import device_count -from torch import GradScaler +try: + from torch import GradScaler # torch >= 2.3 +except ImportError: + from torch.cuda.amp import GradScaler # torch < 2.3 from torch.nn.parallel import DistributedDataParallel as DDP from nnunetv2.configuration import ANISO_THRESHOLD, default_num_processes diff --git a/pyproject.toml b/pyproject.toml index 66405fb68..526d23395 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,8 @@ [project] -name = "nnunetv2" +name = "nnunetv2-neuropoly" version = "2.6.2" requires-python = ">=3.10" -description = "nnU-Net is a framework for out-of-the box image segmentation." +description = "NeuroPoly fork of nnU-Net" readme = "readme.md" license = { file = "LICENSE" } authors = [ @@ -55,8 +55,8 @@ dependencies = [ ] [project.urls] -homepage = "https://github.com/MIC-DKFZ/nnUNet" -repository = "https://github.com/MIC-DKFZ/nnUNet" +homepage = "https://github.com/spinalcordtoolbox/nnUNet-neuropoly" +repository = "https://github.com/spinalcordtoolbox/nnUNet-neuropoly" [project.scripts] nnUNetv2_plan_and_preprocess = "nnunetv2.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry" diff --git a/readme.md b/readme.md index 9f84ff47b..c3f4fb255 100644 --- a/readme.md +++ b/readme.md @@ -1,3 +1,11 @@ +# NeuroPoly Fork of nnU-Net + +This project is a minor fork of the original nnU-Net project, with small fixes to improve compatibility with various Python dependencies, OS versions, and trained model versions. + +The existing readme from the original nnU-Net project is preserved below. + +---- + # Welcome to the new nnU-Net! Click [here](https://github.com/MIC-DKFZ/nnUNet/tree/nnunetv1) if you were looking for the old one instead.