11__all__ = ["SafeguardsSperr" ]
22
3+ import numcodecs
34import numcodecs_safeguards
45import numcodecs_wasm_sperr
6+ import numpy as np
7+ from numcodecs_combinators .stack import CodecStack
58
69from ..abc import Compressor
710
@@ -15,8 +18,38 @@ class SafeguardsSperr(Compressor):
1518 @staticmethod
1619 def abs_bound_codec (error_bound , ** kwargs ):
1720 return numcodecs_safeguards .SafeguardsCodec (
18- codec = numcodecs_wasm_sperr .Sperr (mode = "pwe" , pwe = error_bound ),
21+ codec = CodecStack (
22+ NaNToZero (),
23+ numcodecs_wasm_sperr .Sperr (mode = "pwe" , pwe = error_bound ),
24+ ),
1925 safeguards = [
2026 dict (kind = "eb" , type = "abs" , eb = error_bound , equal_nan = True ),
2127 ],
2228 )
29+
30+ @staticmethod
31+ def rel_bound_codec (error_bound , * , data_abs_min = None , ** kwargs ):
32+ assert data_abs_min is not None , "data_abs_min must be provided"
33+
34+ return numcodecs_safeguards .SafeguardsCodec (
35+ codec = CodecStack (
36+ NaNToZero (),
37+ # conservative rel->abs error bound transformation,
38+ # same as convert_rel_error_to_abs_error
39+ # so that we can inform the safeguards of the rel bound
40+ numcodecs_wasm_sperr .Sperr (mode = "pwe" , pwe = error_bound * data_abs_min ),
41+ ),
42+ safeguards = [
43+ dict (kind = "eb" , type = "rel" , eb = error_bound , equal_nan = True ),
44+ ],
45+ )
46+
47+
48+ class NaNToZero (numcodecs .abc .Codec ):
49+ codec_id = "nan-to-zero"
50+
51+ def encode (self , buf ):
52+ return np .nan_to_num (buf , nan = 0 , posinf = np .inf , neginf = - np .inf )
53+
54+ def decode (self , buf , out = None ):
55+ return numcodecs .compat .ndarray_copy (buf , out )
0 commit comments