diff --git a/mmv_im2im/__init__.py b/mmv_im2im/__init__.py index a8165ef..dba900b 100644 --- a/mmv_im2im/__init__.py +++ b/mmv_im2im/__init__.py @@ -6,7 +6,7 @@ __email__ = "jianxuchen.ai@gmail.com" # Do not edit this string manually, always use bumpversion # Details in CONTRIBUTING.md -__version__ = "0.8.0" +__version__ = "0.7.1" def get_module_version(): diff --git a/mmv_im2im/bin/run_im2im.py b/mmv_im2im/bin/run_im2im.py index c78065d..574327d 100644 --- a/mmv_im2im/bin/run_im2im.py +++ b/mmv_im2im/bin/run_im2im.py @@ -22,6 +22,7 @@ parse_adaptor, configuration_validation, ) +from mmv_im2im.proj_trainer_multishape import VariableSizeProjectTrainer ############################################################################### @@ -33,6 +34,7 @@ ############################################################################### TRAIN_MODE = "train" +TRAIN_MULTISHAPE_MODE = "train-multishape" INFER_MODE = "inference" MAP_MODE = "uncertainty_map" @@ -70,12 +72,35 @@ def main(): else: exe = ProjectTrainer(cfg) exe.run_training() + + elif cfg.mode.lower() == TRAIN_MULTISHAPE_MODE: + if ( + cfg.data.dataloader.train.dataloader_type["func_name"] + == "PersistentDataset" + ): + # Mirror the same PersistentDataset cache handling used in + # standard training mode. + cache_root = Path(cfg.data.dataloader.train.dataset_params["cache_dir"]) + cache_root.mkdir(exist_ok=True) + with tempfile.TemporaryDirectory(dir=cache_root) as tmp_exp: + train_cache = Path(tmp_exp) / "train" + val_cache = Path(tmp_exp) / "val" + cfg.data.dataloader.train.dataset_params["cache_dir"] = train_cache + cfg.data.dataloader.val.dataset_params["cache_dir"] = val_cache + exe = VariableSizeProjectTrainer(cfg) + exe.run_training() + else: + exe = VariableSizeProjectTrainer(cfg) + exe.run_training() + elif cfg.mode.lower() == INFER_MODE: exe = ProjectTester(cfg) exe.run_inference() + elif cfg.mode.lower() == MAP_MODE: exe = MapExtractor(cfg) exe.run_inference() + else: log.error(f"Mode {cfg.mode} is not supported yet") sys.exit(1) diff --git a/mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml b/mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml index f0adbbd..6d91130 100644 --- a/mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml +++ b/mmv_im2im/configs/preset_train_ProbabilisticUnet_Segmentation_regularizers.yaml @@ -2,7 +2,7 @@ mode: train data: category: "pair" - data_path: "/mnt/eternus/users/Jair/Trainings/Retrain2/DataMax" + data_path: "path/to/your/data" dataloader: train: dataloader_type: diff --git a/mmv_im2im/map_extractor.py b/mmv_im2im/map_extractor.py index a1f956f..4657894 100644 --- a/mmv_im2im/map_extractor.py +++ b/mmv_im2im/map_extractor.py @@ -1,5 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) import logging from typing import Union from dask.array.core import Array as DaskArray diff --git a/mmv_im2im/models/pl_FCN.py b/mmv_im2im/models/pl_FCN.py index e53f876..099c209 100644 --- a/mmv_im2im/models/pl_FCN.py +++ b/mmv_im2im/models/pl_FCN.py @@ -1,6 +1,8 @@ from typing import Dict + import lightning as pl import torch +import torch.nn as nn from mmv_im2im.utils.gdl_regularized import ( RegularizedGeneralizedDiceFocalLoss as regularized, ) @@ -19,13 +21,12 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals if isinstance(model_info_xx.net["params"], dict): self.task = model_info_xx.net["params"].pop("task", "segmentation") - if self.task != "regression" and self.task != "segmentation": + if self.task not in ("regression", "segmentation"): raise ValueError( - f"Task should be regression/segmentation : {self.task} was given" + f"Task should be 'regression' or 'segmentation'; got '{self.task}'" ) self.net = parse_config(model_info_xx.net) - init_weights(self.net, init_type="kaiming") self.model_info = model_info_xx @@ -33,6 +34,19 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals self.weighted_loss = False self.seg_flag = False + # ── Regression: adaptive global average pooling ──────────────── + # Works for ANY spatial size (2-D or 3-D) and is more robust than + # the manual .view(...).mean(-1) idiom. + # AdaptiveAvgPool reduces every spatial dimension to size 1, after + # which squeeze() gives a flat [B, C] tensor. + self._gap_2d = nn.AdaptiveAvgPool2d(1) + self._gap_3d = nn.AdaptiveAvgPool3d(1) + + # _logged_shapes → verbose print fires exactly once (any task) + # _shape_checked → regression shape validation fires exactly once + self._logged_shapes = False + self._shape_checked = False + if train: if "use_costmap" in model_info_xx.criterion[ "params" @@ -48,70 +62,114 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals ): self.seg_flag = True + # ------------------------------------------------------------------ def configure_optimizers(self): optimizer = self.optimizer_func(self.parameters()) if self.model_info.scheduler is None: return optimizer - else: - scheduler_func = parse_config_func_without_params(self.model_info.scheduler) - lr_scheduler = scheduler_func( - optimizer, **self.model_info.scheduler["params"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "monitor": "val_loss", - "interval": "epoch", - "frequency": 1, - "strict": True, - }, - } - - def prepare_batch(self, batch): - return - + scheduler_func = parse_config_func_without_params(self.model_info.scheduler) + lr_scheduler = scheduler_func(optimizer, **self.model_info.scheduler["params"]) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } + + # ------------------------------------------------------------------ def forward(self, x): return self.net(x) - def run_step(self, batch, validation_stage): - x = batch["IM"] - y = batch["GT"] - if "CM" in batch.keys(): + # ------------------------------------------------------------------ + def _global_average_pool(self, y_hat: torch.Tensor) -> torch.Tensor: + """ + Reduce [B, C, *spatial] → [B, C] via global average pooling. + + Handles both 2-D [B, C, H, W] and 3-D [B, C, D, H, W] tensors + of arbitrary spatial size. + """ + ndim = y_hat.dim() + if ndim == 4: # 2-D input + return self._gap_2d(y_hat).squeeze(-1).squeeze(-1) + elif ndim == 5: # 3-D input + return self._gap_3d(y_hat).squeeze(-1).squeeze(-1).squeeze(-1) + else: + # Fallback: manual mean over all spatial dims + return y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) + + # ------------------------------------------------------------------ + def run_step(self, batch, validation_stage: bool): + x = batch["IM"] # [B, 1, Z, Y, X] – spatial dims vary per batch + y = batch["GT"] # [B, 3 + n_coeffs] + + if "CM" in batch: assert ( self.weighted_loss - ), "Costmap is detected, but no use_costmap param in criterion" + ), "Costmap detected in batch but use_costmap=False in criterion config." cm = batch["CM"] - # only for badly formated data file - if x.size()[-1] == 1: - x = torch.squeeze(x, dim=-1) - y = torch.squeeze(y, dim=-1) + # ── Remove accidental trailing singleton (badly formatted data) ── + if x.dim() > 4 and x.size(-1) == 1: + x = x.squeeze(-1) + if y.dim() > 2 and y.size(-1) == 1: + y = y.squeeze(-1) - y_hat = self(x) + # ── Forward pass ────────────────────────────────────────────── + y_hat = self(x) # [B, C, Z', Y', X'] (spatial may still vary) + + # Verbose: imprime shapes una sola vez al inicio, independiente del task + if self.verbose and not self._logged_shapes: + print( + f"[pl_FCN] x.shape={tuple(x.shape)} " + f"y_hat.shape (pre-GAP)={tuple(y_hat.shape)}" + ) + self._logged_shapes = True + # ── Task-specific output processing ──────────────────────────── if self.task == "regression": - # Global Average Pooling: Dim reduction 2D / 3D - # [B, C, H, W] -> [B, C] - y_hat = y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) - # GT: [B, C] + # Global Average Pool: [B, C, *spatial] → [B, C] + y_hat = self._global_average_pool(y_hat) + + # GT: ensure [B, N] y = y.view(y.size(0), -1).float() - else: + + # One-time shape sanity check + if not self._shape_checked: + C_pred = y_hat.shape[1] + N_gt = y.shape[1] + if C_pred != N_gt: + raise ValueError( + f"[pl_FCN] Shape mismatch: network predicts {C_pred} " + f"values (out_channels={C_pred}) but GT vector has " + f"{N_gt} elements. " + f"Set out_channels: {N_gt} in your YAML." + ) + if self.verbose: + print( + f"[pl_FCN] Regression shapes OK " + f"y_hat={tuple(y_hat.shape)} y={tuple(y.shape)}" + ) + self._shape_checked = True + + else: # segmentation if isinstance(self.criterion, torch.nn.CrossEntropyLoss): - # in case of CrossEntropy related error - y = torch.squeeze(y, dim=1) # remove C dimension + y = y.squeeze(dim=1) + # ── Loss ───────────────────────────────────────────────────── if isinstance(self.criterion, regularized): - current_epoch = self.current_epoch - loss = self.criterion(y_hat, y, epoch=current_epoch) + loss = self.criterion(y_hat, y, epoch=self.current_epoch) + elif self.weighted_loss: + loss = self.criterion(y_hat, y, cm) else: - if self.weighted_loss: - loss = self.criterion(y_hat, y, cm) - else: - loss = self.criterion(y_hat, y) + loss = self.criterion(y_hat, y) return loss + # ------------------------------------------------------------------ def on_train_epoch_end(self): torch.cuda.synchronize() @@ -142,5 +200,4 @@ def validation_step(self, batch, batch_idx): logger=True, sync_dist=True, ) - return loss diff --git a/mmv_im2im/models/pl_ProbUnet.py b/mmv_im2im/models/pl_ProbUnet.py index a7fcfd6..a64fed7 100644 --- a/mmv_im2im/models/pl_ProbUnet.py +++ b/mmv_im2im/models/pl_ProbUnet.py @@ -1,5 +1,6 @@ import lightning as pl import torch +import torch.nn as nn from mmv_im2im.utils.misc import ( parse_config, parse_config_func, @@ -15,9 +16,9 @@ def __init__(self, model_info_xx, train=True, verbose=False): if isinstance(model_info_xx.net["params"], dict): self.task = model_info_xx.net["params"].get("task", "segmentation") - if self.task != "regression" and self.task != "segmentation": + if self.task not in ("regression", "segmentation"): raise ValueError( - f"Task should be regression/segmentation : {self.task} was given" + f"Task should be 'regression' or 'segmentation'; got '{self.task}'" ) if "utils.elbo_loss" in model_info_xx.criterion["module_name"]: @@ -25,12 +26,25 @@ def __init__(self, model_info_xx, train=True, verbose=False): self.net = parse_config(model_info_xx.net) init_weights(self.net, init_type="kaiming") + self.model_info = model_info_xx self.verbose = verbose + + # ── Adaptive global average pooling for regression ───────────── + # ProbUnet's decoder output is [B, C, *spatial] regardless of + # whether the task is segmentation or regression. For regression + # we need to collapse the spatial dims before computing the loss + # against the flat GT vector. + self._gap_2d = nn.AdaptiveAvgPool2d(1) + self._gap_3d = nn.AdaptiveAvgPool3d(1) + self._logged_shapes = False # verbose print: once, any task + self._shape_checked = False # shape validation: once, regression only + if train: self.criterion = parse_config(model_info_xx.criterion) self.optimizer_func = parse_config_func(model_info_xx.optimizer) + # ------------------------------------------------------------------ def configure_optimizers(self): optimizer = self.optimizer_func(self.parameters()) if self.model_info.scheduler is None: @@ -46,34 +60,81 @@ def configure_optimizers(self): }, } + # ------------------------------------------------------------------ def forward(self, x, seg=None, train_posterior=False): return self.net(x, seg, train_posterior) + # ------------------------------------------------------------------ + def _global_average_pool(self, y_hat: torch.Tensor) -> torch.Tensor: + """[B, C, *spatial] → [B, C] via adaptive global average pooling.""" + ndim = y_hat.dim() + if ndim == 4: + return self._gap_2d(y_hat).squeeze(-1).squeeze(-1) + elif ndim == 5: + return self._gap_3d(y_hat).squeeze(-1).squeeze(-1).squeeze(-1) + else: + return y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) + + # ------------------------------------------------------------------ def run_step(self, batch): x, y = batch["IM"], batch["GT"] + # ── Remove trailing singleton (badly formatted data) ─────────── if x.ndim > 4 and x.shape[-1] == 1: x = x.squeeze(-1) + # ── Prepare GT depending on task ────────────────────────────── if self.task == "regression": + # GT is a flat vector: ensure [B, N] if y.ndim > 2 and y.shape[-1] == 1: y = y.squeeze(-1) - # [B, out_channels] - y = y.view(y.size(0), -1) - else: + y = y.view(y.size(0), -1).float() + + else: # segmentation if y.ndim > 4 and y.shape[-1] == 1: y = y.squeeze(-1) + # Ensure channel dim present: [B, Z, Y, X] → [B, 1, Z, Y, X] if y.ndim == x.ndim - 1: y = y.unsqueeze(1) - # Forward pass (Train Posterior) + # ── Forward (Train Posterior) ───────────────────────────────── output = self(x, seg=y, train_posterior=True) - # Calculate Loss - current_ep = int(self.current_epoch) + # ── Spatial → flat para regression predictions ──────────────── + pred = output["pred"] # [B, C, *spatial] from the decoder + # Verbose: imprime shapes una sola vez, independiente del task + if self.verbose and not self._logged_shapes: + print( + f"[pl_ProbUnet] x.shape={tuple(x.shape)} " + f"pred.shape (pre-GAP)={tuple(pred.shape)}" + ) + self._logged_shapes = True + + if self.task == "regression": + # AdaptiveAvgPool: [B, C, *spatial] → [B, C] + pred = self._global_average_pool(pred) + + # One-time shape sanity check + if not self._shape_checked: + C_pred, N_gt = pred.shape[1], y.shape[1] + if C_pred != N_gt: + raise ValueError( + f"[pl_ProbUnet] Shape mismatch: decoder predicts {C_pred} " + f"channels but GT vector has {N_gt} elements. " + f"Set out_channels: {N_gt} in your YAML." + ) + if self.verbose: + print( + f"[pl_ProbUnet] Regression shapes OK " + f"pred={tuple(pred.shape)} y={tuple(y.shape)}" + ) + self._shape_checked = True + + # ── ELBO Loss ───────────────────────────────────────────────── + current_ep = int(self.current_epoch) loss = self.criterion( - logits=output["pred"], + logits=pred, y_true=y, prior_mu=output["prior_mu"], prior_logvar=output["prior_logvar"], @@ -83,6 +144,7 @@ def run_step(self, batch): ) return loss + # ------------------------------------------------------------------ def on_train_epoch_end(self): torch.cuda.synchronize() @@ -92,12 +154,9 @@ def on_validation_epoch_end(self): def training_step(self, batch, batch_idx): loss = self.run_step(batch) self.log("train_loss", loss, prog_bar=True, on_epoch=True, logger=True) - return loss def validation_step(self, batch, batch_idx): loss = self.run_step(batch) - self.log("val_loss", loss, prog_bar=True, on_epoch=True, logger=True) - return loss diff --git a/mmv_im2im/models/pl_nnUnet.py b/mmv_im2im/models/pl_nnUnet.py index c8bc6c9..598b450 100644 --- a/mmv_im2im/models/pl_nnUnet.py +++ b/mmv_im2im/models/pl_nnUnet.py @@ -1,6 +1,7 @@ from typing import Dict import lightning as pl import torch +import torch.nn as nn from mmv_im2im.utils.gdl_regularized import ( RegularizedGeneralizedDiceFocalLoss as regularized, ) @@ -19,19 +20,26 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals if isinstance(model_info_xx.net["params"], dict): self.task = model_info_xx.net["params"].pop("task", "segmentation") - if self.task != "regression" and self.task != "segmentation": + if self.task not in ("regression", "segmentation"): raise ValueError( - f"Task should be regression/segmentation : {self.task} was given" + f"Task should be 'regression' or 'segmentation'; got '{self.task}'" ) self.net = parse_config(model_info_xx.net) - init_weights(self.net, init_type="kaiming") self.model_info = model_info_xx self.verbose = verbose self.weighted_loss = False self.seg_flag = False + + # ── Adaptive global average pooling for regression ───────────── + # Handles ANY spatial size, both 2-D and 3-D. + self._gap_2d = nn.AdaptiveAvgPool2d(1) + self._gap_3d = nn.AdaptiveAvgPool3d(1) + self._logged_shapes = False # verbose print: once, any task + self._shape_checked = False # shape validation: once, regression only + if train: if "use_costmap" in model_info_xx.criterion[ "params" @@ -47,25 +55,23 @@ def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = Fals ): self.seg_flag = True + # ------------------------------------------------------------------ def configure_optimizers(self): optimizer = self.optimizer_func(self.parameters()) if self.model_info.scheduler is None: return optimizer - else: - scheduler_func = parse_config_func_without_params(self.model_info.scheduler) - lr_scheduler = scheduler_func( - optimizer, **self.model_info.scheduler["params"] - ) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "monitor": "val_loss", - "interval": "epoch", - "frequency": 1, - "strict": True, - }, - } + scheduler_func = parse_config_func_without_params(self.model_info.scheduler) + lr_scheduler = scheduler_func(optimizer, **self.model_info.scheduler["params"]) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } def prepare_batch(self, batch): return @@ -73,51 +79,96 @@ def prepare_batch(self, batch): def forward(self, x): return self.net(x) + # ------------------------------------------------------------------ + def _global_average_pool(self, y_hat: torch.Tensor) -> torch.Tensor: + """[B, C, *spatial] → [B, C] via adaptive global average pooling.""" + ndim = y_hat.dim() + if ndim == 4: + return self._gap_2d(y_hat).squeeze(-1).squeeze(-1) + elif ndim == 5: + return self._gap_3d(y_hat).squeeze(-1).squeeze(-1).squeeze(-1) + else: + return y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) + + # ------------------------------------------------------------------ def run_step(self, batch, validation_stage): x = batch["IM"] y = batch["GT"] + # FIX: initialise cm=None so it is always defined in scope + cm = None if self.weighted_loss: - assert ( - "CM" in batch.keys() - ), "Costmap is detected, but no use_costmap param in criterion" + assert "CM" in batch, ( + "weighted_loss=True but 'CM' key not found in batch. " + "Check use_costmap setting in criterion config." + ) cm = batch["CM"] - if x.size()[-1] == 1 and x.ndim > (self.net.spatial_dims + 2): + # Remove accidental trailing singleton dimension + # (guard: only squeeze if the tensor really has an extra dim) + spatial_dims = getattr(self.net, "spatial_dims", x.ndim - 2) + if x.ndim > (spatial_dims + 2) and x.size(-1) == 1: x = torch.squeeze(x, dim=-1) y = torch.squeeze(y, dim=-1) + # FIX: only squeeze cm if it was actually loaded if cm is not None: cm = torch.squeeze(cm, dim=-1) + # ── Forward ─────────────────────────────────────────────────── y_hat = self(x) - # Handle potential MONAI DynUNet deep supervision outputs safely + # Verbose: imprime shapes una sola vez, independiente del task + if self.verbose and not self._logged_shapes: + print( + f"[pl_nnUnet] x.shape={tuple(x.shape)} " + f"y_hat.shape (pre-GAP)={tuple(y_hat.shape)}" + ) + self._logged_shapes = True + + # Handle DynUNet deep supervision: stacks intermediate outputs + # along a new leading dim → shape [B, n_outputs, C, *spatial] if torch.is_tensor(y_hat) and y_hat.ndim == x.ndim + 1: - # DynUNet interpolates and stacks intermediate predictions along dim=1. - # We unbind and take the primary full-resolution output (index 0). - y_hat = y_hat[:, 0, ...] + y_hat = y_hat[:, 0, ...] # take full-resolution head elif isinstance(y_hat, (list, tuple)): y_hat = y_hat[0] + # ── Task-specific post-processing ───────────────────────────── if self.task == "regression": - # Global Average Pooling [B, C, H, W] -> [B, C] - y_hat = y_hat.view(y_hat.size(0), y_hat.size(1), -1).mean(dim=-1) + # AdaptiveAvgPool: [B, C, *spatial] → [B, C] (any spatial size) + y_hat = self._global_average_pool(y_hat) y = y.view(y.size(0), -1).float() - else: + + # One-time shape sanity check + if not self._shape_checked: + C_pred, N_gt = y_hat.shape[1], y.shape[1] + if C_pred != N_gt: + raise ValueError( + f"[pl_nnUnet] Shape mismatch: network predicts {C_pred} " + f"values but GT vector has {N_gt} elements. " + f"Set out_channels: {N_gt} in your YAML." + ) + if self.verbose: + print( + f"[pl_nnUnet] Regression shapes OK " + f"y_hat={tuple(y_hat.shape)} y={tuple(y.shape)}" + ) + self._shape_checked = True + + else: # segmentation if isinstance(self.criterion, torch.nn.CrossEntropyLoss): y = torch.squeeze(y, dim=1) + # ── Loss ────────────────────────────────────────────────────── if isinstance(self.criterion, regularized): - current_epoch = self.current_epoch - loss = self.criterion(y_hat, y, epoch=current_epoch) + loss = self.criterion(y_hat, y, epoch=self.current_epoch) + elif self.weighted_loss: + loss = self.criterion(y_hat, y, cm) else: - if self.weighted_loss: - loss = self.criterion(y_hat, y, cm) - else: - loss = self.criterion(y_hat, y) + loss = self.criterion(y_hat, y) return loss + # ------------------------------------------------------------------ def on_train_epoch_end(self): torch.cuda.synchronize() @@ -148,5 +199,4 @@ def validation_step(self, batch, batch_idx): logger=True, sync_dist=True, ) - return loss diff --git a/mmv_im2im/proj_tester.py b/mmv_im2im/proj_tester.py index 6ab7073..6b81c64 100644 --- a/mmv_im2im/proj_tester.py +++ b/mmv_im2im/proj_tester.py @@ -1,5 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) import logging from typing import Union from dask.array.core import Array as DaskArray @@ -18,8 +21,8 @@ import bioio_tifffile from tqdm.auto import tqdm from monai.inferers import sliding_window_inference +from monai.transforms import Compose -# https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html#predicting ############################################################################### log = logging.getLogger(__name__) @@ -27,10 +30,37 @@ ############################################################################### +# Module/function name pairs that activate the shared-state padding mechanism. +# Both the short name and the full mmv_im2im package path are accepted so that +# the YAML can use either style. +_PAD_PREPROCESS_TRIGGERS = { + ("inverse_transforms", "RecordShapeAndPad"), + ("custom_transforms", "RecordShapeAndPad"), + ("mmv_im2im.utils.custom_transforms", "RecordShapeAndPad"), + ("custom_transforms", "DivisiblePadWithGTAdjustd"), + ("mmv_im2im.utils.custom_transforms", "DivisiblePadWithGTAdjustd"), +} +_PAD_POSTPROCESS_TRIGGERS = { + ("inverse_transforms", "RemovePadFromPrediction"), + ("custom_transforms", "RemovePadFromPrediction"), + ("mmv_im2im.utils.custom_transforms", "RemovePadFromPrediction"), +} + + +def _get_trigger_key(cfg_entry: dict) -> tuple: + """Devuelve (module_name, func_name) de una entrada de config.""" + return ( + cfg_entry.get("module_name", ""), + cfg_entry.get("func_name", ""), + ) + + +############################################################################### + class ProjectTester(object): """ - entry for training models + Entry point for model inference. Parameters ---------- @@ -38,17 +68,25 @@ class ProjectTester(object): """ def __init__(self, cfg): - # extract the three major chuck of the config self.model_cfg = cfg.model self.data_cfg = cfg.data - # define variables self.model = None self.data = None self.pre_process = None + self.post_process_ops = None # pre-built postprocess pipeline self.cpu = False self.spatial_dims = -1 + self.pad_state = None # shared state for variable-size padding + + # Read task BEFORE setup_model() is called. + # pl_FCN / pl_nnUnet / pl_ProbUnet all do net["params"].pop("task") + # in __init__, so by the time setup_data_processing() runs the key + # is gone from the dict. Reading it here from the raw cfg is safe. + net_params = cfg.model.net.get("params", {}) if cfg.model.net else {} + self.is_regression = net_params.get("task", "segmentation") == "regression" + # ------------------------------------------------------------------ def setup_model(self): model_category = self.model_cfg.framework model_module = import_module(f"mmv_im2im.models.pl_{model_category}") @@ -70,13 +108,11 @@ def setup_model(self): checkpoint = torch.load(self.model_cfg.checkpoint, weights_only=False) if isinstance(checkpoint, dict) and "state_dict" in checkpoint: - pre_train = checkpoint pre_train["state_dict"].pop("criterion.xym", None) pre_train["state_dict"].pop("criterion.xyzm", None) self.model.load_state_dict(pre_train["state_dict"], strict=False) else: - state_dict = checkpoint state_dict.pop("criterion.xym", None) state_dict.pop("criterion.xyzm", None) @@ -84,58 +120,174 @@ def setup_model(self): if not self.cpu: self.model.cuda() - self.model.eval() + # ------------------------------------------------------------------ + def _detect_and_wire_pad_state(self): + """ + Escanea los configs de preprocess y postprocess buscando nuestras + transforms personalizadas de padding variable. + + Si se detecta alguna: + - Crea un PadStateBuffer en self.pad_state. + - Construye self.pre_process con el estado inyectado donde corresponde. + - Construye self.post_process_ops (lista de instancias) con el estado + inyectado en RemovePadFromPrediction. + + Si NO se detecta nada: + - No modifica nada; setup_data_processing() usará el camino estándar. + """ + pre_cfgs = self.data_cfg.preprocess or [] + post_cfgs = self.data_cfg.postprocess or [] + + # ── 1. Detectar si hay algún trigger en preprocess ───────────── + pad_pre_idx = None + for i, cfg in enumerate(pre_cfgs): + if _get_trigger_key(cfg) in _PAD_PREPROCESS_TRIGGERS: + pad_pre_idx = i + break + + if pad_pre_idx is None: + return # nada que hacer; comportamiento original + + # ── 2. Crear estado compartido ──────────────────────────────── + try: + from mmv_im2im.utils.custom_transforms import ( + PadStateBuffer, + RecordShapeAndPad, + ) + except ImportError as e: + raise ImportError( + "Could not import PadStateBuffer/RecordShapeAndPad from " + "mmv_im2im.utils.custom_transforms." + ) from e + + self.pad_state = PadStateBuffer() + print( + f"[ProjectTester] PadStateBuffer created — " + f"detected '{pre_cfgs[pad_pre_idx]['func_name']}' " + f"in preprocess[{pad_pre_idx}]." + ) + + # ── 3. Construir pipeline de preprocess con estado inyectado ── + pre_ops = [] + for i, cfg in enumerate(pre_cfgs): + key = _get_trigger_key(cfg) + + if key in _PAD_PREPROCESS_TRIGGERS: + # Extraer k, mode y constant_value del params original si existen + params = cfg.get("params", {}) + k = params.get("k", 16) + mode = params.get("mode", "constant") + constant_value = params.get("constant_value", 0.0) + + op = RecordShapeAndPad( + state=self.pad_state, + k=k, + mode=mode, + constant_value=constant_value, + ) + print( + f"[ProjectTester] preprocess[{i}] '{cfg['func_name']}' → " + f"RecordShapeAndPad(k={k}, mode='{mode}') with shared state." + ) + else: + # Transform estándar MONAI u otro: instanciar normalmente + op = parse_config(cfg) + + pre_ops.append(op) + + self.pre_process = Compose(pre_ops) + + # ── 4. Construir pipeline de postprocess con estado inyectado ── + if not post_cfgs: + self.post_process_ops = [] + return + + try: + from mmv_im2im.utils.custom_transforms import RemovePadFromPrediction + except ImportError as e: + raise ImportError( + "Could not import RemovePadFromPrediction from " + "mmv_im2im.utils.custom_transforms." + ) from e + + post_ops = [] + for i, cfg in enumerate(post_cfgs): + key = _get_trigger_key(cfg) + + if key in _PAD_POSTPROCESS_TRIGGERS: + params = cfg.get("params", {}) + k = params.get("k", 16) + n_coord_dims = params.get("n_coord_dims", 3) + + op = RemovePadFromPrediction( + state=self.pad_state, + k=k, + n_coord_dims=n_coord_dims, + ) + print( + f"[ProjectTester] postprocess[{i}] 'RemovePadFromPrediction' " + f"with shared state (k={k}, n_coord_dims={n_coord_dims})." + ) + else: + op = parse_config(cfg) + + post_ops.append(op) + + self.post_process_ops = post_ops + + # ------------------------------------------------------------------ def setup_data_processing(self): - # determine spatial dimension from reader parameters + # Determinar dimensión espacial if "Z" in self.data_cfg.inference_input.reader_params["dimension_order_out"]: self.spatial_dims = 3 else: self.spatial_dims = 2 - # prepare data preprocessing if needed - if self.data_cfg.preprocess is not None: - # load preprocessing transformation + # Wire the shared PadStateBuffer if variable-size padding transforms are + # present in the YAML. If no trigger is detected this call does nothing. + self._detect_and_wire_pad_state() + + # Standard preprocess path — only if _detect_and_wire_pad_state did + # not already build self.pre_process. + if self.pre_process is None and self.data_cfg.preprocess is not None: self.pre_process = parse_monai_ops_vanilla(self.data_cfg.preprocess) + # Standard postprocess path. + # post_process_ops is None → no trigger detected; transforms are + # instantiated per image (legacy behaviour). + # post_process_ops is [] → trigger detected but postprocess is empty. + + # ------------------------------------------------------------------ def process_one_image( self, img: Union[DaskArray, NumpyArray], out_fn: Union[str, Path] = None ): - if isinstance(img, DaskArray): - # Perform the prediction x = img.compute() - elif isinstance(img, NumpyArray): x = img else: raise ValueError("invalid image") - # check if need to add channel dimension if len(x.shape) == self.spatial_dims: x = np.expand_dims(x, axis=0) - # convert the numpy array to float tensor x = torch.tensor(x.astype(np.float32)) - # run pre-processing on tensor if needed - + # ── Preprocess ──────────────────────────────────────────────── + # Si pre_process es un Compose con RecordShapeAndPad, el estado + # se actualiza aquí para esta imagen concreta. if self.pre_process is not None: x = self.pre_process(x) - # choose different inference function for different types of models - # the input here is assumed to be a tensor + # ── Inferencia ──────────────────────────────────────────────── with torch.no_grad(): - # add batch dimension and move to GPU - if self.cpu: x = torch.unsqueeze(x, dim=0) else: x = torch.unsqueeze(x, dim=0).cuda() - # TODO: add convert to tensor with proper type, similar to torchio check - if ( self.model_cfg.model_extra is not None and "sliding_window_params" in self.model_cfg.model_extra @@ -146,10 +298,6 @@ def process_one_image( device=torch.device("cpu"), **self.model_cfg.model_extra["sliding_window_params"], ) - - # currently, we keep sliding window stiching step on CPU, but assume - # the output is on GPU (see note below). So, we manually move the data - # back to GPU if not self.cpu: y_hat = y_hat.cuda() else: @@ -160,50 +308,72 @@ def process_one_image( y_hat = y_hat["pred"] except Exception: raise ValueError( - f"y_hat is a dictionary but the key 'pred' it's not found the y_hat output is: {y_hat}" + f"y_hat is a dict but key 'pred' was not found. " + f"y_hat keys: {list(y_hat.keys())}" + ) + + # Global Average Pool — regression only, model-agnostic + # ----------------------------------------------------------------- + # During training, GAP is applied inside each pl_*.run_step(), + # collapsing [B, C, *spatial] → [B, C] before the loss. + # During inference those training steps are bypassed: the raw + # network output is [B, C, *spatial] regardless of which + # architecture is used (AttentionUnet, DynUNet, ProbUnet all + # share this contract). GAP is therefore applied here once, + # before any postprocess transform runs. + if self.is_regression and y_hat.dim() > 2: + import torch.nn.functional as F + + spatial_dims = y_hat.dim() - 2 + if spatial_dims == 2: + y_hat = F.adaptive_avg_pool2d(y_hat, 1).squeeze(-1).squeeze(-1) + elif spatial_dims == 3: + y_hat = ( + F.adaptive_avg_pool3d(y_hat, 1) + .squeeze(-1) + .squeeze(-1) + .squeeze(-1) ) + # Remove the batch dimension → 1-D vector (C,) + if y_hat.shape[0] == 1: + y_hat = y_hat.squeeze(0) + + # Postprocess + if self.post_process_ops is not None: + # Stateful path: pre-built pipeline with shared PadStateBuffer. + # RemovePadFromPrediction already holds a reference to pad_state, + # which was updated by RecordShapeAndPad during preprocess of this + # image — the correct pad_before offsets are applied automatically. + pp_data = y_hat + for pp in self.post_process_ops: + pp_data = pp(pp_data) + pred = pp_data.cpu().numpy() if torch.is_tensor(pp_data) else pp_data - ############################################################################### - # - # Note: currently, we assume y_hat is still on gpu, because embedseg clustering - # step is still only running on GPU (possible on CPU, need to some update on - # grid loading). All the post-procesisng functions we tested so far can accept - # tensor on GPU. If it is from mmv_im2im.post_processing, it will automatically - # convert the tensor to a numpy array and return the result as numpy array; if - # it is from monai.transforms, it is tensor in and tensor out. We have two items - # as #TODO: (1) we will extend post-processing functions in mmv_im2im to work - # similarly to monai transforms, ie. ndarray in ndarray out or tensor in tensor - # out. (2) allow yaml config to control if we want to run post-processing on - # GPU tensors or ndarrays - # - ############################################################################## - - # do post-processing on the prediction - if self.data_cfg.postprocess is not None: + elif self.data_cfg.postprocess is not None: + # Standard path: instantiate each postprocess transform per image. pp_data = y_hat for pp_info in self.data_cfg.postprocess: pp = parse_config(pp_info) pp_data = pp(pp_data) - if torch.is_tensor(pp_data): - pred = pp_data.cpu().numpy() - else: - pred = pp_data + pred = pp_data.cpu().numpy() if torch.is_tensor(pp_data) else pp_data + else: pred = y_hat.cpu().numpy() if out_fn is None: return pred - # determine output dimension orders - if out_fn.suffix == ".npy": + # Save result. + # Regression predictions are always saved as .npy vectors regardless + # of the suffix specified in the output config. + if self.is_regression or out_fn.suffix == ".npy": + out_fn = out_fn.with_suffix(".npy") np.save(out_fn, pred) else: if len(pred.shape) == 2: OmeTiffWriter.save(pred, out_fn, dim_order="YX") elif len(pred.shape) == 3: - # 3D output, for 2D data if self.spatial_dims == 2: - # save as RGB or multi-channel 2D if pred.shape[0] == 3: if out_fn.suffix != ".png": out_fn = out_fn.with_suffix(".png") @@ -224,62 +394,42 @@ def process_one_image( elif pred.shape[1] == 3: if out_fn.suffix != ".png": out_fn = out_fn.with_suffix(".png") - save_rgb( - out_fn, - np.moveaxis( - pred[0,], - 0, - -1, - ), - ) + save_rgb(out_fn, np.moveaxis(pred[0,], 0, -1)) else: - OmeTiffWriter.save( - pred[0,], - out_fn, - dim_order="CYX", - ) + OmeTiffWriter.save(pred[0,], out_fn, dim_order="CYX") else: raise ValueError("invalid 4D output for 2d data") elif len(pred.shape) == 5: assert pred.shape[0] == 1, "error, found non-trivial batch dimension" - OmeTiffWriter.save( - pred[0,], - out_fn, - dim_order="CZYX", - ) + OmeTiffWriter.save(pred[0,], out_fn, dim_order="CZYX") else: raise ValueError("error in prediction output shape") + # ------------------------------------------------------------------ def run_inference(self): - self.setup_model() self.setup_data_processing() - # set up data filenames + dataset_list = generate_test_dataset_dict( self.data_cfg.inference_input.dir, self.data_cfg.inference_input.data_type ) - # loop through all images and apply the model for ds in tqdm(dataset_list, desc="Predicting images"): - - # output file name info fn_core = Path(ds).stem suffix = self.data_cfg.inference_output.suffix - timelapse_data = 0 - # if timelapse ... + if ( "T" in self.data_cfg.inference_input.reader_params["dimension_order_out"] ): if "T" in self.data_cfg.inference_input.reader_params: raise NotImplementedError( - "processing a subset of all timepoint is not supported yet" + "processing a subset of all timepoints is not supported yet" ) tmppath = tempfile.mkdtemp() print(f"making a temp folder at {tmppath}") - # get the number of time points try: reader = BioImage(ds, reader=bioio_tifffile.Reader) except Exception: @@ -290,8 +440,8 @@ def run_inference(self): print(f"Image {ds} failed at read process check the format.") timelapse_data = reader.dims.T - tmpfile_list = [] + for t_idx in tqdm( range(timelapse_data), desc="Predicting image timepoint" ): @@ -310,17 +460,13 @@ def run_inference(self): f"Image {ds} failed at read process check the format." ) - # prepare output filename out_fn = Path(tmppath) / f"{fn_core}_{t_idx}.npy" - self.process_one_image(img, out_fn) tmpfile_list.append(out_fn) - # gather all individual outputs and save as timelapse - out_array = [np.load(tmp_one_file) for tmp_one_file in tmpfile_list] + out_array = [np.load(f) for f in tmpfile_list] out_array = np.stack(out_array, axis=0) - # prepare output filename if "." in suffix: if ".tif" in suffix or ".tiff" in suffix or ".ome.tif" in suffix: out_fn = ( @@ -329,7 +475,8 @@ def run_inference(self): ) else: raise ValueError( - "please check output suffix, either unexpected dot or unsupported fileformat" # noqa E501 + "please check output suffix, either unexpected dot or " + "unsupported fileformat" ) else: out_fn = ( @@ -340,20 +487,15 @@ def run_inference(self): if len(out_array.shape) == 3: dim_order = "TYX" elif len(out_array.shape) == 4: - if self.spatial_dims == 3: - dim_order = "TZYX" - else: - dim_order = "TCYX" + dim_order = "TZYX" if self.spatial_dims == 3 else "TCYX" elif len(out_array.shape) == 5: dim_order = "TCZYX" else: raise ValueError(f"Unexpected pred of shape {out_array.shape}") - # save the file output OmeTiffWriter.save(out_array, out_fn, dim_order=dim_order) - - # clean up temporary dir shutil.rmtree(tmppath) + else: try: img = BioImage(ds, reader=bioio_tifffile.Reader).get_image_data( @@ -368,7 +510,6 @@ def run_inference(self): print(f"Error: {e}") print(f"Image {ds} failed at read process check the format.") - # prepare output filename if "." in suffix: if ( ".png" in suffix @@ -382,7 +523,8 @@ def run_inference(self): ) else: raise ValueError( - "please check output suffix, either unexpected dot or unsupported fileformat" # noqa E501 + "please check output suffix, either unexpected dot or " + "unsupported fileformat" ) else: out_fn = ( diff --git a/mmv_im2im/proj_trainer.py b/mmv_im2im/proj_trainer.py index abd9c20..4a663cc 100644 --- a/mmv_im2im/proj_trainer.py +++ b/mmv_im2im/proj_trainer.py @@ -1,5 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) import logging from pathlib import Path from importlib import import_module @@ -10,9 +13,6 @@ from mmv_im2im.utils.nnHeuristic import get_nnunet_plans import pyrallis -import warnings - -warnings.simplefilter(action="ignore", category=FutureWarning) torch.set_float32_matmul_precision("medium") ############################################################################### @@ -42,9 +42,8 @@ def __init__(self, cfg): def run_training(self): self.data = get_data_module(self.data_cfg) + dynunet_info = None if self.model_cfg.net["func_name"] == "DynUNet": - # 1. Gather inputs for heuristic - # You might need to add these fields to your YAML or extract them from data extra_params = ( self.data_cfg.extra if self.data_cfg.extra is not None else {} ) @@ -63,6 +62,13 @@ def run_training(self): } ) + dynunet_info = ( + f"nnU-Net configured for {len(patch_size)}D.\n" + f"Filters: {plans['filters']}\n" + f"Strides: {plans['strides']}\n" + f"Kernel size: {plans['kernel_size']}\n" + f"Upsample Kernel size: {plans['upsample_kernel_size']}\n" + ) print(f"✅ nnU-Net configured for {len(patch_size)}D.") print(f"Filters: {plans['filters']}") print(f"Strides: {plans['strides']}") @@ -127,6 +133,13 @@ def run_training(self): pyrallis.dump( self.data_cfg, open(save_path / Path("data_config.yaml"), "w") ) + if dynunet_info is not None: + nnunet_cfg_path = save_path / "nnUnet_parameter_generation.txt" + with open(nnunet_cfg_path, "w") as f: + f.write(dynunet_info) + print( + f"Inferred model configuration saved in -> {nnunet_cfg_path} for inference configuration" + ) print("start training ... ") trainer.fit(model=self.model, datamodule=self.data) diff --git a/mmv_im2im/proj_trainer_multishape.py b/mmv_im2im/proj_trainer_multishape.py new file mode 100644 index 0000000..d0ab91e --- /dev/null +++ b/mmv_im2im/proj_trainer_multishape.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +--------------- +Entry point for training the spherical-harmonic regression model with +VARIABLE-SIZE inputs. + +Key changes vs. the original proj_trainer.py +--------------------------------------------- +1. The data module is wrapped in VariableSizeDataModule so that every + batch is dynamically padded to a common size that is divisible by 16, + and the x0 coordinates in the GT vector are corrected for the padding + offset. + +2. The per-sample transform DivisiblePadWithGTAdjustd is added to the + preprocess pipeline in the YAML (see Attention_Unet_Reg_variable.yaml), + so that SINGLE-SAMPLE operations (batch_size=1, augmentation caching, + etc.) also work correctly. + +3. When batch_size=1 is used (the safest choice) the custom collate is + technically not needed, but VariableSizeDataModule is still applied + for consistency and forward compatibility. + +""" + +import warnings + +warnings.simplefilter(action="ignore", category=FutureWarning) +import logging +from importlib import import_module +from pathlib import Path + +import lightning as pl +import pyrallis +import torch + +from mmv_im2im.data_modules import get_data_module +from mmv_im2im.utils.misc import parse_ops_list +from mmv_im2im.utils.nnHeuristic import get_nnunet_plans + +# Variable-size additions +from mmv_im2im.utils.variable_datamodule import VariableSizeDataModule + +torch.set_float32_matmul_precision("medium") + +log = logging.getLogger(__name__) +logging.getLogger("bioio").setLevel(logging.ERROR) + + +# --------------------------------------------------------------------------- +# Trainer +# --------------------------------------------------------------------------- + + +class VariableSizeProjectTrainer: + """ + Trainer with support for variable spatial-size inputs. + + The only behavioural difference from the original ProjectTrainer is + that the DataModule is wrapped in VariableSizeDataModule which injects + a custom collate function that: + - Pads all images in each batch to the same spatial size (max dims + in the batch, rounded up to the nearest multiple of 16). + - Adjusts the first 3 elements of the GT vector (x0: z, y, x centre + coordinates) to reflect the padding offset. + - Leaves all spherical-harmonic coefficients untouched. + + Parameters + ---------- + cfg : ProgramConfig + """ + + def __init__(self, cfg): + pl.seed_everything(123, workers=True) + self.model_cfg = cfg.model + self.train_cfg = cfg.trainer + self.data_cfg = cfg.data + self.model = None + self.data = None + + # ------------------------------------------------------------------ + @staticmethod + def _k_from_strides(strides) -> int: + """ + Computes the minimum spatial divisibility k required by a UNet + whose encoder strides are ``strides`` (list of per-layer stride + lists, e.g. [[1,1],[2,2],[2,2],[2,2]]). + + k = product of all strides along each spatial axis, then take + the maximum across axes (handles anisotropic downsampling). + + Examples + -------- + AttentionUnet strides=[1,2,2,2,2] -> k = 2^4 = 16 + DynUNet 6 DS strides=[[1,1],[2,2]x6] -> k = 2^6 = 64 + """ + # Normalise: accept both flat [1,2,2,2] and nested [[1,1],[2,2],...] + if not isinstance(strides[0], (list, tuple)): + strides = [[s] for s in strides] + + n_dims = len(strides[0]) + k_per_dim = [1] * n_dims + for stride_layer in strides: + for i, s in enumerate(stride_layer): + k_per_dim[i] *= s + return int(max(k_per_dim)) + + # ------------------------------------------------------------------ + def run_training(self): + + # ── 1. Build the base data module ───────────────────────────── + base_data = get_data_module(self.data_cfg) + + dynunet_info = None + collate_k = 16 + if self.model_cfg.net["func_name"] == "DynUNet": + extra_params = ( + self.data_cfg.extra if self.data_cfg.extra is not None else {} + ) + patch_size = extra_params.get("patch_size", [256, 256]) + spacing = extra_params.get("spacing", [1.0, 1.0]) + modality = extra_params.get("modality", "non-CT") + min_size = extra_params.get("min_size", 8) + plans = get_nnunet_plans(patch_size, spacing, modality, min_size=min_size) + self.model_cfg.net["params"].update( + { + "kernel_size": plans["kernel_size"], + "strides": plans["strides"], + "filters": plans["filters"], + "upsample_kernel_size": plans["upsample_kernel_size"], + } + ) + collate_k = self._k_from_strides(plans["strides"]) + + dynunet_info = ( + f"nnU-Net configured for {len(patch_size)}D.\n" + f"[VariableSizeProjectTrainer] DynUNet: {len(plans['strides'])-1} downsampling stages -> collate k={collate_k}\n" + f"Filters: {plans['filters']}\n" + f"Strides: {plans['strides']}\n" + f"Kernel size: {plans['kernel_size']}\n" + f"Upsample Kernel size: {plans['upsample_kernel_size']}\n" + ) + print(f"✅ nnU-Net configured for {len(patch_size)}D.") + print( + f"[VariableSizeProjectTrainer] DynUNet: " + f"{len(plans['strides'])-1} downsampling stages " + f"-> collate k={collate_k}" + ) + print(f"Filters: {plans['filters']}") + print(f"Strides: {plans['strides']}") + print(f"Kernel size: {plans['kernel_size']}") + print(f"Upsample Kernel size: {plans['upsample_kernel_size']}") + + # ── 2. Wrap data module with the correct k ───────────────────── + # DivisiblePadWithGTAdjustd (per-sample, in YAML) may use a smaller + # k. The collate pads further to collate_k and re-adjusts GT coords. + self.data = VariableSizeDataModule( + base_data, + k=collate_k, + mode="constant", + constant_value=0.0, + n_coord_dims=3, + ) + + model_category = self.model_cfg.framework + model_module = import_module(f"mmv_im2im.models.pl_{model_category}") + my_model_func = getattr(model_module, "Model") + self.model = my_model_func(self.model_cfg, verbose=self.train_cfg.verbose) + + # ── 4. Optional weight loading (unchanged) ───────────────────── + if self.model_cfg.model_extra is not None: + if "resume" in self.model_cfg.model_extra: + self.model = self.model.load_from_checkpoint( + self.model_cfg.model_extra["resume"] + ) + elif "pre-train" in self.model_cfg.model_extra: + pre_train = torch.load( + self.model_cfg.model_extra["pre-train"], weights_only=False + ) + if "extend" in self.model_cfg.model_extra: + if ( + self.model_cfg.model_extra["extend"] is not None + and self.model_cfg.model_extra["extend"] is True + ): + pre_train["state_dict"].pop("criterion.xym", None) + model_state = self.model.state_dict() + pretrained_dict = pre_train["state_dict"] + filtered_dict = { + k: v + for k, v in pretrained_dict.items() + if k in model_state and v.shape == model_state[k].shape + } + model_state.update(filtered_dict) + self.model.load_state_dict(model_state, strict=False) + else: + pre_train["state_dict"].pop("criterion.xym", None) + self.model.load_state_dict(pre_train["state_dict"], strict=False) + + # ── 5. Build the Lightning Trainer (unchanged) ───────────────── + if self.train_cfg.callbacks is None: + trainer = pl.Trainer(**self.train_cfg.params) + else: + callback_list = parse_ops_list(self.train_cfg.callbacks) + trainer = pl.Trainer(callbacks=callback_list, **self.train_cfg.params) + + # ── 6. Save configs ──────────────────────────────────────────── + save_path = Path(trainer.log_dir) + if trainer.local_rank == 0: + save_path.mkdir(parents=True, exist_ok=True) + pyrallis.dump( + self.model_cfg, + open(save_path / "model_config.yaml", "w"), + ) + pyrallis.dump( + self.train_cfg, + open(save_path / "train_config.yaml", "w"), + ) + pyrallis.dump( + self.data_cfg, + open(save_path / "data_config.yaml", "w"), + ) + if dynunet_info is not None: + nnunet_cfg_path = save_path / "nnUnet_parameter_generation.txt" + with open(nnunet_cfg_path, "w") as f: + f.write(dynunet_info) + print( + f"Inferred model configuration saved in -> {nnunet_cfg_path} for inference configuration" + ) + + print("Starting training with variable-size input support...") + trainer.fit(model=self.model, datamodule=self.data) diff --git a/mmv_im2im/utils/connectivity_loss.py b/mmv_im2im/utils/connectivity_loss.py index 371ce78..c14ce19 100644 --- a/mmv_im2im/utils/connectivity_loss.py +++ b/mmv_im2im/utils/connectivity_loss.py @@ -35,7 +35,6 @@ def __init__( if self.spatial_dims not in [2, 3]: raise ValueError("spatial_dims must be 2 or 3.") - # Select metric functions self.density_loss_fn = self._get_metric_function(metric_density) self.gradient_loss_fn = self._get_metric_function(metric_gradient) @@ -58,10 +57,8 @@ def __init__( f"Kernel shape should be square or gaussian. {self.kernel_shape} given." ) - # --- Prepare Sobel Kernels for Vectorized Gradient Alignment --- self._init_sobel_kernels() - # --- Initialize Kernel Sizes --- if self.connectivity_mode in ["single", "learneable-single"]: k = self.connectivity_kernel_size k = k if k % 2 != 0 else k + 1 @@ -75,15 +72,11 @@ def __init__( 2**i + 1 for i in range(1, self.connectivity_kernel_size + 1) ] - # --- Initialize Filters --- self.kernels_are_learnable = "learneable" in self.connectivity_mode if self.kernels_are_learnable: self.learnable_filters = nn.ParameterList() for k in self.kernel_sizes: - # Shape depends on dims: - # 2D: (num_classes, 1, k, k) - # 3D: (num_classes, 1, k, k, k) shape = (self.num_classes, 1) + (k,) * self.spatial_dims w_init = torch.empty(shape) nn.init.normal_(w_init, mean=0.0, std=0.01) @@ -93,7 +86,6 @@ def __init__( def _init_sobel_kernels(self): if self.spatial_dims == 2: - # Shape: (1, 1, 3, 3) sobel_x = torch.tensor( [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32 ).view(1, 1, 3, 3) @@ -104,18 +96,12 @@ def _init_sobel_kernels(self): self.register_buffer("sobel_y", sobel_y) elif self.spatial_dims == 3: - # 3D Sobel Kernels construction (3x3x3) - # Smooth (1D): [1, 2, 1] - # Diff (1D): [-1, 0, 1] smooth = torch.tensor([1, 2, 1], dtype=torch.float32) diff = torch.tensor([-1, 0, 1], dtype=torch.float32) - # Helper to create outer products def outer3(v1, v2, v3): return torch.einsum("i,j,k->ijk", v1, v2, v3).view(1, 1, 3, 3, 3) - # Sobel X: Diff(x) * Smooth(y) * Smooth(z) -> indices k(z), j(y), i(x) - # PyTorch Conv3d order is (Depth, Height, Width) i.e. (z, y, x) sobel_x = outer3(smooth, smooth, diff) sobel_y = outer3(smooth, diff, smooth) sobel_z = outer3(diff, smooth, smooth) @@ -130,15 +116,14 @@ def _init_fixed_kernels(self): center = k // 2 shape = (k,) * self.spatial_dims - if self.kernel_shape == "single": + if self.kernel_shape == "square": kernel = torch.ones((1, 1) + shape) / (k**self.spatial_dims - 1) - # Set center to 0 if self.spatial_dims == 2: kernel[0, 0, center, center] = 0.0 else: kernel[0, 0, center, center, center] = 0.0 else: - # Gaussian + # Gaussian kernel sigma = max(0.5, 0.2 * float(k)) coords = torch.arange(0, k) - center if self.spatial_dims == 2: @@ -160,7 +145,6 @@ def _init_fixed_kernels(self): kernel = kernel.view((1, 1) + shape) - # Expand to (Num_Classes, 1, K, K, [K]) repeat_shape = (self.num_classes, 1) + (1,) * self.spatial_dims kernel_expanded = kernel.repeat(repeat_shape) @@ -187,16 +171,13 @@ def _charbonnier_loss(self, pred, target, eps=1e-6): return torch.mean(torch.sqrt((pred - target) ** 2 + eps**2)) def _cosine_loss(self, pred, target): - # Flatten: (B, C, ...) -> (B, C, Total_Pixels) target_flat = target.reshape(target.shape[0], target.shape[1], -1) pred_flat = pred.reshape(pred.shape[0], pred.shape[1], -1) - # Cosine Similarity across the spatial dimension (dim=2 in flat) return 1.0 - F.cosine_similarity(pred_flat, target_flat, dim=2).mean() def _compute_gradient_loss_vectorized(self, pred, target): """ Vectorized gradient calculation using grouped convolutions. - Adapts to 2D or 3D. """ start_idx = 1 if self.ignore_background else 0 p_slice = pred[:, start_idx:].contiguous() @@ -206,38 +187,51 @@ def _compute_gradient_loss_vectorized(self, pred, target): if n_channels_eff == 0: return torch.tensor(0.0, device=pred.device) - # Convolution op and expansion shape if self.spatial_dims == 2: conv_op = F.conv2d expand_shape = (n_channels_eff, 1, 1, 1) + # Stack x and y Sobel kernels: output has 2*n_channels channels + sx = self.sobel_x.repeat(expand_shape) + sy = self.sobel_y.repeat(expand_shape) + # Concat along out_channels dim so one conv handles both axes + sobel_all = torch.cat([sx, sy], dim=0) + p_grads = conv_op(p_slice, sobel_all, padding=1, groups=n_channels_eff) + t_grads = conv_op(t_slice, sobel_all, padding=1, groups=n_channels_eff) + g_x_pred, g_y_pred = ( + p_grads[:, :n_channels_eff], + p_grads[:, n_channels_eff:], + ) + g_x_true, g_y_true = ( + t_grads[:, :n_channels_eff], + t_grads[:, n_channels_eff:], + ) + loss = self.gradient_loss_fn(g_x_pred, g_x_true) + self.gradient_loss_fn( + g_y_pred, g_y_true + ) else: conv_op = F.conv3d expand_shape = (n_channels_eff, 1, 1, 1, 1) - - # Expand Sobel kernels - sx = self.sobel_x.repeat(expand_shape) - sy = self.sobel_y.repeat(expand_shape) - - # Compute gradients - g_x_pred = conv_op(p_slice, sx, padding=1, groups=n_channels_eff) - g_y_pred = conv_op(p_slice, sy, padding=1, groups=n_channels_eff) - g_x_true = conv_op(t_slice, sx, padding=1, groups=n_channels_eff) - g_y_true = conv_op(t_slice, sy, padding=1, groups=n_channels_eff) - - loss = self.gradient_loss_fn(g_x_pred, g_x_true) + self.gradient_loss_fn( - g_y_pred, g_y_true - ) - - if self.spatial_dims == 3: + sx = self.sobel_x.repeat(expand_shape) + sy = self.sobel_y.repeat(expand_shape) sz = self.sobel_z.repeat(expand_shape) - g_z_pred = conv_op(p_slice, sz, padding=1, groups=n_channels_eff) - g_z_true = conv_op(t_slice, sz, padding=1, groups=n_channels_eff) - loss = loss + self.gradient_loss_fn(g_z_pred, g_z_true) + sobel_all = torch.cat([sx, sy, sz], dim=0) + p_grads = conv_op(p_slice, sobel_all, padding=1, groups=n_channels_eff) + t_grads = conv_op(t_slice, sobel_all, padding=1, groups=n_channels_eff) + g_x_pred = p_grads[:, :n_channels_eff] + g_y_pred = p_grads[:, n_channels_eff : 2 * n_channels_eff] + g_z_pred = p_grads[:, 2 * n_channels_eff :] + g_x_true = t_grads[:, :n_channels_eff] + g_y_true = t_grads[:, n_channels_eff : 2 * n_channels_eff] + g_z_true = t_grads[:, 2 * n_channels_eff :] + loss = ( + self.gradient_loss_fn(g_x_pred, g_x_true) + + self.gradient_loss_fn(g_y_pred, g_y_true) + + self.gradient_loss_fn(g_z_pred, g_z_true) + ) return loss def forward(self, y_pred_softmax, y_true_one_hot): - # Ensure inputs are contiguous y_pred_softmax = y_pred_softmax.contiguous() y_true_one_hot = y_true_one_hot.float().contiguous() @@ -257,9 +251,9 @@ def forward(self, y_pred_softmax, y_true_one_hot): if self.lambda_density <= 0: return total_loss - density_loss_accum = [] + density_loss_sum = torch.tensor(0.0, device=device) + num_kernels_applied = 0 - # Prepare list of kernels if self.kernels_are_learnable: kernels_to_use = self.learnable_filters else: @@ -272,29 +266,29 @@ def forward(self, y_pred_softmax, y_true_one_hot): padding = k // 2 weight = weight.contiguous() - # --- Vectorized Convolution --- pred_neighbor_avg = conv_op( y_pred_softmax, weight, padding=padding, groups=self.num_classes ) - true_neighbor_avg = conv_op( - y_true_one_hot, weight, padding=padding, groups=self.num_classes - ) - # --- Slice Relevant Channels --- + with torch.no_grad(): + true_neighbor_avg = conv_op( + y_true_one_hot, weight, padding=padding, groups=self.num_classes + ) + p_neigh_rel = pred_neighbor_avg[:, start_idx:] t_neigh_rel = true_neighbor_avg[:, start_idx:] p_pixel_rel = y_pred_softmax[:, start_idx:] t_pixel_rel = y_true_one_hot[:, start_idx:] - # --- Full-Consistency Logic --- loss_a = self.density_loss_fn(p_neigh_rel, t_pixel_rel) loss_b = self.density_loss_fn(p_pixel_rel, t_neigh_rel) loss_c = self.density_loss_fn(p_neigh_rel, t_neigh_rel) - density_loss_accum.append(loss_a + loss_b + loss_c) + density_loss_sum = density_loss_sum + (loss_a + loss_b + loss_c) + num_kernels_applied += 1 - if density_loss_accum: - density_term = torch.stack(density_loss_accum).mean() + if num_kernels_applied > 0: + density_term = density_loss_sum / num_kernels_applied total_loss = total_loss + (self.lambda_density * density_term) return total_loss diff --git a/mmv_im2im/utils/custom_transforms.py b/mmv_im2im/utils/custom_transforms.py new file mode 100644 index 0000000..e261eb6 --- /dev/null +++ b/mmv_im2im/utils/custom_transforms.py @@ -0,0 +1,509 @@ +""" +-------------------- +MONAI-compatible transform that solves the variable-size input problem +for spherical harmonic regression training. + +Problem context +--------------- +The GT vector has the format: [x0_z, x0_y, x0_x, c_1, c_2, ..., c_n] + - x0 (first 3 elements): center-of-mass of the cell in the CROPPED + volume's coordinate space (voxel units). + - c_i (remaining elements): spherical harmonic coefficients encoding + shape as radii from the center → they are SCALE-DEPENDENT but NOT + position-dependent. Padding does NOT alter them. + +When we pad an image by `pad_before = [pz, py, px]` pixels at the +start of each axis, the cell content shifts by exactly that amount, so: + x0_new[i] = x0_old[i] + pad_before[i] (i = z, y, x) + +The SH coefficients are unchanged. + +This transform: + 1. Pads the image (IM key) so every spatial dimension is divisible by `k` + (default k=16 for Unet Models with 4 downsampling stages). + 2. Applies the padding SYMMETRICALLY (half before, half after) so the + cell stays centred in the padded volume, which helps the network + generalise positional predictions. + 3. Updates x0 in the GT vector to reflect the new coordinates. + +Usage in YAML (preprocess section, after NormalizeIntensityd): + - module_name: custom_transforms + func_name: DivisiblePadWithGTAdjustd + params: + image_key: "IM" + gt_key: "GT" + k: 16 + mode: "constant" # zero-padding is safest after normalisation +""" + +from typing import Dict, Hashable, Mapping, Union +import numpy as np +import torch +import torch.nn.functional as F +from monai.transforms import MapTransform +from monai.transforms import Transform + +# --------------------------------------------------------------------------- +# Core helper +# --------------------------------------------------------------------------- + + +def compute_symmetric_pad(spatial_shape, k: int): + """ + For each spatial dimension of size `s`, compute (pad_before, pad_after) + so that s + pad_before + pad_after is the smallest multiple of k >= s. + + Returns + ------- + pad_before : list[int] – one entry per spatial dim (Z, Y, X order) + pad_after : list[int] – idem + """ + pad_before, pad_after = [], [] + for s in spatial_shape: + remainder = s % k + total_pad = 0 if remainder == 0 else k - remainder + pb = total_pad // 2 + pa = total_pad - pb + pad_before.append(pb) + pad_after.append(pa) + return pad_before, pad_after + + +def apply_pad_to_tensor( + img: torch.Tensor, pad_before, pad_after, mode: str = "constant", value: float = 0.0 +) -> torch.Tensor: + """ + Pad a tensor of shape [C, *spatial] using torch.nn.functional.pad. + + torch.nn.functional.pad expects padding in REVERSED dimension order + (last dim first) and WITHOUT the channel dimension: + (pad_last_front, pad_last_back, ..., pad_first_spatial_front, pad_first_spatial_back) + + Parameters + ---------- + img : [C, Z, Y, X] or [C, Y, X] + pad_before : list[int] in spatial order (Z→Y→X or Y→X) + pad_after : list[int] in spatial order + """ + # Build pad tuple: reversed spatial dims, NO channel padding + pad_args = [] + for pb, pa in reversed(list(zip(pad_before, pad_after))): + pad_args.extend([pb, pa]) + # Channel dim – no padding + pad_args.extend([0, 0]) + + if mode == "constant": + return F.pad(img.float(), pad_args, mode="constant", value=value) + elif mode == "reflect": + # reflect requires pad < dim size; fall back to constant if unsafe + try: + return F.pad(img.float(), pad_args, mode="reflect") + except RuntimeError: + return F.pad(img.float(), pad_args, mode="constant", value=value) + else: + return F.pad(img.float(), pad_args, mode="constant", value=value) + + +# --------------------------------------------------------------------------- +# MONAI MapTransform +# --------------------------------------------------------------------------- + + +class DivisiblePadWithGTAdjustd(MapTransform): + """ + Pads the image key to be spatially divisible by ``k`` and adjusts the + spatial coordinates stored in the first n_coord_dims elements of the GT key. + + Parameters + ---------- + keys : list[str] | None + [image_key, gt_key]. keys[0] is the image, keys[1] is the GT vector. + When used via parse_monai_ops (mmv_im2im training pipeline), this + argument is popped by the parser before the constructor is called and + never arrives here. In that case the transform falls back to the + explicit image_key / gt_key arguments, or to the defaults "IM" / "GT". + k : int + Each spatial dimension will be padded to the nearest multiple of k. + For AttentionUnet with strides [1,2,2,2,2] → k=16. + mode : str + Padding mode for torch.nn.functional.pad. + "constant" (zero-padding) is recommended after intensity normalisation. + constant_value : float + Fill value when mode="constant". Default 0.0. + n_coord_dims : int + Number of leading GT elements that are spatial coordinates and must + be adjusted. Default 3 → (z, y, x). + image_key : str | None + Explicit image key override used when keys=None. Default "IM". + gt_key : str | None + Explicit GT key override used when keys=None. Default "GT". + + Three valid ways to instantiate + -------------------------------- + # 1. Via YAML / parse_monai_ops (keys is popped before __init__, defaults used) + - module_name: mmv_im2im.utils.custom_transforms + func_name: DivisiblePadWithGTAdjustd + params: + keys: ["IM", "GT"] # consumed by parser; sets image_key="IM", gt_key="GT" + k: 16 + + # 2. Direct Python call with keys list + t = DivisiblePadWithGTAdjustd(keys=["IM", "GT"], k=16) + + # 3. Direct Python call with custom key names + t = DivisiblePadWithGTAdjustd(image_key="image", gt_key="label", k=16) + """ + + def __init__( + self, + keys=None, # parse_monai_ops pops 'keys' from func_params before + # calling the constructor of custom (non-monai) transforms, + # so this argument may never actually arrive here. + # keys[0] → image key, keys[1] → GT key. + # Falls back to ["IM", "GT"] if None or not provided. + k: int = 16, + mode: str = "constant", + constant_value: float = 0.0, + n_coord_dims: int = 3, + image_key: str = None, # optional explicit override; ignored if keys provided + gt_key: str = None, # optional explicit override; ignored if keys provided + ): + # Resolve image_key / gt_key from keys or explicit overrides or defaults + if keys is not None: + if len(keys) != 2: + raise ValueError( + f"DivisiblePadWithGTAdjustd expects exactly 2 keys " + f"[image_key, gt_key], got {keys}" + ) + _image_key = keys[0] + _gt_key = keys[1] + + super().__init__(keys=[_image_key, _gt_key]) + self.image_key = _image_key + self.gt_key = _gt_key + self.k = k + self.mode = mode + self.constant_value = constant_value + self.n_coord_dims = n_coord_dims + + # ------------------------------------------------------------------ + def __call__(self, data: Mapping[Hashable, object]) -> Dict[Hashable, object]: + + d = dict(data) + + img = d[self.image_key] + gt = d[self.gt_key] + + # ── Normalise to torch.Tensor ────────────────────────────────── + if isinstance(img, np.ndarray): + img = torch.from_numpy(img.copy()) + else: + img = img.as_tensor() if hasattr(img, "as_tensor") else img.clone() + + if isinstance(gt, np.ndarray): + gt = torch.from_numpy(gt.copy()).float() + else: + gt = ( + gt.as_tensor().float() + if hasattr(gt, "as_tensor") + else gt.clone().float() + ) + + # ── Spatial shape: img is [C, *spatial] ─────────────────────── + spatial_shape = img.shape[1:] # (Z, Y, X) for 3D; (Y, X) for 2D + + pad_before, pad_after = compute_symmetric_pad(spatial_shape, self.k) + + # ── Apply padding to image ───────────────────────────────────── + img_padded = apply_pad_to_tensor( + img, pad_before, pad_after, mode=self.mode, value=self.constant_value + ) + + # ── Adjust spatial coordinates in GT ────────────────────────── + # GT layout: [x0_z, x0_y, x0_x, coeff_0, coeff_1, ...] + # Padding shifts the cell by pad_before[i] voxels on each axis. + gt_adjusted = gt.clone() + n_adjust = min(self.n_coord_dims, len(pad_before)) + for i in range(n_adjust): + gt_adjusted[i] = gt[i] + pad_before[i] + + d[self.image_key] = img_padded + d[self.gt_key] = gt_adjusted + + return d + + +""" +--------------------- +INFERENCE transforms that invert the effect of padding applied +by DivisiblePadWithGTAdjustd during training. + +Issue with variable sizes during inference +-------------------------------------------- +During training, each sample passes through DivisiblePadWithGTAdjustd which: + - Pads the image to a multiple of k + - Adds pad_before[i] to the x0 coordinates of the GT + +In inference, the network predicts in the PADDED space. To recover +the coordinates in the ORIGINAL space, pad_before[i] must be subtracted. + + pad_before[i] = ((ceil(s_i / k) * k) - s_i) // 2 + +The only piece of data needed to calculate this is the original shape of the +image BEFORE padding. Since images have variable sizes, that shape changes +for each image. + +Solution: PadStateBuffer +------------------------ +A lightweight object that acts as shared memory between: + - RecordShapeAndPad → writes the original shape and applies the padding + - RemovePadFromPrediction → reads the original shape and corrects the prediction + +Both transforms receive the SAME instance of PadStateBuffer. +As long as one image is processed at a time (standard sequential inference), +the state remains consistent. + +Full inference flow +----------------------------- + state = PadStateBuffer() + + preproc = RecordShapeAndPad(state, k=16) + postproc = RemovePadFromPrediction(state, k=16, n_coord_dims=3) + + for img in images: + img_padded = preproc(img) # saves shape, pads + pred_padded = model(img_padded) # inference in padded space + pred_original = postproc(pred_padded) # corrects coordinates +""" + + +# --------------------------------------------------------------------------- +# Shared state buffer +# --------------------------------------------------------------------------- + + +class PadStateBuffer: + """ + Shared state object between RecordShapeAndPad and + RemovePadFromPrediction. + + Attributes + ---------- + original_spatial_shape : tuple[int, ...] | None + Spatial shape (without channel) of the image BEFORE padding. + Written in RecordShapeAndPad and read in RemovePadFromPrediction. + pad_before : list[int] | None + Padding applied at the start of each spatial axis. + Calculated and saved in RecordShapeAndPad. + """ + + def __init__(self): + self.original_spatial_shape: Union[tuple, None] = None + self.pad_before: Union[list, None] = None + + def reset(self): + self.original_spatial_shape = None + self.pad_before = None + + def __repr__(self): + return ( + f"PadStateBuffer(" + f"original_spatial_shape={self.original_spatial_shape}, " + f"pad_before={self.pad_before})" + ) + + +# --------------------------------------------------------------------------- +# Internal helper +# --------------------------------------------------------------------------- + + +def _apply_divisible_pad( + img: torch.Tensor, + k: int, + mode: str = "constant", + value: float = 0.0, +) -> tuple: + """ + Pads img=[C, *spatial] to the next multiple of k in each spatial axis. + + Returns + ------- + img_padded : torch.Tensor [C, *spatial_padded] + pad_before : list[int] padding applied at the start of each axis + """ + spatial = img.shape[1:] + pad_before, pad_after = [], [] + for s in spatial: + remainder = s % k + total = 0 if remainder == 0 else k - remainder + pb = total // 2 + pad_before.append(pb) + pad_after.append(total - pb) + + # torch.nn.functional.pad: inverse axis order, no channel dim + pad_args = [] + for pb, pa in reversed(list(zip(pad_before, pad_after))): + pad_args.extend([pb, pa]) + pad_args.extend([0, 0]) + + try: + img_padded = F.pad(img.float(), pad_args, mode=mode, value=value) + except (RuntimeError, NotImplementedError): + img_padded = F.pad(img.float(), pad_args, mode="constant", value=value) + + return img_padded, pad_before + + +# --------------------------------------------------------------------------- +# Transform 1/2 — RecordShapeAndPad (preprocessing) +# --------------------------------------------------------------------------- + + +class RecordShapeAndPad(Transform): + """ + PREPROCESSING transform that: + 1. Registers the original spatial shape in the PadStateBuffer. + 2. Applies divisible padding to the image. + + Replaces MONAI's DivisiblePad in the inference pipeline + when padding needs to be undone in the prediction later. + + Parameters + ---------- + state : PadStateBuffer + Shared object with RemovePadFromPrediction. + k : int + All spatial dimensions are padded to the next multiple of k. + mode : str + Padding mode. "constant" recommended after NormalizeIntensity. + constant_value : float + Fill value when mode="constant". + + Input / Output + -------------- + Input : tensor or numpy [C, *spatial] + Output : tensor [C, *spatial_padded] + """ + + def __init__( + self, + state: PadStateBuffer, + k: int = 16, + mode: str = "constant", + constant_value: float = 0.0, + ): + self.state = state + self.k = k + self.mode = mode + self.constant_value = constant_value + + def __call__( + self, + img: Union[np.ndarray, torch.Tensor], + ) -> torch.Tensor: + + if isinstance(img, np.ndarray): + img = torch.from_numpy(img.copy()).float() + else: + img = img.float() + + # 1. Save shape BEFORE padding + self.state.original_spatial_shape = tuple(img.shape[1:]) + + # 2. Pad and save pad_before in the shared state + img_padded, pad_before = _apply_divisible_pad( + img, self.k, self.mode, self.constant_value + ) + self.state.pad_before = pad_before + + return img_padded + + +# --------------------------------------------------------------------------- +# Transform 2/2 — RemovePadFromPrediction (postprocessing) +# --------------------------------------------------------------------------- + + +class RemovePadFromPrediction(Transform): + """ + POSTPROCESSING transform that corrects the x0 coordinates of the + prediction by subtracting the offset introduced by padding. + + Reads pad_before from the PadStateBuffer shared with RecordShapeAndPad, + which was updated when processing the corresponding image. + + Parameters + ---------- + state : PadStateBuffer + Same object passed to RecordShapeAndPad. + k : int + Same k used in RecordShapeAndPad. Used as fallback to + recalculate pad_before if it is not in the state for some reason. + n_coord_dims : int + Number of initial elements of the prediction vector that + represent spatial coordinates and must be corrected. + Default 3 → (z, y, x) of the center of mass. + + Note + ---- + SH coefficients (indices from n_coord_dims onwards) are + translation-invariant and are NOT modified. + + Input / Output + -------------- + Input : tensor or numpy (N,) or (B, N) — prediction in padded space + Output : same type with the first n_coord_dims elements corrected + """ + + def __init__( + self, + state: PadStateBuffer, + k: int = 16, + n_coord_dims: int = 3, + ): + self.state = state + self.k = k + self.n_coord_dims = n_coord_dims + + def __call__( + self, + pred_vector: Union[np.ndarray, torch.Tensor], + ) -> Union[np.ndarray, torch.Tensor]: + + # -- Recover pad_before -------------------------------------- + if self.state.pad_before is not None: + pad_before = self.state.pad_before + + elif self.state.original_spatial_shape is not None: + # Fallback: recalculate if pad_before was not saved + pad_before = [] + for s in self.state.original_spatial_shape: + remainder = s % self.k + total = 0 if remainder == 0 else self.k - remainder + pad_before.append(total // 2) + else: + raise RuntimeError( + "RemovePadFromPrediction: PadStateBuffer is empty. " + "Ensure RecordShapeAndPad processed the image " + "BEFORE calling this transform." + ) + + # -- Type conversion ------------------------------------------ + return_numpy = isinstance(pred_vector, np.ndarray) + if return_numpy: + out = torch.from_numpy(pred_vector.copy()).float() + else: + out = pred_vector.clone().float() + + # -- Coordinate correction ----------------------------------- + batched = out.dim() == 2 # True if shape [B, N] + n_adjust = min(self.n_coord_dims, len(pad_before)) + + for i in range(n_adjust): + if batched: + out[:, i] = out[:, i] - pad_before[i] + else: + out[i] = out[i] - pad_before[i] + + return out.numpy() if return_numpy else out diff --git a/mmv_im2im/utils/elbo_loss.py b/mmv_im2im/utils/elbo_loss.py index fc08241..63e3f10 100644 --- a/mmv_im2im/utils/elbo_loss.py +++ b/mmv_im2im/utils/elbo_loss.py @@ -1,12 +1,27 @@ import torch import torch.nn as nn import torch.nn.functional as F +from typing import Union class KLDivergence(nn.Module): """ Calculates the KL Divergence between two diagonal Gaussian distributions. Used for the Probabilistic U-Net latent space regularization. + + CHANGE: Numerically stable KL computation. + The original formula computed exp(logvar_q) and exp(logvar_p) independently, + which can overflow (→ inf) when logvar values are large and the clamp is + disabled. The reformulated version uses exp(logvar_q - logvar_p) and + exp(-logvar_p), keeping the exponent as a difference which is bounded even + when individual logvars are large: + + Original: (exp(lq) + (mu_q - mu_p)^2) / exp(lp) + Stable: exp(lq - lp) + (mu_q - mu_p)^2 * exp(-lp) + + These are mathematically identical but the stable form avoids intermediate + overflow. The improvement is most relevant in early training when the + posterior and prior are far apart. """ def __init__(self): @@ -17,10 +32,14 @@ def forward(self, mu_q, logvar_q, mu_p, logvar_p, kl_clamp=None): logvar_q = torch.clamp(logvar_q, min=-kl_clamp, max=kl_clamp) logvar_p = torch.clamp(logvar_p, min=-kl_clamp, max=kl_clamp) + # CHANGE: Stable formulation using log-domain subtraction. + # exp(logvar_q) / exp(logvar_p) = exp(logvar_q - logvar_p) + # (mu_q - mu_p)^2 / exp(logvar_p) = (mu_q - mu_p)^2 * exp(-logvar_p) kl_batch_sum = 0.5 * torch.sum( logvar_p - logvar_q - + (torch.exp(logvar_q) + (mu_q - mu_p) ** 2) / torch.exp(logvar_p) + + torch.exp(logvar_q - logvar_p) + + (mu_q - mu_p) ** 2 * torch.exp(-logvar_p) - 1, dim=1, ) @@ -48,7 +67,7 @@ def __init__( use_topological_regularization: bool = False, topological_weight: float = 0.1, topological_warmup_epochs: int = 10, - topological_connectivity: int = 4, # Use 6 or 26 for 3D + topological_connectivity: int = 4, topological_inclusion: list = None, topological_exclusion: list = None, topological_min_thick: int = 1, @@ -69,13 +88,15 @@ def __init__( gdl_focal_weight: float = 1.0, gdl_warmup_epochs: int = 10, gdl_class_weights: list = None, - # --- Hausdorff Regularization (MONAI) --- + # --- Hausdorff Regularization --- use_hausdorff_regularization: bool = False, hausdorff_weight: float = 0.1, hausdorff_downsample_scale: float = 0.5, - hausdorff_dt_iterations: int = 30, + hausdorff_dt_iterations: Union[int, str] = "auto", hausdorff_warmup_epochs: int = 10, hausdorff_include_background: bool = False, + hausdorff_distance_mode: str = "l2", + hausdorff_normalize_weights: bool = True, # --- Homology Regularization (Persistence Image) --- use_homology_regularization: bool = False, homology_weight: float = 0.1, @@ -93,6 +114,7 @@ def __init__( chunks: int = 2000, weighting_power: float = 2.0, composite_flag: bool = True, + homology_adaptive_sigma: bool = True, # --- Topological Complexity Regularization --- use_topological_complexity: bool = False, topological_complexity_weight: float = 0.1, @@ -106,6 +128,9 @@ def __init__( complexity_k_top: int = 2000, complexity_temperature: float = 0.01, complexity_auto_balance: bool = True, + complexity_normalize_lifetimes: bool = True, + # --- Warmup schedule --- + warmup_schedule: str = "linear", # "linear" or "cosine" ): super().__init__() self.beta = beta @@ -114,9 +139,9 @@ def __init__( self.kl_clamp = kl_clamp self.task = task self.regression_loss_type = regression_loss_type.lower() + self.warmup_schedule = warmup_schedule self.kl_divergence_calculator = KLDivergence() - # Class weights for the main reconstruction loss (Cross Entropy) if elbo_class_weights is not None: self.elbo_class_weights = torch.tensor( elbo_class_weights, dtype=torch.float32 @@ -125,7 +150,6 @@ def __init__( self.elbo_class_weights = None if self.task == "segmentation": - # --- Initialize Regularizers --- reg_used = [] # 1. Fractal @@ -180,7 +204,7 @@ def __init__( metric_gradient=connectivity_metric_gradient, ) - # 4. GDL Focal (MONAI) + # 4. GDL Focal self.use_gdl_focal_regularization = use_gdl_focal_regularization if self.use_gdl_focal_regularization: self.gdl_focal_weight = gdl_focal_weight @@ -193,7 +217,6 @@ def __init__( if gdl_class_weights else None ) - # MONAI losses are generally dimension-agnostic given correct input shape self.gdl_focal_loss_calculator = GeneralizedDiceFocalLoss( softmax=True, to_onehot_y=True, weight=monai_focal_weights ) @@ -213,6 +236,8 @@ def __init__( spatial_dims=self.spatial_dims, dt_iterations=self.hausdorff_dt_iterations, include_background=self.hausdorff_include_background, + distance_mode=hausdorff_distance_mode, + normalize_weights=hausdorff_normalize_weights, ) # 6. Homology (Persistence Image) @@ -234,10 +259,11 @@ def __init__( metric=homology_metric, chunks=chunks, filtering=homology_filtering, - treshold=homology_threshold, + threshold=homology_threshold, k_top=homology_k_top, weighting_power=weighting_power, composite_flag=composite_flag, + adaptive_sigma=homology_adaptive_sigma, ) # 7. Topological Complexity @@ -264,12 +290,21 @@ def __init__( k_top=complexity_k_top, temperature=complexity_temperature, auto_balance=complexity_auto_balance, + normalize_lifetimes=complexity_normalize_lifetimes, ) if len(reg_used) > 0: print(f"Active Regularizers: {reg_used}") def _get_warmup_factor(self, current_epoch, warmup_epochs): + """ + CHANGE: Added cosine warmup schedule as an alternative to linear. + Rationale: Linear warmup introduces the regularizer with a step-like + gradient increase that can cause loss spikes. Cosine warmup starts + very gently (near 0), accelerates in the middle, and smoothly + approaches 1.0. This produces more stable early-training dynamics, + especially for the heavy topological regularizers. + """ if torch.is_tensor(current_epoch): current_epoch = current_epoch.detach().cpu().item() if torch.is_tensor(warmup_epochs): @@ -280,26 +315,32 @@ def _get_warmup_factor(self, current_epoch, warmup_epochs): if current_epoch >= warmup_epochs: return 1.0 - return current_epoch / warmup_epochs + + t = current_epoch / warmup_epochs # in [0, 1) + + if self.warmup_schedule == "cosine": + # Cosine warmup: starts at 0, ends at 1 + import math + + return 0.5 * (1.0 - math.cos(math.pi * t)) + else: + # Default: linear warmup + return t def _downsample_inputs(self, logits, y_true, scale_factor): """Downsamples inputs for computationally expensive topological losses.""" if scale_factor >= 1.0: return logits, y_true - # Determine mode based on dimensionality if self.spatial_dims == 3: mode = "trilinear" else: mode = "bilinear" - # Downsample Logits logits_small = F.interpolate( logits, scale_factor=scale_factor, mode=mode, align_corners=False ) - # Downsample GT - # y_true can be (B, H, W), (B, D, H, W) or with channel dim if y_true.ndim == logits.ndim - 1: y_true_float = y_true.unsqueeze(1).float() else: @@ -309,7 +350,6 @@ def _downsample_inputs(self, logits, y_true, scale_factor): y_true_float, scale_factor=scale_factor, mode="nearest" ) - # Squeeze back if needed if y_true.ndim == logits.ndim - 1: y_true_small = y_true_small.squeeze(1) @@ -326,11 +366,10 @@ def forward( """ # --- 1. Input Standardization --- - # Ensure y_true has the right shape - if y_true.ndim == logits.ndim: # Has channel dim (B, 1, ...) + if y_true.ndim == logits.ndim: y_true_ch = y_true y_true_flat = y_true.squeeze(1) - elif y_true.ndim == logits.ndim - 1: # No channel dim (B, ...) + elif y_true.ndim == logits.ndim - 1: y_true_ch = y_true.unsqueeze(1) y_true_flat = y_true else: @@ -344,7 +383,7 @@ def forward( ): self.elbo_class_weights = self.elbo_class_weights.to(logits.device) - # --- 2. Base Reconstruction Loss (Cross Entropy) --- + # --- 2. Base Reconstruction Loss --- if self.task == "segmentation": reconstruction_loss = F.cross_entropy( logits, @@ -377,13 +416,14 @@ def forward( # --- 4. Regularizers --- if self.task == "segmentation": - # Helper: Softmax Probs - if ( + probs = None + needs_probs = ( self.use_fractal_regularization or self.use_connectivity_regularization or self.use_homology_regularization or self.use_topological_complexity - ): + ) + if needs_probs: probs = F.softmax(logits, dim=1) # A. Fractal Dimension @@ -392,12 +432,8 @@ def forward( epoch, self.fractal_warmup_epochs ) if fractal_factor > 0: - # Slice foreground depending on dimensions if self.n_classes > 1: - if self.spatial_dims == 3: - fg_probs = 1.0 - probs[:, 0:1, :, :, :] - else: - fg_probs = 1.0 - probs[:, 0:1, :, :] + fg_probs = 1.0 - probs[:, 0:1, ...] else: fg_probs = probs @@ -405,7 +441,6 @@ def forward( with torch.no_grad(): fd_true = self.fractal_dimension_calculator(y_fractal) fd_pred = self.fractal_dimension_calculator(fg_probs) - # fd_pred = torch.clamp(fd_pred, min=0.0, max=3.0) -> jut if nan still appear fractal_loss = F.mse_loss(fd_pred, fd_true) total_loss += (self.fractal_weight * fractal_factor) * fractal_loss @@ -426,11 +461,9 @@ def forward( epoch, self.connectivity_warmup_epochs ) if connectivity_factor > 0: - # Create One-Hot GT (B, C, ...) y_true_onehot = F.one_hot( y_true_flat.long(), num_classes=self.n_classes ) - # Permute: Last dim (C) moves to dim 1 permute_dims = (0, logits.ndim - 1) + tuple( range(1, logits.ndim - 1) ) @@ -477,32 +510,37 @@ def forward( homology_factor = self._get_warmup_factor( epoch, self.homology_warmup_epochs ) - if epoch % self.homology_interval == 0: - if homology_factor > 0: + if epoch % self.homology_interval == 0 and homology_factor > 0: + if self.homology_downsample_scale >= 1.0: + probs_h = probs + y_true_h = y_true_flat + else: logits_h, y_true_h = self._downsample_inputs( logits, y_true_flat, self.homology_downsample_scale ) probs_h = F.softmax(logits_h, dim=1) - h_loss = self.homology_calculator(probs_h, y_true_h) - total_loss += (self.homology_weight * homology_factor) * h_loss + + h_loss = self.homology_calculator(probs_h, y_true_h) + total_loss += (self.homology_weight * homology_factor) * h_loss # G. Topological Complexity if self.use_topological_complexity: complexity_factor = self._get_warmup_factor( epoch, self.complexity_warmup_epochs ) - if epoch % self.complexity_interval == 0: - if complexity_factor > 0: + if epoch % self.complexity_interval == 0 and complexity_factor > 0: + if self.complexity_downsample_scale >= 1.0: + probs_c = probs + y_true_c = y_true_flat + else: logits_c, y_true_c = self._downsample_inputs( logits, y_true_flat, self.complexity_downsample_scale ) probs_c = F.softmax(logits_c, dim=1) - c_loss = self.topological_complexity_calculator( - probs_c, y_true_c - ) - total_loss += ( - self.topological_complexity_weight * complexity_factor - ) * c_loss + c_loss = self.topological_complexity_calculator(probs_c, y_true_c) + total_loss += ( + self.topological_complexity_weight * complexity_factor + ) * c_loss return total_loss diff --git a/mmv_im2im/utils/for_transform.py b/mmv_im2im/utils/for_transform.py index 8794d7a..d536a6d 100644 --- a/mmv_im2im/utils/for_transform.py +++ b/mmv_im2im/utils/for_transform.py @@ -1,7 +1,7 @@ from typing import List, Dict from functools import partial from mmv_im2im.utils.misc import parse_config, parse_config_func_without_params -from monai.transforms import Compose, Lambdad, Lambda +from monai.transforms import Compose, Lambdad, Lambda, MapTransform import inspect @@ -33,16 +33,14 @@ def center_crop(img, target_shape): def parse_monai_ops(trans_func: List[Dict]): - # Here, we will use the Compose function in MONAI to merge - # all transformations. If any trnasformation not from MONAI, - # a MONAI Lambda function will be used to wrap around it. + trans_list = [] - # loop throught the config + # loop through the config for func_info in trans_func: if func_info["module_name"] == "monai.transforms": if func_info["func_name"] == "LoadImaged": - # Here, we handle the LoadImaged seperatedly to allow bio-reader + # Here, we handle the LoadImaged separately to allow bio-reader from mmv_im2im.utils.misc import monai_bio_reader from monai.transforms import LoadImaged @@ -53,29 +51,33 @@ def parse_monai_ops(trans_func: List[Dict]): trans_list.append(parse_config(func_info)) else: my_func = parse_config_func_without_params(func_info) - func_params = func_info["params"] - apply_keys = func_params.pop("keys") - # check if any other params - if len(func_params) > 0: - if inspect.isclass(my_func): - callable_func = my_func(**func_params) - else: - callable_func = partial(my_func, **func_params) + # MapTransform subclasses operate on the full data dict and manage + # their own keys internally — wrap them like native MONAI transforms. + if inspect.isclass(my_func) and issubclass(my_func, MapTransform): + trans_list.append(parse_config(func_info)) else: - callable_func = my_func + # Single-key function: use Lambdad wrapper (original behaviour) + func_params = func_info["params"] + apply_keys = func_params.pop("keys") + + # check if any other params + if len(func_params) > 0: + if inspect.isclass(my_func): + callable_func = my_func(**func_params) + else: + callable_func = partial(my_func, **func_params) + else: + callable_func = my_func - trans_list.append(Lambdad(keys=apply_keys, func=callable_func)) + trans_list.append(Lambdad(keys=apply_keys, func=callable_func)) return Compose(trans_list) def parse_monai_ops_vanilla(trans_func: List[Dict]): - # Here, we will use the Compose function in MONAI to merge - # all transformations. If any trnasformation not from MONAI, - # a MONAI Lambda function will be used to wrap around it. trans_list = [] - # loop throught the config + # loop through the config for func_info in trans_func: if func_info["module_name"] == "monai.transforms": trans_list.append(parse_config(func_info)) diff --git a/mmv_im2im/utils/fractal_layers.py b/mmv_im2im/utils/fractal_layers.py index dd8a699..1e85b4d 100644 --- a/mmv_im2im/utils/fractal_layers.py +++ b/mmv_im2im/utils/fractal_layers.py @@ -13,7 +13,6 @@ class Slice_windows_differentiable(nn.Module): def __init__(self, num_kernels: int, mode: str = "classic", spatial_dims=2): super().__init__() self.num_kernels = num_kernels - # Kernel sizes are powers of 2 self.kernel_sizes = [2**i for i in range(1, num_kernels + 1)] self.mode = mode self.spatial_dims = spatial_dims @@ -31,32 +30,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: results = [] - # Sequential processing per scale for k in self.kernel_sizes: - # Check dimensions against kernel size if is_3d: if k > D or k > H or k > W: results.append(x.new_zeros(batch_size)) continue - # 3D Pooling kernel_vol = k * k * k window_avg = F.avg_pool3d(x, kernel_size=k, stride=k) else: if k > H or k > W: results.append(x.new_zeros(batch_size)) continue - # 2D Pooling kernel_vol = k * k window_avg = F.avg_pool2d(x, kernel_size=k, stride=k) if self.mode == "classic": - # Using sigmoid approximation for "box counting" - # window_avg is [0, 1]. We want close to 0 -> 0, >0 -> 1. - # Shifted Sigmoid: sigmoid(10 * (x - 0.05)) centers transition near 0.05 soft_occupied = torch.sigmoid(10.0 * (window_avg - 0.05)) - - # Sum over all spatial locations and channels - # Flatten spatial dims and sum count = soft_occupied.reshape(batch_size, -1).sum(dim=1) results.append(count) @@ -65,17 +54,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: window_sum = window_avg * kernel_vol * channels total_elems = kernel_vol * channels - # Probability of occupancy p1 = window_sum / (total_elems + 1e-8) p1 = torch.clamp(p1, eps, 1.0 - eps) p0 = 1.0 - p1 - # Binary Entropy entropy = -(p0 * torch.log(p0) + p1 * torch.log(p1)) / 0.69314718 avg_entropy = entropy.reshape(batch_size, -1).mean(dim=1) results.append(avg_entropy) - # Stack results: (B, Num_Kernels) return torch.stack(results, dim=1) @@ -90,45 +76,54 @@ def __init__( num_kernels: int, mode: str = "classic", to_binary: bool = False, - spatial_dims=2, + spatial_dims: int = 2, ): super().__init__() self.spatial_dims = spatial_dims + self.to_binary = to_binary + self.mode = mode + self.count_layer = Slice_windows_differentiable( num_kernels=num_kernels, mode=mode, spatial_dims=self.spatial_dims ) self.kernel_sizes = self.count_layer.kernel_sizes - self.mode = mode - # Pre-compute X axis (scales) - # We perform regression of Log(Count) vs Log(1/Scale) - # Log(1/k) = -Log(k) inverse_k = torch.tensor( [1.0 / k for k in self.kernel_sizes], dtype=torch.float32 ) self.register_buffer("log_inv_k", torch.log(inverse_k)) + @staticmethod + def _straight_through_binarize( + x: torch.Tensor, threshold: float = 0.5 + ) -> torch.Tensor: + """ + Straight-through estimator for binarization. + Forward: hard threshold at `threshold`. + Backward: gradient passes through as-is (identity). + This correctly separates the 'what is computed' from 'how gradients flow'. + """ + binary = (x > threshold).float() + # STE: replace forward value with binary, but keep gradient from x + return x + (binary - x).detach() + def differentiable_linregress(self, x, y): """ Differentiable slope calculation of linear regression y = mx + c. x: (Num_Kernels) or (B, Num_Kernels) y: (B, Num_Kernels) """ - # Ensure x is broadcastable if x.dim() == 1: x = x.unsqueeze(0) # (1, K) - # Detach X because we only want gradients flowing through Y (counts/model output) x = x.detach() - # Centering x_mean = x.mean(dim=1, keepdim=True) y_mean = y.mean(dim=1, keepdim=True) x_centered = x - x_mean y_centered = y - y_mean - # Slope formula: sum((x-x_bar)(y-y_bar)) / sum((x-x_bar)^2) numerator = (x_centered * y_centered).sum(dim=1) denominator = (x_centered**2).sum(dim=1) @@ -139,21 +134,26 @@ def differentiable_linregress(self, x, y): return slope def forward(self, x: torch.Tensor) -> torch.Tensor: - # Get counts/entropy for each scale + + # For GT inputs (computed under no_grad), this gives true binary box-counting. + # For pred inputs, gradients still flow through via the STE. + if self.to_binary: + x = self._straight_through_binarize(x, threshold=0.5) + y_values = self.count_layer(x) # (B, Num_Kernels) eps = 1e-8 if self.mode == "classic": - # Log-Log plot for Box Counting log_y = torch.log(y_values + eps) else: - # Entropy is already a 'dimension-like' measure log_y = y_values - # Retrieve registered buffer for X axis log_x = self.log_inv_k - # Calculate Slope (Fractal Dimension) fractal_dims = self.differentiable_linregress(log_x, log_y) + # Protects against regression instabilities producing extreme values that + # would spike the MSE loss in the parent loss function. + fractal_dims = torch.clamp(fractal_dims, min=0.0, max=float(self.spatial_dims)) + return fractal_dims diff --git a/mmv_im2im/utils/gdl_regularized.py b/mmv_im2im/utils/gdl_regularized.py index f9a12ff..5e4c75f 100644 --- a/mmv_im2im/utils/gdl_regularized.py +++ b/mmv_im2im/utils/gdl_regularized.py @@ -1,16 +1,21 @@ import torch import torch.nn as nn import torch.nn.functional as F +import math +from typing import Union from monai.losses import GeneralizedDiceFocalLoss class RegularizedGeneralizedDiceFocalLoss(nn.Module): + """ + Regularized Generalized Dice + Focal loss for deterministic segmentation models. + """ def __init__( self, n_classes: int = 2, spatial_dims: int = 2, - # --- Main Loss Parameters (GeneralizedDiceFocalLoss) --- + # --- Main Loss Parameters --- gdl_focal_weight: float = 1.0, gdl_class_weights: list = None, # --- Fractal Regularization --- @@ -40,13 +45,15 @@ def __init__( lambda_gradient: float = 0.2, connectivity_metric_density: str = "huber", connectivity_metric_gradient: str = "cosine", - # --- Hausdorff Regularization (MONAI) --- + # --- Hausdorff Regularization --- use_hausdorff_regularization: bool = False, hausdorff_weight: float = 0.1, hausdorff_downsample_scale: float = 0.5, - hausdorff_dt_iterations: int = 30, + hausdorff_dt_iterations: Union[int, str] = "auto", hausdorff_warmup_epochs: int = 10, hausdorff_include_background: bool = False, + hausdorff_distance_mode: str = "l2", + hausdorff_normalize_weights: bool = True, # --- Homology Regularization (Persistence Image) --- use_homology_regularization: bool = False, homology_weight: float = 0.1, @@ -64,6 +71,7 @@ def __init__( chunks: int = 2000, weighting_power: float = 2.0, composite_flag: bool = True, + homology_adaptive_sigma: bool = True, # --- Topological Complexity Regularization --- use_topological_complexity: bool = False, topological_complexity_weight: float = 0.1, @@ -77,25 +85,28 @@ def __init__( complexity_k_top: int = 2000, complexity_temperature: float = 0.01, complexity_auto_balance: bool = True, + complexity_normalize_lifetimes: bool = True, + # --- Warmup schedule --- + warmup_schedule: str = "linear", # "linear" or "cosine" ): super().__init__() self.n_classes = n_classes self.spatial_dims = spatial_dims + self.warmup_schedule = warmup_schedule - # Main Segmentation Loss self.gdl_focal_weight = gdl_focal_weight monai_focal_weights = ( torch.tensor(gdl_class_weights, dtype=torch.float32) if gdl_class_weights else None ) - self.main_seg_loss_calculator = GeneralizedDiceFocalLoss( softmax=True, to_onehot_y=True, weight=monai_focal_weights ) reg_used = [] + # 1. Fractal self.use_fractal_regularization = use_fractal_regularization if self.use_fractal_regularization: self.fractal_weight = fractal_weight @@ -110,7 +121,7 @@ def __init__( spatial_dims=self.spatial_dims, ) - # Topological (TI Loss) + # 2. Topological (TI Loss) self.use_topological_regularization = use_topological_regularization if self.use_topological_regularization: self.topological_weight = topological_weight @@ -126,7 +137,7 @@ def __init__( min_thick=topological_min_thick, ) - # Connectivity + # 3. Connectivity self.use_connectivity_regularization = use_connectivity_regularization if self.use_connectivity_regularization: self.connectivity_weight = connectivity_weight @@ -147,7 +158,7 @@ def __init__( metric_gradient=connectivity_metric_gradient, ) - # Hausdorff (MONAI) + # 4. Hausdorff self.use_hausdorff_regularization = use_hausdorff_regularization if self.use_hausdorff_regularization: self.hausdorff_weight = hausdorff_weight @@ -162,9 +173,11 @@ def __init__( spatial_dims=self.spatial_dims, dt_iterations=self.hausdorff_dt_iterations, include_background=self.hausdorff_include_background, + distance_mode=hausdorff_distance_mode, + normalize_weights=hausdorff_normalize_weights, ) - # Homology (Persistence Image) + # 5. Homology (Persistence Image) self.use_homology_regularization = use_homology_regularization if self.use_homology_regularization: self.homology_interval = max(1, homology_interval) @@ -183,13 +196,14 @@ def __init__( metric=homology_metric, chunks=chunks, filtering=homology_filtering, - treshold=homology_threshold, + threshold=homology_threshold, k_top=homology_k_top, weighting_power=weighting_power, composite_flag=composite_flag, + adaptive_sigma=homology_adaptive_sigma, ) - # Topological Complexity + # 6. Topological Complexity self.use_topological_complexity = use_topological_complexity if self.use_topological_complexity: self.complexity_interval = max(1, complexity_interval) @@ -210,12 +224,14 @@ def __init__( k_top=complexity_k_top, temperature=complexity_temperature, auto_balance=complexity_auto_balance, + normalize_lifetimes=complexity_normalize_lifetimes, ) if len(reg_used) > 0: print(f"Active Regularizers in GDL Loss: {reg_used}") def _get_warmup_factor(self, current_epoch, warmup_epochs): + if torch.is_tensor(current_epoch): current_epoch = current_epoch.detach().cpu().item() if torch.is_tensor(warmup_epochs): @@ -226,24 +242,27 @@ def _get_warmup_factor(self, current_epoch, warmup_epochs): if current_epoch >= warmup_epochs: return 1.0 - return current_epoch / warmup_epochs + + t = current_epoch / warmup_epochs + + if self.warmup_schedule == "cosine": + return 0.5 * (1.0 - math.cos(math.pi * t)) + else: + return t def _downsample_inputs(self, logits, y_true, scale_factor): if scale_factor >= 1.0: return logits, y_true - # Determine mode if self.spatial_dims == 3: mode = "trilinear" else: mode = "bilinear" - # Downsample Logits logits_small = F.interpolate( logits, scale_factor=scale_factor, mode=mode, align_corners=False ) - # Downsample GT if y_true.ndim == logits.ndim - 1: y_true_float = y_true.unsqueeze(1).float() else: @@ -253,7 +272,6 @@ def _downsample_inputs(self, logits, y_true, scale_factor): y_true_float, scale_factor=scale_factor, mode="nearest" ) - # Squeeze back if needed if y_true.ndim == logits.ndim - 1: y_true_small = y_true_small.squeeze(1) @@ -263,8 +281,6 @@ def forward(self, logits, y_true, epoch: int = 0): """ Computes the combined segmentation loss with structural regularizers. """ - - # --- Input Standardization --- if y_true.ndim == logits.ndim: y_true_ch = y_true y_true_flat = y_true.squeeze(1) @@ -276,29 +292,25 @@ def forward(self, logits, y_true, epoch: int = 0): f"y_true shape {y_true.shape} incompatible with logits {logits.shape}" ) - # --- Main Segmentation Loss (GDL + Focal) --- primary_loss = self.main_seg_loss_calculator(logits, y_true_ch.long()) total_loss = self.gdl_focal_weight * primary_loss - # --- Regularizers --- - - if ( + probs = None + needs_probs = ( self.use_fractal_regularization or self.use_connectivity_regularization or self.use_homology_regularization or self.use_topological_complexity - ): + ) + if needs_probs: probs = F.softmax(logits, dim=1) - # Fractal Dimension + # Fractal Dimension if self.use_fractal_regularization: fractal_factor = self._get_warmup_factor(epoch, self.fractal_warmup_epochs) if fractal_factor > 0: if self.n_classes > 1: - if self.spatial_dims == 3: - fg_probs = 1.0 - probs[:, 0:1, :, :, :] - else: - fg_probs = 1.0 - probs[:, 0:1, :, :] + fg_probs = 1.0 - probs[:, 0:1, ...] else: fg_probs = probs @@ -306,7 +318,6 @@ def forward(self, logits, y_true, epoch: int = 0): with torch.no_grad(): fd_true = self.fractal_dimension_calculator(y_fractal) fd_pred = self.fractal_dimension_calculator(fg_probs) - # fd_pred = torch.clamp(fd_pred, min=0.0, max=3.0) -> jut if nan still appear fractal_loss = F.mse_loss(fd_pred, fd_true) total_loss += (self.fractal_weight * fractal_factor) * fractal_loss @@ -336,7 +347,7 @@ def forward(self, logits, y_true, epoch: int = 0): self.connectivity_weight * connectivity_factor ) * conn_loss - # Hausdorff (MONAI) + # Hausdorff if self.use_hausdorff_regularization: hausdorff_factor = self._get_warmup_factor( epoch, self.hausdorff_warmup_epochs @@ -356,36 +367,42 @@ def forward(self, logits, y_true, epoch: int = 0): h_loss = torch.tensor(0.0, device=logits.device) total_loss += (self.hausdorff_weight * hausdorff_factor) * h_loss - # Homology (Persistence Image) + # Homology (Persistence Image) if self.use_homology_regularization: homology_factor = self._get_warmup_factor( epoch, self.homology_warmup_epochs ) - if epoch % self.homology_interval == 0: - if homology_factor > 0: + if epoch % self.homology_interval == 0 and homology_factor > 0: + if self.homology_downsample_scale >= 1.0: + probs_h = probs + y_true_h = y_true_flat + else: logits_h, y_true_h = self._downsample_inputs( logits, y_true_flat, self.homology_downsample_scale ) probs_h = F.softmax(logits_h, dim=1) - h_loss = self.homology_calculator(probs_h, y_true_h) - total_loss += (self.homology_weight * homology_factor) * h_loss + h_loss = self.homology_calculator(probs_h, y_true_h) + total_loss += (self.homology_weight * homology_factor) * h_loss - # Topological Complexity + # Topological Complexity if self.use_topological_complexity: complexity_factor = self._get_warmup_factor( epoch, self.complexity_warmup_epochs ) - if epoch % self.complexity_interval == 0: - if complexity_factor > 0: + if epoch % self.complexity_interval == 0 and complexity_factor > 0: + if self.complexity_downsample_scale >= 1.0: + probs_c = probs + y_true_c = y_true_flat + else: logits_c, y_true_c = self._downsample_inputs( logits, y_true_flat, self.complexity_downsample_scale ) probs_c = F.softmax(logits_c, dim=1) - c_loss = self.topological_complexity_calculator(probs_c, y_true_c) - total_loss += ( - self.topological_complexity_weight * complexity_factor - ) * c_loss + c_loss = self.topological_complexity_calculator(probs_c, y_true_c) + total_loss += ( + self.topological_complexity_weight * complexity_factor + ) * c_loss return total_loss diff --git a/mmv_im2im/utils/hausdorff_loss.py b/mmv_im2im/utils/hausdorff_loss.py index 883c9a0..5e7031d 100644 --- a/mmv_im2im/utils/hausdorff_loss.py +++ b/mmv_im2im/utils/hausdorff_loss.py @@ -2,26 +2,23 @@ import torch.nn as nn import torch.nn.functional as F from torch.jit import script +from typing import Union @script def chamfer_distance_transform_gpu( - input_mask: torch.Tensor, iterations: int = 30, spatial_dims: int = 2 + input_mask: torch.Tensor, iterations: int, spatial_dims: int = 2 ) -> torch.Tensor: - """ - Computes a GPU-native Distance Transform using iterative Min-Pooling. - Operates on the union of classes to ensure global geometric consistency. - """ - # Initialize: 0 inside the object, large value outside - # We use a large enough constant to act as infinity + H = input_mask.shape[-2] + W = input_mask.shape[-1] + inf_val = float(H + W) + dist_map = torch.where( input_mask > 0.5, - torch.tensor(0.0, device=input_mask.device, dtype=input_mask.dtype), - torch.tensor(200.0, device=input_mask.device, dtype=input_mask.dtype), + torch.zeros(1, device=input_mask.device, dtype=input_mask.dtype), + torch.full((1,), inf_val, device=input_mask.device, dtype=input_mask.dtype), ) - # Optimization: Use MaxPool on negative values to simulate Min-Pooling - # This stays 100% on GPU and is fully JIT-compatible for _ in range(iterations): neg_dist = -dist_map if spatial_dims == 2: @@ -29,11 +26,8 @@ def chamfer_distance_transform_gpu( else: pooled = F.max_pool3d(neg_dist, kernel_size=3, stride=1, padding=1) - # Manhattan-like propagation: dist = min(current, neighbors + 1) dist_new = -pooled + 1.0 dist_map = torch.min(dist_map, dist_new) - - # Enforce 0 inside the binary mask dist_map = torch.where(input_mask > 0.5, 0.0, dist_map) return dist_map @@ -41,62 +35,83 @@ def chamfer_distance_transform_gpu( class HausdorffLoss(nn.Module): """ - Optimized Hausdorff Loss that treats all foreground classes as a single - global structure. This prevents inter-class geometric conflicts. + Hausdorff Loss with automatic dt_iterations estimation. """ def __init__( self, spatial_dims: int = 2, - dt_iterations: int = 30, + dt_iterations: Union[int, str] = "auto", include_background: bool = False, + distance_mode: str = "l2", + normalize_weights: bool = True, + coverage_fraction: float = 0.25, + dt_iterations_min: int = 10, + dt_iterations_max: int = 200, ): super().__init__() + if not (isinstance(dt_iterations, int) or dt_iterations == "auto"): + raise ValueError( + f"dt_iterations must be a positive integer or 'auto', got: {dt_iterations!r}" + ) self.spatial_dims = spatial_dims self.dt_iterations = dt_iterations self.include_background = include_background + self.distance_mode = distance_mode + self.normalize_weights = normalize_weights + self.coverage_fraction = coverage_fraction + self.dt_iterations_min = dt_iterations_min + self.dt_iterations_max = dt_iterations_max + + @staticmethod + def estimate_iterations( + spatial_shape: torch.Size, + coverage_fraction: float = 0.25, + min_iters: int = 10, + max_iters: int = 200, + ) -> int: + min_dim = min(spatial_shape) + estimated = int(min_dim * coverage_fraction) + return max(min_iters, min(estimated, max_iters)) + + def _resolve_iterations(self, spatial_shape: torch.Size) -> int: + if self.dt_iterations != "auto": + return int(self.dt_iterations) + return self.estimate_iterations( + spatial_shape, + self.coverage_fraction, + self.dt_iterations_min, + self.dt_iterations_max, + ) def forward(self, logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: - """ - Args: - logits: (B, C, ...) raw network outputs. - y_true: (B, 1, ...) label indices. - """ - - # Convert to Probabilities and Get Global Foreground probs = F.softmax(logits, dim=1) - # Union of all foreground classes: - # We take all channels except index 0 (background) and sum them or take max. - # Max is more stable for Hausdorff. if not self.include_background and probs.shape[1] > 1: global_probs = torch.max(probs[:, 1:, ...], dim=1, keepdim=True)[0] - - # Global Ground Truth (any label > 0 is foreground) global_gt = (y_true > 0).float() else: - # If we include background or it's a single channel, use everything global_probs = probs global_gt = y_true.float() - # Compute Distance Transforms on GPU (Union based) - # DT to nearest background pixel + iters = self._resolve_iterations(global_gt.shape[2:]) + gt_dist_map = chamfer_distance_transform_gpu( - global_gt, iterations=self.dt_iterations, spatial_dims=self.spatial_dims + global_gt, iterations=iters, spatial_dims=self.spatial_dims ) - - # DT to nearest foreground pixel (Inverted DT) bg_dist_map = chamfer_distance_transform_gpu( 1.0 - global_gt, - iterations=self.dt_iterations, + iterations=iters, spatial_dims=self.spatial_dims, ) - # Weighted Loss Computation - # We penalize based on (probs - gt)^2 * (distance^2) - # This informs the optimizer to move the union of boundaries to the GT boundaries - weight_map = gt_dist_map**2 + bg_dist_map**2 + if self.distance_mode == "l2": + weight_map = gt_dist_map**2 + bg_dist_map**2 + else: + weight_map = gt_dist_map + bg_dist_map - loss = torch.mean(((global_probs - global_gt) ** 2) * weight_map) + if self.normalize_weights: + norm_factor = weight_map.detach().quantile(0.99).clamp(min=1.0) + weight_map = weight_map / norm_factor - return loss + return torch.mean(((global_probs - global_gt) ** 2) * weight_map) diff --git a/mmv_im2im/utils/homology_loss.py b/mmv_im2im/utils/homology_loss.py index 98f34d8..111bc0c 100644 --- a/mmv_im2im/utils/homology_loss.py +++ b/mmv_im2im/utils/homology_loss.py @@ -18,9 +18,9 @@ def __init__( sigma=0.05, chunks=2000, weighting_power=2.0, + adaptive_sigma: bool = False, ): super().__init__() - # Parse resolution if passed as a string from YAML if isinstance(resolution, str): resolution = tuple(int(x) for x in resolution.strip("()[]").split(",")) self.resolution = resolution @@ -28,14 +28,41 @@ def __init__( self.sigma = max(sigma, 1e-4) self.chunks = chunks self.weighting_power = weighting_power + self.adaptive_sigma = adaptive_sigma + + self._cached_grid_key = None + self._cached_gx = None + self._cached_gy = None + + def _get_grid(self, device, dtype): + key = (str(device), str(dtype)) + if self._cached_grid_key != key: + x = torch.linspace( + self.range_vals[0], + self.range_vals[1], + self.resolution[0], + device=device, + dtype=dtype, + ) + y = torch.linspace( + self.range_vals[0], + self.range_vals[1], + self.resolution[1], + device=device, + dtype=dtype, + ) + gx, gy = torch.meshgrid(x, y, indexing="ij") + self._cached_gx = gx.unsqueeze(0).unsqueeze(0) + self._cached_gy = gy.unsqueeze(0).unsqueeze(0) + self._cached_grid_key = key + return self._cached_gx, self._cached_gy def forward(self, diagrams): """ Args: - diagrams (list of torch.Tensor): List of persistence diagrams. - Each tensor is (N_points, 2). + diagrams (list of torch.Tensor): Each tensor is (N_points, 2). Returns: - torch.Tensor: Stacked Persistence Images. + torch.Tensor: Stacked Persistence Images, shape (B, H, W). """ if len(diagrams) == 0: return torch.zeros( @@ -43,67 +70,64 @@ def forward(self, diagrams): device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) - device = diagrams[0].device - dtype = diagrams[0].dtype + device = None + dtype = None + for d in diagrams: + if d.shape[0] > 0: + device = d.device + dtype = d.dtype + break + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float32 - # Pre-compute grid - x = torch.linspace( - self.range_vals[0], - self.range_vals[1], - self.resolution[0], - device=device, - dtype=dtype, - ) - y = torch.linspace( - self.range_vals[0], - self.range_vals[1], - self.resolution[1], - device=device, - dtype=dtype, - ) - grid_x, grid_y = torch.meshgrid(x, y, indexing="ij") + B = len(diagrams) + H, W = self.resolution + + gx, gy = self._get_grid(device, dtype) + + sizes = [d.shape[0] for d in diagrams] + max_pts = max(sizes) if sizes else 0 - # (1, H, W) for broadcasting - grid_x_exp = grid_x.unsqueeze(0) - grid_y_exp = grid_y.unsqueeze(0) + if max_pts == 0: + return torch.zeros((B, H, W), device=device, dtype=dtype) - norm_factor = 1.0 / (2 * (self.sigma**2) + 1e-8) - pi_list = [] + padded_b = torch.zeros(B, max_pts, device=device, dtype=dtype) + padded_d = torch.zeros(B, max_pts, device=device, dtype=dtype) + padded_w = torch.zeros(B, max_pts, device=device, dtype=dtype) + sigmas = torch.full((B,), self.sigma, device=device, dtype=dtype) - # Iterate over the batch of diagrams - for diag in diagrams: - if diag.shape[0] == 0: - pi_list.append(torch.zeros(self.resolution, device=device, dtype=dtype)) + for i, diag in enumerate(diagrams): + n = diag.shape[0] + if n == 0: continue + persistence = torch.abs(diag[:, 1] - diag[:, 0]).clamp(max=10.0) + padded_b[i, :n] = diag[:, 0] + padded_d[i, :n] = diag[:, 1] + padded_w[i, :n] = torch.pow(persistence, self.weighting_power) - b_vals, d_vals = diag[:, 0], diag[:, 1] - persistence = torch.abs(d_vals - b_vals) - persistence = torch.clamp(persistence, max=10.0) - - weights_all = torch.pow(persistence, self.weighting_power).view(-1, 1, 1) - cx_all = b_vals.view(-1, 1, 1) - cy_all = d_vals.view(-1, 1, 1) - - # Chunking to prevent OOM on high point counts - if ( - isinstance(self.chunks, int) - and self.chunks > 0 - and diag.shape[0] > self.chunks - ): - pi_accum = torch.zeros(self.resolution, device=device, dtype=dtype) - for i in range(0, diag.shape[0], self.chunks): - end = i + self.chunks - w_c, cx_c, cy_c = weights_all[i:end], cx_all[i:end], cy_all[i:end] - dist_sq = (grid_x_exp - cx_c) ** 2 + (grid_y_exp - cy_c) ** 2 - gauss = torch.exp(-dist_sq * norm_factor) - pi_accum += (w_c * gauss).sum(dim=0) - pi_list.append(pi_accum) - else: - dist_sq = (grid_x_exp - cx_all) ** 2 + (grid_y_exp - cy_all) ** 2 - gauss = torch.exp(-dist_sq * norm_factor) - pi_list.append((weights_all * gauss).sum(dim=0)) - - return torch.stack(pi_list) + if self.adaptive_sigma and n > 4: + q75 = torch.quantile(persistence, 0.75) + q25 = torch.quantile(persistence, 0.25) + iqr = (q75 - q25).clamp(min=1e-4) + sigma_eff = (0.9 * iqr * (n ** (-0.2))).detach().clamp(min=1e-4) + sigmas[i] = sigma_eff + + norm_factor = (1.0 / (2.0 * sigmas.pow(2) + 1e-8)).view(B, 1, 1, 1) + + pi_batch = torch.zeros(B, H, W, device=device, dtype=dtype) + chunk_size = max(1, self.chunks) + + for start in range(0, max_pts, chunk_size): + end = min(start + chunk_size, max_pts) + cx = padded_b[:, start:end].unsqueeze(-1).unsqueeze(-1) + cy = padded_d[:, start:end].unsqueeze(-1).unsqueeze(-1) + w = padded_w[:, start:end].unsqueeze(-1).unsqueeze(-1) + dist_sq = (gx - cx) ** 2 + (gy - cy) ** 2 + gauss = torch.exp(-dist_sq * norm_factor) + pi_batch.add_((w * gauss).sum(dim=1)) + + return pi_batch class HomologyLoss(nn.Module): @@ -122,29 +146,31 @@ def __init__( metric="smooth_l1", chunks=2000, filtering=True, - treshold=0.01, + threshold=0.01, + treshold=None, k_top=500, weighting_power=2.0, composite_flag=True, + adaptive_sigma=True, ): super().__init__() self.spatial_dims = spatial_dims + resolved_threshold = treshold if treshold is not None else threshold self.pi_generator = DifferentiablePersistenceImage( resolution=resolution, sigma=sigma, chunks=chunks, weighting_power=weighting_power, + adaptive_sigma=adaptive_sigma, ) self.features = features self.class_context = class_context self.metric = metric self.filtering = filtering - self.filter_thresh = treshold + self.filter_thresh = resolved_threshold self.k_top = k_top self.composite_flag = composite_flag - # CubicalComplex handles the heavy lifting of TDA - # Set dimension based on spatial_dims (2 for 2D images, 3 for 3D volumes) self.cubical_complex = CubicalComplex(dim=self.spatial_dims) if metric == "ssim": @@ -159,32 +185,33 @@ def _extract_persistence_diagrams_from_batch_result(self, batch_info): batch_diag_1 = [] for info in batch_info: - extracted = {0: [], 1: []} - stack = [info] if not isinstance(info, list) else info - while stack: - item = stack.pop() + d0_parts = [] + d1_parts = [] + device = None + + worklist = info if isinstance(info, list) else [info] + worklist = list(worklist) + while worklist: + item = worklist.pop() if isinstance(item, list): - stack.extend(item) + worklist.extend(item) elif hasattr(item, "diagram"): + diag_tensor = item.diagram + if device is None: + device = diag_tensor.device d = item.dimension if isinstance(d, torch.Tensor): - d = int(d.detach().cpu().numpy()) + d = int(d.item()) + if d == 0: + d0_parts.append(diag_tensor) + elif d == 1: + d1_parts.append(diag_tensor) - # We currently only extract dim 0 (components) and dim 1 (loops/tunnels) - if d in [0, 1]: - extracted[d].append(item.diagram) + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - device = info.device if hasattr(info, "device") else item.diagram.device - d0 = ( - torch.cat(extracted[0], dim=0) - if extracted[0] - else torch.empty((0, 2), device=device if extracted[0] else None) - ) - d1 = ( - torch.cat(extracted[1], dim=0) - if extracted[1] - else torch.empty((0, 2), device=device if extracted[1] else None) - ) + d0 = torch.cat(d0_parts) if d0_parts else torch.empty((0, 2), device=device) + d1 = torch.cat(d1_parts) if d1_parts else torch.empty((0, 2), device=device) batch_diag_0.append(self._filter_and_topk(-1.0 * d0)) batch_diag_1.append(self._filter_and_topk(-1.0 * d1)) @@ -197,11 +224,12 @@ def _filter_and_topk(self, diagram): persistence = torch.abs(diagram[:, 1] - diagram[:, 0]) mask = torch.isfinite(persistence) - diagram, persistence = diagram[mask], persistence[mask] - if self.filtering: - mask = persistence > self.filter_thresh - diagram, persistence = diagram[mask], persistence[mask] + mask = mask & (persistence > self.filter_thresh) + + if not mask.all(): + diagram = diagram[mask] + persistence = persistence[mask] if diagram.shape[0] > self.k_top: _, idx = torch.topk(persistence, k=self.k_top) @@ -210,9 +238,6 @@ def _filter_and_topk(self, diagram): return diagram def _compute_image_dist_batch(self, pi_p_batch, pi_g_batch): - """ - Computes distance between batches of Persistence Images. - """ if self.metric == "mse": return F.mse_loss(pi_p_batch, pi_g_batch, reduction="none").mean(dim=[1, 2]) if self.metric == "l1": @@ -225,54 +250,41 @@ def _compute_image_dist_batch(self, pi_p_batch, pi_g_batch): return 1.0 - self.ssim_func( pi_p_batch.unsqueeze(1), pi_g_batch.unsqueeze(1), data_range=1.0 ) - return torch.zeros(pi_p_batch.shape[0], device=pi_p_batch.device) def forward(self, y_pred_softmax, y_true): """ - Vectorized forward pass. y_pred_softmax: (B, C, spatial...) y_true: (B, spatial...) or (B, 1, spatial...) """ - # Determine shapes shape = y_pred_softmax.shape batch_size = shape[0] n_channels = shape[1] spatial_shape = shape[2:] - device = y_pred_softmax.device - # Ensure y_true has channel dim if y_true.dim() == len(shape) - 1: y_true = y_true.unsqueeze(1) - # Convert y_true to One-Hot: (B, C, spatial...) y_true_oh = torch.zeros_like(y_pred_softmax).scatter_(1, y_true.long(), 1) - relevant_channels = range(1, n_channels) - num_relevant = len(relevant_channels) + num_relevant = n_channels - 1 if num_relevant == 0: return torch.tensor(0.0, device=device, requires_grad=True) - # Flatten spatial dims to generalize 2D and 3D - # We need independent maps for Cubical Complex - # Reshape to (B * (C-1), spatial...) p_flat = y_pred_softmax[:, 1:, ...].reshape(-1, *spatial_shape) g_flat = y_true_oh[:, 1:, ...].reshape(-1, *spatial_shape) - total_items = p_flat.shape[0] + if not g_flat.detach().any(): + return torch.tensor(0.0, device=device, requires_grad=True) - # Prediction Diagrams - # CubicalComplex expects (Batch, C=1, Spatial...) or just (Batch, Spatial...) depending on version - # We add a channel dim 1 for the TDA engine input p_input = -1.0 * p_flat.unsqueeze(1) p_info_batch = self.cubical_complex(p_input) diag_p0_list, diag_p1_list = ( self._extract_persistence_diagrams_from_batch_result(p_info_batch) ) - # Ground Truth Diagrams (No Grad) with torch.no_grad(): g_input = -1.0 * g_flat.unsqueeze(1) g_info_batch = self.cubical_complex(g_input) @@ -280,29 +292,29 @@ def forward(self, y_pred_softmax, y_true): self._extract_persistence_diagrams_from_batch_result(g_info_batch) ) - # Generate PIs - loss_vector = torch.zeros(total_items, device=device) + loss_vector = None if self.features in ["all", "cc"]: pi_p0 = self.pi_generator(diag_p0_list) pi_g0 = self.pi_generator(diag_g0_list) - loss_vector += self._compute_image_dist_batch(pi_p0, pi_g0) + cc_loss = self._compute_image_dist_batch(pi_p0, pi_g0) + loss_vector = cc_loss if self.features in ["all", "holes"]: pi_p1 = self.pi_generator(diag_p1_list) pi_g1 = self.pi_generator(diag_g1_list) - loss_vector += self._compute_image_dist_batch(pi_p1, pi_g1) - - # Reshape loss back to (Batch, Num_Relevant_Channels) - loss_matrix = loss_vector.view(batch_size, num_relevant) + holes_loss = self._compute_image_dist_batch(pi_p1, pi_g1) + loss_vector = ( + holes_loss if loss_vector is None else loss_vector + holes_loss + ) - channel_losses = [] - for c_idx in range(num_relevant): - losses_for_channel = loss_matrix[:, c_idx] - c_loss = 0.7 * losses_for_channel.mean() + 0.3 * losses_for_channel.max() - channel_losses.append(c_loss) + if loss_vector is None: + return torch.tensor(0.0, device=device, requires_grad=True) - if len(channel_losses) > 0: - return torch.stack(channel_losses).mean() + loss_matrix = loss_vector.view(batch_size, num_relevant) - return torch.tensor(0.0, device=device, requires_grad=True) + mean_losses = loss_matrix.mean(dim=0) + max_losses = loss_matrix.max(dim=0)[0] + upper_bounds = 5.0 * mean_losses.detach() + 1e-8 + max_losses_clamped = torch.min(max_losses, upper_bounds) + return (0.7 * mean_losses + 0.3 * max_losses_clamped).mean() diff --git a/mmv_im2im/utils/topological_complexity_loss.py b/mmv_im2im/utils/topological_complexity_loss.py index 0bff011..c425ba1 100644 --- a/mmv_im2im/utils/topological_complexity_loss.py +++ b/mmv_im2im/utils/topological_complexity_loss.py @@ -23,6 +23,7 @@ def __init__( k_top=2000, temperature=0.01, auto_balance=True, + normalize_lifetimes=True, ): super().__init__() self.spatial_dims = spatial_dims @@ -33,31 +34,22 @@ def __init__( self.k_top = k_top self.temperature = max(temperature, 1e-4) self.auto_balance = auto_balance + self.normalize_lifetimes = normalize_lifetimes - # Initialize the Cubical Complex calculator. self.cubical_complex = CubicalComplex(dim=self.spatial_dims) - - # Internal state to track device during forward pass self.current_device = None def _stable_log_cosh(self, pred, target): """ - Computes log(cosh(pred - target)) in a numerically stable way. - Formula: - log(cosh(x)) = log( (e^x + e^-x) / 2 ) - = log(e^x + e^-x) - log(2) - For large |x|, this approximates to |x| - log(2). - We use softplus(2|x|) approach or direct approximation for stability. + Numerically stable log(cosh(pred - target)). + For large |x|, log(cosh(x)) ≈ |x| - log(2), avoiding overflow. """ x = pred - target abs_x = torch.abs(x) - loss = torch.where( abs_x > 50.0, abs_x - math.log(2.0), - torch.log( - torch.cosh(x) + 1e-12 - ), # 1e-12 prevents log(0) theoretically impossible but good practice + torch.log(torch.cosh(x) + 1e-12), ) return torch.mean(loss) @@ -70,16 +62,18 @@ def _extract_lifetimes_batch(self, batch_info): for info in batch_info: lts_dict = {0: [], 1: []} - stack = [info] if not isinstance(info, list) else info + device = self.current_device - while stack: - item = stack.pop() + worklist = info if isinstance(info, list) else [info] + worklist = list(worklist) + while worklist: + item = worklist.pop() if isinstance(item, list): - stack.extend(item) + worklist.extend(item) elif hasattr(item, "dimension"): dim = item.dimension if isinstance(dim, torch.Tensor): - dim = int(dim.detach().cpu().numpy()) + dim = int(dim.item()) if dim in [0, 1]: if hasattr(item, "diagram"): @@ -91,170 +85,162 @@ def _extract_lifetimes_batch(self, batch_info): if p is not None: if not isinstance(p, torch.Tensor): - p = torch.as_tensor(p, device=self.current_device) + p = torch.as_tensor(p, device=device) else: - p = p.to(self.current_device) + if device is None: + device = p.device if p.numel() > 0: - # Calculate persistence (death - birth) pers = p[:, 1] - p[:, 0] - # Filter out infinite or NaN persistence just in case valid_mask = torch.isfinite(pers) if valid_mask.any(): lts_dict[dim].append(pers[valid_mask]) - res = [] - for d in [0, 1]: - if lts_dict[d]: - res.append(torch.cat(lts_dict[d])) - else: - res.append(torch.tensor([], device=self.current_device)) + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - lts_0_batch.append(res[0]) - lts_1_batch.append(res[1]) + lts_0_batch.append( + torch.cat(lts_dict[0]) + if lts_dict[0] + else torch.tensor([], device=device) + ) + lts_1_batch.append( + torch.cat(lts_dict[1]) + if lts_dict[1] + else torch.tensor([], device=device) + ) return lts_0_batch, lts_1_batch - def _prepare_padded_lifetimes(self, lts_list): - if len(lts_list) > 0 and isinstance(lts_list[0], torch.Tensor): - device = lts_list[0].device - else: - device = self.current_device + def _normalize_lifetimes(self, lts_list): + non_empty = [lt for lt in lts_list if lt.numel() > 0] + if not non_empty: + return lts_list + global_max = torch.cat(non_empty).max().clamp(min=1e-8) + return [lt / global_max for lt in lts_list] - batch_size = len(lts_list) - padded = torch.zeros((batch_size, self.k_top), device=device) + def _prepare_padded_lifetimes(self, lts_list): + device = next( + (lt.device for lt in lts_list if lt.numel() > 0), + self.current_device or torch.device("cpu"), + ) + result = torch.zeros(len(lts_list), self.k_top, device=device) for i, lt in enumerate(lts_list): - if lt.numel() > 0: - # Filter small persistence to reduce noise and computation on irrelevant features - v = lt[lt > self.threshold] - if v.numel() > 0: - v_sorted, _ = torch.sort(v, descending=True) - n = min(v_sorted.numel(), self.k_top) - padded[i, :n] = v_sorted[:n] - return padded + if lt.numel() == 0: + continue + v = lt[lt > self.threshold] + if v.numel() == 0: + continue + v_sorted, _ = torch.sort(v, descending=True) + n = min(v_sorted.numel(), self.k_top) + result[i, :n] = v_sorted[:n] + + return result def _compute_vectorized_soft_stats(self, padded_lts): - """ - Computes differentiable statistics (mean, count, max) using soft approximations. - Includes clamping to prevent NaNs in Softmax/Exp. - """ - # 1. Stability Clamp: Prevent massive values from exploding in exp() - # If persistence is > 100 (unlikely in normalized images, but possible), clamp it. + padded_clamped = torch.clamp(padded_lts, max=50.0) - # 2. Weighted Mean (Softmax based) - # If padded_lts are all zeros, softmax is uniform. weights = F.softmax(padded_clamped / self.temperature, dim=1) mean_top = torch.sum(padded_lts * weights, dim=1) - # 3. Soft Count (Sigmoid based) - # Calculate distance from threshold, scale by temperature - # Using padded_clamped helps stability, but we use original for logic accuracy - # unless it's huge. sigmoid_in = (padded_lts - self.threshold) / self.temperature - # Clamp input to sigmoid to avoid extremely large negative/positive values (though sigmoid handles them well, gradients can vanish) - count_top = torch.sigmoid(torch.clamp(sigmoid_in, min=-50, max=50)).sum(dim=1) - # 4. Max Value + count_top = ( + torch.sigmoid(torch.clamp(sigmoid_in, min=-50, max=50)).sum(dim=1) + / self.k_top + ) + max_val = torch.max(padded_lts, dim=1)[0] return torch.stack([mean_top, count_top, max_val], dim=1) def _compute_metric(self, pred_stats, target_stats): - """ - Route to the correct metric calculation. - """ if self.metric == "wasserstein": - # Note: For Wasserstein matching, stats are just raw padded vectors usually - # But based on the code flow, this block is handled directly in forward for padded vectors - # This method handles the 'stats' metrics. return F.mse_loss(pred_stats, target_stats) - elif self.metric == "mse": return F.mse_loss(pred_stats, target_stats) - elif self.metric == "log_cosh": return self._stable_log_cosh(pred_stats, target_stats) - elif self.metric == "l1": return F.l1_loss(pred_stats, target_stats) - else: - # Default to MSE return F.mse_loss(pred_stats, target_stats) + def _harmonic_balance(self, loss_a, loss_b): + denom = (loss_a + loss_b).clamp(min=1e-8) + return (2.0 * loss_a * loss_b) / denom + def forward(self, y_pred_softmax, y_true): device = y_pred_softmax.device self.current_device = device - # Generic shape unpacking shape = y_pred_softmax.shape n_channels = shape[1] spatial_shape = shape[2:] if self.class_context == "general": - # Skip background (assuming index 0 is background) if n_channels > 1: p_relevant = y_pred_softmax[:, 1:, ...] y_true_oh = F.one_hot(y_true.long(), num_classes=n_channels) - # Permute OH to (B, C, Spatial...) permute_dims = (0, len(shape) - 1) + tuple(range(1, len(shape) - 1)) y_true_oh = y_true_oh.permute(*permute_dims).float() g_relevant = y_true_oh[:, 1:, ...] else: - # Binary case handling p_relevant = y_pred_softmax g_relevant = y_true.unsqueeze(1).float() else: - # Custom context logic can be added here, currently defaulting to full pass p_relevant = y_pred_softmax y_true_oh = F.one_hot(y_true.long(), num_classes=n_channels) permute_dims = (0, len(shape) - 1) + tuple(range(1, len(shape) - 1)) y_true_oh = y_true_oh.permute(*permute_dims).float() g_relevant = y_true_oh - # Flatten batch and channels to treat them as independent maps - # This allows computing topology for all images/channels in one batched call p_flat = p_relevant.reshape(-1, *spatial_shape) g_flat = g_relevant.reshape(-1, *spatial_shape) - # Extract diagrams - # Invert image (-1.0 *) for sublevel filtration equivalent to superlevel set filtration p_input = -1.0 * p_flat.unsqueeze(1) p_info_batch = self.cubical_complex(p_input) lts_p0, lts_p1 = self._extract_lifetimes_batch(p_info_batch) - # Ground Truth Diagrams (No Grad needed) with torch.no_grad(): g_input = -1.0 * g_flat.unsqueeze(1) g_info_batch = self.cubical_complex(g_input) lts_g0, lts_g1 = self._extract_lifetimes_batch(g_info_batch) + if self.normalize_lifetimes: + lts_p0 = self._normalize_lifetimes(lts_p0) + lts_g0 = self._normalize_lifetimes(lts_g0) + lts_p1 = self._normalize_lifetimes(lts_p1) + lts_g1 = self._normalize_lifetimes(lts_g1) + total_loss = torch.tensor(0.0, device=device) + loss_cc = torch.tensor(0.0, device=device) + loss_holes = torch.tensor(0.0, device=device) - # --- Dim 0 (Components) --- if self.features in ["all", "cc"]: vp0 = self._prepare_padded_lifetimes(lts_p0) vg0 = self._prepare_padded_lifetimes(lts_g0) - if self.metric == "wasserstein": - # Wasserstein on 1D slices approximates to MSE on sorted vectors (Sliced Wasserstein) - # We use MSE on the padded sorted lifetimes directly. + if vp0.detach().sum() == 0 and vg0.detach().sum() == 0: + loss_cc = torch.tensor(0.0, device=device) + elif self.metric == "wasserstein": loss_cc = F.mse_loss(vp0, vg0) else: stats_p0 = self._compute_vectorized_soft_stats(vp0) stats_g0 = self._compute_vectorized_soft_stats(vg0) loss_cc = self._compute_metric(stats_p0, stats_g0) - total_loss += loss_cc + total_loss = total_loss + loss_cc - # --- Dim 1 (Holes/Tunnels) --- if self.features in ["all", "holes"]: vp1 = self._prepare_padded_lifetimes(lts_p1) vg1 = self._prepare_padded_lifetimes(lts_g1) - if self.metric == "wasserstein": + if vp1.detach().sum() == 0 and vg1.detach().sum() == 0: + loss_holes = torch.tensor(0.0, device=device) + elif self.metric == "wasserstein": loss_holes = F.mse_loss(vp1, vg1) else: stats_p1 = self._compute_vectorized_soft_stats(vp1) @@ -262,9 +248,9 @@ def forward(self, y_pred_softmax, y_true): loss_holes = self._compute_metric(stats_p1, stats_g1) if self.auto_balance and self.features == "all": - total_loss = (total_loss + loss_holes) * 0.5 + total_loss = self._harmonic_balance(loss_cc, loss_holes) else: - total_loss += loss_holes + total_loss = total_loss + loss_holes if not torch.isfinite(total_loss): return torch.tensor(0.0, device=device, requires_grad=True) diff --git a/mmv_im2im/utils/variable_collate.py b/mmv_im2im/utils/variable_collate.py new file mode 100644 index 0000000..1318b76 --- /dev/null +++ b/mmv_im2im/utils/variable_collate.py @@ -0,0 +1,199 @@ +""" +variable_collate.py +------------------- +Custom PyTorch collate function for variable-size 3-D (or 2-D) images +paired with spherical-harmonic GT regression vectors. + +------------------ +PyTorch's default collate calls torch.stack() which requires all tensors +in a batch to have identical shapes. Our images have variable spatial +sizes so stacking fails. + +-------- +1. Find the maximum size along each spatial dimension across the batch. +2. Round those maxima UP to the nearest multiple of k (default 16) so + the padded tensors are compatible with the network's downsampling. +3. Pad every sample to that common target size SYMMETRICALLY (half before, + half after), adjusting x0 in the GT vector accordingly. +4. Stack the now-uniform tensors. + +The GT adjustment follows the same logic as DivisiblePadWithGTAdjustd: + x0_new[i] = x0_old[i] + pad_before[i] +The SH coefficients are left untouched. + +Usage +----- +Pass `collate_fn=variable_size_collate_fn` (or the factory variant +`make_collate_fn(k=16)`) to your DataLoader. See variable_datamodule.py +for how this is injected automatically. +""" + +from functools import partial +from typing import Dict, List +import math +import torch + + +from mmv_im2im.utils.custom_transforms import apply_pad_to_tensor + +# --------------------------------------------------------------------------- +# Core collate logic +# --------------------------------------------------------------------------- + + +def variable_size_collate_fn( + batch: List[Dict[str, torch.Tensor]], + k: int = 16, + mode: str = "constant", + constant_value: float = 0.0, + n_coord_dims: int = 3, +) -> Dict[str, torch.Tensor]: + """ + Collate a list of sample dicts with variable-size images into a batch. + + Expected keys in each sample dict + ---------------------------------- + "IM" : torch.Tensor [C, *spatial] the image (any spatial size) + "GT" : torch.Tensor [3 + n_coeffs] the GT regression vector + + Additional keys (e.g. "CM" costmaps) are stacked with torch.stack if + they already have the same shape, or padded the same way as "IM" if + their spatial dims match the image. + + Parameters + ---------- + batch : list of sample dicts produced by the Dataset/transforms + k : divisibility target (16 for AttentionUnet 4× downsampling) + mode : padding mode for torch.nn.functional.pad + constant_value : fill value when mode="constant" + n_coord_dims : how many leading GT elements represent spatial coords + """ + if len(batch) == 0: + return {} + + # ── Determine target spatial shape ──────────────────────────────── + sample0 = batch[0] + spatial_ndim = len(sample0["IM"].shape) - 1 # exclude channel dim + max_spatial = [0] * spatial_ndim + + for sample in batch: + for dim_i, s in enumerate(sample["IM"].shape[1:]): + if s > max_spatial[dim_i]: + max_spatial[dim_i] = s + + # Round up each max dim to the nearest multiple of k + target_dims = [math.ceil(d / k) * k for d in max_spatial] + + # ── Pad images and adjust GT ─────────────────────────────────────── + padded_ims: List[torch.Tensor] = [] + adjusted_gts: List[torch.Tensor] = [] + + # Track pad_before for each sample (needed for optional CM padding) + all_pad_before: List[List[int]] = [] + + for sample in batch: + img = sample["IM"] + gt = sample["GT"] + + if not isinstance(img, torch.Tensor): + img = torch.tensor(img) + if not isinstance(gt, torch.Tensor): + gt = torch.tensor(gt) + + current_spatial = list(img.shape[1:]) + + # Compute symmetric padding to reach target_dims + pad_before = [(t - c) // 2 for t, c in zip(target_dims, current_spatial)] + pad_after = [ + t - c - pb for t, c, pb in zip(target_dims, current_spatial, pad_before) + ] + + img_padded = apply_pad_to_tensor( + img.float(), pad_before, pad_after, mode=mode, value=constant_value + ) + + # Adjust the coordinate elements of GT + gt_adjusted = gt.clone().float() + n_adjust = min(n_coord_dims, len(pad_before)) + for i in range(n_adjust): + gt_adjusted[i] = gt[i] + pad_before[i] + + padded_ims.append(img_padded) + adjusted_gts.append(gt_adjusted) + all_pad_before.append(pad_before) + + result: Dict[str, torch.Tensor] = { + "IM": torch.stack(padded_ims, dim=0), + "GT": torch.stack(adjusted_gts, dim=0), + } + + # ── Handle any extra keys (e.g., "CM" costmaps) ─────────────────── + extra_keys = [k_ for k_ in sample0.keys() if k_ not in ("IM", "GT")] + for key in extra_keys: + samples_key = [s[key] for s in batch if key in s] + if len(samples_key) != len(batch): + continue # skip if not all samples have this key + + # Try to stack without padding first + try: + result[key] = torch.stack( + [ + ( + t.float() + if isinstance(t, torch.Tensor) + else torch.tensor(t).float() + ) + for t in samples_key + ], + dim=0, + ) + except RuntimeError: + # Shape mismatch → pad the same way as IM + padded_key = [] + for sample, pad_before in zip(batch, all_pad_before): + if key not in sample: + continue + t = sample[key] + if not isinstance(t, torch.Tensor): + t = torch.tensor(t) + current_spatial = list(t.shape[1:]) + pad_after = [ + td - c - pb + for td, c, pb in zip(target_dims, current_spatial, pad_before) + ] + t_padded = apply_pad_to_tensor( + t.float(), pad_before, pad_after, mode=mode, value=constant_value + ) + padded_key.append(t_padded) + if padded_key: + result[key] = torch.stack(padded_key, dim=0) + + return result + + +# --------------------------------------------------------------------------- +# Factory for easy partial application +# --------------------------------------------------------------------------- + + +def make_collate_fn( + k: int = 16, + mode: str = "constant", + constant_value: float = 0.0, + n_coord_dims: int = 3, +): + """ + Returns a collate_fn with the given configuration. + + Example + ------- + collate_fn = make_collate_fn(k=16) + loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) + """ + return partial( + variable_size_collate_fn, + k=k, + mode=mode, + constant_value=constant_value, + n_coord_dims=n_coord_dims, + ) diff --git a/mmv_im2im/utils/variable_datamodule.py b/mmv_im2im/utils/variable_datamodule.py new file mode 100644 index 0000000..5118372 --- /dev/null +++ b/mmv_im2im/utils/variable_datamodule.py @@ -0,0 +1,184 @@ +""" +variable_datamodule.py +---------------------- +Thin LightningDataModule wrapper around mmv_im2im's existing data module +that injects the variable_size_collate_fn into both train and validation +DataLoaders WITHOUT modifying mmv_im2im source code. + +-------------- +mmv_im2im's DataModule builds its DataLoaders internally, and the YAML +config has no slot for `collate_fn` (it's not a standard serialisable +object). The cleanest solution is to: + 1. Let the original DataModule set up everything (dataset, transforms, + split logic, sampler, etc.) as normal. + 2. Override train_dataloader() and val_dataloader() to swap in our + custom collate function on the DataLoader that would otherwise be + returned. + +This keeps 100 % compatibility with the rest of the training stack. + +Usage in run_training.py +------------------------ + from variable_datamodule import VariableSizeDataModule + + base_data_module = get_data_module(cfg.data) # original mmv_im2im call + data_module = VariableSizeDataModule( + base_data_module, + k=16, # divisibility target + mode="constant", # padding mode + ) + trainer.fit(model=model, datamodule=data_module) +""" + +from typing import Optional + +import lightning as pl +from torch.utils.data import DataLoader + +from mmv_im2im.utils.variable_collate import make_collate_fn + + +class VariableSizeDataModule(pl.LightningDataModule): + """ + Wraps any LightningDataModule and replaces the collate_fn of the + DataLoaders it produces with variable_size_collate_fn. + + Parameters + ---------- + base_module : pl.LightningDataModule + The original data module returned by mmv_im2im.get_data_module(). + k : int + All spatial dimensions will be padded to multiples of k. + Use k=16 for AttentionUnet with 4 downsampling stages. + mode : str + Padding mode ("constant" recommended – zero-fill after normalisation). + constant_value : float + Fill value when mode="constant". + n_coord_dims : int + Number of leading elements in the GT vector that hold spatial + coordinates and must be shifted when padding is applied. + Default 3 → (z, y, x) centre-of-mass coordinates. + """ + + def __init__( + self, + base_module: pl.LightningDataModule, + k: int = 16, + mode: str = "constant", + constant_value: float = 0.0, + n_coord_dims: int = 3, + ): + super().__init__() + self._base = base_module + self._collate_fn = make_collate_fn( + k=k, + mode=mode, + constant_value=constant_value, + n_coord_dims=n_coord_dims, + ) + + # ------------------------------------------------------------------ + # Delegation helpers + # ------------------------------------------------------------------ + + def prepare_data(self): + self._base.prepare_data() + + def setup(self, stage: Optional[str] = None): + self._base.setup(stage) + + # ------------------------------------------------------------------ + # DataLoader overrides + # ------------------------------------------------------------------ + + def _rebuild_loader(self, original_loader: DataLoader) -> DataLoader: + """ + Rebuild a DataLoader with the same parameters but our collate_fn. + + We replicate every constructor argument from the existing loader's + __dict__ so nothing is lost (batch_size, num_workers, pin_memory, + sampler, etc.). + """ + init_args = dict( + dataset=original_loader.dataset, + batch_size=original_loader.batch_size, + shuffle=isinstance( + original_loader.sampler, + __import__("torch").utils.data.RandomSampler, + ), + sampler=( + original_loader.sampler + if not isinstance( + original_loader.sampler, + __import__("torch").utils.data.RandomSampler, + ) + else None + ), + batch_sampler=( + original_loader.batch_sampler + if original_loader.batch_sampler is not None + and not isinstance( + original_loader.batch_sampler, + __import__("torch").utils.data.BatchSampler, + ) + else None + ), + num_workers=original_loader.num_workers, + collate_fn=self._collate_fn, # ← our injection + pin_memory=original_loader.pin_memory, + drop_last=original_loader.drop_last, + timeout=original_loader.timeout, + worker_init_fn=original_loader.worker_init_fn, + prefetch_factor=( + original_loader.prefetch_factor + if original_loader.num_workers > 0 + else None + ), + persistent_workers=original_loader.persistent_workers, + ) + + # Remove None-valued keys that DataLoader doesn't accept as None + # (batch_sampler=None conflicts with batch_size, for example) + if init_args["batch_sampler"] is not None: + # batch_sampler is mutually exclusive with batch_size / shuffle / sampler + for conflict_key in ("batch_size", "shuffle", "sampler", "drop_last"): + init_args.pop(conflict_key, None) + + if init_args.get("shuffle") is False and init_args.get("sampler") is None: + init_args.pop("sampler", None) + + return DataLoader( + **{ + k: v + for k, v in init_args.items() + if v is not None or k in ("batch_size",) + } + ) + + def train_dataloader(self) -> DataLoader: + original = self._base.train_dataloader() + rebuilt = self._rebuild_loader(original) + print( + f"[VariableSizeDataModule] train DataLoader rebuilt with " + f"variable_size_collate_fn (batch_size={rebuilt.batch_size})" + ) + return rebuilt + + def val_dataloader(self) -> DataLoader: + original = self._base.val_dataloader() + rebuilt = self._rebuild_loader(original) + print( + f"[VariableSizeDataModule] val DataLoader rebuilt with " + f"variable_size_collate_fn (batch_size={rebuilt.batch_size})" + ) + return rebuilt + + # ------------------------------------------------------------------ + # Forward any attribute not defined here to the base module + # ------------------------------------------------------------------ + + def __getattr__(self, name): + # Avoid infinite recursion for attributes set in __init__ + if name.startswith("_"): + raise AttributeError(name) + return getattr(self._base, name) diff --git a/pyproject.toml b/pyproject.toml index 9922ff1..b1250b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mmv_im2im" -version = "0.8.0" +version = "0.7.1" authors = [ { name="Jianxu Chen", email="jianxuchen.ai@gmail.com" }, ] diff --git a/setup.cfg b/setup.cfg index c784fda..add961b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.8.0 +current_version = 0.7.1 commit = True tag = True