diff --git a/src/climatebenchpress/compressor/compressors/abc.py b/src/climatebenchpress/compressor/compressors/abc.py index 1a16813..5156afe 100644 --- a/src/climatebenchpress/compressor/compressors/abc.py +++ b/src/climatebenchpress/compressor/compressors/abc.py @@ -57,12 +57,24 @@ class Compressor(ABC): @staticmethod @abstractmethod - def abs_bound_codec(dtype: np.dtype, error_bound: float) -> Codec: + def abs_bound_codec( + error_bound: float, + *, + dtype: Optional[np.dtype] = None, + data_min: Optional[float] = None, + data_max: Optional[float] = None, + ) -> Codec: pass @staticmethod @abstractmethod - def rel_bound_codec(dtype: np.dtype, error_bound: float) -> Codec: + def rel_bound_codec( + error_bound: float, + *, + dtype: Optional[np.dtype] = None, + data_min: Optional[float] = None, + data_max: Optional[float] = None, + ) -> Codec: pass @classmethod @@ -71,6 +83,8 @@ def build( dtypes: dict[VariableName, np.dtype], data_abs_min: dict[VariableName, float], data_abs_max: dict[VariableName, float], + data_min: dict[VariableName, float], + data_max: dict[VariableName, float], error_bounds: list[dict[VariableName, ErrorBound]], ) -> dict[VariantName, list[NamedPerVariableCodec]]: """ @@ -90,6 +104,10 @@ def build( Dict mapping from variable name to minimum absolute value for the variable. data_abs_max : dict[VariableName, float] Dict mapping from variable name to maximum absolute value for the variable. + data_min : dict[VariableName, float] + Dict mapping from variable name to minimum value for the variable. + data_max : dict[VariableName, float] + Dict mapping from variable name to maximum value for the variable. error_bounds: list[ErrorBound] List of error bounds to use for the compressor. @@ -116,9 +134,19 @@ def build( new_codecs: dict[VariableName, Codec] = dict() for var, eb in eb_per_var.items(): if eb.abs_error is not None and cls.has_abs_error_impl: - new_codecs[var] = cls.abs_bound_codec(dtypes[var], eb.abs_error) + new_codecs[var] = cls.abs_bound_codec( + eb.abs_error, + dtype=dtypes[var], + data_min=data_min[var], + data_max=data_max[var], + ) elif eb.rel_error is not None and cls.has_rel_error_impl: - new_codecs[var] = cls.rel_bound_codec(dtypes[var], eb.rel_error) + new_codecs[var] = cls.rel_bound_codec( + eb.rel_error, + dtype=dtypes[var], + data_min=data_min[var], + data_max=data_max[var], + ) else: # This should never happen as we have already transformed the error bounds. # If this happens, it means there is a bug in the implementation. diff --git a/src/climatebenchpress/compressor/compressors/bitround.py b/src/climatebenchpress/compressor/compressors/bitround.py index 92753fc..501d8f8 100644 --- a/src/climatebenchpress/compressor/compressors/bitround.py +++ b/src/climatebenchpress/compressor/compressors/bitround.py @@ -13,7 +13,9 @@ class BitRound(Compressor): description = "Bit Rounding" @staticmethod - def rel_bound_codec(dtype, error_bound): + def rel_bound_codec(error_bound, *, dtype=None, **kwargs): + assert dtype is not None, "dtype must be provided" + keepbits = compute_keepbits(dtype, error_bound) return CodecStack( numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), diff --git a/src/climatebenchpress/compressor/compressors/bitround_pco.py b/src/climatebenchpress/compressor/compressors/bitround_pco.py index 5fda857..e1f34aa 100644 --- a/src/climatebenchpress/compressor/compressors/bitround_pco.py +++ b/src/climatebenchpress/compressor/compressors/bitround_pco.py @@ -14,7 +14,9 @@ class BitRoundPco(Compressor): description = "Bit Rounding + PCodec" @staticmethod - def rel_bound_codec(dtype, error_bound): + def rel_bound_codec(error_bound, *, dtype=None, **kwargs): + assert dtype is not None, "dtype must be provided" + keepbits = compute_keepbits(dtype, error_bound) return CodecStack( numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), diff --git a/src/climatebenchpress/compressor/compressors/jpeg2000.py b/src/climatebenchpress/compressor/compressors/jpeg2000.py index fd643fc..1d73bba 100644 --- a/src/climatebenchpress/compressor/compressors/jpeg2000.py +++ b/src/climatebenchpress/compressor/compressors/jpeg2000.py @@ -16,29 +16,44 @@ class Jpeg2000(Compressor): description = "JPEG 2000" @staticmethod - def abs_bound_codec(dtype, error_bound): - # Currently, the input is transformed into the range - # round(min_pixel_val/ error_bound) <= x <= round(max_pixel_val / error_bound) - # This means any values outside this range will incur a larger error. - precision = error_bound + def abs_bound_codec( + error_bound, + *, + data_min=None, + data_max=None, + **kwargs, + ): + assert data_min is not None, "data_min must be provided" + assert data_max is not None, "data_max must be provided" + max_pixel_val = 2**25 - 1 # maximum pixel value for our integer encoding. + data_range = data_max - data_min + # Here we use the formula for the PSNR (https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio) # to convert between the absolute error and the PSNR value. # The original PSNR formula uses the root mean square error (RMSE), # therefore JPEG does not guaruantee pointwise error bounds but only # average error bounds. - psnr = 20 * (math.log10(max_pixel_val) - math.log10(error_bound)) + psnr = 20 * (math.log10(data_range) - math.log10(error_bound)) return CodecStack( + # increase precision for better rounding during linear quantization + numcodecs.astype.AsType( + encode_dtype="float64", + decode_dtype="float32", + ), + # remap from [min, max] to [0, max_pixel_val] numcodecs_wasm_fixed_offset_scale.FixedOffsetScale( - offset=0, - scale=precision, + offset=data_min, + scale=data_range / max_pixel_val, ), + # round and truncate to integer values numcodecs_wasm_round.Round(precision=1), numcodecs.astype.AsType( - encode_dtype="int32", - decode_dtype="float32", + encode_dtype="uint32", + decode_dtype="float64", ), + # apply the PSNR error bound numcodecs_wasm_jpeg2000.Jpeg2000(mode="psnr", psnr=psnr), ) diff --git a/src/climatebenchpress/compressor/compressors/stochround.py b/src/climatebenchpress/compressor/compressors/stochround.py index 12847a2..3b76db9 100644 --- a/src/climatebenchpress/compressor/compressors/stochround.py +++ b/src/climatebenchpress/compressor/compressors/stochround.py @@ -13,7 +13,7 @@ class StochRound(Compressor): description = "Stochastic Rounding" @staticmethod - def abs_bound_codec(dtype, error_bound): + def abs_bound_codec(error_bound, **kwargs): precision = error_bound return CodecStack( numcodecs_wasm_uniform_noise.UniformNoise(scale=precision / 2, seed=42), diff --git a/src/climatebenchpress/compressor/compressors/stochround_pco.py b/src/climatebenchpress/compressor/compressors/stochround_pco.py index 9c0a7f0..e0b2f8b 100644 --- a/src/climatebenchpress/compressor/compressors/stochround_pco.py +++ b/src/climatebenchpress/compressor/compressors/stochround_pco.py @@ -13,7 +13,7 @@ class StochRoundPco(Compressor): description = "Stochastic Rounding + PCodec" @staticmethod - def abs_bound_codec(dtype, error_bound): + def abs_bound_codec(error_bound, **kwargs): precision = error_bound return CodecStack( numcodecs_wasm_uniform_noise.UniformNoise(scale=precision / 2, seed=42), diff --git a/src/climatebenchpress/compressor/compressors/sz3.py b/src/climatebenchpress/compressor/compressors/sz3.py index ddc148b..b4f6c08 100644 --- a/src/climatebenchpress/compressor/compressors/sz3.py +++ b/src/climatebenchpress/compressor/compressors/sz3.py @@ -10,11 +10,11 @@ class Sz3(Compressor): description = "SZ3" @staticmethod - def abs_bound_codec(dtype, error_bound): + def abs_bound_codec(error_bound, **kwargs): return numcodecs_wasm_sz3.Sz3(eb_mode="abs", eb_abs=error_bound) @staticmethod - def rel_bound_codec(dtype, error_bound): + def rel_bound_codec(error_bound, **kwargs): # SZ3 will not ensure that the relative error bound is strictly met. # Internally, SZ3 transforms the relative error bound to an absolute error bound # based on the range of the input data: diff --git a/src/climatebenchpress/compressor/compressors/tthresh.py b/src/climatebenchpress/compressor/compressors/tthresh.py index 68ba853..401186f 100644 --- a/src/climatebenchpress/compressor/compressors/tthresh.py +++ b/src/climatebenchpress/compressor/compressors/tthresh.py @@ -10,9 +10,9 @@ class Tthresh(Compressor): description = "tthresh" @staticmethod - def abs_bound_codec(dtype, error_bound): + def abs_bound_codec(error_bound, **kwargs): return numcodecs_wasm_tthresh.Tthresh(eb_mode="rmse", eb_rmse=error_bound) @staticmethod - def rel_bound_codec(dtype, error_bound): + def rel_bound_codec(error_bound, **kwargs): return numcodecs_wasm_tthresh.Tthresh(eb_mode="eps", eb_rmse=error_bound) diff --git a/src/climatebenchpress/compressor/compressors/zfp.py b/src/climatebenchpress/compressor/compressors/zfp.py index 82bc23e..9e28cb8 100644 --- a/src/climatebenchpress/compressor/compressors/zfp.py +++ b/src/climatebenchpress/compressor/compressors/zfp.py @@ -18,7 +18,7 @@ class Zfp(Compressor): # See https://zfp.readthedocs.io/en/release1.0.1/faq.html#q-relerr for more details. @staticmethod - def abs_bound_codec(dtype, error_bound): + def abs_bound_codec(error_bound, **kwargs): return numcodecs_wasm_zfp_classic.ZfpClassic( mode="fixed-accuracy", tolerance=error_bound ) diff --git a/src/climatebenchpress/compressor/compressors/zfp_round.py b/src/climatebenchpress/compressor/compressors/zfp_round.py index 91e07ef..bc55e81 100644 --- a/src/climatebenchpress/compressor/compressors/zfp_round.py +++ b/src/climatebenchpress/compressor/compressors/zfp_round.py @@ -18,5 +18,5 @@ class ZfpRound(Compressor): # See https://zfp.readthedocs.io/en/release1.0.1/faq.html#q-relerr for more details. @staticmethod - def abs_bound_codec(dtype, error_bound): + def abs_bound_codec(error_bound, **kwargs): return numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=error_bound) diff --git a/src/climatebenchpress/compressor/scripts/compress.py b/src/climatebenchpress/compressor/scripts/compress.py index 251f05d..bd039e2 100644 --- a/src/climatebenchpress/compressor/scripts/compress.py +++ b/src/climatebenchpress/compressor/scripts/compress.py @@ -15,11 +15,7 @@ from numcodecs_observers.walltime import WalltimeObserver from numcodecs_wasm import WasmCodecInstructionCounterObserver -from ..compressors.abc import ( - Compressor, - ErrorBound, - NamedPerVariableCodec, -) +from ..compressors.abc import Compressor, ErrorBound, NamedPerVariableCodec from ..monitor import progress_bar @@ -49,11 +45,19 @@ def compress( continue ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr") - ds_dtypes, ds_abs_mins, ds_abs_maxs = dict(), dict(), dict() + ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs = ( + dict(), + dict(), + dict(), + dict(), + dict(), + ) for v in ds: abs_vals = xr.ufuncs.abs(ds[v]) ds_abs_mins[v] = abs_vals.min().values.item() ds_abs_maxs[v] = abs_vals.max().values.item() + ds_mins[v] = ds[v].min().values.item() + ds_maxs[v] = ds[v].max().values.item() ds_dtypes[v] = ds[v].dtype error_bounds = get_error_bounds(datasets_error_bounds, dataset.parent.name) @@ -64,7 +68,9 @@ def compress( continue compressor_variants: dict[str, list[NamedPerVariableCodec]] = ( - compressor.build(ds_dtypes, ds_abs_mins, ds_abs_maxs, error_bounds) + compressor.build( + ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs, error_bounds + ) ) for compr_name, named_codecs in compressor_variants.items():