Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"astropy~=7.0.1",
"cf-xarray~=0.10",
"dask~=2024.12.0",
"numcodecs>=0.13.0,<0.16",
Expand Down Expand Up @@ -46,5 +47,5 @@ where = ["src"]
"climatebenchpress.compressor" = ["py.typed"]

[[tool.mypy.overrides]]
module = ["numcodecs.*"]
module = ["numcodecs.*", "astropy.convolution.*"]
follow_untyped_imports = true
85 changes: 49 additions & 36 deletions scripts/collect_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = {
"MAE": climatebenchpress.compressor.metrics.MAE(),
"Spectral Error": climatebenchpress.compressor.metrics.SpectralError(),
"DSSIM": climatebenchpress.compressor.metrics.DSSIM(),
"PSNR": climatebenchpress.compressor.metrics.PSNR(),
}

Expand All @@ -29,32 +30,41 @@ def main():
if dataset.name == ".gitignore":
continue

for compressor in dataset.iterdir():
print(f"Evaluating {compressor.stem} on {dataset.name}...")

compressed_dataset = compressed_datasets / dataset.name / compressor.stem
compressed_dataset_path = compressed_dataset / "decompressed.zarr"
uncompressed_dataset = datasets / dataset.name / "standardized.zarr"
assert compressed_dataset_path.exists(), (
f"No compressed dataset at {compressed_dataset_path}"
)
assert uncompressed_dataset.exists(), (
f"No uncompressed dataset at {uncompressed_dataset}"
)

ds = xr.open_zarr(uncompressed_dataset, chunks=dict()).compute()
ds_new = xr.open_zarr(compressed_dataset_path, chunks=dict()).compute()

compressor_metrics = metrics_dir / dataset.name / compressor.stem
compressor_metrics.mkdir(parents=True, exist_ok=True)

metrics = compute_metrics(compressor_metrics, ds, ds_new)
tests = compute_tests(compressor_metrics, ds, ds_new)
measurements = load_measurements(compressed_datasets, dataset, compressor)

df = merge_metrics(measurements, metrics, tests)
df["Dataset"] = dataset.name
all_results.append(df)
for error_bound in dataset.iterdir():
for compressor in error_bound.iterdir():
print(f"Evaluating {compressor.stem} on {dataset.name}...")

compressed_dataset = (
compressed_datasets
/ dataset.name
/ error_bound.name
/ compressor.stem
)
compressed_dataset_path = compressed_dataset / "decompressed.zarr"
uncompressed_dataset = datasets / dataset.name / "standardized.zarr"
if not compressed_dataset_path.exists():
print(f"No compressed dataset at {compressed_dataset_path}")
continue
if not uncompressed_dataset.exists():
print(f"No uncompressed dataset at {uncompressed_dataset}")
continue

ds = xr.open_zarr(uncompressed_dataset, chunks=dict()).compute()
ds_new = xr.open_zarr(compressed_dataset_path, chunks=dict()).compute()

compressor_metrics = (
metrics_dir / dataset.name / error_bound.name / compressor.stem
)
compressor_metrics.mkdir(parents=True, exist_ok=True)

metrics = compute_metrics(compressor_metrics, ds, ds_new)
tests = compute_tests(compressor_metrics, ds, ds_new)
measurements = load_measurements(compressed_dataset, compressor)

df = merge_metrics(measurements, metrics, tests)
df["Dataset"] = dataset.name
df["Error Bound"] = error_bound.name
all_results.append(df)

all_results = pd.concat(all_results)
all_results.to_csv(metrics_dir / "all_results.csv", index=False)
Expand Down Expand Up @@ -107,12 +117,8 @@ def compute_tests(
return tests


def load_measurements(
compressed_datasets: Path, dataset: Path, compressor: Path
) -> pd.DataFrame:
with open(
compressed_datasets / dataset.name / compressor.stem / "measurements.json"
) as f:
def load_measurements(compressed_dataset: Path, compressor: Path) -> pd.DataFrame:
with open(compressed_dataset / "measurements.json") as f:
measurements = json.load(f)

rows = []
Expand Down Expand Up @@ -149,16 +155,23 @@ def load_measurements(
def merge_metrics(
measurements: pd.DataFrame, metrics: pd.DataFrame, tests: pd.DataFrame
) -> pd.DataFrame:
# Turn each metric into a column. Merge on "variable" to avoid duplicating
# Turn each metric/test into a column. Merge on "variable" to avoid duplicating
# the "variable" column.
test_per_variable = tests.pivot(
index="Variable", columns="Test", values=["Passed", "Value"]
)
# mypy cannot infer that test_per_variable.columns is a MultiIndex and therefore
# gives spurious errors for this assignment.
test_per_variable.columns = [ # type: ignore
f"{metric_name} ({passed_or_val})" # type: ignore
for passed_or_val, metric_name in test_per_variable.columns # type: ignore
]
return pd.merge(
measurements,
metrics.pivot(index="Variable", columns="Metric", values="Error")
.reset_index()
.merge(
tests.pivot(
index="Variable", columns="Test", values="Passed"
).reset_index(),
test_per_variable.reset_index(),
on="Variable",
),
on="Variable",
Expand Down
3 changes: 2 additions & 1 deletion src/climatebenchpress/compressor/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__all__ = ["MAE", "SpectralError", "PSNR"]
__all__ = ["MAE", "SpectralError", "PSNR", "DSSIM"]

from . import abc as abc
from .dssim import DSSIM
from .mae import MAE
from .psnr import PSNR
from .spectral_error import SpectralError
133 changes: 133 additions & 0 deletions src/climatebenchpress/compressor/metrics/dssim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
__all__ = ["DSSIM"]

import numpy as np
import xarray as xr
from astropy.convolution import Gaussian2DKernel, convolve

from .abc import Metric


class DSSIM(Metric):
def __call__(self, x: xr.DataArray, y: xr.DataArray) -> float:
"""
Implementation of the data-SSIM (dSSIM) metric presented in [1]. This is an
extension of the standard structural similarity index (SSIM) to floating
point data.

Here we assume that the input data has shape (realization, time, vertical, latitude, longitude).
The dSSIM metric is defined for 2D fields, so we compute the dSSIM for each vertical slice
and then take the minimum value over all vertical slices (this follows the official implementation
of [1]). The final dSSIM value is the average over the realization and time dimensions.
Comment thread
treigerm marked this conversation as resolved.

NOTE: This implementation can return values > 1.0 in the case that one of the inputs
has large regions with NaNs and the other input does not. This is because the
`astropy.convolution.convolve` function linearly interpolates the NaN values.
The interpolation of NaN is an explicit design decision made in [1]. In practice,
this metric should not be used for data with large regions of NaNs.

References:
[1] A. H. Baker, A. Pinard and D. M. Hammerling, "On a Structural Similarity
Index Approach for Floating-Point Data," in IEEE Transactions on Visualization
and Computer Graphics

Parameters
----------
x : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
y : xr.DataArray
Shape (realization, time, vertical, latitude, longitude)
"""
_, _, num_vert, num_lat, num_lon = x.shape
x_ = x.values.reshape(-1, num_vert, num_lat, num_lon)
y_ = y.values.reshape(-1, num_vert, num_lat, num_lon)
dssims = np.zeros(x_.shape[0])
for i in range(x_.shape[0]):
Comment thread
treigerm marked this conversation as resolved.
Outdated
vertical_dssims = np.zeros(num_vert)
for j in range(num_vert):
vertical_dssims[j] = _dssim(x_[i, j], y_[i, j])
dssims[i] = vertical_dssims.min()
return dssims.mean()


def _dssim(
a1: np.ndarray,
a2: np.ndarray,
eps: float = 1e-8,
kernel_size: tuple[int, int] = (11, 11),
) -> float:
"""
Implementation adapted from the official dSSIM implementation at
https://github.com/NCAR/ldcpy/blob/6c5bcb8149ec7876a4f53b0e784e9c528f6f14cb/ldcpy/calcs.py#L2516

The official implementation makes assumptions about the input data that are
specific to models developed at NCAR which is why we cannot use the official
implementation directly.

Parameters
----------
x : np.ndarray
Shape: (latitude, longitude)
y : np.ndarray
Shape: (latitude, longitude)

Returns
-------
float
The data-SSIM value between the two input arrays.
"""
# re-scale to [0,1] - if not constant
smin = min(np.nanmin(a1), np.nanmin(a2))
smax = max(np.nanmax(a1), np.nanmax(a2))
r = smax - smin
if r == 0.0: # scale by smax if field is a constant (and smax != 0)
if smax == 0.0:
sc_a1 = a1
sc_a2 = a2
else:
sc_a1 = a1 / smax
sc_a2 = a2 / smax
else:
sc_a1 = (a1 - smin) / r
sc_a2 = (a2 - smin) / r

# now quantize to 256 bins
sc_a1 = np.round(sc_a1 * 255) / 255
sc_a2 = np.round(sc_a2 * 255) / 255

# gaussian filter
kernel = Gaussian2DKernel(
x_stddev=1.5, x_size=kernel_size[0], y_size=kernel_size[1]
)
k = 5
Comment thread
treigerm marked this conversation as resolved.
Outdated
filter_args = {"boundary": "fill", "preserve_nan": True}

a1_mu = convolve(sc_a1, kernel, **filter_args)
a2_mu = convolve(sc_a2, kernel, **filter_args)

a1a1 = convolve(sc_a1 * sc_a1, kernel, **filter_args)
a2a2 = convolve(sc_a2 * sc_a2, kernel, **filter_args)

a1a2 = convolve(sc_a1 * sc_a2, kernel, **filter_args)

###########
var_a1 = a1a1 - a1_mu * a1_mu
var_a2 = a2a2 - a2_mu * a2_mu
cov_a1a2 = a1a2 - a1_mu * a2_mu

# ssim constants
C1 = eps
C2 = eps

ssim_t1 = 2 * a1_mu * a2_mu + C1
ssim_t2 = 2 * cov_a1a2 + C2

ssim_b1 = a1_mu * a1_mu + a2_mu * a2_mu + C1
ssim_b2 = var_a1 + var_a2 + C2

ssim_1 = ssim_t1 / ssim_b1
ssim_2 = ssim_t2 / ssim_b2
ssim_mat = ssim_1 * ssim_2

# cropping (the border region)
ssim_mat = ssim_mat[k : ssim_mat.shape[0] - k, k : ssim_mat.shape[1] - k]
return np.nanmean(ssim_mat)