Skip to content

Commit 63d9ef7

Browse files
Added docs about tensors and output types
1 parent 2965eed commit 63d9ef7

File tree

1 file changed

+97
-28
lines changed

1 file changed

+97
-28
lines changed

docs/freeze.rst

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,18 @@ step can be very expensive, since the underlying compilers perform a lot of
2121
optimization on the intermediary code. Dr.Jit therefore caches this step by
2222
default using a hash of the assembled IR code. As mentioned in the :ref:`_eval`
2323
page, changing literal values can cause re-compilation of the kernel and result
24-
in a significant performance bottleneck. Memoization of compilation
25-
significantly reduces the overhead that otherwise would be encountered.
26-
However, the first two steps of tracing the Python code and generating the
27-
intermediary representation can still be expensive. This feature tries to
28-
address this performance bottleneck, by introducing the :py:func:`drjit.freeze`
29-
decorator. If a function is annotated with this decorator, Dr.Jit will try to
30-
cache the tracing and assembly steps as well. When a frozen function is called
31-
the first time, Dr.Jit will analyze the inputs, and then trace the
32-
function once, capturing all kernels lauched. On subsequent calls to the
33-
function Dr.Jit will try to find previous recordings with compatible input
34-
layouts. If such a recording is found, it will launch it instead of re-tracing
35-
the function. This skips tracing and assembly of kernels, as well as
36-
compilation, reducing the time spent not executing kernels.
24+
in a significant performance bottleneck. However, the first two steps of
25+
tracing the Python code and generating the intermediary representation can
26+
still be expensive. This feature tries to address this performance bottleneck,
27+
by introducing the :py:func:`drjit.freeze` decorator. If a function is
28+
annotated with this decorator, Dr.Jit will try to cache the tracing and
29+
assembly steps as well. When a frozen function is called the first time, Dr.Jit
30+
will analyze the inputs, and then trace the function once, capturing all
31+
kernels lauched. On subsequent calls to the function Dr.Jit will try to find
32+
previous recordings with compatible input layouts. If such a recording is
33+
found, it will be launched instead of re-tracing the function. This skips
34+
tracing and assembly of kernels, as well as compilation, reducing the time
35+
spent not executing kernels.
3736

3837
.. code-block:: python
3938
@@ -63,22 +62,22 @@ How Function Freezing Works
6362

6463
Every time the function is called, the input is analyzed and all JIT variables
6564
are extracted into a flat-deduplicated array. Additionally, a key of the layout
66-
in which the variables where stored in the input is generated. The key is used
67-
to find recordings of previous calls to the function in a hashmap. If none are
68-
found, the inner function is called and the backend is put into a recording
69-
mode. In this mode, all device level operations, such as kernel launches are
70-
record. When the function is called again, the input is traversed, and the
71-
layout is used to lookup compatible recordings. If such a recording is found,
72-
it is used to replay the kernel launches.
65+
in which the variables where stored is generated. The key is used to find
66+
recordings of previous calls to the function in a hashmap. If none are found,
67+
the inner function is called and the backend is put into a recording mode. In
68+
this mode, all device level operations, such as kernel launches are record.
69+
When the function is called again, the input is traversed, and the layout is
70+
used to lookup compatible recordings. If such a recording is found, it is used
71+
to replay the kernel launches.
7372

7473
Traversal
7574
~~~~~~~~~
7675

7776
In order to map the variables provided to a frozen function in its inputs to
78-
the to the kernel slots, Dr.Jit has to be able to traverse the input of the
79-
function. In addition to basic python containers such as lists, tuples and
80-
dictionaries, the following containers are traversable and can be part of the
81-
input of a frozen function.
77+
the kernel slots, Dr.Jit has to be able to traverse the input of the function.
78+
In addition to basic python containers such as lists, tuples and dictionaries,
79+
the following containers are traversable and can be part of the input of a
80+
frozen function.
8281

8382
*Dataclasses* are traversable by Dr.Jit and their fields are automatically made
8483
visible to the traversal algorithm.
@@ -104,8 +103,10 @@ traversable.
104103
"x": Float
105104
}
106105
107-
Classes inheriting from trampoline classes are automatically traversed. This is
108-
useful when implementing your own subclasses with vcalls.
106+
C++ classes such as scenes might additionally expose an interface to make them
107+
traversable. Python classes, inehriting from these classes through trampolines
108+
are automatically traversed. This is useful when implementing your own
109+
subclasses with vcalls.
109110

110111
.. code-block:: python
111112
@@ -114,6 +115,47 @@ useful when implementing your own subclasses with vcalls.
114115
class MyClass(BSDF):
115116
x: Float
116117
118+
Output Construction
119+
~~~~~~~~~~~~~~~~~~~
120+
121+
After a frozen function has been replayed, the outputs of the function have to
122+
be constructed from a flat array of JIT variable indices. This is accomplished
123+
by saving the layout of the output returned when recording the function. Since
124+
the output has to be constructed, only a subset of traversable variables can be
125+
returned from frozen functions. This includes:
126+
127+
- JIT and AD variables
128+
- Dr.Jit Tensors and Arrays
129+
- Python lists, tuples and dictionaries
130+
- Dataclasses
131+
- ``DRJIT_STRUCT`` annotated classes with a default constructor
132+
133+
The following example shows an unsupported return type, because the constructor
134+
of ``MyClass`` expects a variable.
135+
136+
.. code-block:: python
137+
138+
class MyClass:
139+
x: Float
140+
141+
DRJIT_STRUCT = {
142+
"x": Float,
143+
}
144+
145+
def __init__(self, x):
146+
self.x = x
147+
148+
@dr.freeze
149+
def func(x):
150+
return MyClass(x + 1)
151+
152+
# Calling the function will fail, as the output of the frozen function
153+
# cannot be constructed without a default constructor.
154+
x = Float(1, 2, 3)
155+
func(x)
156+
157+
Gradient Propagation
158+
--------------------
117159

118160
Unsupported Operations
119161
----------------------
@@ -445,10 +487,37 @@ if a JIT variable was missed.
445487
def outer(x):
446488
return inner(x)
447489
448-
Unsupported Inputs
449-
~~~~~~~~~~~~~~~~~~
490+
Tensor Shapes
491+
~~~~~~~~~~~~~
492+
493+
When a frozen function is called with a tensor, the first dimension of the
494+
tensor is assumed to be dynamic. It can change from one call to another without
495+
triggering re-tracing of the function. Changes in any other dimension will
496+
change the key of the function and cause it to be re-traced. This limitation
497+
results from the way tensors are generally indexed, where the index into the
498+
tensor array can be calculated without involving the first dimension.
499+
500+
.. code-block:: python
501+
502+
@dr.freeze
503+
def func(t: TensorXf, row, col):
504+
# Indexes into the tensor array, getting the entry at (row, col)
505+
return dr.gather(Float, t.array, row * dr.shape(t) [1] + col)
506+
507+
# The first call will record the function
508+
t = TensorXf(dr.arange(Float, 10*10), shape = (10, 10))
509+
func(t, UInt(1), UInt(1))
510+
511+
# Subsequent calls with the same trailing dimensions will be replayed
512+
t = TensorXf(dr.arange(Float, 5*10), shape = (5, 10))
513+
func(t, UInt(1), UInt(1))
450514
515+
# Changes in trailing dimensions will cause the function to be re-traced
516+
t = TensorXf(dr.arange(Float, 10*5), shape = (10, 5))
517+
func(t, UInt(1), UInt(1))
451518
519+
Textures
520+
~~~~~~~~
452521
453522
Virtual Function Calls
454523
~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)