3
3
.. _neural_nets :
4
4
5
5
Neural Networks
6
- ---------------
6
+ ===============
7
7
8
8
Dr.Jit's neural network infrastructure builds on :ref: `cooperative vectors
9
9
<coop_vec>`. Please review their documentation before reading this section.
26
26
27
27
The set of neural network module currently includes:
28
28
29
- - Sequential evaluation of a list of models: :py:class: `nn.Sequential `.
29
+ - Sequential evaluation of a list of models: :py:class: `nn.Sequential <Sequential> `.
30
30
31
- - Linear/affine layers: :py:class: `nn.Linear `.
31
+ - Linear/affine layers: :py:class: `nn.Linear <Linear> `.
32
32
33
- - Encoding layers: :py:class: `nn.SinEncode `, :py:class: `nn.TriEncode `.
33
+ - Encoding layers: :py:class: `nn.SinEncode <SinEncode> `, :py:class: `nn.TriEncode <TriEncode> `.
34
34
35
- - Activation functions and other nonlinear transformations: :py:class: `ReLU `, :py:class: `LeakyReLU `,
36
- :py:class: `Exp `, :py:class: `Exp2 `, :py:class: `Tanh `.
35
+ - Activation functions and other nonlinear transformations: :py:class: `nn. ReLU <ReLU> `, :py:class: `nn. LeakyReLU <LeakyReLU> `,
36
+ :py:class: `nn. Exp <nn.Exp> `, :py:class: `nn. Exp2 <Exp2> `, :py:class: `nn. Tanh <Tanh> `.
37
37
38
- - Miscellaneous: :py:class: `nn.Cast `, :py:class: `nn.ScaleAdd `.
38
+ - Miscellaneous: :py:class: `nn.Cast <Cast> `, :py:class: `nn.ScaleAdd <ScaleAdd> `.
39
39
40
40
Example
41
41
-------
@@ -48,9 +48,13 @@ Great Wave off Kanagawa
48
48
<https://en.wikipedia.org/wiki/The_Great_Wave_off_Kanagawa> `__ (left).
49
49
50
50
.. image :: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/coopvec-screenshot.png
51
- :width: 300
51
+ :width: 600
52
52
:align: center
53
53
54
+ The optimization uses the :py:class: `dr.opt.Adam <drjit.opt.Adam> ` optimizer
55
+ and :py:class: `dr.opt.GradScaler <drjit.opt.GradScaler> ` gradient scaler for
56
+ adaptive mixed-precision training.
57
+
54
58
.. code-block :: python
55
59
56
60
from tqdm.auto import tqdm
@@ -126,3 +130,4 @@ Great Wave off Kanagawa
126
130
ax[1 ].imshow(dr.clip(img, 0 , 1 ))
127
131
fig.tight_layout()
128
132
plt.show()
133
+
0 commit comments