Skip to content

Add SoftClDiceLoss and register it in the loss stack#203

Open
bogenc wants to merge 2 commits into
PytorchConnectomics:masterfrom
bogenc:master
Open

Add SoftClDiceLoss and register it in the loss stack#203
bogenc wants to merge 2 commits into
PytorchConnectomics:masterfrom
bogenc:master

Conversation

@bogenc
Copy link
Copy Markdown

@bogenc bogenc commented May 12, 2026

SoftClDiceLoss was introduced as a new segmentation loss that focuses on preserving structure using soft skeletonization, with support for binary and multi-class modes and efficient tensor operations. It was fully integrated into the system through the loss factory, metadata, and module exports, and validated with unit tests covering functionality and correct orchestration behavior.

SoftClDiceLoss was introduced as a new segmentation loss that focuses on preserving structure using soft skeletonization, with support for binary and multi-class modes and efficient tensor operations. It was fully integrated into the system through the loss factory, metadata, and module exports, and validated with unit tests covering functionality and correct orchestration behavior.
Copy link
Copy Markdown
Collaborator

@donglaiw donglaiw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

Overview

Adds SoftClDiceLoss (skeleton-aware Dice using differentiable morphology), registers it in the loss factory + metadata table, and adds two unit tests plus one orchestrator-metadata assertion. Net +286/-1.

High-impact concerns

1. Logits-vs-probabilities mismatch will silently produce garbage. The class docstring states it "expects probability maps (sigmoid/softmax outputs), not logits", but LossOrchestrator passes raw model outputs to all losses (see orchestrator.py:497–547 — no activation step). Default training stacks use BCEWithLogitsLoss + MONAI DiceLoss(sigmoid=True), so model output is logits. A user composing SoftClDiceLoss next to BCEWithLogitsLoss will:

  • feed unbounded values into _soft_skeletonize_pool,
  • get a meaningless "skeleton" (max-pool of negatives, ReLU clipping),
  • and get no warning.

Recommendation: add sigmoid: bool = False / softmax: bool = False constructor flags matching MONAI's DiceLoss convention, applied inside forward before skeletonization. At minimum, raise if pred.min() < 0 or pred.max() > 1 (gated by a validate_inputs=True flag for dev runs).

2. use_fused_cuda is dead code. It looks for torch.ops.connectomics.soft_skeletonize, which doesn't exist anywhere in this repo, and silently falls back. Per CLAUDE.md (no speculative configurability) — drop the flag and the _soft_skeletonize indirection until the fused op actually lands. It's also untestable as written.

Correctness / API

  • losses.py weight: torch.Tensor = None — should be Optional[torch.Tensor]. Other losses in this file follow that convention.
  • smooth=1e-6 default is unusually small for clDice (paper / reference impls use ~1.0). With near-empty skeletons this is numerically jittery. Consider 1.0 or 1e-3.
  • _prepare_target ambiguity (multi mode): if a caller passes a class-index map shaped (B, C, H, W) with C == pred.shape[1] but containing class indices (not one-hot), the code treats them as probabilities. Orchestrator metadata enforces target_kind=\"dense\", so mostly theoretical, but the docstring should say the loss requires dense one-hot/probability targets.
  • 3D erosion uses an axis-aligned cross (3×1×1, 1×3×1, 1×1×3) min. Matches the original clDice formulation, but worth a one-line comment so a future maintainer doesn't "fix" it to a 3×3×3 kernel.
  • No size guard for spatial dims < 3. max_pool3d(kernel=3, padding=1) on a depth-1 patch will run but produce no shrinkage; for size-2 dims the result is undefined-ish. Add a shape assert or document the minimum spatial size.

Style / project convention

  • Unused import: Union in from typing import List, Union — only List is used.
  • Class docstring is one sentence; other losses in this file have multi-line arg docs. At minimum document mode, foreground_channel, background_index, clamp_probabilities, and the activation expectation.
  • __all__ ordering in losses.py (SoftClDiceLoss between PerChannelBCEWithLogitsLoss and WeightedMSELoss) and the alphabetical order in losses/__init__.py (between GANLoss and WeightedMAELoss) are inconsistent. Pick one.
  • No tutorial / docs update. CLAUDE.md "Adding New Loss Functions" explicitly asks for "Update documentation". A short example YAML snippet (or a row in the docs table) would let agents discover this loss without grepping.

Test coverage

The added tests verify forward runs and returns finite ≥0 values for binary + multi. Gaps that matter for a new loss in this codebase:

  • No backward / grad test (loss.backward() then check pred.grad is finite). Skeletonization is iterative; a regression breaking grad would slip past these tests.
  • No weight-arg test — the _prepare_weight branches (channels=1, channels=fg, channels=pred) are entirely uncovered.
  • No logits-input test showing what the loss does when fed unbounded values (this is the most likely user mistake).
  • No clamp_probabilities=True test.
  • No deep-supervision integration test through LossOrchestrator with SoftClDiceLoss to confirm metadata wiring lines up at non-main scales.

Performance note

_soft_skeletonize_pool runs ~3*(num_iters+1) 3D max-pools per call, on both pred and target. Targets are static across an epoch — caching the target skeleton (or computing it once on CPU in the dataloader) is a substantial win for typical num_iters=5. Not blocking, but worth a TODO.

Verdict

Functionally plausible and well-isolated, but the logits-vs-probabilities silent failure mode and the dead use_fused_cuda path should be addressed before merge. Tests need a backward + weight-path case to be load-bearing.

@donglaiw
Copy link
Copy Markdown
Collaborator

Thank you for the pull request! you may find this repo useful to let codex and claude code to collaborate: https://github.com/donglaiw/ccc-duet

The updated SoftClDiceLoss now includes MONAI-style activation options (sigmoid or softmax, mutually exclusive), input validation for probability ranges and spatial size, and a higher default smoothing value (1.0) to stabilize behavior when skeletons are nearly empty.

Removed an unnecessary fallback path by deleting the unused fused CUDA option (use_fused_cuda / torch.ops.connectomics.soft_skeletonize). The implementation now consistently relies on the pooled differentiable morphology approach as the single, canonical method.

Improved clarity throughout the code by documenting that targets are expected to be dense, explaining argument behavior, and noting that 3D erosion intentionally uses axis-aligned cross-shaped kernels. Typing was tightened (weight: torch.Tensor | None), and export ordering was cleaned up.

Expanded test coverage to reflect realistic failure cases and integration behavior. This includes rejecting logits without activation, accepting them when sigmoid is enabled, clamping out-of-range inputs, verifying backward pass stability with finite gradients, checking weight handling (both per-channel and broadcast), and validating integration through the SoftClDiceLoss orchestrator.

Added a new configuration example (loss_soft_cldice) to improve discoverability, located in connectomics/config/profiles/loss_profiles.yaml.
@bogenc
Copy link
Copy Markdown
Author

bogenc commented May 17, 2026

Hi, thanks for the review, you were right to point out the logits vs. probabilities issue and the dead fused CUDA path, those were both real issues. I’ve pushed updates addressing the concerns:

  • Made the loss safe in the logits-first training flow by adding MONAI-style sigmoid / softmax flags (mutually exclusive) and optional input validation. It now explicitly rejects raw logits unless an activation is enabled.
  • Removed the unused use_fused_cuda path and kept the pooled differentiable morphology as the single, tested implementation.
  • Switched smooth to 1.0 for stability with sparse skeletons.
  • Clarified the API and expectations (dense targets, argument behavior, 3D cross-shaped erosion note), and cleaned up typing/imports/ordering.

On the testing side, I expanded coverage to hit the gaps you mentioned:

  • logits rejection + correct behavior with sigmoid=True
  • clamp path
  • backward pass with finite gradients
  • weight routing branches
  • orchestrator integration path

Also added a config profile example so the loss is easier to discover/use.

Let me know if anything still feels off or if you'd like stricter validation defaults or different activation behavior.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants