Skip to content

Commit 77129f1

Browse files
committed
Add top-level open_datatree function (TODO: deduplicate and clean up code)
1 parent 39ba056 commit 77129f1

File tree

4 files changed

+231
-6
lines changed

4 files changed

+231
-6
lines changed

xarray/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
load_dataset,
55
open_dataarray,
66
open_dataset,
7+
open_datatree,
78
open_mfdataset,
89
save_mfdataset,
910
)
@@ -84,6 +85,7 @@
8485
"ones_like",
8586
"open_dataarray",
8687
"open_dataset",
88+
"open_datatree",
8789
"open_mfdataset",
8890
"open_rasterio",
8991
"open_zarr",

xarray/backends/api.py

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,45 @@ def _chunk_ds(
336336
return backend_ds._replace(variables)
337337

338338

339+
def _datatree_from_backend_datatree(
340+
backend_dt,
341+
filename_or_obj,
342+
engine,
343+
chunks,
344+
cache,
345+
overwrite_encoded_chunks,
346+
inline_array,
347+
**extra_tokens,
348+
):
349+
# TODO: deduplicate with _dataset_from_backend_dataset
350+
if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
351+
raise ValueError(
352+
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
353+
)
354+
355+
backend_dt.map_over_subtree_inplace(_protect_dataset_variables_inplace, cache=cache)
356+
if chunks is None:
357+
dt = backend_dt
358+
else:
359+
dt = backend_dt.map_over_subtree(
360+
_chunk_ds,
361+
filename_or_obj=filename_or_obj,
362+
engine=engine,
363+
chunks=chunks,
364+
overwrite_encoded_chunks=overwrite_encoded_chunks,
365+
inline_array=inline_array,
366+
**extra_tokens,
367+
)
368+
369+
dt.map_over_subtree_inplace((lambda ds: ds.set_close), backend_dt._close)
370+
371+
# Ensure source filename always stored in dataset object
372+
if "source" not in dt.encoding and isinstance(filename_or_obj, (str, os.PathLike)):
373+
dt.encoding["source"] = _normalize_path(filename_or_obj)
374+
375+
return dt
376+
377+
339378
def _dataset_from_backend_dataset(
340379
backend_ds,
341380
filename_or_obj,
@@ -374,6 +413,188 @@ def _dataset_from_backend_dataset(
374413
return ds
375414

376415

416+
def open_datatree(
417+
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
418+
*,
419+
engine: T_Engine = None,
420+
chunks: T_Chunks = None,
421+
cache: bool | None = None,
422+
decode_cf: bool | None = None,
423+
mask_and_scale: bool | None = None,
424+
decode_times: bool | None = None,
425+
decode_timedelta: bool | None = None,
426+
use_cftime: bool | None = None,
427+
concat_characters: bool | None = None,
428+
decode_coords: Literal["coordinates", "all"] | bool | None = None,
429+
drop_variables: str | Iterable[str] | None = None,
430+
inline_array: bool = False,
431+
backend_kwargs: dict[str, Any] | None = None,
432+
**kwargs,
433+
) -> Dataset:
434+
"""Open and decode a dataset from a file or file-like object.
435+
436+
Parameters
437+
----------
438+
filename_or_obj : str, Path, file-like or DataStore
439+
Strings and Path objects are interpreted as a path to a netCDF file
440+
or an OpenDAP URL and opened with python-netCDF4, unless the filename
441+
ends with .gz, in which case the file is gunzipped and opened with
442+
scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like
443+
objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF).
444+
engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \
445+
"pseudonetcdf", "zarr", None}, installed backend \
446+
or subclass of xarray.backends.BackendEntrypoint, optional
447+
Engine to use when reading files. If not provided, the default engine
448+
is chosen based on available dependencies, with a preference for
449+
"netcdf4". A custom backend class (a subclass of ``BackendEntrypoint``)
450+
can also be used.
451+
chunks : int, dict, 'auto' or None, optional
452+
If chunks is provided, it is used to load the new dataset into dask
453+
arrays. ``chunks=-1`` loads the dataset with dask using a single
454+
chunk for all arrays. ``chunks={}`` loads the dataset with dask using
455+
engine preferred chunks if exposed by the backend, otherwise with
456+
a single chunk for all arrays.
457+
``chunks='auto'`` will use dask ``auto`` chunking taking into account the
458+
engine preferred chunks. See dask chunking for more details.
459+
cache : bool, optional
460+
If True, cache data loaded from the underlying datastore in memory as
461+
NumPy arrays when accessed to avoid reading from the underlying data-
462+
store multiple times. Defaults to True unless you specify the `chunks`
463+
argument to use dask, in which case it defaults to False. Does not
464+
change the behavior of coordinates corresponding to dimensions, which
465+
always load their data from disk into a ``pandas.Index``.
466+
decode_cf : bool, optional
467+
Whether to decode these variables, assuming they were saved according
468+
to CF conventions.
469+
mask_and_scale : bool, optional
470+
If True, replace array values equal to `_FillValue` with NA and scale
471+
values according to the formula `original_values * scale_factor +
472+
add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are
473+
taken from variable attributes (if they exist). If the `_FillValue` or
474+
`missing_value` attribute contains multiple values a warning will be
475+
issued and all array values matching one of the multiple values will
476+
be replaced by NA. mask_and_scale defaults to True except for the
477+
pseudonetcdf backend. This keyword may not be supported by all the backends.
478+
decode_times : bool, optional
479+
If True, decode times encoded in the standard NetCDF datetime format
480+
into datetime objects. Otherwise, leave them encoded as numbers.
481+
This keyword may not be supported by all the backends.
482+
decode_timedelta : bool, optional
483+
If True, decode variables and coordinates with time units in
484+
{"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"}
485+
into timedelta objects. If False, leave them encoded as numbers.
486+
If None (default), assume the same value of decode_time.
487+
This keyword may not be supported by all the backends.
488+
use_cftime: bool, optional
489+
Only relevant if encoded dates come from a standard calendar
490+
(e.g. "gregorian", "proleptic_gregorian", "standard", or not
491+
specified). If None (default), attempt to decode times to
492+
``np.datetime64[ns]`` objects; if this is not possible, decode times to
493+
``cftime.datetime`` objects. If True, always decode times to
494+
``cftime.datetime`` objects, regardless of whether or not they can be
495+
represented using ``np.datetime64[ns]`` objects. If False, always
496+
decode times to ``np.datetime64[ns]`` objects; if this is not possible
497+
raise an error. This keyword may not be supported by all the backends.
498+
concat_characters : bool, optional
499+
If True, concatenate along the last dimension of character arrays to
500+
form string arrays. Dimensions will only be concatenated over (and
501+
removed) if they have no corresponding variable and if they are only
502+
used as the last dimension of character arrays.
503+
This keyword may not be supported by all the backends.
504+
decode_coords : bool or {"coordinates", "all"}, optional
505+
Controls which variables are set as coordinate variables:
506+
507+
- "coordinates" or True: Set variables referred to in the
508+
``'coordinates'`` attribute of the datasets or individual variables
509+
as coordinate variables.
510+
- "all": Set variables referred to in ``'grid_mapping'``, ``'bounds'`` and
511+
other attributes as coordinate variables.
512+
drop_variables: str or iterable of str, optional
513+
A variable or list of variables to exclude from being parsed from the
514+
dataset. This may be useful to drop variables with problems or
515+
inconsistent values.
516+
inline_array: bool, default: False
517+
How to include the array in the dask task graph.
518+
By default(``inline_array=False``) the array is included in a task by
519+
itself, and each chunk refers to that task by its key. With
520+
``inline_array=True``, Dask will instead inline the array directly
521+
in the values of the task graph. See :py:func:`dask.array.from_array`.
522+
backend_kwargs: dict
523+
Additional keyword arguments passed on to the engine open function,
524+
equivalent to `**kwargs`.
525+
**kwargs: dict
526+
Additional keyword arguments passed on to the engine open function.
527+
For example:
528+
529+
- 'group': path to the netCDF4 group in the given file to open given as
530+
a str,supported by "netcdf4", "h5netcdf", "zarr".
531+
- 'lock': resource lock to use when reading data from disk. Only
532+
relevant when using dask or another form of parallelism. By default,
533+
appropriate locks are chosen to safely read and write files with the
534+
currently active dask scheduler. Supported by "netcdf4", "h5netcdf",
535+
"scipy", "pynio", "pseudonetcdf", "cfgrib".
536+
537+
See engine open function for kwargs accepted by each specific engine.
538+
539+
Returns
540+
-------
541+
datatree : datatree.DataTree
542+
The newly created datatree.
543+
544+
Notes
545+
-----
546+
``open_datatree`` opens the file with read-only access. When you modify
547+
values of a Dataset, even one linked to files on disk, only the in-memory
548+
copy you are manipulating in xarray is modified: the original file on disk
549+
is never touched.
550+
551+
"""
552+
# TODO deduplicate with open_dataset
553+
554+
if cache is None:
555+
cache = chunks is None
556+
557+
if backend_kwargs is not None:
558+
kwargs.update(backend_kwargs)
559+
560+
if engine is None:
561+
engine = plugins.guess_engine(filename_or_obj)
562+
563+
backend = plugins.get_backend(engine)
564+
565+
decoders = _resolve_decoders_kwargs(
566+
decode_cf,
567+
open_backend_dataset_parameters=backend.open_dataset_parameters,
568+
mask_and_scale=mask_and_scale,
569+
decode_times=decode_times,
570+
decode_timedelta=decode_timedelta,
571+
concat_characters=concat_characters,
572+
use_cftime=use_cftime,
573+
decode_coords=decode_coords,
574+
)
575+
576+
overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
577+
backend_dt = backend.open_datatree(
578+
filename_or_obj,
579+
drop_variables=drop_variables,
580+
**decoders,
581+
**kwargs,
582+
)
583+
dt = _datatree_from_backend_datatree(
584+
backend_dt,
585+
filename_or_obj,
586+
engine,
587+
chunks,
588+
cache,
589+
overwrite_encoded_chunks,
590+
inline_array,
591+
drop_variables=drop_variables,
592+
**decoders,
593+
**kwargs,
594+
)
595+
return dt
596+
597+
377598
def open_dataset(
378599
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
379600
*,

xarray/backends/netCDF4_.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,9 @@ def get_group_stores(self):
443443
def select_group(self, group):
444444
"""Return new NetCDF4DataStore for specified group of this NetCDF4DataStore."""
445445
if group in self.ds.groups:
446-
return self.__init__(
447-
manager=self._manager, group=group, mode=self._mode, lock=self.lock, autoclose=self.autoclose
446+
parent_group = self._group if self._group is not None else ''
447+
return self.__class__(
448+
manager=self._manager, group=f"{parent_group}{group}/", mode=self._mode, lock=self.lock, autoclose=self.autoclose
448449
)
449450
else:
450451
raise KeyError(group)
@@ -654,7 +655,7 @@ def open_dataset(
654655
autoclose=autoclose,
655656
)
656657

657-
def open_dataset(
658+
def open_datatree(
658659
self,
659660
filename_or_obj,
660661
mask_and_scale=True,

xarray/backends/store.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,17 @@ def _add_node(store, path, datasets):
102102
use_cftime=use_cftime,
103103
decode_timedelta=decode_timedelta,
104104
)
105+
ds.set_close(store.close) # TODO should this be on datatree? if so, need to add to datatree API
105106
datasets[path] = ds
106107

107108
# Recursively add children to collector
108-
for child_name, child_store in store.get_group_stores():
109+
for child_name, child_store in store.get_group_stores().items():
109110
datasets = _add_node(child_store, f"{path}{child_name}/", datasets)
110111

111112
return datasets
112113

113-
dt = DataTree.from_dict(_add_node(store, "/", {}))
114-
dt.set_close(store.close)
114+
datasets = _add_node(store, "/", {})
115+
dt = DataTree.from_dict(datasets)
115116

116117
return dt
117118

0 commit comments

Comments
 (0)