Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@
# You can specify multiple suffix as a list of string:
#
source_suffix = {
".rst": "restructuredtext",
".txt": "markdown",
".md": "markdown",
".rst": "restructuredtext",
".txt": "markdown",
".md": "markdown",
}

# The master toctree document.
master_doc = "index"

# General information about the project.
project = u"MMV Im2Im Transformation"
copyright = u'2021, Jianxu Chen'
author = u"Jianxu Chen"
project = "MMV Im2Im Transformation"
copyright = "2021, Jianxu Chen"
author = "Jianxu Chen"

# The version info for the project you"re documenting, acts as replacement
# for |version| and |release|, also used in various other places throughout
Expand Down Expand Up @@ -135,15 +135,12 @@
# The paper size ("letterpaper" or "a4paper").
#
# "papersize": "letterpaper",

# The font size ("10pt", "11pt" or "12pt").
#
# "pointsize": "10pt",

# Additional stuff for the LaTeX preamble.
#
# "preamble": "",

# Latex figure (float) alignment
#
# "figure_align": "htbp",
Expand All @@ -153,9 +150,13 @@
# (source start file, target name, title, author, documentclass
# [howto, manual, or own class]).
latex_documents = [
(master_doc, "mmv_im2im.tex",
u"MMV Im2Im Transformation Documentation",
u"Jianxu Chen", "manual"),
(
master_doc,
"mmv_im2im.tex",
"MMV Im2Im Transformation Documentation",
"Jianxu Chen",
"manual",
),
]


Expand All @@ -164,9 +165,7 @@
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, "mmv_im2im",
u"MMV Im2Im Transformation Documentation",
[author], 1)
(master_doc, "mmv_im2im", "MMV Im2Im Transformation Documentation", [author], 1)
]


Expand All @@ -176,10 +175,13 @@
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, "mmv_im2im",
u"MMV Im2Im Transformation Documentation",
author,
"mmv_im2im",
"One line description of project.",
"Miscellaneous"),
(
master_doc,
"mmv_im2im",
"MMV Im2Im Transformation Documentation",
author,
"mmv_im2im",
"One line description of project.",
"Miscellaneous",
),
]
1 change: 0 additions & 1 deletion mmv_im2im/models/pl_FCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,3 @@ def validation_step(self, batch, batch_idx):
)

return loss

27 changes: 19 additions & 8 deletions mmv_im2im/proj_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from mmv_im2im.utils.for_transform import parse_monai_ops_vanilla
from skimage.io import imsave as save_rgb


# from mmv_im2im.utils.piecewise_inference import predict_piecewise
from monai.inferers import sliding_window_inference

Expand Down Expand Up @@ -60,16 +61,26 @@ def setup_model(self):
and self.model_cfg.model_extra["cpu_only"]
):
self.cpu = True
pre_train = torch.load(
self.model_cfg.checkpoint, map_location=torch.device("cpu")
checkpoint = torch.load(
self.model_cfg.checkpoint,
map_location=torch.device("cpu"),
weights_only=False,
)
else:
pre_train = torch.load(self.model_cfg.checkpoint)
checkpoint = torch.load(self.model_cfg.checkpoint, weights_only=False)

if isinstance(checkpoint, dict) and "state_dict" in checkpoint:

# TODO: hacky solution to remove a wrongly registered key
pre_train["state_dict"].pop("criterion.xym", None)
pre_train["state_dict"].pop("criterion.xyzm", None)
self.model.load_state_dict(pre_train["state_dict"])
pre_train = checkpoint
pre_train["state_dict"].pop("criterion.xym", None)
pre_train["state_dict"].pop("criterion.xyzm", None)
self.model.load_state_dict(pre_train["state_dict"], strict=False)
else:

state_dict = checkpoint
state_dict.pop("criterion.xym", None)
state_dict.pop("criterion.xyzm", None)
self.model.load_state_dict(state_dict, strict=False)

if not self.cpu:
self.model.cuda()
Expand Down Expand Up @@ -322,7 +333,7 @@ def run_inference(self):
img = BioImage(ds).get_image_data(
**self.data_cfg.inference_input.reader_params
)

# prepare output filename
if "." in suffix:
if (
Expand Down
10 changes: 6 additions & 4 deletions mmv_im2im/proj_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ def run_training(self):
self.model = self.model.load_from_checkpoint(
self.model_cfg.model_extra["resume"]
)

elif "pre-train" in self.model_cfg.model_extra:
pre_train = torch.load(self.model_cfg.model_extra["pre-train"])
pre_train = torch.load(
self.model_cfg.model_extra["pre-train"], weights_only=False
)

if "extend" in self.model_cfg.model_extra:
if (
Expand All @@ -71,10 +74,10 @@ def run_training(self):
)

model_state.update(filtered_dict)
self.model.load_state_dict(model_state)
self.model.load_state_dict(model_state, strict=False)
else:
pre_train["state_dict"].pop("criterion.xym", None)
self.model.load_state_dict(pre_train["state_dict"])
self.model.load_state_dict(pre_train["state_dict"], strict=False)

if self.train_cfg.callbacks is None:
trainer = pl.Trainer(**self.train_cfg.params)
Expand All @@ -97,4 +100,3 @@ def run_training(self):

print("start training ... ")
trainer.fit(model=self.model, datamodule=self.data)

102 changes: 102 additions & 0 deletions mmv_im2im/utils/connectivity_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConnectivityCoherenceLoss(nn.Module):
"""
Calculates a connectivity coherence loss to penalize fragmentation or
undesired isolated components within predicted regions.

This loss encourages predicted regions to be spatially coherent and continuous
by comparing local neighborhoods in predictions with ground truth.
It penalizes:
1. Discontinuities within what should be a single, connected region (e.g., breaking a vein).
2. Isolated 'islands' of one class within another class.

Args:
kernel_size (int): Size of the convolutional kernel for neighborhood analysis (e.g., 3 for 3x3).
ignore_background (bool): If True, the loss focuses primarily on non-background classes.
Useful if background fragmentation is less critical.
num_classes (int): Total number of classes, including background.
"""

def __init__(
self, kernel_size: int = 3, ignore_background: bool = True, num_classes: int = 2
):
super().__init__()
if kernel_size % 2 == 0:
raise ValueError("kernel_size must be an odd number.")
self.kernel_size = kernel_size
self.ignore_background = ignore_background
self.num_classes = num_classes
self.average_kernel = torch.ones(1, 1, kernel_size, kernel_size) / (
kernel_size**2 - 1
)
self.center_offset = (kernel_size // 2, kernel_size // 2)

def forward(self, y_pred_softmax, y_true_one_hot):
"""
Args:
y_pred_softmax (torch.Tensor): Softmax probabilities from the model (B, C, H, W).
y_true_one_hot (torch.Tensor): Ground truth as one-hot encoded tensor (B, C, H, W).
Should be float.

Returns:
torch.Tensor: The calculated connectivity coherence loss.
"""
current_average_kernel = self.average_kernel.to(y_pred_softmax.device)

y_true_one_hot = y_true_one_hot.float()

loss_per_class = []

for c in range(self.num_classes):
if self.ignore_background and c == 0:
continue

true_mask_c = y_true_one_hot[:, c : c + 1, :, :]
pred_prob_c = y_pred_softmax[:, c : c + 1, :, :]

padded_true_mask_c = F.pad(
true_mask_c,
(
self.center_offset[1],
self.center_offset[1],
self.center_offset[0],
self.center_offset[0],
),
mode="replicate",
)

neighbor_sum_true = F.conv2d(
padded_true_mask_c, current_average_kernel, padding=0, groups=1
) * (self.kernel_size**2)

padded_pred_prob_c = F.pad(
pred_prob_c,
(
self.center_offset[1],
self.center_offset[1],
self.center_offset[0],
self.center_offset[0],
),
mode="replicate",
)

pred_neighbor_avg = F.conv2d(
padded_pred_prob_c, current_average_kernel, padding=0, groups=1
)

true_neighbor_avg = neighbor_sum_true / (self.kernel_size**2 - 1)

loss_b = F.mse_loss(pred_neighbor_avg, true_mask_c, reduction="none")
loss_c = F.mse_loss(pred_prob_c, true_neighbor_avg, reduction="none")

class_coherence_loss = torch.mean(loss_b + loss_c)
loss_per_class.append(class_coherence_loss)

if not loss_per_class:
return torch.tensor(0.0, device=y_pred_softmax.device)

return torch.sum(torch.stack(loss_per_class))
Loading
Loading