@@ -157,6 +157,52 @@ of ``MyClass`` expects a variable.
157
157
Gradient Propagation
158
158
--------------------
159
159
160
+ Very often tracing the backward pass of an AD-attached computation is at least
161
+ as complex as the forward pass, and caching both the tracing and assembly steps
162
+ is desireable. Therefore, the :py:func: `drjit.freeze ` decorator supports
163
+ propagating gradients to the inputs of the function. However, it is not yet
164
+ supported to propagate gradients from the result of a frozen function backwards
165
+ through the function. In terms of autodiff, anotating a function with the
166
+ :py:func: `dr.freeze ` decorator is equivalent to wrapping the content with an
167
+ isolated gradient scope.
168
+
169
+ .. code-block :: python
170
+
171
+ @dr.freeze
172
+ def func (y ):
173
+ # Some differentiable operation...
174
+ z = dr.mean(y)
175
+ # Propagate the gradients to the input of the function...
176
+ dr.backward(z)
177
+
178
+ x = dr.arange(Float, 3 )
179
+ dr.enable_grad(x)
180
+
181
+ y = dr.square(x)
182
+
183
+ # The first time the function is called, it will be recorded and the correct
184
+ # gradients will be accumulated into x.
185
+ func(y)
186
+
187
+ y = x * 2
188
+
189
+ # On subsequent calls the the function will be replayed, and gradients will
190
+ # be accumulated in x.
191
+ func(y)
192
+
193
+ The :py:func: `drjit.freeze ` decorator adds an implicit
194
+ :py:func: `drjit.isolate_grad ` context to the function. The above function is
195
+ then equivalent to the following function.
196
+
197
+ .. code-block :: python
198
+
199
+ def func (y ):
200
+ with dr.isolate_grad():
201
+ # Some differentiable operation...
202
+ z = dr.mean(y)
203
+ # Propagate the gradients to the input of the function...
204
+ dr.backward(z)
205
+
160
206
Unsupported Operations
161
207
----------------------
162
208
@@ -229,6 +275,7 @@ supported.
229
275
..code- block:: cpp
230
276
231
277
# This pattern is not supported inside of frozen functions.
278
+
232
279
UInt32::load_(x.data() + 4 )
233
280
234
281
This pattern might be used in C++ code called by the frozen function and can
0 commit comments