Skip to content

Commit 21fb2e4

Browse files
Added more documentation
1 parent d4bb421 commit 21fb2e4

File tree

2 files changed

+99
-18
lines changed

2 files changed

+99
-18
lines changed

docs/freeze.rst

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@ Function Freezing
66
=================
77

88
This feature is still experimental, and we list a number of unsupported cases
9-
in the :ref:`pitfalls` section. If you encounter any issues please feel free to
10-
open an issue `here <https://github.com/mitsuba-renderer/drjit/issues>`__.
9+
in the :ref:`pitfalls` section. This feature also only supports a subset of the
10+
operations, that can be performed with Dr.Jit, we list them in the
11+
:ref:`unsupported_operations` section. If you encounter any issues please feel
12+
free to open an issue `here
13+
<https://github.com/mitsuba-renderer/drjit/issues>`__.
1114

1215
Introduction
1316
------------
@@ -23,16 +26,19 @@ default using a hash of the assembled IR code. As mentioned in the :ref:`_eval`
2326
page, changing literal values can cause re-compilation of the kernel and result
2427
in a significant performance bottleneck. However, the first two steps of
2528
tracing the Python code and generating the intermediary representation can
26-
still be expensive. This feature tries to address this performance bottleneck,
27-
by introducing the :py:func:`drjit.freeze` decorator. If a function is
28-
annotated with this decorator, Dr.Jit will try to cache the tracing and
29-
assembly steps as well. When a frozen function is called the first time, Dr.Jit
30-
will analyze the inputs, and then trace the function once, capturing all
31-
kernels lauched. On subsequent calls to the function Dr.Jit will try to find
32-
previous recordings with compatible input layouts. If such a recording is
33-
found, it will be launched instead of re-tracing the function. This skips
34-
tracing and assembly of kernels, as well as compilation, reducing the time
35-
spent not executing kernels.
29+
still be expensive. When a lot of Python code has to be traced, such as custom
30+
Python functions, the GIL has to be locked multiple times. Similarely, when
31+
tracing virtual function calls of many instances of custom plugins, these
32+
functions can cause a large performance overhead. This feature tries to address
33+
this performance bottleneck, by introducing the :py:func:`drjit.freeze`
34+
decorator. If a function is annotated with this decorator, Dr.Jit will try to
35+
cache the tracing and assembly steps as well. When a frozen function is called
36+
the first time, Dr.Jit will analyze the inputs, and then trace the function
37+
once, capturing all kernels lauched. On subsequent calls to the function Dr.Jit
38+
will try to find previous recordings with compatible input layouts. If such a
39+
recording is found, it will be launched instead of re-tracing the function.
40+
This skips tracing and assembly of kernels, as well as compilation, reducing
41+
the time spent not executing kernels.
3642

3743
.. code-block:: python
3844
@@ -124,11 +130,11 @@ by saving the layout of the output returned when recording the function. Since
124130
the output has to be constructed, only a subset of traversable variables can be
125131
returned from frozen functions. This includes:
126132

127-
- JIT and AD variables
128-
- Dr.Jit Tensors and Arrays
129-
- Python lists, tuples and dictionaries
130-
- Dataclasses
131-
- ``DRJIT_STRUCT`` annotated classes with a default constructor
133+
- JIT and AD variables.
134+
- Dr.Jit Tensors and Arrays.
135+
- Python lists, tuples and dictionaries.
136+
- Dataclasses i.e. classes annotated with ``@dataclass``.
137+
- ``DRJIT_STRUCT`` annotated classes with a default constructor.
132138

133139
The following example shows an unsupported return type, because the constructor
134140
of ``MyClass`` expects a variable.
@@ -197,17 +203,20 @@ then equivalent to the following function.
197203
.. code-block:: python
198204
199205
def func(y):
206+
# The isolate grad scope is added implicitly by the freezing decorator
200207
with dr.isolate_grad():
201208
# Some differentiable operation...
202209
z = dr.mean(y)
203210
# Propagate the gradients to the input of the function...
204211
dr.backward(z)
205212
213+
.. _unsupported_operations:
214+
206215
Unsupported Operations
207216
----------------------
208217

209218
Since frozen functions record kernel launches and have to be able to replay
210-
them later, certian operations are not supported inside frozen functions.
219+
them later, certian operations are not supported by them.
211220

212221
Array Access
213222
~~~~~~~~~~~~
@@ -566,6 +575,26 @@ tensor array can be calculated without involving the first dimension.
566575
Textures
567576
~~~~~~~~
568577
578+
Textures can be used inside of frozen functions for lookups, as well as for
579+
gradient calculations. However because they require special memory operations
580+
on CUDA, it is not possible to update or initialize CUDA textures inside of
581+
frozen functions.
582+
583+
.. code-block:: python
584+
585+
@dr.freeze
586+
def func(tex: Texture1f, pos: Float):
587+
return tex.eval(pos)
588+
589+
tex = Texture1f([2], 1)
590+
tex.set_value(t(0, 1))
591+
592+
pos = dr.arange(Float, 4) / 4
593+
594+
# The texture can be evaluated inside the frozen function.
595+
func(tex, pos)
596+
597+
569598
Virtual Function Calls
570599
~~~~~~~~~~~~~~~~~~~~~
571600

tests/test_freeze.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3345,3 +3345,55 @@ def func(y):
33453345
# Compare against manually calculated gradient
33463346
assert dr.allclose(dr.grad(x), [2 * 1 / dr.width(x)] * dr.width(x))
33473347

3348+
@pytest.test_arrays("float32, jit, diff, shape=(*)")
3349+
@pytest.mark.parametrize("auto_opaque", [False, True])
3350+
def test89_custom_grad(t, auto_opaque):
3351+
"""
3352+
Tests the code snippet from the docs section on gradients.
3353+
"""
3354+
mod = sys.modules[t.__module__]
3355+
3356+
def func(x):
3357+
return dr.mean(x)
3358+
3359+
frozen = dr.freeze(func)
3360+
3361+
def func_bwd(x, dy):
3362+
dr.enable_grad(x)
3363+
3364+
y = func(x)
3365+
3366+
dr.set_grad(y, dy)
3367+
3368+
dr.backward(y)
3369+
3370+
dx = dr.grad(x)
3371+
dr.disable_grad(x)
3372+
3373+
return dx
3374+
3375+
frozen_bwd = dr.freeze(func_bwd)
3376+
3377+
class Custom(dr.CustomOp):
3378+
def eval(self, x):
3379+
self.x = x
3380+
return frozen(x)
3381+
3382+
def backward(self):
3383+
x = self.x
3384+
dy = self.grad_out()
3385+
dx = frozen_bwd(dr.detach(x), dy)
3386+
print(f"{dx=}")
3387+
self.set_grad_in("x", dx)
3388+
3389+
3390+
for i in range(3):
3391+
x = dr.arange(t, i + 3)
3392+
dr.enable_grad(x)
3393+
3394+
y = dr.custom(Custom, x)
3395+
3396+
dr.backward(y)
3397+
3398+
print(f"{dr.grad(x)=}")
3399+

0 commit comments

Comments
 (0)