Skip to content

Commit 58d45e5

Browse files
committed
Incorporate review feedback
1 parent 0247195 commit 58d45e5

File tree

10 files changed

+72
-41
lines changed

10 files changed

+72
-41
lines changed

docs/coop_vec.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Dr.Jit supports cooperative vectors on both of its backends:
2828
<https://raytracing-docs.nvidia.com/optix9/guide/index.html#cooperative_vectors#neural-rendering-with-cooperative-vectors>`__,
2929
leveraging built-in `tensor cores
3030
<https://www.nvidia.com/en-us/data-center/tensor-cores/>`__ for acceleration.
31+
Driver version R570 or newer is required to use this feature.
3132

3233
- On the **CPU (LLVM) backend**, compilation of cooperative vector operations
3334
targets the available instruction set extensions (AVX512, NEON, etc.).

docs/nn.rst

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,39 +90,41 @@ mixed-precision training.
9090
net = net.alloc(TensorXf16, 2)
9191
9292
# Convert to training-optimal layout
93-
coeffs, net = nn.pack(net, layout='training')
93+
weights, net = nn.pack(net, layout='training')
9494
print(net)
9595
96-
# Optimize a single precision copy of the parameters
97-
opt = Adam(lr=1e-3, params={'coeffs': Float32(coeffs)})
96+
# Optimize a single-precision copy of the parameters
97+
opt = Adam(lr=1e-3, params={'weights': Float32(weights)})
9898
9999
# This is an adaptive mixed-precision (AMP) optimization, where a half
100-
# precision computation runs within a larger single precision program.
100+
# precision computation runs within a larger single-precision program.
101101
# Gradient scaling is required to make this numerically well-behaved.
102102
scaler = GradScaler()
103103
104104
res = 256
105105
106106
for i in tqdm(range(40000)):
107107
# Update network state from optimizer
108-
coeffs[:] = Float16(opt['coeffs'])
108+
weights[:] = Float16(opt['weights'])
109109
110110
# Generate jittered positions on [0, 1]^2
111111
t = dr.arange(Float32, res)
112-
p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res*res))) / res
112+
p = (Array2f(dr.meshgrid(t, t)) + dr.rand(Array2f, (2, res * res))) / res
113113
114114
# Evaluate neural net + L2 loss
115115
img = Array3f(net(nn.CoopVec(p)))
116-
loss = dr.squared_norm(tex.eval(p)-img)
116+
loss = dr.squared_norm(tex.eval(p) - img)
117117
118118
# Mixed-precision training: take suitably scaled steps
119119
dr.backward(scaler.scale(loss))
120120
scaler.step(opt)
121121
122122
# Done optimizing, now let's plot the result
123123
t = dr.linspace(Float32, 0, 1, res)
124-
p= Array2f(dr.meshgrid(t, t))
124+
p = Array2f(dr.meshgrid(t, t))
125125
img = Array3f(net(nn.CoopVec(p)))
126+
127+
# Convert 'img' with shape 3 x (N*N) into a N x N x 3 tensor
126128
img = dr.reshape(TensorXf(img, flip_axes=True), (res, res, 3))
127129
128130
import matplotlib.pyplot as plt

docs/what.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Using Dr.Jit involves two steps:
2424
Perhaps the most significant difference to the majority of existing tools is
2525
that Dr.Jit is *not primarily* a machine learning library. While it does
2626
provide support for neural network :ref:`evaluation and training <neural_nets>`,
27-
it its sweet spot are non-neural programs characterized by *embarrassing
27+
its sweet spot are non-neural programs characterized by *embarrassing
2828
parallelism*---that is to say, programs with large data-parallel regions. A
2929
good example of this are `Monte Carlo
3030
<https://en.wikipedia.org/wiki/Monte_Carlo_method>`__ methods with their

drjit/nn.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,16 @@ def __call__(self, arg: CoopVec, /) -> CoopVec:
6161
raise NotImplementedError(f"{type(self).__name__}.__call__() implementation is missing.")
6262

6363
def _alloc(self, dtype: Type[drjit.ArrayBase], size: int, /) -> Tuple[Module, int]:
64+
"""
65+
Internal method used to propagate argument sizes and allocate weight
66+
storage of all NN modules.
67+
68+
The method takes to parameters as input: a weight storage type
69+
``dtype`` (e.g., :py:class:`drjit.cuda.ad.TensorXf16`) and ``size``,
70+
the number of input arguments of the module. The function returns a
71+
potentially new module instance with allocated weights, plus the number
72+
of outputs.
73+
"""
6474
return self, size
6575

6676
def alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1) -> Module:
@@ -110,7 +120,7 @@ def __len__(self):
110120
"""Return the number of contained models"""
111121
return len(self.layers)
112122

113-
def __getitem__(self, index: Union[int], /) -> Module: # type: ignore
123+
def __getitem__(self, index: int, /) -> Module: # type: ignore
114124
"""Return the model at position ``index``"""
115125
return self.layers[index]
116126

@@ -155,8 +165,8 @@ class LeakyReLU(Module):
155165
\end{cases}
156166
"""
157167

158-
DRJIT_STRUCT = { 'negative_slope': float }
159-
def __init__(self, negative_slope: float = 1e-2):
168+
DRJIT_STRUCT = { 'negative_slope': Union[float, drjit.ArrayBase] }
169+
def __init__(self, negative_slope: Union[float, drjit.ArrayBase] = 1e-2):
160170
self.negative_slope = negative_slope
161171

162172
def __call__(self, arg: CoopVec, /) -> CoopVec:
@@ -449,8 +459,8 @@ def __init__(self, octaves: int = 0, shift: float = 0) -> None:
449459
if shift == 0:
450460
self.shift = None
451461
else:
452-
self.shift = (drjit.sin(shift*2*drjit.pi),
453-
drjit.cos(shift*2*drjit.pi))
462+
self.shift = (drjit.sin(shift * 2 * drjit.pi),
463+
drjit.cos(shift * 2 * drjit.pi))
454464

455465
def _alloc(self, dtype: Type[drjit.ArrayBase], size : int = -1, /) -> Tuple[Module, int]:
456466
return self, size * self.octaves * 2

include/drjit/extra.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_matvec(uint64_t A_index,
546546
int transpose);
547547

548548
/// Cast a cooperative vector to a different precision
549-
extern JIT_EXPORT uint64_t ad_coop_vec_cast(uint64_t index, VarType vt);
549+
extern DRJIT_EXTRA_EXPORT uint64_t ad_coop_vec_cast(uint64_t index, VarType vt);
550550

551551
#if defined(__cplusplus)
552552
}

src/python/coop_vec.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ void export_coop_vec(nb::module_ &m) {
561561
coop_vector_type = nb::class_<CoopVec>(nn, "CoopVec", nb::is_generic(), nb::sig("class CoopVec(typing.Generic[T])"))
562562
.def(nb::init<nb::args>(),
563563
nb::sig("def __init__(self, *args: typing.Unpack[typing.Tuple[typing.Union[drjit.ArrayBase[SelfT, SelfCpT, ValT, ValCpT, T, PlainT, MaskT], float, int], ...]]) -> None"),
564-
doc_coop_CoopVec_init)
564+
doc_nn_CoopVec_init)
565565
.def("__iter__", [](const CoopVec &v) { return iter(v.expand_to_list()); },
566566
nb::sig("def __iter__(self, /) -> typing.Iterator[T]"))
567567
.def("__add__", &coop_vec_binary_op<JitOp::Add>,
@@ -587,7 +587,7 @@ void export_coop_vec(nb::module_ &m) {
587587
jit_var_size(v.m_index));
588588
});
589589

590-
view_type = nb::class_<MatrixView>(nn, "MatrixView", doc_coop_MatrixView)
590+
view_type = nb::class_<MatrixView>(nn, "MatrixView", doc_nn_MatrixView)
591591
.def(nb::init<>())
592592
.def("__repr__", &MatrixView::repr)
593593
.def("__getitem__", &MatrixView::getitem,
@@ -669,12 +669,12 @@ void export_coop_vec(nb::module_ &m) {
669669
view_type.attr("DRJIT_STRUCT") = drjit_struct;
670670

671671
nn.def("view", &view,
672-
doc_coop_view);
672+
doc_nn_view);
673673

674674
nn.def("pack", [](nb::handle arg, const char *layout) { return repack("pack", layout, arg); },
675675
nb::arg(), "layout"_a = "inference",
676676
nb::sig("def pack(arg: MatrixView | drjit.AnyArray, *, layout: typing.Literal['inference', 'training'] = 'inference') -> typing.Tuple[drjit.ArrayBase, MatrixView]"),
677-
doc_coop_pack);
677+
doc_nn_pack);
678678

679679
nn.def("pack",
680680
[](nb::args args, const char *layout) {
@@ -692,7 +692,7 @@ void export_coop_vec(nb::module_ &m) {
692692
nn.def("unpack", [](nb::handle arg) {
693693
return repack("unpack", nullptr, arg); },
694694
nb::sig("def unpack(arg: MatrixView | drjit.AnyArray, /) -> typing.Tuple[drjit.ArrayBase, MatrixView]"),
695-
doc_coop_unpack);
695+
doc_nn_unpack);
696696

697697
nn.def("unpack",
698698
[](nb::args args) {
@@ -710,7 +710,7 @@ void export_coop_vec(nb::module_ &m) {
710710
"b"_a.noconvert() = nb::none(), "transpose"_a = false,
711711
nb::sig("def matvec(A: MatrixView, x: drjit.nn.CoopVec[T], b: typing.Optional[MatrixView] = "
712712
"None, /, transpose: bool = False) -> drjit.nn.CoopVec[T]"),
713-
doc_coop_matvec);
713+
doc_nn_matvec);
714714

715715
nn.def("cast",
716716
[](CoopVec vec, nb::type_object_t<drjit::ArrayBase> tp) {
@@ -721,7 +721,7 @@ void export_coop_vec(nb::module_ &m) {
721721
return CoopVec(ad_coop_vec_cast(vec.m_index, (VarType) s.type),
722722
vec.m_size, new_type);
723723
}, nb::sig("def cast(arg0: CoopVec[T], arg1: typing.Type[ArrayT], /) -> CoopVec[ArrayT]"),
724-
doc_coop_cast
724+
doc_nn_cast
725725
);
726726

727727
m.def("fma", &coop_vec_ternary_op<JitOp::Fma>);

src/python/docstr.rst

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8130,12 +8130,14 @@
81308130
Returns:
81318131
object: The computed array as described above
81328132
8133-
.. topic:: coop_CoopVec
8133+
.. topic:: nn_CoopVec
81348134

81358135
A *cooperative vector* is a dynamically-sized container of elements of a
81368136
consistent type. It admits both floating point and integer 1D arrays as
81378137
elements (e.g., :py:class:`drjit.cuda.Float16`,
8138-
:py:class:`drjit.llvm.UInt32`).
8138+
:py:class:`drjit.llvm.UInt32`). Cooperative vectors primarily exist to
8139+
enable the compilation of expressions that make use of matrix-vector
8140+
multiplication.
81398141

81408142
Seen from a high level, cooperative vectors resemble nested array types,
81418143
such as as :py:class:`drjit.cuda.ArrayXf16`. A variety of conversions
@@ -8177,7 +8179,7 @@
81778179
To unpack a cooperative vector into its components, use an expression
81788180
like ``x, y, z = vec``, ``ArrayXf(vec)``, or ``list(vec)``.
81798181

8180-
.. topic:: coop_CoopVec_init
8182+
.. topic:: nn_CoopVec_init
81818183

81828184
The constructor accepts a variable number of arguments including Dr.Jit
81838185
arrays, scalar Python integers and floating point values, and :ref:`PyTrees
@@ -8188,7 +8190,7 @@
81888190
the input contains Dr.Jit arrays of inconsistent scalar types (e.g.,
81898191
:py:class:`drjit.cuda.Array2f` and :py:class:`drjit.cuda.UInt`).
81908192

8191-
.. topic:: coop_MatrixView
8193+
.. topic:: nn_MatrixView
81928194

81938195
The :py:class:`drjit.nn.MatrixView` provides pointer into a buffer along with
81948196
shape and type metadata.
@@ -8203,7 +8205,7 @@
82038205
representation. The returned views can then be passed to
82048206
:py:func:`drjit.nn.matvec()`.
82058207

8206-
.. topic:: coop_view
8208+
.. topic:: nn_view
82078209

82088210
Convert a Dr.Jit array or tensor into a *view*.
82098211

@@ -8221,13 +8223,13 @@
82218223
directly re-packed into optimal layouts without performing further
82228224
unnecessary copies.
82238225

8224-
.. topic:: coop_pack
8226+
.. topic:: nn_pack
82258227

8226-
A training-optimal layout must be used used if the program
8227-
*backpropagates* (as in :py:func:`dr.backward*() <drjit.backward>`)
8228-
gradients through matrix-vector products. Forward derivative propagation (as
8229-
in :py:func:`dr.forward*() <drjit.forward>`) does not require a
8230-
training-optimal layout.
8228+
A training-optimal layout must be used used if the program *backpropagates*
8229+
(as in :py:func:`dr.backward*() <drjit.backward>`) gradients through
8230+
matrix-vector products. Inference (primal evaluation) and forward derivative
8231+
propagation (as in :py:func:`dr.forward*() <drjit.forward>`) does not
8232+
require a training-optimal layout.
82318233

82328234
If the input matrices are already packed in a row-major layout, call
82338235
:py:func:`dr.nn.view() <drjit.nn.view>` to create an efficient reference
@@ -8244,7 +8246,7 @@
82448246
mat_view[32:64, :]
82458247
)
82468248
8247-
.. topic:: coop_unpack
8249+
.. topic:: nn_unpack
82488250

82498251
The function :py:func:`dr.nn.unpack() <drjit.nn.unpack>` transforms a
82508252
sequence (or :ref:`PyTree <pytrees>`) of vectors and optimal-layout matrices
@@ -8255,13 +8257,14 @@
82558257
A_out, b_out = dr.nn.unpack(A_opt, b_opt)
82568258
82578259
Note that the output of this function are (row-major) *views* into a shared
8258-
buffer. These views can be converted back into regular tensors:
8260+
buffer. Each view holds a reference to the shared buffer. Views can be
8261+
converted back into regular tensors:
82598262

82608263
.. code-block:: python
82618264
82628265
A = TensorXf16(A)
82638266
8264-
.. topic:: coop_matvec
8267+
.. topic:: nn_matvec
82658268

82668269
Evaluate a matrix-vector multiplication involving a cooperative vector.
82678270

@@ -8275,9 +8278,9 @@
82758278
+ b``). This bias vector ``b`` should also be specified as a view.
82768279

82778280
Specify ``tranpose=True`` to multiply by the transpose of the matrix ``A``.
8278-
On the CUDA/OptiX backend, this feature requires that ``A`` is inference
8281+
On the CUDA/OptiX backend, this feature requires that ``A`` is in inference
82798282
or training-optimal layout.
82808283

8281-
.. topic:: coop_cast
8284+
.. topic:: nn_cast
82828285

82838286
Cast the numeric type underlying a cooperative vector

src/python/eval.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ static void make_opaque(nb::handle h) {
6868

6969
ad_var_dec_ref(index_new);
7070
}
71+
72+
void traverse_unknown(nb::handle h) override {
73+
if (h.type().is(local_type)) {
74+
Local & local = nb::cast<Local&>(h);
75+
for (uint32_t index : local.arrays())
76+
result |= (bool) jit_var_schedule(index);
77+
}
78+
if (h.type().is(coop_vector_type))
79+
nb::raise("Cooperative vectors cannot be evaluated. They must be unpacked into regular variables.");
80+
}
7181
};
7282

7383
ScheduleForceCallback sfc;

src/python/tracker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ bool VariableTracker::Impl::traverse(Context &ctx, nb::handle h) {
333333
ctx.label.c_str(), nb::inst_name(prev).c_str(),
334334
nb::type_name(tp).c_str());
335335

336-
// Were there any external changes to sub-PyTree variable indices (As
336+
// Were there any external changes to sub-PyTree variable indices (as
337337
// opposed to changes done by the VariableTracker)
338338
bool changed = false;
339339

tests/test_coop_vec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def skip_if_coopvec_not_supported(t):
77
if dr.backend_v(t) == dr.JitBackend.CUDA:
88
if dr.detail.cuda_version() < (12, 8):
9-
pytest.skip("CUDA driver does not support cooperative vectors")
9+
pytest.skip("CUDA driver does not support cooperative vectors (Driver R570) or later is required")
1010

1111
@pytest.test_arrays('jit,float16,shape=(3, *),-diff', 'jit,float32,shape=(3, *),-diff')
1212
def test01_pack_unpack(t):
@@ -20,6 +20,7 @@ def test01_pack_unpack(t):
2020
assert len(nn.CoopVec(*x, 2, (4, 5), *x)) == 19
2121
y = list(x)
2222
z = m.ArrayXf(x)
23+
assert len(y) == 8 and len(z) == 8
2324
result_ok = True
2425
for i in range(8):
2526
result_ok &= dr.all(y[i] == i+1)
@@ -258,7 +259,7 @@ def test10_fwd_addition(t):
258259
def test11_bwd_mul(t):
259260
skip_if_coopvec_not_supported(t)
260261

261-
# Propagate forward gradients through an addition
262+
# Propagate forward gradients through a multiplication
262263
a, b = t(8), t(9)
263264
c, d = t(3), t(2)
264265
dr.enable_grad(a, b, c, d)
@@ -523,5 +524,9 @@ def test19_no_eval(t):
523524
# Cooperative vectors cannot be evaluted via dr.eval()
524525
UInt32 = dr.uint32_array_t(t)
525526
a = nn.CoopVec(t(1), t(2))
527+
with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"):
528+
dr.schedule(a)
526529
with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"):
527530
dr.eval(a)
531+
with pytest.raises(RuntimeError, match="Cooperative vectors cannot be evaluated"):
532+
dr.make_opaque(a)

0 commit comments

Comments
 (0)