Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions connectomics/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Connectomics-specific losses (for direct use if needed)
from .losses import (
GANLoss,
SoftClDiceLoss,
WeightedMAELoss,
WeightedMSELoss,
)
Expand All @@ -42,6 +43,7 @@
"get_loss_metadata",
"get_loss_metadata_for_module",
# Custom losses
"SoftClDiceLoss",
"WeightedMSELoss",
"WeightedMAELoss",
"GANLoss",
Expand Down
2 changes: 2 additions & 0 deletions connectomics/models/losses/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CrossEntropyLossWrapper,
GANLoss,
PerChannelBCEWithLogitsLoss,
SoftClDiceLoss,
SmoothL1Loss,
WeightedBCEWithLogitsLoss,
WeightedMAELoss,
Expand Down Expand Up @@ -70,6 +71,7 @@ def _get_loss_registry() -> Dict[str, type[nn.Module]]:
# Custom connectomics losses
"WeightedBCEWithLogitsLoss": WeightedBCEWithLogitsLoss,
"PerChannelBCEWithLogitsLoss": PerChannelBCEWithLogitsLoss,
"SoftClDiceLoss": SoftClDiceLoss,
"WeightedMSELoss": WeightedMSELoss,
"WeightedMAELoss": WeightedMAELoss,
"GANLoss": GANLoss,
Expand Down
252 changes: 251 additions & 1 deletion connectomics/models/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Union
from typing import List, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -44,6 +44,46 @@ def _reduce_weighted_tensor(
return loss_tensor[valid].mean()


def _soft_erode_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable morphological erosion using min-pool via max-pool."""
if prob.ndim == 5:
p1 = -F.max_pool3d(-prob, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
p2 = -F.max_pool3d(-prob, kernel_size=(1, 3, 1), stride=1, padding=(0, 1, 0))
p3 = -F.max_pool3d(-prob, kernel_size=(1, 1, 3), stride=1, padding=(0, 0, 1))
return torch.minimum(torch.minimum(p1, p2), p3)
if prob.ndim == 4:
p1 = -F.max_pool2d(-prob, kernel_size=(3, 1), stride=1, padding=(1, 0))
p2 = -F.max_pool2d(-prob, kernel_size=(1, 3), stride=1, padding=(0, 1))
return torch.minimum(p1, p2)
raise ValueError(f"Expected 4D/5D tensor for soft erosion, got shape {tuple(prob.shape)}")


def _soft_dilate_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable morphological dilation."""
if prob.ndim == 5:
return F.max_pool3d(prob, kernel_size=3, stride=1, padding=1)
if prob.ndim == 4:
return F.max_pool2d(prob, kernel_size=3, stride=1, padding=1)
raise ValueError(f"Expected 4D/5D tensor for soft dilation, got shape {tuple(prob.shape)}")


def _soft_open_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable opening (erode followed by dilate)."""
return _soft_dilate_pool(_soft_erode_pool(prob))


def _soft_skeletonize_pool(prob: torch.Tensor, num_iters: int) -> torch.Tensor:
"""Iterative soft skeletonization from clDice-style morphology."""
opened = _soft_open_pool(prob)
skeleton = F.relu(prob - opened)
for _ in range(num_iters):
prob = _soft_erode_pool(prob)
opened = _soft_open_pool(prob)
delta = F.relu(prob - opened)
skeleton = skeleton + F.relu(delta - skeleton * delta)
return skeleton


class CrossEntropyLossWrapper(nn.Module):
"""
Wrapper for CrossEntropyLoss that handles shape conversion.
Expand Down Expand Up @@ -310,6 +350,215 @@ def forward(
return total


class SoftClDiceLoss(nn.Module):
"""
Soft clDice loss using differentiable skeletonization.

This loss expects probability maps (sigmoid/softmax outputs), not logits.
"""

def __init__(
self,
num_iters: int = 5,
mode: str = "binary",
reduction: str = "mean",
smooth: float = 1e-6,
foreground_channel: int = 1,
background_index: int = 0,
clamp_probabilities: bool = False,
use_fused_cuda: bool = False,
):
super().__init__()
if num_iters < 0:
raise ValueError(f"num_iters must be >= 0, got {num_iters}")
if mode not in {"binary", "multi"}:
raise ValueError(f"mode must be 'binary' or 'multi', got {mode!r}")
if reduction not in {"none", "mean", "sum"}:
raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction!r}")
if smooth <= 0:
raise ValueError(f"smooth must be > 0, got {smooth}")

self.num_iters = int(num_iters)
self.mode = mode
self.reduction = reduction
self.smooth = float(smooth)
self.foreground_channel = int(foreground_channel)
self.background_index = int(background_index)
self.clamp_probabilities = bool(clamp_probabilities)
self.use_fused_cuda = bool(use_fused_cuda)

def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
if target.ndim == pred.ndim - 1:
target = target.unsqueeze(1)
if target.ndim != pred.ndim:
raise ValueError(
f"Target ndim ({target.ndim}) does not match prediction ndim ({pred.ndim})"
)
if target.shape[0] != pred.shape[0] or target.shape[2:] != pred.shape[2:]:
raise ValueError(
"Target shape must match prediction shape except for channel dimension: "
f"target={tuple(target.shape)}, pred={tuple(pred.shape)}"
)

if target.shape[1] == pred.shape[1]:
return target.to(device=pred.device, dtype=pred.dtype)

if target.shape[1] == 1 and pred.shape[1] > 1:
class_index = target.squeeze(1).long()
min_label = int(class_index.min().item())
max_label = int(class_index.max().item())
if min_label < 0 or max_label >= pred.shape[1]:
raise ValueError(
f"Class-index targets must be in [0, {pred.shape[1] - 1}], "
f"got min={min_label}, max={max_label}"
)
one_hot = F.one_hot(class_index, num_classes=pred.shape[1]).movedim(-1, 1)
return one_hot.to(device=pred.device, dtype=pred.dtype)

raise ValueError(
"Target channel count is incompatible with prediction: "
f"target_channels={target.shape[1]}, pred_channels={pred.shape[1]}"
)

def _select_foreground_channels(
self, pred: torch.Tensor, target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, List[int]]:
channels = pred.shape[1]
if self.mode == "binary":
fg_idx = 0 if channels == 1 else self.foreground_channel
if fg_idx < 0 or fg_idx >= channels:
raise ValueError(
f"foreground_channel={self.foreground_channel} is invalid for {channels} channels"
)
return pred[:, fg_idx : fg_idx + 1], target[:, fg_idx : fg_idx + 1], [fg_idx]

if channels == 1:
return pred, target, [0]

background_index = self.background_index
if background_index < 0:
background_index += channels
if background_index < 0 or background_index >= channels:
raise ValueError(
f"background_index={self.background_index} is invalid for {channels} channels"
)

foreground_indices = [idx for idx in range(channels) if idx != background_index]
if not foreground_indices:
raise ValueError(
f"No foreground classes available: channels={channels}, "
f"background_index={self.background_index}"
)
index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long)
return (
torch.index_select(pred, dim=1, index=index_tensor),
torch.index_select(target, dim=1, index=index_tensor),
foreground_indices,
)

def _prepare_weight(
self,
weight: torch.Tensor,
pred: torch.Tensor,
foreground_indices: List[int],
num_fg_channels: int,
) -> torch.Tensor:
if weight.ndim == pred.ndim - 1:
weight = weight.unsqueeze(1)
if weight.ndim != pred.ndim:
raise ValueError(f"Weight ndim ({weight.ndim}) must match pred ndim ({pred.ndim})")
if weight.shape[0] != pred.shape[0] or weight.shape[2:] != pred.shape[2:]:
raise ValueError(
"Weight shape must match prediction shape except for channel dimension: "
f"weight={tuple(weight.shape)}, pred={tuple(pred.shape)}"
)

weight = weight.to(device=pred.device, dtype=pred.dtype)
if weight.shape[1] == num_fg_channels:
return weight
if weight.shape[1] == 1:
return weight.expand(weight.shape[0], num_fg_channels, *weight.shape[2:])
if weight.shape[1] == pred.shape[1]:
if num_fg_channels == pred.shape[1]:
return weight
index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long)
return torch.index_select(weight, dim=1, index=index_tensor)

raise ValueError(
"Weight channel count must be 1, foreground-channel count, or prediction-channel count; "
f"got {weight.shape[1]}"
)

def _soft_skeletonize(self, prob: torch.Tensor) -> torch.Tensor:
if self.use_fused_cuda and prob.is_cuda:
connectomics_ops = getattr(torch.ops, "connectomics", None)
fused_op = getattr(connectomics_ops, "soft_skeletonize", None)
if fused_op is not None:
try:
return fused_op(prob, self.num_iters)
except (RuntimeError, TypeError):
pass
return _soft_skeletonize_pool(prob, self.num_iters)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor = None,
) -> torch.Tensor:
if pred.ndim not in {4, 5}:
raise ValueError(f"SoftClDiceLoss expects 4D/5D tensors, got {tuple(pred.shape)}")

target = self._prepare_target(target, pred)
pred_fg, target_fg, foreground_indices = self._select_foreground_channels(pred, target)

if self.clamp_probabilities:
pred_fg = pred_fg.clamp(0.0, 1.0)
target_fg = target_fg.clamp(0.0, 1.0)

fg_weight = None
if weight is not None:
fg_weight = self._prepare_weight(weight, pred, foreground_indices, pred_fg.shape[1])
if self.clamp_probabilities:
fg_weight = fg_weight.clamp_min(0.0)

pred_skeleton = self._soft_skeletonize(pred_fg)
target_skeleton = self._soft_skeletonize(target_fg)

if fg_weight is not None:
pred_eval = pred_fg * fg_weight
target_eval = target_fg * fg_weight
pred_skeleton_eval = pred_skeleton * fg_weight
target_skeleton_eval = target_skeleton * fg_weight
else:
pred_eval = pred_fg
target_eval = target_fg
pred_skeleton_eval = pred_skeleton
target_skeleton_eval = target_skeleton

spatial_dims = tuple(range(2, pred_fg.ndim))
topology_precision = (
(pred_skeleton_eval * target_eval).sum(dim=spatial_dims) + self.smooth
) / (pred_skeleton_eval.sum(dim=spatial_dims) + self.smooth)
topology_sensitivity = (
(target_skeleton_eval * pred_eval).sum(dim=spatial_dims) + self.smooth
) / (target_skeleton_eval.sum(dim=spatial_dims) + self.smooth)

cl_dice = (
2.0
* topology_precision
* topology_sensitivity
/ (topology_precision + topology_sensitivity + self.smooth)
)
loss = 1.0 - cl_dice

if self.reduction == "none":
return loss
if self.reduction == "sum":
return loss.sum()
return loss.mean()


class WeightedMAELoss(nn.Module):
"""
Weighted mean absolute error loss.
Expand Down Expand Up @@ -478,6 +727,7 @@ def forward(
"CrossEntropyLossWrapper",
"WeightedBCEWithLogitsLoss",
"PerChannelBCEWithLogitsLoss",
"SoftClDiceLoss",
"WeightedMSELoss",
"WeightedMAELoss",
"SmoothL1Loss",
Expand Down
1 change: 1 addition & 0 deletions connectomics/models/losses/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class LossMetadata:
"PerChannelBCEWithLogitsLoss": LossMetadata(
"PerChannelBCEWithLogitsLoss", spatial_weight_arg="weight"
),
"SoftClDiceLoss": LossMetadata("SoftClDiceLoss", spatial_weight_arg="weight"),
"WeightedMSELoss": LossMetadata("WeightedMSELoss", spatial_weight_arg="weight"),
"WeightedMAELoss": LossMetadata("WeightedMAELoss", spatial_weight_arg="weight"),
# GAN is not compatible with the generic supervised orchestrator path
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest

import torch
import torch.nn.functional as F

from connectomics.models.losses import create_loss

Expand Down Expand Up @@ -37,6 +38,29 @@ def test_tversky_loss(self):
loss = loss_fn(pred, target)
self.assertTrue(loss >= 0.0)

def test_soft_cldice_binary_probabilities(self):
"""Soft clDice in binary mode should accept probability tensors."""
loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=3)
pred = torch.rand(2, 1, 4, 8, 8)
target = (torch.rand(2, 1, 4, 8, 8) > 0.5).float()

loss = loss_fn(pred, target)
self.assertTrue(torch.isfinite(loss))
self.assertTrue(loss >= 0.0)

def test_soft_cldice_multi_excludes_background(self):
"""Soft clDice in multi mode should use foreground classes only."""
loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=2)

logits = torch.randn(1, 3, 4, 8, 8)
pred = torch.softmax(logits, dim=1)
labels = torch.randint(0, 3, (1, 4, 8, 8))
target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float()

loss = loss_fn(pred, target)
self.assertTrue(torch.isfinite(loss))
self.assertTrue(loss >= 0.0)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions tests/unit/test_loss_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,12 @@ def _expected_pos_weight(target: torch.Tensor) -> torch.Tensor:
def test_create_loss_attaches_metadata_for_supervised_and_regularization_losses():
weighted_bce = create_loss("WeightedBCEWithLogitsLoss")
weighted_mse = create_loss("WeightedMSELoss")
soft_cldice = create_loss("SoftClDiceLoss")
binary_reg = create_loss("BinaryRegularization")

weighted_bce_meta = weighted_bce._connectomics_loss_metadata
weighted_meta = weighted_mse._connectomics_loss_metadata
soft_cldice_meta = soft_cldice._connectomics_loss_metadata
reg_meta = binary_reg._connectomics_loss_metadata

assert weighted_bce_meta.name == "WeightedBCEWithLogitsLoss"
Expand All @@ -153,6 +155,10 @@ def test_create_loss_attaches_metadata_for_supervised_and_regularization_losses(
assert weighted_meta.call_kind == "pred_target"
assert weighted_meta.spatial_weight_arg == "weight"

assert soft_cldice_meta.name == "SoftClDiceLoss"
assert soft_cldice_meta.call_kind == "pred_target"
assert soft_cldice_meta.spatial_weight_arg == "weight"

assert reg_meta.name == "BinaryRegularization"
assert reg_meta.call_kind == "pred_only"
assert reg_meta.spatial_weight_arg == "mask"
Expand Down