Skip to content

Commit cd67909

Browse files
committed
miscellaneous fixes
1 parent 9638fcb commit cd67909

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

docs/nn.rst

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
.. _neural_nets:
44

55
Neural Networks
6-
---------------
6+
===============
77

88
Dr.Jit's neural network infrastructure builds on :ref:`cooperative vectors
99
<coop_vec>`. Please review their documentation before reading this section.
@@ -26,16 +26,16 @@ List
2626

2727
The set of neural network module currently includes:
2828

29-
- Sequential evaluation of a list of models: :py:class:`nn.Sequential`.
29+
- Sequential evaluation of a list of models: :py:class:`nn.Sequential <Sequential>`.
3030

31-
- Linear/affine layers: :py:class:`nn.Linear`.
31+
- Linear/affine layers: :py:class:`nn.Linear <Linear>`.
3232

33-
- Encoding layers: :py:class:`nn.SinEncode`, :py:class:`nn.TriEncode`.
33+
- Encoding layers: :py:class:`nn.SinEncode <SinEncode>`, :py:class:`nn.TriEncode <TriEncode>`.
3434

35-
- Activation functions and other nonlinear transformations: :py:class:`ReLU`, :py:class:`LeakyReLU`,
36-
:py:class:`Exp`, :py:class:`Exp2`, :py:class:`Tanh`.
35+
- Activation functions and other nonlinear transformations: :py:class:`nn.ReLU <ReLU>`, :py:class:`nn.LeakyReLU <LeakyReLU>`,
36+
:py:class:`nn.Exp <nn.Exp>`, :py:class:`nn.Exp2 <Exp2>`, :py:class:`nn.Tanh <Tanh>`.
3737

38-
- Miscellaneous: :py:class:`nn.Cast`, :py:class:`nn.ScaleAdd`.
38+
- Miscellaneous: :py:class:`nn.Cast <Cast>`, :py:class:`nn.ScaleAdd <ScaleAdd>`.
3939

4040
Example
4141
-------
@@ -48,9 +48,13 @@ Great Wave off Kanagawa
4848
<https://en.wikipedia.org/wiki/The_Great_Wave_off_Kanagawa>`__ (left).
4949

5050
.. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/coopvec-screenshot.png
51-
:width: 300
51+
:width: 600
5252
:align: center
5353

54+
The optimization uses the :py:class:`dr.opt.Adam <drjit.opt.Adam>` optimizer
55+
and :py:class:`dr.opt.GradScaler <drjit.opt.GradScaler>` gradient scaler for
56+
adaptive mixed-precision training.
57+
5458
.. code-block:: python
5559
5660
from tqdm.auto import tqdm
@@ -126,3 +130,4 @@ Great Wave off Kanagawa
126130
ax[1].imshow(dr.clip(img, 0, 1))
127131
fig.tight_layout()
128132
plt.show()
133+

drjit/nn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from __future__ import annotations
22
import drjit
3-
from typing import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any
3+
import sys
4+
5+
if sys.version_info < (3, 11):
6+
from typing_extensions import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any
7+
else:
8+
from typing import Tuple, Sequence, Union, Type, TypeAlias, Optional, Any
49

510
# Import classes/functions from C++ extension
611
MatrixView = drjit.detail.nn.MatrixView

src/python/coop_vec.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,13 @@ static nb::object repack_impl(const char *name, MatrixLayout layout,
447447
nb::list result;
448448
for (nb::handle h : l)
449449
result.append(repack_impl(name, layout, h, offset, items));
450-
return result;
450+
return std::move(result);
451451
} else if (arg_tp.is(&PyDict_Type)) {
452452
nb::dict d = nb::borrow<nb::dict>(arg);
453453
nb::dict result;
454454
for (auto [k, v] : d)
455455
result[k] = repack_impl(name, layout, v, offset, items);
456-
return result;
456+
return std::move(result);
457457
} else if (nb::dict ds = get_drjit_struct(arg_tp); ds.is_valid()) {
458458
nb::object tmp = arg_tp();
459459
for (auto [k, v] : ds)

0 commit comments

Comments
 (0)