diff --git a/pyproject.toml b/pyproject.toml index 3d84656f..9c18529c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ repository = "https://git.dkfz.de/mic/personal/group2/tassilow/nnssl" [project.scripts] nnssl_plan_and_preprocess = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry" -nnssl_convert_openmind = "nnssl.data.dataset_conversion.Dataset001_Openmind:main" +nnssl_convert_openmind = "nnssl.dataset_conversion.Dataset001_OpenMind:main" nnssl_extract_fingerprint = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry" nnssl_plan_experiment = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry" nnssl_preprocess = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry" diff --git a/src/nnssl/architectures/get_network_by_name.py b/src/nnssl/architectures/get_network_by_name.py index fab39dd0..186fa5c7 100644 --- a/src/nnssl/architectures/get_network_by_name.py +++ b/src/nnssl/architectures/get_network_by_name.py @@ -1,9 +1,9 @@ -from typing import Literal -from dynamic_network_architectures.architectures.abstract_arch import AbstractDynamicNetworkArchitectures +import torch.nn + from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet from dynamic_network_architectures.architectures.primus import PrimusS, PrimusB, PrimusM, PrimusL -from torch import nn -from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op +from dynamic_network_architectures.architectures.abstract_arch import AbstractDynamicNetworkArchitectures + from nnssl.architectures.architecture_registry import ( SUPPORTED_ARCHITECTURES, get_res_enc_l, @@ -69,9 +69,12 @@ def get_network_by_name( if architecture_name in ["ResEncL", "NoSkipResEncL"]: model: ResidualEncoderUNet try: - model = model.encoder - model.key_to_encoder = model.key_to_encoder.replace("encoder.", "") - model.keys_to_in_proj = [k.replace("encoder.", "") for k in model.keys_to_in_proj] + model.decoder = torch.nn.Identity() + # key_to_encoder = model.key_to_encoder.replace("encoder.", "") + # keys_to_in_proj = [k.replace("encoder.", "") for k in model.keys_to_in_proj] + # model = model.encoder + # model.key_to_encoder = key_to_encoder + # model.keys_to_in_proj = keys_to_in_proj except AttributeError: raise RuntimeError("Trying to get the 'encoder' of the network failed. Cannot return encoder only.") elif architecture_name in ["PrimusS", "PrimusB", "PrimusM", "PrimusL"]: diff --git a/src/nnssl/dataset_conversion/Dataset001_OpenMind.py b/src/nnssl/dataset_conversion/Dataset001_OpenMind.py index f86317c4..8694671f 100644 --- a/src/nnssl/dataset_conversion/Dataset001_OpenMind.py +++ b/src/nnssl/dataset_conversion/Dataset001_OpenMind.py @@ -1,4 +1,4 @@ -mport argparse +import argparse import os from collections import defaultdict from pathlib import Path @@ -134,4 +134,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/nnssl/experiment_planning/experiment_planners/plan_wandb.py b/src/nnssl/experiment_planning/experiment_planners/plan_wandb.py index 03472416..f89117d6 100644 --- a/src/nnssl/experiment_planning/experiment_planners/plan_wandb.py +++ b/src/nnssl/experiment_planning/experiment_planners/plan_wandb.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, asdict, is_dataclass -from typing import Any -from nnssl.experiment_planning.experiment_planners.plan import Plan, ConfigurationPlan + import json -import numpy as np + +from nnssl.experiment_planning.experiment_planners.plan import Plan, ConfigurationPlan def dataclass_to_dict(data): if is_dataclass(data): @@ -13,127 +13,15 @@ def dataclass_to_dict(data): @dataclass class ConfigurationPlan_wandb(ConfigurationPlan): - data_identifier: str - preprocessor_name: str - batch_size: int - batch_dice: bool - patch_size: np.ndarray - median_image_size_in_voxels: np.ndarray - spacing: np.ndarray - normalization_schemes: list[str] - use_mask_for_norm: list[str] - UNet_class_name: str - UNet_base_num_features: int - n_conv_per_stage_encoder: tuple[int] - n_conv_per_stage_decoder: tuple[int] - num_pool_per_axis: list[int] - pool_op_kernel_sizes: list[list[int]] - conv_kernel_sizes: list[list[int]] - unet_max_num_features: int - resampling_fn_data: str - resampling_fn_data_kwargs: dict[str, Any] - resampling_fn_mask: str - resampling_fn_mask_kwargs: dict[str, Any] - mask_ratio: float=None - vit_patch_size: list[int]=None - embed_dim: int=None - encoder_eva_depth: int=None - encoder_eva_numheads: int=None - decoder_eva_depth: int=None - decoder_eva_numheads: int=None - initial_lr: float=None, - - # - # - # def __getitem__(self, key): - # return getattr(self, key) - # - # def __setitem__(self, key, value): - # setattr(self, key, value) - # - # def __delitem__(self, key): - # delattr(self, key) - # - # def __contains__(self, key): - # return hasattr(self, key) - # - # def __len__(self): - # return len(self.__dict__) - # - # def keys(self): - # return self.__dict__.keys() - # - # def values(self): - # return [getattr(self, key) for key in self.keys()] - # - # def items(self): - # return [(key, getattr(self, key)) for key in self.keys()] + pass @dataclass class Plan_wandb(Plan): - dataset_name: str - plans_name: str - original_median_spacing_after_transp: list[float] - original_median_shape_after_transp: list[int] - image_reader_writer: str - transpose_forward: list[int] - transpose_backward: list[int] - configurations: dict[str, ConfigurationPlan] - experiment_planner_used: str - # def __getitem__(self, key): - # return getattr(self, key) - # - # def __setitem__(self, key, value): - # setattr(self, key, value) - # - # def __delitem__(self, key): - # delattr(self, key) - # - # def __contains__(self, key): - # return hasattr(self, key) - # - # def _expected_save_directory(self): - # pp_path = os.environ.get("nnssl_preprocessed") - # if pp_path is None: - # raise RuntimeError( - # "nnssl_preprocessed environment variable not set. This is where the preprocessed data will be saved." - # ) - # return os.path.join(pp_path, self.dataset_name, self.plans_name + ".json") - # - # def save_to_file(self, overwrite=False): - # save_dir = self._expected_save_directory() - # print(f"Saving plan to {save_dir}...") - # if os.path.isfile(save_dir) and not overwrite: - # return - # os.makedirs(os.path.dirname(save_dir), exist_ok=True) - # with open(save_dir, "w") as f: - # json.dump(self._json_serializable(), f, indent=4, sort_keys=False) - # - # def _json_serializable(self) -> dict: - # only_dicts = dataclass_to_dict(self) - # recursive_fix_for_json_export(only_dicts) - # return only_dicts - # - # def __len__(self): - # return len(self.__dict__) - # - # def keys(self): - # return self.__dict__.keys() - # - # def values(self): - # return [getattr(self, key) for key in self.keys()] - # - # def items(self): - # return [(key, getattr(self, key)) for key in self.keys()] - # @staticmethod def load_from_file(path: str): json_dict: dict = json.load(open(path, "r")) configs = {k: ConfigurationPlan_wandb(**v) for k, v in json_dict["configurations"].items()} json_dict["configurations"] = configs return Plan(**json_dict) - # - # def image_reader_writer_class(self) -> "Type[BaseReaderWriter]": - # return recursive_find_reader_writer_by_name(self.image_reader_writer) diff --git a/src/nnssl/run/load_pretrained_weights.py b/src/nnssl/run/load_pretrained_weights.py index bb26e41e..580f5613 100644 --- a/src/nnssl/run/load_pretrained_weights.py +++ b/src/nnssl/run/load_pretrained_weights.py @@ -1,4 +1,7 @@ +import warnings + import torch + from torch._dynamo import OptimizedModule from torch.nn.parallel import DistributedDataParallel as DDP @@ -32,14 +35,20 @@ def load_pretrained_weights(network, fname, verbose=False): model_dict = mod.state_dict() # verify that all but the segmentation layers have the same shape for key, _ in model_dict.items(): - if all([i not in key for i in skip_strings_in_pretrained]): - assert key in pretrained_dict, \ - f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \ - f"compatible with your network." - assert model_dict[key].shape == pretrained_dict[key].shape, \ - f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ - f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ - f"does not seem to be compatible with your network." + if all([i not in key for i in skip_strings_in_pretrained]): # perform check only for non-segmentation layers + if key not in pretrained_dict: + # user warning + warnings.warn( + f"{'=' * 20} WARNING {'=' * 20}\n" + f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " + "compatible with your network. This is not an error, but you should check whether this is intended." + f"\n{'=' * 20} WARNING {'=' * 20}" + ) + if (key in model_dict) and (key in pretrained_dict): + assert model_dict[key].shape == pretrained_dict[key].shape, \ + f"The shape of the parameters of key {key} is not the same. Pretrained model: " \ + f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \ + f"does not seem to be compatible with your network." # fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained # encoders. Not supported by this function though (see assertions above) @@ -50,8 +59,11 @@ def load_pretrained_weights(network, fname, verbose=False): # if (('module.' + k if is_ddp else k) in model_dict) and # all([i not in k for i in skip_strings_in_pretrained])} - pretrained_dict = {k: v for k, v in pretrained_dict.items() - if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])} + pretrained_dict = { + k: v + for k, v in pretrained_dict.items() + if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained]) + } model_dict.update(pretrained_dict) @@ -62,5 +74,3 @@ def load_pretrained_weights(network, fname, verbose=False): print(key, 'shape', value.shape) print("################### Done ###################") mod.load_state_dict(model_dict) - - diff --git a/src/nnssl/run/run_training_wandb.py b/src/nnssl/run/run_training_wandb.py index 57036623..22ea81ab 100644 --- a/src/nnssl/run/run_training_wandb.py +++ b/src/nnssl/run/run_training_wandb.py @@ -1,16 +1,22 @@ -from datetime import timedelta import os +import signal import socket -from typing import Type, Union, Optional -from loguru import logger +from typing import get_args +from datetime import timedelta +from typing import Type, Union, Optional -import nnssl +import wandb import torch.cuda import torch.distributed as dist import torch.multiprocessing as mp + +from loguru import logger from batchgenerators.utilities.file_and_folder_operations import join, isfile, load_json + +import nnssl from nnssl.experiment_planning.experiment_planners.plan_wandb import Plan_wandb +from nnssl.experiment_planning.experiment_planners.plan import PREPROCESS_SPACING_STYLES from nnssl.paths import nnssl_preprocessed from nnssl.run.load_pretrained_weights import load_pretrained_weights from nnssl.training.nnsslTrainer.AbstractTrainer import AbstractBaseTrainer @@ -33,23 +39,20 @@ def find_free_network_port() -> int: def get_trainer_from_args( - dataset_name_or_id: Union[int, str], - configuration: str, - fold: int, - trainer_name: str = "nnsslTrainer", - plans_identifier: str = "nnsslPlans", - device: torch.device = torch.device("cuda"), - # *args, - **kwargs, + dataset_name_or_id: Union[int, str], + configuration: str, + fold: int, + trainer_name: str = "nnsslTrainer", + plans_identifier: str = "nnsslPlans", + device: torch.device = torch.device("cuda"), + # *args, + **kwargs, ): - # load nnunet class and do sanity checks nnssl_trainer_cls: Type[AbstractBaseTrainer] = recursive_find_python_class( join(nnssl.__path__[0], "training", "nnsslTrainer"), trainer_name, "nnssl.training.nnsslTrainer" ) - print(nnssl_trainer_cls, trainer_name) if nnssl_trainer_cls is None: - raise RuntimeError( f"Could not find requested nnunet trainer {trainer_name} in " f"nnssl.training.nnsslTrainer (" @@ -76,27 +79,8 @@ def get_trainer_from_args( # initialize nnunet trainer preprocessed_dataset_folder_base = join(nnssl_preprocessed, maybe_convert_to_dataset_name(dataset_name_or_id)) plans_file = join(preprocessed_dataset_folder_base, plans_identifier + ".json") - plans: Plan = Plan_wandb.load_from_file(plans_file) - for param in kwargs: - if param in ["mask_ratio", "initial_lr", "vit_patch_size", "attention_drop_rate"]: - plans.plans_name = ( - plans.plans_name - + "__" - + param - + str(kwargs[param]) - .replace(".", "") - .replace("[", "") - .replace("]", "") - .replace(",", "") - .replace(" ", "") - ) - else: - plans.plans_name = plans.plans_name + "__" + param + str(kwargs[param]) - for config in plans.configurations: - plans.configurations[config][param] = kwargs[param] - - pretrain_json = load_json(join(preprocessed_dataset_folder_base, "pretrain_data.json")) - + plans: Plan_wandb = Plan_wandb.load_from_file(plans_file) + pretrain_json = load_json(join(preprocessed_dataset_folder_base, f"pretrain_data__{configuration}.json")) nnssl_trainer: AbstractBaseTrainer = nnssl_trainer_cls( plans, configuration, @@ -110,10 +94,10 @@ def get_trainer_from_args( def maybe_load_checkpoint( - nnunet_trainer: AbstractBaseTrainer, - continue_training: bool, - validation_only: bool, - pretrained_weights_file: str = None, + nnunet_trainer: AbstractBaseTrainer, + continue_training: bool, + validation_only: bool, + pretrained_weights_file: str = None, ): if continue_training and pretrained_weights_file is not None: raise RuntimeError( @@ -129,16 +113,19 @@ def maybe_load_checkpoint( # special case where --c is used to run a previously aborted validation if not isfile(expected_checkpoint_file): expected_checkpoint_file = join(nnunet_trainer.output_folder, "checkpoint_best.pth") - if not isfile(expected_checkpoint_file): - print( - f"WARNING: Cannot continue training because there seems to be no checkpoint available to " - f"continue from. Starting a new training..." - ) - raise RuntimeError( - f"Cannot continue training because there seems to be no checkpoint available to continue from. Starting a new training..." - ) - logger.info(f"Using {expected_checkpoint_file} as the starting checkpoint for training...") - + # if not isfile(expected_checkpoint_file): + # print( + # f"WARNING: Cannot continue training because there seems to be no checkpoint available to " + # f"continue from. Starting a new training..." + # ) + # raise RuntimeError( + # f"Cannot continue training because there seems to be no checkpoint available to continue from. Starting a new training..." + # ) + if isfile(expected_checkpoint_file): + logger.info(f"Using {expected_checkpoint_file} as the starting checkpoint for training...") + else: + expected_checkpoint_file = None + logger.info(f"No starting checkpoint available, starting a new training...") elif validation_only: expected_checkpoint_file = join(nnunet_trainer.output_folder, "checkpoint_final.pth") if not isfile(expected_checkpoint_file): @@ -151,14 +138,16 @@ def maybe_load_checkpoint( expected_checkpoint_file = None if expected_checkpoint_file is not None: - nnunet_trainer.load_checkpoint(expected_checkpoint_file) + try: + nnunet_trainer.load_checkpoint(expected_checkpoint_file) + except EOFError: + os.remove(expected_checkpoint_file) def setup_ddp(rank, world_size): # initialize the process group # Unpacking actually takes about dist.init_process_group("nccl", rank=rank, world_size=world_size, timeout=timedelta(minutes=25)) - torch.cuda.set_device(rank) def cleanup_ddp(): @@ -166,25 +155,24 @@ def cleanup_ddp(): def run_ddp( - rank, - dataset_name_or_id, - configuration, - fold, - tr, - p, - disable_checkpointing, - c, - val, - pretrained_weights, - npz, - val_with_best, - world_size, - add_params, + rank, + dataset_name_or_id, + configuration, + fold, + tr, + p, + disable_checkpointing, + c, + val, + pretrained_weights, + npz, + val_with_best, + world_size, + add_params, + use_wandb: bool = False, ): setup_ddp(rank, world_size) - - # torch.cuda.set_device(torch.device("cuda", dist.get_rank())) - + torch.cuda.set_device(torch.device("cuda", dist.get_rank())) device = torch.device(f"cuda:{rank}") nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, device, **add_params) @@ -199,8 +187,12 @@ def run_ddp( cudnn.deterministic = False cudnn.benchmark = True + # Prepare the auto-exiting in case wall-time is exceeded. + # This sets a internal flag, letting the trainer know it's 10 minutes till wall-clock time is up. + signal.signal(signal.SIGUSR1, nnunet_trainer.exit_training) + if not val: - nnunet_trainer.run_training() + nnunet_trainer.run_training(use_wandb) if val_with_best: nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, "checkpoint_best.pth")) @@ -209,21 +201,21 @@ def run_ddp( def run_training( - dataset_name_or_id: Union[str, int], - configuration: str, - fold: Union[int, str], - trainer_class_name: str = "nnsslTrainer", - plans_identifier: str = "nnsslPlans", - pretrained_weights: Optional[str] = None, - num_gpus: int = 1, - export_validation_probabilities: bool = False, - continue_training: bool = False, - only_run_validation: bool = False, - disable_checkpointing: bool = False, - val_with_best: bool = False, - device: torch.device = torch.device("cuda"), - *args, - **kwargs, + dataset_name_or_id: Union[str, int], + configuration: str, + fold: Union[int, str], + trainer_class_name: str = "nnsslTrainer", + plans_identifier: str = "nnsslPlans", + pretrained_weights: Optional[str] = None, + num_gpus: int = 1, + export_validation_probabilities: bool = False, + continue_training: bool = False, + only_run_validation: bool = False, + disable_checkpointing: bool = False, + val_with_best: bool = False, + device: torch.device = torch.device("cuda"), + *args, + **kwargs, ): if isinstance(fold, str): if fold != "all": @@ -238,9 +230,27 @@ def run_training( if val_with_best: assert not disable_checkpointing, "--val_best is not compatible with --disable_checkpointing" + try: + entity = os.environ.get("WANDB_ENTITY", None) + project = os.environ.get("WANDB_PROJECT", "nnssl") + run_id = os.environ.get("WANDB_RUN_ID", None) + + wandb.init( + entity=entity, + project=project, + id=run_id, + name=f"{dataset_name_or_id}_{configuration}_fold{fold}_{trainer_class_name}_{plans_identifier}", + ) + except wandb.Error as e: + print( + "Failed to initialize wandb. " + "Make sure you have set the WANDB_ENTITY and WANDB_PROJECT environment variables correctly." + ) + raise e + if num_gpus > 1: assert ( - device.type == "cuda" + device.type == "cuda" ), f"DDP training (triggered by num_gpus > 1) is only implemented for cuda devices. Your device: {device}" os.environ["MASTER_ADDR"] = "localhost" @@ -266,6 +276,7 @@ def run_training( val_with_best, num_gpus, add_params, + wandb.run is not None, # use_wandb ), nprocs=num_gpus, join=True, @@ -281,11 +292,15 @@ def run_training( **kwargs, ) + # Prepare the auto-exiting in case wall-time is exceeded. + # This sets a internal flag, letting the trainer know it's 10 minutes till wall-clock time is up. + signal.signal(signal.SIGUSR1, nnunet_trainer.exit_training) + if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing assert not ( - continue_training and only_run_validation + continue_training and only_run_validation ), f"Cannot set --c and --val flag at the same time. Dummy." maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) @@ -295,12 +310,16 @@ def run_training( cudnn.benchmark = True if not only_run_validation: - nnunet_trainer.run_training() + nnunet_trainer.run_training(using_wandb=True) if val_with_best: nnunet_trainer.load_checkpoint(join(nnunet_trainer.output_folder, "checkpoint_best.pth")) + nnunet_trainer.perform_actual_validation(export_validation_probabilities) + if wandb.run is not None: + wandb.finish() + def run_training_entry(): import argparse @@ -308,7 +327,12 @@ def run_training_entry(): parser = argparse.ArgumentParser() parser.add_argument("dataset_name_or_id", type=str, help="Dataset name or ID to train with") - parser.add_argument("configuration", type=str, help="Configuration that should be trained") + parser.add_argument( + "configuration", + type=str, + help="Configuration that should be trained", + choices=get_args(PREPROCESS_SPACING_STYLES), + ) parser.add_argument( "-tr", type=str, @@ -323,13 +347,22 @@ def run_training_entry(): default="nnsslPlans", help="[OPTIONAL] Use this flag to specify a custom plans identifier. Default: nnSSLPlans", ) + parser.add_argument( + "-fold", + type=str, + required=False, + default="all", + help="[OPTIONAL] Use this flag to specify the fold to train on. Default: all. " + "If you want to train on a specific fold, use an integer (e.g. 0-5). " + "If you want to train on all folds, use 'all'.", + ) parser.add_argument( "-pretrained_weights", type=str, required=False, default=None, help="[OPTIONAL] path to nnU-Net checkpoint file to be used as pretrained model. Will only " - "be used when actually training. Beta. Use with caution.", + "be used when actually training. Beta. Use with caution.", ) parser.add_argument( "-num_gpus", type=int, default=1, required=False, help="Specify the number of GPUs to use for training" @@ -339,7 +372,7 @@ def run_training_entry(): action="store_true", required=False, help="[OPTIONAL] Save softmax predictions from final validation as npz files (in addition to predicted " - "segmentations). Needed for finding the best ensemble.", + "segmentations). Needed for finding the best ensemble.", ) parser.add_argument( "--c", action="store_true", required=False, help="[OPTIONAL] Continue training from latest checkpoint" @@ -355,16 +388,16 @@ def run_training_entry(): action="store_true", required=False, help="[OPTIONAL] If set, the validation will be performed with the checkpoint_best instead " - "of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! " - "WARNING: This will use the same 'validation' folder as the regular validation " - "with no way of distinguishing the two!", + "of checkpoint_final. NOT COMPATIBLE with --disable_checkpointing! " + "WARNING: This will use the same 'validation' folder as the regular validation " + "with no way of distinguishing the two!", ) parser.add_argument( "--disable_checkpointing", action="store_true", required=False, help="[OPTIONAL] Set this flag to disable checkpointing. Ideal for testing things out and " - "you dont want to flood your hard drive with checkpoints.", + "you dont want to flood your hard drive with checkpoints.", ) parser.add_argument( "-device", @@ -372,19 +405,9 @@ def run_training_entry(): default="cuda", required=False, help="Use this to set the device the training should run with. Available options are 'cuda' " - "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " - "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!", + "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " + "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!", ) - parser.add_argument("--mask_ratio", required=False, default=0.5, type=float) - parser.add_argument("--vit_patch_size", required=False, default=[8, 8, 8], nargs="+", type=int) - parser.add_argument("--embed_dim", required=False, default=864, type=int) - parser.add_argument("--encoder_eva_depth", required=False, default=16, type=int) - parser.add_argument("--encoder_eva_numheads", required=False, default=12, type=int) - parser.add_argument("--decoder_eva_depth", required=False, default=6, type=int) - parser.add_argument("--decoder_eva_numheads", required=False, default=8, type=int) - parser.add_argument("--batch_size", required=False, default=None, type=int) - parser.add_argument("--initial_lr", required=False, default=None, type=float) - parser.add_argument("--attention_drop_rate", required=False, default=None, type=float) args = parser.parse_args() assert args.device in [ @@ -394,11 +417,12 @@ def run_training_entry(): ], f"-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}." # ------------------------------- Post Parsers ------------------------------- # + dataset_name = args.dataset_name_or_id config = args.configuration assert ( - os.environ.get("nnssl_results") is not None + os.environ.get("nnssl_results") is not None ), "nnssl_results not set. Stopping as no outputs would be written otherwise." if args.device == "cpu": @@ -417,7 +441,7 @@ def run_training_entry(): run_training( dataset_name, config, - "all", + args.fold, args.tr, args.p, args.pretrained_weights, @@ -428,16 +452,6 @@ def run_training_entry(): args.disable_checkpointing, args.val_best, device, - mask_ratio=args.mask_ratio, - vit_patch_size=args.vit_patch_size, - embed_dim=args.embed_dim, - encoder_eva_depth=args.encoder_eva_depth, - encoder_eva_numheads=args.encoder_eva_numheads, - decoder_eva_depth=args.decoder_eva_depth, - decoder_eva_numheads=args.decoder_eva_numheads, - batch_size=args.batch_size, - initial_lr=args.initial_lr, - attention_drop_rate=args.attention_drop_rate, ) diff --git a/src/nnssl/training/logging/nnssl_logger_wandb.py b/src/nnssl/training/logging/nnssl_logger_wandb.py index f0f5e000..2a015781 100644 --- a/src/nnssl/training/logging/nnssl_logger_wandb.py +++ b/src/nnssl/training/logging/nnssl_logger_wandb.py @@ -35,12 +35,13 @@ def __init__(self, verbose: bool = False, use_wandb: bool = False, wandb_init_ar self.wandb = use_wandb if self.wandb: project_name = "nnssl_{}".format(dataset_name) - run_id = os.getenv("WANDB_RUN_ID") + run_id = os.getenv("WANDB_RUN_ID", None) + entity = os.getenv("WANDB_ENTITY", None) maybe_resume_logging = self._maybe_resume_logging(wandb_init_args) if maybe_resume_logging: - wandb.init(project=project_name, entity='mic_rocket', id=run_id, allow_val_change=True, resume=maybe_resume_logging, **wandb_init_args) + wandb.init(project=project_name, entity=entity, id=run_id, allow_val_change=True, resume=maybe_resume_logging, **wandb_init_args) else: - wandb.init(project=project_name, entity='mic_rocket', id=run_id, allow_val_change=True, **wandb_init_args) + wandb.init(project=project_name, entity=entity, id=run_id, allow_val_change=True, **wandb_init_args) def _maybe_resume_logging(self, wandb_init_args) -> Union[None, str]: """ diff --git a/src/nnssl/training/nnsslTrainer/AbstractTrainer.py b/src/nnssl/training/nnsslTrainer/AbstractTrainer.py index e64cd791..7086e07d 100644 --- a/src/nnssl/training/nnsslTrainer/AbstractTrainer.py +++ b/src/nnssl/training/nnsslTrainer/AbstractTrainer.py @@ -8,6 +8,8 @@ import sys from types import FrameType + +import wandb from torch import nn from copy import deepcopy from datetime import datetime @@ -407,14 +409,14 @@ def exit_training(self, *args, **kwargs): self.print_to_log_file("Received exit signal. Terminating after finishing epoch.") self.exit_training_flag = True - def run_training(self): + def run_training(self, using_wandb: bool = False): try: self.on_train_start() for epoch in range(self.current_epoch, self.num_epochs): self.on_epoch_start() - self.on_train_epoch_start() + self.on_train_epoch_start(using_wandb) train_outputs = [] for batch_id in tqdm( @@ -422,7 +424,19 @@ def run_training(self): desc=f"Epoch {epoch}", disable=True if (("LSF_JOBID" in os.environ) or ("SLURM_JOB_ID" in os.environ)) else False, ): - train_outputs.append(self.train_step(next(self.dataloader_train))) + step_metrics = self.train_step(next(self.dataloader_train)) + train_outputs.append(step_metrics) + if using_wandb and wandb.run is not None and self.local_rank == 0: + if isinstance(step_metrics, dict): + # add train/ prefix to all keys + to_log_metrics = { + f"train/{k}": v + for k, v in step_metrics.items() + if not k.startswith("train/") and k not in ["epoch", "step"] + } + to_log_metrics["epoch"] = epoch + to_log_metrics["step"] = batch_id + epoch * self.num_iterations_per_epoch + wandb.log(to_log_metrics) self.on_train_epoch_end(train_outputs) @@ -431,7 +445,7 @@ def run_training(self): val_outputs = [] for batch_id in range(self.num_val_iterations_per_epoch): val_outputs.append(self.validation_step(next(self.dataloader_val))) - self.on_validation_epoch_end(val_outputs) + self.on_validation_epoch_end(val_outputs, using_wandb) if self.exit_training_flag: # This is a signal that we need to resubmit, so we break the loop and exit gracefully @@ -765,7 +779,7 @@ def on_train_end(self): empty_cache(self.device) self.print_to_log_file("Training done.") - def on_train_epoch_end(self, train_outputs: List[dict]): + def on_train_epoch_end(self, train_outputs: List[dict], using_wandb: bool = False): self.interrupt_at_nans(train_outputs) outputs = collate_outputs(train_outputs) @@ -775,9 +789,10 @@ def on_train_epoch_end(self, train_outputs: List[dict]): loss_here = np.vstack(losses_tr).mean() else: loss_here = np.mean(outputs["loss"]) + self.logger.log("train_losses", loss_here, self.current_epoch) - def on_validation_epoch_end(self, val_outputs: List[dict]): + def on_validation_epoch_end(self, val_outputs: List[dict], using_wandb: bool = False): outputs_collated = collate_outputs(val_outputs) if self.is_ddp: @@ -787,9 +802,17 @@ def on_validation_epoch_end(self, val_outputs: List[dict]): loss_here = np.vstack(losses_val).mean() else: loss_here = np.mean(outputs_collated["loss"]) + if using_wandb and wandb.run is not None: + wandb.log( + { + "val/loss": loss_here, + "epoch": self.current_epoch, + "step": self.current_epoch * self.num_iterations_per_epoch + } + ) self.logger.log("val_losses", loss_here, self.current_epoch) - def on_train_epoch_start(self): + def on_train_epoch_start(self, using_wandb: bool = False): self.network.train() self.lr_scheduler.step(self.current_epoch) self.print_to_log_file("") @@ -798,6 +821,15 @@ def on_train_epoch_start(self): # lrs are the same for all workers so we don't need to gather them in case of DDP training self.logger.log("lrs", self.optimizer.param_groups[0]["lr"], self.current_epoch) + if using_wandb and wandb.run is not None: + wandb.log( + { + "train/lr": self.optimizer.param_groups[0]["lr"], + "epoch": self.current_epoch, + "step": self.current_epoch * self.num_iterations_per_epoch + } + ) + def on_validation_epoch_start(self): self.network.eval() diff --git a/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseEvaMAETrainer.py b/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseEvaMAETrainer.py index 97847c66..f45722db 100644 --- a/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseEvaMAETrainer.py +++ b/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseEvaMAETrainer.py @@ -98,13 +98,13 @@ def configure_optimizers(self, stage: str = "warmup_all"): empty_cache(self.device) return optimizer, lr_scheduler - def on_train_epoch_start(self): + def on_train_epoch_start(self, using_wandb: bool = False) -> None: if self.current_epoch == 0: self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all") elif self.current_epoch == self.warmup_duration_whole_net: self.optimizer, self.lr_scheduler = self.configure_optimizers("train") - super().on_train_epoch_start() + super().on_train_epoch_start(using_wandb) @staticmethod def create_mask( diff --git a/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseMAETrainer.py b/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseMAETrainer.py index aa465846..34424e5f 100644 --- a/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseMAETrainer.py +++ b/src/nnssl/training/nnsslTrainer/masked_image_modeling/BaseMAETrainer.py @@ -1,6 +1,7 @@ import os from typing import List, Tuple, Union import matplotlib.pyplot as plt +import wandb from tqdm import tqdm from deprecated import deprecated from typing_extensions import override @@ -318,20 +319,32 @@ def get_centercrop_val_dataloader(self): ) return dl_val - def run_training(self): + def run_training(self, using_wandb:bool = False): try: self.on_train_start() for epoch in range(self.current_epoch, self.num_epochs): self.on_epoch_start() - self.on_train_epoch_start() + self.on_train_epoch_start(using_wandb) train_outputs = [] for batch_id in tqdm( - range(self.num_iterations_per_epoch), - desc=f"Epoch {epoch}", - disable=True if (("LSF_JOBID" in os.environ) or ("SLURM_JOB_ID" in os.environ)) else False, + range(self.num_iterations_per_epoch), + desc=f"Epoch {epoch}", + disable=True if (("LSF_JOBID" in os.environ) or ("SLURM_JOB_ID" in os.environ)) else False, ): - train_outputs.append(self.train_step(next(self.dataloader_train))) + step_metrics = self.train_step(next(self.dataloader_train)) + train_outputs.append(step_metrics) + if using_wandb and wandb.run is not None and self.local_rank == 0: + if isinstance(step_metrics, dict): + # add train/ prefix to all keys + to_log_metrics = { + f"train/{k}": v + for k, v in step_metrics.items() + if not k.startswith("train/") and k not in ["epoch", "step"] + } + to_log_metrics["epoch"] = epoch + to_log_metrics["step"] = batch_id + epoch * self.num_iterations_per_epoch + wandb.log(to_log_metrics) self.on_train_epoch_end(train_outputs) with torch.no_grad(): @@ -340,7 +353,7 @@ def run_training(self): for batch_id in range(self.num_val_iterations_per_epoch): val_batch = next(self.dataloader_val) val_outputs.append(self.validation_step(val_batch)) - self.on_validation_epoch_end(val_outputs) + self.on_validation_epoch_end(val_outputs, using_wandb) self.on_epoch_end() if self.exit_training_flag: diff --git a/src/nnssl/training/nnsslTrainer/masked_image_modeling/evaMAETrainer.py b/src/nnssl/training/nnsslTrainer/masked_image_modeling/evaMAETrainer.py index 51e8065a..ecd0c1f6 100644 --- a/src/nnssl/training/nnsslTrainer/masked_image_modeling/evaMAETrainer.py +++ b/src/nnssl/training/nnsslTrainer/masked_image_modeling/evaMAETrainer.py @@ -1,23 +1,28 @@ -import torch import os -from torch import nn -from torch._dynamo import OptimizedModule + from typing import Tuple, Union -from nnssl.adaptation_planning.adaptation_plan import AdaptationPlan, ArchitecturePlans -from nnssl.architectures.evaMAE_module import EvaMAE + +import torch +import wandb +import numpy as np + +from torch import nn +from tqdm import tqdm from torch import autocast +from torch import distributed as dist +from torch._dynamo import OptimizedModule +from torch.nn.parallel import DistributedDataParallel as DDP +from batchgenerators.utilities.file_and_folder_operations import join, save_json, maybe_mkdir_p + +from nnssl.paths import nnssl_results +from nnssl.utilities.helpers import empty_cache from nnssl.utilities.helpers import dummy_context -from tqdm import tqdm +from nnssl.architectures.evaMAE_module import EvaMAE from nnssl.experiment_planning.experiment_planners.plan import Plan +from nnssl.training.logging.nnssl_logger_wandb import nnSSLLogger_wandb +from nnssl.adaptation_planning.adaptation_plan import AdaptationPlan, ArchitecturePlans from nnssl.training.nnsslTrainer.masked_image_modeling.BaseMAETrainer import BaseMAETrainer from nnssl.training.lr_scheduler.warmup import Lin_incr_LRScheduler, PolyLRScheduler_offset -import numpy as np -from nnssl.paths import nnssl_results -from torch import distributed as dist -from nnssl.training.logging.nnssl_logger_wandb import nnSSLLogger_wandb -from batchgenerators.utilities.file_and_folder_operations import join, save_json, maybe_mkdir_p -from torch.nn.parallel import DistributedDataParallel as DDP -from nnssl.utilities.helpers import empty_cache class EvaMAETrainer(BaseMAETrainer): @@ -139,13 +144,13 @@ def configure_optimizers(self, stage: str = "warmup_all"): empty_cache(self.device) return optimizer, lr_scheduler - def on_train_epoch_start(self): + def on_train_epoch_start(self, using_wandb: bool = False): if self.current_epoch == 0: self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all") elif self.current_epoch == self.warmup_duration_whole_net: self.optimizer, self.lr_scheduler = self.configure_optimizers("train") - super().on_train_epoch_start() + super().on_train_epoch_start(using_wandb) def _overwrite_batch_size(self): if not self.is_ddp: @@ -294,7 +299,7 @@ def build_architecture_and_adaptation_plan( recommended_downstream_patchsize=self.recommended_downstream_patchsize, key_to_encoder="eva", key_to_stem="down_projection", - key_to_in_proj=("down_projection.proj",), + keys_to_in_proj=("down_projection.proj",), key_to_lpe="eva.pos_embed", ) raise NotImplementedError("Current AdaptationPlan is not correct") @@ -346,28 +351,40 @@ def validation_step(self, batch: dict) -> dict: return {"loss": l.detach().cpu().numpy()} - def run_training(self): + def run_training(self, using_wandb: bool = False) -> None: try: self.on_train_start() for epoch in range(self.current_epoch, self.num_epochs): self.on_epoch_start() - self.on_train_epoch_start() + self.on_train_epoch_start(using_wandb) train_outputs = [] for batch_id in tqdm( range(self.num_iterations_per_epoch), desc=f"Epoch {epoch}", disable=True if ("LSF_JOBID" in os.environ) else False, ): - train_outputs.append(self.train_step(next(self.dataloader_train))) + step_metrics = self.train_step(next(self.dataloader_train)) + train_outputs.append(step_metrics) + if using_wandb and wandb.run is not None and self.local_rank == 0: + if isinstance(step_metrics, dict): + # add train/ prefix to all keys + to_log_metrics = { + f"train/{k}": v + for k, v in step_metrics.items() + if not k.startswith("train/") and k not in ["epoch", "step"] + } + to_log_metrics["epoch"] = epoch + to_log_metrics["step"] = batch_id + epoch * self.num_iterations_per_epoch + wandb.log(to_log_metrics) self.on_train_epoch_end(train_outputs) with torch.no_grad(): self.on_validation_epoch_start() val_outputs = [] - for batch_id in tqdm(range(self.num_val_iterations_per_epoch)): + for _ in tqdm(range(self.num_val_iterations_per_epoch)): val_outputs.append(self.validation_step(next(self.dataloader_val))) - self.on_validation_epoch_end(val_outputs) + self.on_validation_epoch_end(val_outputs, using_wandb) if self.exit_training_flag: # This is a signal that we need to resubmit, so we break the loop and exit gracefully diff --git a/src/nnssl/training/nnsslTrainer/simCLR/simCLREvaTrainer.py b/src/nnssl/training/nnsslTrainer/simCLR/simCLREvaTrainer.py index 611ac307..67338633 100644 --- a/src/nnssl/training/nnsslTrainer/simCLR/simCLREvaTrainer.py +++ b/src/nnssl/training/nnsslTrainer/simCLR/simCLREvaTrainer.py @@ -61,13 +61,13 @@ def __init__( self.init_value = 0.1 self.scale_attn_inner = True - def on_train_epoch_start(self): + def on_train_epoch_start(self, using_wandb: bool = False) -> None: if self.current_epoch == 0: self.optimizer, self.lr_scheduler = self.configure_optimizers("warmup_all") elif self.current_epoch == self.warmup_duration_whole_net: self.optimizer, self.lr_scheduler = self.configure_optimizers("train") - super().on_train_epoch_start() + super().on_train_epoch_start(using_wandb) def configure_optimizers(self, stage: str = "warmup_all"): assert stage in ["warmup_all", "train"] @@ -142,7 +142,7 @@ def build_architecture_and_adaptation_plan( pretrain_num_input_channels=1, # This is the actual input patch size! key_to_encoder="encoder.eva", key_to_stem="encoder.down_projection", - key_to_in_proj=("encoder.down_projection.proj",), + keys_to_in_proj=("encoder.down_projection.proj",), key_to_lpe="encoder.eva.pos_embed", )