Skip to content

Commit 44dcf38

Browse files
Wenzel Jakobwjakob
Wenzel Jakob
authored andcommitted
Cooperative Vector API
Cooperative vectors enable efficient compilation and evaluation of expressions involving matrix multiplication. They cater to a specific use case, where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways. On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.). For further details on this new API and now to use it, refer to the documentation in ``docs/coop_vec.rst``.
1 parent dcb5217 commit 44dcf38

33 files changed

+3504
-145
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ if (DRJIT_ENABLE_JIT)
108108
set_target_properties(nanothread PROPERTIES ${DRJIT_OUTPUT_DIRECTORY})
109109
endif()
110110

111-
mark_as_advanced(NANOTHREAD_ENABLE_TESTS)
111+
mark_as_advanced(NANOTHREAD_ENABLE_TESTS NANOTHREAD_STATIC)
112112
mark_as_advanced(DRJIT_CORE_ENABLE_TESTS)
113113
mark_as_advanced(NB_TEST NB_TEST_SHARED_BUILD NB_TEST_STABLE_ABI NB_USE_SUBMODULE_DEPS NB_TEST_SANITZE NB_CREATE_INSTALL_RULES nanobind_DIR)
114114
mark_as_advanced(NB_TEST_CUDA NB_TEST_FREE_THREADED NB_TEST_SANITIZERS_ASAN NB_TEST_SANITIZERS_TSAN NB_TEST_SANITIZERS_UBSAN)

docs/autodiff.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ Dr.Jit how a particular operation should be differentiated. Reasons for this
427427
may include:
428428

429429
- The automatic differentiation backend cannot keep track of computation
430-
performed outside of Dr.Jit (e.g. using a highly optimized :ref:`CUDA kernel
431-
<custom-cuda>`). In this case, review the section on :ref:`interoperability
430+
performed outside of Dr.Jit (e.g. using custom CUDA kernels). In this case,
431+
review the section on :ref:`interoperability
432432
<interop>`, since it presents a potentially simpler solution.
433433

434434
- The derivative may admit a simplified analytic expression that is superior to

docs/changelog.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ Here is what's new:
348348

349349

350350
⚠️ Compatibility ⚠️
351-
------------------
351+
-------------------
352352

353353
- **Symbolic loop syntax**: the old "recorded loop" syntax is no longer
354354
supported. Existing code will need adjustments to use

docs/coop_vec.rst

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
.. py:currentmodule:: drjit
2+
3+
.. cpp:namespace:: drjit
4+
5+
.. _coop_vec:
6+
7+
Cooperative vectors
8+
===================
9+
10+
*Cooperative vectors* are a `new API
11+
<https://github.com/KhronosGroup/GLSL/blob/main/extensions/nv/GLSL_NV_cooperative_vector.txt>`__
12+
for evaluating matrix-vector products in certain types of GPU workloads. They
13+
are designed to handle cases, where each thread of a parallel program needs
14+
to multiply a vector by a reasonably small matrix (e.g., 64x64 or fewer
15+
entries). By working together, the threads can perform these multiplications
16+
more efficiently, which is why the approach is called *cooperative*.
17+
18+
Cooperative vectors are especially useful for evaluating small `multilayer
19+
perceptrons <https://en.wikipedia.org/wiki/Multilayer_perceptron>`__ (MLPs)
20+
within larger programs while fully *fusing* all steps of the process into a
21+
single kernel. Other workloads that heavily rely on matrix-vector products may
22+
benefit as well.
23+
24+
Dr.Jit supports cooperative vectors on both of its backends:
25+
26+
- On **NVIDIA GPUs (Turing or newer)**, cooperative vectors map to the OptiX
27+
`cooperative vector API
28+
<https://raytracing-docs.nvidia.com/optix9/guide/index.html#cooperative_vectors#neural-rendering-with-cooperative-vectors>`__,
29+
leveraging built-in `tensor cores
30+
<https://www.nvidia.com/en-us/data-center/tensor-cores/>`__ for acceleration.
31+
32+
- On the **CPU (LLVM) backend**, compilation of cooperative vector operations
33+
targets the available instruction set extensions (AVX512, NEON, etc.).
34+
35+
Code snippets in the remainder of this section assume the following include
36+
directives:
37+
38+
.. code-block:: python
39+
40+
import drjit as dr
41+
import drjit.nn as nn
42+
from drjit.auto.ad import Float16, TensorXf16
43+
44+
Motivation
45+
----------
46+
47+
The cooperative vector API is available via the :py:mod:`drjit.nn` submodule.
48+
Below is an example demonstrating how to use it to perform a matrix
49+
multiplication.
50+
51+
.. code-block:: python
52+
53+
# Matrix shape
54+
m, n = 3, 16
55+
56+
# Create a random matrix + offset
57+
A = dr.normal(TensorXf, (m, n))
58+
b = dr.rand(TensorXf, m)
59+
60+
# Pack 'A' and 'b' into a buffer with an optimal layout
61+
buffer, A_view, b_view = nn.pack(A, b)
62+
63+
# Create a cooperative vector
64+
x = nn.CoopVec(... 16 values ...)
65+
66+
# Evaluate A @ x + b
67+
v_out = nn.matvec(A_view, v_in, b_view)
68+
69+
# Unpack the resulting cooperative vector
70+
x, y, z = v_out
71+
72+
This involves the following steps:
73+
74+
- Initializing matrix data and packing it into an optimized memory layout using
75+
:py:func:`nn.pack() <drjit.nn.pack>`.
76+
77+
- Constructing a :py:class:`nn.CoopVec` containing the inputs to the matrix
78+
multiplication.inputs.
79+
80+
- Performing one or more matrix-vector multiplications and other arithmetic,
81+
while keeping the state in cooperative vector form.
82+
83+
- Unpacking the final cooperative vector into regular Dr.Jit arrays.
84+
85+
Cooperative vectors
86+
-------------------
87+
88+
The central type of this API is the *cooperative vector* class
89+
:py:class:`nn.CoopVec`. This is a dynamically sized vector with uniformly
90+
typed elements.
91+
92+
Unlike regular Dr.Jit arrays (e.g. :py:class:`drjit.cuda.ArrayXf`), cooperative
93+
vectors *do not allow indexed element access*. For example, the following
94+
operation raises an exception:
95+
96+
.. code-block:: pycon
97+
98+
>>> vec = nn.CoopVec(Float16(1), Float16(2))
99+
>>> vec[1]
100+
Traceback (most recent call last):
101+
File "<stdin>", line 1, in <module>
102+
TypeError: 'drjit.nn.CoopVec' object is not subscriptable
103+
104+
This restriction exists because the compiler may arbitrarily distribute
105+
cooperative vector components across threads for efficiency. Allowing direct
106+
indexing would interfere with this optimization.
107+
108+
The :py:class:`drjit.nn.CoopVec` constructor accepts an arbitrary sequence
109+
of :ref:`PyTrees <pytrees>` containing Dr.Jit array and Python scalars and
110+
flattens them into a cooperative vector:
111+
112+
.. code-block:: python
113+
114+
vec = nn.CoopVec( # Construct a 4D vector
115+
Float16(1),
116+
3.0,
117+
Array2f(4, 5)
118+
)
119+
120+
Use the standard Python unpacking syntax to turn cooperative vectors back into
121+
their components:
122+
123+
.. code-block:: python
124+
125+
x, y, z = vec # Unpack a cooperative 3D vector
126+
x, y, *extra = vec # Unpack first 2 components, put rest into 'extra'
127+
128+
The same syntax can also be used to concatenate vectors:
129+
130+
.. code-block:: python
131+
132+
vec_3 = nn.CoopVec(*vec_1, *vec_2)
133+
134+
Cooperative vectors can also be converted into nested arrays, tensors, or
135+
Python lists:
136+
137+
.. code-block:: python
138+
139+
vec_arr = Array3f(vec)
140+
vec_ten = TensorXf(vec)
141+
vec_lst = list(vec)
142+
143+
Cooperative vectors are compatible with Dr.Jit's symbolic tracing
144+
infrastructure and may be used as state variables in
145+
:py:func:`drjit.while_loop` and :py:func:`drjit.if_stmt`.
146+
147+
Arithmetic
148+
^^^^^^^^^^
149+
150+
Cooperative vectors support a restricted set of arithmetic operations:
151+
152+
- Elementary arithmetic operations: ``+``, ``-``, ``*`` (but no division)
153+
- :py:func:`dr.fma() <fma>`,
154+
- :py:func:`dr.minimum() <minimum>`, :py:func:`dr.maximum() <maximum>`,
155+
- :py:func:`dr.log2() <log2>`, :py:func:`dr.exp2() <exp2>`,
156+
- :py:func:`dr.tanh() <tanh>`,
157+
- :py:func:`dr.step() <step>`.
158+
- :py:func:`nn.matvec() <drjit.nn.matvec>`
159+
160+
These operations directly map to hardware-optimized operations on CUDA/OptiX.
161+
Operations outside of this set can be realized via unpacking/repacking, e.g.:
162+
163+
.. code-block::
164+
165+
x : nn.CoopVec = ...
166+
y = nn.CoopVec(dr.sin(v) for v in x)
167+
168+
However, this may degrade performance. It is best to keep cooperative vectors
169+
in their opaque layout whenever possible.
170+
171+
Arithmetic operations may mix cooperative vectors and regular Dr.Jit arrays or
172+
Python scalars, which will undergo implicit broadcasting.
173+
174+
.. code-block::
175+
176+
x: nn.CoopVec[dr.cuda.Float16] = ...
177+
y: dr.cuda.Float16 = ...
178+
z = dr.maximum(x, 0) + y
179+
180+
.. _matrix_views:
181+
182+
Matrix views
183+
------------
184+
185+
Input matrices and bias vectors should generally be converted into a
186+
hardware-dependent layout to improve performance compared to the default
187+
row-major representation (also, many operations raise exceptions on the
188+
OptiX/CUDA backend when matrices are not in such an optimal layout).
189+
190+
The function :py:func:`nn.pack() <drjit.nn.pack>` performs this conversion and
191+
furthermore packs data into a shared buffer for optimal efficiency. The
192+
function takes an arbitrary sequence of :ref:`PyTrees <pytrees>` as input and
193+
returns a result with the same structure.
194+
195+
.. code-block:: python
196+
197+
A: TensorXf = ...
198+
b: Float = ...
199+
A_view, b_view = nn.pack(A, b, layout='inference')
200+
201+
Every Dr.Jit array or tensor will be replaced by a
202+
:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer
203+
annotated with layout and type metadata. The function can generate optimal
204+
memory layouts for either *inference* (the default) and *training*. You must
205+
specify ``layout='training'`` if you wish to differentiate matrix
206+
multiplication in reverse mode.
207+
208+
Following this step, ``A`` and ``b`` have been merged into ``buffer``, and
209+
``A_view`` and ``b_view`` encode the offset and layout within this larger
210+
buffer. Matrix views *cannot* be used in arithmetic expressions and are best
211+
thought of as opaque handles. They only exist to describe the input of the
212+
matrix-vector multiplication operation explained next.
213+
214+
Two other view-related operations be useful in certain situations, please
215+
see the linked documentation for details.
216+
217+
- :py:func:`drjit.nn.unpack` converts optimal-layout data back into a row-major layout.
218+
- :py:func:`drjit.nn.view` creates row-major views.
219+
220+
Matrix-vector products
221+
----------------------
222+
223+
The main purpose of cooperative vectors is the matrix-vector multiplication
224+
operation :py:func:`nn.matvec() <drjit.nn.matvec>`:
225+
226+
.. code-block:: python
227+
228+
y = nn.matvec(A, x, b) # Compute y = A @ x + b
229+
230+
Here,
231+
232+
- ``A`` and ``b`` are *views* (:py:class:`nn.MatrixView`) created by
233+
:py:func:`nn.pack() <drjit.nn.pack>` or :py:func:`nn.view()
234+
<drjit.nn.view>`.
235+
- ``x`` and ``y`` are cooperative vectors. They are interpreted as *column
236+
vectors*, i.e., ``y = A[:, 0] * x[0] + A[:, 1] * x[1] + ... + b``.
237+
- the ``b`` term is optional.
238+
239+
The function also accepts an optional ``transpose=True`` parameter to compute
240+
:math:`A^Tx + b`.
241+
242+
The standard Python ``A @ x`` and ``A.T @ x`` matrix multiplication syntax
243+
works as well. However, if your computation requires the addition of a ``b``
244+
vector, prefer :py:func:`nn.matvec() <drjit.nn.matvec>` over this syntax, since
245+
it merges both steps into a single operation.
246+
247+
Differentiation
248+
---------------
249+
250+
Cooperative vectors support automatic differentiation. Simply pack variables
251+
with tracked gradients into cooperative vectors---the system will then
252+
propagate derivatives through subsequent operations. Here is an example:
253+
254+
.. code-block:: python
255+
256+
# Differentiable input
257+
a = Array2f16(..)
258+
dr.enable_grad(a)
259+
260+
# Differentiable matrix + bias vector
261+
buffer, A_view, b_view = nn.pack(A, b)
262+
dr.enable_grad(buffer)
263+
264+
# Pack grad-enabled variables into a cooperative vector
265+
x = nn.CoopVec(a)
266+
267+
# Differentiable matrix-vector multiplication
268+
y = dr.matvec(A_view, x, b_view)
269+
270+
r0, r1 = y # Unpack
271+
loss = r0**2 + r1**2 # Continue calculation and ..
272+
dr.backward_from(loss) # .. eventually backpropagate
273+
274+
Specific views or cooperative vectors can also be detached via
275+
:py:func:`drjit.detach()` to inhibit gradient propagation, e.g.:
276+
277+
.. code-block:: python
278+
279+
y = nn.matvec(A_view, dr.detach(x), dr.detach(b_view))
280+
281+
Note that the conversion functions :py:func:`nn.pack() <drjit.nn.pack()>` and
282+
:py:func:`nn.unpack() <drjit.nn.unpack()>` are *not differentiable*. This is
283+
intentional: to train a neural network, convert the initial coefficient values
284+
into training-optimal layout and optimize this representation directly. Doing
285+
so is more efficient than changing layouts twice in every optimization step
286+
(once for the weights and once for their derivatives).
287+
288+
The following AD operations recognize :py:func:`nn.CoopVec
289+
<drjit.nn.CoopVec>` and :py:func:`nn.MatrixView <drjit.nn.MatrixView>` objects:
290+
291+
- :py:func:`grad_enabled`, :py:func:`enable_grad`, :py:func:`disable_grad`.
292+
- :py:func:`detach`.
293+
294+
Performance considerations
295+
--------------------------
296+
297+
- **CUDA/OptiX** backend:
298+
299+
- :py:func:`nn.matvec() <drjit.nn.matvec>` currently requires 16-bit
300+
floating point arguments. FP8 formats may be added in the future.
301+
302+
- Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column
303+
counts are not a multiples of 8 or 16 will be zero-padded internally. There
304+
is no performance benefit in working with such intermediate sizes.
305+
306+
- **LLVM** backend:
307+
308+
- There is no difference between row-major and training/inference-optimal
309+
layouts on the CPU. However, using :py:func:`nn.pack()
310+
<drjit.nn.pack>` is still recommended, since packing multiple arrays
311+
into a shared buffer has a small performance benefit.
312+
313+
- On Intel-compatible processors, using half precision cooperative vectors is
314+
not recommended. FP16 matrix multiplication requires ``AVX512FP16``, an
315+
extension not yet available on consumer CPUs as of 2025. Without this
316+
extension, FP16 computation involves many costly FP16 ↔ FP32 roundtrips.

docs/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ public API.
4646
bench
4747
cpp
4848
textures
49+
coop_vec
50+
nn
4951
faq
5052

5153
.. toctree::

docs/misc.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ resolve at a later point. So here, we have
529529
- ``SelfCp``: a forward reference to ``drjit.llvm.ad._Array2fCp`` (more on this shortly),
530530
- ``ValT``: :py:class:`drjit.llvm.ad.Float`,
531531
- ``ValCpT``: a forward reference to ``drjit.llvm.ad._FloatCp`` (more on this shortly),
532-
- ``RedT``: :py:class`drjit.llvm.ad.Float`,
532+
- ``RedT``: :py:class:`drjit.llvm.ad.Float`,
533533
- ``PlainT``: :py:class:`drjit.llvm.ad.Array2f`, and
534534
- ``MaskT``: :py:class:`drjit.llvm.ad.Array2b`.
535535

0 commit comments

Comments
 (0)