Skip to content

Commit 9638fcb

Browse files
committed
updates source code in nn example
1 parent 54b67a0 commit 9638fcb

File tree

2 files changed

+62
-24
lines changed

2 files changed

+62
-24
lines changed

docs/nn.rst

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,48 +43,86 @@ Example
4343
Below is a fully worked out example demonstrating how to use it to declare and
4444
optimize a small `multilayer perceptron
4545
<https://en.wikipedia.org/wiki/Multilayer_perceptron>`__ (MLP). This network
46-
implements a 2D neural field that we fit to an image.
46+
implements a 2D neural field (right) that we then fit to a low-resolution image of `The
47+
Great Wave off Kanagawa
48+
<https://en.wikipedia.org/wiki/The_Great_Wave_off_Kanagawa>`__ (left).
49+
50+
.. image:: https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/coopvec-screenshot.png
51+
:width: 300
52+
:align: center
4753

4854
.. code-block:: python
4955
56+
from tqdm.auto import tqdm
57+
import imageio.v3 as iio
5058
import drjit as dr
5159
import drjit.nn as nn
60+
from drjit.opt import Adam, GradScaler
61+
from drjit.auto.ad import Texture2f, TensorXf, TensorXf16, Float16, Float32, Array2f, Array3f
62+
63+
# Load a test image and construct a texture object
64+
ref = TensorXf(iio.imread("https://rgl.s3.eu-central-1.amazonaws.com/media/uploads/wjakob/2024/06/wave-128.png") / 256)
65+
tex = Texture2f(ref)
5266
53-
from drjit.llvm.ad import TensorXf16
54-
from drjit.opt import Adam
67+
# Ensure consistent results when re-running the following
68+
dr.seed(0)
5569
5670
# Establish the network structure
5771
net = nn.Sequential(
58-
nn.Linear(-1, 32, bias=False),
59-
nn.ReLU(),
60-
nn.Linear(-1, -1),
61-
nn.ReLU(),
72+
nn.TriEncode(16, 0.2),
73+
nn.Cast(Float16),
74+
nn.Linear(-1, -1, bias=False),
75+
nn.LeakyReLU(),
76+
nn.Linear(-1, -1, bias=False),
77+
nn.LeakyReLU(),
78+
nn.Linear(-1, -1, bias=False),
79+
nn.LeakyReLU(),
6280
nn.Linear(-1, 3, bias=False),
63-
nn.Tanh()
81+
nn.Exp()
6482
)
6583
6684
# Instantiate the network for a specific backend + input size
6785
net = net.alloc(TensorXf16, 2)
6886
69-
# Pack coefficients into a training-optimal layout
87+
# Convert to training-optimal layout
7088
coeffs, net = nn.pack(net, layout='training')
89+
print(net)
7190
72-
# Optimize a float32 version of the packed coefficients
91+
# Optimize a single precision copy of the parameters
7392
opt = Adam(lr=1e-3, params={'coeffs': Float32(coeffs)})
7493
75-
# Update network state from optimizer
76-
for i in range(1000):
77-
# Update neural network state
78-
coeffs[:] = Float16(opt['coeffs'])
94+
# This is an adaptive mixed-precision (AMP) optimization, where a half
95+
# precision computation runs within a larger single precision program.
96+
# Gradient scaling is required to make this numerically well-behaved.
97+
scaler = GradScaler()
7998
80-
# Create input
81-
out = net(nn.CoopVec(...))
99+
res = 256
82100
83-
# Unpack
84-
out = Array3f16(result)
85-
86-
# Backpropagate
87-
dr.backward(dr.square(reference-out))
101+
for i in tqdm(range(40000)):
102+
# Update network state from optimizer
103+
coeffs[:] = Float16(opt['coeffs'])
88104
89-
# Take a gradient step
90-
opt.step()
105+
# Generate jittered positions on [0, 1]^2
106+
t = dr.arange(Float32, res)
107+
p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res*res))) / res
108+
109+
# Evaluate neural net + L2 loss
110+
img = Array3f(net(nn.CoopVec(p)))
111+
loss = dr.squared_norm(tex.eval(p)-img)
112+
113+
# Mixed-precision training: take suitably scaled steps
114+
dr.backward(scaler.scale(loss))
115+
scaler.step(opt)
116+
117+
# Done optimizing, now let's plot the result
118+
t = dr.linspace(Float32, 0, 1, res)
119+
p= Array2f(dr.meshgrid(t, t))
120+
img = Array3f(net(nn.CoopVec(p)))
121+
img = dr.reshape(TensorXf(img, flip_axes=True), (res, res, 3))
122+
123+
import matplotlib.pyplot as plt
124+
fig, ax = plt.subplots(1, 2, figsize=(10,5))
125+
ax[0].imshow(ref)
126+
ax[1].imshow(dr.clip(img, 0, 1))
127+
fig.tight_layout()
128+
plt.show()

0 commit comments

Comments
 (0)