diff --git a/connectomics/config/profiles/loss_profiles.yaml b/connectomics/config/profiles/loss_profiles.yaml index b04cc604..56c5b29a 100644 --- a/connectomics/config/profiles/loss_profiles.yaml +++ b/connectomics/config/profiles/loss_profiles.yaml @@ -7,6 +7,11 @@ loss_profiles: - function: DiceLoss weight: 1.0 kwargs: {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} + loss_soft_cldice: + losses: + - function: SoftClDiceLoss + weight: 1.0 + kwargs: {mode: binary, num_iters: 5, sigmoid: true} loss_bcd: losses: - function: WeightedBCEWithLogitsLoss diff --git a/connectomics/models/losses/__init__.py b/connectomics/models/losses/__init__.py index e8b35d3f..280d0760 100755 --- a/connectomics/models/losses/__init__.py +++ b/connectomics/models/losses/__init__.py @@ -16,6 +16,7 @@ # Connectomics-specific losses (for direct use if needed) from .losses import ( GANLoss, + SoftClDiceLoss, WeightedMAELoss, WeightedMSELoss, ) @@ -42,9 +43,10 @@ "get_loss_metadata", "get_loss_metadata_for_module", # Custom losses - "WeightedMSELoss", - "WeightedMAELoss", "GANLoss", + "SoftClDiceLoss", + "WeightedMAELoss", + "WeightedMSELoss", # Regularization losses "BinaryRegularization", "ForegroundDistanceConsistency", diff --git a/connectomics/models/losses/build.py b/connectomics/models/losses/build.py index cef8c3f4..b5d690f3 100644 --- a/connectomics/models/losses/build.py +++ b/connectomics/models/losses/build.py @@ -28,6 +28,7 @@ CrossEntropyLossWrapper, GANLoss, PerChannelBCEWithLogitsLoss, + SoftClDiceLoss, SmoothL1Loss, WeightedBCEWithLogitsLoss, WeightedMAELoss, @@ -70,6 +71,7 @@ def _get_loss_registry() -> Dict[str, type[nn.Module]]: # Custom connectomics losses "WeightedBCEWithLogitsLoss": WeightedBCEWithLogitsLoss, "PerChannelBCEWithLogitsLoss": PerChannelBCEWithLogitsLoss, + "SoftClDiceLoss": SoftClDiceLoss, "WeightedMSELoss": WeightedMSELoss, "WeightedMAELoss": WeightedMAELoss, "GANLoss": GANLoss, diff --git a/connectomics/models/losses/losses.py b/connectomics/models/losses/losses.py index 2092e6c2..07e8d7f9 100644 --- a/connectomics/models/losses/losses.py +++ b/connectomics/models/losses/losses.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Union +from typing import List, Union import torch import torch.nn as nn @@ -44,6 +44,47 @@ def _reduce_weighted_tensor( return loss_tensor[valid].mean() +def _soft_erode_pool(prob: torch.Tensor) -> torch.Tensor: + """Differentiable morphological erosion using min-pool via max-pool.""" + if prob.ndim == 5: + # Use an axis-aligned cross (3x1x1, 1x3x1, 1x1x3), matching clDice-style soft morphology. + p1 = -F.max_pool3d(-prob, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) + p2 = -F.max_pool3d(-prob, kernel_size=(1, 3, 1), stride=1, padding=(0, 1, 0)) + p3 = -F.max_pool3d(-prob, kernel_size=(1, 1, 3), stride=1, padding=(0, 0, 1)) + return torch.minimum(torch.minimum(p1, p2), p3) + if prob.ndim == 4: + p1 = -F.max_pool2d(-prob, kernel_size=(3, 1), stride=1, padding=(1, 0)) + p2 = -F.max_pool2d(-prob, kernel_size=(1, 3), stride=1, padding=(0, 1)) + return torch.minimum(p1, p2) + raise ValueError(f"Expected 4D/5D tensor for soft erosion, got shape {tuple(prob.shape)}") + + +def _soft_dilate_pool(prob: torch.Tensor) -> torch.Tensor: + """Differentiable morphological dilation.""" + if prob.ndim == 5: + return F.max_pool3d(prob, kernel_size=3, stride=1, padding=1) + if prob.ndim == 4: + return F.max_pool2d(prob, kernel_size=3, stride=1, padding=1) + raise ValueError(f"Expected 4D/5D tensor for soft dilation, got shape {tuple(prob.shape)}") + + +def _soft_open_pool(prob: torch.Tensor) -> torch.Tensor: + """Differentiable opening (erode followed by dilate).""" + return _soft_dilate_pool(_soft_erode_pool(prob)) + + +def _soft_skeletonize_pool(prob: torch.Tensor, num_iters: int) -> torch.Tensor: + """Iterative soft skeletonization from clDice-style morphology.""" + opened = _soft_open_pool(prob) + skeleton = F.relu(prob - opened) + for _ in range(num_iters): + prob = _soft_erode_pool(prob) + opened = _soft_open_pool(prob) + delta = F.relu(prob - opened) + skeleton = skeleton + F.relu(delta - skeleton * delta) + return skeleton + + class CrossEntropyLossWrapper(nn.Module): """ Wrapper for CrossEntropyLoss that handles shape conversion. @@ -310,6 +351,275 @@ def forward( return total +class SoftClDiceLoss(nn.Module): + """ + Soft clDice loss using differentiable skeletonization. + + Supports optional activation on logits (`sigmoid=True` or `softmax=True`), + MONAI-style, before topology computation. + Targets must be dense maps (one-hot or soft labels); single-channel class-index + targets are accepted and converted to one-hot for multi-class predictions. + + Args: + num_iters: Number of soft-skeleton erosion iterations. + mode: ``"binary"`` (single foreground channel) or ``"multi"`` (all classes except + ``background_index``). + reduction: ``"none"``, ``"mean"``, or ``"sum"``. + smooth: Smoothing constant in topology precision/sensitivity fractions. + foreground_channel: Foreground channel index used in ``mode="binary"`` when C>1. + background_index: Background channel index excluded in ``mode="multi"``. + sigmoid: Apply sigmoid activation to predictions inside ``forward``. + softmax: Apply softmax activation to predictions inside ``forward``. + clamp_probabilities: Clamp predictions/targets to ``[0, 1]`` before skeletonization. + validate_inputs: When True, enforce probability-range and minimal spatial-size checks. + validation_tolerance: Tolerance used by probability-range checks. + """ + + def __init__( + self, + num_iters: int = 5, + mode: str = "binary", + reduction: str = "mean", + smooth: float = 1.0, + foreground_channel: int = 1, + background_index: int = 0, + sigmoid: bool = False, + softmax: bool = False, + clamp_probabilities: bool = False, + validate_inputs: bool = True, + validation_tolerance: float = 1e-5, + ): + super().__init__() + if num_iters < 0: + raise ValueError(f"num_iters must be >= 0, got {num_iters}") + if mode not in {"binary", "multi"}: + raise ValueError(f"mode must be 'binary' or 'multi', got {mode!r}") + if reduction not in {"none", "mean", "sum"}: + raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction!r}") + if smooth <= 0: + raise ValueError(f"smooth must be > 0, got {smooth}") + if sigmoid and softmax: + raise ValueError("sigmoid and softmax are mutually exclusive") + if validation_tolerance < 0: + raise ValueError(f"validation_tolerance must be >= 0, got {validation_tolerance}") + + self.num_iters = int(num_iters) + self.mode = mode + self.reduction = reduction + self.smooth = float(smooth) + self.foreground_channel = int(foreground_channel) + self.background_index = int(background_index) + self.sigmoid = bool(sigmoid) + self.softmax = bool(softmax) + self.clamp_probabilities = bool(clamp_probabilities) + self.validate_inputs = bool(validate_inputs) + self.validation_tolerance = float(validation_tolerance) + + def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: + if target.ndim == pred.ndim - 1: + target = target.unsqueeze(1) + if target.ndim != pred.ndim: + raise ValueError( + f"Target ndim ({target.ndim}) does not match prediction ndim ({pred.ndim})" + ) + if target.shape[0] != pred.shape[0] or target.shape[2:] != pred.shape[2:]: + raise ValueError( + "Target shape must match prediction shape except for channel dimension: " + f"target={tuple(target.shape)}, pred={tuple(pred.shape)}" + ) + + if target.shape[1] == pred.shape[1]: + return target.to(device=pred.device, dtype=pred.dtype) + + if target.shape[1] == 1 and pred.shape[1] > 1: + class_index = target.squeeze(1).long() + min_label = int(class_index.min().item()) + max_label = int(class_index.max().item()) + if min_label < 0 or max_label >= pred.shape[1]: + raise ValueError( + f"Class-index targets must be in [0, {pred.shape[1] - 1}], " + f"got min={min_label}, max={max_label}" + ) + one_hot = F.one_hot(class_index, num_classes=pred.shape[1]).movedim(-1, 1) + return one_hot.to(device=pred.device, dtype=pred.dtype) + + raise ValueError( + "Target channel count is incompatible with prediction: " + f"target_channels={target.shape[1]}, pred_channels={pred.shape[1]}" + ) + + def _apply_activation(self, pred: torch.Tensor) -> torch.Tensor: + if self.sigmoid: + return torch.sigmoid(pred) + if self.softmax: + if pred.shape[1] < 2: + raise ValueError("softmax=True requires prediction with at least 2 channels") + return F.softmax(pred, dim=1) + return pred + + def _validate_probability_range(self, tensor: torch.Tensor, name: str) -> None: + if not self.validate_inputs: + return + tol = self.validation_tolerance + min_val = float(tensor.min().item()) + max_val = float(tensor.max().item()) + if min_val < -tol or max_val > (1.0 + tol): + raise ValueError( + f"{name} must be probabilities in [0, 1] (tolerance={tol}), " + f"got min={min_val:.6f}, max={max_val:.6f}. " + "Pass sigmoid=True/softmax=True for logits." + ) + + def _validate_spatial_shape(self, pred: torch.Tensor) -> None: + if not self.validate_inputs: + return + spatial_shape = tuple(pred.shape[2:]) + if any(dim < 3 for dim in spatial_shape): + raise ValueError( + "SoftClDiceLoss expects each spatial dimension >= 3 for stable morphology, " + f"got spatial shape {spatial_shape}" + ) + + def _select_foreground_channels( + self, pred: torch.Tensor, target: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, List[int]]: + channels = pred.shape[1] + if self.mode == "binary": + fg_idx = 0 if channels == 1 else self.foreground_channel + if fg_idx < 0 or fg_idx >= channels: + raise ValueError( + f"foreground_channel={self.foreground_channel} is invalid " + f"for {channels} channels" + ) + return pred[:, fg_idx : fg_idx + 1], target[:, fg_idx : fg_idx + 1], [fg_idx] + + if channels == 1: + return pred, target, [0] + + background_index = self.background_index + if background_index < 0: + background_index += channels + if background_index < 0 or background_index >= channels: + raise ValueError( + f"background_index={self.background_index} is invalid for {channels} channels" + ) + + foreground_indices = [idx for idx in range(channels) if idx != background_index] + if not foreground_indices: + raise ValueError( + f"No foreground classes available: channels={channels}, " + f"background_index={self.background_index}" + ) + index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long) + return ( + torch.index_select(pred, dim=1, index=index_tensor), + torch.index_select(target, dim=1, index=index_tensor), + foreground_indices, + ) + + def _prepare_weight( + self, + weight: torch.Tensor, + pred: torch.Tensor, + foreground_indices: List[int], + num_fg_channels: int, + ) -> torch.Tensor: + if weight.ndim == pred.ndim - 1: + weight = weight.unsqueeze(1) + if weight.ndim != pred.ndim: + raise ValueError(f"Weight ndim ({weight.ndim}) must match pred ndim ({pred.ndim})") + if weight.shape[0] != pred.shape[0] or weight.shape[2:] != pred.shape[2:]: + raise ValueError( + "Weight shape must match prediction shape except for channel dimension: " + f"weight={tuple(weight.shape)}, pred={tuple(pred.shape)}" + ) + + weight = weight.to(device=pred.device, dtype=pred.dtype) + if weight.shape[1] == num_fg_channels: + return weight + if weight.shape[1] == 1: + return weight.expand(weight.shape[0], num_fg_channels, *weight.shape[2:]) + if weight.shape[1] == pred.shape[1]: + if num_fg_channels == pred.shape[1]: + return weight + index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long) + return torch.index_select(weight, dim=1, index=index_tensor) + + raise ValueError( + "Weight channel count must be 1, foreground-channel count, " + "or prediction-channel count; " + f"got {weight.shape[1]}" + ) + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor | None = None, + ) -> torch.Tensor: + if pred.ndim not in {4, 5}: + raise ValueError(f"SoftClDiceLoss expects 4D/5D tensors, got {tuple(pred.shape)}") + + pred = self._apply_activation(pred) + target = self._prepare_target(target, pred) + + if self.clamp_probabilities: + pred = pred.clamp(0.0, 1.0) + target = target.clamp(0.0, 1.0) + + self._validate_spatial_shape(pred) + self._validate_probability_range(pred, "pred") + self._validate_probability_range(target, "target") + + pred_fg, target_fg, foreground_indices = self._select_foreground_channels(pred, target) + + if self.clamp_probabilities: + pred_fg = pred_fg.clamp(0.0, 1.0) + target_fg = target_fg.clamp(0.0, 1.0) + + fg_weight = None + if weight is not None: + fg_weight = self._prepare_weight(weight, pred, foreground_indices, pred_fg.shape[1]) + if self.clamp_probabilities: + fg_weight = fg_weight.clamp_min(0.0) + + pred_skeleton = _soft_skeletonize_pool(pred_fg, self.num_iters) + target_skeleton = _soft_skeletonize_pool(target_fg, self.num_iters) + + if fg_weight is not None: + pred_eval = pred_fg * fg_weight + target_eval = target_fg * fg_weight + pred_skeleton_eval = pred_skeleton * fg_weight + target_skeleton_eval = target_skeleton * fg_weight + else: + pred_eval = pred_fg + target_eval = target_fg + pred_skeleton_eval = pred_skeleton + target_skeleton_eval = target_skeleton + + spatial_dims = tuple(range(2, pred_fg.ndim)) + topology_precision = ( + (pred_skeleton_eval * target_eval).sum(dim=spatial_dims) + self.smooth + ) / (pred_skeleton_eval.sum(dim=spatial_dims) + self.smooth) + topology_sensitivity = ( + (target_skeleton_eval * pred_eval).sum(dim=spatial_dims) + self.smooth + ) / (target_skeleton_eval.sum(dim=spatial_dims) + self.smooth) + + cl_dice = ( + 2.0 + * topology_precision + * topology_sensitivity + / (topology_precision + topology_sensitivity + self.smooth) + ) + loss = 1.0 - cl_dice + + if self.reduction == "none": + return loss + if self.reduction == "sum": + return loss.sum() + return loss.mean() + + class WeightedMAELoss(nn.Module): """ Weighted mean absolute error loss. @@ -476,10 +786,11 @@ def forward( __all__ = [ "CrossEntropyLossWrapper", - "WeightedBCEWithLogitsLoss", + "GANLoss", "PerChannelBCEWithLogitsLoss", - "WeightedMSELoss", - "WeightedMAELoss", "SmoothL1Loss", - "GANLoss", + "SoftClDiceLoss", + "WeightedBCEWithLogitsLoss", + "WeightedMAELoss", + "WeightedMSELoss", ] diff --git a/connectomics/models/losses/metadata.py b/connectomics/models/losses/metadata.py index 0925a4d7..095026e3 100644 --- a/connectomics/models/losses/metadata.py +++ b/connectomics/models/losses/metadata.py @@ -41,6 +41,7 @@ class LossMetadata: "PerChannelBCEWithLogitsLoss": LossMetadata( "PerChannelBCEWithLogitsLoss", spatial_weight_arg="weight" ), + "SoftClDiceLoss": LossMetadata("SoftClDiceLoss", spatial_weight_arg="weight"), "WeightedMSELoss": LossMetadata("WeightedMSELoss", spatial_weight_arg="weight"), "WeightedMAELoss": LossMetadata("WeightedMAELoss", spatial_weight_arg="weight"), # GAN is not compatible with the generic supervised orchestrator path diff --git a/tests/unit/test_loss_functions.py b/tests/unit/test_loss_functions.py index d33885ca..856249ca 100755 --- a/tests/unit/test_loss_functions.py +++ b/tests/unit/test_loss_functions.py @@ -3,6 +3,7 @@ import unittest import torch +import torch.nn.functional as F from connectomics.models.losses import create_loss @@ -37,6 +38,102 @@ def test_tversky_loss(self): loss = loss_fn(pred, target) self.assertTrue(loss >= 0.0) + def test_soft_cldice_binary_probabilities(self): + """Soft clDice in binary mode should accept probability tensors.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=3) + pred = torch.rand(2, 1, 4, 8, 8) + target = (torch.rand(2, 1, 4, 8, 8) > 0.5).float() + + loss = loss_fn(pred, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_multi_excludes_background(self): + """Soft clDice in multi mode should use foreground classes only.""" + loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=2) + + logits = torch.randn(1, 3, 4, 8, 8) + pred = torch.softmax(logits, dim=1) + labels = torch.randint(0, 3, (1, 4, 8, 8)) + target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float() + + loss = loss_fn(pred, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_rejects_logits_without_activation(self): + """Soft clDice should fail fast on logits when no activation is requested.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=1, validate_inputs=True) + logits = torch.randn(1, 1, 4, 8, 8) + target = (torch.rand(1, 1, 4, 8, 8) > 0.5).float() + + with self.assertRaisesRegex(ValueError, "must be probabilities in \\[0, 1\\]"): + _ = loss_fn(logits, target) + + def test_soft_cldice_accepts_logits_with_sigmoid(self): + """Soft clDice should support logits when sigmoid=True.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=2, sigmoid=True) + logits = torch.randn(1, 1, 4, 8, 8) + target = (torch.rand(1, 1, 4, 8, 8) > 0.5).float() + + loss = loss_fn(logits, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_clamp_probabilities_allows_out_of_range_inputs(self): + """Clamping should tolerate out-of-range values when validate_inputs is enabled.""" + loss_fn = create_loss( + "SoftClDiceLoss", + mode="binary", + num_iters=1, + clamp_probabilities=True, + validate_inputs=True, + ) + pred = torch.randn(1, 1, 4, 8, 8) * 3.0 # outside [0, 1] + target = (torch.randn(1, 1, 4, 8, 8) * 2.0) + 0.5 + + loss = loss_fn(pred, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_backward_produces_finite_gradients(self): + """Soft clDice should backpropagate finite gradients.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=2, sigmoid=True) + logits = torch.randn(2, 1, 4, 8, 8, requires_grad=True) + target = (torch.rand(2, 1, 4, 8, 8) > 0.5).float() + + loss = loss_fn(logits, target) + loss.backward() + + self.assertIsNotNone(logits.grad) + self.assertTrue(torch.all(torch.isfinite(logits.grad))) + + def test_soft_cldice_weight_channel_routing(self): + """Soft clDice should handle weight maps with full channel count in multi mode.""" + loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=1) + + pred = torch.softmax(torch.randn(1, 3, 4, 8, 8), dim=1) + labels = torch.randint(0, 3, (1, 4, 8, 8)) + target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float() + weight = torch.rand(1, 3, 4, 8, 8) + + loss = loss_fn(pred, target, weight=weight) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_weight_single_channel_broadcast(self): + """Soft clDice should broadcast a single-channel weight map.""" + loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=1) + + pred = torch.softmax(torch.randn(1, 3, 4, 8, 8), dim=1) + labels = torch.randint(0, 3, (1, 4, 8, 8)) + target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float() + weight = torch.rand(1, 1, 4, 8, 8) + + loss = loss_fn(pred, target, weight=weight) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_loss_orchestrator.py b/tests/unit/test_loss_orchestrator.py index 3a9347e3..af22104b 100644 --- a/tests/unit/test_loss_orchestrator.py +++ b/tests/unit/test_loss_orchestrator.py @@ -139,10 +139,12 @@ def _expected_pos_weight(target: torch.Tensor) -> torch.Tensor: def test_create_loss_attaches_metadata_for_supervised_and_regularization_losses(): weighted_bce = create_loss("WeightedBCEWithLogitsLoss") weighted_mse = create_loss("WeightedMSELoss") + soft_cldice = create_loss("SoftClDiceLoss") binary_reg = create_loss("BinaryRegularization") weighted_bce_meta = weighted_bce._connectomics_loss_metadata weighted_meta = weighted_mse._connectomics_loss_metadata + soft_cldice_meta = soft_cldice._connectomics_loss_metadata reg_meta = binary_reg._connectomics_loss_metadata assert weighted_bce_meta.name == "WeightedBCEWithLogitsLoss" @@ -153,6 +155,10 @@ def test_create_loss_attaches_metadata_for_supervised_and_regularization_losses( assert weighted_meta.call_kind == "pred_target" assert weighted_meta.spatial_weight_arg == "weight" + assert soft_cldice_meta.name == "SoftClDiceLoss" + assert soft_cldice_meta.call_kind == "pred_target" + assert soft_cldice_meta.spatial_weight_arg == "weight" + assert reg_meta.name == "BinaryRegularization" assert reg_meta.call_kind == "pred_only" assert reg_meta.spatial_weight_arg == "mask" @@ -170,6 +176,34 @@ def test_weighted_bce_rejects_unknown_loss_kwargs(): create_loss("WeightedBCEWithLogitsLoss", unknown_kwarg=True) +def test_soft_cldice_with_sigmoid_runs_through_orchestrator(): + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=1, sigmoid=True) + orchestrator = LossOrchestrator( + cfg=_cfg( + losses=[ + { + "weight": 1.0, + }, + ] + ), + loss_functions=nn.ModuleList([loss_fn]), + loss_weights=[1.0], + enable_nan_detection=False, + debug_on_nan=False, + ) + + outputs = torch.randn(1, 1, 4, 8, 8, requires_grad=True) + labels = (torch.rand(1, 1, 4, 8, 8) > 0.5).float() + + total_loss, loss_dict = orchestrator.compute_standard_loss(outputs, labels, stage="train") + total_loss.backward() + + assert torch.isfinite(total_loss) + assert "train_loss_total" in loss_dict + assert outputs.grad is not None + assert torch.all(torch.isfinite(outputs.grad)) + + def test_loss_orchestrator_requires_explicit_losses(): with pytest.raises(ValueError, match="model\\.loss\\.losses is required"): LossOrchestrator(