Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ repository = "https://git.dkfz.de/mic/personal/group2/tassilow/nnssl"

[project.scripts]
nnssl_plan_and_preprocess = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:plan_and_preprocess_entry"
nnssl_convert_openmind = "nnssl.data.dataset_conversion.Dataset001_Openmind:main"
nnssl_convert_openmind = "nnssl.dataset_conversion.Dataset001_OpenMind:main"
nnssl_extract_fingerprint = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:extract_fingerprint_entry"
nnssl_plan_experiment = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:plan_experiment_entry"
nnssl_preprocess = "nnssl.experiment_planning.plan_and_preprocess_entrypoints:preprocess_entry"
Expand Down
17 changes: 10 additions & 7 deletions src/nnssl/architectures/get_network_by_name.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Literal
from dynamic_network_architectures.architectures.abstract_arch import AbstractDynamicNetworkArchitectures
import torch.nn

from dynamic_network_architectures.architectures.unet import ResidualEncoderUNet
from dynamic_network_architectures.architectures.primus import PrimusS, PrimusB, PrimusM, PrimusL
from torch import nn
from dynamic_network_architectures.building_blocks.helper import get_matching_instancenorm, convert_dim_to_conv_op
from dynamic_network_architectures.architectures.abstract_arch import AbstractDynamicNetworkArchitectures

from nnssl.architectures.architecture_registry import (
SUPPORTED_ARCHITECTURES,
get_res_enc_l,
Expand Down Expand Up @@ -69,9 +69,12 @@ def get_network_by_name(
if architecture_name in ["ResEncL", "NoSkipResEncL"]:
model: ResidualEncoderUNet
try:
model = model.encoder
model.key_to_encoder = model.key_to_encoder.replace("encoder.", "")
model.keys_to_in_proj = [k.replace("encoder.", "") for k in model.keys_to_in_proj]
model.decoder = torch.nn.Identity()
# key_to_encoder = model.key_to_encoder.replace("encoder.", "")
# keys_to_in_proj = [k.replace("encoder.", "") for k in model.keys_to_in_proj]
# model = model.encoder
# model.key_to_encoder = key_to_encoder
# model.keys_to_in_proj = keys_to_in_proj
except AttributeError:
raise RuntimeError("Trying to get the 'encoder' of the network failed. Cannot return encoder only.")
elif architecture_name in ["PrimusS", "PrimusB", "PrimusM", "PrimusL"]:
Expand Down
4 changes: 2 additions & 2 deletions src/nnssl/dataset_conversion/Dataset001_OpenMind.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mport argparse
import argparse
import os
from collections import defaultdict
from pathlib import Path
Expand Down Expand Up @@ -134,4 +134,4 @@ def main():


if __name__ == "__main__":
main()
main()
120 changes: 4 additions & 116 deletions src/nnssl/experiment_planning/experiment_planners/plan_wandb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, asdict, is_dataclass
from typing import Any
from nnssl.experiment_planning.experiment_planners.plan import Plan, ConfigurationPlan

import json
import numpy as np

from nnssl.experiment_planning.experiment_planners.plan import Plan, ConfigurationPlan

def dataclass_to_dict(data):
if is_dataclass(data):
Expand All @@ -13,127 +13,15 @@ def dataclass_to_dict(data):

@dataclass
class ConfigurationPlan_wandb(ConfigurationPlan):
data_identifier: str
preprocessor_name: str
batch_size: int
batch_dice: bool
patch_size: np.ndarray
median_image_size_in_voxels: np.ndarray
spacing: np.ndarray
normalization_schemes: list[str]
use_mask_for_norm: list[str]
UNet_class_name: str
UNet_base_num_features: int
n_conv_per_stage_encoder: tuple[int]
n_conv_per_stage_decoder: tuple[int]
num_pool_per_axis: list[int]
pool_op_kernel_sizes: list[list[int]]
conv_kernel_sizes: list[list[int]]
unet_max_num_features: int
resampling_fn_data: str
resampling_fn_data_kwargs: dict[str, Any]
resampling_fn_mask: str
resampling_fn_mask_kwargs: dict[str, Any]
mask_ratio: float=None
vit_patch_size: list[int]=None
embed_dim: int=None
encoder_eva_depth: int=None
encoder_eva_numheads: int=None
decoder_eva_depth: int=None
decoder_eva_numheads: int=None
initial_lr: float=None,

#
#
# def __getitem__(self, key):
# return getattr(self, key)
#
# def __setitem__(self, key, value):
# setattr(self, key, value)
#
# def __delitem__(self, key):
# delattr(self, key)
#
# def __contains__(self, key):
# return hasattr(self, key)
#
# def __len__(self):
# return len(self.__dict__)
#
# def keys(self):
# return self.__dict__.keys()
#
# def values(self):
# return [getattr(self, key) for key in self.keys()]
#
# def items(self):
# return [(key, getattr(self, key)) for key in self.keys()]
pass


@dataclass
class Plan_wandb(Plan):
dataset_name: str
plans_name: str
original_median_spacing_after_transp: list[float]
original_median_shape_after_transp: list[int]
image_reader_writer: str
transpose_forward: list[int]
transpose_backward: list[int]
configurations: dict[str, ConfigurationPlan]
experiment_planner_used: str

# def __getitem__(self, key):
# return getattr(self, key)
#
# def __setitem__(self, key, value):
# setattr(self, key, value)
#
# def __delitem__(self, key):
# delattr(self, key)
#
# def __contains__(self, key):
# return hasattr(self, key)
#
# def _expected_save_directory(self):
# pp_path = os.environ.get("nnssl_preprocessed")
# if pp_path is None:
# raise RuntimeError(
# "nnssl_preprocessed environment variable not set. This is where the preprocessed data will be saved."
# )
# return os.path.join(pp_path, self.dataset_name, self.plans_name + ".json")
#
# def save_to_file(self, overwrite=False):
# save_dir = self._expected_save_directory()
# print(f"Saving plan to {save_dir}...")
# if os.path.isfile(save_dir) and not overwrite:
# return
# os.makedirs(os.path.dirname(save_dir), exist_ok=True)
# with open(save_dir, "w") as f:
# json.dump(self._json_serializable(), f, indent=4, sort_keys=False)
#
# def _json_serializable(self) -> dict:
# only_dicts = dataclass_to_dict(self)
# recursive_fix_for_json_export(only_dicts)
# return only_dicts
#
# def __len__(self):
# return len(self.__dict__)
#
# def keys(self):
# return self.__dict__.keys()
#
# def values(self):
# return [getattr(self, key) for key in self.keys()]
#
# def items(self):
# return [(key, getattr(self, key)) for key in self.keys()]
#
@staticmethod
def load_from_file(path: str):
json_dict: dict = json.load(open(path, "r"))
configs = {k: ConfigurationPlan_wandb(**v) for k, v in json_dict["configurations"].items()}
json_dict["configurations"] = configs
return Plan(**json_dict)
#
# def image_reader_writer_class(self) -> "Type[BaseReaderWriter]":
# return recursive_find_reader_writer_by_name(self.image_reader_writer)
34 changes: 22 additions & 12 deletions src/nnssl/run/load_pretrained_weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import warnings

import torch

from torch._dynamo import OptimizedModule
from torch.nn.parallel import DistributedDataParallel as DDP

Expand Down Expand Up @@ -32,14 +35,20 @@ def load_pretrained_weights(network, fname, verbose=False):
model_dict = mod.state_dict()
# verify that all but the segmentation layers have the same shape
for key, _ in model_dict.items():
if all([i not in key for i in skip_strings_in_pretrained]):
assert key in pretrained_dict, \
f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be " \
f"compatible with your network."
assert model_dict[key].shape == pretrained_dict[key].shape, \
f"The shape of the parameters of key {key} is not the same. Pretrained model: " \
f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \
f"does not seem to be compatible with your network."
if all([i not in key for i in skip_strings_in_pretrained]): # perform check only for non-segmentation layers
if key not in pretrained_dict:
# user warning
warnings.warn(
f"{'=' * 20} WARNING {'=' * 20}\n"
f"Key {key} is missing in the pretrained model weights. The pretrained weights do not seem to be "
"compatible with your network. This is not an error, but you should check whether this is intended."
f"\n{'=' * 20} WARNING {'=' * 20}"
)
if (key in model_dict) and (key in pretrained_dict):
assert model_dict[key].shape == pretrained_dict[key].shape, \
f"The shape of the parameters of key {key} is not the same. Pretrained model: " \
f"{pretrained_dict[key].shape}; your network: {model_dict[key]}. The pretrained model " \
f"does not seem to be compatible with your network."

# fun fact: in principle this allows loading from parameters that do not cover the entire network. For example pretrained
# encoders. Not supported by this function though (see assertions above)
Expand All @@ -50,8 +59,11 @@ def load_pretrained_weights(network, fname, verbose=False):
# if (('module.' + k if is_ddp else k) in model_dict) and
# all([i not in k for i in skip_strings_in_pretrained])}

pretrained_dict = {k: v for k, v in pretrained_dict.items()
if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])}
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict.keys() and all([i not in k for i in skip_strings_in_pretrained])
}

model_dict.update(pretrained_dict)

Expand All @@ -62,5 +74,3 @@ def load_pretrained_weights(network, fname, verbose=False):
print(key, 'shape', value.shape)
print("################### Done ###################")
mod.load_state_dict(model_dict)


Loading