diff --git a/README.md b/README.md index 8faaf13..567e051 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ An overview of the content of the YAML configuration file specified via `-c` / ` │ inittime: │ Forecast initialization time │ │ leadtime: │ Forecast leadtime │ │ validtime: │ Forecast validtime │ +│ format: │ 'grib', 'netcdf', or 'zarr' (optional) │ │ mask: │ Sequence of [lat, lon] pairs (optional) │ │ name: │ Dataset descriptive name │ │ path: │ Filesystem path to Zarr/netCDF dataset │ @@ -145,6 +146,10 @@ Specify values under `forecast.coords.time` as follows: If a variable specified under `forecast.coords.time` names a coordinate dimension variable, that variable will be used. If no such variable exists, `wxvx` will look for a dataset attribute with the given name and try to use it, coercing it to the expected type (e.g. `datetime` or `timedelta`) as needed. For example, it will parse an ISO8601-formatted string to a Python `datetime` object. +### forecast.format + +If this optional value is omitted, `wxvx` will introspect forecast datasets to determine if they are GRIB, netCDF, or Zarr. As a performance optimization, or as an override for correctness in case `wxvx` makes the wrong determination, a value of `grib`, `netcdf`, or `zarr` may be supplied to indicate that forecast datasets are formatted as GRIB, netCDF, or Zarr, respectively. In this case, `wxvx` will behave as if it had determined the format, and will likely fail if the indicated format is incorrect. + ### forecast.mask A sequence of latitude/longitude pairs describing a masking polygon. See the [Example](#example). The specified mask will be applied to forecast, baseline, or truth grids before verification. diff --git a/recipe/meta.json b/recipe/meta.json index 371f9f8..01579e7 100644 --- a/recipe/meta.json +++ b/recipe/meta.json @@ -21,7 +21,6 @@ "pytest-xdist ==3.8.*", "python ==3.13", "python-eccodes ==2.42.*", - "python-magic ==0.4.*", "pyyaml ==6.0.*", "requests ==2.32.*", "ruff ==0.15.*", @@ -41,7 +40,6 @@ "pyproj ==3.7.*", "python ==3.13", "python-eccodes ==2.42.*", - "python-magic ==0.4.*", "pyyaml ==6.0.*", "requests ==2.32.*", "seaborn ==0.13.*", @@ -50,5 +48,5 @@ "zarr ==3.1.*" ] }, - "version": "0.5.1" + "version": "0.6.0" } diff --git a/recipe/meta.yaml b/recipe/meta.yaml index c328684..1358469 100644 --- a/recipe/meta.yaml +++ b/recipe/meta.yaml @@ -21,7 +21,6 @@ requirements: - netcdf4 1.7.* - pyproj 3.7.* - python-eccodes 2.42.* - - python-magic 0.4.* - pyyaml 6.0.* - requests 2.32.* - seaborn 0.13.* diff --git a/src/pyproject.toml b/src/pyproject.toml index 35d39e7..176dd1b 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -25,6 +25,7 @@ warn_return_any = true [tool.pytest.ini_options] filterwarnings = [ + "ignore:Consolidated metadata .*:UserWarning", "ignore:This process .* is multi-threaded:DeprecationWarning", "ignore:jsonschema.RefResolver is deprecated:DeprecationWarning" # from uwtools ] @@ -42,6 +43,7 @@ ignore = [ "ANN202", # missing-return-type-private-function "ANN204", # missing-return-type-special-method "ANN401", # any-type + "BLE001", # blind-except "C408", # unnecessary-collection-call "C901", # complex-structure "COM812", # missing-trailing-comma diff --git a/src/wxvx/resources/config-grid.yaml b/src/wxvx/resources/config-grid.yaml index d5c33e9..ca83019 100644 --- a/src/wxvx/resources/config-grid.yaml +++ b/src/wxvx/resources/config-grid.yaml @@ -12,6 +12,7 @@ forecast: time: inittime: time leadtime: lead_time + format: zarr mask: - [52.61564933, 225.90452027] - [52.61564933, 255.00000000] diff --git a/src/wxvx/resources/config-point.yaml b/src/wxvx/resources/config-point.yaml index 6fb6a38..334344a 100644 --- a/src/wxvx/resources/config-point.yaml +++ b/src/wxvx/resources/config-point.yaml @@ -11,6 +11,7 @@ forecast: time: inittime: forecast_reference_time validtime: time + format: zarr name: AIWX path: /path/to/forecast.zarr projection: diff --git a/src/wxvx/resources/config.jsonschema b/src/wxvx/resources/config.jsonschema index c1b195d..220f84a 100644 --- a/src/wxvx/resources/config.jsonschema +++ b/src/wxvx/resources/config.jsonschema @@ -153,6 +153,13 @@ ], "type": "object" }, + "format": { + "enum": [ + "grib", + "netcdf", + "zarr" + ] + }, "mask": { "items": { "items": { diff --git a/src/wxvx/resources/info.json b/src/wxvx/resources/info.json index 4331341..198fdae 100644 --- a/src/wxvx/resources/info.json +++ b/src/wxvx/resources/info.json @@ -1,4 +1,4 @@ { "buildnum": "0", - "version": "0.5.1" + "version": "0.6.0" } diff --git a/src/wxvx/strings.py b/src/wxvx/strings.py index 3918857..810b538 100644 --- a/src/wxvx/strings.py +++ b/src/wxvx/strings.py @@ -146,6 +146,7 @@ class _S(_ValsMatchKeys): firstbyte: str = _ forecast: str = _ forecast_reference_time: str = _ + format: str = _ grid: str = _ grids: str = _ grids_baseline: str = _ diff --git a/src/wxvx/tests/support.py b/src/wxvx/tests/support.py index 1e8505e..26ad11b 100644 --- a/src/wxvx/tests/support.py +++ b/src/wxvx/tests/support.py @@ -13,7 +13,9 @@ def with_del(d: dict, *args: Any) -> dict: p = new for key in args[:-1]: p = p[key] - del p[args[-1]] + key = args[-1] + if key in p: + del p[key] return new diff --git a/src/wxvx/tests/test_schema.py b/src/wxvx/tests/test_schema.py index 2a58661..b6ce1ce 100644 --- a/src/wxvx/tests/test_schema.py +++ b/src/wxvx/tests/test_schema.py @@ -113,6 +113,12 @@ def test_schema_forecast(logged, config_data, fs): # Additional keys are not allowed: assert not ok(with_set(config, 42, "n")) assert logged("'n' was unexpected") + # Some keys have enum values: + for key in [S.format]: + for val in ["grib", "netcdf", "zarr"]: + assert ok(with_set(config, val, key)) + assert not ok(with_set(config, "foo", key)) + assert logged(r"'foo' is not one of \['grib', 'netcdf', 'zarr'\]") # Some keys have object values: for key in [S.coords, S.projection]: assert not ok(with_set(config, None, key)) @@ -122,7 +128,7 @@ def test_schema_forecast(logged, config_data, fs): assert not ok(with_set(config, None, key)) assert logged("None is not of type 'string'") # Some keys are optional: - for key in [S.mask]: + for key in [S.format, S.mask]: assert ok(with_del(config, key)) @@ -349,6 +355,12 @@ def test_schema_variables(logged, config_data, fs): assert logged("None is not of type 'string'") +def test_support_with_del(): + # Test case where with_del() finds nothing to delete, for 100% branch coverage: + c = {"a": "apple"} + assert with_del(c, "b") == c + + # Helpers diff --git a/src/wxvx/tests/test_types.py b/src/wxvx/tests/test_types.py index 03f3d3c..989fb17 100644 --- a/src/wxvx/tests/test_types.py +++ b/src/wxvx/tests/test_types.py @@ -11,7 +11,7 @@ from wxvx import types from wxvx.strings import EC, MET, S -from wxvx.util import WXVXError, resource_path +from wxvx.util import DataFormat, WXVXError, resource_path # Fixtures @@ -247,6 +247,7 @@ def test_types_Forecast(config_data, forecast): assert obj.coords.longitude == "longitude" assert obj.coords.time.inittime == "time" assert obj.coords.time.leadtime == "lead_time" + assert obj.format is None assert obj.name == "Forecast" assert obj.path == "/path/to/forecast-{{ yyyymmdd }}-{{ hh }}-{{ '%03d' % fh }}.nc" cfg = config_data[S.forecast] @@ -258,6 +259,14 @@ def test_types_Forecast(config_data, forecast): del cfg_no_proj[S.projection] default = types.Forecast(**cfg_no_proj) assert default.projection == {S.proj: S.latlon} + for k, v in { + "grib": DataFormat.GRIB, + "netcdf": DataFormat.NETCDF, + "zarr": DataFormat.ZARR, + None: None, + }.items(): + obj = types.Forecast(**{**config_data[S.forecast], "format": k}) + assert obj.format == v def test_types_Leadtimes(): diff --git a/src/wxvx/tests/test_util.py b/src/wxvx/tests/test_util.py index 73d3446..bb22136 100644 --- a/src/wxvx/tests/test_util.py +++ b/src/wxvx/tests/test_util.py @@ -9,6 +9,7 @@ from pathlib import Path from unittest.mock import Mock, patch +import xarray as xr from pytest import mark, raises from wxvx import util @@ -36,60 +37,36 @@ def test_util_atomic(fakefs): assert recipient.read_text() == s2 -@mark.parametrize( - ("expected", "inferred"), - [ - (util.DataFormat.BUFR, "Binary Universal Form data (BUFR) Edition 3"), - (util.DataFormat.GRIB, "Gridded binary (GRIB) version 2"), - (util.DataFormat.NETCDF, "Hierarchical Data Format (version 5) data"), - ], -) -def test_util_classify_data_format__file(expected, fakefs, inferred): - path = fakefs / "datafile" - path.touch() - util.classify_data_format.cache_clear() - with patch.object(util.magic, "from_file", return_value=inferred): - assert util.classify_data_format(path=path) == expected - - -def test_util_classify_data_format__file_missing(fakefs, logged): - path = fakefs / "no-such-file" - util.classify_data_format.cache_clear() +def test_util_classify_data_format__fail_missing(fakefs, logged): + path = fakefs / "a.missing" assert util.classify_data_format(path=path) == util.DataFormat.UNKNOWN assert logged(f"Path not found: {path}") -def test_util_classify_data_format__file_unrecognized(fakefs, logged): - path = fakefs / "datafile" - path.touch() - util.classify_data_format.cache_clear() - with patch.object(util.magic, "from_file", return_value="What Is This I Don't Even"): - assert util.classify_data_format(path=path) == util.DataFormat.UNKNOWN - assert logged(f"Could not determine format of {path}") +def test_util_classify_data_format__fail_unknown(logged, tmp_path): + path = tmp_path / "a.foo" + path.write_text("foo") + assert util.classify_data_format(path=path) == util.DataFormat.UNKNOWN + assert logged(f"Could not determine format of: {path}") -def test_util_classify_data_format__zarr(fakefs): - path = fakefs / "datadir" - path.mkdir() - util.classify_data_format.cache_clear() - with patch.object(util.zarr, "open"): - assert util.classify_data_format(path=path) == util.DataFormat.ZARR +def test_util_classify_data_format__pass_grib(tmp_path): + path = tmp_path / "a.grib" + for edition in [1, 2]: + path.write_bytes(b"GRIB\x00\x00\x00" + int.to_bytes(edition)) + assert util.classify_data_format(path=path) == util.DataFormat.GRIB -def test_util_classify_data_format__zarr_corrupt(fakefs, logged): - path = fakefs / "datadir" - path.mkdir() - util.classify_data_format.cache_clear() - with patch.object(util.zarr, "open", side_effect=Exception("failure")): - assert util.classify_data_format(path=path) == util.DataFormat.UNKNOWN - assert logged(f"Could not determine format of {path}") +def test_util_classify_data_format__pass_netcdf(tmp_path): + path = tmp_path / "a.nc" + xr.DataArray([1]).to_netcdf(path) + assert util.classify_data_format(path=path) == util.DataFormat.NETCDF -def test_util_classify_data_format__zarr_missing(fakefs, logged): - path = fakefs / "no-such-dir" - util.classify_data_format.cache_clear() - assert util.classify_data_format(path=path) == util.DataFormat.UNKNOWN - assert logged(f"Path not found: {path}") +def test_util_classify_data_format__pass_zarr(tmp_path): + path = tmp_path / "a.zarr" + xr.DataArray([1]).to_zarr(path) + assert util.classify_data_format(path=path) == util.DataFormat.ZARR @mark.parametrize( diff --git a/src/wxvx/tests/test_workflow.py b/src/wxvx/tests/test_workflow.py index 3fd1fd9..dba3be1 100644 --- a/src/wxvx/tests/test_workflow.py +++ b/src/wxvx/tests/test_workflow.py @@ -726,18 +726,28 @@ def test_workflow__enforce_point_truth_type(c): @mark.parametrize( - ("fmt", "path"), - [(DataFormat.NETCDF, "/path/to/a.nc"), (DataFormat.ZARR, "/path/to/a.zarr")], + ("datafmt_expected", "fmtstr", "path"), + [ + (DataFormat.NETCDF, "netcdf", "/path/to/a.nc"), + (DataFormat.NETCDF, None, "/path/to/a.nc"), + (DataFormat.ZARR, "zarr", "/path/to/a.zarr"), + (DataFormat.ZARR, None, "/path/to/a.zarr"), + ], ) -def test_workflow__forecast_grid(c, fmt, path, tc, testvars): - with patch.object(workflow, "classify_data_format", return_value=fmt): - req, datafmt = workflow._forecast_grid( +def test_workflow__forecast_grid(c, datafmt_expected, fmtstr, path, tc, testvars): + c.forecast._format = fmtstr + with patch.object( + workflow, "classify_data_format", return_value=datafmt_expected + ) as classify_data_format: + req, datafmt_actual = workflow._forecast_grid( path=path, c=c, varname="foo", tc=tc, var=testvars[EC.t2] ) + expected_classify_data_format_call_count = 0 if fmtstr else 1 + assert classify_data_format.call_count == expected_classify_data_format_call_count # For netCDF and Zarr forecast datasets, the grid will be extracted from the dataset and CF- # decorated, so the requirement is a _grid_nc task, whose taskname is "Forecast grid ..." assert req.taskname.startswith("Forecast grid") - assert datafmt == fmt + assert datafmt_actual == datafmt_expected def test_workflow__forecast_grid__grib(c, tc, testvars): diff --git a/src/wxvx/types.py b/src/wxvx/types.py index ab9a947..dd60189 100644 --- a/src/wxvx/types.py +++ b/src/wxvx/types.py @@ -12,7 +12,15 @@ from uwtools.api.config import YAMLConfig, validate from wxvx.strings import MET, S -from wxvx.util import LINETYPE, WXVXError, expand, resource_path, to_datetime, to_timedelta +from wxvx.util import ( + LINETYPE, + DataFormat, + WXVXError, + expand, + resource_path, + to_datetime, + to_timedelta, +) _TRUTH_NAMES_GRID = (S.GFS, S.HRRR) _TRUTH_NAMES_POINT = (S.PREPBUFR,) @@ -225,6 +233,7 @@ def values(self) -> list[datetime]: class Forecast: KEYS = ( S.coords, + S.format, S.mask, S.name, S.path, @@ -236,12 +245,14 @@ def __init__( name: str, path: str, coords: Coords | dict | None = None, + format: DataFormat | str | None = None, # noqa: A002 mask: list[list[float]] | None = None, projection: dict | None = None, ): self._name = name self._path = path self._coords = coords + self._format = format self._mask = mask self._projection = projection @@ -263,6 +274,12 @@ def name(self) -> str: def path(self) -> str: return self._path + @property + def format(self) -> DataFormat | None: + if isinstance(self._format, str): + self._format = cast(DataFormat, getattr(DataFormat, self._format.upper())) + return self._format + @property def coords(self) -> Coords | None: if isinstance(self._coords, dict): diff --git a/src/wxvx/util.py b/src/wxvx/util.py index f6ee1d4..1e395ab 100644 --- a/src/wxvx/util.py +++ b/src/wxvx/util.py @@ -17,14 +17,14 @@ from urllib.parse import urlparse import jinja2 -import magic +import netCDF4 import zarr from wxvx.strings import MET, S from wxvx.times import tcinfo if TYPE_CHECKING: - from collections.abc import Iterator + from collections.abc import Callable, Iterator from wxvx.times import TimeCoords @@ -71,28 +71,38 @@ def atomic(path: Path) -> Iterator[Path]: @cache def classify_data_format(path: str | Path) -> DataFormat: - path = Path(path) - if path.is_file(): - inferred = magic.from_file(path.resolve()) - for pre, fmt in [ - ("Binary Universal Form data (BUFR)", DataFormat.BUFR), - ("Gridded binary (GRIB)", DataFormat.GRIB), - ("Hierarchical Data Format", DataFormat.NETCDF), - ]: - if inferred.startswith(pre): - return fmt - elif path.is_dir(): + def check(f: Callable) -> bool: try: - zarr.open(path, mode="r") + f() except Exception as e: for line in str(e).split(): - logging.exception(line) - else: - return DataFormat.ZARR - else: + logging.debug(line) + return False + return True + + def grib(path: Path) -> None: + # It might be better to just try to read the file with a GRIB library like cfgrib, but this + # tends to be unacceptably slow: Since a GRIB file is just a series of messages without any + # kind of header/metadata, cfgrib et al. read the entire file. Instead, inspect the initial + # bytes in the file to see if it is apparently GRIB. + with path.open(mode="rb") as f: + header = f.read(8) + editions = (1, 2) + apparently_grib = header[:4] == b"GRIB" and header[7] in editions + assert apparently_grib + + path = Path(path) + if not path.exists(): logging.warning("Path not found: %s", path) return DataFormat.UNKNOWN - logging.error("Could not determine format of %s", path) + if check(lambda: zarr.open(path, mode="r")): + return DataFormat.ZARR + if check(lambda: netCDF4.Dataset(path, mode="r")): + return DataFormat.NETCDF + if check(lambda: grib(path)): + return DataFormat.GRIB + + logging.error("Could not determine format of: %s", path) return DataFormat.UNKNOWN diff --git a/src/wxvx/workflow.py b/src/wxvx/workflow.py index a6db0ca..1e0b79f 100644 --- a/src/wxvx/workflow.py +++ b/src/wxvx/workflow.py @@ -606,7 +606,8 @@ def _enforce_point_truth_type(c: Config, taskname: str): def _forecast_grid( path: Path, c: Config, varname: str, tc: TimeCoords, var: Var ) -> tuple[Node, DataFormat]: - data_format = classify_data_format(path) + if not (data_format := c.forecast.format): + data_format = classify_data_format(path) if data_format is DataFormat.UNKNOWN: return _missing(path), data_format if data_format == DataFormat.GRIB: