Skip to content

Commit d4bb421

Browse files
Added docs about gradient propagation and corresponding test
1 parent 63d9ef7 commit d4bb421

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

docs/freeze.rst

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,52 @@ of ``MyClass`` expects a variable.
157157
Gradient Propagation
158158
--------------------
159159

160+
Very often tracing the backward pass of an AD-attached computation is at least
161+
as complex as the forward pass, and caching both the tracing and assembly steps
162+
is desireable. Therefore, the :py:func:`drjit.freeze` decorator supports
163+
propagating gradients to the inputs of the function. However, it is not yet
164+
supported to propagate gradients from the result of a frozen function backwards
165+
through the function. In terms of autodiff, anotating a function with the
166+
:py:func:`dr.freeze` decorator is equivalent to wrapping the content with an
167+
isolated gradient scope.
168+
169+
.. code-block:: python
170+
171+
@dr.freeze
172+
def func(y):
173+
# Some differentiable operation...
174+
z = dr.mean(y)
175+
# Propagate the gradients to the input of the function...
176+
dr.backward(z)
177+
178+
x = dr.arange(Float, 3)
179+
dr.enable_grad(x)
180+
181+
y = dr.square(x)
182+
183+
# The first time the function is called, it will be recorded and the correct
184+
# gradients will be accumulated into x.
185+
func(y)
186+
187+
y = x * 2
188+
189+
# On subsequent calls the the function will be replayed, and gradients will
190+
# be accumulated in x.
191+
func(y)
192+
193+
The :py:func:`drjit.freeze` decorator adds an implicit
194+
:py:func:`drjit.isolate_grad` context to the function. The above function is
195+
then equivalent to the following function.
196+
197+
.. code-block:: python
198+
199+
def func(y):
200+
with dr.isolate_grad():
201+
# Some differentiable operation...
202+
z = dr.mean(y)
203+
# Propagate the gradients to the input of the function...
204+
dr.backward(z)
205+
160206
Unsupported Operations
161207
----------------------
162208

@@ -229,6 +275,7 @@ supported.
229275
..code-block:: cpp
230276
231277
# This pattern is not supported inside of frozen functions.
278+
232279
UInt32::load_(x.data() + 4)
233280
234281
This pattern might be used in C++ code called by the frozen function and can

tests/test_freeze.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3307,3 +3307,41 @@ def func(x: mod.TensorXf, row: mod.UInt32, col: mod.UInt32):
33073307

33083308
assert frozen.n_recordings == 1
33093309

3310+
3311+
@pytest.test_arrays("float32, jit, diff, shape=(*)")
3312+
@pytest.mark.parametrize("auto_opaque", [False, True])
3313+
def test88_grad_doc(t, auto_opaque):
3314+
"""
3315+
Tests the code snippet from the docs section on gradients.
3316+
"""
3317+
3318+
@dr.freeze
3319+
def func(y):
3320+
# Some differentiable operation...
3321+
z = dr.mean(y)
3322+
# Propagate the gradients to the input of the function...
3323+
dr.backward(z)
3324+
3325+
x = dr.arange(t, 3)
3326+
dr.enable_grad(x)
3327+
3328+
y = dr.square(x)
3329+
3330+
# The first time the function is called, it will be recorded and the correct
3331+
# gradients will be accumulated into x.
3332+
func(y)
3333+
3334+
# Compare against manually calculated gradient
3335+
assert dr.allclose(dr.grad(x), 2 * 1 / dr.width(x) * x)
3336+
3337+
dr.clear_grad(x)
3338+
3339+
y = x * 2
3340+
3341+
# On subsequent calls the the function will be replayed, and gradients will
3342+
# be accumulated in x.
3343+
func(y)
3344+
3345+
# Compare against manually calculated gradient
3346+
assert dr.allclose(dr.grad(x), [2 * 1 / dr.width(x)] * dr.width(x))
3347+

0 commit comments

Comments
 (0)