Skip to content
This repository was archived by the owner on Oct 24, 2024. It is now read-only.

Test cases for tree broadcasting (for hollow trees) #198

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions datatree/tests/test_tree_broadcasting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
import pytest
import xarray as xr

import datatree.testing as dtt
from datatree.datatree import DataTree

empty = xr.Dataset()

def int_ds(value):
return xr.Dataset({'data':xr.DataArray(value)})

class TestTreeBroadcasting:

def test_single_root(self):
dt1 = DataTree.from_dict(d={'root': int_ds(3)})
dt2 = DataTree.from_dict(d={"root/a": int_ds(2) , "root/b": int_ds(5)})
expected = DataTree.from_dict(d={'root/a': int_ds(3*2), 'root/b': int_ds(3*5)})
dtt.assert_equal(dt1*dt2, expected)

def test_hollow_level_2(self):
dt1 = DataTree.from_dict(d={'root/a': int_ds(5), 'root/b': int_ds(4)})
dt2 = DataTree.from_dict(d={'root/a': int_ds(3), 'root/b/c': int_ds(2), 'root/b/d': int_ds(7)})
expected = DataTree.from_dict(d={'root/a': int_ds(3*5), 'root/b/c':int_ds(2*4), 'root/b/d': int_ds(7*4)})
dtt.assert_equal(dt1*dt2, expected)

def test_dense_level_2(self):
dt1 = DataTree.from_dict(d={'root/a': int_ds(5), 'root/b': int_ds(4)})
dt2 = DataTree.from_dict(d={'root/a': int_ds(3), 'root/b': int_ds(9), 'root/b/c': int_ds(2), 'root/b/d': int_ds(7)})
expected = DataTree.from_dict(d={'root/a': int_ds(3*5), 'root/b':int_ds(9*4)})
with pytest.raises(ValueError, match='not implemented for non-hollow trees')
dtt.assert_equal(dt1*dt2, expected)

def test_hollow_twoway_level_2(self):
dt1 = DataTree.from_dict(d={'root/a/e': int_ds(5), 'root/a/f': int_ds(20), 'root/b': int_ds(4)})
dt2 = DataTree.from_dict(d={'root/a': int_ds(3), 'root/b/c': int_ds(2), 'root/b/d': int_ds(7)})
expected = DataTree.from_dict(d={
'root/a/e': int_ds(3*5),
'root/a/f': int_ds(3*20),
'root/b/c': int_ds(4*2),
'root/b/d': int_ds(4*7)
}
)
dtt.assert_equal(dt1*dt2, expected)