@@ -473,7 +473,7 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None:
473
473
)
474
474
path = str (NodePath (parent .path ) / name )
475
475
node_ds = self .to_dataset (inherited = False )
476
- parent_ds = parent ._to_dataset_view (rebuild_dims = False )
476
+ parent_ds = parent ._to_dataset_view (rebuild_dims = False , inherited = True )
477
477
_check_alignment (path , node_ds , parent_ds , self .children )
478
478
479
479
@property
@@ -490,30 +490,46 @@ def _dims(self) -> ChainMap[Hashable, int]:
490
490
def _indexes (self ) -> ChainMap [Hashable , Index ]:
491
491
return ChainMap (self ._node_indexes , * (p ._node_indexes for p in self .parents ))
492
492
493
- def _to_dataset_view (self , rebuild_dims : bool ) -> DatasetView :
493
+ def _to_dataset_view (self , rebuild_dims : bool , inherited : bool ) -> DatasetView :
494
+ coord_vars = self ._coord_variables if inherited else self ._node_coord_variables
494
495
variables = dict (self ._data_variables )
495
- variables |= self . _coord_variables
496
+ variables |= coord_vars
496
497
if rebuild_dims :
497
498
dims = calculate_dimensions (variables )
498
- else :
499
- # Note: rebuild_dims=False can create technically invalid Dataset
500
- # objects because it may not contain all dimensions on its direct
501
- # member variables, e.g., consider:
502
- # tree = DataTree.from_dict(
503
- # {
504
- # "/": xr.Dataset({"a": (("x",), [1, 2])}), # x has size 2
505
- # "/b/c": xr.Dataset({"d": (("x",), [3])}), # x has size1
506
- # }
507
- # )
508
- # However, they are fine for internal use cases, for align() or
509
- # building a repr().
499
+ elif inherited :
500
+ # Note: rebuild_dims=False with inherited=True can create
501
+ # technically invalid Dataset objects because it still includes
502
+ # dimensions that are only defined on parent data variables (i.e. not present on any parent coordinate variables), e.g.,
503
+ # consider:
504
+ # >>> tree = DataTree.from_dict(
505
+ # ... {
506
+ # ... "/": xr.Dataset({"foo": ("x", [1, 2])}), # x has size 2
507
+ # ... "/b": xr.Dataset(),
508
+ # ... }
509
+ # ... )
510
+ # >>> ds = tree["b"]._to_dataset_view(rebuild_dims=False, inherited=True)
511
+ # >>> ds
512
+ # <xarray.DatasetView> Size: 0B
513
+ # Dimensions: (x: 2)
514
+ # Dimensions without coordinates: x
515
+ # Data variables:
516
+ # *empty*
517
+ #
518
+ # Notice the "x" dimension is still defined, even though there are no
519
+ # variables or coordinates.
520
+ # Normally this is not supposed to be possible in xarray's data model, but here it is useful internally for use cases where we
521
+ # want to inherit everything from parents nodes, e.g., for align()
522
+ # and repr().
523
+ # The user should never be able to see this dimension via public API.
510
524
dims = dict (self ._dims )
525
+ else :
526
+ dims = dict (self ._node_dims )
511
527
return DatasetView ._constructor (
512
528
variables = variables ,
513
529
coord_names = set (self ._coord_variables ),
514
530
dims = dims ,
515
531
attrs = self ._attrs ,
516
- indexes = dict (self ._indexes ),
532
+ indexes = dict (self ._indexes if inherited else self . _node_indexes ),
517
533
encoding = self ._encoding ,
518
534
close = None ,
519
535
)
@@ -532,7 +548,7 @@ def ds(self) -> DatasetView:
532
548
--------
533
549
DataTree.to_dataset
534
550
"""
535
- return self ._to_dataset_view (rebuild_dims = True )
551
+ return self ._to_dataset_view (rebuild_dims = True , inherited = True )
536
552
537
553
@ds .setter
538
554
def ds (self , data : Dataset | None = None ) -> None :
@@ -739,7 +755,7 @@ def _replace_node(
739
755
raise ValueError (f"node already contains a variable named { child_name } " )
740
756
741
757
parent_ds = (
742
- self .parent ._to_dataset_view (rebuild_dims = False )
758
+ self .parent ._to_dataset_view (rebuild_dims = False , inherited = True )
743
759
if self .parent is not None
744
760
else None
745
761
)
@@ -800,8 +816,10 @@ def _copy_node(
800
816
deep : bool = False ,
801
817
) -> DataTree :
802
818
"""Copy just one node of a tree"""
803
- data = self .ds .copy (deep = deep )
804
- new_node : DataTree = DataTree (data , name = self .name )
819
+ data = self ._to_dataset_view (rebuild_dims = False , inherited = False )
820
+ if deep :
821
+ data = data .copy (deep = True )
822
+ new_node = DataTree (data , name = self .name )
805
823
return new_node
806
824
807
825
def __copy__ (self : DataTree ) -> DataTree :
@@ -1096,7 +1114,6 @@ def from_dict(
1096
1114
root_data = d_cast .pop ("/" , None )
1097
1115
if isinstance (root_data , DataTree ):
1098
1116
obj = root_data .copy ()
1099
- obj .orphan ()
1100
1117
elif root_data is None or isinstance (root_data , Dataset ):
1101
1118
obj = cls (name = name , data = root_data , children = None )
1102
1119
else :
@@ -1116,9 +1133,10 @@ def depth(item) -> int:
1116
1133
node_name = NodePath (path ).name
1117
1134
if isinstance (data , DataTree ):
1118
1135
new_node = data .copy ()
1119
- new_node .orphan ()
1120
- else :
1136
+ elif isinstance (data , Dataset ) or data is None :
1121
1137
new_node = cls (name = node_name , data = data )
1138
+ else :
1139
+ raise TypeError (f"invalid values: { data } " )
1122
1140
obj ._set_item (
1123
1141
path ,
1124
1142
new_node ,
0 commit comments