Skip to content

Commit 01e405d

Browse files
committed
Initial commit
1 parent 4d8bbee commit 01e405d

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

xarray/backends/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from xarray.backends.locks import _get_scheduler
3636
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
3737
from xarray.core import indexing
38-
from xarray.core.chunk import _get_chunk, _maybe_chunk
38+
from xarray.core.chunk import _get_chunk, _maybe_chunk, _maybe_get_path_chunk
3939
from xarray.core.combine import (
4040
_infer_concat_order_from_positions,
4141
_nested_combine,
@@ -450,7 +450,7 @@ def _datatree_from_backend_datatree(
450450
node.dataset,
451451
filename_or_obj,
452452
engine,
453-
chunks,
453+
_maybe_get_path_chunk(node.path, chunks),
454454
overwrite_encoded_chunks,
455455
inline_array,
456456
chunked_array_type,

xarray/core/chunk.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,15 @@ def _maybe_chunk(
145145
return var
146146
else:
147147
return var
148+
149+
150+
def _maybe_get_path_chunk(path: str, chunks: int | dict | Any) -> int | dict | Any:
151+
"""Returns path-specific chunks from a chunks dictionary, if path is a key of chunks.
152+
Otherwise, returns chunks as is"""
153+
if isinstance(chunks, dict):
154+
try:
155+
return chunks[path]
156+
except KeyError:
157+
pass
158+
159+
return chunks

xarray/tests/test_backends_datatree.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,37 @@ def test_open_datatree_chunks(self, tmpdir, simple_datatree) -> None:
256256

257257
assert_chunks_equal(tree, original_tree, enforce_dask=True)
258258

259+
@requires_dask
260+
def test_open_datatree_path_chunks(self, tmpdir, simple_datatree) -> None:
261+
filepath = tmpdir / "test.nc"
262+
263+
root_chunks = {"x": 2, "y": 1}
264+
set1_chunks = {"x": 1, "y": 2}
265+
set2_chunks = {"x": 2, "y": 3}
266+
267+
root_data = xr.Dataset({"a": ("y", [6, 7, 8]), "set0": ("x", [9, 10])})
268+
set1_data = xr.Dataset({"a": ("y", [-1, 0, 1]), "b": ("x", [-10, 6])})
269+
set2_data = xr.Dataset({"a": ("y", [1, 2, 3]), "b": ("x", [0.1, 0.2])})
270+
original_tree = DataTree.from_dict(
271+
{
272+
"/": root_data.chunk(root_chunks),
273+
"/group1": set1_data.chunk(set1_chunks),
274+
"/group2": set2_data.chunk(set2_chunks),
275+
}
276+
)
277+
original_tree.to_netcdf(filepath, engine="netcdf4")
278+
279+
chunks = {
280+
"/": root_chunks,
281+
"/group1": set1_chunks,
282+
"/group2": set2_chunks,
283+
}
284+
285+
with open_datatree(filepath, engine="netcdf4", chunks=chunks) as tree:
286+
xr.testing.assert_identical(tree, original_tree)
287+
288+
assert_chunks_equal(tree, original_tree, enforce_dask=True)
289+
259290
def test_open_groups(self, unaligned_datatree_nc) -> None:
260291
"""Test `open_groups` with a netCDF4 file with an unaligned group hierarchy."""
261292
unaligned_dict_of_datasets = open_groups(unaligned_datatree_nc)

0 commit comments

Comments
 (0)