Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmv_im2im/bin/run_im2im.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
This sample script will get deployed in the bin directory of the
users' virtualenv when the parent module is installed using pip.
"""

import logging
import sys
import traceback
Expand All @@ -22,7 +23,6 @@
configuration_validation,
)


###############################################################################

log = logging.getLogger()
Expand Down
39 changes: 25 additions & 14 deletions mmv_im2im/configs/config_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,10 @@ class DataloaderModuleConfig:
# the parameters to be passed in when making the dataset (e.g. PersistentDataset)
# find full API in the corresponding Dataset function
# e.g. https://docs.monai.io/en/stable/data.html#persistentdataset
dataset_params: Dict = field(
default={"cache_dir": "./tmp", "pickle_protocol": 5}, is_mutable=True
)
# dataset_params: Dict = field(
# default={"cache_dir": "./tmp", "pickle_protocol": 5}, is_mutable=True
# )
dataset_params: Dict = field(default_factory=dict, is_mutable=True)

# the parameters to be passed in when making the DataLoader
# find full API here:
Expand Down Expand Up @@ -313,6 +314,9 @@ class ModelConfig:
class TrainerConfig:
"""Config for how to run the training"""

# strategy used example ddp, ddp_find_unused_parameters_true
strategy: str = field(default=None)

# whether to save sample outputs at the beginning of each epoch
verbose: bool = field(default=False)

Expand Down Expand Up @@ -382,25 +386,32 @@ def configuration_validation(cfg):

# check 5, if a global GPU number is set, update the value in trainer
if cfg.trainer.gpus is not None:
cfg.trainer.params["gpus"] = cfg.trainer.gpu
cfg.trainer.params["devices"] = cfg.trainer.gpus
cfg.trainer.params["accelerator"] = "gpu"

# check 5, if PersistentDataset is used, make sure add a tmpdir in subdirectory
# (otherwise may cause hash errors)
if cfg.mode == "train":
if (
is_train_persistent = (
cfg.data.dataloader.train.dataloader_type["func_name"]
== "PersistentDataset"
):
assert (
cfg.data.dataloader.val.dataloader_type["func_name"]
== "PersistentDataset"
), "currently, train and val can only use persisten loader together"
)
is_val_persistent = (
cfg.data.dataloader.val.dataloader_type["func_name"] == "PersistentDataset"
)

if is_train_persistent or is_val_persistent:

if cfg.data.dataloader.val.dataloader_type["func_name"] == "PersistentDataset":
assert (
cfg.data.dataloader.train.dataloader_type["func_name"]
== "PersistentDataset"
), "currently, train and val can only use persisten loader together"
is_train_persistent and is_val_persistent
), "currently, train and val can only use persisten loader together."

for loader_cfg in [cfg.data.dataloader.train, cfg.data.dataloader.val]:
if "cache_dir" not in loader_cfg.dataset_params:
loader_cfg.dataset_params["cache_dir"] = "./tmp_cache"
if "pickle_protocol" not in loader_cfg.dataset_params:
loader_cfg.dataset_params["pickle_protocol"] = 5

if (
cfg.data.dataloader.train.dataset_params["cache_dir"]
!= cfg.data.dataloader.val.dataset_params["cache_dir"]
Expand Down
245 changes: 245 additions & 0 deletions mmv_im2im/configs/preset_train_AttentionUnet_regularizers.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
mode: train

data:
category: "pair"
data_path: "/path/to/your/data"
dataloader:
train:
dataloader_type:
module_name: monai.data
func_name: Dataset # PersistentDataset CacheDataset Dataset
# dataset_params:
# # For CacheDataset
# # cache_rate: 1.0
# # num_workers: 8
# # For PersistentDataset
# # pickle_protocol: 2
# # cache_dir: ./cache28
dataloader_params:
batch_size: 2
pin_memory: True
num_workers: 4
persistent_workers: False
val:
dataloader_type:
module_name: monai.data
func_name: Dataset # PersistentDataset CacheDataset Dataset
# dataset_params:
# # For CacheDataset
# # cache_rate: 1.0
# # num_workers: 4
# # For PersistentDataset
# # pickle_protocol: 2
# # cache_dir: ./cache28
dataloader_params:
batch_size: 2
pin_memory: True
num_workers: 4
persistent_workers: False
preprocess:
- module_name: monai.transforms
func_name: LoadImaged
params:
keys: ["IM"]
dimension_order_out: "CYX"
T: 0
Z: 0
- module_name: monai.transforms
func_name: LoadImaged
params:
keys: ["GT"]
dtype: int
dimension_order_out: "YX"
C: 0
T: 0
Z: 0
- module_name: monai.transforms
func_name: EnsureChannelFirstd
params:
keys: ["GT"]
channel_dim: "no_channel"
- module_name: monai.transforms
func_name: NormalizeIntensityd
params:
channel_wise: True
keys: ["IM"]
#Activate this if the image sape is not divisible by 2^n = k n number of reducction 2 in channels AttentionUnet ([1, 2, 2, 2, 2] ->2^4 = 16 =k)
# - module_name: monai.transforms
# func_name: DivisiblePadd
# params:
# keys: ["IM", "GT"]
# k: 16
- module_name: monai.transforms
func_name: EnsureTyped
params:
keys: ["IM", "GT"]

augmentation:
- module_name: monai.transforms
func_name: RandFlipd
params:
prob: 0.2 #0.5
keys: ["IM", "GT"]
- module_name: monai.transforms
func_name: RandRotate90d
params:
prob: 0.2 #0.5
keys: ["IM", "GT"]
- module_name: monai.transforms
func_name: RandHistogramShiftd
params:
prob: 0.2
num_control_points: 50
keys: ["IM"]
- module_name: monai.transforms
func_name: Rand2DElasticd
params:
prob: 0.2
spacing: [32, 32]
magnitude_range: [1, 5]
rotate_range: [0, 0.5]
scale_range: [0.1, 0.25]
translate_range: [10, 50]
padding_mode: "reflection"
mode: "nearest"
keys: ["IM", "GT"]

model:
framework: FCN
# model_extra:
# pre-train: path to ckp file
# extend: True option to tranfer learing (change output layer)
net:
module_name: monai.networks.nets
func_name: AttentionUnet
params:
in_channels: 3 # nuber of channels in the input IM
out_channels: 2 # one per class counting background
spatial_dims: 2 # 2d o 3d for convolutions
channels: [32, 64, 128, 256, 512]
strides: [1, 2, 2, 2, 2]
dropout: 0.2
criterion:
module_name: mmv_im2im.utils.gdl_regularized
func_name: RegularizedGeneralizedDiceFocalLoss
params:
spatial_dims: 2 #2d or 3d
n_classes: 3 # Number of classes (counting background)
gdl_focal_weight: 1.0
#gdl_class_weights: [1,1,1]

################# Set the desired used regularizer true and the desired parameters ############################

#Fractal-Dimension
use_fractal_regularization: False # Set to True to enable fractal regularization and then the parametert for this
fractal_weight: 1 # loss fractal weight.
fractal_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1
fractal_num_kernels: 9 # Number of kernel used (n kernels of size 2^1,...,2^n)
fractal_mode: "classic" # aproximation type for the FD ("classic" or "entropy").
fractal_to_binary: True #binarization for the masks (classes doesnt matter just the geometry)

#CC
use_connectivity_regularization : False
connectivity_mode: learneable-linear #multiscale-exp multiscale-linear single (preset) learneable-single learneable-linear learneable-exp (enables to learn filters)
kernel_shape: square # square or gaussian
connectivity_weight : 1
connectivity_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1
connectivity_kernel_size : 5 # 3, 5, 7
connectivity_ignore_background : True
lambda_density: 1.0
lambda_gradient: 0.2
connectivity_metric_density: huber #'l1' , 'mse' , 'huber' , 'charbonnier'
connectivity_metric_gradient: cosine # 'l1', 'mse', 'charbonnier', 'cosine'

#Persistence Image
use_homology_regularization: False
homology_weight: 1
homology_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1
homology_interval: 3 # apply the regularizator just every n epoch to help with heavy computations
homology_class_context: 2 # n take n clases whithout 0, [1,4,n] especific classes, 'general' all the seg like one class
homology_metric: mse # simm, mse , l1, l2
homology_features: all # all (h0,h1) cc (conceted components) holes
homology_sigma: 0.05 # For gaussian definition in Persistence Image manual creation.
homology_resolution: (100,100) # size of the generated Persistence Image small value for no filtration
homology_downsample_scale: 0.06 # reduction on the sizes for PI generation (>1) original gt size
homology_filtering: True # Apply filtration to PI deleting noise and help using greater resolution (100,100) depending on memory
homology_threshold: 0.01 # If filterinf then apply persistence > tershold noise removal.
homology_k_top: 500 # If filtering then keep the 2000 most significatives topology features to prevent memory problems
chunks: 100 # 0 no apply , chunks >0 divide PI computation
weighting_power: 2.0 # Penalize more the persistent features (W = persistence^alpha).
composite_flag: True # Compute union topology to penalize inter-class gaps/fragmentation.

# Persistence Entropy - Betti numbers
use_topological_complexity: False
topological_complexity_weight: 1
complexity_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1
complexity_interval: 3 # apply the regularizator just every n epoch to help with heavy computations
complexity_features: all # all (h0,h1) cc (conceted components) holes
complexity_downsample_scale: 0.3 # reduction on the sizes for PD generation (>1) original gt size
complexity_class_context: 2 # n take n clases whithout 0, [1,4,n] especific classes, 'general' all the seg like one class
complexity_metric: wasserstein # mse , wasserstein , log_cosh
complexity_threshold: 0.001 # If filterinf then apply persistence > tershold noise removal.
complexity_k_top: 500 #If filtering then keep the 2000 most significatives topology features to prevent memory problems
complexity_temperature: 0.05 #relaxation for Betti number Lower approx discrete count, higher smoother gradients.
complexity_auto_balance: True #Enables dynamic inverse-frequency weighting per batch.

#Hauss
use_hausdorff_regularization: False
hausdorff_weight: 1
hausdorff_downsample_scale: 0.3 # reduction on the sizes for distance maps generation (>1) original gt size
hausdorff_dt_iterations: 25 # Boundary attraction range; determines how far the loss "sees" to pull predictions toward targets.
hausdorff_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1
hausdorff_include_background: False

#Top
use_topological_regularization: False # Enable top features
topological_weight: 1 # Adjust this weight based on the importance of this regularization.
topological_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1
topological_connectivity: 8 # 4 or 8 for 2D; 6 or 26 for 3D.
topological_inclusion: []
topological_exclusion: []
topological_min_thick: 1 # Default value.


optimizer:
module_name: torch.optim
func_name: AdamW
params:
lr: 0.001 # 0.001
weight_decay: 0.01 # 0.01

scheduler:
module_name: torch.optim.lr_scheduler
func_name: ReduceLROnPlateau
params:
mode: 'min'
factor: 0.2
patience: 25
monitor: 'val_loss'

trainer:
verbose: True
# strategy: ddp_find_unused_parameters_true #avoid timeout
# gpus: 1 #number or list of gpus to use
params:
precision: 32 #16
max_epochs: 3000
detect_anomaly: False
log_every_n_steps: 10 # less than the number of training iterations
# gradient_clip_val: 0.5
# gradient_clip_algorithm: "norm"
callbacks:
- module_name: lightning.pytorch.callbacks.early_stopping
func_name: EarlyStopping
params:
monitor: 'val_loss'
patience: 100
verbose: True
- module_name: lightning.pytorch.callbacks.model_checkpoint
func_name: ModelCheckpoint
params:
monitor: 'val_loss'
filename: '{epoch}-{val_loss:.5f}'
mode: min
save_top_k: 5
save_last: true
Loading