Skip to content

Commit c1c0cdc

Browse files
authored
Add loss_fn to IgniteMetric and rename to IgniteMetricHandler (Project-MONAI#6695)
### Description As explained in Project-MONAI#6693 I would like to use the DiceCELoss as a train metric as well. This branch adds a very crude but working version of that. The added tests, which I copied from the MeanDice metric, do still fail. It would be cool if someone more experienced could check, what needs to be done there. I ran the code with my full DeepEdit setup and it appears to be working just fine there. No formatting checks done so far - I want to find out first if this code is useful for others. Docs not updated yet either. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Matthias Hadlich <[email protected]>
1 parent 48a86b2 commit c1c0cdc

14 files changed

+273
-35
lines changed

Diff for: docs/source/handlers.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ CSV saver
2929
:members:
3030

3131

32-
Ignite Metric
33-
-------------
34-
.. autoclass:: IgniteMetric
32+
Ignite Metric Handler
33+
---------------------
34+
.. autoclass:: IgniteMetricHandler
3535
:members:
3636

3737

Diff for: monai/handlers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from .earlystop_handler import EarlyStopHandler
2121
from .garbage_collector import GarbageCollector
2222
from .hausdorff_distance import HausdorffDistance
23-
from .ignite_metric import IgniteMetric
23+
from .ignite_metric import IgniteMetric, IgniteMetricHandler
2424
from .logfile_handler import LogfileHandler
2525
from .lr_schedule_handler import LrScheduleHandler
2626
from .mean_dice import MeanDice

Diff for: monai/handlers/confusion_matrix.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import ConfusionMatrixMetric
1818
from monai.utils.enums import MetricReduction
1919

2020

21-
class ConfusionMatrix(IgniteMetric):
21+
class ConfusionMatrix(IgniteMetricHandler):
2222
"""
2323
Compute confusion matrix related metrics from full size Tensor and collects average over batch, class-channels, iterations.
2424
"""

Diff for: monai/handlers/hausdorff_distance.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import HausdorffDistanceMetric
1818
from monai.utils import MetricReduction
1919

2020

21-
class HausdorffDistance(IgniteMetric):
21+
class HausdorffDistance(IgniteMetricHandler):
2222
"""
2323
Computes Hausdorff distance from full size Tensor and collects average over batch, class-channels, iterations.
2424
"""

Diff for: monai/handlers/ignite_metric.py

+54-9
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,26 @@
1313

1414
import warnings
1515
from collections.abc import Callable, Sequence
16-
from typing import TYPE_CHECKING, Any
16+
from typing import TYPE_CHECKING, Any, cast
1717

1818
import torch
19+
from torch.nn.modules.loss import _Loss
1920

2021
from monai.config import IgniteInfo
21-
from monai.metrics import CumulativeIterationMetric
22-
from monai.utils import min_version, optional_import
22+
from monai.metrics import CumulativeIterationMetric, LossMetric
23+
from monai.utils import MetricReduction, deprecated, min_version, optional_import
2324

2425
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
2526

2627
if TYPE_CHECKING:
27-
from ignite.engine import Engine
28-
from ignite.metrics import Metric
29-
from ignite.metrics.metric import reinit__is_reduced
28+
try:
29+
_, has_ignite = optional_import("ignite")
30+
from ignite.engine import Engine
31+
from ignite.metrics import Metric
32+
from ignite.metrics.metric import reinit__is_reduced
33+
except ImportError:
34+
has_ignite = False
35+
3036
else:
3137
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
3238
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric", as_type="base")
@@ -35,7 +41,7 @@
3541
)
3642

3743

38-
class IgniteMetric(Metric):
44+
class IgniteMetricHandler(Metric):
3945
"""
4046
Base Metric class based on ignite event handler mechanism.
4147
The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
@@ -44,6 +50,7 @@ class IgniteMetric(Metric):
4450
Args:
4551
metric_fn: callable function or class to compute raw metric results after every iteration.
4652
expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
53+
loss_fn: A torch _Loss function which is used to generate the LossMetric
4754
output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
4855
construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
4956
lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
@@ -52,18 +59,35 @@ class IgniteMetric(Metric):
5259
https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
5360
save_details: whether to save metric computation details per image, for example: mean_dice of every image.
5461
default to True, will save to `engine.state.metric_details` dict with the metric name as key.
62+
reduction: Argument for the LossMetric, look there for details
63+
get_not_nans: Argument for the LossMetric, look there for details
5564
5665
"""
5766

5867
def __init__(
59-
self, metric_fn: CumulativeIterationMetric, output_transform: Callable = lambda x: x, save_details: bool = True
68+
self,
69+
metric_fn: CumulativeIterationMetric | None = None,
70+
loss_fn: _Loss | None = None,
71+
output_transform: Callable = lambda x: x,
72+
save_details: bool = True,
73+
reduction: MetricReduction | str = MetricReduction.MEAN,
74+
get_not_nans: bool = False,
6075
) -> None:
6176
self._is_reduced: bool = False
62-
self.metric_fn = metric_fn
77+
self.metric_fn: CumulativeIterationMetric = cast(CumulativeIterationMetric, metric_fn)
78+
self.loss_fn = loss_fn
6379
self.save_details = save_details
6480
self._scores: list = []
6581
self._engine: Engine | None = None
6682
self._name: str | None = None
83+
84+
if self.metric_fn is None and self.loss_fn is None:
85+
raise ValueError("Either metric_fn or loss_fn have to be passed.")
86+
if self.metric_fn is not None and self.loss_fn is not None:
87+
raise ValueError("Either metric_fn or loss_fn have to be passed, but not both.")
88+
if self.loss_fn:
89+
self.metric_fn = LossMetric(loss_fn=self.loss_fn, reduction=reduction, get_not_nans=get_not_nans)
90+
6791
super().__init__(output_transform)
6892

6993
@reinit__is_reduced
@@ -129,3 +153,24 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
129153
self._name = name
130154
if self.save_details and not hasattr(engine.state, "metric_details"):
131155
engine.state.metric_details = {} # type: ignore
156+
157+
158+
@deprecated(since="1.3", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.")
159+
class IgniteMetric(IgniteMetricHandler):
160+
def __init__(
161+
self,
162+
metric_fn: CumulativeIterationMetric | None = None,
163+
loss_fn: _Loss | None = None,
164+
output_transform: Callable = lambda x: x,
165+
save_details: bool = True,
166+
reduction: MetricReduction | str = MetricReduction.MEAN,
167+
get_not_nans: bool = False,
168+
) -> None:
169+
super().__init__(
170+
metric_fn=metric_fn,
171+
loss_fn=loss_fn,
172+
output_transform=output_transform,
173+
save_details=save_details,
174+
reduction=reduction,
175+
get_not_nans=get_not_nans,
176+
)

Diff for: monai/handlers/mean_dice.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import DiceMetric
1818
from monai.utils import MetricReduction
1919

2020

21-
class MeanDice(IgniteMetric):
21+
class MeanDice(IgniteMetricHandler):
2222
"""
2323
Computes Dice score metric from full size Tensor and collects average over batch, class-channels, iterations.
2424
"""

Diff for: monai/handlers/mean_iou.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import MeanIoU
1818
from monai.utils import MetricReduction
1919

2020

21-
class MeanIoUHandler(IgniteMetric):
21+
class MeanIoUHandler(IgniteMetricHandler):
2222
"""
2323
Computes IoU score metric from full size Tensor and collects average over batch, class-channels, iterations.
2424
"""

Diff for: monai/handlers/metrics_reloaded_handler.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import MetricsReloadedBinary, MetricsReloadedCategorical
1818
from monai.utils.enums import MetricReduction
1919

2020

21-
class MetricsReloadedBinaryHandler(IgniteMetric):
21+
class MetricsReloadedBinaryHandler(IgniteMetricHandler):
2222
"""
2323
Handler of MetricsReloadedBinary, which wraps the binary pairwise metrics of MetricsReloaded.
2424
"""
@@ -65,7 +65,7 @@ def __init__(
6565
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
6666

6767

68-
class MetricsReloadedCategoricalHandler(IgniteMetric):
68+
class MetricsReloadedCategoricalHandler(IgniteMetricHandler):
6969
"""
7070
Handler of MetricsReloadedCategorical, which wraps the categorical pairwise metrics of MetricsReloaded.
7171
"""

Diff for: monai/handlers/panoptic_quality.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import PanopticQualityMetric
1818
from monai.utils import MetricReduction
1919

2020

21-
class PanopticQuality(IgniteMetric):
21+
class PanopticQuality(IgniteMetricHandler):
2222
"""
2323
Computes Panoptic quality from full size Tensor and collects average over batch, class-channels, iterations.
2424
"""

Diff for: monai/handlers/regression_metrics.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import MAEMetric, MSEMetric, PSNRMetric, RMSEMetric
1818
from monai.utils import MetricReduction
1919

2020

21-
class MeanSquaredError(IgniteMetric):
21+
class MeanSquaredError(IgniteMetricHandler):
2222
"""
2323
Computes Mean Squared Error from full size Tensor and collects average over batch, iterations.
2424
"""
@@ -51,7 +51,7 @@ def __init__(
5151
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
5252

5353

54-
class MeanAbsoluteError(IgniteMetric):
54+
class MeanAbsoluteError(IgniteMetricHandler):
5555
"""
5656
Computes Mean Absolute Error from full size Tensor and collects average over batch, iterations.
5757
"""
@@ -84,7 +84,7 @@ def __init__(
8484
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
8585

8686

87-
class RootMeanSquaredError(IgniteMetric):
87+
class RootMeanSquaredError(IgniteMetricHandler):
8888
"""
8989
Computes Root Mean Squared Error from full size Tensor and collects average over batch, iterations.
9090
"""
@@ -117,7 +117,7 @@ def __init__(
117117
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)
118118

119119

120-
class PeakSignalToNoiseRatio(IgniteMetric):
120+
class PeakSignalToNoiseRatio(IgniteMetricHandler):
121121
"""
122122
Computes Peak Signal to Noise Ratio from full size Tensor and collects average over batch, iterations.
123123
"""

Diff for: monai/handlers/roc_auc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import ROCAUCMetric
1818
from monai.utils import Average
1919

2020

21-
class ROCAUC(IgniteMetric):
21+
class ROCAUC(IgniteMetricHandler):
2222
"""
2323
Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC).
2424
accumulating predictions and the ground-truth during an epoch and applying `compute_roc_auc`.

Diff for: monai/handlers/surface_distance.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313

1414
from collections.abc import Callable
1515

16-
from monai.handlers.ignite_metric import IgniteMetric
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
1717
from monai.metrics import SurfaceDistanceMetric
1818
from monai.utils import MetricReduction
1919

2020

21-
class SurfaceDistance(IgniteMetric):
21+
class SurfaceDistance(IgniteMetricHandler):
2222
"""
2323
Computes surface distance from full size Tensor and collects average over batch, class-channels, iterations.
2424
"""

Diff for: tests/min_tests.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def run_testsuit():
8383
"test_handler_early_stop",
8484
"test_handler_garbage_collector",
8585
"test_handler_hausdorff_distance",
86+
"test_handler_ignite_metric",
8687
"test_handler_lr_scheduler",
8788
"test_handler_mean_dice",
8889
"test_handler_panoptic_quality",

0 commit comments

Comments
 (0)