Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions scripts/collect_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = {
"MAE": climatebenchpress.compressor.metrics.MAE(),
"Max Absolute Error": climatebenchpress.compressor.metrics.MaxAbsError(),
"Max Relative Error": climatebenchpress.compressor.metrics.MaxRelError(),
"Spectral Error": climatebenchpress.compressor.metrics.SpectralError(),
"DSSIM": climatebenchpress.compressor.metrics.DSSIM(),
"PSNR": climatebenchpress.compressor.metrics.PSNR(),
Expand Down
4 changes: 3 additions & 1 deletion src/climatebenchpress/compressor/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM"]
__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM", "MaxAbsError", "MaxRelError"]

from . import abc as abc
from .dssim import DSSIM
from .mae import MAE
from .max_abs_error import MaxAbsError
from .max_rel_error import MaxRelError
from .psnr import PSNR
from .spectral_error import SpectralError
7 changes: 4 additions & 3 deletions src/climatebenchpress/compressor/metrics/mae.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__all__ = ["MAE"]

import numpy as np
import xarray as xr

from .abc import Metric
Expand All @@ -9,7 +8,7 @@
class MAE(Metric):
def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float:
"""
Compute the mean squared error between two inputs.
Compute the mean absolute error between two inputs.

Parameters
----------
Expand All @@ -18,4 +17,6 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float:
y : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
"""
return float(np.mean(np.abs(x - y)))
# If we don't use xr.ufuncs, mypy cannot infer that the result is a DataArray
abs_error = xr.ufuncs.abs(x - y)
return float(abs_error.mean(skipna=True))
22 changes: 22 additions & 0 deletions src/climatebenchpress/compressor/metrics/max_abs_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
__all__ = ["MaxAbsError"]

import xarray as xr

from .abc import Metric


class MaxAbsError(Metric):
def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float:
"""
Compute the maximum absolute error between two inputs.

Parameters
----------
x : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
y : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
"""
# If we don't use xr.ufuncs, mypy cannot infer that the result is a DataArray
abs_error = xr.ufuncs.abs(x - y)
return float(abs_error.max(skipna=True))
23 changes: 23 additions & 0 deletions src/climatebenchpress/compressor/metrics/max_rel_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
__all__ = ["MaxRelError"]

import numpy as np
import xarray as xr

from .abc import Metric


class MaxRelError(Metric):
def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float:
"""
Compute the maximum relative error between two inputs.

Parameters
----------
x : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
y : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
"""
# Avoid dividing by zero when x is zero and y is also zero.
rel_error = xr.where((x == 0) & (x == y), 0.0, np.abs(x - y) / np.abs(x))
return float(rel_error.max(skipna=True))
7 changes: 4 additions & 3 deletions src/climatebenchpress/compressor/metrics/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float:
y : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
"""
mse = np.mean((x - y) ** 2, axis=(-3, -2, -1))
max_pixel = np.max(x)
# Average over the vertical, latitude, and longitude dimensions
mse = ((x - y) ** 2).mean(skipna=True, dim=x.dims[-3:])
max_pixel = x.max(skipna=True)
psnr = 20 * np.log10(max_pixel) - 10 * np.log10(mse)
return float(np.mean(psnr))
return float(psnr.mean(skipna=True))
Loading