Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions connectomics/config/profiles/loss_profiles.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ loss_profiles:
- function: DiceLoss
weight: 1.0
kwargs: {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5}
loss_soft_cldice:
losses:
- function: SoftClDiceLoss
weight: 1.0
kwargs: {mode: binary, num_iters: 5, sigmoid: true}
loss_bcd:
losses:
- function: WeightedBCEWithLogitsLoss
Expand Down
6 changes: 4 additions & 2 deletions connectomics/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Connectomics-specific losses (for direct use if needed)
from .losses import (
GANLoss,
SoftClDiceLoss,
WeightedMAELoss,
WeightedMSELoss,
)
Expand All @@ -42,9 +43,10 @@
"get_loss_metadata",
"get_loss_metadata_for_module",
# Custom losses
"WeightedMSELoss",
"WeightedMAELoss",
"GANLoss",
"SoftClDiceLoss",
"WeightedMAELoss",
"WeightedMSELoss",
# Regularization losses
"BinaryRegularization",
"ForegroundDistanceConsistency",
Expand Down
2 changes: 2 additions & 0 deletions connectomics/models/losses/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CrossEntropyLossWrapper,
GANLoss,
PerChannelBCEWithLogitsLoss,
SoftClDiceLoss,
SmoothL1Loss,
WeightedBCEWithLogitsLoss,
WeightedMAELoss,
Expand Down Expand Up @@ -70,6 +71,7 @@ def _get_loss_registry() -> Dict[str, type[nn.Module]]:
# Custom connectomics losses
"WeightedBCEWithLogitsLoss": WeightedBCEWithLogitsLoss,
"PerChannelBCEWithLogitsLoss": PerChannelBCEWithLogitsLoss,
"SoftClDiceLoss": SoftClDiceLoss,
"WeightedMSELoss": WeightedMSELoss,
"WeightedMAELoss": WeightedMAELoss,
"GANLoss": GANLoss,
Expand Down
321 changes: 316 additions & 5 deletions connectomics/models/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from typing import Union
from typing import List, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -44,6 +44,47 @@ def _reduce_weighted_tensor(
return loss_tensor[valid].mean()


def _soft_erode_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable morphological erosion using min-pool via max-pool."""
if prob.ndim == 5:
# Use an axis-aligned cross (3x1x1, 1x3x1, 1x1x3), matching clDice-style soft morphology.
p1 = -F.max_pool3d(-prob, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
p2 = -F.max_pool3d(-prob, kernel_size=(1, 3, 1), stride=1, padding=(0, 1, 0))
p3 = -F.max_pool3d(-prob, kernel_size=(1, 1, 3), stride=1, padding=(0, 0, 1))
return torch.minimum(torch.minimum(p1, p2), p3)
if prob.ndim == 4:
p1 = -F.max_pool2d(-prob, kernel_size=(3, 1), stride=1, padding=(1, 0))
p2 = -F.max_pool2d(-prob, kernel_size=(1, 3), stride=1, padding=(0, 1))
return torch.minimum(p1, p2)
raise ValueError(f"Expected 4D/5D tensor for soft erosion, got shape {tuple(prob.shape)}")


def _soft_dilate_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable morphological dilation."""
if prob.ndim == 5:
return F.max_pool3d(prob, kernel_size=3, stride=1, padding=1)
if prob.ndim == 4:
return F.max_pool2d(prob, kernel_size=3, stride=1, padding=1)
raise ValueError(f"Expected 4D/5D tensor for soft dilation, got shape {tuple(prob.shape)}")


def _soft_open_pool(prob: torch.Tensor) -> torch.Tensor:
"""Differentiable opening (erode followed by dilate)."""
return _soft_dilate_pool(_soft_erode_pool(prob))


def _soft_skeletonize_pool(prob: torch.Tensor, num_iters: int) -> torch.Tensor:
"""Iterative soft skeletonization from clDice-style morphology."""
opened = _soft_open_pool(prob)
skeleton = F.relu(prob - opened)
for _ in range(num_iters):
prob = _soft_erode_pool(prob)
opened = _soft_open_pool(prob)
delta = F.relu(prob - opened)
skeleton = skeleton + F.relu(delta - skeleton * delta)
return skeleton


class CrossEntropyLossWrapper(nn.Module):
"""
Wrapper for CrossEntropyLoss that handles shape conversion.
Expand Down Expand Up @@ -310,6 +351,275 @@ def forward(
return total


class SoftClDiceLoss(nn.Module):
"""
Soft clDice loss using differentiable skeletonization.

Supports optional activation on logits (`sigmoid=True` or `softmax=True`),
MONAI-style, before topology computation.
Targets must be dense maps (one-hot or soft labels); single-channel class-index
targets are accepted and converted to one-hot for multi-class predictions.

Args:
num_iters: Number of soft-skeleton erosion iterations.
mode: ``"binary"`` (single foreground channel) or ``"multi"`` (all classes except
``background_index``).
reduction: ``"none"``, ``"mean"``, or ``"sum"``.
smooth: Smoothing constant in topology precision/sensitivity fractions.
foreground_channel: Foreground channel index used in ``mode="binary"`` when C>1.
background_index: Background channel index excluded in ``mode="multi"``.
sigmoid: Apply sigmoid activation to predictions inside ``forward``.
softmax: Apply softmax activation to predictions inside ``forward``.
clamp_probabilities: Clamp predictions/targets to ``[0, 1]`` before skeletonization.
validate_inputs: When True, enforce probability-range and minimal spatial-size checks.
validation_tolerance: Tolerance used by probability-range checks.
"""

def __init__(
self,
num_iters: int = 5,
mode: str = "binary",
reduction: str = "mean",
smooth: float = 1.0,
foreground_channel: int = 1,
background_index: int = 0,
sigmoid: bool = False,
softmax: bool = False,
clamp_probabilities: bool = False,
validate_inputs: bool = True,
validation_tolerance: float = 1e-5,
):
super().__init__()
if num_iters < 0:
raise ValueError(f"num_iters must be >= 0, got {num_iters}")
if mode not in {"binary", "multi"}:
raise ValueError(f"mode must be 'binary' or 'multi', got {mode!r}")
if reduction not in {"none", "mean", "sum"}:
raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction!r}")
if smooth <= 0:
raise ValueError(f"smooth must be > 0, got {smooth}")
if sigmoid and softmax:
raise ValueError("sigmoid and softmax are mutually exclusive")
if validation_tolerance < 0:
raise ValueError(f"validation_tolerance must be >= 0, got {validation_tolerance}")

self.num_iters = int(num_iters)
self.mode = mode
self.reduction = reduction
self.smooth = float(smooth)
self.foreground_channel = int(foreground_channel)
self.background_index = int(background_index)
self.sigmoid = bool(sigmoid)
self.softmax = bool(softmax)
self.clamp_probabilities = bool(clamp_probabilities)
self.validate_inputs = bool(validate_inputs)
self.validation_tolerance = float(validation_tolerance)

def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
if target.ndim == pred.ndim - 1:
target = target.unsqueeze(1)
if target.ndim != pred.ndim:
raise ValueError(
f"Target ndim ({target.ndim}) does not match prediction ndim ({pred.ndim})"
)
if target.shape[0] != pred.shape[0] or target.shape[2:] != pred.shape[2:]:
raise ValueError(
"Target shape must match prediction shape except for channel dimension: "
f"target={tuple(target.shape)}, pred={tuple(pred.shape)}"
)

if target.shape[1] == pred.shape[1]:
return target.to(device=pred.device, dtype=pred.dtype)

if target.shape[1] == 1 and pred.shape[1] > 1:
class_index = target.squeeze(1).long()
min_label = int(class_index.min().item())
max_label = int(class_index.max().item())
if min_label < 0 or max_label >= pred.shape[1]:
raise ValueError(
f"Class-index targets must be in [0, {pred.shape[1] - 1}], "
f"got min={min_label}, max={max_label}"
)
one_hot = F.one_hot(class_index, num_classes=pred.shape[1]).movedim(-1, 1)
return one_hot.to(device=pred.device, dtype=pred.dtype)

raise ValueError(
"Target channel count is incompatible with prediction: "
f"target_channels={target.shape[1]}, pred_channels={pred.shape[1]}"
)

def _apply_activation(self, pred: torch.Tensor) -> torch.Tensor:
if self.sigmoid:
return torch.sigmoid(pred)
if self.softmax:
if pred.shape[1] < 2:
raise ValueError("softmax=True requires prediction with at least 2 channels")
return F.softmax(pred, dim=1)
return pred

def _validate_probability_range(self, tensor: torch.Tensor, name: str) -> None:
if not self.validate_inputs:
return
tol = self.validation_tolerance
min_val = float(tensor.min().item())
max_val = float(tensor.max().item())
if min_val < -tol or max_val > (1.0 + tol):
raise ValueError(
f"{name} must be probabilities in [0, 1] (tolerance={tol}), "
f"got min={min_val:.6f}, max={max_val:.6f}. "
"Pass sigmoid=True/softmax=True for logits."
)

def _validate_spatial_shape(self, pred: torch.Tensor) -> None:
if not self.validate_inputs:
return
spatial_shape = tuple(pred.shape[2:])
if any(dim < 3 for dim in spatial_shape):
raise ValueError(
"SoftClDiceLoss expects each spatial dimension >= 3 for stable morphology, "
f"got spatial shape {spatial_shape}"
)

def _select_foreground_channels(
self, pred: torch.Tensor, target: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, List[int]]:
channels = pred.shape[1]
if self.mode == "binary":
fg_idx = 0 if channels == 1 else self.foreground_channel
if fg_idx < 0 or fg_idx >= channels:
raise ValueError(
f"foreground_channel={self.foreground_channel} is invalid "
f"for {channels} channels"
)
return pred[:, fg_idx : fg_idx + 1], target[:, fg_idx : fg_idx + 1], [fg_idx]

if channels == 1:
return pred, target, [0]

background_index = self.background_index
if background_index < 0:
background_index += channels
if background_index < 0 or background_index >= channels:
raise ValueError(
f"background_index={self.background_index} is invalid for {channels} channels"
)

foreground_indices = [idx for idx in range(channels) if idx != background_index]
if not foreground_indices:
raise ValueError(
f"No foreground classes available: channels={channels}, "
f"background_index={self.background_index}"
)
index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long)
return (
torch.index_select(pred, dim=1, index=index_tensor),
torch.index_select(target, dim=1, index=index_tensor),
foreground_indices,
)

def _prepare_weight(
self,
weight: torch.Tensor,
pred: torch.Tensor,
foreground_indices: List[int],
num_fg_channels: int,
) -> torch.Tensor:
if weight.ndim == pred.ndim - 1:
weight = weight.unsqueeze(1)
if weight.ndim != pred.ndim:
raise ValueError(f"Weight ndim ({weight.ndim}) must match pred ndim ({pred.ndim})")
if weight.shape[0] != pred.shape[0] or weight.shape[2:] != pred.shape[2:]:
raise ValueError(
"Weight shape must match prediction shape except for channel dimension: "
f"weight={tuple(weight.shape)}, pred={tuple(pred.shape)}"
)

weight = weight.to(device=pred.device, dtype=pred.dtype)
if weight.shape[1] == num_fg_channels:
return weight
if weight.shape[1] == 1:
return weight.expand(weight.shape[0], num_fg_channels, *weight.shape[2:])
if weight.shape[1] == pred.shape[1]:
if num_fg_channels == pred.shape[1]:
return weight
index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long)
return torch.index_select(weight, dim=1, index=index_tensor)

raise ValueError(
"Weight channel count must be 1, foreground-channel count, "
"or prediction-channel count; "
f"got {weight.shape[1]}"
)

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
weight: torch.Tensor | None = None,
) -> torch.Tensor:
if pred.ndim not in {4, 5}:
raise ValueError(f"SoftClDiceLoss expects 4D/5D tensors, got {tuple(pred.shape)}")

pred = self._apply_activation(pred)
target = self._prepare_target(target, pred)

if self.clamp_probabilities:
pred = pred.clamp(0.0, 1.0)
target = target.clamp(0.0, 1.0)

self._validate_spatial_shape(pred)
self._validate_probability_range(pred, "pred")
self._validate_probability_range(target, "target")

pred_fg, target_fg, foreground_indices = self._select_foreground_channels(pred, target)

if self.clamp_probabilities:
pred_fg = pred_fg.clamp(0.0, 1.0)
target_fg = target_fg.clamp(0.0, 1.0)

fg_weight = None
if weight is not None:
fg_weight = self._prepare_weight(weight, pred, foreground_indices, pred_fg.shape[1])
if self.clamp_probabilities:
fg_weight = fg_weight.clamp_min(0.0)

pred_skeleton = _soft_skeletonize_pool(pred_fg, self.num_iters)
target_skeleton = _soft_skeletonize_pool(target_fg, self.num_iters)

if fg_weight is not None:
pred_eval = pred_fg * fg_weight
target_eval = target_fg * fg_weight
pred_skeleton_eval = pred_skeleton * fg_weight
target_skeleton_eval = target_skeleton * fg_weight
else:
pred_eval = pred_fg
target_eval = target_fg
pred_skeleton_eval = pred_skeleton
target_skeleton_eval = target_skeleton

spatial_dims = tuple(range(2, pred_fg.ndim))
topology_precision = (
(pred_skeleton_eval * target_eval).sum(dim=spatial_dims) + self.smooth
) / (pred_skeleton_eval.sum(dim=spatial_dims) + self.smooth)
topology_sensitivity = (
(target_skeleton_eval * pred_eval).sum(dim=spatial_dims) + self.smooth
) / (target_skeleton_eval.sum(dim=spatial_dims) + self.smooth)

cl_dice = (
2.0
* topology_precision
* topology_sensitivity
/ (topology_precision + topology_sensitivity + self.smooth)
)
loss = 1.0 - cl_dice

if self.reduction == "none":
return loss
if self.reduction == "sum":
return loss.sum()
return loss.mean()


class WeightedMAELoss(nn.Module):
"""
Weighted mean absolute error loss.
Expand Down Expand Up @@ -476,10 +786,11 @@ def forward(

__all__ = [
"CrossEntropyLossWrapper",
"WeightedBCEWithLogitsLoss",
"GANLoss",
"PerChannelBCEWithLogitsLoss",
"WeightedMSELoss",
"WeightedMAELoss",
"SmoothL1Loss",
"GANLoss",
"SoftClDiceLoss",
"WeightedBCEWithLogitsLoss",
"WeightedMAELoss",
"WeightedMSELoss",
]
Loading