Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions src/climatebenchpress/compressor/compressors/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,24 @@ class Compressor(ABC):

@staticmethod
@abstractmethod
def abs_bound_codec(dtype: np.dtype, error_bound: float) -> Codec:
def abs_bound_codec(
error_bound: float,
*,
dtype: Optional[np.dtype] = None,
data_min: Optional[float] = None,
data_max: Optional[float] = None,
) -> Codec:
pass

@staticmethod
@abstractmethod
def rel_bound_codec(dtype: np.dtype, error_bound: float) -> Codec:
def rel_bound_codec(
error_bound: float,
*,
dtype: Optional[np.dtype] = None,
data_min: Optional[float] = None,
data_max: Optional[float] = None,
) -> Codec:
pass

@classmethod
Expand All @@ -71,6 +83,8 @@ def build(
dtypes: dict[VariableName, np.dtype],
data_abs_min: dict[VariableName, float],
data_abs_max: dict[VariableName, float],
data_min: dict[VariableName, float],
data_max: dict[VariableName, float],
error_bounds: list[dict[VariableName, ErrorBound]],
) -> dict[VariantName, list[NamedPerVariableCodec]]:
"""
Expand All @@ -90,6 +104,10 @@ def build(
Dict mapping from variable name to minimum absolute value for the variable.
data_abs_max : dict[VariableName, float]
Dict mapping from variable name to maximum absolute value for the variable.
data_min : dict[VariableName, float]
Dict mapping from variable name to minimum value for the variable.
data_max : dict[VariableName, float]
Dict mapping from variable name to maximum value for the variable.
error_bounds: list[ErrorBound]
List of error bounds to use for the compressor.

Expand All @@ -116,9 +134,19 @@ def build(
new_codecs: dict[VariableName, Codec] = dict()
for var, eb in eb_per_var.items():
if eb.abs_error is not None and cls.has_abs_error_impl:
new_codecs[var] = cls.abs_bound_codec(dtypes[var], eb.abs_error)
new_codecs[var] = cls.abs_bound_codec(
eb.abs_error,
dtype=dtypes[var],
data_min=data_min[var],
data_max=data_max[var],
)
elif eb.rel_error is not None and cls.has_rel_error_impl:
new_codecs[var] = cls.rel_bound_codec(dtypes[var], eb.rel_error)
new_codecs[var] = cls.rel_bound_codec(
eb.rel_error,
dtype=dtypes[var],
data_min=data_min[var],
data_max=data_max[var],
)
else:
# This should never happen as we have already transformed the error bounds.
# If this happens, it means there is a bug in the implementation.
Expand Down
4 changes: 3 additions & 1 deletion src/climatebenchpress/compressor/compressors/bitround.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ class BitRound(Compressor):
description = "Bit Rounding"

@staticmethod
def rel_bound_codec(dtype, error_bound):
def rel_bound_codec(error_bound, *, dtype=None, **kwargs):
assert dtype is not None, "dtype must be provided"

keepbits = compute_keepbits(dtype, error_bound)
return CodecStack(
numcodecs_wasm_bit_round.BitRound(keepbits=keepbits),
Expand Down
4 changes: 3 additions & 1 deletion src/climatebenchpress/compressor/compressors/bitround_pco.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ class BitRoundPco(Compressor):
description = "Bit Rounding + PCodec"

@staticmethod
def rel_bound_codec(dtype, error_bound):
def rel_bound_codec(error_bound, *, dtype=None, **kwargs):
assert dtype is not None, "dtype must be provided"

keepbits = compute_keepbits(dtype, error_bound)
return CodecStack(
numcodecs_wasm_bit_round.BitRound(keepbits=keepbits),
Expand Down
35 changes: 25 additions & 10 deletions src/climatebenchpress/compressor/compressors/jpeg2000.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,44 @@ class Jpeg2000(Compressor):
description = "JPEG 2000"

@staticmethod
def abs_bound_codec(dtype, error_bound):
# Currently, the input is transformed into the range
# round(min_pixel_val/ error_bound) <= x <= round(max_pixel_val / error_bound)
# This means any values outside this range will incur a larger error.
precision = error_bound
def abs_bound_codec(
error_bound,
*,
data_min=None,
data_max=None,
**kwargs,
):
assert data_min is not None, "data_min must be provided"
assert data_max is not None, "data_max must be provided"

max_pixel_val = 2**25 - 1 # maximum pixel value for our integer encoding.

data_range = data_max - data_min

# Here we use the formula for the PSNR (https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio)
# to convert between the absolute error and the PSNR value.
# The original PSNR formula uses the root mean square error (RMSE),
# therefore JPEG does not guaruantee pointwise error bounds but only
# average error bounds.
psnr = 20 * (math.log10(max_pixel_val) - math.log10(error_bound))
psnr = 20 * (math.log10(data_range) - math.log10(error_bound))

return CodecStack(
# increase precision for better rounding during linear quantization
numcodecs.astype.AsType(
encode_dtype="float64",
decode_dtype="float32",
),
# remap from [min, max] to [0, max_pixel_val]
numcodecs_wasm_fixed_offset_scale.FixedOffsetScale(
offset=0,
scale=precision,
offset=data_min,
scale=data_range / max_pixel_val,
),
# round and truncate to integer values
numcodecs_wasm_round.Round(precision=1),
numcodecs.astype.AsType(
encode_dtype="int32",
decode_dtype="float32",
encode_dtype="uint32",
decode_dtype="float64",
),
# apply the PSNR error bound
numcodecs_wasm_jpeg2000.Jpeg2000(mode="psnr", psnr=psnr),
)
2 changes: 1 addition & 1 deletion src/climatebenchpress/compressor/compressors/stochround.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class StochRound(Compressor):
description = "Stochastic Rounding"

@staticmethod
def abs_bound_codec(dtype, error_bound):
def abs_bound_codec(error_bound, **kwargs):
precision = error_bound
return CodecStack(
numcodecs_wasm_uniform_noise.UniformNoise(scale=precision / 2, seed=42),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class StochRoundPco(Compressor):
description = "Stochastic Rounding + PCodec"

@staticmethod
def abs_bound_codec(dtype, error_bound):
def abs_bound_codec(error_bound, **kwargs):
precision = error_bound
return CodecStack(
numcodecs_wasm_uniform_noise.UniformNoise(scale=precision / 2, seed=42),
Expand Down
4 changes: 2 additions & 2 deletions src/climatebenchpress/compressor/compressors/sz3.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class Sz3(Compressor):
description = "SZ3"

@staticmethod
def abs_bound_codec(dtype, error_bound):
def abs_bound_codec(error_bound, **kwargs):
return numcodecs_wasm_sz3.Sz3(eb_mode="abs", eb_abs=error_bound)

@staticmethod
def rel_bound_codec(dtype, error_bound):
def rel_bound_codec(error_bound, **kwargs):
# SZ3 will not ensure that the relative error bound is strictly met.
# Internally, SZ3 transforms the relative error bound to an absolute error bound
# based on the range of the input data:
Expand Down
4 changes: 2 additions & 2 deletions src/climatebenchpress/compressor/compressors/tthresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class Tthresh(Compressor):
description = "tthresh"

@staticmethod
def abs_bound_codec(dtype, error_bound):
def abs_bound_codec(error_bound, **kwargs):
return numcodecs_wasm_tthresh.Tthresh(eb_mode="rmse", eb_rmse=error_bound)

@staticmethod
def rel_bound_codec(dtype, error_bound):
def rel_bound_codec(error_bound, **kwargs):
return numcodecs_wasm_tthresh.Tthresh(eb_mode="eps", eb_rmse=error_bound)
2 changes: 1 addition & 1 deletion src/climatebenchpress/compressor/compressors/zfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Zfp(Compressor):
# See https://zfp.readthedocs.io/en/release1.0.1/faq.html#q-relerr for more details.

@staticmethod
def abs_bound_codec(dtype, error_bound):
def abs_bound_codec(error_bound, **kwargs):
return numcodecs_wasm_zfp_classic.ZfpClassic(
mode="fixed-accuracy", tolerance=error_bound
)
2 changes: 1 addition & 1 deletion src/climatebenchpress/compressor/compressors/zfp_round.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ class ZfpRound(Compressor):
# See https://zfp.readthedocs.io/en/release1.0.1/faq.html#q-relerr for more details.

@staticmethod
def abs_bound_codec(dtype, error_bound):
def abs_bound_codec(error_bound, **kwargs):
return numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=error_bound)
20 changes: 13 additions & 7 deletions src/climatebenchpress/compressor/scripts/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
from numcodecs_observers.walltime import WalltimeObserver
from numcodecs_wasm import WasmCodecInstructionCounterObserver

from ..compressors.abc import (
Compressor,
ErrorBound,
NamedPerVariableCodec,
)
from ..compressors.abc import Compressor, ErrorBound, NamedPerVariableCodec
from ..monitor import progress_bar


Expand Down Expand Up @@ -49,11 +45,19 @@ def compress(
continue

ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr")
ds_dtypes, ds_abs_mins, ds_abs_maxs = dict(), dict(), dict()
ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs = (
dict(),
dict(),
dict(),
dict(),
dict(),
)
for v in ds:
abs_vals = xr.ufuncs.abs(ds[v])
ds_abs_mins[v] = abs_vals.min().values.item()
ds_abs_maxs[v] = abs_vals.max().values.item()
ds_mins[v] = ds[v].min().values.item()
ds_maxs[v] = ds[v].max().values.item()
ds_dtypes[v] = ds[v].dtype

error_bounds = get_error_bounds(datasets_error_bounds, dataset.parent.name)
Expand All @@ -64,7 +68,9 @@ def compress(
continue

compressor_variants: dict[str, list[NamedPerVariableCodec]] = (
compressor.build(ds_dtypes, ds_abs_mins, ds_abs_maxs, error_bounds)
compressor.build(
ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs, error_bounds
)
)

for compr_name, named_codecs in compressor_variants.items():
Expand Down
Loading