@@ -2414,6 +2414,7 @@ def freeze(
2414
2414
max_cache_size : Optional [int ] = None ,
2415
2415
warn_recording_count : int = 10 ,
2416
2416
backend : Optional [JitBackend ] = None ,
2417
+ auto_opaque : bool = True ,
2417
2418
) -> Callable [[F ], F ]:
2418
2419
"""
2419
2420
Decorator to freeze a function for replaying kernels without re-tracing.
@@ -2460,11 +2461,6 @@ def freeze(
2460
2461
y = dr.gather(type(x), x, dr.width(x)//2)
2461
2462
```
2462
2463
2463
- Similarly, calculating the mean of a variable relies on the number of entries,
2464
- which will be baked into the frozen function. To avoid this, we suggest
2465
- supplying the number of entries as a Dr.Jit literal in the arguments to the
2466
- function.
2467
-
2468
2464
Args:
2469
2465
f: The function to be frozen
2470
2466
@@ -2485,6 +2481,12 @@ def freeze(
2485
2481
frozen function, the backend used has to be specified using this argument.
2486
2482
It has to be the same backend that is used for variables inside the function,
2487
2483
otherwise recording will fail and variables may be leaked.
2484
+
2485
+ auto_opaque: (bool): If this flag is set true and only literal values
2486
+ or their size changes between calls to the function, these variables
2487
+ will be marked and made opaque. This reduces the memory usage, traversal
2488
+ overhead, and can improved the performance of generated kernels.
2489
+ If the flag is set to false, all input variables will be made opaque.
2488
2490
"""
2489
2491
2490
2492
@@ -2496,6 +2498,7 @@ def freeze(
2496
2498
max_cache_size : Optional [int ] = None ,
2497
2499
warn_recording_count : int = 10 ,
2498
2500
backend : Optional [JitBackend ] = None ,
2501
+ auto_opaque : bool = True ,
2499
2502
) -> F : ...
2500
2503
2501
2504
@@ -2506,6 +2509,7 @@ def freeze(
2506
2509
max_cache_size : Optional [int ] = None ,
2507
2510
warn_recording_count : int = 10 ,
2508
2511
backend : Optional [JitBackend ] = None ,
2512
+ auto_opaque : bool = True ,
2509
2513
) -> Union [F , Callable [[F2 ], F2 ]]:
2510
2514
max_cache_size = max_cache_size if max_cache_size is not None else - 1
2511
2515
backend = backend if backend is not None else JitBackend .Invalid
@@ -2530,10 +2534,7 @@ def __init__(self, f) -> None:
2530
2534
closure = inspect .getclosurevars (f )
2531
2535
self .closure = (closure .nonlocals , closure .globals )
2532
2536
self .frozen = detail .FrozenFunction (
2533
- inner ,
2534
- max_cache_size ,
2535
- warn_recording_count ,
2536
- backend ,
2537
+ inner , max_cache_size , warn_recording_count , backend , auto_opaque
2537
2538
)
2538
2539
2539
2540
def __call__ (self , * args , ** kwargs ):
0 commit comments