@@ -35,8 +35,9 @@ class SurfaceDiceMetric(CumulativeIterationMetric):
3535 Computes the Normalized Surface Dice (NSD) for each batch sample and class of
3636 predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.
3737 This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
38- Be aware that the computation of boundaries is different from DeepMind's implementation
39- https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is
38+ Be aware that by default (`use_subvoxels=False`), the computation of boundaries is different from DeepMind's
39+ mplementation https://github.com/deepmind/surface-distance.
40+ In this implementation, the length/area of a segmentation boundary is
4041 interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
4142 depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
4243 This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103.
@@ -86,7 +87,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
8687 It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
8788 y: Reference segmentation.
8889 It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
89- kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
90+ kwargs: additional parameters: ``spacing`` should be passed to correctly compute the metric.
9091 ``spacing``: spacing of pixel (or voxel). This parameter is relevant only
9192 if ``distance_metric`` is set to ``"euclidean"``.
9293 If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
@@ -96,6 +97,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
9697 If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
9798 else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
9899 for all images in batch. Defaults to ``None``.
100+ use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.
101+
99102
100103 Returns:
101104 Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch
@@ -108,6 +111,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
108111 include_background = self .include_background ,
109112 distance_metric = self .distance_metric ,
110113 spacing = kwargs .get ("spacing" ),
114+ use_subvoxels = kwargs .get ("use_subvoxels" , False ),
111115 )
112116
113117 def aggregate (
@@ -141,13 +145,14 @@ def compute_surface_dice(
141145 include_background : bool = False ,
142146 distance_metric : str = "euclidean" ,
143147 spacing : int | float | np .ndarray | Sequence [int | float | np .ndarray | Sequence [int | float ]] | None = None ,
148+ use_subvoxels : bool = False ,
144149) -> torch .Tensor :
145150 r"""
146151 This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as
147152 :math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation
148153 boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the
149- reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in
150- pixels. The NSD is bounded between 0 and 1.
154+ reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation
155+ in pixels. The NSD is bounded between 0 and 1.
151156
152157 This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`.
153158 The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function:
@@ -159,24 +164,23 @@ def compute_surface_dice(
159164 :label: nsd
160165
161166 with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor
162- distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation
163- boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
167+ distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference
168+ segmentation boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
164169 :math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the
165170 acceptable distance :math:`\tau_c`:
166171
167172 .. math::
168173 \mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}.
169174
170175
171- In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value
172- will be returned for this class. In the case of a class being present in only one of predicted segmentation or
173- reference segmentation, the class NSD will be 0.
176+ In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation,
177+ a nan value will be returned for this class. In the case of a class being present in only one of predicted
178+ segmentation or reference segmentation, the class NSD will be 0.
174179
175180 This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
176- Be aware that the computation of boundaries is different from DeepMind's implementation
177- https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
178- interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
179- depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
181+ The computation of boundaries follows DeepMind's implementation
182+ https://github.com/deepmind/surface-distance when `use_subvoxels=True`; Otherwise the length of a segmentation
183+ boundary is interpreted as the number of its edge pixels.
180184
181185 Args:
182186 y_pred: Predicted segmentation, typically segmentation model output.
@@ -198,6 +202,7 @@ def compute_surface_dice(
198202 If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
199203 else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
200204 for all images in batch. Defaults to ``None``.
205+ use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.
201206
202207 Raises:
203208 ValueError: If `y_pred` and/or `y` are not PyTorch tensors.
@@ -227,11 +232,6 @@ def compute_surface_dice(
227232 f"y_pred and y should have same shape, but instead, shapes are { y_pred .shape } (y_pred) and { y .shape } (y)."
228233 )
229234
230- if not torch .all (y_pred .byte () == y_pred ) or not torch .all (y .byte () == y ):
231- raise ValueError ("y_pred and y should be binarized tensors (e.g. torch.int64)." )
232- if torch .any (y_pred > 1 ) or torch .any (y > 1 ):
233- raise ValueError ("y_pred and y should be one-hot encoded." )
234-
235235 y = y .float ()
236236 y_pred = y_pred .float ()
237237
@@ -254,24 +254,37 @@ def compute_surface_dice(
254254 spacing_list = prepare_spacing (spacing = spacing , batch_size = batch_size , img_dim = img_dim )
255255
256256 for b , c in np .ndindex (batch_size , n_class ):
257- (edges_pred , edges_gt ) = get_mask_edges (y_pred [b , c ], y [b , c ], crop = False )
257+ if not use_subvoxels :
258+ (edges_pred , edges_gt ) = get_mask_edges (y_pred [b , c ], y [b , c ], crop = True )
259+ distances_pred_gt = get_surface_distance (
260+ edges_pred , edges_gt , distance_metric = distance_metric , spacing = spacing_list [b ]
261+ )
262+ distances_gt_pred = get_surface_distance (
263+ edges_gt , edges_pred , distance_metric = distance_metric , spacing = spacing_list [b ]
264+ )
265+
266+ boundary_complete = len (distances_pred_gt ) + len (distances_gt_pred )
267+ boundary_correct = np .sum (distances_pred_gt <= class_thresholds [c ]) + np .sum (
268+ distances_gt_pred <= class_thresholds [c ]
269+ )
270+ else :
271+ _spacing = spacing_list [b ] if spacing_list [b ] is not None else [1 ] * img_dim
272+ areas_pred : np .ndarray
273+ areas_gt : np .ndarray
274+ edges_pred , edges_gt , areas_pred , areas_gt = get_mask_edges ( # type: ignore
275+ y_pred [b , c ], y [b , c ], crop = True , spacing = _spacing # type: ignore
276+ )
277+ dist_pred_to_gt = get_surface_distance (edges_pred , edges_gt , distance_metric , spacing = spacing_list [b ])
278+ dist_gt_to_pred = get_surface_distance (edges_gt , edges_pred , distance_metric , spacing = spacing_list [b ])
279+ areas_gt , areas_pred = areas_gt [edges_gt ], areas_pred [edges_pred ]
280+ boundary_complete = areas_gt .sum () + areas_pred .sum ()
281+ gt_true = areas_gt [dist_gt_to_pred <= class_thresholds [c ]].sum () if len (areas_gt ) > 0 else 0.0
282+ pred_true = areas_pred [dist_pred_to_gt <= class_thresholds [c ]].sum () if len (areas_pred ) > 0 else 0.0
283+ boundary_correct = gt_true + pred_true
258284 if not np .any (edges_gt ):
259285 warnings .warn (f"the ground truth of class { c } is all 0, this may result in nan/inf distance." )
260286 if not np .any (edges_pred ):
261287 warnings .warn (f"the prediction of class { c } is all 0, this may result in nan/inf distance." )
262-
263- distances_pred_gt = get_surface_distance (
264- edges_pred , edges_gt , distance_metric = distance_metric , spacing = spacing_list [b ]
265- )
266- distances_gt_pred = get_surface_distance (
267- edges_gt , edges_pred , distance_metric = distance_metric , spacing = spacing_list [b ]
268- )
269-
270- boundary_complete = len (distances_pred_gt ) + len (distances_gt_pred )
271- boundary_correct = np .sum (distances_pred_gt <= class_thresholds [c ]) + np .sum (
272- distances_gt_pred <= class_thresholds [c ]
273- )
274-
275288 if boundary_complete == 0 :
276289 # the class is neither present in the prediction, nor in the reference segmentation
277290 nsd [b , c ] = np .nan
0 commit comments