@@ -3281,3 +3281,29 @@ def func(c, x, s):
3281
3281
3282
3282
assert frozen .n_recordings == 1
3283
3283
3284
+ @pytest .test_arrays ("float32, jit, diff, shape=(*)" )
3285
+ @pytest .mark .parametrize ("auto_opaque" , [False , True ])
3286
+ def test87_tensor_indexing (t , auto_opaque ):
3287
+ """
3288
+ Tests that changes in the first dimension of a tensor do not cause re-tracing.
3289
+ """
3290
+ mod = sys .modules [t .__module__ ]
3291
+
3292
+ def func (x : mod .TensorXf , row : mod .UInt32 , col : mod .UInt32 ):
3293
+ return dr .gather (mod .Float , x .array , row * dr .shape (x )[1 ] + col )
3294
+
3295
+ frozen = dr .freeze (func , auto_opaque = auto_opaque )
3296
+
3297
+ for i in range (3 ):
3298
+ shape = ((i + 4 ), 10 )
3299
+ x = mod .TensorXf (dr .arange (mod .Float , dr .prod (shape )), shape = shape )
3300
+ row = dr .arange (mod .UInt32 , i + 3 )
3301
+ col = dr .arange (mod .UInt32 , i + 3 ) + 1
3302
+
3303
+ res = frozen (x , row , col )
3304
+ ref = func (x , row , col )
3305
+
3306
+ assert dr .allclose (res , ref )
3307
+
3308
+ assert frozen .n_recordings == 1
3309
+
0 commit comments