From 1b5bf03f3ef2d06cbd071e0592146806a6b67068 Mon Sep 17 00:00:00 2001 From: Jair Sanchez Date: Mon, 9 Mar 2026 18:00:43 +0100 Subject: [PATCH] Topological regularizers minor fixes. --- mmv_im2im/map_extractor.py | 15 +- mmv_im2im/models/nets/ProbUnet.py | 11 +- mmv_im2im/models/pl_ProbUnet.py | 14 +- mmv_im2im/models/pl_ProbUnet_old.py | 163 +++++++++++++++++++ mmv_im2im/postprocessing/basic_collection.py | 2 + mmv_im2im/tests/test_dummy.py | 4 +- 6 files changed, 189 insertions(+), 20 deletions(-) create mode 100644 mmv_im2im/models/pl_ProbUnet_old.py diff --git a/mmv_im2im/map_extractor.py b/mmv_im2im/map_extractor.py index 855f95d..8df20f7 100644 --- a/mmv_im2im/map_extractor.py +++ b/mmv_im2im/map_extractor.py @@ -55,7 +55,6 @@ def __init__(self, cfg): self.model_cfg = cfg.model self.data_cfg = cfg.data - # define variables self.model = None self.data = None self.pre_process = None @@ -111,7 +110,10 @@ def setup_data_processing(self): self.pre_process = parse_monai_ops_vanilla(self.data_cfg.preprocess) def process_one_image( - self, img: Union[DaskArray, NumpyArray], out_fn: Union[str, Path] = None + self, + img: Union[DaskArray, NumpyArray], + dim: int = 2, + out_fn: Union[str, Path] = None, ): if isinstance(img, DaskArray): @@ -131,7 +133,8 @@ def process_one_image( # run pre-processing on tensor if needed if self.pre_process is not None: x = self.pre_process(x) - x = x[0] + if dim == 2: + x = x[0] # choose different inference function for different types of models with torch.no_grad(): @@ -412,7 +415,7 @@ def _process_vol2slice(self, img, pred_cfg, original_postprocess, pert_opt): else: inp = im_input - logits = self.process_one_image(inp) + logits = self.process_one_image(inp, dim=2) samplesz.append(np.squeeze(logits)) # Multi-prediction aggregation @@ -464,6 +467,7 @@ def _process_vol2slice(self, img, pred_cfg, original_postprocess, pert_opt): def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt): """New direct 3D volume processing logic.""" # Input img is (C, Z, Y, X) or (Z, Y, X) + # Handle dummy channel if missing if len(img.shape) == 3: img = img[None, ...] # (C, Z, Y, X) @@ -478,7 +482,7 @@ def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt): # Process one image will handle 4D input by adding batch dim -> (1, C, Z, Y, X) # Ensure spatial_dims is set to 3 in setup - logits = self.process_one_image(inp) + logits = self.process_one_image(inp, dim=3) samples_vol.append(np.squeeze(logits)) # Multi-prediction aggregation (3D) @@ -509,7 +513,6 @@ def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt): seg_full = seg # Shape (Z, Y, X) or (C, Z, Y, X) # Ensure seg_full is (Z, Y, X) if single class, usually squeezed. - # Use simple squeeze if C=1. if seg_full.ndim == 4 and seg_full.shape[0] == 1: seg_full = seg_full[0] diff --git a/mmv_im2im/models/nets/ProbUnet.py b/mmv_im2im/models/nets/ProbUnet.py index 5a8291b..c4d3277 100644 --- a/mmv_im2im/models/nets/ProbUnet.py +++ b/mmv_im2im/models/nets/ProbUnet.py @@ -16,7 +16,7 @@ def __init__( dropout=0.0, ): super().__init__() - # padding=None in MONAI Convolution defaults to "same" padding + layers = [ Convolution( spatial_dims, @@ -227,14 +227,17 @@ def forward(self, x, seg=None, train_posterior=True): mu_post, logvar_post = None, None if train_posterior and seg is not None: + if seg.shape[1] != self.out_channels: seg_temp = seg if seg_temp.shape[1] == 1: seg_temp = seg_temp.squeeze(1) + + seg_one_hot = F.one_hot(seg_temp.long(), num_classes=self.out_channels) + + dims = list(range(seg_one_hot.ndim)) seg_one_hot = ( - F.one_hot(seg_temp.long(), num_classes=self.out_channels) - .permute(0, 3, 1, 2) - .float() + seg_one_hot.permute(0, dims[-1], *dims[1:-1]).contiguous().float() ) else: seg_one_hot = seg.float() diff --git a/mmv_im2im/models/pl_ProbUnet.py b/mmv_im2im/models/pl_ProbUnet.py index e2d822b..14a6fe2 100644 --- a/mmv_im2im/models/pl_ProbUnet.py +++ b/mmv_im2im/models/pl_ProbUnet.py @@ -40,20 +40,18 @@ def forward(self, x, seg=None, train_posterior=False): def run_step(self, batch): x, y = batch["IM"], batch["GT"] - # Ensure x is (B, C, H, W) - if x.ndim == 5 and x.shape[-1] == 1: + if x.ndim > 4 and x.shape[-1] == 1: x = x.squeeze(-1) - # Ensure y is (B, 1, H, W) for passing to model and loss - if y.ndim == 5 and y.shape[-1] == 1: + + if y.ndim > 4 and y.shape[-1] == 1: y = y.squeeze(-1) - if y.ndim == 3: - y = y.unsqueeze(1) # Add channel dim if missing (B, H, W) -> (B, 1, H, W) + + if y.ndim == x.ndim - 1: + y = y.unsqueeze(1) # Forward pass (Train Posterior) output = self(x, seg=y, train_posterior=True) - # Calculate Loss - # Ensure 'epoch' is a number, not a tensor, to avoid issues in elbo_loss warmup current_ep = int(self.current_epoch) loss = self.criterion( diff --git a/mmv_im2im/models/pl_ProbUnet_old.py b/mmv_im2im/models/pl_ProbUnet_old.py new file mode 100644 index 0000000..f25aa89 --- /dev/null +++ b/mmv_im2im/models/pl_ProbUnet_old.py @@ -0,0 +1,163 @@ +import numpy as np +from typing import Dict +from pathlib import Path +from random import randint +import lightning as pl +import torch +from bioio.writers import OmeTiffWriter + +from mmv_im2im.utils.misc import ( + parse_config, + parse_config_func, + parse_config_func_without_params, +) +from mmv_im2im.utils.model_utils import init_weights + + +class Model(pl.LightningModule): + def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = False): + super().__init__() + self.net = parse_config(model_info_xx.net) + init_weights(self.net, init_type="kaiming") + + self.model_info = model_info_xx + self.verbose = verbose + self.weighted_loss = False + if train: + self.criterion = parse_config(model_info_xx.criterion) + self.optimizer_func = parse_config_func(model_info_xx.optimizer) + + # Store these as attributes for access in run_step/training_step/validation_step + self.last_prior_mu = None + self.last_prior_logvar = None + self.last_post_mu = None + self.last_post_logvar = None + + def forward(self, x, y=None): + # The underlying ProbabilisticUNet returns multiple values. + # Capture them here and store them as instance attributes. + logits, prior_mu, prior_logvar, post_mu, post_logvar = self.net(x, y) + + # Store for use in run_step (which calculates loss) + self.last_prior_mu = prior_mu + self.last_prior_logvar = prior_logvar + self.last_post_mu = post_mu + self.last_post_logvar = post_logvar + + # For the 'Model' (LightningModule) forward, only return the logits + # This makes the API consistent with other models in your framework. + return logits + + def configure_optimizers(self): + optimizer = self.optimizer_func(self.parameters()) + if self.model_info.scheduler is None: + return optimizer + else: + scheduler_func = parse_config_func_without_params(self.model_info.scheduler) + lr_scheduler = scheduler_func( + optimizer, **self.model_info.scheduler["params"] + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": lr_scheduler, + "monitor": "val_loss", + "interval": "epoch", + "frequency": 1, + "strict": True, + }, + } + + def run_step(self, batch, validation_stage): + x = batch["IM"] + y = batch["GT"] + + if x.size(-1) == 1: + x = torch.squeeze(x, dim=-1) + y = torch.squeeze(y, dim=-1) + + # Call forward pass of the LightningModule. + # This will internally call self.net(x,y) and store the extra outputs. + logits = self(x, y) # This is now just 'logits' + + # Calculate loss using the stored attributes + # Ensure post_mu and post_logvar are not None if y was provided + # The ELBOLoss expects these to be tensors, not None. + if self.last_post_mu is None or self.last_post_logvar is None: + raise ValueError( + "Posterior distributions (mu, logvar) were not computed. Ensure 'y' is provided during training." + ) + + loss = self.criterion( + logits, + y, + self.last_prior_mu, + self.last_prior_logvar, + self.last_post_mu, + self.last_post_logvar, + ) + + return loss, logits + + def training_step(self, batch, batch_idx): + loss, y_hat = self.run_step(batch, validation_stage=False) + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + if self.verbose and batch_idx == 0: + self.log_images(batch, y_hat, "train") + + return loss + + def validation_step(self, batch, batch_idx): + loss, y_hat = self.run_step(batch, validation_stage=True) + self.log( + "val_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + + if self.verbose and batch_idx == 0: + self.log_images(batch, y_hat, "val") + + return loss + + def log_images(self, batch, y_hat, stage): + src = batch["IM"] + tar = batch["GT"] + + save_path = Path(self.trainer.log_dir) + save_path.mkdir(parents=True, exist_ok=True) + + act = torch.nn.Softmax(dim=1) + yhat_act = act(y_hat) + + src_out = np.squeeze(src[0].detach().cpu().numpy()).astype(float) + tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float) + prd_out = np.squeeze(yhat_act[0].detach().cpu().numpy()).astype(float) + + def get_dim_order(arr): + dims = len(arr.shape) + return {2: "YX", 3: "ZYX", 4: "CZYX"}.get(dims, "YX") + + rand_tag = randint(1, 1000) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_src_{rand_tag}.tiff" + OmeTiffWriter.save(src_out, out_fn, dim_order=get_dim_order(src_out)) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_tar_{rand_tag}.tiff" + OmeTiffWriter.save(tar_out, out_fn, dim_order=get_dim_order(tar_out)) + + out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_prd_{rand_tag}.tiff" + OmeTiffWriter.save(prd_out, out_fn, dim_order=get_dim_order(prd_out)) diff --git a/mmv_im2im/postprocessing/basic_collection.py b/mmv_im2im/postprocessing/basic_collection.py index 2dbbc38..290b0c2 100644 --- a/mmv_im2im/postprocessing/basic_collection.py +++ b/mmv_im2im/postprocessing/basic_collection.py @@ -78,6 +78,8 @@ def generate_classmap(im: Union[np.ndarray, torch.Tensor]) -> np.ndarray: # convert tensor to numpy if torch.is_tensor(im): im = im.cpu().numpy() + if len(im.shape) == 4 and im.shape[0] != 1: + im = im[None, ...] assert len(im.shape) == 4 or len(im.shape) == 5, "extract seg only accepts 4D/5D" assert im.shape[0] == 1, "extract seg requires first dim to be 1" diff --git a/mmv_im2im/tests/test_dummy.py b/mmv_im2im/tests/test_dummy.py index 58dd2a0..73b446f 100644 --- a/mmv_im2im/tests/test_dummy.py +++ b/mmv_im2im/tests/test_dummy.py @@ -49,9 +49,9 @@ def test_connectivity_loss(dims): pred_softmax = F.softmax(logits, dim=1) target_one_hot = F.one_hot(target_indices, num_classes=n_classes).float() if dims == 2: - target_one_hot = target_one_hot.permute(0, 3, 1, 2) + target_one_hot = target_one_hot.permute(0, 3, 1, 2) else: - target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3) + target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3) loss = loss_fn(pred_softmax, target_one_hot) assert not torch.isnan(loss)