Skip to content

Commit 2965eed

Browse files
Added tensor shape test
1 parent 38e74fd commit 2965eed

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/test_freeze.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,3 +3281,29 @@ def func(c, x, s):
32813281

32823282
assert frozen.n_recordings == 1
32833283

3284+
@pytest.test_arrays("float32, jit, diff, shape=(*)")
3285+
@pytest.mark.parametrize("auto_opaque", [False, True])
3286+
def test87_tensor_indexing(t, auto_opaque):
3287+
"""
3288+
Tests that changes in the first dimension of a tensor do not cause re-tracing.
3289+
"""
3290+
mod = sys.modules[t.__module__]
3291+
3292+
def func(x: mod.TensorXf, row: mod.UInt32, col: mod.UInt32):
3293+
return dr.gather(mod.Float, x.array, row * dr.shape(x)[1] + col)
3294+
3295+
frozen = dr.freeze(func, auto_opaque = auto_opaque)
3296+
3297+
for i in range(3):
3298+
shape = ((i + 4), 10)
3299+
x = mod.TensorXf(dr.arange(mod.Float, dr.prod(shape)), shape = shape)
3300+
row = dr.arange(mod.UInt32, i+3)
3301+
col = dr.arange(mod.UInt32, i+3) + 1
3302+
3303+
res = frozen(x, row, col)
3304+
ref = func(x, row, col)
3305+
3306+
assert dr.allclose(res, ref)
3307+
3308+
assert frozen.n_recordings == 1
3309+

0 commit comments

Comments
 (0)