Skip to content

Commit 35b2591

Browse files
committed
fix Safeguarded(0, dSSIM) for multiple 2D slices
1 parent a9b373e commit 35b2591

4 files changed

Lines changed: 63 additions & 14 deletions

File tree

src/climatebenchpress/compressor/compressors/abc.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def abs_bound_codec(
9090
data_max: Optional[float] = None,
9191
data_abs_min: Optional[float] = None,
9292
data_abs_max: Optional[float] = None,
93+
data_min_2d: Optional[np.ndarray] = None,
94+
data_max_2d: Optional[np.ndarray] = None,
9395
) -> Codec:
9496
"""Create a codec with an absolute error bound."""
9597
pass
@@ -104,6 +106,8 @@ def rel_bound_codec(
104106
data_max: Optional[float] = None,
105107
data_abs_min: Optional[float] = None,
106108
data_abs_max: Optional[float] = None,
109+
data_min_2d: Optional[np.ndarray] = None,
110+
data_max_2d: Optional[np.ndarray] = None,
107111
) -> Codec:
108112
"""Create a codec with a relative error bound."""
109113
pass
@@ -116,6 +120,8 @@ def build(
116120
data_abs_max: dict[VariableName, float],
117121
data_min: dict[VariableName, float],
118122
data_max: dict[VariableName, float],
123+
data_min_2d: dict[VariableName, np.ndarray],
124+
data_max_2d: dict[VariableName, np.ndarray],
119125
error_bounds: list[dict[VariableName, ErrorBound]],
120126
) -> dict[VariantName, list[NamedPerVariableCodec]]:
121127
"""
@@ -139,6 +145,12 @@ def build(
139145
Dict mapping from variable name to minimum value for the variable.
140146
data_max : dict[VariableName, float]
141147
Dict mapping from variable name to maximum value for the variable.
148+
data_min_2d : dict[VariableName, np.ndarray]
149+
Dict mapping from variable name to per-lat-lon-slice minimum value for the
150+
variable.
151+
data_max_2d : dict[VariableName, np.ndarray]
152+
Dict mapping from variable name to per-lat-lon-slice maximum value for the
153+
variable.
142154
error_bounds: list[ErrorBound]
143155
List of error bounds to use for the compressor.
144156
@@ -173,6 +185,8 @@ def build(
173185
data_max=data_max[var],
174186
data_abs_min=data_abs_min[var],
175187
data_abs_max=data_abs_max[var],
188+
data_min_2d=data_min_2d[var],
189+
data_max_2d=data_max_2d[var],
176190
)
177191
elif eb.rel_error is not None and cls.has_rel_error_impl:
178192
new_codecs[var] = partial(
@@ -183,6 +197,8 @@ def build(
183197
data_max=data_max[var],
184198
data_abs_min=data_abs_min[var],
185199
data_abs_max=data_abs_max[var],
200+
data_min_2d=data_min_2d[var],
201+
data_max_2d=data_max_2d[var],
186202
)
187203
else:
188204
# This should never happen as we have already transformed the error bounds.

src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ class SafeguardedZeroDssim(Compressor):
1313
description = "Safeguarded(0, dSSIM)"
1414

1515
@staticmethod
16-
def abs_bound_codec(error_bound, data_min=None, data_max=None, **kwargs):
17-
assert data_min is not None, "data_min must be provided"
18-
assert data_max is not None, "data_max must be provided"
16+
def abs_bound_codec(error_bound, data_min_2d=None, data_max_2d=None, **kwargs):
17+
assert data_min_2d is not None, "data_min_2d must be provided"
18+
assert data_max_2d is not None, "data_max_2d must be provided"
1919

2020
return numcodecs_safeguards.SafeguardedCodec(
2121
codec=numcodecs_zero.ZeroCodec(),
@@ -46,14 +46,15 @@ def abs_bound_codec(error_bound, data_min=None, data_max=None, **kwargs):
4646
eb=0,
4747
),
4848
],
49-
# use data_min instead of $x_min to allow for chunking
50-
fixed_constants=dict(x_min=data_min, x_max=data_max),
49+
# use data_min_2d instead of $x_min since we need the minimum per
50+
# 2d latitude-longitude slice
51+
fixed_constants=dict(x_min=data_min_2d, x_max=data_max_2d),
5152
)
5253

5354
@staticmethod
54-
def rel_bound_codec(error_bound, data_min=None, data_max=None, **kwargs):
55-
assert data_min is not None, "data_min must be provided"
56-
assert data_max is not None, "data_max must be provided"
55+
def rel_bound_codec(error_bound, data_min_2d=None, data_max_2d=None, **kwargs):
56+
assert data_min_2d is not None, "data_min_2d must be provided"
57+
assert data_max_2d is not None, "data_max_2d must be provided"
5758

5859
return numcodecs_safeguards.SafeguardedCodec(
5960
codec=numcodecs_zero.ZeroCodec(),
@@ -84,6 +85,7 @@ def rel_bound_codec(error_bound, data_min=None, data_max=None, **kwargs):
8485
eb=0,
8586
),
8687
],
87-
# use data_min instead of $x_min to allow for chunking
88-
fixed_constants=dict(x_min=data_min, x_max=data_max),
88+
# use data_min_2d instead of $x_min since we need the minimum per
89+
# 2d latitude-longitude slice
90+
fixed_constants=dict(x_min=data_min_2d, x_max=data_max_2d),
8991
)

src/climatebenchpress/compressor/plotting/plot_metrics.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,9 @@ def _normalize(data):
207207

208208
# Normalize each variable by its mean and std
209209
normalized[new_col] = normalized.apply(
210-
lambda x: (x[col] - mean_std[x["Variable"]][0])
211-
/ mean_std[x["Variable"]][1],
210+
lambda x: (
211+
(x[col] - mean_std[x["Variable"]][0]) / mean_std[x["Variable"]][1]
212+
),
212213
axis=1,
213214
)
214215

src/climatebenchpress/compressor/scripts/compress.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def compress(
8787
ds_abs_maxs: dict[str, float] = dict()
8888
ds_mins: dict[str, float] = dict()
8989
ds_maxs: dict[str, float] = dict()
90+
ds_min_2ds: dict[str, np.ndarray] = dict()
91+
ds_max_2ds: dict[str, np.ndarray] = dict()
9092
for v in ds:
9193
vs: str = str(v)
9294
abs_vals = xr.ufuncs.abs(ds[v])
@@ -96,6 +98,16 @@ def compress(
9698
ds_abs_maxs[vs] = abs_vals.max().values.item()
9799
ds_mins[vs] = ds[v].min().values.item()
98100
ds_maxs[vs] = ds[v].max().values.item()
101+
ds_min_2ds[vs] = (
102+
ds[v]
103+
.min(dim=[ds[v].cf["Y"].name, ds[v].cf["X"].name], keepdims=True)
104+
.values
105+
)
106+
ds_max_2ds[vs] = (
107+
ds[v]
108+
.max(dim=[ds[v].cf["Y"].name, ds[v].cf["X"].name], keepdims=True)
109+
.values
110+
)
99111

100112
if chunked:
101113
for v in ds:
@@ -115,7 +127,14 @@ def compress(
115127

116128
compressor_variants: dict[str, list[NamedPerVariableCodec]] = (
117129
compressor.build(
118-
ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs, error_bounds
130+
ds_dtypes,
131+
ds_abs_mins,
132+
ds_abs_maxs,
133+
ds_mins,
134+
ds_maxs,
135+
ds_min_2ds,
136+
ds_max_2ds,
137+
error_bounds,
119138
)
120139
)
121140

@@ -189,6 +208,15 @@ def compress_decompress(
189208
if not isinstance(codec, CodecStack):
190209
codec = CodecStack(codec)
191210

211+
# HACK: Safeguarded(0, dSSIM) requires the per-lat-lon-slice minimum
212+
# and maximum
213+
# for potentially-chunked data we should really use xarray-safeguards,
214+
# but not using chunks also works (for now)
215+
is_safeguarded_zero_dssim = (
216+
"# === pointwise dSSIM quantity of interest === #"
217+
in json.dumps(codec.get_config())
218+
)
219+
192220
with numcodecs_observers.observe(
193221
codec,
194222
observers=[
@@ -197,7 +225,9 @@ def compress_decompress(
197225
timing,
198226
],
199227
) as codec_:
200-
variables[v] = codec_.encode_decode_data_array(ds[v]).compute()
228+
variables[v] = codec_.encode_decode_data_array(
229+
ds[v].compute() if is_safeguarded_zero_dssim else ds[v]
230+
).compute()
201231

202232
cs = [c._codec for c in codec_.__iter__()]
203233

0 commit comments

Comments
 (0)