@@ -27,4 +27,38 @@ def test01_tile(t):
2727def test02_repeat (t ):
2828 pkg = get_pkg (t )
2929 x = dr .arange (t , 10 )
30- assert dr .all (pkg .repeat (x , 3 ) == dr .repeat (x , 3 ))
30+ assert dr .all (pkg .repeat (x , 3 ) == dr .repeat (x , 3 ))
31+
32+ @pytest .test_arrays ('float32,is_diff,shape=(*)' )
33+ def test03_tile_ad (t ):
34+ pkg = get_pkg (t )
35+ x = dr .arange (t , 10 )
36+ x_tiled_dr = dr .tile (x , 3 )
37+ x_tiled_pkg = pkg .tile (x , 3 )
38+
39+ dr .enable_grad (x_tiled_dr , x_tiled_pkg )
40+
41+ x2_tiled_dr = x_tiled_dr * x_tiled_dr
42+ x2_tiled_pkg = x_tiled_pkg * x_tiled_pkg
43+
44+ dr .backward (x2_tiled_dr )
45+ dr .backward (x2_tiled_pkg )
46+
47+ assert dr .all (dr .grad (x_tiled_dr ) == dr .grad (x_tiled_pkg ))
48+
49+ @pytest .test_arrays ('float32,is_diff,shape=(*)' )
50+ def test04_repeat_ad (t ):
51+ pkg = get_pkg (t )
52+ x = dr .arange (t , 10 )
53+ x_repeated_dr = dr .repeat (x , 3 )
54+ x_repeated_pkg = pkg .repeat (x , 3 )
55+
56+ dr .enable_grad (x_repeated_dr , x_repeated_pkg )
57+
58+ x2_repeated_dr = x_repeated_dr * x_repeated_dr
59+ x2_repeated_pkg = x_repeated_pkg * x_repeated_pkg
60+
61+ dr .backward (x2_repeated_dr )
62+ dr .backward (x2_repeated_pkg )
63+
64+ assert dr .all (dr .grad (x_repeated_dr ) == dr .grad (x_repeated_pkg ))
0 commit comments