Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mllam_data_prep/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
# expose the public API
from .config import Config, InvalidConfigException # noqa
from .create_dataset import create_dataset, create_dataset_zarr # noqa
from .recreate_inputs import recreate_inputs # noqa
24 changes: 22 additions & 2 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from pathlib import Path
from typing import Optional, Union

import cf_xarray as cfxr
import numpy as np
import pandas as pd
import xarray as xr
import yaml
import zarr
Expand Down Expand Up @@ -39,6 +41,9 @@
# support v0.2.0, v0.5.0, and v0.6.0
SUPPORTED_CONFIG_VERSIONS = ["v0.2.0", "v0.5.0", "v0.6.0"]

STATISTICS_VARIABLE_NAME_FORMAT = "{var_name}__{split_name}__{op}"
SOURCE_DATASET_NAME_ATTR = "source_dataset"


def _check_dataset_attributes(ds, expected_attributes, dataset_name):
# check that the dataset has the expected attributes with the expected values
Expand Down Expand Up @@ -232,7 +237,7 @@ def create_dataset(config: Config):
f" produce variable {target_output_var} from dataset {dataset_name}"
) from ex

da_target.attrs["source_dataset"] = dataset_name
da_target.attrs[SOURCE_DATASET_NAME_ATTR] = dataset_name

# only need to do selection for the coordinates that the input dataset actually has
if output_coord_ranges is not None:
Expand Down Expand Up @@ -276,7 +281,10 @@ def create_dataset(config: Config):
)
for op, op_dataarrays in split_stats.items():
for var_name, da in op_dataarrays.items():
ds[f"{var_name}__{split_name}__{op}"] = da
stat_var_name = STATISTICS_VARIABLE_NAME_FORMAT.format(
var_name=var_name, split_name=split_name, op=op
)
ds[stat_var_name] = da

# add a new variable which contains the start, stop for each split, the coords would then be the split names
# and the data would be the start, stop values
Expand All @@ -288,6 +296,18 @@ def create_dataset(config: Config):
)
ds["splits"] = da_splits

# We have to deal with the fact that MultiIndex objects (this would
# commonly before example `grid_index` created by stacking the `x` and `y`
# coordinates) can't be written to netcdf/zarr. In cf_xarray this has been
# handled in a cf-compliant manner using so-called "compression by
# gathering" (see
# https://cf-xarray.readthedocs.io/en/latest/generated/cf_xarray.encode_multi_index_as_compress.html#cf_xarray.encode_multi_index_as_compress).
# which allows us to safely roundtrip MultiIndexes through netcdf/zarr,
# using their encode and decode functions.
for idx in ds.indexes:
if isinstance(ds.indexes[idx], pd.MultiIndex):
ds = cfxr.encode_multi_index_as_compress(ds, idxnames=idx)

ds.attrs = {}
ds.attrs["schema_version"] = config.schema_version
ds.attrs["dataset_version"] = config.dataset_version
Expand Down
2 changes: 1 addition & 1 deletion mllam_data_prep/ops/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def map_dims_and_variables(ds, dim_mapping, expected_input_var_dims):
# in the input dataset that we want to stack to create the architecture
# dimension, this is for example used for flatting the spatial dimensions
# into a single dimension representing the grid index
ds = ds.stack({arch_dim: source_dims}).reset_index(arch_dim)
ds = ds.stack({arch_dim: source_dims})
else:
raise NotImplementedError(method)

Expand Down
Loading