From 21c0fae52c93e0848d23d872bfff8c6cc81753bf Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 11 Apr 2025 14:28:00 +0100 Subject: [PATCH 1/5] Add maximum absolute error metric --- scripts/collect_metrics.py | 1 + .../compressor/metrics/__init__.py | 3 ++- .../compressor/metrics/mae.py | 2 +- .../compressor/metrics/max_error.py | 21 +++++++++++++++++++ 4 files changed, 25 insertions(+), 2 deletions(-) create mode 100644 src/climatebenchpress/compressor/metrics/max_error.py diff --git a/scripts/collect_metrics.py b/scripts/collect_metrics.py index aa24a25..f8a3efe 100644 --- a/scripts/collect_metrics.py +++ b/scripts/collect_metrics.py @@ -11,6 +11,7 @@ EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = { "MAE": climatebenchpress.compressor.metrics.MAE(), + "Max Absolute Error": climatebenchpress.compressor.metrics.MaxAbsError(), "Spectral Error": climatebenchpress.compressor.metrics.SpectralError(), "DSSIM": climatebenchpress.compressor.metrics.DSSIM(), "PSNR": climatebenchpress.compressor.metrics.PSNR(), diff --git a/src/climatebenchpress/compressor/metrics/__init__.py b/src/climatebenchpress/compressor/metrics/__init__.py index 8cb620f..beb51bc 100644 --- a/src/climatebenchpress/compressor/metrics/__init__.py +++ b/src/climatebenchpress/compressor/metrics/__init__.py @@ -1,7 +1,8 @@ -__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM"] +__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM", "MaxAbsError"] from . import abc as abc from .dssim import DSSIM from .mae import MAE +from .max_error import MaxAbsError from .psnr import PSNR from .spectral_error import SpectralError diff --git a/src/climatebenchpress/compressor/metrics/mae.py b/src/climatebenchpress/compressor/metrics/mae.py index e94017a..96bfae1 100644 --- a/src/climatebenchpress/compressor/metrics/mae.py +++ b/src/climatebenchpress/compressor/metrics/mae.py @@ -9,7 +9,7 @@ class MAE(Metric): def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: """ - Compute the mean squared error between two inputs. + Compute the mean absolute error between two inputs. Parameters ---------- diff --git a/src/climatebenchpress/compressor/metrics/max_error.py b/src/climatebenchpress/compressor/metrics/max_error.py new file mode 100644 index 0000000..b906911 --- /dev/null +++ b/src/climatebenchpress/compressor/metrics/max_error.py @@ -0,0 +1,21 @@ +__all__ = ["MaxAbsError"] + +import numpy as np +import xarray as xr + +from .abc import Metric + + +class MaxAbsError(Metric): + def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: + """ + Compute the maximum absolute error between two inputs. + + Parameters + ---------- + x : xr.DataArray + Shape (realization, time, vertical, latitude, longitude) + y : xr.DataArray + Shape (realization, time, vertical, latitude, longitude) + """ + return float(np.max(np.abs(x - y))) From a411bca688508fc9fdb4cbd285452c39f1e88ac7 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 14 Apr 2025 11:46:00 +0100 Subject: [PATCH 2/5] Make skipping of NaNs explicit --- src/climatebenchpress/compressor/metrics/mae.py | 2 +- .../compressor/metrics/{max_error.py => max_abs_error.py} | 2 +- src/climatebenchpress/compressor/metrics/psnr.py | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) rename src/climatebenchpress/compressor/metrics/{max_error.py => max_abs_error.py} (90%) diff --git a/src/climatebenchpress/compressor/metrics/mae.py b/src/climatebenchpress/compressor/metrics/mae.py index 96bfae1..f5d7c7e 100644 --- a/src/climatebenchpress/compressor/metrics/mae.py +++ b/src/climatebenchpress/compressor/metrics/mae.py @@ -18,4 +18,4 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: y : xr.DataArray Shape (realization, time, vertical, latitude, longitude) """ - return float(np.mean(np.abs(x - y))) + return float(np.abs(x - y).mean(skipna=True)) diff --git a/src/climatebenchpress/compressor/metrics/max_error.py b/src/climatebenchpress/compressor/metrics/max_abs_error.py similarity index 90% rename from src/climatebenchpress/compressor/metrics/max_error.py rename to src/climatebenchpress/compressor/metrics/max_abs_error.py index b906911..f73f449 100644 --- a/src/climatebenchpress/compressor/metrics/max_error.py +++ b/src/climatebenchpress/compressor/metrics/max_abs_error.py @@ -18,4 +18,4 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: y : xr.DataArray Shape (realization, time, vertical, latitude, longitude) """ - return float(np.max(np.abs(x - y))) + return float(np.abs(x - y).max(skipna=True)) diff --git a/src/climatebenchpress/compressor/metrics/psnr.py b/src/climatebenchpress/compressor/metrics/psnr.py index b9f048e..4521cab 100644 --- a/src/climatebenchpress/compressor/metrics/psnr.py +++ b/src/climatebenchpress/compressor/metrics/psnr.py @@ -24,7 +24,8 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: y : xr.DataArray Shape (realization, time, vertical, latitude, longitude) """ - mse = np.mean((x - y) ** 2, axis=(-3, -2, -1)) - max_pixel = np.max(x) + # Average over the vertical, latitude, and longitude dimensions + mse = ((x - y) ** 2).mean(skipna=True, dim=x.dims[-3:]) + max_pixel = x.max(skipna=True) psnr = 20 * np.log10(max_pixel) - 10 * np.log10(mse) - return float(np.mean(psnr)) + return float(psnr.mean(skipna=True)) From aa951f6bdd4c9519d229e6b50437dbeb8c3b46bc Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 14 Apr 2025 11:47:11 +0100 Subject: [PATCH 3/5] Add maximum relative error --- scripts/collect_metrics.py | 1 + .../compressor/metrics/__init__.py | 5 +++-- .../compressor/metrics/max_rel_error.py | 22 +++++++++++++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 src/climatebenchpress/compressor/metrics/max_rel_error.py diff --git a/scripts/collect_metrics.py b/scripts/collect_metrics.py index f8a3efe..3d5221d 100644 --- a/scripts/collect_metrics.py +++ b/scripts/collect_metrics.py @@ -12,6 +12,7 @@ EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = { "MAE": climatebenchpress.compressor.metrics.MAE(), "Max Absolute Error": climatebenchpress.compressor.metrics.MaxAbsError(), + "Max Relative Error": climatebenchpress.compressor.metrics.MaxRelError(), "Spectral Error": climatebenchpress.compressor.metrics.SpectralError(), "DSSIM": climatebenchpress.compressor.metrics.DSSIM(), "PSNR": climatebenchpress.compressor.metrics.PSNR(), diff --git a/src/climatebenchpress/compressor/metrics/__init__.py b/src/climatebenchpress/compressor/metrics/__init__.py index beb51bc..6b0a008 100644 --- a/src/climatebenchpress/compressor/metrics/__init__.py +++ b/src/climatebenchpress/compressor/metrics/__init__.py @@ -1,8 +1,9 @@ -__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM", "MaxAbsError"] +__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM", "MaxAbsError", "MaxRelError"] from . import abc as abc from .dssim import DSSIM from .mae import MAE -from .max_error import MaxAbsError +from .max_abs_error import MaxAbsError +from .max_rel_error import MaxRelError from .psnr import PSNR from .spectral_error import SpectralError diff --git a/src/climatebenchpress/compressor/metrics/max_rel_error.py b/src/climatebenchpress/compressor/metrics/max_rel_error.py new file mode 100644 index 0000000..f3c8525 --- /dev/null +++ b/src/climatebenchpress/compressor/metrics/max_rel_error.py @@ -0,0 +1,22 @@ +__all__ = ["MaxRelError"] + +import numpy as np +import xarray as xr + +from .abc import Metric + + +class MaxRelError(Metric): + def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: + """ + Compute the maximum relative error between two inputs. + + Parameters + ---------- + x : xr.DataArray + Shape (realization, time, vertical, latitude, longitude) + y : xr.DataArray + Shape (realization, time, vertical, latitude, longitude) + """ + rel_error = np.abs(x - y) / np.abs(x) + return float(rel_error.max(skipna=True)) From 51eb6de6f6d8ddab17ad76f2c4beeb73c6712e36 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 14 Apr 2025 12:03:06 +0100 Subject: [PATCH 4/5] Make mypy happy --- src/climatebenchpress/compressor/metrics/mae.py | 5 +++-- src/climatebenchpress/compressor/metrics/max_abs_error.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/metrics/mae.py b/src/climatebenchpress/compressor/metrics/mae.py index f5d7c7e..c365a88 100644 --- a/src/climatebenchpress/compressor/metrics/mae.py +++ b/src/climatebenchpress/compressor/metrics/mae.py @@ -1,6 +1,5 @@ __all__ = ["MAE"] -import numpy as np import xarray as xr from .abc import Metric @@ -18,4 +17,6 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: y : xr.DataArray Shape (realization, time, vertical, latitude, longitude) """ - return float(np.abs(x - y).mean(skipna=True)) + # If we don't use xr.ufuncs, mypy cannot infer that the result is a DataArray + abs_error = xr.ufuncs.abs(x - y) + return float(abs_error.mean(skipna=True)) diff --git a/src/climatebenchpress/compressor/metrics/max_abs_error.py b/src/climatebenchpress/compressor/metrics/max_abs_error.py index f73f449..dc0f198 100644 --- a/src/climatebenchpress/compressor/metrics/max_abs_error.py +++ b/src/climatebenchpress/compressor/metrics/max_abs_error.py @@ -1,6 +1,5 @@ __all__ = ["MaxAbsError"] -import numpy as np import xarray as xr from .abc import Metric @@ -18,4 +17,6 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: y : xr.DataArray Shape (realization, time, vertical, latitude, longitude) """ - return float(np.abs(x - y).max(skipna=True)) + # If we don't use xr.ufuncs, mypy cannot infer that the result is a DataArray + abs_error = xr.ufuncs.abs(x - y) + return float(abs_error.max(skipna=True)) From 38528b4888ec47f25aa5569e60cd0a846b75aa8e Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 14 Apr 2025 12:33:31 +0100 Subject: [PATCH 5/5] Avoid division by zero --- src/climatebenchpress/compressor/metrics/max_rel_error.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/climatebenchpress/compressor/metrics/max_rel_error.py b/src/climatebenchpress/compressor/metrics/max_rel_error.py index f3c8525..8fb752c 100644 --- a/src/climatebenchpress/compressor/metrics/max_rel_error.py +++ b/src/climatebenchpress/compressor/metrics/max_rel_error.py @@ -18,5 +18,6 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float: y : xr.DataArray Shape (realization, time, vertical, latitude, longitude) """ - rel_error = np.abs(x - y) / np.abs(x) + # Avoid dividing by zero when x is zero and y is also zero. + rel_error = xr.where((x == 0) & (x == y), 0.0, np.abs(x - y) / np.abs(x)) return float(rel_error.max(skipna=True))