diff --git a/mmv_im2im/configs/preset_train_AttentionUnet_Regression.yaml b/mmv_im2im/configs/preset_train_AttentionUnet_Regression.yaml new file mode 100644 index 0000000..9253d06 --- /dev/null +++ b/mmv_im2im/configs/preset_train_AttentionUnet_Regression.yaml @@ -0,0 +1,146 @@ +mode: train + +data: + category: "pair" + data_path: "path/to/yor/data" #_IM.tiff (imagen) , _GT.npy (numpy vector) in this case + dataloader: + train: + dataloader_type: + module_name: monai.data + func_name: Dataset # PersistentDataset CacheDataset Dataset + # dataset_params: + # # For CacheDataset + # # cache_rate: 1.0 + # # num_workers: 8 + # # For PersistentDataset + # # pickle_protocol: 2 + # # cache_dir: ./tmp0dmp0 + dataloader_params: + batch_size: 2 + pin_memory: True + num_workers: 4 + persistent_workers: False + val: + dataloader_type: + module_name: monai.data + func_name: Dataset # PersistentDataset CacheDataset Dataset + # dataset_params: + # # For CacheDataset + # # cache_rate: 1.0 + # # num_workers: 4 + # # For PersistentDataset + # # pickle_protocol: 2 + # # cache_dir: ./tmp0dmp0 + dataloader_params: + batch_size: 2 + pin_memory: True + num_workers: 4 + persistent_workers: False + preprocess: + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["IM"] + dimension_order_out: "YX" + C: 0 + T: 0 + Z: 0 + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["GT"] + dtype: float #int + - module_name: monai.transforms + func_name: EnsureChannelFirstd + params: + keys: ["IM"] + channel_dim: "no_channel" + - module_name: monai.transforms + func_name: NormalizeIntensityd + params: + channel_wise: True + keys: ["IM"] + # Activate this if the image sape is not divisible by 2^n = k n number of reducction 2 in channels AttentionUnet ([1, 2, 2, 2, 2] ->2^4 = 16 =k) + # - module_name: monai.transforms + # func_name: DivisiblePadd + # params: + # keys: ["IM", "GT"] + # k: 16 + - module_name: monai.transforms + func_name: EnsureTyped + params: + keys: ["IM", "GT"] + + # augmentation: + # - module_name: monai.transforms + # func_name: RandHistogramShiftd + # params: + # prob: 0.2 + # num_control_points: 50 + # keys: ["IM"] + +model: + framework: FCN + # model_extra: + # pre-train: path to ckp file + # extend: True option to tranfer learing (change output layer) + net: + module_name: monai.networks.nets + func_name: AttentionUnet + params: + in_channels: 1 # nuber of channels in the input IM + out_channels: 4 # Vector size on this case [v1,...,vn]-> n + spatial_dims: 2 # 2d or 3d for convolutions + channels: [32, 64, 128, 256, 512] + strides: [1, 2, 2, 2, 2] + dropout: 0.2 + task: "regression" + criterion: + module_name: torch.nn + func_name: HuberLoss #HuberLoss/L1Loss/MSELoss for vectors + params: + reduction: 'mean' + delta: 1.0 + + optimizer: + module_name: torch.optim + func_name: AdamW + params: + lr: 0.01 # 0.001 + weight_decay: 0.01 # 0.01 + + scheduler: + module_name: torch.optim.lr_scheduler + func_name: ReduceLROnPlateau + params: + mode: 'min' + factor: 0.2 + patience: 25 + monitor: 'val_loss' + +trainer: + verbose: True + # strategy: ddp_find_unused_parameters_true #avoid timeout + # gpus: 1 #number or list of gpus to use + params: + precision: 32 #16 + max_epochs: 3000 + detect_anomaly: False + log_every_n_steps: 2 # less than the number of training iterations + # gradient_clip_val: 0.5 + # gradient_clip_algorithm: "norm" + callbacks: + - module_name: lightning.pytorch.callbacks.early_stopping + func_name: EarlyStopping + params: + monitor: 'val_loss' + patience: 60 + verbose: True + - module_name: lightning.pytorch.callbacks.model_checkpoint + func_name: ModelCheckpoint + params: + monitor: 'val_loss' + filename: '{epoch}-{val_loss:.5f}' + mode: min + save_top_k: 5 + save_last: true diff --git a/mmv_im2im/configs/preset_train_AttentionUnet_regularizers.yaml b/mmv_im2im/configs/preset_train_AttentionUnet_Segmentation_regularizers.yaml similarity index 99% rename from mmv_im2im/configs/preset_train_AttentionUnet_regularizers.yaml rename to mmv_im2im/configs/preset_train_AttentionUnet_Segmentation_regularizers.yaml index 854956a..9fba441 100644 --- a/mmv_im2im/configs/preset_train_AttentionUnet_regularizers.yaml +++ b/mmv_im2im/configs/preset_train_AttentionUnet_Segmentation_regularizers.yaml @@ -119,6 +119,7 @@ model: channels: [32, 64, 128, 256, 512] strides: [1, 2, 2, 2, 2] dropout: 0.2 + task: 'segmentation' # segmentation/regression criterion: module_name: mmv_im2im.utils.gdl_regularized func_name: RegularizedGeneralizedDiceFocalLoss diff --git a/mmv_im2im/configs/preset_train_ProbabilisticUnet_Regression.yaml b/mmv_im2im/configs/preset_train_ProbabilisticUnet_Regression.yaml new file mode 100644 index 0000000..870bad2 --- /dev/null +++ b/mmv_im2im/configs/preset_train_ProbabilisticUnet_Regression.yaml @@ -0,0 +1,151 @@ +mode: train + +data: + category: "pair" + data_path: "path/to/yor/data" #_IM.tiff (imagen) , _GT.npy (numpy vector) in this case + dataloader: + train: + dataloader_type: + module_name: monai.data + func_name: Dataset # PersistentDataset CacheDataset Dataset + # dataset_params: + # # For CacheDataset + # # cache_rate: 1.0 + # # num_workers: 8 + # # For PersistentDataset + # # pickle_protocol: 2 + # # cache_dir: ./tmp0dmp0 + dataloader_params: + batch_size: 2 + pin_memory: True + num_workers: 4 + persistent_workers: False + val: + dataloader_type: + module_name: monai.data + func_name: Dataset # PersistentDataset CacheDataset Dataset + # dataset_params: + # # For CacheDataset + # # cache_rate: 1.0 + # # num_workers: 4 + # # For PersistentDataset + # # pickle_protocol: 2 + # # cache_dir: ./tmp0dmp0 + dataloader_params: + batch_size: 2 + pin_memory: True + num_workers: 4 + persistent_workers: False + preprocess: + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["IM"] + dimension_order_out: "YX" + C: 0 + T: 0 + Z: 0 + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["GT"] + dtype: float # float/in according to vector type + - module_name: monai.transforms + func_name: EnsureChannelFirstd + params: + keys: ["IM"] + channel_dim: "no_channel" + - module_name: monai.transforms + func_name: NormalizeIntensityd + params: + channel_wise: True + keys: ["IM"] + # Activate this if the image sape is not divisible by 2^n = k n number of reducction 2 in channels AttentionUnet ([1, 2, 2, 2, 2] ->2^4 = 16 =k) + # - module_name: monai.transforms + # func_name: DivisiblePadd + # params: + # keys: ["IM", "GT"] + # k: 16 + - module_name: monai.transforms + func_name: EnsureTyped + params: + keys: ["IM", "GT"] + + # augmentation: + # - module_name: monai.transforms + # func_name: RandHistogramShiftd + # params: + # prob: 0.2 + # num_control_points: 50 + # keys: ["IM"] + +model: + framework: ProbUnet + # model_extra: + # pre-train: path to ckp file + # extend: True option to tranfer learing (change output layer) + net: + module_name: mmv_im2im.models.nets.ProbUnet + func_name: ProbabilisticUNet + params: + in_channels: 1 # nuber of channels in the input IM + out_channels: 4 # Vector size on this case [v1,...,vn]-> n + spatial_dims: 2 # 2d o 3d for convolutions + latent_dim: 4 + channels: [32, 64, 128, 256, 512] + strides: [1, 2, 2, 2, 2] + dropout: 0.2 + task: "regression" + criterion: + module_name: mmv_im2im.utils.elbo_loss + func_name: ELBOLoss + params: + spatial_dims: 2 #2d or 3d + beta: 1.0 # The beta parameter for ELBO loss + n_classes: 2 # Number of classes (counting background) + kl_clamp: 20.0 # for kl stability and convergence + #elbo_class_weights: None + regression_loss_type: "huber" # On regression task regression loss is required mse/l1/huber + + optimizer: + module_name: torch.optim + func_name: AdamW + params: + lr: 0.01 # 0.001 + weight_decay: 0.01 # 0.01 + + scheduler: + module_name: torch.optim.lr_scheduler + func_name: ReduceLROnPlateau + params: + mode: 'min' + factor: 0.2 + patience: 25 + monitor: 'val_loss' + +trainer: + verbose: True + # strategy: ddp_find_unused_parameters_true #avoid timeout + # gpus: 1 #number or list of gpus to use + params: + precision: 32 #16 + max_epochs: 3000 + detect_anomaly: False + log_every_n_steps: 10 # less than the number of training iterations + # gradient_clip_val: 0.5 + # gradient_clip_algorithm: "norm" + callbacks: + - module_name: lightning.pytorch.callbacks.early_stopping + func_name: EarlyStopping + params: + monitor: 'val_loss' + patience: 60 + verbose: True + - module_name: lightning.pytorch.callbacks.model_checkpoint + func_name: ModelCheckpoint + params: + monitor: 'val_loss' + filename: '{epoch}-{val_loss:.5f}' + mode: min + save_top_k: 5 + save_last: true \ No newline at end of file diff --git a/mmv_im2im/configs/preset_train_ProbabilisticUnet_regularizers.yaml b/mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml similarity index 98% rename from mmv_im2im/configs/preset_train_ProbabilisticUnet_regularizers.yaml rename to mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml index 34fff5b..f0adbbd 100644 --- a/mmv_im2im/configs/preset_train_ProbabilisticUnet_regularizers.yaml +++ b/mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml @@ -120,17 +120,17 @@ model: channels: [32, 64, 128, 256, 512] strides: [1, 2, 2, 2, 2] dropout: 0.2 - + task: 'segmentation' # segmentation/regression criterion: module_name: mmv_im2im.utils.elbo_loss func_name: ELBOLoss params: - spatial_dims: 2 #2d or 3d - task: "segment" + spatial_dims: 2 #2d or 3d beta: 0.5 # The beta parameter for ELBO loss n_classes: 3 # Number of classes (counting background) kl_clamp: 20.0 # for kl stability and convergence #elbo_class_weights: None + #regression_loss_type: "mse" # On regression task regression loss is required mse/l1/huber ################# Set the desired used regularizer true and the desired parameters ############################ diff --git a/mmv_im2im/configs/preset_train_nnUnet_Regression.yaml b/mmv_im2im/configs/preset_train_nnUnet_Regression.yaml new file mode 100644 index 0000000..e6b4076 --- /dev/null +++ b/mmv_im2im/configs/preset_train_nnUnet_Regression.yaml @@ -0,0 +1,156 @@ +mode: train + +data: + category: "pair" + data_path: "path/to/yor/data" #_IM.tiff (imagen) , _GT.npy (numpy vector) in this case + extra: # Info for heurisctis nnUnet + patch_size: [256, 256] + spacing: [1.0, 1.0] + modality: "non-CT" + #vram_gb: 40 #if you use cuda automatically its computed in other case you should provide + dataloader: + train: + dataloader_type: + module_name: monai.data + func_name: Dataset # PersistentDataset CacheDataset Dataset + # dataset_params: + # # For CacheDataset + # # cache_rate: 1.0 + # # num_workers: 8 + # # For PersistentDataset + # # pickle_protocol: 2 + # # cache_dir: ./tmp0dmp0 + dataloader_params: + batch_size: 2 + pin_memory: True + num_workers: 4 + persistent_workers: False + val: + dataloader_type: + module_name: monai.data + func_name: Dataset # PersistentDataset CacheDataset Dataset + # dataset_params: + # # For CacheDataset + # # cache_rate: 1.0 + # # num_workers: 4 + # # For PersistentDataset + # # pickle_protocol: 2 + # # cache_dir: ./tmp0dmp0 + dataloader_params: + batch_size: 2 + pin_memory: True + num_workers: 4 + persistent_workers: False + preprocess: + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["IM"] + dimension_order_out: "YX" + C: 0 + T: 0 + Z: 0 + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["GT"] + dtype: float #int + - module_name: monai.transforms + func_name: EnsureChannelFirstd + params: + keys: ["IM"] + channel_dim: "no_channel" + - module_name: monai.transforms + func_name: NormalizeIntensityd + params: + channel_wise: True + keys: ["IM"] + # Activate this if the image sape is not divisible by 2^n = k n number of reducction 2 in channels AttentionUnet ([1, 2, 2, 2, 2] ->2^4 = 16 =k) + # - module_name: monai.transforms + # func_name: DivisiblePadd + # params: + # keys: ["IM", "GT"] + # k: 16 + - module_name: monai.transforms + func_name: EnsureTyped + params: + keys: ["IM", "GT"] + + # augmentation: + # - module_name: monai.transforms + # func_name: RandHistogramShiftd + # params: + # prob: 0.2 + # num_control_points: 50 + # keys: ["IM"] + +model: + framework: nnUnet + # model_extra: + # pre-train: path to ckp file + # extend: True option to tranfer learing (change output layer) + net: + module_name: monai.networks.nets + func_name: DynUNet + params: + spatial_dims: 2 # 2d or 3d for convolutions + in_channels: 1 # nuber of channels in the input IM + out_channels: 4 # Vector size on this case [v1,...,vn]-> n + # DynUNet initial values + kernel_size: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]] + strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]] + upsample_kernel_size: [[2, 2], [2, 2], [2, 2], [2, 2]] + filters: [32, 64, 128, 256, 512] + dropout: 0.2 + res_block: True # Enables residual connections standard in modern nnU-Nets + deep_supervision: False + task: 'regression' # segmentation/regression + criterion: + module_name: torch.nn + func_name: HuberLoss #HuberLoss/L1Loss/MSELoss for vectors + params: + reduction: 'mean' + delta: 1.0 + + optimizer: + module_name: torch.optim + func_name: AdamW + params: + lr: 0.01 # 0.001 + weight_decay: 0.01 # 0.01 + + scheduler: + module_name: torch.optim.lr_scheduler + func_name: ReduceLROnPlateau + params: + mode: 'min' + factor: 0.2 + patience: 25 + monitor: 'val_loss' + +trainer: + verbose: True + # strategy: ddp_find_unused_parameters_true #avoid timeout + # gpus: 1 #number or list of gpus to use + params: + precision: 32 #16 + max_epochs: 3000 + detect_anomaly: False + log_every_n_steps: 2 # less than the number of training iterations + # gradient_clip_val: 0.5 + # gradient_clip_algorithm: "norm" + callbacks: + - module_name: lightning.pytorch.callbacks.early_stopping + func_name: EarlyStopping + params: + monitor: 'val_loss' + patience: 60 + verbose: True + - module_name: lightning.pytorch.callbacks.model_checkpoint + func_name: ModelCheckpoint + params: + monitor: 'val_loss' + filename: '{epoch}-{val_loss:.5f}' + mode: min + save_top_k: 5 + save_last: true diff --git a/mmv_im2im/configs/preset_train_nnUnet_Segmentation_regularizers.yaml b/mmv_im2im/configs/preset_train_nnUnet_Segmentation_regularizers.yaml new file mode 100644 index 0000000..3035a25 --- /dev/null +++ b/mmv_im2im/configs/preset_train_nnUnet_Segmentation_regularizers.yaml @@ -0,0 +1,230 @@ +mode: train + +data: + category: "pair" + data_path: "path/to/your/data" + extra: # Info for heurisctis nnUnet + patch_size: [256, 256] + spacing: [1.0, 1.0] + modality: "non-CT" + #vram_gb: 40 #if you use cuda automatically its computed in other case you should provide + dataloader: + train: + dataloader_type: + module_name: monai.data + func_name: Dataset + dataloader_params: + batch_size: 4 + pin_memory: True + num_workers: 4 + persistent_workers: False + val: + dataloader_type: + module_name: monai.data + func_name: Dataset + dataloader_params: + batch_size: 4 + pin_memory: True + num_workers: 4 + persistent_workers: False + preprocess: + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["IM"] + dimension_order_out: "CYX" + T: 0 + Z: 0 + - module_name: monai.transforms + func_name: LoadImaged + params: + keys: ["GT"] + dtype: int + dimension_order_out: "YX" + C: 0 + T: 0 + Z: 0 + - module_name: monai.transforms + func_name: EnsureChannelFirstd + params: + keys: ["GT"] + channel_dim: "no_channel" + - module_name: monai.transforms + func_name: NormalizeIntensityd + params: + channel_wise: True + keys: ["IM"] + - module_name: monai.transforms + func_name: EnsureTyped + params: + keys: ["IM", "GT"] + augmentation: + - module_name: monai.transforms + func_name: RandFlipd + params: + prob: 0.2 #0.5 + keys: ["IM", "GT"] + - module_name: monai.transforms + func_name: RandRotate90d + params: + prob: 0.2 #0.5 + keys: ["IM", "GT"] + - module_name: monai.transforms + func_name: RandHistogramShiftd + params: + prob: 0.2 + num_control_points: 50 + keys: ["IM"] + - module_name: monai.transforms + func_name: Rand2DElasticd + params: + prob: 0.2 + spacing: [32, 32] + magnitude_range: [1, 5] + rotate_range: [0, 0.5] + scale_range: [0.1, 0.25] + translate_range: [10, 50] + padding_mode: "reflection" + mode: "nearest" + keys: ["IM", "GT"] + +model: + framework: nnUnet + # model_extra: + # pre-train: path to ckp file + # extend: True option to tranfer learing (change output layer) + net: + module_name: monai.networks.nets + func_name: DynUNet + params: + spatial_dims: 2 # 2d or 3d for convolutions + in_channels: 2 # nuber of channels in the input IM + out_channels: 3 # One per class + # DynUNet initial values + kernel_size: [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]] + strides: [[1, 1], [2, 2], [2, 2], [2, 2], [2, 2]] + upsample_kernel_size: [[2, 2], [2, 2], [2, 2], [2, 2]] + filters: [32, 64, 128, 256, 512] + dropout: 0.2 + res_block: True # Enables residual connections standard in modern nnU-Nets + deep_supervision: False + task: 'segmentation' # segmentation/regression + criterion: + module_name: mmv_im2im.utils.gdl_regularized + func_name: RegularizedGeneralizedDiceFocalLoss + params: + spatial_dims: 2 #2d or 3d + n_classes: 3 # Number of classes (counting background) + gdl_focal_weight: 1.0 + #gdl_class_weights: [1,1,1] + + #Fractal-Dimension + use_fractal_regularization: False # Set to True to enable fractal regularization and then the parametert for this + fractal_weight: 1 # loss fractal weight. + fractal_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1 + fractal_num_kernels: 9 # Number of kernel used (n kernels of size 2^1,...,2^n) + fractal_mode: "entropy" # aproximation type for the FD ("classic" or "entropy"). + fractal_to_binary: True #binarization for the masks (classes doesnt matter just the geometry) + + #CC + use_connectivity_regularization : False + connectivity_mode: learneable-linear #multiscale-exp multiscale-linear single (preset) learneable-single learneable-linear learneable-exp (enables to learn filters) + kernel_shape: square # square or gaussian + connectivity_weight : 1 + connectivity_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1 + connectivity_kernel_size : 5 # 3, 5, 7 + connectivity_ignore_background : True + lambda_density: 1.0 + lambda_gradient: 0.2 + connectivity_metric_density: huber #'l1' , 'mse' , 'huber' , 'charbonnier' + connectivity_metric_gradient: cosine # 'l1', 'mse', 'charbonnier', 'cosine' + + #Persistence Image + use_homology_regularization: False + homology_weight: 1 + homology_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1 + homology_class_context: 2 # n take n clases whithout 0, [1,4,n] especific classes, 'general' all the seg like one class + homology_metric: mse # simm, mse , l1, l2 + homology_features: all # all (h0,h1) cc (conceted components) holes + homology_sigma: 0.05 # For gaussian definition in Persistence Image manual creation. + homology_resolution: (100,100) # size of the generated Persistence Image small value for no filtration + homology_downsample_scale: 0.06 # reduction on the sizes for PI generation (>1) original gt size + homology_filtering: True # Apply filtration to PI deleting noise and help using greater resolution (100,100) depending on memory + homology_threshold: 0.01 # If filterinf then apply persistence > tershold noise removal. + homology_k_top: 500 # If filtering then keep the 2000 most significatives topology features to prevent memory problems + chunks: 100 # 0 no apply , chunks >0 divide PI computation + weighting_power: 2.0 # Penalize more the persistent features (W = persistence^alpha). + composite_flag: True # Compute union topology to penalize inter-class gaps/fragmentation. + + # Persistence Entropy - Betti numbers + use_topological_complexity: False + topological_complexity_weight: 1 + complexity_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1 + complexity_features: all # all (h0,h1) cc (conceted components) holes + complexity_downsample_scale: 0.3 # reduction on the sizes for PD generation (>1) original gt size + complexity_class_context: 2 # n take n clases whithout 0, [1,4,n] especific classes, 'general' all the seg like one class + complexity_metric: wasserstein # mse , wasserstein , log_cosh + complexity_threshold: 0.001 # If filterinf then apply persistence > tershold noise removal. + complexity_k_top: 500 #If filtering then keep the 2000 most significatives topology features to prevent memory problems + complexity_temperature: 0.05 #relaxation for Betti number Lower approx discrete count, higher smoother gradients. + complexity_auto_balance: True #Enables dynamic inverse-frequency weighting per batch. + + #Hauss + use_hausdorff_regularization: False + hausdorff_weight: 1 + hausdorff_downsample_scale: 0.3 # reduction on the sizes for distance maps generation (>1) original gt size + hausdorff_dt_iterations: 25 # Boundary attraction range; determines how far the loss "sees" to pull predictions toward targets. + hausdorff_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1 + hausdorff_include_background: False + + #Top + use_topological_regularization: False # Enable top features + topological_weight: 1 # Adjust this weight based on the importance of this regularization. + topological_warmup_epochs: 150 # distribute the impact in the given n epochs -> recomend weight 1 + topological_connectivity: 8 # 4 or 8 for 2D; 6 or 26 for 3D. + topological_inclusion: [] + topological_exclusion: [] + topological_min_thick: 1 # Default value. + + optimizer: + module_name: torch.optim + func_name: AdamW + params: + lr: 0.01 # 0.001 + weight_decay: 0.01 # 0.01 + + scheduler: + module_name: torch.optim.lr_scheduler + func_name: ReduceLROnPlateau + params: + mode: 'min' + factor: 0.2 + patience: 25 + monitor: 'val_loss' + +trainer: + verbose: True + # strategy: ddp_find_unused_parameters_true #avoid timeout + # gpus: 1 #number or list of gpus to use + params: + precision: 32 #16 + max_epochs: 3000 + detect_anomaly: False + log_every_n_steps: 2 # less than the number of training iterations + # gradient_clip_val: 0.5 + # gradient_clip_algorithm: "norm" + callbacks: + - module_name: lightning.pytorch.callbacks.early_stopping + func_name: EarlyStopping + params: + monitor: 'val_loss' + patience: 100 + verbose: True + - module_name: lightning.pytorch.callbacks.model_checkpoint + func_name: ModelCheckpoint + params: + monitor: 'val_loss' + filename: '{epoch}-{val_loss:.5f}' + mode: min + save_top_k: 5 + save_last: true diff --git a/mmv_im2im/map_extractor.py b/mmv_im2im/map_extractor.py index 855f95d..a1f956f 100644 --- a/mmv_im2im/map_extractor.py +++ b/mmv_im2im/map_extractor.py @@ -111,7 +111,10 @@ def setup_data_processing(self): self.pre_process = parse_monai_ops_vanilla(self.data_cfg.preprocess) def process_one_image( - self, img: Union[DaskArray, NumpyArray], out_fn: Union[str, Path] = None + self, + img: Union[DaskArray, NumpyArray], + dim: int = 2, + out_fn: Union[str, Path] = None, ): if isinstance(img, DaskArray): @@ -131,7 +134,8 @@ def process_one_image( # run pre-processing on tensor if needed if self.pre_process is not None: x = self.pre_process(x) - x = x[0] + if dim == 2: + x = x[0] # choose different inference function for different types of models with torch.no_grad(): @@ -412,7 +416,7 @@ def _process_vol2slice(self, img, pred_cfg, original_postprocess, pert_opt): else: inp = im_input - logits = self.process_one_image(inp) + logits = self.process_one_image(inp, dim=2) samplesz.append(np.squeeze(logits)) # Multi-prediction aggregation @@ -464,6 +468,7 @@ def _process_vol2slice(self, img, pred_cfg, original_postprocess, pert_opt): def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt): """New direct 3D volume processing logic.""" # Input img is (C, Z, Y, X) or (Z, Y, X) + # Handle dummy channel if missing if len(img.shape) == 3: img = img[None, ...] # (C, Z, Y, X) @@ -478,7 +483,7 @@ def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt): # Process one image will handle 4D input by adding batch dim -> (1, C, Z, Y, X) # Ensure spatial_dims is set to 3 in setup - logits = self.process_one_image(inp) + logits = self.process_one_image(inp, dim=3) samples_vol.append(np.squeeze(logits)) # Multi-prediction aggregation (3D) diff --git a/mmv_im2im/models/nets/ProbUnet.py b/mmv_im2im/models/nets/ProbUnet.py index 5a8291b..dee889e 100644 --- a/mmv_im2im/models/nets/ProbUnet.py +++ b/mmv_im2im/models/nets/ProbUnet.py @@ -125,11 +125,13 @@ def __init__( up_kernel_size=3, latent_dim=6, dropout=0.0, + task="segmentation", ): super().__init__() self.spatial_dims = spatial_dims self.out_channels = out_channels self.latent_dim = latent_dim + self.task = task # Backbone Encoder self.unet_encoder = nn.ModuleList() @@ -144,7 +146,6 @@ def __init__( self.unet_decoder = nn.ModuleList() rev_c = channels[::-1] rev_s = strides[::-1] - # We decode len(channels) - 1 times to match the skip connections for i in range(len(channels) - 1): self.unet_decoder.append( UpConv( @@ -169,6 +170,8 @@ def __init__( dropout, latent_dim, ) + + # on regression out_channels spaciol concatyenation self.posterior_net = Encoder( spatial_dims, in_channels + out_channels, @@ -227,19 +230,33 @@ def forward(self, x, seg=None, train_posterior=True): mu_post, logvar_post = None, None if train_posterior and seg is not None: - if seg.shape[1] != self.out_channels: - seg_temp = seg - if seg_temp.shape[1] == 1: - seg_temp = seg_temp.squeeze(1) - seg_one_hot = ( - F.one_hot(seg_temp.long(), num_classes=self.out_channels) - .permute(0, 3, 1, 2) - .float() - ) + if self.task == "regression": + # seg -> [B, out_channels] + # expand + dims_to_add = len(x.shape) - 2 + seg_spatial = seg.view(seg.shape[0], seg.shape[1], *([1] * dims_to_add)) + seg_input = seg_spatial.expand(-1, -1, *x.shape[2:]).float() else: - seg_one_hot = seg.float() + # regular segmentation + if seg.shape[1] != self.out_channels: + seg_temp = seg + if seg_temp.shape[1] == 1: + seg_temp = seg_temp.squeeze(1) + + seg_one_hot = F.one_hot( + seg_temp.long(), num_classes=self.out_channels + ) - cat_input = torch.cat([x, seg_one_hot], dim=1) + dims = list(range(seg_one_hot.ndim)) + seg_input = ( + seg_one_hot.permute(0, dims[-1], *dims[1:-1]) + .contiguous() + .float() + ) + else: + seg_input = seg.float() + + cat_input = torch.cat([x, seg_input], dim=1) mu_post, logvar_post = self.posterior_net(cat_input) z_sample = self.reparameterize(mu_post, logvar_post) else: @@ -251,6 +268,12 @@ def forward(self, x, seg=None, train_posterior=True): ).expand(-1, -1, *unet_x.shape[2:]) reconstruction = self.f_comb(torch.cat([unet_x, z_b], dim=1)) + # (GAP) + if self.task == "regression": + reconstruction = reconstruction.view( + reconstruction.size(0), reconstruction.size(1), -1 + ).mean(dim=-1) + return { "pred": reconstruction, "mu_post": mu_post, diff --git a/mmv_im2im/models/pl_FCN.py b/mmv_im2im/models/pl_FCN.py index 779e023..e53f876 100644 --- a/mmv_im2im/models/pl_FCN.py +++ b/mmv_im2im/models/pl_FCN.py @@ -15,6 +15,15 @@ class Model(pl.LightningModule): def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = False): super().__init__() + + if isinstance(model_info_xx.net["params"], dict): + self.task = model_info_xx.net["params"].pop("task", "segmentation") + + if self.task != "regression" and self.task != "segmentation": + raise ValueError( + f"Task should be regression/segmentation : {self.task} was given" + ) + self.net = parse_config(model_info_xx.net) init_weights(self.net, init_type="kaiming") @@ -23,6 +32,7 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals self.verbose = verbose self.weighted_loss = False self.seg_flag = False + if train: if "use_costmap" in model_info_xx.criterion[ "params" @@ -80,10 +90,16 @@ def run_step(self, batch, validation_stage): y_hat = self(x) - if isinstance(self.criterion, torch.nn.CrossEntropyLoss): - # in case of CrossEntropy related error - # see: https://discuss.pytorch.org/t/runtimeerror-expected-object-of-scalar-type-long-but-got-scalar-type-float-when-using-crossentropyloss/30542 # noqa E501 - y = torch.squeeze(y, dim=1) # remove C dimension + if self.task == "regression": + # Global Average Pooling: Dim reduction 2D / 3D + # [B, C, H, W] -> [B, C] + y_hat = y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) + # GT: [B, C] + y = y.view(y.size(0), -1).float() + else: + if isinstance(self.criterion, torch.nn.CrossEntropyLoss): + # in case of CrossEntropy related error + y = torch.squeeze(y, dim=1) # remove C dimension if isinstance(self.criterion, regularized): current_epoch = self.current_epoch diff --git a/mmv_im2im/models/pl_ProbUnet.py b/mmv_im2im/models/pl_ProbUnet.py index e2d822b..a7fcfd6 100644 --- a/mmv_im2im/models/pl_ProbUnet.py +++ b/mmv_im2im/models/pl_ProbUnet.py @@ -11,6 +11,18 @@ class Model(pl.LightningModule): def __init__(self, model_info_xx, train=True, verbose=False): super().__init__() + + if isinstance(model_info_xx.net["params"], dict): + self.task = model_info_xx.net["params"].get("task", "segmentation") + + if self.task != "regression" and self.task != "segmentation": + raise ValueError( + f"Task should be regression/segmentation : {self.task} was given" + ) + + if "utils.elbo_loss" in model_info_xx.criterion["module_name"]: + model_info_xx.criterion["params"]["task"] = self.task + self.net = parse_config(model_info_xx.net) init_weights(self.net, init_type="kaiming") self.model_info = model_info_xx @@ -40,25 +52,29 @@ def forward(self, x, seg=None, train_posterior=False): def run_step(self, batch): x, y = batch["IM"], batch["GT"] - # Ensure x is (B, C, H, W) - if x.ndim == 5 and x.shape[-1] == 1: + if x.ndim > 4 and x.shape[-1] == 1: x = x.squeeze(-1) - # Ensure y is (B, 1, H, W) for passing to model and loss - if y.ndim == 5 and y.shape[-1] == 1: - y = y.squeeze(-1) - if y.ndim == 3: - y = y.unsqueeze(1) # Add channel dim if missing (B, H, W) -> (B, 1, H, W) + + if self.task == "regression": + if y.ndim > 2 and y.shape[-1] == 1: + y = y.squeeze(-1) + # [B, out_channels] + y = y.view(y.size(0), -1) + else: + if y.ndim > 4 and y.shape[-1] == 1: + y = y.squeeze(-1) + if y.ndim == x.ndim - 1: + y = y.unsqueeze(1) # Forward pass (Train Posterior) output = self(x, seg=y, train_posterior=True) # Calculate Loss - # Ensure 'epoch' is a number, not a tensor, to avoid issues in elbo_loss warmup current_ep = int(self.current_epoch) loss = self.criterion( logits=output["pred"], - y_true=y, # (B, 1, H, W) Integer labels usually + y_true=y, prior_mu=output["prior_mu"], prior_logvar=output["prior_logvar"], post_mu=output["mu_post"], diff --git a/mmv_im2im/models/pl_ProbUnet_old.py b/mmv_im2im/models/pl_ProbUnet_old.py new file mode 100644 index 0000000..f25aa89 --- /dev/null +++ b/mmv_im2im/models/pl_ProbUnet_old.py @@ -0,0 +1,163 @@ +import numpy as np +from typing import Dict +from pathlib import Path +from random import randint +import lightning as pl +import torch +from bioio.writers import OmeTiffWriter + +from mmv_im2im.utils.misc import ( + parse_config, + parse_config_func, + parse_config_func_without_params, +) +from mmv_im2im.utils.model_utils import init_weights + + +class Model(pl.LightningModule): + def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = False): + super().__init__() + self.net = parse_config(model_info_xx.net) + init_weights(self.net, init_type="kaiming") + + self.model_info = model_info_xx + self.verbose = verbose + self.weighted_loss = False + if train: + self.criterion = parse_config(model_info_xx.criterion) + self.optimizer_func = parse_config_func(model_info_xx.optimizer) + + # Store these as attributes for access in run_step/training_step/validation_step + self.last_prior_mu = None + self.last_prior_logvar = None + self.last_post_mu = None + self.last_post_logvar = None + + def forward(self, x, y=None): + # The underlying ProbabilisticUNet returns multiple values. + # Capture them here and store them as instance attributes. + logits, prior_mu, prior_logvar, post_mu, post_logvar = self.net(x, y) + + # Store for use in run_step (which calculates loss) + self.last_prior_mu = prior_mu + self.last_prior_logvar = prior_logvar + self.last_post_mu = post_mu + self.last_post_logvar = post_logvar + + # For the 'Model' (LightningModule) forward, only return the logits + # This makes the API consistent with other models in your framework. + return logits + + def configure_optimizers(self): + optimizer = self.optimizer_func(self.parameters()) + if self.model_info.scheduler is None: + return optimizer + else: + scheduler_func = parse_config_func_without_params(self.model_info.scheduler) + lr_scheduler = scheduler_func( + optimizer, **self.model_info.scheduler["params"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } + + def run_step(self, batch, validation_stage): + x = batch["IM"] + y = batch["GT"] + + if x.size(-1) == 1: + x = torch.squeeze(x, dim=-1) + y = torch.squeeze(y, dim=-1) + + # Call forward pass of the LightningModule. + # This will internally call self.net(x,y) and store the extra outputs. + logits = self(x, y) # This is now just 'logits' + + # Calculate loss using the stored attributes + # Ensure post_mu and post_logvar are not None if y was provided + # The ELBOLoss expects these to be tensors, not None. + if self.last_post_mu is None or self.last_post_logvar is None: + raise ValueError( + "Posterior distributions (mu, logvar) were not computed. Ensure 'y' is provided during training." + ) + + loss = self.criterion( + logits, + y, + self.last_prior_mu, + self.last_prior_logvar, + self.last_post_mu, + self.last_post_logvar, + ) + + return loss, logits + + def training_step(self, batch, batch_idx): + loss, y_hat = self.run_step(batch, validation_stage=False) + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + if self.verbose and batch_idx == 0: + self.log_images(batch, y_hat, "train") + + return loss + + def validation_step(self, batch, batch_idx): + loss, y_hat = self.run_step(batch, validation_stage=True) + self.log( + "val_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + if self.verbose and batch_idx == 0: + self.log_images(batch, y_hat, "val") + + return loss + + def log_images(self, batch, y_hat, stage): + src = batch["IM"] + tar = batch["GT"] + + save_path = Path(self.trainer.log_dir) + save_path.mkdir(parents=True, exist_ok=True) + + act = torch.nn.Softmax(dim=1) + yhat_act = act(y_hat) + + src_out = np.squeeze(src[0].detach().cpu().numpy()).astype(float) + tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float) + prd_out = np.squeeze(yhat_act[0].detach().cpu().numpy()).astype(float) + + def get_dim_order(arr): + dims = len(arr.shape) + return {2: "YX", 3: "ZYX", 4: "CZYX"}.get(dims, "YX") + + rand_tag = randint(1, 1000) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_src_{rand_tag}.tiff" + OmeTiffWriter.save(src_out, out_fn, dim_order=get_dim_order(src_out)) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_tar_{rand_tag}.tiff" + OmeTiffWriter.save(tar_out, out_fn, dim_order=get_dim_order(tar_out)) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_prd_{rand_tag}.tiff" + OmeTiffWriter.save(prd_out, out_fn, dim_order=get_dim_order(prd_out)) diff --git a/mmv_im2im/models/pl_nnUnet.py b/mmv_im2im/models/pl_nnUnet.py new file mode 100644 index 0000000..c8bc6c9 --- /dev/null +++ b/mmv_im2im/models/pl_nnUnet.py @@ -0,0 +1,152 @@ +from typing import Dict +import lightning as pl +import torch +from mmv_im2im.utils.gdl_regularized import ( + RegularizedGeneralizedDiceFocalLoss as regularized, +) +from mmv_im2im.utils.misc import ( + parse_config, + parse_config_func, + parse_config_func_without_params, +) +from mmv_im2im.utils.model_utils import init_weights + + +class Model(pl.LightningModule): + def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = False): + super().__init__() + + if isinstance(model_info_xx.net["params"], dict): + self.task = model_info_xx.net["params"].pop("task", "segmentation") + + if self.task != "regression" and self.task != "segmentation": + raise ValueError( + f"Task should be regression/segmentation : {self.task} was given" + ) + + self.net = parse_config(model_info_xx.net) + + init_weights(self.net, init_type="kaiming") + + self.model_info = model_info_xx + self.verbose = verbose + self.weighted_loss = False + self.seg_flag = False + if train: + if "use_costmap" in model_info_xx.criterion[ + "params" + ] and model_info_xx.criterion["params"].pop("use_costmap"): + self.weighted_loss = True + self.criterion = parse_config(model_info_xx.criterion) + self.optimizer_func = parse_config_func(model_info_xx.optimizer) + + if ( + model_info_xx.model_extra is not None + and "debug_segmentation" in model_info_xx.model_extra + and model_info_xx.model_extra["debug_segmentation"] + ): + self.seg_flag = True + + def configure_optimizers(self): + optimizer = self.optimizer_func(self.parameters()) + if self.model_info.scheduler is None: + return optimizer + else: + scheduler_func = parse_config_func_without_params(self.model_info.scheduler) + lr_scheduler = scheduler_func( + optimizer, **self.model_info.scheduler["params"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } + + def prepare_batch(self, batch): + return + + def forward(self, x): + return self.net(x) + + def run_step(self, batch, validation_stage): + x = batch["IM"] + y = batch["GT"] + + if self.weighted_loss: + assert ( + "CM" in batch.keys() + ), "Costmap is detected, but no use_costmap param in criterion" + cm = batch["CM"] + + if x.size()[-1] == 1 and x.ndim > (self.net.spatial_dims + 2): + x = torch.squeeze(x, dim=-1) + y = torch.squeeze(y, dim=-1) + if cm is not None: + cm = torch.squeeze(cm, dim=-1) + + y_hat = self(x) + + # Handle potential MONAI DynUNet deep supervision outputs safely + if torch.is_tensor(y_hat) and y_hat.ndim == x.ndim + 1: + # DynUNet interpolates and stacks intermediate predictions along dim=1. + # We unbind and take the primary full-resolution output (index 0). + y_hat = y_hat[:, 0, ...] + elif isinstance(y_hat, (list, tuple)): + y_hat = y_hat[0] + + if self.task == "regression": + # Global Average Pooling [B, C, H, W] -> [B, C] + y_hat = y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) + y = y.view(y.size(0), -1).float() + else: + if isinstance(self.criterion, torch.nn.CrossEntropyLoss): + y = torch.squeeze(y, dim=1) + + if isinstance(self.criterion, regularized): + current_epoch = self.current_epoch + loss = self.criterion(y_hat, y, epoch=current_epoch) + else: + if self.weighted_loss: + loss = self.criterion(y_hat, y, cm) + else: + loss = self.criterion(y_hat, y) + + return loss + + def on_train_epoch_end(self): + torch.cuda.synchronize() + + def on_validation_epoch_end(self): + torch.cuda.synchronize() + + def training_step(self, batch, batch_idx): + loss = self.run_step(batch, validation_stage=False) + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.run_step(batch, validation_stage=True) + self.log( + "val_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + return loss diff --git a/mmv_im2im/postprocessing/basic_collection.py b/mmv_im2im/postprocessing/basic_collection.py index 2dbbc38..290b0c2 100644 --- a/mmv_im2im/postprocessing/basic_collection.py +++ b/mmv_im2im/postprocessing/basic_collection.py @@ -78,6 +78,8 @@ def generate_classmap(im: Union[np.ndarray, torch.Tensor]) -> np.ndarray: # convert tensor to numpy if torch.is_tensor(im): im = im.cpu().numpy() + if len(im.shape) == 4 and im.shape[0] != 1: + im = im[None, ...] assert len(im.shape) == 4 or len(im.shape) == 5, "extract seg only accepts 4D/5D" assert im.shape[0] == 1, "extract seg requires first dim to be 1" diff --git a/mmv_im2im/proj_tester.py b/mmv_im2im/proj_tester.py index 6851753..6ab7073 100644 --- a/mmv_im2im/proj_tester.py +++ b/mmv_im2im/proj_tester.py @@ -16,7 +16,7 @@ from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla from skimage.io import imsave as save_rgb import bioio_tifffile -from tqdm import tqdm +from tqdm.auto import tqdm from monai.inferers import sliding_window_inference # https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html#predicting @@ -155,6 +155,14 @@ def process_one_image( else: y_hat = self.model(x) + if isinstance(y_hat, dict): + try: + y_hat = y_hat["pred"] + except Exception: + raise ValueError( + f"y_hat is a dictionary but the key 'pred' it's not found the y_hat output is: {y_hat}" + ) + ############################################################################### # # Note: currently, we assume y_hat is still on gpu, because embedseg clustering diff --git a/mmv_im2im/proj_trainer.py b/mmv_im2im/proj_trainer.py index d9315e2..abd9c20 100644 --- a/mmv_im2im/proj_trainer.py +++ b/mmv_im2im/proj_trainer.py @@ -7,6 +7,7 @@ import torch from mmv_im2im.data_modules import get_data_module from mmv_im2im.utils.misc import parse_ops_list +from mmv_im2im.utils.nnHeuristic import get_nnunet_plans import pyrallis import warnings @@ -40,6 +41,34 @@ def __init__(self, cfg): def run_training(self): self.data = get_data_module(self.data_cfg) + + if self.model_cfg.net["func_name"] == "DynUNet": + # 1. Gather inputs for heuristic + # You might need to add these fields to your YAML or extract them from data + extra_params = ( + self.data_cfg.extra if self.data_cfg.extra is not None else {} + ) + patch_size = extra_params.get("patch_size", [256, 256]) + spacing = extra_params.get("spacing", [1.0, 1.0]) + modality = extra_params.get("modality", "non-CT") + + plans = get_nnunet_plans(patch_size, spacing, modality) + + self.model_cfg.net["params"].update( + { + "kernel_size": plans["kernel_size"], + "strides": plans["strides"], + "filters": plans["filters"], + "upsample_kernel_size": plans["upsample_kernel_size"], + } + ) + + print(f"✅ nnU-Net configured for {len(patch_size)}D.") + print(f"Filters: {plans['filters']}") + print(f"Strides: {plans['strides']}") + print(f"Kernel size: {plans['kernel_size']}") + print(f"Upsample Kernel size: {plans['upsample_kernel_size']}") + model_category = self.model_cfg.framework model_module = import_module(f"mmv_im2im.models.pl_{model_category}") my_model_func = getattr(model_module, "Model") diff --git a/mmv_im2im/tests/test_dummy.py b/mmv_im2im/tests/test_dummy.py index 58dd2a0..73b446f 100644 --- a/mmv_im2im/tests/test_dummy.py +++ b/mmv_im2im/tests/test_dummy.py @@ -49,9 +49,9 @@ def test_connectivity_loss(dims): pred_softmax = F.softmax(logits, dim=1) target_one_hot = F.one_hot(target_indices, num_classes=n_classes).float() if dims == 2: - target_one_hot = target_one_hot.permute(0, 3, 1, 2) + target_one_hot = target_one_hot.permute(0, 3, 1, 2) else: - target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3) + target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3) loss = loss_fn(pred_softmax, target_one_hot) assert not torch.isnan(loss) diff --git a/mmv_im2im/utils/elbo_loss.py b/mmv_im2im/utils/elbo_loss.py index 706acd3..fc08241 100644 --- a/mmv_im2im/utils/elbo_loss.py +++ b/mmv_im2im/utils/elbo_loss.py @@ -35,7 +35,8 @@ def __init__( spatial_dims: int = 2, kl_clamp: float = None, elbo_class_weights: list = None, - task: str = "segment", + task: str = "segmentation", + regression_loss_type: str = "mse", # --- Fractal Regularization --- use_fractal_regularization: bool = False, fractal_weight: float = 0.1, @@ -112,6 +113,7 @@ def __init__( self.spatial_dims = spatial_dims self.kl_clamp = kl_clamp self.task = task + self.regression_loss_type = regression_loss_type.lower() self.kl_divergence_calculator = KLDivergence() # Class weights for the main reconstruction loss (Cross Entropy) @@ -122,149 +124,150 @@ def __init__( else: self.elbo_class_weights = None - # --- Initialize Regularizers --- - reg_used = [] - - # 1. Fractal - self.use_fractal_regularization = use_fractal_regularization - if self.use_fractal_regularization: - self.fractal_weight = fractal_weight - self.fractal_warmup_epochs = fractal_warmup_epochs - reg_used.append(f"Fractal-Dimension {fractal_mode}") - from mmv_im2im.utils.fractal_layers import FractalDimension - - self.fractal_dimension_calculator = FractalDimension( - num_kernels=fractal_num_kernels, - mode=fractal_mode, - to_binary=fractal_to_binary, - spatial_dims=self.spatial_dims, - ) + if self.task == "segmentation": + # --- Initialize Regularizers --- + reg_used = [] - # 2. Topological (TI Loss) - self.use_topological_regularization = use_topological_regularization - if self.use_topological_regularization: - self.topological_weight = topological_weight - self.topological_warmup_epochs = topological_warmup_epochs - reg_used.append("Topological-Restrictions (TI Loss)") - from mmv_im2im.utils.topological_loss import TI_Loss - - self.topological_loss_calculator = TI_Loss( - dim=self.spatial_dims, - connectivity=topological_connectivity, - inclusion=topological_inclusion if topological_inclusion else [], - exclusion=topological_exclusion if topological_exclusion else [], - min_thick=topological_min_thick, - ) + # 1. Fractal + self.use_fractal_regularization = use_fractal_regularization + if self.use_fractal_regularization: + self.fractal_weight = fractal_weight + self.fractal_warmup_epochs = fractal_warmup_epochs + reg_used.append(f"Fractal-Dimension {fractal_mode}") + from mmv_im2im.utils.fractal_layers import FractalDimension + + self.fractal_dimension_calculator = FractalDimension( + num_kernels=fractal_num_kernels, + mode=fractal_mode, + to_binary=fractal_to_binary, + spatial_dims=self.spatial_dims, + ) - # 3. Connectivity - self.use_connectivity_regularization = use_connectivity_regularization - if self.use_connectivity_regularization: - self.connectivity_weight = connectivity_weight - self.connectivity_warmup_epochs = connectivity_warmup_epochs - reg_used.append(f"Connectivity Coherence {connectivity_mode}") - from mmv_im2im.utils.connectivity_loss import ConnectivityCoherenceLoss - - self.connectivity_coherence_calculator = ConnectivityCoherenceLoss( - spatial_dims=self.spatial_dims, - connectivity_mode=connectivity_mode, - kernel_shape=kernel_shape, - connectivity_kernel_size=connectivity_kernel_size, - ignore_background=connectivity_ignore_background, - num_classes=n_classes, - lambda_density=lambda_density, - lambda_gradient=lambda_gradient, - metric_density=connectivity_metric_density, - metric_gradient=connectivity_metric_gradient, - ) + # 2. Topological (TI Loss) + self.use_topological_regularization = use_topological_regularization + if self.use_topological_regularization: + self.topological_weight = topological_weight + self.topological_warmup_epochs = topological_warmup_epochs + reg_used.append("Topological-Restrictions (TI Loss)") + from mmv_im2im.utils.topological_loss import TI_Loss + + self.topological_loss_calculator = TI_Loss( + dim=self.spatial_dims, + connectivity=topological_connectivity, + inclusion=topological_inclusion if topological_inclusion else [], + exclusion=topological_exclusion if topological_exclusion else [], + min_thick=topological_min_thick, + ) - # 4. GDL Focal (MONAI) - self.use_gdl_focal_regularization = use_gdl_focal_regularization - if self.use_gdl_focal_regularization: - self.gdl_focal_weight = gdl_focal_weight - self.gdl_warmup_epochs = gdl_warmup_epochs - reg_used.append("Generalized Dice Focal") - from monai.losses import GeneralizedDiceFocalLoss - - monai_focal_weights = ( - torch.tensor(gdl_class_weights, dtype=torch.float32) - if gdl_class_weights - else None - ) - # MONAI losses are generally dimension-agnostic given correct input shape - self.gdl_focal_loss_calculator = GeneralizedDiceFocalLoss( - softmax=True, to_onehot_y=True, weight=monai_focal_weights - ) + # 3. Connectivity + self.use_connectivity_regularization = use_connectivity_regularization + if self.use_connectivity_regularization: + self.connectivity_weight = connectivity_weight + self.connectivity_warmup_epochs = connectivity_warmup_epochs + reg_used.append(f"Connectivity Coherence {connectivity_mode}") + from mmv_im2im.utils.connectivity_loss import ConnectivityCoherenceLoss + + self.connectivity_coherence_calculator = ConnectivityCoherenceLoss( + spatial_dims=self.spatial_dims, + connectivity_mode=connectivity_mode, + kernel_shape=kernel_shape, + connectivity_kernel_size=connectivity_kernel_size, + ignore_background=connectivity_ignore_background, + num_classes=n_classes, + lambda_density=lambda_density, + lambda_gradient=lambda_gradient, + metric_density=connectivity_metric_density, + metric_gradient=connectivity_metric_gradient, + ) - # 5. Hausdorff - self.use_hausdorff_regularization = use_hausdorff_regularization - if self.use_hausdorff_regularization: - self.hausdorff_weight = hausdorff_weight - self.hausdorff_warmup_epochs = hausdorff_warmup_epochs - self.hausdorff_downsample_scale = hausdorff_downsample_scale - self.hausdorff_dt_iterations = hausdorff_dt_iterations - self.hausdorff_include_background = hausdorff_include_background - reg_used.append("Hausdorff") - from mmv_im2im.utils.hausdorff_loss import HausdorffLoss - - self.hausdorff_loss_calculator = HausdorffLoss( - spatial_dims=self.spatial_dims, - dt_iterations=self.hausdorff_dt_iterations, - include_background=self.hausdorff_include_background, - ) + # 4. GDL Focal (MONAI) + self.use_gdl_focal_regularization = use_gdl_focal_regularization + if self.use_gdl_focal_regularization: + self.gdl_focal_weight = gdl_focal_weight + self.gdl_warmup_epochs = gdl_warmup_epochs + reg_used.append("Generalized Dice Focal") + from monai.losses import GeneralizedDiceFocalLoss + + monai_focal_weights = ( + torch.tensor(gdl_class_weights, dtype=torch.float32) + if gdl_class_weights + else None + ) + # MONAI losses are generally dimension-agnostic given correct input shape + self.gdl_focal_loss_calculator = GeneralizedDiceFocalLoss( + softmax=True, to_onehot_y=True, weight=monai_focal_weights + ) - # 6. Homology (Persistence Image) - self.use_homology_regularization = use_homology_regularization - if self.use_homology_regularization: - self.homology_interval = max(1, homology_interval) - self.homology_weight = homology_weight - self.homology_warmup_epochs = homology_warmup_epochs - self.homology_downsample_scale = homology_downsample_scale - reg_used.append("Persistence Image (Homology)") - from mmv_im2im.utils.homology_loss import HomologyLoss - - self.homology_calculator = HomologyLoss( - spatial_dims=self.spatial_dims, - resolution=homology_resolution, - sigma=homology_sigma, - features=homology_features, - class_context=homology_class_context, - metric=homology_metric, - chunks=chunks, - filtering=homology_filtering, - treshold=homology_threshold, - k_top=homology_k_top, - weighting_power=weighting_power, - composite_flag=composite_flag, - ) + # 5. Hausdorff + self.use_hausdorff_regularization = use_hausdorff_regularization + if self.use_hausdorff_regularization: + self.hausdorff_weight = hausdorff_weight + self.hausdorff_warmup_epochs = hausdorff_warmup_epochs + self.hausdorff_downsample_scale = hausdorff_downsample_scale + self.hausdorff_dt_iterations = hausdorff_dt_iterations + self.hausdorff_include_background = hausdorff_include_background + reg_used.append("Hausdorff") + from mmv_im2im.utils.hausdorff_loss import HausdorffLoss + + self.hausdorff_loss_calculator = HausdorffLoss( + spatial_dims=self.spatial_dims, + dt_iterations=self.hausdorff_dt_iterations, + include_background=self.hausdorff_include_background, + ) - # 7. Topological Complexity - self.use_topological_complexity = use_topological_complexity - if self.use_topological_complexity: - self.complexity_interval = max(1, complexity_interval) - self.topological_complexity_weight = topological_complexity_weight - self.complexity_warmup_epochs = complexity_warmup_epochs - self.complexity_downsample_scale = complexity_downsample_scale - if complexity_metric == "wasserstein": - reg_used.append("Persistence Complexity (Diagrams)") - else: - reg_used.append("Persistence Complexity (Entropy-Betti)") - from mmv_im2im.utils.topological_complexity_loss import ( - TopologicalComplexityLoss, - ) + # 6. Homology (Persistence Image) + self.use_homology_regularization = use_homology_regularization + if self.use_homology_regularization: + self.homology_interval = max(1, homology_interval) + self.homology_weight = homology_weight + self.homology_warmup_epochs = homology_warmup_epochs + self.homology_downsample_scale = homology_downsample_scale + reg_used.append("Persistence Image (Homology)") + from mmv_im2im.utils.homology_loss import HomologyLoss + + self.homology_calculator = HomologyLoss( + spatial_dims=self.spatial_dims, + resolution=homology_resolution, + sigma=homology_sigma, + features=homology_features, + class_context=homology_class_context, + metric=homology_metric, + chunks=chunks, + filtering=homology_filtering, + treshold=homology_threshold, + k_top=homology_k_top, + weighting_power=weighting_power, + composite_flag=composite_flag, + ) - self.topological_complexity_calculator = TopologicalComplexityLoss( - spatial_dims=self.spatial_dims, - features=complexity_features, - class_context=complexity_class_context, - metric=complexity_metric, - threshold=complexity_threshold, - k_top=complexity_k_top, - temperature=complexity_temperature, - auto_balance=complexity_auto_balance, - ) + # 7. Topological Complexity + self.use_topological_complexity = use_topological_complexity + if self.use_topological_complexity: + self.complexity_interval = max(1, complexity_interval) + self.topological_complexity_weight = topological_complexity_weight + self.complexity_warmup_epochs = complexity_warmup_epochs + self.complexity_downsample_scale = complexity_downsample_scale + if complexity_metric == "wasserstein": + reg_used.append("Persistence Complexity (Diagrams)") + else: + reg_used.append("Persistence Complexity (Entropy-Betti)") + from mmv_im2im.utils.topological_complexity_loss import ( + TopologicalComplexityLoss, + ) + + self.topological_complexity_calculator = TopologicalComplexityLoss( + spatial_dims=self.spatial_dims, + features=complexity_features, + class_context=complexity_class_context, + metric=complexity_metric, + threshold=complexity_threshold, + k_top=complexity_k_top, + temperature=complexity_temperature, + auto_balance=complexity_auto_balance, + ) - if len(reg_used) > 0: - print(f"Active Regularizers: {reg_used}") + if len(reg_used) > 0: + print(f"Active Regularizers: {reg_used}") def _get_warmup_factor(self, current_epoch, warmup_epochs): if torch.is_tensor(current_epoch): @@ -342,7 +345,7 @@ def forward( self.elbo_class_weights = self.elbo_class_weights.to(logits.device) # --- 2. Base Reconstruction Loss (Cross Entropy) --- - if self.task == "segment": + if self.task == "segmentation": reconstruction_loss = F.cross_entropy( logits, y_true_flat.long(), @@ -350,7 +353,19 @@ def forward( weight=self.elbo_class_weights, ) else: - reconstruction_loss = F.huber_loss(logits, y_true, reduction="mean") + target = ( + y_true_ch.float() if y_true_ch.shape == logits.shape else y_true.float() + ) + if self.regression_loss_type == "mse": + reconstruction_loss = F.mse_loss(logits, target, reduction="mean") + elif self.regression_loss_type == "l1": + reconstruction_loss = F.l1_loss(logits, target, reduction="mean") + elif self.regression_loss_type == "huber": + reconstruction_loss = F.huber_loss(logits, target, reduction="mean") + else: + raise ValueError( + f"Regression loss should be mse/l1/huber but : {self.regression_loss_type} was given" + ) # --- 3. KL Divergence --- kl_div = self.kl_divergence_calculator( @@ -360,7 +375,7 @@ def forward( total_loss = reconstruction_loss + (self.beta * kl_div) # --- 4. Regularizers --- - if self.task == "segment": + if self.task == "segmentation": # Helper: Softmax Probs if ( diff --git a/mmv_im2im/utils/misc.py b/mmv_im2im/utils/misc.py index 17e3f64..6f93d06 100644 --- a/mmv_im2im/utils/misc.py +++ b/mmv_im2im/utils/misc.py @@ -19,68 +19,53 @@ def __init__(self, **kwargs): super().__init__() self.kwargs = kwargs + def verify_suffix(self, name: PathLike) -> bool: + if str(name).endswith(".npy"): + return True + return True + def read(self, data: Union[Sequence[PathLike], PathLike]): filenames: Sequence[PathLike] = ensure_tuple(data) img_ = [] for name in filenames: - try: - img_.append(BioImage(f"{name}", reader=bioio_tifffile.Reader)) - - except Exception: + if str(name).endswith(".npy"): + img_.append(np.load(str(name))) + else: try: - img_.append(BioImage(f"{name}")) - except Exception as e: - print(f"Error: {e}") - print(f"Image {name} failed at read process check the format.") + img_.append(BioImage(f"{name}", reader=bioio_tifffile.Reader)) + except Exception: + try: + img_.append(BioImage(f"{name}")) + except Exception as e: + print(f"Error: {e}") + print(f"Image {name} failed at read process check the format.") return img_ if len(filenames) > 1 else img_[0] def get_data(self, img) -> Tuple[np.ndarray, Dict]: + if isinstance(img, np.ndarray): + return img, {} + img_array: List[np.ndarray] = [] for img_obj in ensure_tuple(img): - data = img_obj.get_image_data(**self.kwargs) - img_array.append(data) - - return _stack_images(img_array, {}), {} - - def verify_suffix(self, filename: Union[Sequence[PathLike], PathLike]) -> bool: - return True - - -""" -def aicsimageio_reader(fn, **kwargs): - img = AICSImage(fn).reader.get_image_dask_data(**kwargs) - img_data = tio.data.io.check_uint_to_int(img.compute()) - return img_data, np.eye(4) - - -def load_yaml_cfg(yaml_path): - with open(yaml_path, "r") as stream: - opt_dict = yaml.safe_load(stream) - - # convert dictionary to attribute-like object - opt = Munch(opt_dict) - - return opt + if isinstance(img_obj, np.ndarray): + img_array.append(img_obj) + else: + data = img_obj.get_image_data(**self.kwargs) + img_array.append(data) -def get_max_shape(subjects): - dataset = tio.SubjectsDataset(subjects) - shapes = np.array([s.spatial_shape for s in dataset]) - return shapes.max(axis=0) -""" + return _stack_images(img_array, {}), {} def parse_config_func_without_params(info): - # TODO add docstring my_module = importlib.import_module(info["module_name"]) my_func = getattr(my_module, info["func_name"]) return my_func def parse_config(info): - # univeral configuration parser my_func = parse_config_func_without_params(info) if "params" in info: if inspect.isclass(my_func): @@ -95,7 +80,6 @@ def parse_config(info): def parse_config_func(info): - # TODO: remove this function, use the universal parse_config above my_func = parse_config_func_without_params(info) if "params" in info: return partial(my_func, **info["params"]) @@ -105,30 +89,15 @@ def parse_config_func(info): def parse_ops_list(trans_func: List[Dict]): op_list = [] - # {"module_name":"numpy.random", - # "func_name":"randint", - # "params":{"low":0,"high":1}} - # load functions according to config for trans_dict in trans_func: op_list.append(parse_config(trans_dict)) - # trans_module = importlib.import_module(trans_dict["module_name"]) - # trans_func = getattr(trans_module, trans_dict["func_name"]) - # op_list.append(trans_func(**trans_dict["params"])) return op_list def generate_test_dataset_dict(data: Union[str, Path], data_type: str = None) -> List: - """ - different options for "data": - - one CSV - - one folder - Return - a list of filename - """ dataset_list = [] data = Path(data).expanduser() if data.is_file(): - # should be a csv of dataframe import pandas as pd df = pd.read_csv(data) @@ -152,78 +121,51 @@ def generate_test_dataset_dict(data: Union[str, Path], data_type: str = None) -> def generate_dataset_dict(data: Union[str, Path, Dict]) -> List[Dict]: - """ - different options for "data": - - one CSV (columns: source, target, cmap), then split - - one folder (_IM.tiff, _GT.tiff, _CM.tiff), then split - - a dictionary of two or three folders (Im, GT, CM), then split - - Return - a list of dict, each dict contains 2 or 3 keys - "source_fn", "target_fn", "costmap_fn" (optional) - """ dataset_list = [] if isinstance(data, str) or isinstance(data, Path): data = Path(data).expanduser() if data.is_file(): - # should be a csv of dataframe import pandas as pd df = pd.read_csv(data) assert "source_path" in df.columns, "column source_path not found" assert "target_path" in df.columns, "column target_path not found" - - # check if costmap is needed - if "costmap_path" in df.columns: - cm_flag = True - else: - cm_flag = False + cm_flag = "costmap_path" in df.columns for row in df.itertuples(): + item = { + "source_fn": row.source_path, + "target_fn": row.target_path, + } if cm_flag: - dataset_list.append( - { - "source_fn": row.source_path, - "target_fn": row.target_path, - "costmap_fn": row.costmap_path, - } - ) - else: - dataset_list.append( - { - "source_fn": row.source_path, - "target_fn": row.target_path, - } # noqa E501 - ) + item["costmap_fn"] = row.costmap_path + dataset_list.append(item) + elif data.is_dir(): all_filename = sorted(data.glob("*_IM.*")) assert len(all_filename) > 0, f"no file found in {data}" - all_filename.sort() + for fn in all_filename: - target_fn = data / fn.name.replace("_IM.", "_GT.") - costmap_fn = data / fn.name.replace("_IM.", "_CM.") - if costmap_fn.is_file(): - dataset_list.append( - { - "source_fn": fn, - "target_fn": target_fn, - "costmap_fn": costmap_fn, - } - ) + # Extension agnostic matching + basename = fn.name[: fn.name.rfind("_IM.")] + target_fn = list(data.glob(f"{basename}_GT.*")) + costmap_fn = list(data.glob(f"{basename}_CM.*")) + + item = {"source_fn": fn} + if target_fn: + item["target_fn"] = target_fn[0] else: - dataset_list.append( - { - "source_fn": fn, - "target_fn": target_fn, - } - ) + item["target_fn"] = data / fn.name.replace("_IM.", "_GT.") + + if costmap_fn: + item["costmap_fn"] = costmap_fn[0] + + dataset_list.append(item) else: print(f"{data} is not a valid file or directory") elif isinstance(data, Dict): - # assume 3~4 keys: "source_dir", "target_dir", and - # "image_type", "costmap_dir" (optional) if "costmap_dir" in data: cm_path = Path(data["costmap_dir"]).expanduser() else: @@ -231,7 +173,6 @@ def generate_dataset_dict(data: Union[str, Path, Dict]) -> List[Dict]: source_path = Path(data["source_dir"]).expanduser() target_path = Path(data["target_dir"]).expanduser() - data_type = data["image_type"] all_filename = sorted(source_path.glob(f"*.{data_type}")) @@ -240,53 +181,48 @@ def generate_dataset_dict(data: Union[str, Path, Dict]) -> List[Dict]: for fn in all_filename: target_fn = target_path / fn.name + item = {"source_fn": fn, "target_fn": target_fn} if cm_path is not None: - costmap_fn = cm_path / fn.name - dataset_list.append( - { - "source_fn": fn, - "target_fn": target_fn, - "costmap_fn": costmap_fn, - } # noqa E501 - ) - else: - dataset_list.append({"source_fn": fn, "target_fn": target_fn}) - + item["costmap_fn"] = cm_path / fn.name + dataset_list.append(item) else: print("unsupported data type") assert len(dataset_list) > 0, f"empty dataset in {data}" - return dataset_list def load_one_folder(data) -> List: dataset_list = [] - + data = Path(data) all_filename = sorted(data.glob("*_IM.*")) assert len(all_filename) > 0, f"no file found in {data}" - # parse how many images associated with one subject - basename = all_filename[0].stem - basename = basename[: basename.rfind("_")] - subject_files = sorted(data.glob(f"{basename}_*.*")) - all_tags = [sfile.stem[sfile.stem.rfind("_") + 1 :] for sfile in subject_files] + first_fn = all_filename[0] + basename_template = first_fn.name[: first_fn.name.rfind("_IM")] + subject_files = sorted(data.glob(f"{basename_template}_*.*")) + + all_tags = list( + set([sfile.stem[sfile.stem.rfind("_") + 1 :] for sfile in subject_files]) + ) - all_filename.sort() for fn in all_filename: path_list = {} + current_basename = fn.name[: fn.name.rfind("_IM")] + for tag_name in all_tags: - fn_full = data / fn.name.replace("_IM.", f"_{tag_name}.") - path_list[tag_name] = fn_full - dataset_list.append(path_list) + tag_files = list(data.glob(f"{current_basename}_{tag_name}.*")) + if len(tag_files) > 0: + path_list[tag_name] = tag_files[0] + + if len(path_list) > 0: + dataset_list.append(path_list) return dataset_list def load_subfolders(data) -> List: dataset_list = [] - - # parse how many images associated with one subject all_tags = [ d.name for d in sorted(data.iterdir()) @@ -309,41 +245,38 @@ def load_subfolders(data) -> List: def generate_dataset_dict_monai(data: Union[str, Path, Dict]) -> List[Dict]: - """ - different options for "data": - - one CSV (columns: source, target, cmap), then split - - one folder (_IM.tiff, _GT.tiff, _CM.tiff), then split - - one folder with two or three subfolders (Im, GT, CM), then split - - a dictionary of train/val - - Return - a list of dict, each dict contains 2 or more keys, such as "IM", - "GT", and more (optional, e.g. "LM", "CM", etc.) - """ if isinstance(data, str): try: data = eval(data) except Exception as e: print(f"data path is recognized as a string ... due to {e}") + dataset_list = [] if isinstance(data, str) or isinstance(data, Path): data = Path(data).expanduser() if data.is_file(): - # should be a csv of dataframe + # Support for CSV loading + import pandas as pd + + df = pd.read_csv(data) + for row in df.itertuples(index=False): + row_dict = row._asdict() + item = {} + if "IM" in row_dict: + item["IM"] = Path(row_dict["IM"]) + if "GT" in row_dict: + item["GT"] = Path(row_dict["GT"]) + dataset_list.append(item) - # TODO: add loading - pass elif data.is_dir(): if len(sorted(data.glob("*_IM.*"))) > 0: dataset_list = load_one_folder(data) else: - # one folder with multiple subfolders (Im, GT, CM) dataset_list = load_subfolders(data) else: print(f"{data} is not a valid file or directory") elif isinstance(data, Dict): - # a dictionary of train/val if "train" in data and "val" in data: train_path = Path(data["train"]) train_list = load_one_folder(train_path) @@ -358,5 +291,4 @@ def generate_dataset_dict_monai(data: Union[str, Path, Dict]) -> List[Dict]: print("unsupported data type") assert len(dataset_list) > 0, "empty dataset" - return dataset_list diff --git a/mmv_im2im/utils/nnHeuristic.py b/mmv_im2im/utils/nnHeuristic.py new file mode 100644 index 0000000..f3672f5 --- /dev/null +++ b/mmv_im2im/utils/nnHeuristic.py @@ -0,0 +1,105 @@ +import numpy as np +import torch + + +def get_nnunet_plans(patch_size, spacing, modality="non-CT", min_size=8, vram_gb=None): + """ + Full replication of nnU-Net heuristic for topology and normalization. + + Args: + patch_size: List/tuple of spatial dimensions (e.g., [256, 256] or [128, 128, 128]) + spacing: Voxel spacing (e.g., [1.0, 1.0] or [1.0, 1.0, 5.0]) + modality: "CT" or "non-CT" (determines normalization) + vram_gb: Available GPU memory in GB (default is 40 for A100-40GB). + """ + dim = len(patch_size) + cur_size = np.array(patch_size) + cur_spacing = np.array(spacing) + + if vram_gb is None: + if torch.cuda.is_available(): + # Get total memory of the current device in bytes, then convert to GB + total_memory_bytes = torch.cuda.get_device_properties(0).total_memory + vram_gb = total_memory_bytes / (1024**3) + print( + f"Detected GPU with {vram_gb:.1f} GB VRAM. Adapting network topology..." + ) + else: + print("CUDA not available. Defaulting to safe 12GB VRAM heuristic.") + vram_gb = 12.0 + else: + vram_gb = vram_gb + + strides = [] + kernels = [] + + # TOPOLOGY HEURISTIC (Strides & Kernels) + # The first layer always has stride 1 + strides.append([1] * dim) + kernels.append([3] * dim) + + if vram_gb >= 38: # Using 38 as a safe threshold for 40GB cards + max_f = 512 if dim == 3 else 1024 + elif vram_gb >= 22: # Safe threshold for 24GB cards + max_f = 416 if dim == 3 else 768 + else: + # Standard <16GB fallback + max_f = 320 if dim == 3 else 512 + + while True: + # Determine which axes to downsample based on size and anisotropy + # Rule: Only downsample an axis if it's > min_size AND (it's at the + # highest resolution OR we have already balanced the resolution) + target_spacing = np.min(cur_spacing) + + # Decide stride for each axis + new_stride = [] + for i in range(dim): + # Downsample if axis is large enough AND its spacing is 'close enough' + # to the target high-res spacing (within a factor of 2) + if cur_size[i] > min_size and cur_spacing[i] <= 2 * target_spacing: + new_stride.append(2) + else: + new_stride.append(1) + + # Termination: if no axis can be downsampled anymore + if all(s == 1 for s in new_stride): + break + + strides.append(new_stride) + kernels.append([3] * dim) + + # Update current state for next iteration + cur_size = cur_size // np.array(new_stride) + cur_spacing = cur_spacing * np.array(new_stride) + + # FILTER HEURISTIC + num_layers = len(strides) + + filters = [min(32 * (2**i), max_f) for i in range(num_layers)] + + # 3. NORMALIZATION & INTENSITY HEURISTIC + # nnU-Net uses different intensity strategies based on modality + norm_plans = {} + if modality.upper() == "CT": + # CT uses Global Normalization (Clipped to percentiles) + norm_plans = { + "norm_name": "instance", + "intensity_clipping": [0.5, 99.5], # Percentiles + "z_score_type": "global", + } + else: + # MRI/Microscopy uses Per-Image Z-Score + norm_plans = { + "norm_name": "instance", + "intensity_clipping": None, + "z_score_type": "per_image", + } + + return { + "kernel_size": kernels, + "strides": strides, + "filters": filters, + "upsample_kernel_size": strides[1:], # MONAI DynUNet specific + "norm_info": norm_plans, + }