Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,3 @@ jobs:

- name: Typecheck the package
run: uv run mypy -p climatebenchpress.compressor

- name: Typecheck the scripts
run: uv run mypy scripts/
30 changes: 15 additions & 15 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@ requires-python = ">=3.12"
dependencies = [
"astropy~=7.0.1",
"cf-xarray~=0.10",
"dask~=2024.12.0",
"numcodecs>=0.13.0,<0.16",
"dask>=2024.12.0,<2025.4",
"numcodecs>=0.13.0,<0.17",
"numcodecs-combinators[xarray]~=0.2.4",
"numcodecs-observers~=0.1.1",
"numcodecs-wasm~=0.1.3",
Comment thread
juntyr marked this conversation as resolved.
"numcodecs-wasm-bit-round~=0.2.0",
"numcodecs-wasm-fixed-offset-scale~=0.2.1",
"numcodecs-wasm-jpeg2000~=0.1.1",
"numcodecs-wasm-pco~=0.1.0",
"numcodecs-wasm-round~=0.2.0",
"numcodecs-wasm-sz3~=0.5.0",
"numcodecs-wasm-tthresh~=0.1.0",
"numcodecs-wasm-uniform-noise~=0.2.0",
"numcodecs-wasm-zfp~=0.4.0",
"numcodecs-wasm-zlib~=0.2.0",
"numcodecs-wasm~=0.1.6",
"numcodecs-wasm-bit-round~=0.3.0",
"numcodecs-wasm-fixed-offset-scale~=0.3.0",
"numcodecs-wasm-jpeg2000~=0.2.0",
"numcodecs-wasm-pco~=0.2.0",
"numcodecs-wasm-round~=0.3.0",
"numcodecs-wasm-sz3~=0.6.0",
"numcodecs-wasm-tthresh~=0.2.0",
"numcodecs-wasm-uniform-noise~=0.3.0",
"numcodecs-wasm-zfp~=0.5.1",
"numcodecs-wasm-zlib~=0.3.0",
"pandas~=2.2",
"scipy~=1.15",
Comment thread
juntyr marked this conversation as resolved.
"scipy~=1.14",
"tabulate~=0.9",
"typed-classproperties~=1.1.0",
"xarray~=2024.11.0",
"xarray>=2024.11.0,<2025.4",
"zarr~=2.18.0",
]

Expand Down
4 changes: 2 additions & 2 deletions src/climatebenchpress/compressor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["compressors", "metrics", "tests"]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm definitely amenable to a different name, this just makes things work with the least effort without making anything worse

__all__ = ["compressors", "metrics", "scripts", "tests"]

from . import cf as cf
from . import compressors, metrics, tests
from . import compressors, metrics, scripts, tests
14 changes: 14 additions & 0 deletions src/climatebenchpress/compressor/monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
__all__ = ["progress_bar"]

from contextlib import contextmanager

from dask.diagnostics.progress import ProgressBar


@contextmanager
def progress_bar(progress: bool = True):
if progress:
with ProgressBar():
yield
else:
yield
1 change: 1 addition & 0 deletions src/climatebenchpress/compressor/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__all__: list[str] = []
Comment thread
treigerm marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__all__ = ["collect_metrics"]

import json
import re
from pathlib import Path
Expand All @@ -7,8 +9,6 @@

import climatebenchpress.compressor

REPO = Path(__file__).parent.parent

EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = {
"MAE": climatebenchpress.compressor.metrics.MAE(),
"Max Absolute Error": climatebenchpress.compressor.metrics.MaxAbsError(),
Expand All @@ -24,10 +24,13 @@
}


def main():
datasets = REPO.parent / "data-loader" / "datasets"
compressed_datasets = REPO / "compressed-datasets"
metrics_dir = REPO / "metrics"
def collect_metrics(
basepath: Path = Path(),
data_loader_base_path: None | Path = None,
Comment thread
juntyr marked this conversation as resolved.
):
datasets = (data_loader_base_path or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"
metrics_dir = basepath / "metrics"

all_results = []
for dataset in compressed_datasets.iterdir():
Expand Down Expand Up @@ -74,8 +77,8 @@ def main():
df["Error Bound"] = error_bound.name
all_results.append(df)

all_results = pd.concat(all_results)
all_results.to_csv(metrics_dir / "all_results.csv", index=False)
all_results_df = pd.concat(all_results)
all_results_df.to_csv(metrics_dir / "all_results.csv", index=False)


def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]:
Expand Down Expand Up @@ -192,7 +195,7 @@ def compute_tests(


def load_measurements(compressed_dataset: Path, compressor: Path) -> pd.DataFrame:
with open(compressed_dataset / "measurements.json") as f:
with (compressed_dataset / "measurements.json").open() as f:
Comment thread
juntyr marked this conversation as resolved.
measurements = json.load(f)

rows = []
Expand Down Expand Up @@ -253,4 +256,7 @@ def merge_metrics(


if __name__ == "__main__":
main()
collect_metrics(
basepath=Path(),
data_loader_base_path=Path() / ".." / "data-loader",
)
Original file line number Diff line number Diff line change
@@ -1,31 +1,40 @@
__all__ = ["compress"]

import argparse
import json
import traceback
from collections.abc import Container
from pathlib import Path
from typing import Hashable

import numcodecs_observers
import xarray as xr
from climatebenchpress.compressor.compressors.abc import (
Compressor,
ErrorBound,
NamedPerVariableCodec,
)
from dask.diagnostics.progress import ProgressBar
from numcodecs.abc import Codec
from numcodecs_combinators.stack import CodecStack
from numcodecs_observers.bytesize import BytesizeObserver
from numcodecs_observers.hash import HashableCodec
from numcodecs_observers.walltime import WalltimeObserver
from numcodecs_wasm import WasmCodecInstructionCounterObserver

REPO = Path(__file__).parent.parent


def main(exclude_dataset, include_dataset, exclude_compressor, include_compressor):
datasets = REPO.parent / "data-loader" / "datasets"
compressed_datasets = REPO / "compressed-datasets"
datasets_error_bounds = REPO / "datasets-error-bounds"
from ..compressors.abc import (
Compressor,
ErrorBound,
NamedPerVariableCodec,
)
from ..monitor import progress_bar


def compress(
basepath: Path = Path(),
exclude_dataset: Container[str] = tuple(),
include_dataset: None | Container[str] = None,
exclude_compressor: Container[str] = tuple(),
include_compressor: None | Container[str] = None,
data_loader_base_path: None | Path = None,
progress: bool = True,
):
datasets = (data_loader_base_path or basepath) / "datasets"
compressed_datasets = basepath / "compressed-datasets"
datasets_error_bounds = basepath / "datasets-error-bounds"

for dataset in datasets.iterdir():
if dataset.name == ".gitignore" or dataset.name in exclude_dataset:
Expand All @@ -34,6 +43,11 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compresso
continue

dataset /= "standardized.zarr"

if not dataset.exists():
Comment thread
juntyr marked this conversation as resolved.
print(f"No input dataset at {dataset}")
continue

ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr")
ds_dtypes, ds_abs_mins, ds_abs_maxs = dict(), dict(), dict()
for v in ds:
Expand Down Expand Up @@ -88,14 +102,14 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compresso
with (compressed_dataset / "measurements.json").open("w") as f:
json.dump(measurements, f)

with ProgressBar():
with progress_bar(progress):
ds_new.to_zarr(
compressed_dataset_path, encoding=dict(), compute=False
).compute()


def compress_decompress(
codecs: dict[Hashable, Codec],
codecs: dict[str, Codec],
ds: xr.Dataset,
) -> tuple[xr.Dataset, dict]:
variables = dict()
Expand All @@ -106,7 +120,7 @@ def compress_decompress(
timing = WalltimeObserver()
instructions = WasmCodecInstructionCounterObserver()

codec = codecs[v]
codec = codecs[v] # type: ignore
if not isinstance(codec, CodecStack):
codec = CodecStack(codec)

Expand Down Expand Up @@ -151,7 +165,7 @@ def get_error_bounds(
)

dataset_error_bounds = datasets_error_bounds / dataset_name
with open(dataset_error_bounds / "error_bounds.json") as f:
with (dataset_error_bounds / "error_bounds.json").open() as f:
error_bounds = json.load(f)
return [
{var_name: ErrorBound(**eb) for var_name, eb in eb_per_var.items()}
Expand All @@ -166,4 +180,13 @@ def get_error_bounds(
parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[])
parser.add_argument("--include-compressor", type=str, nargs="+", default=None)
args = parser.parse_args()
main(**vars(args))

compress(
basepath=Path(),
exclude_dataset=args.exclude_dataset,
include_dataset=args.include_dataset,
exclude_compressor=args.exclude_compressor,
include_compressor=args.include_compressor,
data_loader_base_path=Path() / ".." / "data-loader",
progress=True,
)
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
__all__ = ["create_error_bounds"]

import json
from pathlib import Path

import xarray as xr

REPO = Path(__file__).parent.parent


def main():
datasets = REPO.parent / "data-loader" / "datasets"
datasets_error_bounds = REPO / "datasets-error-bounds"
def create_error_bounds(
basepath: Path = Path(),
data_loader_base_path: None | Path = None,
):
datasets = (data_loader_base_path or basepath) / "datasets"
datasets_error_bounds = basepath / "datasets-error-bounds"

for dataset in datasets.iterdir():
if dataset.name == ".gitignore":
continue

if not (dataset / "standardized.zarr").exists():
print(f"No input dataset at {dataset / 'standardized.zarr'}")
continue

print(dataset.name)
ds = xr.open_dataset(
dataset / "standardized.zarr",
Expand All @@ -26,7 +33,7 @@ def main():
# principled method to selct the error bounds.
low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict()
for v in ds:
data_range = (ds[v].max() - ds[v].min()).values.item()
data_range: float = (ds[v].max() - ds[v].min()).values.item() # type: ignore
low_error_bounds[v] = {"abs_error": 0.0001 * data_range, "rel_error": None}
mid_error_bounds[v] = {"abs_error": 0.001 * data_range, "rel_error": None}
high_error_bounds[v] = {"abs_error": 0.01 * data_range, "rel_error": None}
Expand All @@ -35,9 +42,12 @@ def main():

dataset_error_bounds = datasets_error_bounds / dataset.name
dataset_error_bounds.mkdir(parents=True, exist_ok=True)
with open(dataset_error_bounds / "error_bounds.json", "w") as f:
with (dataset_error_bounds / "error_bounds.json").open("w") as f:
json.dump(error_bounds, f)


if __name__ == "__main__":
main()
create_error_bounds(
basepath=Path(),
data_loader_base_path=Path() / ".." / "data-loader",
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
__all__ = ["print_metrics"]

from pathlib import Path

import pandas as pd

REPO = Path(__file__).parent.parent


def main():
all_results = pd.read_csv(REPO / "metrics" / "all_results.csv")
def print_metrics(basepath: Path = Path()):
all_results = pd.read_csv(basepath / "metrics" / "all_results.csv")
for dataset in all_results["Dataset"].unique():
print("\n" + 100 * "=")
print(f"Results on {dataset}")
Expand All @@ -19,4 +19,4 @@ def main():


if __name__ == "__main__":
main()
print_metrics(basepath=Path())
Loading