Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,8 @@ def create_dataset(config: Config):
if input_config.coord_ranges is not None:
ds_input = selection.select_by_kwargs(ds_input, **input_config.coord_ranges)

# Initialize the output dataset
ds = xr.Dataset()
ds.attrs.update(ds_input.attrs)
# Initialize independent output data storage dict
ds = {}

if selected_variables:
logger.info(f"Extracting selected variables from dataset {dataset_name}")
Expand Down Expand Up @@ -213,8 +212,9 @@ def create_dataset(config: Config):
target_dims=expected_input_var_dims,
)

# Verify attributes on the intact input dataset
_check_dataset_attributes(
ds=ds,
ds=ds_input,
expected_attributes=expected_input_attributes,
dataset_name=dataset_name,
)
Expand Down
39 changes: 28 additions & 11 deletions mllam_data_prep/ops/cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,29 @@ def create_convex_hull_mask(ds: xr.Dataset, ds_reference: xr.Dataset) -> xr.Data

chull_lam = SphericalPolygon.convex_hull(da_ref_xyz.values)

# call .load() to avoid using dask arrays in the following apply_ufunc
def _mask_points_in_hull(lon_vals, lat_vals):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this relates to the issue you are fixing here, right?

# Flatten all dimensions
shape = lon_vals.shape
lon = lon_vals.ravel()
lat = lat_vals.ravel()

xyz_pts = np.array(sg.vector.lonlat_to_vector(lon, lat)).T

# We iterate over pre-calculated Cartesian coordinates to avoid millions
# of redundant trigonometric python calls that would be present if we
# passed (lon, lat) points individually. (SphericalPolygon does not yet
# support vectorized batched arrays internally).
mask = np.array([chull_lam.contains_point(pt) for pt in xyz_pts], dtype=bool)

return mask.reshape(shape)

# use dask-parallelized vectorized containment test without np.vectorize
da_interior_mask = xr.apply_ufunc(
chull_lam.contains_lonlat, da_lon.load(), da_lat.load(), vectorize=True
_mask_points_in_hull,
da_lon.load(),
da_lat.load(),
dask="parallelized",
output_dtypes=[bool],
).astype(bool)
da_interior_mask.attrs[
"long_name"
Expand Down Expand Up @@ -241,15 +261,12 @@ def distance_to_convex_hull_boundary(
(da_xyz_chull[-1], da_xyz_chull[0])
] # Add arc from last to first point

# Calculate minimum distance to each arc and take the minimum
# distance over all arcs
mindist_to_ref = np.stack(
[
shortest_distance_to_arc(da_xyz, arc_start, arc_end)
for arc_start, arc_end in chull_arcs
],
axis=0,
).min(axis=0)
mindist_to_ref = np.full(da_xyz.shape[0], np.inf)
xyz_arr = da_xyz.values

for arc_start, arc_end in chull_arcs:
dist = shortest_distance_to_arc(xyz_arr, arc_start, arc_end)
np.minimum(mindist_to_ref, dist, out=mindist_to_ref)

da_mindist_to_ref = xr.DataArray(
mindist_to_ref, coords=ds_exterior_lat.coords, dims=ds_exterior_lat.dims
Expand Down
19 changes: 16 additions & 3 deletions mllam_data_prep/ops/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def map_dims_and_variables(ds, dim_mapping, expected_input_var_dims):
)

# check that none of the variables have dims that are not in the expected_input_var_dims
for var_name in ds.data_vars:
data_vars = list(ds.data_vars) if hasattr(ds, "data_vars") else list(ds.keys())
for var_name in data_vars:
if not set(ds[var_name].dims).issubset(expected_input_var_dims):
extra_dims = set(ds[var_name].dims) - set(expected_input_var_dims)
raise ValueError(
Expand All @@ -93,14 +94,26 @@ def map_dims_and_variables(ds, dim_mapping, expected_input_var_dims):

if method == "rename":
source_dim = input_dim_map.dim
ds = ds.rename({source_dim: arch_dim})
if hasattr(ds, "data_vars"): # xr.Dataset
ds = ds.rename({source_dim: arch_dim})
else: # dictionary of DataArrays
ds = {
k: (v.rename({source_dim: arch_dim}) if source_dim in v.dims else v)
for k, v in ds.items()
}
elif method == "stack":
source_dims = input_dim_map.dims
# when stacking we assume that the input_dims is a list of dimensions
# in the input dataset that we want to stack to create the architecture
# dimension, this is for example used for flatting the spatial dimensions
# into a single dimension representing the grid index
ds = ds.stack({arch_dim: source_dims}).reset_index(arch_dim)
if hasattr(ds, "data_vars"):
ds = ds.stack({arch_dim: source_dims}).reset_index(arch_dim)
else:
ds = {
k: v.stack({arch_dim: source_dims}).reset_index(arch_dim)
for k, v in ds.items()
}
else:
raise NotImplementedError(method)

Expand Down
25 changes: 15 additions & 10 deletions mllam_data_prep/ops/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ def stack_variables_as_coord_values(ds, name_format, combined_dim_name):

Parameters
----------
ds : xr.Dataset
source dataset with variables to stack
ds : xr.Dataset or dict
source dataset or dictionary of variables to stack
name_format : str
format string to construct the new coordinate values for the
stacked variables, e.g. "{var_name}_level"
Expand All @@ -31,7 +31,8 @@ def stack_variables_as_coord_values(ds, name_format, combined_dim_name):
" {var_name} to construct the new coordinate values"
)
dataarrays = []
for var_name in list(ds.data_vars):
data_vars = list(ds.data_vars) if hasattr(ds, "data_vars") else list(ds.keys())
for var_name in data_vars:
da = ds[var_name].expand_dims(combined_dim_name)
da.coords[combined_dim_name] = [name_format.format(var_name=var_name)]

Expand Down Expand Up @@ -76,8 +77,8 @@ def stack_variables_by_coord_values(ds, coord, name_format, combined_dim_name):

Parameters
----------
ds : xr.Dataset
dataset with variables as data_vars and `level_dim` as a coordinate
ds : xr.Dataset or dict
dataset or dict of variables as data_vars and `level_dim` as a coordinate
coord : str
name of the coordinate that should mapped over
name_format : str
Expand All @@ -101,14 +102,18 @@ def stack_variables_by_coord_values(ds, coord, name_format, combined_dim_name):
"The name_format should include the coordinate name as"
f" {{{coord}}} to construct the new coordinate values"
)
if coord not in ds.coords:
raise ValueError(
f"The coordinate {coord} is not in the dataset, found coords: {list(ds.coords)}"
)

# Note: validation that the coord exists is slightly harder when we just have a dict
# of variables, as not all variables may have the same dimensionality
datasets = []
for var_name in list(ds.data_vars):
data_vars = list(ds.data_vars) if hasattr(ds, "data_vars") else list(ds.keys())
for var_name in data_vars:
da = ds[var_name]
if coord not in da.coords:
raise ValueError(
f"The coordinate {coord} is not in the variable {var_name}, found coords: {list(da.coords)}"
)

coord_values = da.coords[coord].values
new_coord_values = [
name_format.format(var_name=var_name, **{coord: val})
Expand Down
101 changes: 101 additions & 0 deletions tests/test_issue_61.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import numpy as np
import xarray as xr

from mllam_data_prep.config import Config
from mllam_data_prep.create_dataset import create_dataset

def test_variable_selection_by_independent_coords():
"""
Test reproducing Issue #61: selecting variables by different coordinates.
Ensure that we don't get NaN-filled cartesian product variables.
"""
# Create mock dataset
altitudes = [30, 50, 75, 100]
time = [1, 2]
x = [0, 1]
y = [0, 1]

shape = (len(time), len(x), len(y), len(altitudes))
coords = {"time": time, "x": x, "y": y, "altitude": altitudes}

ds_mock = xr.Dataset(
data_vars={
"u": (["time", "x", "y", "altitude"], np.ones(shape)),
"v": (["time", "x", "y", "altitude"], np.ones(shape) * 2),
"t": (["time", "x", "y", "altitude"], np.ones(shape) * 3),
},
coords=coords,
)
# Add expected units to coordinates for extraction check
ds_mock.altitude.attrs["units"] = "m"

# Save mock dataset to disk or pass it to config somehow
# By default load_input_dataset reads from path.
# To bypass, we can mock load_input_dataset, or just save it to a temp zarr.
import tempfile
import pathlib

with tempfile.TemporaryDirectory() as tmpdir:
tmp_path = pathlib.Path(tmpdir) / "height_levels.zarr"
ds_mock.to_zarr(tmp_path)

# Config exactly resembling issue 61
config_dict = {
"schema_version": "v0.6.0",
"dataset_version": "v1.0",
"inputs": {
"danra_height_levels": {
"path": str(tmp_path),
"dims": ["time", "x", "y", "altitude"],
"variables": {
"u": {"altitude": {"values": [100, 50], "units": "m"}},
"v": {"altitude": {"values": [100, 75], "units": "m"}},
"t": {"altitude": {"values": [30], "units": "m"}},
},
"dim_mapping": {
"time": {"method": "rename", "dim": "time"},
"state_feature": {
"method": "stack_variables_by_var_name",
"dims": ["altitude"],
"name_format": "{var_name}{altitude}m",
},
"grid_index": {"method": "stack", "dims": ["x", "y"]},
},
"target_output_variable": "state",
}
},
"output": {
"variables": {
"state": ["time", "grid_index", "state_feature"]
}
},
}

import yaml
config_path = tmp_path.parent / "config.yaml"
with open(config_path, "w") as f:
yaml.dump(config_dict, f)

config = Config.from_yaml_file(config_path)

# Execute
ds_out = create_dataset(config)

# Check results
expected_vars = {"u100m", "u50m", "v100m", "v75m", "t30m"}

# ds_out has data variables mapped into `state`. But wait!
# The variables are mapped into `state_feature` coordinate in the `state` data_var, NOT as `data_vars`!
# Let's verify the mllam-data-prep behavior.
# "target_output_variable": "state" means it creates a dataset with `ds_out["state"]`
# and coordinate `state_feature` containing `['u100m', 'u50m', 'v100m', 'v75m', 't30m']`.

assert "state" in ds_out.data_vars
state_features = ds_out.coords["state_feature"].values.tolist()

assert set(state_features) == expected_vars, f"Expected {expected_vars}, got {state_features}"

for feature in expected_vars:
da_feature = ds_out["state"].sel(state_feature=feature)
# Assert it's not entirely NaNs
assert not da_feature.isnull().all(), f"Feature {feature} is entirely NaNs!"