Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions mmv_im2im/map_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(self, cfg):
self.model_cfg = cfg.model
self.data_cfg = cfg.data

# define variables
self.model = None
self.data = None
self.pre_process = None
Expand Down Expand Up @@ -111,7 +110,10 @@ def setup_data_processing(self):
self.pre_process = parse_monai_ops_vanilla(self.data_cfg.preprocess)

def process_one_image(
self, img: Union[DaskArray, NumpyArray], out_fn: Union[str, Path] = None
self,
img: Union[DaskArray, NumpyArray],
dim: int = 2,
out_fn: Union[str, Path] = None,
):

if isinstance(img, DaskArray):
Expand All @@ -131,7 +133,8 @@ def process_one_image(
# run pre-processing on tensor if needed
if self.pre_process is not None:
x = self.pre_process(x)
x = x[0]
if dim == 2:
x = x[0]

# choose different inference function for different types of models
with torch.no_grad():
Expand Down Expand Up @@ -412,7 +415,7 @@ def _process_vol2slice(self, img, pred_cfg, original_postprocess, pert_opt):
else:
inp = im_input

logits = self.process_one_image(inp)
logits = self.process_one_image(inp, dim=2)
samplesz.append(np.squeeze(logits))

# Multi-prediction aggregation
Expand Down Expand Up @@ -464,6 +467,7 @@ def _process_vol2slice(self, img, pred_cfg, original_postprocess, pert_opt):
def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt):
"""New direct 3D volume processing logic."""
# Input img is (C, Z, Y, X) or (Z, Y, X)
# Handle dummy channel if missing
if len(img.shape) == 3:
img = img[None, ...] # (C, Z, Y, X)

Expand All @@ -478,7 +482,7 @@ def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt):

# Process one image will handle 4D input by adding batch dim -> (1, C, Z, Y, X)
# Ensure spatial_dims is set to 3 in setup
logits = self.process_one_image(inp)
logits = self.process_one_image(inp, dim=3)
samples_vol.append(np.squeeze(logits))

# Multi-prediction aggregation (3D)
Expand Down Expand Up @@ -509,7 +513,6 @@ def _process_vol2vol(self, img, pred_cfg, original_postprocess, pert_opt):

seg_full = seg # Shape (Z, Y, X) or (C, Z, Y, X)
# Ensure seg_full is (Z, Y, X) if single class, usually squeezed.
# Use simple squeeze if C=1.
if seg_full.ndim == 4 and seg_full.shape[0] == 1:
seg_full = seg_full[0]

Expand Down
11 changes: 7 additions & 4 deletions mmv_im2im/models/nets/ProbUnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
dropout=0.0,
):
super().__init__()
# padding=None in MONAI Convolution defaults to "same" padding

layers = [
Convolution(
spatial_dims,
Expand Down Expand Up @@ -227,14 +227,17 @@ def forward(self, x, seg=None, train_posterior=True):
mu_post, logvar_post = None, None

if train_posterior and seg is not None:

if seg.shape[1] != self.out_channels:
seg_temp = seg
if seg_temp.shape[1] == 1:
seg_temp = seg_temp.squeeze(1)

seg_one_hot = F.one_hot(seg_temp.long(), num_classes=self.out_channels)

dims = list(range(seg_one_hot.ndim))
seg_one_hot = (
F.one_hot(seg_temp.long(), num_classes=self.out_channels)
.permute(0, 3, 1, 2)
.float()
seg_one_hot.permute(0, dims[-1], *dims[1:-1]).contiguous().float()
)
else:
seg_one_hot = seg.float()
Expand Down
14 changes: 6 additions & 8 deletions mmv_im2im/models/pl_ProbUnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,18 @@ def forward(self, x, seg=None, train_posterior=False):
def run_step(self, batch):
x, y = batch["IM"], batch["GT"]

# Ensure x is (B, C, H, W)
if x.ndim == 5 and x.shape[-1] == 1:
if x.ndim > 4 and x.shape[-1] == 1:
x = x.squeeze(-1)
# Ensure y is (B, 1, H, W) for passing to model and loss
if y.ndim == 5 and y.shape[-1] == 1:

if y.ndim > 4 and y.shape[-1] == 1:
y = y.squeeze(-1)
if y.ndim == 3:
y = y.unsqueeze(1) # Add channel dim if missing (B, H, W) -> (B, 1, H, W)

if y.ndim == x.ndim - 1:
y = y.unsqueeze(1)

# Forward pass (Train Posterior)
output = self(x, seg=y, train_posterior=True)

# Calculate Loss
# Ensure 'epoch' is a number, not a tensor, to avoid issues in elbo_loss warmup
current_ep = int(self.current_epoch)

loss = self.criterion(
Expand Down
163 changes: 163 additions & 0 deletions mmv_im2im/models/pl_ProbUnet_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import numpy as np
from typing import Dict
from pathlib import Path
from random import randint
import lightning as pl
import torch
from bioio.writers import OmeTiffWriter

from mmv_im2im.utils.misc import (
parse_config,
parse_config_func,
parse_config_func_without_params,
)
from mmv_im2im.utils.model_utils import init_weights


class Model(pl.LightningModule):
def __init__(self, model_info_xx: Dict, train: bool = True, verbose: bool = False):
super().__init__()
self.net = parse_config(model_info_xx.net)
init_weights(self.net, init_type="kaiming")

self.model_info = model_info_xx
self.verbose = verbose
self.weighted_loss = False
if train:
self.criterion = parse_config(model_info_xx.criterion)
self.optimizer_func = parse_config_func(model_info_xx.optimizer)

# Store these as attributes for access in run_step/training_step/validation_step
self.last_prior_mu = None
self.last_prior_logvar = None
self.last_post_mu = None
self.last_post_logvar = None

def forward(self, x, y=None):
# The underlying ProbabilisticUNet returns multiple values.
# Capture them here and store them as instance attributes.
logits, prior_mu, prior_logvar, post_mu, post_logvar = self.net(x, y)

# Store for use in run_step (which calculates loss)
self.last_prior_mu = prior_mu
self.last_prior_logvar = prior_logvar
self.last_post_mu = post_mu
self.last_post_logvar = post_logvar

# For the 'Model' (LightningModule) forward, only return the logits
# This makes the API consistent with other models in your framework.
return logits

def configure_optimizers(self):
optimizer = self.optimizer_func(self.parameters())
if self.model_info.scheduler is None:
return optimizer
else:
scheduler_func = parse_config_func_without_params(self.model_info.scheduler)
lr_scheduler = scheduler_func(
optimizer, **self.model_info.scheduler["params"]
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"monitor": "val_loss",
"interval": "epoch",
"frequency": 1,
"strict": True,
},
}

def run_step(self, batch, validation_stage):
x = batch["IM"]
y = batch["GT"]

if x.size(-1) == 1:
x = torch.squeeze(x, dim=-1)
y = torch.squeeze(y, dim=-1)

# Call forward pass of the LightningModule.
# This will internally call self.net(x,y) and store the extra outputs.
logits = self(x, y) # This is now just 'logits'

# Calculate loss using the stored attributes
# Ensure post_mu and post_logvar are not None if y was provided
# The ELBOLoss expects these to be tensors, not None.
if self.last_post_mu is None or self.last_post_logvar is None:
raise ValueError(
"Posterior distributions (mu, logvar) were not computed. Ensure 'y' is provided during training."
)

loss = self.criterion(
logits,
y,
self.last_prior_mu,
self.last_prior_logvar,
self.last_post_mu,
self.last_post_logvar,
)

return loss, logits

def training_step(self, batch, batch_idx):
loss, y_hat = self.run_step(batch, validation_stage=False)
self.log(
"train_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)

if self.verbose and batch_idx == 0:
self.log_images(batch, y_hat, "train")

return loss

def validation_step(self, batch, batch_idx):
loss, y_hat = self.run_step(batch, validation_stage=True)
self.log(
"val_loss",
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)

if self.verbose and batch_idx == 0:
self.log_images(batch, y_hat, "val")

return loss

def log_images(self, batch, y_hat, stage):
src = batch["IM"]
tar = batch["GT"]

save_path = Path(self.trainer.log_dir)
save_path.mkdir(parents=True, exist_ok=True)

act = torch.nn.Softmax(dim=1)
yhat_act = act(y_hat)

src_out = np.squeeze(src[0].detach().cpu().numpy()).astype(float)
tar_out = np.squeeze(tar[0].detach().cpu().numpy()).astype(float)
prd_out = np.squeeze(yhat_act[0].detach().cpu().numpy()).astype(float)

def get_dim_order(arr):
dims = len(arr.shape)
return {2: "YX", 3: "ZYX", 4: "CZYX"}.get(dims, "YX")

rand_tag = randint(1, 1000)

out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_src_{rand_tag}.tiff"
OmeTiffWriter.save(src_out, out_fn, dim_order=get_dim_order(src_out))

out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_tar_{rand_tag}.tiff"
OmeTiffWriter.save(tar_out, out_fn, dim_order=get_dim_order(tar_out))

out_fn = save_path / f"epoch_{self.current_epoch}_{stage}_prd_{rand_tag}.tiff"
OmeTiffWriter.save(prd_out, out_fn, dim_order=get_dim_order(prd_out))
2 changes: 2 additions & 0 deletions mmv_im2im/postprocessing/basic_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def generate_classmap(im: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
# convert tensor to numpy
if torch.is_tensor(im):
im = im.cpu().numpy()
if len(im.shape) == 4 and im.shape[0] != 1:
im = im[None, ...]
assert len(im.shape) == 4 or len(im.shape) == 5, "extract seg only accepts 4D/5D"
assert im.shape[0] == 1, "extract seg requires first dim to be 1"

Expand Down
4 changes: 2 additions & 2 deletions mmv_im2im/tests/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def test_connectivity_loss(dims):
pred_softmax = F.softmax(logits, dim=1)
target_one_hot = F.one_hot(target_indices, num_classes=n_classes).float()
if dims == 2:
target_one_hot = target_one_hot.permute(0, 3, 1, 2)
target_one_hot = target_one_hot.permute(0, 3, 1, 2)
else:
target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3)
target_one_hot = target_one_hot.permute(0, 4, 1, 2, 3)

loss = loss_fn(pred_softmax, target_one_hot)
assert not torch.isnan(loss)
Expand Down