Skip to content

Commit 5ce8a10

Browse files
authored
more efficient Dice metrics for large num_class (#6163)
- mostly from @myron's implementation - remove `is_binary_tensor` checks, they are too expensive. - remove the deprecated `compute_meandice` (use `compute_dice` instead) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <[email protected]>
1 parent 66d0478 commit 5ce8a10

File tree

12 files changed

+150
-97
lines changed

12 files changed

+150
-97
lines changed

docs/source/metrics.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ Metrics
5353

5454
`Mean Dice`
5555
-----------
56-
.. autofunction:: compute_meandice
57-
5856
.. autoclass:: DiceMetric
5957
:members:
6058

59+
.. autoclass:: DiceHelper
60+
:members:
61+
6162
`Mean IoU`
6263
----------
6364
.. autofunction:: compute_meaniou

monai/handlers/mean_dice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(
4848
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
4949
5050
See also:
51-
:py:meth:`monai.metrics.meandice.compute_meandice`
51+
:py:meth:`monai.metrics.meandice.compute_dice`
5252
"""
5353
metric_fn = DiceMetric(include_background=include_background, reduction=reduction)
5454
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)

monai/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from .generalized_dice import GeneralizedDiceScore, compute_generalized_dice
2020
from .hausdorff_distance import HausdorffDistanceMetric, compute_hausdorff_distance, compute_percent_hausdorff_distance
2121
from .loss_metric import LossMetric
22-
from .meandice import DiceMetric, compute_dice, compute_meandice
22+
from .meandice import DiceHelper, DiceMetric, compute_dice
2323
from .meaniou import MeanIoU, compute_iou, compute_meaniou
2424
from .metric import Cumulative, CumulativeIterationMetric, IterationMetric, Metric
2525
from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality

monai/metrics/confusion_matrix.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
19+
from monai.metrics.utils import do_metric_reduction, ignore_background
2020
from monai.utils import MetricReduction, ensure_tuple
2121

2222
from .metric import CumulativeIterationMetric
@@ -85,12 +85,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
8585
y: ground truth to compute the metric. It must be one-hot format and first dim is batch.
8686
The values should be binarized.
8787
Raises:
88-
ValueError: when `y` is not a binarized tensor.
8988
ValueError: when `y_pred` has less than two dimensions.
9089
"""
91-
is_binary_tensor(y_pred, "y_pred")
92-
is_binary_tensor(y, "y")
93-
9490
# check dimension
9591
dims = y_pred.ndimension()
9692
if dims < 2:

monai/metrics/f_beta_score.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import torch
1717

18-
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
18+
from monai.metrics.utils import do_metric_reduction, ignore_background
1919
from monai.utils import MetricReduction
2020

2121
from .metric import CumulativeIterationMetric
@@ -36,9 +36,6 @@ def __init__(
3636
self.get_not_nans = get_not_nans
3737

3838
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
39-
is_binary_tensor(y_pred, "y_pred")
40-
is_binary_tensor(y, "y")
41-
4239
if y_pred.ndimension() < 2:
4340
raise ValueError("y_pred should have at least two dimensions.")
4441

monai/metrics/generalized_dice.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from monai.utils import MetricReduction, Weight, look_up_option
1818

1919
from .metric import CumulativeIterationMetric
20-
from .utils import is_binary_tensor
2120

2221

2322
class GeneralizedDiceScore(CumulativeIterationMetric):
@@ -73,8 +72,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
7372
y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
7473
7574
Raises:
76-
ValueError: if `y_pred` or `y` is not a binarized PyTorch tensor, if `y_pred` and `y` have less than
77-
three dimensions, or `y_pred` and `y` don't have the same shape.
75+
ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
7876
"""
7977
return compute_generalized_dice(
8078
y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
@@ -127,10 +125,6 @@ def compute_generalized_dice(
127125
ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
128126
or `y_pred` and `y` don't have the same shape.
129127
"""
130-
# Ensure tensors are binarized
131-
is_binary_tensor(y_pred, "y_pred")
132-
is_binary_tensor(y, "y")
133-
134128
# Ensure tensors have at least 3 dimensions and have the same shape
135129
dims = y_pred.dim()
136130
if dims < 3:

monai/metrics/hausdorff_distance.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,7 @@
1616
import numpy as np
1717
import torch
1818

19-
from monai.metrics.utils import (
20-
do_metric_reduction,
21-
get_mask_edges,
22-
get_surface_distance,
23-
ignore_background,
24-
is_binary_tensor,
25-
)
19+
from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background
2620
from monai.utils import MetricReduction, convert_data_type
2721

2822
from .metric import CumulativeIterationMetric
@@ -86,12 +80,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
8680
The values should be binarized.
8781
8882
Raises:
89-
ValueError: when `y` is not a binarized tensor.
9083
ValueError: when `y_pred` has less than three dimensions.
9184
"""
92-
is_binary_tensor(y_pred, "y_pred")
93-
is_binary_tensor(y, "y")
94-
9585
dims = y_pred.ndimension()
9686
if dims < 3:
9787
raise ValueError("y_pred should have at least three dimensions.")

monai/metrics/meandice.py

Lines changed: 126 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,26 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
17-
from monai.utils import MetricReduction, deprecated
16+
from monai.metrics.utils import do_metric_reduction
17+
from monai.utils import MetricReduction
1818

1919
from .metric import CumulativeIterationMetric
2020

21+
__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
22+
2123

2224
class DiceMetric(CumulativeIterationMetric):
2325
"""
24-
Compute average Dice score between two tensors. It can support both multi-classes and multi-labels tasks.
26+
Compute average Dice score for a set of pairs of prediction-groundtruth segmentations.
27+
28+
It supports both multi-classes and multi-labels tasks.
2529
Input `y_pred` is compared with ground truth `y`.
26-
`y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
27-
in ``monai.transforms.post`` first to achieve binarized values.
28-
The `include_background` parameter can be set to ``False`` to exclude
30+
`y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the
31+
one-hot format. The `include_background` parameter can be set to ``False`` to exclude
2932
the first category (channel index 0) which is by convention assumed to be background. If the non-background
3033
segmentations are small compared to the total image size they can get overwhelmed by the signal from the
31-
background.
32-
`y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).
34+
background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]),
35+
`y` can also be in the format of `B1HW[D]`.
3336
3437
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
3538
@@ -59,36 +62,37 @@ def __init__(
5962
self.reduction = reduction
6063
self.get_not_nans = get_not_nans
6164
self.ignore_empty = ignore_empty
65+
self.dice_helper = DiceHelper(
66+
include_background=self.include_background,
67+
reduction=MetricReduction.NONE,
68+
get_not_nans=False,
69+
softmax=False,
70+
ignore_empty=self.ignore_empty,
71+
)
6272

6373
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
6474
"""
6575
Args:
6676
y_pred: input data to compute, typical segmentation model output.
6777
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
6878
should be binarized.
69-
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
70-
The values should be binarized.
79+
y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
80+
in the one-hot format.
7181
7282
Raises:
73-
ValueError: when `y` is not a binarized tensor.
7483
ValueError: when `y_pred` has less than three dimensions.
7584
"""
76-
is_binary_tensor(y_pred, "y_pred")
77-
is_binary_tensor(y, "y")
78-
7985
dims = y_pred.ndimension()
8086
if dims < 3:
8187
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")
8288
# compute dice (BxC) for each channel for each batch
83-
return compute_dice(
84-
y_pred=y_pred, y=y, include_background=self.include_background, ignore_empty=self.ignore_empty
85-
)
89+
return self.dice_helper(y_pred=y_pred, y=y) # type: ignore
8690

8791
def aggregate(
8892
self, reduction: MetricReduction | str | None = None
8993
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
9094
"""
91-
Execute reduction logic for the output of `compute_meandice`.
95+
Execute reduction and aggregation logic for the output of `compute_dice`.
9296
9397
Args:
9498
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
@@ -98,7 +102,7 @@ def aggregate(
98102
"""
99103
data = self.get_buffer()
100104
if not isinstance(data, torch.Tensor):
101-
raise ValueError("the data to aggregate must be PyTorch Tensor.")
105+
raise ValueError(f"the data to aggregate must be PyTorch Tensor, got {type(data)}.")
102106

103107
# do metric reduction
104108
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
@@ -114,45 +118,124 @@ def compute_dice(
114118
y_pred: input data to compute, typical segmentation model output.
115119
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
116120
should be binarized.
117-
y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
118-
The values should be binarized.
121+
y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
119122
include_background: whether to skip Dice computation on the first channel of
120123
the predicted output. Defaults to True.
121124
ignore_empty: whether to ignore empty ground truth cases during calculation.
122125
If `True`, NaN value will be set for empty ground truth cases.
123126
If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
124127
125128
Returns:
126-
Dice scores per batch and per class, (shape [batch_size, num_classes]).
129+
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
130+
131+
"""
132+
return DiceHelper( # type: ignore
133+
include_background=include_background,
134+
reduction=MetricReduction.NONE,
135+
get_not_nans=False,
136+
softmax=False,
137+
ignore_empty=ignore_empty,
138+
)(y_pred=y_pred, y=y)
127139

128-
Raises:
129-
ValueError: when `y_pred` and `y` have different shapes.
130140

141+
class DiceHelper:
131142
"""
143+
Compute Dice score between two tensors `y_pred` and `y`.
144+
`y_pred` must have N channels, `y` can be single-channel class indices or in the one-hot format.
132145
133-
if not include_background:
134-
y_pred, y = ignore_background(y_pred=y_pred, y=y)
146+
Example:
135147
136-
y = y.float()
137-
y_pred = y_pred.float()
148+
.. code-block:: python
138149
139-
if y.shape != y_pred.shape:
140-
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
150+
import torch
151+
from monai.metrics import DiceHelper
141152
142-
# reducing only spatial dimensions (not batch nor channels)
143-
n_len = len(y_pred.shape)
144-
reduce_axis = list(range(2, n_len))
145-
intersection = torch.sum(y * y_pred, dim=reduce_axis)
153+
n_classes, batch_size = 5, 16
154+
spatial_shape = (128, 128, 128)
146155
147-
y_o = torch.sum(y, reduce_axis)
148-
y_pred_o = torch.sum(y_pred, dim=reduce_axis)
149-
denominator = y_o + y_pred_o
156+
y_pred = torch.rand(batch_size, n_classes, *spatial_shape).float() # predictions
157+
y = torch.randint(0, n_classes, size=(batch_size, 1, *spatial_shape)).long() # ground truth
150158
151-
if ignore_empty:
152-
return torch.where(y_o > 0, (2.0 * intersection) / denominator, torch.tensor(float("nan"), device=y_o.device))
153-
return torch.where(denominator > 0, (2.0 * intersection) / denominator, torch.tensor(1.0, device=y_o.device))
159+
score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
160+
print(score, not_nans)
154161
162+
"""
155163

156-
@deprecated(since="1.0.0", msg_suffix="use `compute_dice` instead.")
157-
def compute_meandice(*args, **kwargs):
158-
return compute_dice(*args, **kwargs)
164+
def __init__(
165+
self,
166+
include_background: bool | None = None,
167+
sigmoid: bool = False,
168+
softmax: bool | None = None,
169+
activate: bool = False,
170+
get_not_nans: bool = True,
171+
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
172+
ignore_empty: bool = True,
173+
) -> None:
174+
"""
175+
176+
Args:
177+
include_background: whether to skip the score on the first channel
178+
(default to the value of `sigmoid`, False).
179+
sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
180+
will be performed to get the discrete prediction. Defaults to False.
181+
softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
182+
get the discrete prediction. Defaults to the value of ``not sigmoid``.
183+
activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
184+
This option is only valid when ``sigmoid`` is True.
185+
get_not_nans: whether to return the number of not-nan values.
186+
reduction: define mode of reduction to the metrics
187+
ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
188+
If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
189+
"""
190+
self.sigmoid = sigmoid
191+
self.reduction = reduction
192+
self.get_not_nans = get_not_nans
193+
self.include_background = sigmoid if include_background is None else include_background
194+
self.softmax = not sigmoid if softmax is None else softmax
195+
self.activate = activate
196+
self.ignore_empty = ignore_empty
197+
198+
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
199+
""""""
200+
y_o = torch.sum(y)
201+
if y_o > 0:
202+
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / (y_o + torch.sum(y_pred))
203+
if self.ignore_empty:
204+
return torch.tensor(float("nan"), device=y_o.device)
205+
denorm = y_o + torch.sum(y_pred)
206+
if denorm <= 0:
207+
return torch.tensor(1.0, device=y_o.device)
208+
return (2.0 * torch.sum(torch.masked_select(y, y_pred))) / denorm
209+
210+
def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
211+
"""
212+
213+
Args:
214+
y_pred: input predictions with shape (batch_size, num_classes, spatial_dims...).
215+
the number of channels is inferred from ``y_pred.shape[1]``.
216+
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
217+
"""
218+
n_pred_ch = y_pred.shape[1]
219+
220+
if self.softmax:
221+
if n_pred_ch > 1:
222+
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
223+
224+
elif self.sigmoid:
225+
if self.activate:
226+
y_pred = torch.sigmoid(y_pred)
227+
y_pred = y_pred > 0.5
228+
229+
first_ch = 0 if self.include_background else 1
230+
data = []
231+
for b in range(y_pred.shape[0]):
232+
c_list = []
233+
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
234+
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
235+
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
236+
c_list.append(self.compute_channel(x_pred, x))
237+
data.append(torch.stack(c_list))
238+
data = torch.stack(data, dim=0).contiguous() # type: ignore
239+
240+
f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore
241+
return (f, not_nans) if self.get_not_nans else f

monai/metrics/meaniou.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction, ignore_background, is_binary_tensor
16+
from monai.metrics.utils import do_metric_reduction, ignore_background
1717
from monai.utils import MetricReduction, deprecated
1818

1919
from .metric import CumulativeIterationMetric
@@ -71,12 +71,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
7171
The values should be binarized.
7272
7373
Raises:
74-
ValueError: when `y` is not a binarized tensor.
7574
ValueError: when `y_pred` has less than three dimensions.
7675
"""
77-
is_binary_tensor(y_pred, "y_pred")
78-
is_binary_tensor(y, "y")
79-
8076
dims = y_pred.ndimension()
8177
if dims < 3:
8278
raise ValueError(f"y_pred should have at least 3 dimensions (batch, channel, spatial), got {dims}.")

0 commit comments

Comments
 (0)