Skip to content

Commit fd69305

Browse files
lnuicnjroussel
authored andcommitted
test: add coverage for ad repeat/tile c++ operations
1 parent 282da88 commit fd69305

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

tests/test_py_cpp_consistency_ext.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,38 @@ def test01_tile(t):
2727
def test02_repeat(t):
2828
pkg = get_pkg(t)
2929
x = dr.arange(t, 10)
30-
assert dr.all(pkg.repeat(x, 3) == dr.repeat(x, 3))
30+
assert dr.all(pkg.repeat(x, 3) == dr.repeat(x, 3))
31+
32+
@pytest.test_arrays('float32,is_diff,shape=(*)')
33+
def test03_tile_ad(t):
34+
pkg = get_pkg(t)
35+
x = dr.arange(t, 10)
36+
x_tiled_dr = dr.tile(x, 3)
37+
x_tiled_pkg = pkg.tile(x, 3)
38+
39+
dr.enable_grad(x_tiled_dr, x_tiled_pkg)
40+
41+
x2_tiled_dr = x_tiled_dr * x_tiled_dr
42+
x2_tiled_pkg = x_tiled_pkg * x_tiled_pkg
43+
44+
dr.backward(x2_tiled_dr)
45+
dr.backward(x2_tiled_pkg)
46+
47+
assert dr.all(dr.grad(x_tiled_dr) == dr.grad(x_tiled_pkg))
48+
49+
@pytest.test_arrays('float32,is_diff,shape=(*)')
50+
def test04_repeat_ad(t):
51+
pkg = get_pkg(t)
52+
x = dr.arange(t, 10)
53+
x_repeated_dr = dr.repeat(x, 3)
54+
x_repeated_pkg = pkg.repeat(x, 3)
55+
56+
dr.enable_grad(x_repeated_dr, x_repeated_pkg)
57+
58+
x2_repeated_dr = x_repeated_dr * x_repeated_dr
59+
x2_repeated_pkg = x_repeated_pkg * x_repeated_pkg
60+
61+
dr.backward(x2_repeated_dr)
62+
dr.backward(x2_repeated_pkg)
63+
64+
assert dr.all(dr.grad(x_repeated_dr) == dr.grad(x_repeated_pkg))

0 commit comments

Comments
 (0)