diff --git a/CHANGELOG.md b/CHANGELOG.md index 98605aa..846a62f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [unreleased](https://github.com/mllam/mllam-data-prep/compare/v0.7.0...HEAD) + +### Fixes +- fix bug where coordinate selection of an unshared dimension isn't applied to subsequent ouput variables when an output variable without this dimension is processed before the others [\#90](https://github.com/mllam/mllam-data-prep/pull/90) @zweihuehner & @leifdenby + ## [v0.7.0](https://github.com/mllam/mllam-data-prep/release/tag/v0.7.0) [All changes](https://github.com/mllam/mllam-data-prep/compare/v0.7.0...v0.6.1) diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 5f12612..3b7cfb1 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -105,10 +105,17 @@ def _merge_dataarrays_by_target(dataarrays_by_target): ds = xr.merge(dataarrays, join="exact") except ValueError as ex: if ex.args[0].startswith("cannot align objects with join='exact'"): + + def _summarize(da): + dims = ", ".join([f"{k}: {v}" for k, v in da.sizes.items()]) + return f"{da.name} ({dims})\n{da.coords}" + + coord_summaries = "\n".join([_summarize(da) for da in dataarrays]) raise InvalidConfigException( - f"Couldn't merge together the dataarrays for all targets ({', '.join(dataarrays_by_target.keys())})" - f" This is likely because the dataarrays have different dimensions or coordinates." - " Maybe you need to give the 'feature' dimension a unique name for each target variable?" + f"Couldn't merge together the dataarrays for all targets ({', '.join(dataarrays_by_target.keys())}). " + "This is likely because the dataarrays have different dimensions or coordinates. " + f"Dataarray coords:\n{coord_summaries}" + "Maybe you need to give the 'feature' dimension a unique name for each target variable?" ) from ex else: raise ex @@ -245,10 +252,15 @@ def create_dataset(config: Config): # only need to do selection for the coordinates that the input dataset actually has if output_coord_ranges is not None: - output_coord_ranges = { + # Use a temporary dict to apply selection on coordinate ranges to avoid + # modifying the original ranges given in the config. This is needed because + # static features, for example, do not have a time dimension. Hence, the time + # based selection returns an empty dictionary, which should not overwrite the + # selection for the other variables. + output_coord_ranges_tmp = { k: w for k, w in output_coord_ranges.items() if k in output_dims } - da_target = select_by_kwargs(da_target, **output_coord_ranges) + da_target = select_by_kwargs(da_target, **output_coord_ranges_tmp) dataarrays_by_target[target_output_var].append(da_target) diff --git a/tests/test_output_coord_ranges_slicing.py b/tests/test_output_coord_ranges_slicing.py new file mode 100644 index 0000000..17f840a --- /dev/null +++ b/tests/test_output_coord_ranges_slicing.py @@ -0,0 +1,124 @@ +import itertools + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +import mllam_data_prep as mdp +from mllam_data_prep.config import DimMapping, InputDataset, Output, Range + + +def _write_zarr(ds: xr.Dataset, path): + ds.to_zarr(path, mode="w") + + +@pytest.mark.parametrize( + "input_order", + list(itertools.permutations(["state", "static", "forcing"])), +) +def test_output_coord_ranges_not_dropped_between_inputs(tmp_path, input_order): + """ + Ensure output coord range slicing is applied per-input without being + affected by the order of inputs. This guards against mutating the shared + output_coord_ranges dict when an input lacks a dimension (e.g. `static` + without `time`), which would otherwise remove slicing for later inputs. + + See https://github.com/mllam/mllam-data-prep/issues/81 for bug report. + """ + time = pd.date_range("2000-01-01", "2000-01-05", freq="1D") + x = np.arange(2) + + state_ds = xr.Dataset( + {"s": (("time", "x"), np.zeros((len(time), len(x))))}, + coords={"time": time, "x": x}, + ) + forcing_ds = xr.Dataset( + {"f": (("time", "x"), np.ones((len(time), len(x))))}, + coords={"time": time, "x": x}, + ) + static_ds = xr.Dataset( + {"static_feature": (("x",), np.array([10.0, 20.0]))}, + coords={"x": x}, + ) + + state_path = tmp_path / "state.zarr" + forcing_path = tmp_path / "forcing.zarr" + static_path = tmp_path / "static.zarr" + + _write_zarr(state_ds, state_path) + _write_zarr(forcing_ds, forcing_path) + _write_zarr(static_ds, static_path) + + inputs_by_name = { + "state": InputDataset( + path=str(state_path), + dims=["time", "x"], + variables=["s"], + target_output_variable="state", + dim_mapping={ + "time": DimMapping(method="rename", dim="time"), + "grid_index": DimMapping(method="stack", dims=["x"]), + "state_feature": DimMapping( + method="stack_variables_by_var_name", + name_format="{var_name}", + ), + }, + ), + "static": InputDataset( + path=str(static_path), + dims=["x"], + variables=["static_feature"], + target_output_variable="static", + dim_mapping={ + "grid_index": DimMapping(method="stack", dims=["x"]), + "static_feature": DimMapping( + method="stack_variables_by_var_name", + name_format="{var_name}", + ), + }, + ), + "forcing": InputDataset( + path=str(forcing_path), + dims=["time", "x"], + variables=["f"], + target_output_variable="forcing", + dim_mapping={ + "time": DimMapping(method="rename", dim="time"), + "grid_index": DimMapping(method="stack", dims=["x"]), + "forcing_feature": DimMapping( + method="stack_variables_by_var_name", + name_format="{var_name}", + ), + }, + ), + } + + ordered_inputs = {name: inputs_by_name[name] for name in input_order} + + config = mdp.Config( + schema_version="v0.6.0", + dataset_version="v0.0.0", + output=Output( + variables={ + "state": ["time", "grid_index", "state_feature"], + "forcing": ["time", "grid_index", "forcing_feature"], + "static": ["grid_index", "static_feature"], + }, + coord_ranges={ + "time": Range( + start="2000-01-01T00:00", + end="2000-01-03T00:00", + step="PT24H", + ) + }, + ), + inputs=ordered_inputs, + ) + + ds = mdp.create_dataset(config=config) + + expected_len = 3 + assert ds["state"].sizes["time"] == expected_len + assert ds["forcing"].sizes["time"] == expected_len + assert "time" not in ds["static"].dims