1313
1414import torch
1515
16- from monai .metrics .utils import do_metric_reduction , ignore_background , is_binary_tensor
17- from monai .utils import MetricReduction , deprecated
16+ from monai .metrics .utils import do_metric_reduction
17+ from monai .utils import MetricReduction
1818
1919from .metric import CumulativeIterationMetric
2020
21+ __all__ = ["DiceMetric" , "compute_dice" , "DiceHelper" ]
22+
2123
2224class DiceMetric (CumulativeIterationMetric ):
2325 """
24- Compute average Dice score between two tensors. It can support both multi-classes and multi-labels tasks.
26+ Compute average Dice score for a set of pairs of prediction-groundtruth segmentations.
27+
28+ It supports both multi-classes and multi-labels tasks.
2529 Input `y_pred` is compared with ground truth `y`.
26- `y_preds` is expected to have binarized predictions and `y` should be in one-hot format. You can use suitable transforms
27- in ``monai.transforms.post`` first to achieve binarized values.
28- The `include_background` parameter can be set to ``False`` to exclude
30+ `y_pred` is expected to have binarized predictions and `y` can be single-channel class indices or in the
31+ one-hot format. The `include_background` parameter can be set to ``False`` to exclude
2932 the first category (channel index 0) which is by convention assumed to be background. If the non-background
3033 segmentations are small compared to the total image size they can get overwhelmed by the signal from the
31- background.
32- `y_preds` and ` y` can be a list of channel-first Tensor (CHW [D]) or a batch-first Tensor (BCHW[D]) .
34+ background. `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]),
35+ `y` can also be in the format of `B1HW [D]` .
3336
3437 Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
3538
@@ -59,36 +62,37 @@ def __init__(
5962 self .reduction = reduction
6063 self .get_not_nans = get_not_nans
6164 self .ignore_empty = ignore_empty
65+ self .dice_helper = DiceHelper (
66+ include_background = self .include_background ,
67+ reduction = MetricReduction .NONE ,
68+ get_not_nans = False ,
69+ softmax = False ,
70+ ignore_empty = self .ignore_empty ,
71+ )
6272
6373 def _compute_tensor (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor : # type: ignore[override]
6474 """
6575 Args:
6676 y_pred: input data to compute, typical segmentation model output.
6777 It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
6878 should be binarized.
69- y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
70- The values should be binarized .
79+ y: ground truth to compute mean Dice metric. `y` can be single-channel class indices or
80+ in the one-hot format .
7181
7282 Raises:
73- ValueError: when `y` is not a binarized tensor.
7483 ValueError: when `y_pred` has less than three dimensions.
7584 """
76- is_binary_tensor (y_pred , "y_pred" )
77- is_binary_tensor (y , "y" )
78-
7985 dims = y_pred .ndimension ()
8086 if dims < 3 :
8187 raise ValueError (f"y_pred should have at least 3 dimensions (batch, channel, spatial), got { dims } ." )
8288 # compute dice (BxC) for each channel for each batch
83- return compute_dice (
84- y_pred = y_pred , y = y , include_background = self .include_background , ignore_empty = self .ignore_empty
85- )
89+ return self .dice_helper (y_pred = y_pred , y = y ) # type: ignore
8690
8791 def aggregate (
8892 self , reduction : MetricReduction | str | None = None
8993 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
9094 """
91- Execute reduction logic for the output of `compute_meandice `.
95+ Execute reduction and aggregation logic for the output of `compute_dice `.
9296
9397 Args:
9498 reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
@@ -98,7 +102,7 @@ def aggregate(
98102 """
99103 data = self .get_buffer ()
100104 if not isinstance (data , torch .Tensor ):
101- raise ValueError ("the data to aggregate must be PyTorch Tensor." )
105+ raise ValueError (f "the data to aggregate must be PyTorch Tensor, got { type ( data ) } ." )
102106
103107 # do metric reduction
104108 f , not_nans = do_metric_reduction (data , reduction or self .reduction )
@@ -114,45 +118,124 @@ def compute_dice(
114118 y_pred: input data to compute, typical segmentation model output.
115119 It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
116120 should be binarized.
117- y: ground truth to compute mean dice metric. It must be one-hot format and first dim is batch.
118- The values should be binarized.
121+ y: ground truth to compute mean dice metric. `y` can be single-channel class indices or in the one-hot format.
119122 include_background: whether to skip Dice computation on the first channel of
120123 the predicted output. Defaults to True.
121124 ignore_empty: whether to ignore empty ground truth cases during calculation.
122125 If `True`, NaN value will be set for empty ground truth cases.
123126 If `False`, 1 will be set if the predictions of empty ground truth cases are also empty.
124127
125128 Returns:
126- Dice scores per batch and per class, (shape [batch_size, num_classes]).
129+ Dice scores per batch and per class, (shape: [batch_size, num_classes]).
130+
131+ """
132+ return DiceHelper ( # type: ignore
133+ include_background = include_background ,
134+ reduction = MetricReduction .NONE ,
135+ get_not_nans = False ,
136+ softmax = False ,
137+ ignore_empty = ignore_empty ,
138+ )(y_pred = y_pred , y = y )
127139
128- Raises:
129- ValueError: when `y_pred` and `y` have different shapes.
130140
141+ class DiceHelper :
131142 """
143+ Compute Dice score between two tensors `y_pred` and `y`.
144+ `y_pred` must have N channels, `y` can be single-channel class indices or in the one-hot format.
132145
133- if not include_background :
134- y_pred , y = ignore_background (y_pred = y_pred , y = y )
146+ Example:
135147
136- y = y .float ()
137- y_pred = y_pred .float ()
148+ .. code-block:: python
138149
139- if y . shape != y_pred . shape :
140- raise ValueError ( f"y_pred and y should have same shapes, got { y_pred . shape } and { y . shape } ." )
150+ import torch
151+ from monai.metrics import DiceHelper
141152
142- # reducing only spatial dimensions (not batch nor channels)
143- n_len = len (y_pred .shape )
144- reduce_axis = list (range (2 , n_len ))
145- intersection = torch .sum (y * y_pred , dim = reduce_axis )
153+ n_classes, batch_size = 5, 16
154+ spatial_shape = (128, 128, 128)
146155
147- y_o = torch .sum (y , reduce_axis )
148- y_pred_o = torch .sum (y_pred , dim = reduce_axis )
149- denominator = y_o + y_pred_o
156+ y_pred = torch.rand(batch_size, n_classes, *spatial_shape).float() # predictions
157+ y = torch.randint(0, n_classes, size=(batch_size, 1, *spatial_shape)).long() # ground truth
150158
151- if ignore_empty :
152- return torch .where (y_o > 0 , (2.0 * intersection ) / denominator , torch .tensor (float ("nan" ), device = y_o .device ))
153- return torch .where (denominator > 0 , (2.0 * intersection ) / denominator , torch .tensor (1.0 , device = y_o .device ))
159+ score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
160+ print(score, not_nans)
154161
162+ """
155163
156- @deprecated (since = "1.0.0" , msg_suffix = "use `compute_dice` instead." )
157- def compute_meandice (* args , ** kwargs ):
158- return compute_dice (* args , ** kwargs )
164+ def __init__ (
165+ self ,
166+ include_background : bool | None = None ,
167+ sigmoid : bool = False ,
168+ softmax : bool | None = None ,
169+ activate : bool = False ,
170+ get_not_nans : bool = True ,
171+ reduction : MetricReduction | str = MetricReduction .MEAN_BATCH ,
172+ ignore_empty : bool = True ,
173+ ) -> None :
174+ """
175+
176+ Args:
177+ include_background: whether to skip the score on the first channel
178+ (default to the value of `sigmoid`, False).
179+ sigmoid: whether ``y_pred`` are/will be sigmoid activated outputs. If True, thresholding at 0.5
180+ will be performed to get the discrete prediction. Defaults to False.
181+ softmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
182+ get the discrete prediction. Defaults to the value of ``not sigmoid``.
183+ activate: whether to apply sigmoid to ``y_pred`` if ``sigmoid`` is True. Defaults to False.
184+ This option is only valid when ``sigmoid`` is True.
185+ get_not_nans: whether to return the number of not-nan values.
186+ reduction: define mode of reduction to the metrics
187+ ignore_empty: if `True`, NaN value will be set for empty ground truth cases.
188+ If `False`, 1 will be set if the Union of ``y_pred`` and ``y`` is empty.
189+ """
190+ self .sigmoid = sigmoid
191+ self .reduction = reduction
192+ self .get_not_nans = get_not_nans
193+ self .include_background = sigmoid if include_background is None else include_background
194+ self .softmax = not sigmoid if softmax is None else softmax
195+ self .activate = activate
196+ self .ignore_empty = ignore_empty
197+
198+ def compute_channel (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
199+ """"""
200+ y_o = torch .sum (y )
201+ if y_o > 0 :
202+ return (2.0 * torch .sum (torch .masked_select (y , y_pred ))) / (y_o + torch .sum (y_pred ))
203+ if self .ignore_empty :
204+ return torch .tensor (float ("nan" ), device = y_o .device )
205+ denorm = y_o + torch .sum (y_pred )
206+ if denorm <= 0 :
207+ return torch .tensor (1.0 , device = y_o .device )
208+ return (2.0 * torch .sum (torch .masked_select (y , y_pred ))) / denorm
209+
210+ def __call__ (self , y_pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
211+ """
212+
213+ Args:
214+ y_pred: input predictions with shape (batch_size, num_classes, spatial_dims...).
215+ the number of channels is inferred from ``y_pred.shape[1]``.
216+ y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
217+ """
218+ n_pred_ch = y_pred .shape [1 ]
219+
220+ if self .softmax :
221+ if n_pred_ch > 1 :
222+ y_pred = torch .argmax (y_pred , dim = 1 , keepdim = True )
223+
224+ elif self .sigmoid :
225+ if self .activate :
226+ y_pred = torch .sigmoid (y_pred )
227+ y_pred = y_pred > 0.5
228+
229+ first_ch = 0 if self .include_background else 1
230+ data = []
231+ for b in range (y_pred .shape [0 ]):
232+ c_list = []
233+ for c in range (first_ch , n_pred_ch ) if n_pred_ch > 1 else [1 ]:
234+ x_pred = (y_pred [b , 0 ] == c ) if (y_pred .shape [1 ] == 1 ) else y_pred [b , c ].bool ()
235+ x = (y [b , 0 ] == c ) if (y .shape [1 ] == 1 ) else y [b , c ]
236+ c_list .append (self .compute_channel (x_pred , x ))
237+ data .append (torch .stack (c_list ))
238+ data = torch .stack (data , dim = 0 ).contiguous () # type: ignore
239+
240+ f , not_nans = do_metric_reduction (data , self .reduction ) # type: ignore
241+ return (f , not_nans ) if self .get_not_nans else f
0 commit comments