Skip to content

Commit 6cddcd8

Browse files
committed
safeguard (conservative) relative error bounds
1 parent 93c903f commit 6cddcd8

4 files changed

Lines changed: 58 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ dependencies = [
1313
"matplotlib~=3.8",
1414
"netcdf4==1.7.3",
1515
"numcodecs>=0.13.0,<0.17",
16-
"numcodecs-combinators[xarray]~=0.2.10",
16+
"numcodecs-combinators[xarray]~=0.2.13",
1717
"numcodecs-observers~=0.1.2",
18-
"numcodecs-safeguards==0.1.0a1",
18+
"numcodecs-safeguards==0.1.0b1",
1919
"numcodecs-wasm~=0.2.2",
2020
"numcodecs-wasm-bit-round~=0.4.0",
2121
"numcodecs-wasm-fixed-offset-scale~=0.4.0",
@@ -24,12 +24,12 @@ dependencies = [
2424
"numcodecs-wasm-round~=0.5.0",
2525
"numcodecs-wasm-sperr~=0.2.0",
2626
"numcodecs-wasm-stochastic-rounding~=0.2.0",
27-
"numcodecs-wasm-sz3~=0.7.0",
27+
"numcodecs-wasm-sz3~=0.8.0",
2828
"numcodecs-wasm-tthresh~=0.3.0",
2929
"numcodecs-wasm-zfp~=0.6.0",
3030
"numcodecs-wasm-zfp-classic~=0.4.0",
3131
"numcodecs-wasm-zstd~=0.4.0",
32-
"numcodecs-zero~=0.1.0",
32+
"numcodecs-zero~=0.1.2",
3333
"pandas~=2.2",
3434
"scipy~=1.14",
3535
"seaborn~=0.13.2",

src/climatebenchpress/compressor/compressors/abc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def build(
167167
dtype=dtypes[var],
168168
data_min=data_min[var],
169169
data_max=data_max[var],
170+
data_abs_min=data_abs_min[var],
171+
data_abs_max=data_abs_max[var],
170172
)
171173
elif eb.rel_error is not None and cls.has_rel_error_impl:
172174
new_codecs[var] = partial(
@@ -175,6 +177,8 @@ def build(
175177
dtype=dtypes[var],
176178
data_min=data_min[var],
177179
data_max=data_max[var],
180+
data_abs_min=data_abs_min[var],
181+
data_abs_max=data_abs_max[var],
178182
)
179183
else:
180184
# This should never happen as we have already transformed the error bounds.

src/climatebenchpress/compressor/compressors/safeguards/sperr.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
__all__ = ["SafeguardsSperr"]
22

3+
import numcodecs
34
import numcodecs_safeguards
45
import numcodecs_wasm_sperr
6+
import numpy as np
7+
from numcodecs_combinators.stack import CodecStack
58

69
from ..abc import Compressor
710

@@ -15,8 +18,38 @@ class SafeguardsSperr(Compressor):
1518
@staticmethod
1619
def abs_bound_codec(error_bound, **kwargs):
1720
return numcodecs_safeguards.SafeguardsCodec(
18-
codec=numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound),
21+
codec=CodecStack(
22+
NaNToZero(),
23+
numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound),
24+
),
1925
safeguards=[
2026
dict(kind="eb", type="abs", eb=error_bound, equal_nan=True),
2127
],
2228
)
29+
30+
@staticmethod
31+
def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs):
32+
assert data_abs_min is not None, "data_abs_min must be provided"
33+
34+
return numcodecs_safeguards.SafeguardsCodec(
35+
codec=CodecStack(
36+
NaNToZero(),
37+
# conservative rel->abs error bound transformation,
38+
# same as convert_rel_error_to_abs_error
39+
# so that we can inform the safeguards of the rel bound
40+
numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound * data_abs_min),
41+
),
42+
safeguards=[
43+
dict(kind="eb", type="rel", eb=error_bound, equal_nan=True),
44+
],
45+
)
46+
47+
48+
class NaNToZero(numcodecs.abc.Codec):
49+
codec_id = "nan-to-zero"
50+
51+
def encode(self, buf):
52+
return np.nan_to_num(buf, nan=0, posinf=np.inf, neginf=-np.inf)
53+
54+
def decode(self, buf, out=None):
55+
return numcodecs.compat.ndarray_copy(buf, out)

src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,19 @@ def abs_bound_codec(error_bound, **kwargs):
3232
dict(kind="eb", type="abs", eb=error_bound, equal_nan=True),
3333
],
3434
)
35+
36+
@staticmethod
37+
def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs):
38+
assert data_abs_min is not None, "data_abs_min must be provided"
39+
40+
return numcodecs_safeguards.SafeguardsCodec(
41+
# conservative rel->abs error bound transformation,
42+
# same as convert_rel_error_to_abs_error
43+
# so that we can inform the safeguards of the rel bound
44+
codec=numcodecs_wasm_zfp.Zfp(
45+
mode="fixed-accuracy", tolerance=error_bound * data_abs_min
46+
),
47+
safeguards=[
48+
dict(kind="eb", type="rel", eb=error_bound, equal_nan=True),
49+
],
50+
)

0 commit comments

Comments
 (0)