Skip to content

Commit 40666b2

Browse files
shoyerTomNicholaspre-commit-ci[bot]
authored
Fix inheritance in DataTree.copy() (#9457)
* Fix inheritance in DataTree.copy() Fixes #9454 Previously, we were copying parent coordinates/dimensions onto all child nodes. This is not obvious in the current repr, but you can see it from looking at the private `._node_coord_variables` and `._node_dims`. To make the use of `_to_dataset_view()` little more obvious, I've added a required boolean `inherited` argument. * typing error * add missing inherited argument * Apply suggestions from code review Co-authored-by: Tom Nicholas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tweaks to from_dict * add issue link --------- Co-authored-by: Tom Nicholas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 45a0027 commit 40666b2

File tree

5 files changed

+61
-28
lines changed

5 files changed

+61
-28
lines changed

Diff for: asv_bench/benchmarks/datatree.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
class Datatree:
66
def setup(self):
77
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
8-
self.d = {"run1": run1}
8+
self.d_few = {"run1": run1}
9+
self.d_many = {f"run{i}": run1.copy() for i in range(100)}
910

10-
def time_from_dict(self):
11-
DataTree.from_dict(self.d)
11+
def time_from_dict_few(self):
12+
DataTree.from_dict(self.d_few)
13+
14+
def time_from_dict_many(self):
15+
DataTree.from_dict(self.d_many)

Diff for: xarray/core/datatree.py

+41-23
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
473473
)
474474
path = str(NodePath(parent.path) / name)
475475
node_ds = self.to_dataset(inherited=False)
476-
parent_ds = parent._to_dataset_view(rebuild_dims=False)
476+
parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True)
477477
_check_alignment(path, node_ds, parent_ds, self.children)
478478

479479
@property
@@ -490,30 +490,46 @@ def _dims(self) -> ChainMap[Hashable, int]:
490490
def _indexes(self) -> ChainMap[Hashable, Index]:
491491
return ChainMap(self._node_indexes, *(p._node_indexes for p in self.parents))
492492

493-
def _to_dataset_view(self, rebuild_dims: bool) -> DatasetView:
493+
def _to_dataset_view(self, rebuild_dims: bool, inherited: bool) -> DatasetView:
494+
coord_vars = self._coord_variables if inherited else self._node_coord_variables
494495
variables = dict(self._data_variables)
495-
variables |= self._coord_variables
496+
variables |= coord_vars
496497
if rebuild_dims:
497498
dims = calculate_dimensions(variables)
498-
else:
499-
# Note: rebuild_dims=False can create technically invalid Dataset
500-
# objects because it may not contain all dimensions on its direct
501-
# member variables, e.g., consider:
502-
# tree = DataTree.from_dict(
503-
# {
504-
# "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2
505-
# "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1
506-
# }
507-
# )
508-
# However, they are fine for internal use cases, for align() or
509-
# building a repr().
499+
elif inherited:
500+
# Note: rebuild_dims=False with inherited=True can create
501+
# technically invalid Dataset objects because it still includes
502+
# dimensions that are only defined on parent data variables (i.e. not present on any parent coordinate variables), e.g.,
503+
# consider:
504+
# >>> tree = DataTree.from_dict(
505+
# ... {
506+
# ... "/": xr.Dataset({"foo": ("x", [1, 2])}), # x has size 2
507+
# ... "/b": xr.Dataset(),
508+
# ... }
509+
# ... )
510+
# >>> ds = tree["b"]._to_dataset_view(rebuild_dims=False, inherited=True)
511+
# >>> ds
512+
# <xarray.DatasetView> Size: 0B
513+
# Dimensions: (x: 2)
514+
# Dimensions without coordinates: x
515+
# Data variables:
516+
# *empty*
517+
#
518+
# Notice the "x" dimension is still defined, even though there are no
519+
# variables or coordinates.
520+
# Normally this is not supposed to be possible in xarray's data model, but here it is useful internally for use cases where we
521+
# want to inherit everything from parents nodes, e.g., for align()
522+
# and repr().
523+
# The user should never be able to see this dimension via public API.
510524
dims = dict(self._dims)
525+
else:
526+
dims = dict(self._node_dims)
511527
return DatasetView._constructor(
512528
variables=variables,
513529
coord_names=set(self._coord_variables),
514530
dims=dims,
515531
attrs=self._attrs,
516-
indexes=dict(self._indexes),
532+
indexes=dict(self._indexes if inherited else self._node_indexes),
517533
encoding=self._encoding,
518534
close=None,
519535
)
@@ -532,7 +548,7 @@ def ds(self) -> DatasetView:
532548
--------
533549
DataTree.to_dataset
534550
"""
535-
return self._to_dataset_view(rebuild_dims=True)
551+
return self._to_dataset_view(rebuild_dims=True, inherited=True)
536552

537553
@ds.setter
538554
def ds(self, data: Dataset | None = None) -> None:
@@ -739,7 +755,7 @@ def _replace_node(
739755
raise ValueError(f"node already contains a variable named {child_name}")
740756

741757
parent_ds = (
742-
self.parent._to_dataset_view(rebuild_dims=False)
758+
self.parent._to_dataset_view(rebuild_dims=False, inherited=True)
743759
if self.parent is not None
744760
else None
745761
)
@@ -800,8 +816,10 @@ def _copy_node(
800816
deep: bool = False,
801817
) -> DataTree:
802818
"""Copy just one node of a tree"""
803-
data = self.ds.copy(deep=deep)
804-
new_node: DataTree = DataTree(data, name=self.name)
819+
data = self._to_dataset_view(rebuild_dims=False, inherited=False)
820+
if deep:
821+
data = data.copy(deep=True)
822+
new_node = DataTree(data, name=self.name)
805823
return new_node
806824

807825
def __copy__(self: DataTree) -> DataTree:
@@ -1096,7 +1114,6 @@ def from_dict(
10961114
root_data = d_cast.pop("/", None)
10971115
if isinstance(root_data, DataTree):
10981116
obj = root_data.copy()
1099-
obj.orphan()
11001117
elif root_data is None or isinstance(root_data, Dataset):
11011118
obj = cls(name=name, data=root_data, children=None)
11021119
else:
@@ -1116,9 +1133,10 @@ def depth(item) -> int:
11161133
node_name = NodePath(path).name
11171134
if isinstance(data, DataTree):
11181135
new_node = data.copy()
1119-
new_node.orphan()
1120-
else:
1136+
elif isinstance(data, Dataset) or data is None:
11211137
new_node = cls(name=node_name, data=data)
1138+
else:
1139+
raise TypeError(f"invalid values: {data}")
11221140
obj._set_item(
11231141
path,
11241142
new_node,

Diff for: xarray/core/formatting.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,10 @@ def diff_datatree_repr(a: DataTree, b: DataTree, compat):
10511051
def _single_node_repr(node: DataTree) -> str:
10521052
"""Information about this node, not including its relationships to other nodes."""
10531053
if node.has_data or node.has_attrs:
1054-
ds_info = "\n" + repr(node._to_dataset_view(rebuild_dims=False))
1054+
# TODO: change this to inherited=False, in order to clarify what is
1055+
# inherited? https://github.com/pydata/xarray/issues/9463
1056+
node_view = node._to_dataset_view(rebuild_dims=False, inherited=True)
1057+
ds_info = "\n" + repr(node_view)
10551058
else:
10561059
ds_info = ""
10571060
return f"Group: {node.path}{ds_info}"

Diff for: xarray/core/formatting_html.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
386386
def datatree_node_repr(group_title: str, dt: DataTree) -> str:
387387
header_components = [f"<div class='xr-obj-type'>{escape(group_title)}</div>"]
388388

389-
ds = dt._to_dataset_view(rebuild_dims=False)
389+
ds = dt._to_dataset_view(rebuild_dims=False, inherited=True)
390390

391391
sections = [
392392
children_section(dt.children),

Diff for: xarray/tests/test_datatree.py

+8
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,14 @@ def test_copy_subtree(self):
368368

369369
assert_identical(actual, expected)
370370

371+
def test_copy_coord_inheritance(self) -> None:
372+
tree = DataTree.from_dict(
373+
{"/": xr.Dataset(coords={"x": [0, 1]}), "/c": DataTree()}
374+
)
375+
tree2 = tree.copy()
376+
node_ds = tree2.children["c"].to_dataset(inherited=False)
377+
assert_identical(node_ds, xr.Dataset())
378+
371379
def test_deepcopy(self, create_test_datatree):
372380
dt = create_test_datatree()
373381

0 commit comments

Comments
 (0)