Skip to content

Migrate datatreee assertions/extensions/formatting #8967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
@@ -42,12 +42,17 @@ Bug fixes

Internal Changes
~~~~~~~~~~~~~~~~
- Migrates ``formatting_html`` functionality for `DataTree` into ``xarray/core`` (:pull: `8930`)
- Migrates ``formatting_html`` functionality for ``DataTree`` into ``xarray/core`` (:pull: `8930`)
By `Eni Awowale <https://github.com/eni-awowale>`_, `Julia Signell <https://github.com/jsignell>`_
and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``datatree_mapping`` functionality into ``xarray/core`` (:pull:`8948`)
By `Matt Savoie <https://github.com/flamingbear>`_ `Owen Littlejohns
<https://github.com/owenlittlejohns>` and `Tom Nicholas <https://github.com/TomNicholas>`_.
<https://github.com/owenlittlejohns>`_ and `Tom Nicholas <https://github.com/TomNicholas>`_.
- Migrates ``extensions``, ``formatting`` and ``datatree_render`` functionality for
``DataTree`` into ``xarray/core``. Also migrates ``testing`` functionality into
``xarray/testing/assertions`` for ``DataTree``. (:pull:`8967`)
By `Owen Littlejohns <https://github.com/owenlittlejohns>`_ and
`Tom Nicholas <https://github.com/TomNicholas>`_.


.. _whats-new.2024.03.0:
6 changes: 3 additions & 3 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@
check_isomorphic,
map_over_subtree,
)
from xarray.core.datatree_render import RenderDataTree
from xarray.core.formatting import datatree_repr
from xarray.core.formatting_html import (
datatree_repr as datatree_repr_html,
)
@@ -40,13 +42,11 @@
)
from xarray.core.variable import Variable
from xarray.datatree_.datatree.common import TreeAttrAccessMixin
from xarray.datatree_.datatree.formatting import datatree_repr
from xarray.datatree_.datatree.ops import (
DataTreeArithmeticMixin,
MappedDatasetMethodsMixin,
MappedDataWithCoords,
)
from xarray.datatree_.datatree.render import RenderTree

try:
from xarray.core.variable import calculate_dimensions
@@ -1451,7 +1451,7 @@ def pipe(

def render(self):
"""Print tree structure, including any data stored at each node."""
for pre, fill, node in RenderTree(self):
for pre, fill, node in RenderDataTree(self):
print(f"{pre}DataTree('{self.name}')")
for ds_line in repr(node.ds)[1:]:
print(f"{fill}{ds_line}")
37 changes: 3 additions & 34 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
@@ -3,11 +3,11 @@
import functools
import sys
from itertools import repeat
from textwrap import dedent
from typing import TYPE_CHECKING, Callable

from xarray import DataArray, Dataset
from xarray.core.iterators import LevelOrderIter
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.formatting import diff_treestructure
from xarray.core.treenode import NodePath, TreeNode

if TYPE_CHECKING:
@@ -71,37 +71,6 @@ def check_isomorphic(
raise TreeIsomorphismError("DataTree objects are not isomorphic:\n" + diff)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this into xarray/core/formatting.py to avoid a circular dependency issue.

"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
f"""\
Number of children on node '{path_a}' of the left object: {len(node_a.children)}
Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
)
return diff

return ""


def map_over_subtree(func: Callable) -> Callable:
"""
Decorator which turns a function which acts on (and returns) Datasets into one which acts on and returns DataTrees.
266 changes: 266 additions & 0 deletions xarray/core/datatree_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""
String Tree Rendering. Copied from anytree.

Minor changes to `RenderDataTree` include accessing `children.values()`, and
type hints.

"""

from __future__ import annotations

from collections import namedtuple
from collections.abc import Iterable, Iterator
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from xarray.core.datatree import DataTree

Row = namedtuple("Row", ("pre", "fill", "node"))


class AbstractStyle:
def __init__(self, vertical: str, cont: str, end: str):
"""
Tree Render Style.
Args:
vertical: Sign for vertical line.
cont: Chars for a continued branch.
end: Chars for the last branch.
"""
super().__init__()
self.vertical = vertical
self.cont = cont
self.end = end
assert (
len(cont) == len(vertical) == len(end)
), f"'{vertical}', '{cont}' and '{end}' need to have equal length"

@property
def empty(self) -> str:
"""Empty string as placeholder."""
return " " * len(self.end)

def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class ContStyle(AbstractStyle):
def __init__(self):
"""
Continued style, without gaps.

>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root")
>>> s0 = DataTree(name="sub0", parent=root)
>>> s0b = DataTree(name="sub0B", parent=s0)
>>> s0a = DataTree(name="sub0A", parent=s0)
>>> s1 = DataTree(name="sub1", parent=root)
>>> print(RenderDataTree(root))
DataTree('root', parent=None)
├── DataTree('sub0')
│ ├── DataTree('sub0B')
│ └── DataTree('sub0A')
└── DataTree('sub1')
"""
super().__init__("\u2502 ", "\u251c\u2500\u2500 ", "\u2514\u2500\u2500 ")


class RenderDataTree:
def __init__(
self,
node: DataTree,
style=ContStyle(),
childiter: type = list,
maxlevel: int | None = None,
):
"""
Render tree starting at `node`.
Keyword Args:
style (AbstractStyle): Render Style.
childiter: Child iterator. Note, due to the use of node.children.values(),
Iterables that change the order of children cannot be used
(e.g., `reversed`).
maxlevel: Limit rendering to this depth.
:any:`RenderDataTree` is an iterator, returning a tuple with 3 items:
`pre`
tree prefix.
`fill`
filling for multiline entries.
`node`
:any:`NodeMixin` object.
It is up to the user to assemble these parts to a whole.

Examples
--------

>>> from xarray import Dataset
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The examples in this documentation string are a bit shorter than the originals from anytree. That's because using node.children.values() gets a ValuesView which isn't compatible with iterables like reversed that alter the order of the items in the iterable.

>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root", data=Dataset({"a": 0, "b": 1}))
>>> s0 = DataTree(name="sub0", parent=root, data=Dataset({"c": 2, "d": 3}))
>>> s0b = DataTree(name="sub0B", parent=s0, data=Dataset({"e": 4}))
>>> s0a = DataTree(name="sub0A", parent=s0, data=Dataset({"f": 5, "g": 6}))
>>> s1 = DataTree(name="sub1", parent=root, data=Dataset({"h": 7}))

# Simple one line:

>>> for pre, _, node in RenderDataTree(root):
... print(f"{pre}{node.name}")
...
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1

# Multiline:

>>> for pre, fill, node in RenderDataTree(root):
... print(f"{pre}{node.name}")
... for variable in node.variables:
... print(f"{fill}{variable}")
...
root
a
b
├── sub0
│ c
│ d
│ ├── sub0B
│ │ e
│ └── sub0A
│ f
│ g
└── sub1
h

:any:`by_attr` simplifies attribute rendering and supports multiline:
>>> print(RenderDataTree(root).by_attr())
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1

# `maxlevel` limits the depth of the tree:

>>> print(RenderDataTree(root, maxlevel=2).by_attr("name"))
root
├── sub0
└── sub1
"""
if not isinstance(style, AbstractStyle):
style = style()
self.node = node
self.style = style
self.childiter = childiter
self.maxlevel = maxlevel

def __iter__(self) -> Iterator[Row]:
return self.__next(self.node, tuple())

def __next(
self, node: DataTree, continues: tuple[bool, ...], level: int = 0
) -> Iterator[Row]:
yield RenderDataTree.__item(node, continues, self.style)
children = node.children.values()
level += 1
if children and (self.maxlevel is None or level < self.maxlevel):
children = self.childiter(children)
for child, is_last in _is_last(children):
yield from self.__next(child, continues + (not is_last,), level=level)

@staticmethod
def __item(
node: DataTree, continues: tuple[bool, ...], style: AbstractStyle
) -> Row:
if not continues:
return Row("", "", node)
else:
items = [style.vertical if cont else style.empty for cont in continues]
indent = "".join(items[:-1])
branch = style.cont if continues[-1] else style.end
pre = indent + branch
fill = "".join(items)
return Row(pre, fill, node)

def __str__(self) -> str:
return str(self.node)

def __repr__(self) -> str:
classname = self.__class__.__name__
args = [
repr(self.node),
f"style={repr(self.style)}",
f"childiter={repr(self.childiter)}",
]
return f"{classname}({', '.join(args)})"

def by_attr(self, attrname: str = "name") -> str:
"""
Return rendered tree with node attribute `attrname`.

Examples
--------

>>> from xarray import Dataset
>>> from xarray.core.datatree import DataTree
>>> from xarray.core.datatree_render import RenderDataTree
>>> root = DataTree(name="root")
>>> s0 = DataTree(name="sub0", parent=root)
>>> s0b = DataTree(
... name="sub0B", parent=s0, data=Dataset({"foo": 4, "bar": 109})
... )
>>> s0a = DataTree(name="sub0A", parent=s0)
>>> s1 = DataTree(name="sub1", parent=root)
>>> s1a = DataTree(name="sub1A", parent=s1)
>>> s1b = DataTree(name="sub1B", parent=s1, data=Dataset({"bar": 8}))
>>> s1c = DataTree(name="sub1C", parent=s1)
>>> s1ca = DataTree(name="sub1Ca", parent=s1c)
>>> print(RenderDataTree(root).by_attr("name"))
root
├── sub0
│ ├── sub0B
│ └── sub0A
└── sub1
├── sub1A
├── sub1B
└── sub1C
└── sub1Ca
"""

def get() -> Iterator[str]:
for pre, fill, node in self:
attr = (
attrname(node)
if callable(attrname)
else getattr(node, attrname, "")
)
if isinstance(attr, (list, tuple)):
lines = attr
else:
lines = str(attr).split("\n")
yield f"{pre}{lines[0]}"
for line in lines[1:]:
yield f"{fill}{line}"

return "\n".join(get())


def _is_last(iterable: Iterable) -> Iterator[tuple[DataTree, bool]]:
iter_ = iter(iterable)
try:
nextitem = next(iter_)
except StopIteration:
pass
else:
item = nextitem
while True:
try:
nextitem = next(iter_)
yield item, False
except StopIteration:
yield nextitem, True
break
item = nextitem
18 changes: 18 additions & 0 deletions xarray/core/extensions.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree


class AccessorRegistrationWarning(Warning):
@@ -121,3 +122,20 @@ def register_dataset_accessor(name):
register_dataarray_accessor
"""
return _register_accessor(name, Dataset)


def register_datatree_accessor(name):
"""Register a custom accessor on DataTree objects.
Parameters
----------
name : str
Name under which the accessor should be registered. A warning is issued
if this name conflicts with a preexisting attribute.
See Also
--------
xarray.register_dataarray_accessor
xarray.register_dataset_accessor
"""
return _register_accessor(name, DataTree)
115 changes: 115 additions & 0 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
@@ -11,20 +11,24 @@
from datetime import datetime, timedelta
from itertools import chain, zip_longest
from reprlib import recursive_repr
from textwrap import dedent
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
from pandas.errors import OutOfBoundsDatetime

from xarray.core.datatree_render import RenderDataTree
from xarray.core.duck_array_ops import array_equiv, astype
from xarray.core.indexing import MemoryCachedArray
from xarray.core.iterators import LevelOrderIter
from xarray.core.options import OPTIONS, _get_boolean_with_default
from xarray.core.utils import is_duck_array
from xarray.namedarray.pycompat import array_type, to_duck_array, to_numpy

if TYPE_CHECKING:
from xarray.core.coordinates import AbstractCoordinates
from xarray.core.datatree import DataTree

UNITS = ("B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")

@@ -926,6 +930,37 @@ def diff_array_repr(a, b, compat):
return "\n".join(summary)


def diff_treestructure(a: DataTree, b: DataTree, require_names_equal: bool) -> str:
"""
Return a summary of why two trees are not isomorphic.
If they are isomorphic return an empty string.
"""

# Walking nodes in "level-order" fashion means walking down from the root breadth-first.
# Checking for isomorphism by walking in this way implicitly assumes that the tree is an ordered tree
# (which it is so long as children are stored in a tuple or list rather than in a set).
for node_a, node_b in zip(LevelOrderIter(a), LevelOrderIter(b)):
path_a, path_b = node_a.path, node_b.path

if require_names_equal and node_a.name != node_b.name:
diff = dedent(
f"""\
Node '{path_a}' in the left object has name '{node_a.name}'
Node '{path_b}' in the right object has name '{node_b.name}'"""
)
return diff

if len(node_a.children) != len(node_b.children):
diff = dedent(
f"""\
Number of children on node '{path_a}' of the left object: {len(node_a.children)}
Number of children on node '{path_b}' of the right object: {len(node_b.children)}"""
)
return diff

return ""


def diff_dataset_repr(a, b, compat):
summary = [
f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}"
@@ -945,6 +980,86 @@ def diff_dataset_repr(a, b, compat):
return "\n".join(summary)


def diff_nodewise_summary(a: DataTree, b: DataTree, compat):
"""Iterates over all corresponding nodes, recording differences between data at each location."""

compat_str = _compat_to_str(compat)

summary = []
for node_a, node_b in zip(a.subtree, b.subtree):
a_ds, b_ds = node_a.ds, node_b.ds

if not a_ds._all_compat(b_ds, compat):
dataset_diff = diff_dataset_repr(a_ds, b_ds, compat_str)
data_diff = "\n".join(dataset_diff.split("\n", 1)[1:])

nodediff = (
f"\nData in nodes at position '{node_a.path}' do not match:"
f"{data_diff}"
)
summary.append(nodediff)

return "\n".join(summary)


def diff_datatree_repr(a: DataTree, b: DataTree, compat):
summary = [
f"Left and right {type(a).__name__} objects are not {_compat_to_str(compat)}"
]

strict_names = True if compat in ["equals", "identical"] else False
treestructure_diff = diff_treestructure(a, b, strict_names)

# If the trees structures are different there is no point comparing each node
# TODO we could show any differences in nodes up to the first place that structure differs?
if treestructure_diff or compat == "isomorphic":
summary.append("\n" + treestructure_diff)
else:
nodewise_diff = diff_nodewise_summary(a, b, compat)
summary.append("\n" + nodewise_diff)

return "\n".join(summary)


def _single_node_repr(node: DataTree) -> str:
"""Information about this node, not including its relationships to other nodes."""
node_info = f"DataTree('{node.name}')"

if node.has_data or node.has_attrs:
ds_info = "\n" + repr(node.ds)
else:
ds_info = ""
return node_info + ds_info


def datatree_repr(dt: DataTree):
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)

lines = []
for pre, fill, node in renderer:
node_repr = _single_node_repr(node)

node_line = f"{pre}{node_repr.splitlines()[0]}"
lines.append(node_line)

if node.has_data or node.has_attrs:
ds_repr = node_repr.splitlines()[2:]
for line in ds_repr:
if len(node.children) > 0:
lines.append(f"{fill}{renderer.style.vertical}{line}")
else:
lines.append(f"{fill}{' ' * len(renderer.style.vertical)}{line}")

# Tack on info about whether or not root node has a parent at the start
first_line = lines[0]
parent = f'"{dt.parent.name}"' if dt.parent is not None else "None"
first_line_with_parent = first_line[:-1] + f", parent={parent})"
lines[0] = first_line_with_parent

return "\n".join(lines)


def shorten_list_repr(items: Sequence, max_items: int) -> str:
if len(items) <= max_items:
return repr(items)
20 changes: 0 additions & 20 deletions xarray/datatree_/datatree/extensions.py

This file was deleted.

91 changes: 0 additions & 91 deletions xarray/datatree_/datatree/formatting.py

This file was deleted.

2 changes: 1 addition & 1 deletion xarray/datatree_/datatree/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import textwrap

from xarray import Dataset
from xarray.core.dataset import Dataset
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was causing another circular dependency issue. @flamingbear - just FYI, for when you are tweaking ops.py.


from xarray.core.datatree_mapping import map_over_subtree

271 changes: 0 additions & 271 deletions xarray/datatree_/datatree/render.py

This file was deleted.

120 changes: 0 additions & 120 deletions xarray/datatree_/datatree/testing.py

This file was deleted.

29 changes: 0 additions & 29 deletions xarray/datatree_/datatree/tests/__init__.py

This file was deleted.

65 changes: 0 additions & 65 deletions xarray/datatree_/datatree/tests/conftest.py

This file was deleted.

98 changes: 0 additions & 98 deletions xarray/datatree_/datatree/tests/test_dataset_api.py

This file was deleted.

40 changes: 0 additions & 40 deletions xarray/datatree_/datatree/tests/test_extensions.py

This file was deleted.

120 changes: 0 additions & 120 deletions xarray/datatree_/datatree/tests/test_formatting.py

This file was deleted.

1 change: 1 addition & 0 deletions xarray/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# TODO: Add assert_isomorphic when making DataTree API public
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed we didn't want to surface assert_isomorphic until everything else was public. Does that sound good to you @TomNicholas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there was an issue collecting things we need to do to put a final bow on things, but I'm not finding it. Should we add it to #8572? or is that overkill?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't find a dedicated issue for that either. Yes lets' just make an explicit list under Expose datatree API publicly. on #8572 (I'll do that now)

from xarray.testing.assertions import ( # noqa: F401
_assert_dataarray_invariants,
_assert_dataset_invariants,
106 changes: 99 additions & 7 deletions xarray/testing/assertions.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import functools
import warnings
from collections.abc import Hashable
from typing import Union
from typing import Union, overload

import numpy as np
import pandas as pd
@@ -12,6 +12,8 @@
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.datatree import DataTree
from xarray.core.formatting import diff_datatree_repr
from xarray.core.indexes import Index, PandasIndex, PandasMultiIndex, default_indexes
from xarray.core.variable import IndexVariable, Variable

@@ -50,7 +52,59 @@ def _data_allclose_or_equiv(arr1, arr2, rtol=1e-05, atol=1e-08, decode_bytes=Tru


@ensure_warnings
def assert_equal(a, b):
def assert_isomorphic(a: DataTree, b: DataTree, from_root: bool = False):
"""
Two DataTrees are considered isomorphic if every node has the same number of children.
Nothing about the data or attrs in each node is checked.
Isomorphism is a necessary condition for two trees to be used in a nodewise binary operation,
such as tree1 + tree2.
By default this function does not check any part of the tree above the given node.
Therefore this function can be used as default to check that two subtrees are isomorphic.
Parameters
----------
a : DataTree
The first object to compare.
b : DataTree
The second object to compare.
from_root : bool, optional, default is False
Whether or not to first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
See Also
--------
DataTree.isomorphic
assert_equal
assert_identical
"""
__tracebackhide__ = True
assert isinstance(a, type(b))

if isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.isomorphic(b, from_root=from_root), diff_datatree_repr(
a, b, "isomorphic"
)
else:
raise TypeError(f"{type(a)} not of type DataTree")


@overload
def assert_equal(a, b): ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially tried to specify all the individual overloads (e.g., assert_equal(a: Dataset, b: Dataset), etc). That led to issues with the return values from DataTree.__getitem__, which are: DataTree | DataArray.

Hopefully this hits enough of the spot, though.



@overload
def assert_equal(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_equal(a, b, from_root=True):
"""Like :py:func:`numpy.testing.assert_array_equal`, but for xarray
objects.
@@ -59,12 +113,20 @@ def assert_equal(a, b):
(except for Dataset objects for which the variable names must match).
Arrays with NaN in the same location are considered equal.
For DataTree objects, assert_equal is mapped over all Datasets on each node,
with the DataTrees being equal if both are isomorphic and the corresponding
Datasets at each node are themselves equal.
Parameters
----------
a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The second object to compare.
a : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
or xarray.core.datatree.DataTree. The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable, xarray.Coordinates
or xarray.core.datatree.DataTree. The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
See Also
--------
@@ -81,23 +143,45 @@ def assert_equal(a, b):
assert a.equals(b), formatting.diff_dataset_repr(a, b, "equals")
elif isinstance(a, Coordinates):
assert a.equals(b), formatting.diff_coords_repr(a, b, "equals")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.equals(b, from_root=from_root), diff_datatree_repr(a, b, "equals")
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")


@overload
def assert_identical(a, b): ...


@overload
def assert_identical(a: DataTree, b: DataTree, from_root: bool = True): ...


@ensure_warnings
def assert_identical(a, b):
def assert_identical(a, b, from_root=True):
"""Like :py:func:`xarray.testing.assert_equal`, but also matches the
objects' names and attributes.
Raises an AssertionError if two objects are not identical.
For DataTree objects, assert_identical is mapped over all Datasets on each
node, with the DataTrees being identical if both are isomorphic and the
corresponding Datasets at each node are themselves identical.
Parameters
----------
a : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The first object to compare.
b : xarray.Dataset, xarray.DataArray, xarray.Variable or xarray.Coordinates
The second object to compare.
from_root : bool, optional, default is True
Only used when comparing DataTree objects. Indicates whether or not to
first traverse to the root of the trees before checking for isomorphism.
If a & b have no parents then this has no effect.
See Also
--------
@@ -116,6 +200,14 @@ def assert_identical(a, b):
assert a.identical(b), formatting.diff_dataset_repr(a, b, "identical")
elif isinstance(a, Coordinates):
assert a.identical(b), formatting.diff_coords_repr(a, b, "identical")
elif isinstance(a, DataTree):
if from_root:
a = a.root
b = b.root

assert a.identical(b, from_root=from_root), diff_datatree_repr(
a, b, "identical"
)
else:
raise TypeError(f"{type(a)} not supported by assertion comparison")

2 changes: 1 addition & 1 deletion xarray/tests/test_backends_datatree.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import pytest

from xarray.backends.api import open_datatree
from xarray.datatree_.datatree.testing import assert_equal
from xarray.testing import assert_equal
from xarray.tests import (
requires_h5netcdf,
requires_netCDF4,
178 changes: 136 additions & 42 deletions xarray/tests/test_datatree.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
check_isomorphic,
map_over_subtree,
)
from xarray.datatree_.datatree.testing import assert_equal
from xarray.testing import assert_equal

empty = xr.Dataset()

9 changes: 9 additions & 0 deletions xarray/tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -5,9 +5,14 @@
import pytest

import xarray as xr

# TODO: Remove imports in favour of xr.DataTree etc, once part of public API
Copy link
Contributor Author

@owenlittlejohns owenlittlejohns Apr 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also hoping that once we can use xr.DataTree that all the type annotations that mypy has insisted on in the tests might be able to be removed (e.g., dt: DataTree = DataTree())

from xarray.core.datatree import DataTree
from xarray.core.extensions import register_datatree_accessor
from xarray.tests import assert_identical


@register_datatree_accessor("example_accessor")
@xr.register_dataset_accessor("example_accessor")
@xr.register_dataarray_accessor("example_accessor")
class ExampleAccessor:
@@ -19,6 +24,7 @@ def __init__(self, xarray_obj):

class TestAccessor:
def test_register(self) -> None:
@register_datatree_accessor("demo")
@xr.register_dataset_accessor("demo")
@xr.register_dataarray_accessor("demo")
class DemoAccessor:
@@ -31,6 +37,9 @@ def __init__(self, xarray_obj):
def foo(self):
return "bar"

dt: DataTree = DataTree()
assert dt.demo.foo == "bar"

ds = xr.Dataset()
assert ds.demo.foo == "bar"

103 changes: 103 additions & 0 deletions xarray/tests/test_formatting.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@

import xarray as xr
from xarray.core import formatting
from xarray.core.datatree import DataTree # TODO: Remove when can do xr.DataTree
from xarray.tests import requires_cftime, requires_dask, requires_netCDF4

ON_WINDOWS = sys.platform == "win32"
@@ -555,6 +556,108 @@ def test_array_scalar_format(self) -> None:
format(var, ".2f")
assert "Using format_spec is only supported" in str(excinfo.value)

def test_datatree_print_empty_node(self):
dt: DataTree = DataTree(name="root")
printout = dt.__str__()
assert printout == "DataTree('root', parent=None)"

def test_datatree_print_empty_node_with_attrs(self):
dat = xr.Dataset(attrs={"note": "has attrs"})
dt: DataTree = DataTree(name="root", data=dat)
printout = dt.__str__()
assert printout == dedent(
"""\
DataTree('root', parent=None)
Dimensions: ()
Data variables:
*empty*
Attributes:
note: has attrs"""
)

def test_datatree_print_node_with_data(self):
dat = xr.Dataset({"a": [0, 2]})
dt: DataTree = DataTree(name="root", data=dat)
printout = dt.__str__()
expected = [
"DataTree('root', parent=None)",
"Dimensions",
"Coordinates",
"a",
"Data variables",
"*empty*",
]
for expected_line, printed_line in zip(expected, printout.splitlines()):
assert expected_line in printed_line

def test_datatree_printout_nested_node(self):
dat = xr.Dataset({"a": [0, 2]})
root: DataTree = DataTree(name="root")
DataTree(name="results", data=dat, parent=root)
printout = root.__str__()
assert printout.splitlines()[2].startswith(" ")

def test_datatree_repr_of_node_with_data(self):
dat = xr.Dataset({"a": [0, 2]})
dt: DataTree = DataTree(name="root", data=dat)
assert "Coordinates" in repr(dt)

def test_diff_datatree_repr_structure(self):
dt_1: DataTree = DataTree.from_dict({"a": None, "a/b": None, "a/c": None})
dt_2: DataTree = DataTree.from_dict({"d": None, "d/e": None})

expected = dedent(
"""\
Left and right DataTree objects are not isomorphic
Number of children on node '/a' of the left object: 2
Number of children on node '/d' of the right object: 1"""
)
actual = formatting.diff_datatree_repr(dt_1, dt_2, "isomorphic")
assert actual == expected

def test_diff_datatree_repr_node_names(self):
dt_1: DataTree = DataTree.from_dict({"a": None})
dt_2: DataTree = DataTree.from_dict({"b": None})

expected = dedent(
"""\
Left and right DataTree objects are not identical
Node '/a' in the left object has name 'a'
Node '/b' in the right object has name 'b'"""
)
actual = formatting.diff_datatree_repr(dt_1, dt_2, "identical")
assert actual == expected

def test_diff_datatree_repr_node_data(self):
# casting to int64 explicitly ensures that int64s are created on all architectures
ds1 = xr.Dataset({"u": np.int64(0), "v": np.int64(1)})
ds3 = xr.Dataset({"w": np.int64(5)})
dt_1: DataTree = DataTree.from_dict({"a": ds1, "a/b": ds3})
ds2 = xr.Dataset({"u": np.int64(0)})
ds4 = xr.Dataset({"w": np.int64(6)})
dt_2: DataTree = DataTree.from_dict({"a": ds2, "a/b": ds4})

expected = dedent(
"""\
Left and right DataTree objects are not equal
Data in nodes at position '/a' do not match:
Data variables only on the left object:
v int64 8B 1
Data in nodes at position '/a/b' do not match:
Differing data variables:
L w int64 8B 5
R w int64 8B 6"""
)
actual = formatting.diff_datatree_repr(dt_1, dt_2, "equals")
assert actual == expected


def test_inline_variable_array_repr_custom_repr() -> None:
class CustomArray: