From b7ec9aa3f7abb95eaa049fb3387800e838398792 Mon Sep 17 00:00:00 2001 From: Bora <94760052+bogenc@users.noreply.github.com> Date: Tue, 12 May 2026 20:28:48 +0300 Subject: [PATCH 1/2] Add SoftClDiceLoss and register it in the loss stack SoftClDiceLoss was introduced as a new segmentation loss that focuses on preserving structure using soft skeletonization, with support for binary and multi-class modes and efficient tensor operations. It was fully integrated into the system through the loss factory, metadata, and module exports, and validated with unit tests covering functionality and correct orchestration behavior. --- connectomics/models/losses/__init__.py | 2 + connectomics/models/losses/build.py | 2 + connectomics/models/losses/losses.py | 252 ++++++++++++++++++++++++- connectomics/models/losses/metadata.py | 1 + tests/unit/test_loss_functions.py | 24 +++ tests/unit/test_loss_orchestrator.py | 6 + 6 files changed, 286 insertions(+), 1 deletion(-) diff --git a/connectomics/models/losses/__init__.py b/connectomics/models/losses/__init__.py index e8b35d3f..b3f73b73 100755 --- a/connectomics/models/losses/__init__.py +++ b/connectomics/models/losses/__init__.py @@ -16,6 +16,7 @@ # Connectomics-specific losses (for direct use if needed) from .losses import ( GANLoss, + SoftClDiceLoss, WeightedMAELoss, WeightedMSELoss, ) @@ -42,6 +43,7 @@ "get_loss_metadata", "get_loss_metadata_for_module", # Custom losses + "SoftClDiceLoss", "WeightedMSELoss", "WeightedMAELoss", "GANLoss", diff --git a/connectomics/models/losses/build.py b/connectomics/models/losses/build.py index cef8c3f4..b5d690f3 100644 --- a/connectomics/models/losses/build.py +++ b/connectomics/models/losses/build.py @@ -28,6 +28,7 @@ CrossEntropyLossWrapper, GANLoss, PerChannelBCEWithLogitsLoss, + SoftClDiceLoss, SmoothL1Loss, WeightedBCEWithLogitsLoss, WeightedMAELoss, @@ -70,6 +71,7 @@ def _get_loss_registry() -> Dict[str, type[nn.Module]]: # Custom connectomics losses "WeightedBCEWithLogitsLoss": WeightedBCEWithLogitsLoss, "PerChannelBCEWithLogitsLoss": PerChannelBCEWithLogitsLoss, + "SoftClDiceLoss": SoftClDiceLoss, "WeightedMSELoss": WeightedMSELoss, "WeightedMAELoss": WeightedMAELoss, "GANLoss": GANLoss, diff --git a/connectomics/models/losses/losses.py b/connectomics/models/losses/losses.py index 2092e6c2..aad2c58f 100644 --- a/connectomics/models/losses/losses.py +++ b/connectomics/models/losses/losses.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Union +from typing import List, Union import torch import torch.nn as nn @@ -44,6 +44,46 @@ def _reduce_weighted_tensor( return loss_tensor[valid].mean() +def _soft_erode_pool(prob: torch.Tensor) -> torch.Tensor: + """Differentiable morphological erosion using min-pool via max-pool.""" + if prob.ndim == 5: + p1 = -F.max_pool3d(-prob, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) + p2 = -F.max_pool3d(-prob, kernel_size=(1, 3, 1), stride=1, padding=(0, 1, 0)) + p3 = -F.max_pool3d(-prob, kernel_size=(1, 1, 3), stride=1, padding=(0, 0, 1)) + return torch.minimum(torch.minimum(p1, p2), p3) + if prob.ndim == 4: + p1 = -F.max_pool2d(-prob, kernel_size=(3, 1), stride=1, padding=(1, 0)) + p2 = -F.max_pool2d(-prob, kernel_size=(1, 3), stride=1, padding=(0, 1)) + return torch.minimum(p1, p2) + raise ValueError(f"Expected 4D/5D tensor for soft erosion, got shape {tuple(prob.shape)}") + + +def _soft_dilate_pool(prob: torch.Tensor) -> torch.Tensor: + """Differentiable morphological dilation.""" + if prob.ndim == 5: + return F.max_pool3d(prob, kernel_size=3, stride=1, padding=1) + if prob.ndim == 4: + return F.max_pool2d(prob, kernel_size=3, stride=1, padding=1) + raise ValueError(f"Expected 4D/5D tensor for soft dilation, got shape {tuple(prob.shape)}") + + +def _soft_open_pool(prob: torch.Tensor) -> torch.Tensor: + """Differentiable opening (erode followed by dilate).""" + return _soft_dilate_pool(_soft_erode_pool(prob)) + + +def _soft_skeletonize_pool(prob: torch.Tensor, num_iters: int) -> torch.Tensor: + """Iterative soft skeletonization from clDice-style morphology.""" + opened = _soft_open_pool(prob) + skeleton = F.relu(prob - opened) + for _ in range(num_iters): + prob = _soft_erode_pool(prob) + opened = _soft_open_pool(prob) + delta = F.relu(prob - opened) + skeleton = skeleton + F.relu(delta - skeleton * delta) + return skeleton + + class CrossEntropyLossWrapper(nn.Module): """ Wrapper for CrossEntropyLoss that handles shape conversion. @@ -310,6 +350,215 @@ def forward( return total +class SoftClDiceLoss(nn.Module): + """ + Soft clDice loss using differentiable skeletonization. + + This loss expects probability maps (sigmoid/softmax outputs), not logits. + """ + + def __init__( + self, + num_iters: int = 5, + mode: str = "binary", + reduction: str = "mean", + smooth: float = 1e-6, + foreground_channel: int = 1, + background_index: int = 0, + clamp_probabilities: bool = False, + use_fused_cuda: bool = False, + ): + super().__init__() + if num_iters < 0: + raise ValueError(f"num_iters must be >= 0, got {num_iters}") + if mode not in {"binary", "multi"}: + raise ValueError(f"mode must be 'binary' or 'multi', got {mode!r}") + if reduction not in {"none", "mean", "sum"}: + raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction!r}") + if smooth <= 0: + raise ValueError(f"smooth must be > 0, got {smooth}") + + self.num_iters = int(num_iters) + self.mode = mode + self.reduction = reduction + self.smooth = float(smooth) + self.foreground_channel = int(foreground_channel) + self.background_index = int(background_index) + self.clamp_probabilities = bool(clamp_probabilities) + self.use_fused_cuda = bool(use_fused_cuda) + + def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: + if target.ndim == pred.ndim - 1: + target = target.unsqueeze(1) + if target.ndim != pred.ndim: + raise ValueError( + f"Target ndim ({target.ndim}) does not match prediction ndim ({pred.ndim})" + ) + if target.shape[0] != pred.shape[0] or target.shape[2:] != pred.shape[2:]: + raise ValueError( + "Target shape must match prediction shape except for channel dimension: " + f"target={tuple(target.shape)}, pred={tuple(pred.shape)}" + ) + + if target.shape[1] == pred.shape[1]: + return target.to(device=pred.device, dtype=pred.dtype) + + if target.shape[1] == 1 and pred.shape[1] > 1: + class_index = target.squeeze(1).long() + min_label = int(class_index.min().item()) + max_label = int(class_index.max().item()) + if min_label < 0 or max_label >= pred.shape[1]: + raise ValueError( + f"Class-index targets must be in [0, {pred.shape[1] - 1}], " + f"got min={min_label}, max={max_label}" + ) + one_hot = F.one_hot(class_index, num_classes=pred.shape[1]).movedim(-1, 1) + return one_hot.to(device=pred.device, dtype=pred.dtype) + + raise ValueError( + "Target channel count is incompatible with prediction: " + f"target_channels={target.shape[1]}, pred_channels={pred.shape[1]}" + ) + + def _select_foreground_channels( + self, pred: torch.Tensor, target: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, List[int]]: + channels = pred.shape[1] + if self.mode == "binary": + fg_idx = 0 if channels == 1 else self.foreground_channel + if fg_idx < 0 or fg_idx >= channels: + raise ValueError( + f"foreground_channel={self.foreground_channel} is invalid for {channels} channels" + ) + return pred[:, fg_idx : fg_idx + 1], target[:, fg_idx : fg_idx + 1], [fg_idx] + + if channels == 1: + return pred, target, [0] + + background_index = self.background_index + if background_index < 0: + background_index += channels + if background_index < 0 or background_index >= channels: + raise ValueError( + f"background_index={self.background_index} is invalid for {channels} channels" + ) + + foreground_indices = [idx for idx in range(channels) if idx != background_index] + if not foreground_indices: + raise ValueError( + f"No foreground classes available: channels={channels}, " + f"background_index={self.background_index}" + ) + index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long) + return ( + torch.index_select(pred, dim=1, index=index_tensor), + torch.index_select(target, dim=1, index=index_tensor), + foreground_indices, + ) + + def _prepare_weight( + self, + weight: torch.Tensor, + pred: torch.Tensor, + foreground_indices: List[int], + num_fg_channels: int, + ) -> torch.Tensor: + if weight.ndim == pred.ndim - 1: + weight = weight.unsqueeze(1) + if weight.ndim != pred.ndim: + raise ValueError(f"Weight ndim ({weight.ndim}) must match pred ndim ({pred.ndim})") + if weight.shape[0] != pred.shape[0] or weight.shape[2:] != pred.shape[2:]: + raise ValueError( + "Weight shape must match prediction shape except for channel dimension: " + f"weight={tuple(weight.shape)}, pred={tuple(pred.shape)}" + ) + + weight = weight.to(device=pred.device, dtype=pred.dtype) + if weight.shape[1] == num_fg_channels: + return weight + if weight.shape[1] == 1: + return weight.expand(weight.shape[0], num_fg_channels, *weight.shape[2:]) + if weight.shape[1] == pred.shape[1]: + if num_fg_channels == pred.shape[1]: + return weight + index_tensor = torch.tensor(foreground_indices, device=pred.device, dtype=torch.long) + return torch.index_select(weight, dim=1, index=index_tensor) + + raise ValueError( + "Weight channel count must be 1, foreground-channel count, or prediction-channel count; " + f"got {weight.shape[1]}" + ) + + def _soft_skeletonize(self, prob: torch.Tensor) -> torch.Tensor: + if self.use_fused_cuda and prob.is_cuda: + connectomics_ops = getattr(torch.ops, "connectomics", None) + fused_op = getattr(connectomics_ops, "soft_skeletonize", None) + if fused_op is not None: + try: + return fused_op(prob, self.num_iters) + except (RuntimeError, TypeError): + pass + return _soft_skeletonize_pool(prob, self.num_iters) + + def forward( + self, + pred: torch.Tensor, + target: torch.Tensor, + weight: torch.Tensor = None, + ) -> torch.Tensor: + if pred.ndim not in {4, 5}: + raise ValueError(f"SoftClDiceLoss expects 4D/5D tensors, got {tuple(pred.shape)}") + + target = self._prepare_target(target, pred) + pred_fg, target_fg, foreground_indices = self._select_foreground_channels(pred, target) + + if self.clamp_probabilities: + pred_fg = pred_fg.clamp(0.0, 1.0) + target_fg = target_fg.clamp(0.0, 1.0) + + fg_weight = None + if weight is not None: + fg_weight = self._prepare_weight(weight, pred, foreground_indices, pred_fg.shape[1]) + if self.clamp_probabilities: + fg_weight = fg_weight.clamp_min(0.0) + + pred_skeleton = self._soft_skeletonize(pred_fg) + target_skeleton = self._soft_skeletonize(target_fg) + + if fg_weight is not None: + pred_eval = pred_fg * fg_weight + target_eval = target_fg * fg_weight + pred_skeleton_eval = pred_skeleton * fg_weight + target_skeleton_eval = target_skeleton * fg_weight + else: + pred_eval = pred_fg + target_eval = target_fg + pred_skeleton_eval = pred_skeleton + target_skeleton_eval = target_skeleton + + spatial_dims = tuple(range(2, pred_fg.ndim)) + topology_precision = ( + (pred_skeleton_eval * target_eval).sum(dim=spatial_dims) + self.smooth + ) / (pred_skeleton_eval.sum(dim=spatial_dims) + self.smooth) + topology_sensitivity = ( + (target_skeleton_eval * pred_eval).sum(dim=spatial_dims) + self.smooth + ) / (target_skeleton_eval.sum(dim=spatial_dims) + self.smooth) + + cl_dice = ( + 2.0 + * topology_precision + * topology_sensitivity + / (topology_precision + topology_sensitivity + self.smooth) + ) + loss = 1.0 - cl_dice + + if self.reduction == "none": + return loss + if self.reduction == "sum": + return loss.sum() + return loss.mean() + + class WeightedMAELoss(nn.Module): """ Weighted mean absolute error loss. @@ -478,6 +727,7 @@ def forward( "CrossEntropyLossWrapper", "WeightedBCEWithLogitsLoss", "PerChannelBCEWithLogitsLoss", + "SoftClDiceLoss", "WeightedMSELoss", "WeightedMAELoss", "SmoothL1Loss", diff --git a/connectomics/models/losses/metadata.py b/connectomics/models/losses/metadata.py index 0925a4d7..095026e3 100644 --- a/connectomics/models/losses/metadata.py +++ b/connectomics/models/losses/metadata.py @@ -41,6 +41,7 @@ class LossMetadata: "PerChannelBCEWithLogitsLoss": LossMetadata( "PerChannelBCEWithLogitsLoss", spatial_weight_arg="weight" ), + "SoftClDiceLoss": LossMetadata("SoftClDiceLoss", spatial_weight_arg="weight"), "WeightedMSELoss": LossMetadata("WeightedMSELoss", spatial_weight_arg="weight"), "WeightedMAELoss": LossMetadata("WeightedMAELoss", spatial_weight_arg="weight"), # GAN is not compatible with the generic supervised orchestrator path diff --git a/tests/unit/test_loss_functions.py b/tests/unit/test_loss_functions.py index d33885ca..74ada045 100755 --- a/tests/unit/test_loss_functions.py +++ b/tests/unit/test_loss_functions.py @@ -3,6 +3,7 @@ import unittest import torch +import torch.nn.functional as F from connectomics.models.losses import create_loss @@ -37,6 +38,29 @@ def test_tversky_loss(self): loss = loss_fn(pred, target) self.assertTrue(loss >= 0.0) + def test_soft_cldice_binary_probabilities(self): + """Soft clDice in binary mode should accept probability tensors.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=3) + pred = torch.rand(2, 1, 4, 8, 8) + target = (torch.rand(2, 1, 4, 8, 8) > 0.5).float() + + loss = loss_fn(pred, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_multi_excludes_background(self): + """Soft clDice in multi mode should use foreground classes only.""" + loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=2) + + logits = torch.randn(1, 3, 4, 8, 8) + pred = torch.softmax(logits, dim=1) + labels = torch.randint(0, 3, (1, 4, 8, 8)) + target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float() + + loss = loss_fn(pred, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_loss_orchestrator.py b/tests/unit/test_loss_orchestrator.py index 3a9347e3..af073943 100644 --- a/tests/unit/test_loss_orchestrator.py +++ b/tests/unit/test_loss_orchestrator.py @@ -139,10 +139,12 @@ def _expected_pos_weight(target: torch.Tensor) -> torch.Tensor: def test_create_loss_attaches_metadata_for_supervised_and_regularization_losses(): weighted_bce = create_loss("WeightedBCEWithLogitsLoss") weighted_mse = create_loss("WeightedMSELoss") + soft_cldice = create_loss("SoftClDiceLoss") binary_reg = create_loss("BinaryRegularization") weighted_bce_meta = weighted_bce._connectomics_loss_metadata weighted_meta = weighted_mse._connectomics_loss_metadata + soft_cldice_meta = soft_cldice._connectomics_loss_metadata reg_meta = binary_reg._connectomics_loss_metadata assert weighted_bce_meta.name == "WeightedBCEWithLogitsLoss" @@ -153,6 +155,10 @@ def test_create_loss_attaches_metadata_for_supervised_and_regularization_losses( assert weighted_meta.call_kind == "pred_target" assert weighted_meta.spatial_weight_arg == "weight" + assert soft_cldice_meta.name == "SoftClDiceLoss" + assert soft_cldice_meta.call_kind == "pred_target" + assert soft_cldice_meta.spatial_weight_arg == "weight" + assert reg_meta.name == "BinaryRegularization" assert reg_meta.call_kind == "pred_only" assert reg_meta.spatial_weight_arg == "mask" From 8b08486aac34d316b5309dd7d14d1adef2413c6a Mon Sep 17 00:00:00 2001 From: Bora <94760052+bogenc@users.noreply.github.com> Date: Sun, 17 May 2026 19:30:02 +0300 Subject: [PATCH 2/2] Updated SoftClDiceLoss to be safe in the logits-first training flow The updated SoftClDiceLoss now includes MONAI-style activation options (sigmoid or softmax, mutually exclusive), input validation for probability ranges and spatial size, and a higher default smoothing value (1.0) to stabilize behavior when skeletons are nearly empty. Removed an unnecessary fallback path by deleting the unused fused CUDA option (use_fused_cuda / torch.ops.connectomics.soft_skeletonize). The implementation now consistently relies on the pooled differentiable morphology approach as the single, canonical method. Improved clarity throughout the code by documenting that targets are expected to be dense, explaining argument behavior, and noting that 3D erosion intentionally uses axis-aligned cross-shaped kernels. Typing was tightened (weight: torch.Tensor | None), and export ordering was cleaned up. Expanded test coverage to reflect realistic failure cases and integration behavior. This includes rejecting logits without activation, accepting them when sigmoid is enabled, clamping out-of-range inputs, verifying backward pass stability with finite gradients, checking weight handling (both per-channel and broadcast), and validating integration through the SoftClDiceLoss orchestrator. Added a new configuration example (loss_soft_cldice) to improve discoverability, located in connectomics/config/profiles/loss_profiles.yaml. --- .../config/profiles/loss_profiles.yaml | 5 + connectomics/models/losses/__init__.py | 4 +- connectomics/models/losses/losses.py | 109 ++++++++++++++---- tests/unit/test_loss_functions.py | 73 ++++++++++++ tests/unit/test_loss_orchestrator.py | 28 +++++ 5 files changed, 193 insertions(+), 26 deletions(-) diff --git a/connectomics/config/profiles/loss_profiles.yaml b/connectomics/config/profiles/loss_profiles.yaml index b04cc604..56c5b29a 100644 --- a/connectomics/config/profiles/loss_profiles.yaml +++ b/connectomics/config/profiles/loss_profiles.yaml @@ -7,6 +7,11 @@ loss_profiles: - function: DiceLoss weight: 1.0 kwargs: {sigmoid: true, smooth_nr: 1e-5, smooth_dr: 1e-5} + loss_soft_cldice: + losses: + - function: SoftClDiceLoss + weight: 1.0 + kwargs: {mode: binary, num_iters: 5, sigmoid: true} loss_bcd: losses: - function: WeightedBCEWithLogitsLoss diff --git a/connectomics/models/losses/__init__.py b/connectomics/models/losses/__init__.py index b3f73b73..280d0760 100755 --- a/connectomics/models/losses/__init__.py +++ b/connectomics/models/losses/__init__.py @@ -43,10 +43,10 @@ "get_loss_metadata", "get_loss_metadata_for_module", # Custom losses + "GANLoss", "SoftClDiceLoss", - "WeightedMSELoss", "WeightedMAELoss", - "GANLoss", + "WeightedMSELoss", # Regularization losses "BinaryRegularization", "ForegroundDistanceConsistency", diff --git a/connectomics/models/losses/losses.py b/connectomics/models/losses/losses.py index aad2c58f..07e8d7f9 100644 --- a/connectomics/models/losses/losses.py +++ b/connectomics/models/losses/losses.py @@ -47,6 +47,7 @@ def _reduce_weighted_tensor( def _soft_erode_pool(prob: torch.Tensor) -> torch.Tensor: """Differentiable morphological erosion using min-pool via max-pool.""" if prob.ndim == 5: + # Use an axis-aligned cross (3x1x1, 1x3x1, 1x1x3), matching clDice-style soft morphology. p1 = -F.max_pool3d(-prob, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0)) p2 = -F.max_pool3d(-prob, kernel_size=(1, 3, 1), stride=1, padding=(0, 1, 0)) p3 = -F.max_pool3d(-prob, kernel_size=(1, 1, 3), stride=1, padding=(0, 0, 1)) @@ -354,7 +355,24 @@ class SoftClDiceLoss(nn.Module): """ Soft clDice loss using differentiable skeletonization. - This loss expects probability maps (sigmoid/softmax outputs), not logits. + Supports optional activation on logits (`sigmoid=True` or `softmax=True`), + MONAI-style, before topology computation. + Targets must be dense maps (one-hot or soft labels); single-channel class-index + targets are accepted and converted to one-hot for multi-class predictions. + + Args: + num_iters: Number of soft-skeleton erosion iterations. + mode: ``"binary"`` (single foreground channel) or ``"multi"`` (all classes except + ``background_index``). + reduction: ``"none"``, ``"mean"``, or ``"sum"``. + smooth: Smoothing constant in topology precision/sensitivity fractions. + foreground_channel: Foreground channel index used in ``mode="binary"`` when C>1. + background_index: Background channel index excluded in ``mode="multi"``. + sigmoid: Apply sigmoid activation to predictions inside ``forward``. + softmax: Apply softmax activation to predictions inside ``forward``. + clamp_probabilities: Clamp predictions/targets to ``[0, 1]`` before skeletonization. + validate_inputs: When True, enforce probability-range and minimal spatial-size checks. + validation_tolerance: Tolerance used by probability-range checks. """ def __init__( @@ -362,11 +380,14 @@ def __init__( num_iters: int = 5, mode: str = "binary", reduction: str = "mean", - smooth: float = 1e-6, + smooth: float = 1.0, foreground_channel: int = 1, background_index: int = 0, + sigmoid: bool = False, + softmax: bool = False, clamp_probabilities: bool = False, - use_fused_cuda: bool = False, + validate_inputs: bool = True, + validation_tolerance: float = 1e-5, ): super().__init__() if num_iters < 0: @@ -377,6 +398,10 @@ def __init__( raise ValueError(f"reduction must be 'none', 'mean', or 'sum', got {reduction!r}") if smooth <= 0: raise ValueError(f"smooth must be > 0, got {smooth}") + if sigmoid and softmax: + raise ValueError("sigmoid and softmax are mutually exclusive") + if validation_tolerance < 0: + raise ValueError(f"validation_tolerance must be >= 0, got {validation_tolerance}") self.num_iters = int(num_iters) self.mode = mode @@ -384,8 +409,11 @@ def __init__( self.smooth = float(smooth) self.foreground_channel = int(foreground_channel) self.background_index = int(background_index) + self.sigmoid = bool(sigmoid) + self.softmax = bool(softmax) self.clamp_probabilities = bool(clamp_probabilities) - self.use_fused_cuda = bool(use_fused_cuda) + self.validate_inputs = bool(validate_inputs) + self.validation_tolerance = float(validation_tolerance) def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: if target.ndim == pred.ndim - 1: @@ -420,6 +448,38 @@ def _prepare_target(self, target: torch.Tensor, pred: torch.Tensor) -> torch.Ten f"target_channels={target.shape[1]}, pred_channels={pred.shape[1]}" ) + def _apply_activation(self, pred: torch.Tensor) -> torch.Tensor: + if self.sigmoid: + return torch.sigmoid(pred) + if self.softmax: + if pred.shape[1] < 2: + raise ValueError("softmax=True requires prediction with at least 2 channels") + return F.softmax(pred, dim=1) + return pred + + def _validate_probability_range(self, tensor: torch.Tensor, name: str) -> None: + if not self.validate_inputs: + return + tol = self.validation_tolerance + min_val = float(tensor.min().item()) + max_val = float(tensor.max().item()) + if min_val < -tol or max_val > (1.0 + tol): + raise ValueError( + f"{name} must be probabilities in [0, 1] (tolerance={tol}), " + f"got min={min_val:.6f}, max={max_val:.6f}. " + "Pass sigmoid=True/softmax=True for logits." + ) + + def _validate_spatial_shape(self, pred: torch.Tensor) -> None: + if not self.validate_inputs: + return + spatial_shape = tuple(pred.shape[2:]) + if any(dim < 3 for dim in spatial_shape): + raise ValueError( + "SoftClDiceLoss expects each spatial dimension >= 3 for stable morphology, " + f"got spatial shape {spatial_shape}" + ) + def _select_foreground_channels( self, pred: torch.Tensor, target: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, List[int]]: @@ -428,7 +488,8 @@ def _select_foreground_channels( fg_idx = 0 if channels == 1 else self.foreground_channel if fg_idx < 0 or fg_idx >= channels: raise ValueError( - f"foreground_channel={self.foreground_channel} is invalid for {channels} channels" + f"foreground_channel={self.foreground_channel} is invalid " + f"for {channels} channels" ) return pred[:, fg_idx : fg_idx + 1], target[:, fg_idx : fg_idx + 1], [fg_idx] @@ -485,31 +546,31 @@ def _prepare_weight( return torch.index_select(weight, dim=1, index=index_tensor) raise ValueError( - "Weight channel count must be 1, foreground-channel count, or prediction-channel count; " + "Weight channel count must be 1, foreground-channel count, " + "or prediction-channel count; " f"got {weight.shape[1]}" ) - def _soft_skeletonize(self, prob: torch.Tensor) -> torch.Tensor: - if self.use_fused_cuda and prob.is_cuda: - connectomics_ops = getattr(torch.ops, "connectomics", None) - fused_op = getattr(connectomics_ops, "soft_skeletonize", None) - if fused_op is not None: - try: - return fused_op(prob, self.num_iters) - except (RuntimeError, TypeError): - pass - return _soft_skeletonize_pool(prob, self.num_iters) - def forward( self, pred: torch.Tensor, target: torch.Tensor, - weight: torch.Tensor = None, + weight: torch.Tensor | None = None, ) -> torch.Tensor: if pred.ndim not in {4, 5}: raise ValueError(f"SoftClDiceLoss expects 4D/5D tensors, got {tuple(pred.shape)}") + pred = self._apply_activation(pred) target = self._prepare_target(target, pred) + + if self.clamp_probabilities: + pred = pred.clamp(0.0, 1.0) + target = target.clamp(0.0, 1.0) + + self._validate_spatial_shape(pred) + self._validate_probability_range(pred, "pred") + self._validate_probability_range(target, "target") + pred_fg, target_fg, foreground_indices = self._select_foreground_channels(pred, target) if self.clamp_probabilities: @@ -522,8 +583,8 @@ def forward( if self.clamp_probabilities: fg_weight = fg_weight.clamp_min(0.0) - pred_skeleton = self._soft_skeletonize(pred_fg) - target_skeleton = self._soft_skeletonize(target_fg) + pred_skeleton = _soft_skeletonize_pool(pred_fg, self.num_iters) + target_skeleton = _soft_skeletonize_pool(target_fg, self.num_iters) if fg_weight is not None: pred_eval = pred_fg * fg_weight @@ -725,11 +786,11 @@ def forward( __all__ = [ "CrossEntropyLossWrapper", - "WeightedBCEWithLogitsLoss", + "GANLoss", "PerChannelBCEWithLogitsLoss", + "SmoothL1Loss", "SoftClDiceLoss", - "WeightedMSELoss", + "WeightedBCEWithLogitsLoss", "WeightedMAELoss", - "SmoothL1Loss", - "GANLoss", + "WeightedMSELoss", ] diff --git a/tests/unit/test_loss_functions.py b/tests/unit/test_loss_functions.py index 74ada045..856249ca 100755 --- a/tests/unit/test_loss_functions.py +++ b/tests/unit/test_loss_functions.py @@ -61,6 +61,79 @@ def test_soft_cldice_multi_excludes_background(self): self.assertTrue(torch.isfinite(loss)) self.assertTrue(loss >= 0.0) + def test_soft_cldice_rejects_logits_without_activation(self): + """Soft clDice should fail fast on logits when no activation is requested.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=1, validate_inputs=True) + logits = torch.randn(1, 1, 4, 8, 8) + target = (torch.rand(1, 1, 4, 8, 8) > 0.5).float() + + with self.assertRaisesRegex(ValueError, "must be probabilities in \\[0, 1\\]"): + _ = loss_fn(logits, target) + + def test_soft_cldice_accepts_logits_with_sigmoid(self): + """Soft clDice should support logits when sigmoid=True.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=2, sigmoid=True) + logits = torch.randn(1, 1, 4, 8, 8) + target = (torch.rand(1, 1, 4, 8, 8) > 0.5).float() + + loss = loss_fn(logits, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_clamp_probabilities_allows_out_of_range_inputs(self): + """Clamping should tolerate out-of-range values when validate_inputs is enabled.""" + loss_fn = create_loss( + "SoftClDiceLoss", + mode="binary", + num_iters=1, + clamp_probabilities=True, + validate_inputs=True, + ) + pred = torch.randn(1, 1, 4, 8, 8) * 3.0 # outside [0, 1] + target = (torch.randn(1, 1, 4, 8, 8) * 2.0) + 0.5 + + loss = loss_fn(pred, target) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_backward_produces_finite_gradients(self): + """Soft clDice should backpropagate finite gradients.""" + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=2, sigmoid=True) + logits = torch.randn(2, 1, 4, 8, 8, requires_grad=True) + target = (torch.rand(2, 1, 4, 8, 8) > 0.5).float() + + loss = loss_fn(logits, target) + loss.backward() + + self.assertIsNotNone(logits.grad) + self.assertTrue(torch.all(torch.isfinite(logits.grad))) + + def test_soft_cldice_weight_channel_routing(self): + """Soft clDice should handle weight maps with full channel count in multi mode.""" + loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=1) + + pred = torch.softmax(torch.randn(1, 3, 4, 8, 8), dim=1) + labels = torch.randint(0, 3, (1, 4, 8, 8)) + target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float() + weight = torch.rand(1, 3, 4, 8, 8) + + loss = loss_fn(pred, target, weight=weight) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + + def test_soft_cldice_weight_single_channel_broadcast(self): + """Soft clDice should broadcast a single-channel weight map.""" + loss_fn = create_loss("SoftClDiceLoss", mode="multi", num_iters=1) + + pred = torch.softmax(torch.randn(1, 3, 4, 8, 8), dim=1) + labels = torch.randint(0, 3, (1, 4, 8, 8)) + target = F.one_hot(labels, num_classes=3).movedim(-1, 1).float() + weight = torch.rand(1, 1, 4, 8, 8) + + loss = loss_fn(pred, target, weight=weight) + self.assertTrue(torch.isfinite(loss)) + self.assertTrue(loss >= 0.0) + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/test_loss_orchestrator.py b/tests/unit/test_loss_orchestrator.py index af073943..af22104b 100644 --- a/tests/unit/test_loss_orchestrator.py +++ b/tests/unit/test_loss_orchestrator.py @@ -176,6 +176,34 @@ def test_weighted_bce_rejects_unknown_loss_kwargs(): create_loss("WeightedBCEWithLogitsLoss", unknown_kwarg=True) +def test_soft_cldice_with_sigmoid_runs_through_orchestrator(): + loss_fn = create_loss("SoftClDiceLoss", mode="binary", num_iters=1, sigmoid=True) + orchestrator = LossOrchestrator( + cfg=_cfg( + losses=[ + { + "weight": 1.0, + }, + ] + ), + loss_functions=nn.ModuleList([loss_fn]), + loss_weights=[1.0], + enable_nan_detection=False, + debug_on_nan=False, + ) + + outputs = torch.randn(1, 1, 4, 8, 8, requires_grad=True) + labels = (torch.rand(1, 1, 4, 8, 8) > 0.5).float() + + total_loss, loss_dict = orchestrator.compute_standard_loss(outputs, labels, stage="train") + total_loss.backward() + + assert torch.isfinite(total_loss) + assert "train_loss_total" in loss_dict + assert outputs.grad is not None + assert torch.all(torch.isfinite(outputs.grad)) + + def test_loss_orchestrator_requires_explicit_losses(): with pytest.raises(ValueError, match="model\\.loss\\.losses is required"): LossOrchestrator(