diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1fe2bf0..8b62a6f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,6 +51,3 @@ jobs: - name: Typecheck the package run: uv run mypy -p climatebenchpress.compressor - - - name: Typecheck the scripts - run: uv run mypy scripts/ diff --git a/pyproject.toml b/pyproject.toml index ca84ece..8a92df2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,26 +7,26 @@ requires-python = ">=3.12" dependencies = [ "astropy~=7.0.1", "cf-xarray~=0.10", - "dask~=2024.12.0", - "numcodecs>=0.13.0,<0.16", + "dask>=2024.12.0,<2025.4", + "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.4", "numcodecs-observers~=0.1.1", - "numcodecs-wasm~=0.1.3", - "numcodecs-wasm-bit-round~=0.2.0", - "numcodecs-wasm-fixed-offset-scale~=0.2.1", - "numcodecs-wasm-jpeg2000~=0.1.1", - "numcodecs-wasm-pco~=0.1.0", - "numcodecs-wasm-round~=0.2.0", - "numcodecs-wasm-sz3~=0.5.0", - "numcodecs-wasm-tthresh~=0.1.0", - "numcodecs-wasm-uniform-noise~=0.2.0", - "numcodecs-wasm-zfp~=0.4.0", - "numcodecs-wasm-zlib~=0.2.0", + "numcodecs-wasm~=0.1.6", + "numcodecs-wasm-bit-round~=0.3.0", + "numcodecs-wasm-fixed-offset-scale~=0.3.0", + "numcodecs-wasm-jpeg2000~=0.2.0", + "numcodecs-wasm-pco~=0.2.0", + "numcodecs-wasm-round~=0.3.0", + "numcodecs-wasm-sz3~=0.6.0", + "numcodecs-wasm-tthresh~=0.2.0", + "numcodecs-wasm-uniform-noise~=0.3.0", + "numcodecs-wasm-zfp~=0.5.1", + "numcodecs-wasm-zlib~=0.3.0", "pandas~=2.2", - "scipy~=1.15", + "scipy~=1.14", "tabulate~=0.9", "typed-classproperties~=1.1.0", - "xarray~=2024.11.0", + "xarray>=2024.11.0,<2025.4", "zarr~=2.18.0", ] diff --git a/src/climatebenchpress/compressor/__init__.py b/src/climatebenchpress/compressor/__init__.py index d4c6440..ad65d6a 100644 --- a/src/climatebenchpress/compressor/__init__.py +++ b/src/climatebenchpress/compressor/__init__.py @@ -1,4 +1,4 @@ -__all__ = ["compressors", "metrics", "tests"] +__all__ = ["compressors", "metrics", "scripts", "tests"] from . import cf as cf -from . import compressors, metrics, tests +from . import compressors, metrics, scripts, tests diff --git a/src/climatebenchpress/compressor/monitor.py b/src/climatebenchpress/compressor/monitor.py new file mode 100644 index 0000000..2da298d --- /dev/null +++ b/src/climatebenchpress/compressor/monitor.py @@ -0,0 +1,14 @@ +__all__ = ["progress_bar"] + +from contextlib import contextmanager + +from dask.diagnostics.progress import ProgressBar + + +@contextmanager +def progress_bar(progress: bool = True): + if progress: + with ProgressBar(): + yield + else: + yield diff --git a/src/climatebenchpress/compressor/scripts/__init__.py b/src/climatebenchpress/compressor/scripts/__init__.py new file mode 100644 index 0000000..c9c2ef6 --- /dev/null +++ b/src/climatebenchpress/compressor/scripts/__init__.py @@ -0,0 +1 @@ +__all__: list[str] = [] diff --git a/scripts/collect_metrics.py b/src/climatebenchpress/compressor/scripts/collect_metrics.py similarity index 93% rename from scripts/collect_metrics.py rename to src/climatebenchpress/compressor/scripts/collect_metrics.py index 3d5221d..1123594 100644 --- a/scripts/collect_metrics.py +++ b/src/climatebenchpress/compressor/scripts/collect_metrics.py @@ -1,3 +1,5 @@ +__all__ = ["collect_metrics"] + import json import re from pathlib import Path @@ -7,8 +9,6 @@ import climatebenchpress.compressor -REPO = Path(__file__).parent.parent - EVALUATION_METRICS: dict[str, climatebenchpress.compressor.metrics.abc.Metric] = { "MAE": climatebenchpress.compressor.metrics.MAE(), "Max Absolute Error": climatebenchpress.compressor.metrics.MaxAbsError(), @@ -24,10 +24,13 @@ } -def main(): - datasets = REPO.parent / "data-loader" / "datasets" - compressed_datasets = REPO / "compressed-datasets" - metrics_dir = REPO / "metrics" +def collect_metrics( + basepath: Path = Path(), + data_loader_base_path: None | Path = None, +): + datasets = (data_loader_base_path or basepath) / "datasets" + compressed_datasets = basepath / "compressed-datasets" + metrics_dir = basepath / "metrics" all_results = [] for dataset in compressed_datasets.iterdir(): @@ -74,8 +77,8 @@ def main(): df["Error Bound"] = error_bound.name all_results.append(df) - all_results = pd.concat(all_results) - all_results.to_csv(metrics_dir / "all_results.csv", index=False) + all_results_df = pd.concat(all_results) + all_results_df.to_csv(metrics_dir / "all_results.csv", index=False) def parse_error_bounds(error_bound_str: str) -> dict[str, tuple[str, float]]: @@ -192,7 +195,7 @@ def compute_tests( def load_measurements(compressed_dataset: Path, compressor: Path) -> pd.DataFrame: - with open(compressed_dataset / "measurements.json") as f: + with (compressed_dataset / "measurements.json").open() as f: measurements = json.load(f) rows = [] @@ -253,4 +256,7 @@ def merge_metrics( if __name__ == "__main__": - main() + collect_metrics( + basepath=Path(), + data_loader_base_path=Path() / ".." / "data-loader", + ) diff --git a/scripts/compress.py b/src/climatebenchpress/compressor/scripts/compress.py similarity index 82% rename from scripts/compress.py rename to src/climatebenchpress/compressor/scripts/compress.py index e350823..251f05d 100644 --- a/scripts/compress.py +++ b/src/climatebenchpress/compressor/scripts/compress.py @@ -1,17 +1,13 @@ +__all__ = ["compress"] + import argparse import json import traceback +from collections.abc import Container from pathlib import Path -from typing import Hashable import numcodecs_observers import xarray as xr -from climatebenchpress.compressor.compressors.abc import ( - Compressor, - ErrorBound, - NamedPerVariableCodec, -) -from dask.diagnostics.progress import ProgressBar from numcodecs.abc import Codec from numcodecs_combinators.stack import CodecStack from numcodecs_observers.bytesize import BytesizeObserver @@ -19,13 +15,26 @@ from numcodecs_observers.walltime import WalltimeObserver from numcodecs_wasm import WasmCodecInstructionCounterObserver -REPO = Path(__file__).parent.parent - - -def main(exclude_dataset, include_dataset, exclude_compressor, include_compressor): - datasets = REPO.parent / "data-loader" / "datasets" - compressed_datasets = REPO / "compressed-datasets" - datasets_error_bounds = REPO / "datasets-error-bounds" +from ..compressors.abc import ( + Compressor, + ErrorBound, + NamedPerVariableCodec, +) +from ..monitor import progress_bar + + +def compress( + basepath: Path = Path(), + exclude_dataset: Container[str] = tuple(), + include_dataset: None | Container[str] = None, + exclude_compressor: Container[str] = tuple(), + include_compressor: None | Container[str] = None, + data_loader_base_path: None | Path = None, + progress: bool = True, +): + datasets = (data_loader_base_path or basepath) / "datasets" + compressed_datasets = basepath / "compressed-datasets" + datasets_error_bounds = basepath / "datasets-error-bounds" for dataset in datasets.iterdir(): if dataset.name == ".gitignore" or dataset.name in exclude_dataset: @@ -34,6 +43,11 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compresso continue dataset /= "standardized.zarr" + + if not dataset.exists(): + print(f"No input dataset at {dataset}") + continue + ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr") ds_dtypes, ds_abs_mins, ds_abs_maxs = dict(), dict(), dict() for v in ds: @@ -88,14 +102,14 @@ def main(exclude_dataset, include_dataset, exclude_compressor, include_compresso with (compressed_dataset / "measurements.json").open("w") as f: json.dump(measurements, f) - with ProgressBar(): + with progress_bar(progress): ds_new.to_zarr( compressed_dataset_path, encoding=dict(), compute=False ).compute() def compress_decompress( - codecs: dict[Hashable, Codec], + codecs: dict[str, Codec], ds: xr.Dataset, ) -> tuple[xr.Dataset, dict]: variables = dict() @@ -106,7 +120,7 @@ def compress_decompress( timing = WalltimeObserver() instructions = WasmCodecInstructionCounterObserver() - codec = codecs[v] + codec = codecs[v] # type: ignore if not isinstance(codec, CodecStack): codec = CodecStack(codec) @@ -151,7 +165,7 @@ def get_error_bounds( ) dataset_error_bounds = datasets_error_bounds / dataset_name - with open(dataset_error_bounds / "error_bounds.json") as f: + with (dataset_error_bounds / "error_bounds.json").open() as f: error_bounds = json.load(f) return [ {var_name: ErrorBound(**eb) for var_name, eb in eb_per_var.items()} @@ -166,4 +180,13 @@ def get_error_bounds( parser.add_argument("--exclude-compressor", type=str, nargs="+", default=[]) parser.add_argument("--include-compressor", type=str, nargs="+", default=None) args = parser.parse_args() - main(**vars(args)) + + compress( + basepath=Path(), + exclude_dataset=args.exclude_dataset, + include_dataset=args.include_dataset, + exclude_compressor=args.exclude_compressor, + include_compressor=args.include_compressor, + data_loader_base_path=Path() / ".." / "data-loader", + progress=True, + ) diff --git a/scripts/create_error_bounds.py b/src/climatebenchpress/compressor/scripts/create_error_bounds.py similarity index 62% rename from scripts/create_error_bounds.py rename to src/climatebenchpress/compressor/scripts/create_error_bounds.py index 16ab3c6..f33c84a 100644 --- a/scripts/create_error_bounds.py +++ b/src/climatebenchpress/compressor/scripts/create_error_bounds.py @@ -1,19 +1,26 @@ +__all__ = ["create_error_bounds"] + import json from pathlib import Path import xarray as xr -REPO = Path(__file__).parent.parent - -def main(): - datasets = REPO.parent / "data-loader" / "datasets" - datasets_error_bounds = REPO / "datasets-error-bounds" +def create_error_bounds( + basepath: Path = Path(), + data_loader_base_path: None | Path = None, +): + datasets = (data_loader_base_path or basepath) / "datasets" + datasets_error_bounds = basepath / "datasets-error-bounds" for dataset in datasets.iterdir(): if dataset.name == ".gitignore": continue + if not (dataset / "standardized.zarr").exists(): + print(f"No input dataset at {dataset / 'standardized.zarr'}") + continue + print(dataset.name) ds = xr.open_dataset( dataset / "standardized.zarr", @@ -26,7 +33,7 @@ def main(): # principled method to selct the error bounds. low_error_bounds, mid_error_bounds, high_error_bounds = dict(), dict(), dict() for v in ds: - data_range = (ds[v].max() - ds[v].min()).values.item() + data_range: float = (ds[v].max() - ds[v].min()).values.item() # type: ignore low_error_bounds[v] = {"abs_error": 0.0001 * data_range, "rel_error": None} mid_error_bounds[v] = {"abs_error": 0.001 * data_range, "rel_error": None} high_error_bounds[v] = {"abs_error": 0.01 * data_range, "rel_error": None} @@ -35,9 +42,12 @@ def main(): dataset_error_bounds = datasets_error_bounds / dataset.name dataset_error_bounds.mkdir(parents=True, exist_ok=True) - with open(dataset_error_bounds / "error_bounds.json", "w") as f: + with (dataset_error_bounds / "error_bounds.json").open("w") as f: json.dump(error_bounds, f) if __name__ == "__main__": - main() + create_error_bounds( + basepath=Path(), + data_loader_base_path=Path() / ".." / "data-loader", + ) diff --git a/scripts/print_metrics.py b/src/climatebenchpress/compressor/scripts/print_metrics.py similarity index 68% rename from scripts/print_metrics.py rename to src/climatebenchpress/compressor/scripts/print_metrics.py index 41f3a3f..d2dfc51 100644 --- a/scripts/print_metrics.py +++ b/src/climatebenchpress/compressor/scripts/print_metrics.py @@ -1,12 +1,12 @@ +__all__ = ["print_metrics"] + from pathlib import Path import pandas as pd -REPO = Path(__file__).parent.parent - -def main(): - all_results = pd.read_csv(REPO / "metrics" / "all_results.csv") +def print_metrics(basepath: Path = Path()): + all_results = pd.read_csv(basepath / "metrics" / "all_results.csv") for dataset in all_results["Dataset"].unique(): print("\n" + 100 * "=") print(f"Results on {dataset}") @@ -19,4 +19,4 @@ def main(): if __name__ == "__main__": - main() + print_metrics(basepath=Path())