@@ -21,19 +21,18 @@ step can be very expensive, since the underlying compilers perform a lot of
21
21
optimization on the intermediary code. Dr.Jit therefore caches this step by
22
22
default using a hash of the assembled IR code. As mentioned in the :ref: `_eval `
23
23
page, changing literal values can cause re-compilation of the kernel and result
24
- in a significant performance bottleneck. Memoization of compilation
25
- significantly reduces the overhead that otherwise would be encountered.
26
- However, the first two steps of tracing the Python code and generating the
27
- intermediary representation can still be expensive. This feature tries to
28
- address this performance bottleneck, by introducing the :py:func: `drjit.freeze `
29
- decorator. If a function is annotated with this decorator, Dr.Jit will try to
30
- cache the tracing and assembly steps as well. When a frozen function is called
31
- the first time, Dr.Jit will analyze the inputs, and then trace the
32
- function once, capturing all kernels lauched. On subsequent calls to the
33
- function Dr.Jit will try to find previous recordings with compatible input
34
- layouts. If such a recording is found, it will launch it instead of re-tracing
35
- the function. This skips tracing and assembly of kernels, as well as
36
- compilation, reducing the time spent not executing kernels.
24
+ in a significant performance bottleneck. However, the first two steps of
25
+ tracing the Python code and generating the intermediary representation can
26
+ still be expensive. This feature tries to address this performance bottleneck,
27
+ by introducing the :py:func: `drjit.freeze ` decorator. If a function is
28
+ annotated with this decorator, Dr.Jit will try to cache the tracing and
29
+ assembly steps as well. When a frozen function is called the first time, Dr.Jit
30
+ will analyze the inputs, and then trace the function once, capturing all
31
+ kernels lauched. On subsequent calls to the function Dr.Jit will try to find
32
+ previous recordings with compatible input layouts. If such a recording is
33
+ found, it will be launched instead of re-tracing the function. This skips
34
+ tracing and assembly of kernels, as well as compilation, reducing the time
35
+ spent not executing kernels.
37
36
38
37
.. code-block :: python
39
38
@@ -63,22 +62,22 @@ How Function Freezing Works
63
62
64
63
Every time the function is called, the input is analyzed and all JIT variables
65
64
are extracted into a flat-deduplicated array. Additionally, a key of the layout
66
- in which the variables where stored in the input is generated. The key is used
67
- to find recordings of previous calls to the function in a hashmap. If none are
68
- found, the inner function is called and the backend is put into a recording
69
- mode. In this mode, all device level operations, such as kernel launches are
70
- record. When the function is called again, the input is traversed, and the
71
- layout is used to lookup compatible recordings. If such a recording is found,
72
- it is used to replay the kernel launches.
65
+ in which the variables where stored is generated. The key is used to find
66
+ recordings of previous calls to the function in a hashmap. If none are found,
67
+ the inner function is called and the backend is put into a recording mode. In
68
+ this mode, all device level operations, such as kernel launches are record.
69
+ When the function is called again, the input is traversed, and the layout is
70
+ used to lookup compatible recordings. If such a recording is found, it is used
71
+ to replay the kernel launches.
73
72
74
73
Traversal
75
74
~~~~~~~~~
76
75
77
76
In order to map the variables provided to a frozen function in its inputs to
78
- the to the kernel slots, Dr.Jit has to be able to traverse the input of the
79
- function. In addition to basic python containers such as lists, tuples and
80
- dictionaries, the following containers are traversable and can be part of the
81
- input of a frozen function.
77
+ the kernel slots, Dr.Jit has to be able to traverse the input of the function.
78
+ In addition to basic python containers such as lists, tuples and dictionaries,
79
+ the following containers are traversable and can be part of the input of a
80
+ frozen function.
82
81
83
82
*Dataclasses * are traversable by Dr.Jit and their fields are automatically made
84
83
visible to the traversal algorithm.
@@ -104,8 +103,10 @@ traversable.
104
103
" x" : Float
105
104
}
106
105
107
- Classes inheriting from trampoline classes are automatically traversed. This is
108
- useful when implementing your own subclasses with vcalls.
106
+ C++ classes such as scenes might additionally expose an interface to make them
107
+ traversable. Python classes, inehriting from these classes through trampolines
108
+ are automatically traversed. This is useful when implementing your own
109
+ subclasses with vcalls.
109
110
110
111
.. code-block :: python
111
112
@@ -114,6 +115,47 @@ useful when implementing your own subclasses with vcalls.
114
115
class MyClass (BSDF ):
115
116
x: Float
116
117
118
+ Output Construction
119
+ ~~~~~~~~~~~~~~~~~~~
120
+
121
+ After a frozen function has been replayed, the outputs of the function have to
122
+ be constructed from a flat array of JIT variable indices. This is accomplished
123
+ by saving the layout of the output returned when recording the function. Since
124
+ the output has to be constructed, only a subset of traversable variables can be
125
+ returned from frozen functions. This includes:
126
+
127
+ - JIT and AD variables
128
+ - Dr.Jit Tensors and Arrays
129
+ - Python lists, tuples and dictionaries
130
+ - Dataclasses
131
+ - ``DRJIT_STRUCT `` annotated classes with a default constructor
132
+
133
+ The following example shows an unsupported return type, because the constructor
134
+ of ``MyClass `` expects a variable.
135
+
136
+ .. code-block :: python
137
+
138
+ class MyClass :
139
+ x: Float
140
+
141
+ DRJIT_STRUCT = {
142
+ " x" : Float,
143
+ }
144
+
145
+ def __init__ (self , x ):
146
+ self .x = x
147
+
148
+ @dr.freeze
149
+ def func (x ):
150
+ return MyClass(x + 1 )
151
+
152
+ # Calling the function will fail, as the output of the frozen function
153
+ # cannot be constructed without a default constructor.
154
+ x = Float(1 , 2 , 3 )
155
+ func(x)
156
+
157
+ Gradient Propagation
158
+ --------------------
117
159
118
160
Unsupported Operations
119
161
----------------------
@@ -445,10 +487,37 @@ if a JIT variable was missed.
445
487
def outer(x):
446
488
return inner(x)
447
489
448
- Unsupported Inputs
449
- ~~~~~~~~~~~~~~~~~~
490
+ Tensor Shapes
491
+ ~~~~~~~~~~~~~
492
+
493
+ When a frozen function is called with a tensor, the first dimension of the
494
+ tensor is assumed to be dynamic. It can change from one call to another without
495
+ triggering re- tracing of the function. Changes in any other dimension will
496
+ change the key of the function and cause it to be re- traced. This limitation
497
+ results from the way tensors are generally indexed, where the index into the
498
+ tensor array can be calculated without involving the first dimension.
499
+
500
+ .. code- block:: python
501
+
502
+ @ dr.freeze
503
+ def func(t: TensorXf, row, col):
504
+ # Indexes into the tensor array, getting the entry at (row, col)
505
+ return dr.gather(Float, t.array, row * dr.shape(t) [1 ] + col)
506
+
507
+ # The first call will record the function
508
+ t = TensorXf(dr.arange(Float, 10 * 10 ), shape = (10 , 10 ))
509
+ func(t, UInt(1 ), UInt(1 ))
510
+
511
+ # Subsequent calls with the same trailing dimensions will be replayed
512
+ t = TensorXf(dr.arange(Float, 5 * 10 ), shape = (5 , 10 ))
513
+ func(t, UInt(1 ), UInt(1 ))
450
514
515
+ # Changes in trailing dimensions will cause the function to be re-traced
516
+ t = TensorXf(dr.arange(Float, 10 * 5 ), shape = (10 , 5 ))
517
+ func(t, UInt(1 ), UInt(1 ))
451
518
519
+ Textures
520
+ ~~~~~~~~
452
521
453
522
Virtual Function Calls
454
523
~~~~~~~~~~~~~~~~~~~~~
0 commit comments