Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [unreleased](https://github.com/mllam/mllam-data-prep/compare/v0.7.0...HEAD)

### Fixes
- fix bug where coordinate selection of an unshared dimension isn't applied to subsequent ouput variables when an output variable without this dimension is processed before the others [\#90](https://github.com/mllam/mllam-data-prep/pull/90) @zweihuehner & @leifdenby

## [v0.7.0](https://github.com/mllam/mllam-data-prep/release/tag/v0.7.0)

[All changes](https://github.com/mllam/mllam-data-prep/compare/v0.7.0...v0.6.1)
Expand Down
22 changes: 17 additions & 5 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,17 @@ def _merge_dataarrays_by_target(dataarrays_by_target):
ds = xr.merge(dataarrays, join="exact")
except ValueError as ex:
if ex.args[0].startswith("cannot align objects with join='exact'"):

def _summarize(da):
dims = ", ".join([f"{k}: {v}" for k, v in da.sizes.items()])
return f"{da.name} ({dims})\n{da.coords}"

coord_summaries = "\n".join([_summarize(da) for da in dataarrays])
raise InvalidConfigException(
f"Couldn't merge together the dataarrays for all targets ({', '.join(dataarrays_by_target.keys())})"
f" This is likely because the dataarrays have different dimensions or coordinates."
" Maybe you need to give the 'feature' dimension a unique name for each target variable?"
f"Couldn't merge together the dataarrays for all targets ({', '.join(dataarrays_by_target.keys())}). "
"This is likely because the dataarrays have different dimensions or coordinates. "
f"Dataarray coords:\n{coord_summaries}"
"Maybe you need to give the 'feature' dimension a unique name for each target variable?"
) from ex
else:
raise ex
Expand Down Expand Up @@ -245,10 +252,15 @@ def create_dataset(config: Config):

# only need to do selection for the coordinates that the input dataset actually has
if output_coord_ranges is not None:
output_coord_ranges = {
# Use a temporary dict to apply selection on coordinate ranges to avoid
# modifying the original ranges given in the config. This is needed because
# static features, for example, do not have a time dimension. Hence, the time
# based selection returns an empty dictionary, which should not overwrite the
# selection for the other variables.
output_coord_ranges_tmp = {
k: w for k, w in output_coord_ranges.items() if k in output_dims
}
da_target = select_by_kwargs(da_target, **output_coord_ranges)
da_target = select_by_kwargs(da_target, **output_coord_ranges_tmp)

dataarrays_by_target[target_output_var].append(da_target)

Expand Down
124 changes: 124 additions & 0 deletions tests/test_output_coord_ranges_slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import itertools

import numpy as np
import pandas as pd
import pytest
import xarray as xr

import mllam_data_prep as mdp
from mllam_data_prep.config import DimMapping, InputDataset, Output, Range


def _write_zarr(ds: xr.Dataset, path):
ds.to_zarr(path, mode="w")


@pytest.mark.parametrize(
"input_order",
list(itertools.permutations(["state", "static", "forcing"])),
)
def test_output_coord_ranges_not_dropped_between_inputs(tmp_path, input_order):
"""
Ensure output coord range slicing is applied per-input without being
affected by the order of inputs. This guards against mutating the shared
output_coord_ranges dict when an input lacks a dimension (e.g. `static`
without `time`), which would otherwise remove slicing for later inputs.

See https://github.com/mllam/mllam-data-prep/issues/81 for bug report.
"""
time = pd.date_range("2000-01-01", "2000-01-05", freq="1D")
x = np.arange(2)

state_ds = xr.Dataset(
{"s": (("time", "x"), np.zeros((len(time), len(x))))},
coords={"time": time, "x": x},
)
forcing_ds = xr.Dataset(
{"f": (("time", "x"), np.ones((len(time), len(x))))},
coords={"time": time, "x": x},
)
static_ds = xr.Dataset(
{"static_feature": (("x",), np.array([10.0, 20.0]))},
coords={"x": x},
)

state_path = tmp_path / "state.zarr"
forcing_path = tmp_path / "forcing.zarr"
static_path = tmp_path / "static.zarr"

_write_zarr(state_ds, state_path)
_write_zarr(forcing_ds, forcing_path)
_write_zarr(static_ds, static_path)

inputs_by_name = {
"state": InputDataset(
path=str(state_path),
dims=["time", "x"],
variables=["s"],
target_output_variable="state",
dim_mapping={
"time": DimMapping(method="rename", dim="time"),
"grid_index": DimMapping(method="stack", dims=["x"]),
"state_feature": DimMapping(
method="stack_variables_by_var_name",
name_format="{var_name}",
),
},
),
"static": InputDataset(
path=str(static_path),
dims=["x"],
variables=["static_feature"],
target_output_variable="static",
dim_mapping={
"grid_index": DimMapping(method="stack", dims=["x"]),
"static_feature": DimMapping(
method="stack_variables_by_var_name",
name_format="{var_name}",
),
},
),
"forcing": InputDataset(
path=str(forcing_path),
dims=["time", "x"],
variables=["f"],
target_output_variable="forcing",
dim_mapping={
"time": DimMapping(method="rename", dim="time"),
"grid_index": DimMapping(method="stack", dims=["x"]),
"forcing_feature": DimMapping(
method="stack_variables_by_var_name",
name_format="{var_name}",
),
},
),
}

ordered_inputs = {name: inputs_by_name[name] for name in input_order}

config = mdp.Config(
schema_version="v0.6.0",
dataset_version="v0.0.0",
output=Output(
variables={
"state": ["time", "grid_index", "state_feature"],
"forcing": ["time", "grid_index", "forcing_feature"],
"static": ["grid_index", "static_feature"],
},
coord_ranges={
"time": Range(
start="2000-01-01T00:00",
end="2000-01-03T00:00",
step="PT24H",
)
},
),
inputs=ordered_inputs,
)

ds = mdp.create_dataset(config=config)

expected_len = 3
assert ds["state"].sizes["time"] == expected_len
assert ds["forcing"].sizes["time"] == expected_len
assert "time" not in ds["static"].dims