Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion mmv_im2im/data_modules/data_loader_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import monai
from monai.data import list_data_collate


class Im2ImDataModule(pl.LightningDataModule):
def __init__(self, data_cfg):
super().__init__()
Expand Down
29 changes: 22 additions & 7 deletions mmv_im2im/proj_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla
from skimage.io import imsave as save_rgb


# from mmv_im2im.utils.piecewise_inference import predict_piecewise
from monai.inferers import sliding_window_inference

Expand Down Expand Up @@ -48,6 +49,9 @@ def __init__(self, cfg):
self.cpu = False
self.spatial_dims = -1




def setup_model(self):
model_category = self.model_cfg.framework
model_module = import_module(f"mmv_im2im.models.pl_{model_category}")
Expand All @@ -60,22 +64,33 @@ def setup_model(self):
and self.model_cfg.model_extra["cpu_only"]
):
self.cpu = True
pre_train = torch.load(
self.model_cfg.checkpoint, map_location=torch.device("cpu")
checkpoint = torch.load(
self.model_cfg.checkpoint, map_location=torch.device("cpu"), weights_only=False
)
else:
pre_train = torch.load(self.model_cfg.checkpoint)
checkpoint = torch.load(self.model_cfg.checkpoint, weights_only=False)


# TODO: hacky solution to remove a wrongly registered key
pre_train["state_dict"].pop("criterion.xym", None)
pre_train["state_dict"].pop("criterion.xyzm", None)
self.model.load_state_dict(pre_train["state_dict"])
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:

pre_train = checkpoint
pre_train["state_dict"].pop("criterion.xym", None)
pre_train["state_dict"].pop("criterion.xyzm", None)
self.model.load_state_dict(pre_train["state_dict"], strict=False)
else:

state_dict = checkpoint
state_dict.pop("criterion.xym", None)
state_dict.pop("criterion.xyzm", None)
self.model.load_state_dict(state_dict, strict=False)

if not self.cpu:
self.model.cuda()

self.model.eval()



def setup_data_processing(self):
# determine spatial dimension from reader parameters
if "Z" in self.data_cfg.inference_input.reader_params["dimension_order_out"]:
Expand Down
8 changes: 5 additions & 3 deletions mmv_im2im/proj_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def run_training(self):
self.model = self.model.load_from_checkpoint(
self.model_cfg.model_extra["resume"]
)

elif "pre-train" in self.model_cfg.model_extra:
pre_train = torch.load(self.model_cfg.model_extra["pre-train"])
pre_train = torch.load(self.model_cfg.model_extra["pre-train"], weights_only=False)

if "extend" in self.model_cfg.model_extra:
if (
Expand All @@ -71,10 +72,11 @@ def run_training(self):
)

model_state.update(filtered_dict)
self.model.load_state_dict(model_state)
self.model.load_state_dict(model_state, strict=False)
else:
pre_train["state_dict"].pop("criterion.xym", None)
self.model.load_state_dict(pre_train["state_dict"])
self.model.load_state_dict(pre_train["state_dict"], strict=False)


if self.train_cfg.callbacks is None:
trainer = pl.Trainer(**self.train_cfg.params)
Expand Down
80 changes: 80 additions & 0 deletions mmv_im2im/utils/connectivity_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConnectivityCoherenceLoss(nn.Module):
"""
Calculates a connectivity coherence loss to penalize fragmentation or
undesired isolated components within predicted regions.

This loss encourages predicted regions to be spatially coherent and continuous
by comparing local neighborhoods in predictions with ground truth.
It penalizes:
1. Discontinuities within what should be a single, connected region (e.g., breaking a vein).
2. Isolated 'islands' of one class within another class.

Args:
kernel_size (int): Size of the convolutional kernel for neighborhood analysis (e.g., 3 for 3x3).
ignore_background (bool): If True, the loss focuses primarily on non-background classes.
Useful if background fragmentation is less critical.
num_classes (int): Total number of classes, including background.
"""
def __init__(self, kernel_size: int = 3, ignore_background: bool = True, num_classes: int = 2):
super().__init__()
if kernel_size % 2 == 0:
raise ValueError("kernel_size must be an odd number.")
self.kernel_size = kernel_size
self.ignore_background = ignore_background
self.num_classes = num_classes
self.average_kernel = torch.ones(1, 1, kernel_size, kernel_size) / (kernel_size**2 - 1)
self.center_offset = (kernel_size // 2, kernel_size // 2)

def forward(self, y_pred_softmax, y_true_one_hot):
"""
Args:
y_pred_softmax (torch.Tensor): Softmax probabilities from the model (B, C, H, W).
y_true_one_hot (torch.Tensor): Ground truth as one-hot encoded tensor (B, C, H, W).
Should be float.

Returns:
torch.Tensor: The calculated connectivity coherence loss.
"""
current_average_kernel = self.average_kernel.to(y_pred_softmax.device)

y_true_one_hot = y_true_one_hot.float()

loss_per_class = []

for c in range(self.num_classes):
if self.ignore_background and c == 0:
continue

true_mask_c = y_true_one_hot[:, c:c+1, :, :]
pred_prob_c = y_pred_softmax[:, c:c+1, :, :]

padded_true_mask_c = F.pad(true_mask_c, (self.center_offset[1], self.center_offset[1],
self.center_offset[0], self.center_offset[0]),
mode='replicate')


neighbor_sum_true = F.conv2d(padded_true_mask_c, current_average_kernel, padding=0, groups=1) * (self.kernel_size**2)

padded_pred_prob_c = F.pad(pred_prob_c, (self.center_offset[1], self.center_offset[1],
self.center_offset[0], self.center_offset[0]),
mode='replicate')

pred_neighbor_avg = F.conv2d(padded_pred_prob_c, current_average_kernel, padding=0, groups=1)

true_neighbor_avg = neighbor_sum_true / (self.kernel_size**2 - 1)

loss_b = F.mse_loss(pred_neighbor_avg, true_mask_c, reduction='none')
loss_c = F.mse_loss(pred_prob_c, true_neighbor_avg, reduction='none')

class_coherence_loss = torch.mean(loss_b + loss_c)
loss_per_class.append(class_coherence_loss)

if not loss_per_class:
return torch.tensor(0.0, device=y_pred_softmax.device)

return torch.sum(torch.stack(loss_per_class))
Loading
Loading