|
| 1 | +.. py:currentmodule:: drjit |
| 2 | +
|
| 3 | +.. cpp:namespace:: drjit |
| 4 | + |
| 5 | +.. _coop_vec: |
| 6 | + |
| 7 | +Cooperative vectors |
| 8 | +=================== |
| 9 | + |
| 10 | +*Cooperative vectors* are a `new API |
| 11 | +<https://github.com/KhronosGroup/GLSL/blob/main/extensions/nv/GLSL_NV_cooperative_vector.txt>`__ |
| 12 | +for evaluating matrix-vector products in certain types of GPU workloads. They |
| 13 | +are designed to handle cases, where each thread of a parallel program needs |
| 14 | +to multiply a vector by a reasonably small matrix (e.g., 64x64 or fewer |
| 15 | +entries). By working together, the threads can perform these multiplications |
| 16 | +more efficiently, which is why the approach is called *cooperative*. |
| 17 | + |
| 18 | +Cooperative vectors are especially useful for evaluating small `multilayer |
| 19 | +perceptrons <https://en.wikipedia.org/wiki/Multilayer_perceptron>`__ (MLPs) |
| 20 | +within larger programs while fully *fusing* all steps of the process into a |
| 21 | +single kernel. Other workloads that heavily rely on matrix-vector products may |
| 22 | +benefit as well. |
| 23 | + |
| 24 | +Dr.Jit supports cooperative vectors on both of its backends: |
| 25 | + |
| 26 | +- On **NVIDIA GPUs (Turing or newer)**, cooperative vectors map to the OptiX |
| 27 | + `cooperative vector API |
| 28 | + <https://raytracing-docs.nvidia.com/optix9/guide/index.html#cooperative_vectors#neural-rendering-with-cooperative-vectors>`__, |
| 29 | + leveraging built-in `tensor cores |
| 30 | + <https://www.nvidia.com/en-us/data-center/tensor-cores/>`__ for acceleration. |
| 31 | + |
| 32 | +- On the **CPU (LLVM) backend**, compilation of cooperative vector operations |
| 33 | + targets the available instruction set extensions (AVX512, NEON, etc.). |
| 34 | + |
| 35 | +Code snippets in the remainder of this section assume the following include |
| 36 | +directives: |
| 37 | + |
| 38 | +.. code-block:: python |
| 39 | +
|
| 40 | + import drjit as dr |
| 41 | + import drjit.nn as nn |
| 42 | + from drjit.auto.ad import Float16, TensorXf16 |
| 43 | +
|
| 44 | +Motivation |
| 45 | +---------- |
| 46 | + |
| 47 | +The cooperative vector API is available via the :py:mod:`drjit.nn` submodule. |
| 48 | +Below is an example demonstrating how to use it to perform a matrix |
| 49 | +multiplication. |
| 50 | + |
| 51 | +.. code-block:: python |
| 52 | +
|
| 53 | + # Matrix shape |
| 54 | + m, n = 3, 16 |
| 55 | +
|
| 56 | + # Create a random matrix + offset |
| 57 | + A = dr.normal(TensorXf, (m, n)) |
| 58 | + b = dr.rand(TensorXf, m) |
| 59 | +
|
| 60 | + # Pack 'A' and 'b' into a buffer with an optimal layout |
| 61 | + buffer, A_view, b_view = nn.pack(A, b) |
| 62 | +
|
| 63 | + # Create a cooperative vector |
| 64 | + x = nn.CoopVec(... 16 values ...) |
| 65 | +
|
| 66 | + # Evaluate A @ x + b |
| 67 | + v_out = nn.matvec(A_view, v_in, b_view) |
| 68 | +
|
| 69 | + # Unpack the resulting cooperative vector |
| 70 | + x, y, z = v_out |
| 71 | +
|
| 72 | +This involves the following steps: |
| 73 | + |
| 74 | +- Initializing matrix data and packing it into an optimized memory layout using |
| 75 | + :py:func:`nn.pack() <drjit.nn.pack>`. |
| 76 | + |
| 77 | +- Constructing a :py:class:`nn.CoopVec` containing the inputs to the matrix |
| 78 | + multiplication.inputs. |
| 79 | + |
| 80 | +- Performing one or more matrix-vector multiplications and other arithmetic, |
| 81 | + while keeping the state in cooperative vector form. |
| 82 | + |
| 83 | +- Unpacking the final cooperative vector into regular Dr.Jit arrays. |
| 84 | + |
| 85 | +Cooperative vectors |
| 86 | +------------------- |
| 87 | + |
| 88 | +The central type of this API is the *cooperative vector* class |
| 89 | +:py:class:`nn.CoopVec`. This is a dynamically sized vector with uniformly |
| 90 | +typed elements. |
| 91 | + |
| 92 | +Unlike regular Dr.Jit arrays (e.g. :py:class:`drjit.cuda.ArrayXf`), cooperative |
| 93 | +vectors *do not allow indexed element access*. For example, the following |
| 94 | +operation raises an exception: |
| 95 | + |
| 96 | +.. code-block:: pycon |
| 97 | +
|
| 98 | + >>> vec = nn.CoopVec(Float16(1), Float16(2)) |
| 99 | + >>> vec[1] |
| 100 | + Traceback (most recent call last): |
| 101 | + File "<stdin>", line 1, in <module> |
| 102 | + TypeError: 'drjit.nn.CoopVec' object is not subscriptable |
| 103 | +
|
| 104 | +This restriction exists because the compiler may arbitrarily distribute |
| 105 | +cooperative vector components across threads for efficiency. Allowing direct |
| 106 | +indexing would interfere with this optimization. |
| 107 | + |
| 108 | +The :py:class:`drjit.nn.CoopVec` constructor accepts an arbitrary sequence |
| 109 | +of :ref:`PyTrees <pytrees>` containing Dr.Jit array and Python scalars and |
| 110 | +flattens them into a cooperative vector: |
| 111 | + |
| 112 | +.. code-block:: python |
| 113 | +
|
| 114 | + vec = nn.CoopVec( # Construct a 4D vector |
| 115 | + Float16(1), |
| 116 | + 3.0, |
| 117 | + Array2f(4, 5) |
| 118 | + ) |
| 119 | +
|
| 120 | +Use the standard Python unpacking syntax to turn cooperative vectors back into |
| 121 | +their components: |
| 122 | + |
| 123 | +.. code-block:: python |
| 124 | +
|
| 125 | + x, y, z = vec # Unpack a cooperative 3D vector |
| 126 | + x, y, *extra = vec # Unpack first 2 components, put rest into 'extra' |
| 127 | +
|
| 128 | +The same syntax can also be used to concatenate vectors: |
| 129 | + |
| 130 | +.. code-block:: python |
| 131 | +
|
| 132 | + vec_3 = nn.CoopVec(*vec_1, *vec_2) |
| 133 | +
|
| 134 | +Cooperative vectors can also be converted into nested arrays, tensors, or |
| 135 | +Python lists: |
| 136 | + |
| 137 | +.. code-block:: python |
| 138 | +
|
| 139 | + vec_arr = Array3f(vec) |
| 140 | + vec_ten = TensorXf(vec) |
| 141 | + vec_lst = list(vec) |
| 142 | +
|
| 143 | +Cooperative vectors are compatible with Dr.Jit's symbolic tracing |
| 144 | +infrastructure and may be used as state variables in |
| 145 | +:py:func:`drjit.while_loop` and :py:func:`drjit.if_stmt`. |
| 146 | + |
| 147 | +Arithmetic |
| 148 | +^^^^^^^^^^ |
| 149 | + |
| 150 | +Cooperative vectors support a restricted set of arithmetic operations: |
| 151 | + |
| 152 | +- Elementary arithmetic operations: ``+``, ``-``, ``*`` (but no division) |
| 153 | +- :py:func:`dr.fma() <fma>`, |
| 154 | +- :py:func:`dr.minimum() <minimum>`, :py:func:`dr.maximum() <maximum>`, |
| 155 | +- :py:func:`dr.log2() <log2>`, :py:func:`dr.exp2() <exp2>`, |
| 156 | +- :py:func:`dr.tanh() <tanh>`, |
| 157 | +- :py:func:`dr.step() <step>`. |
| 158 | +- :py:func:`nn.matvec() <drjit.nn.matvec>` |
| 159 | + |
| 160 | +These operations directly map to hardware-optimized operations on CUDA/OptiX. |
| 161 | +Operations outside of this set can be realized via unpacking/repacking, e.g.: |
| 162 | + |
| 163 | +.. code-block:: |
| 164 | +
|
| 165 | + x : nn.CoopVec = ... |
| 166 | + y = nn.CoopVec(dr.sin(v) for v in x) |
| 167 | +
|
| 168 | +However, this may degrade performance. It is best to keep cooperative vectors |
| 169 | +in their opaque layout whenever possible. |
| 170 | + |
| 171 | +Arithmetic operations may mix cooperative vectors and regular Dr.Jit arrays or |
| 172 | +Python scalars, which will undergo implicit broadcasting. |
| 173 | + |
| 174 | +.. code-block:: |
| 175 | +
|
| 176 | + x: nn.CoopVec[dr.cuda.Float16] = ... |
| 177 | + y: dr.cuda.Float16 = ... |
| 178 | + z = dr.maximum(x, 0) + y |
| 179 | +
|
| 180 | +.. _matrix_views: |
| 181 | + |
| 182 | +Matrix views |
| 183 | +------------ |
| 184 | + |
| 185 | +Input matrices and bias vectors should generally be converted into a |
| 186 | +hardware-dependent layout to improve performance compared to the default |
| 187 | +row-major representation (also, many operations raise exceptions on the |
| 188 | +OptiX/CUDA backend when matrices are not in such an optimal layout). |
| 189 | + |
| 190 | +The function :py:func:`nn.pack() <drjit.nn.pack>` performs this conversion and |
| 191 | +furthermore packs data into a shared buffer for optimal efficiency. The |
| 192 | +function takes an arbitrary sequence of :ref:`PyTrees <pytrees>` as input and |
| 193 | +returns a result with the same structure. |
| 194 | + |
| 195 | +.. code-block:: python |
| 196 | +
|
| 197 | + A: TensorXf = ... |
| 198 | + b: Float = ... |
| 199 | + A_view, b_view = nn.pack(A, b, layout='inference') |
| 200 | +
|
| 201 | +Every Dr.Jit array or tensor will be replaced by a |
| 202 | +:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer |
| 203 | +annotated with layout and type metadata. The function can generate optimal |
| 204 | +memory layouts for either *inference* (the default) and *training*. You must |
| 205 | +specify ``layout='training'`` if you wish to differentiate matrix |
| 206 | +multiplication in reverse mode. |
| 207 | + |
| 208 | +Following this step, ``A`` and ``b`` have been merged into ``buffer``, and |
| 209 | +``A_view`` and ``b_view`` encode the offset and layout within this larger |
| 210 | +buffer. Matrix views *cannot* be used in arithmetic expressions and are best |
| 211 | +thought of as opaque handles. They only exist to describe the input of the |
| 212 | +matrix-vector multiplication operation explained next. |
| 213 | + |
| 214 | +Two other view-related operations be useful in certain situations, please |
| 215 | +see the linked documentation for details. |
| 216 | + |
| 217 | +- :py:func:`drjit.nn.unpack` converts optimal-layout data back into a row-major layout. |
| 218 | +- :py:func:`drjit.nn.view` creates row-major views. |
| 219 | + |
| 220 | +Matrix-vector products |
| 221 | +---------------------- |
| 222 | + |
| 223 | +The main purpose of cooperative vectors is the matrix-vector multiplication |
| 224 | +operation :py:func:`nn.matvec() <drjit.nn.matvec>`: |
| 225 | + |
| 226 | +.. code-block:: python |
| 227 | +
|
| 228 | + y = nn.matvec(A, x, b) # Compute y = A @ x + b |
| 229 | +
|
| 230 | +Here, |
| 231 | + |
| 232 | +- ``A`` and ``b`` are *views* (:py:class:`nn.MatrixView`) created by |
| 233 | + :py:func:`nn.pack() <drjit.nn.pack>` or :py:func:`nn.view() |
| 234 | + <drjit.nn.view>`. |
| 235 | +- ``x`` and ``y`` are cooperative vectors. They are interpreted as *column |
| 236 | + vectors*, i.e., ``y = A[:, 0] * x[0] + A[:, 1] * x[1] + ... + b``. |
| 237 | +- the ``b`` term is optional. |
| 238 | + |
| 239 | +The function also accepts an optional ``transpose=True`` parameter to compute |
| 240 | +:math:`A^Tx + b`. |
| 241 | + |
| 242 | +The standard Python ``A @ x`` and ``A.T @ x`` matrix multiplication syntax |
| 243 | +works as well. However, if your computation requires the addition of a ``b`` |
| 244 | +vector, prefer :py:func:`nn.matvec() <drjit.nn.matvec>` over this syntax, since |
| 245 | +it merges both steps into a single operation. |
| 246 | + |
| 247 | +Differentiation |
| 248 | +--------------- |
| 249 | + |
| 250 | +Cooperative vectors support automatic differentiation. Simply pack variables |
| 251 | +with tracked gradients into cooperative vectors---the system will then |
| 252 | +propagate derivatives through subsequent operations. Here is an example: |
| 253 | + |
| 254 | +.. code-block:: python |
| 255 | +
|
| 256 | + # Differentiable input |
| 257 | + a = Array2f16(..) |
| 258 | + dr.enable_grad(a) |
| 259 | +
|
| 260 | + # Differentiable matrix + bias vector |
| 261 | + buffer, A_view, b_view = nn.pack(A, b) |
| 262 | + dr.enable_grad(buffer) |
| 263 | +
|
| 264 | + # Pack grad-enabled variables into a cooperative vector |
| 265 | + x = nn.CoopVec(a) |
| 266 | +
|
| 267 | + # Differentiable matrix-vector multiplication |
| 268 | + y = dr.matvec(A_view, x, b_view) |
| 269 | +
|
| 270 | + r0, r1 = y # Unpack |
| 271 | + loss = r0**2 + r1**2 # Continue calculation and .. |
| 272 | + dr.backward_from(loss) # .. eventually backpropagate |
| 273 | +
|
| 274 | +Specific views or cooperative vectors can also be detached via |
| 275 | +:py:func:`drjit.detach()` to inhibit gradient propagation, e.g.: |
| 276 | + |
| 277 | +.. code-block:: python |
| 278 | +
|
| 279 | + y = nn.matvec(A_view, dr.detach(x), dr.detach(b_view)) |
| 280 | +
|
| 281 | +Note that the conversion functions :py:func:`nn.pack() <drjit.nn.pack()>` and |
| 282 | +:py:func:`nn.unpack() <drjit.nn.unpack()>` are *not differentiable*. This is |
| 283 | +intentional: to train a neural network, convert the initial coefficient values |
| 284 | +into training-optimal layout and optimize this representation directly. Doing |
| 285 | +so is more efficient than changing layouts twice in every optimization step |
| 286 | +(once for the weights and once for their derivatives). |
| 287 | + |
| 288 | +The following AD operations recognize :py:func:`nn.CoopVec |
| 289 | +<drjit.nn.CoopVec>` and :py:func:`nn.MatrixView <drjit.nn.MatrixView>` objects: |
| 290 | + |
| 291 | +- :py:func:`grad_enabled`, :py:func:`enable_grad`, :py:func:`disable_grad`. |
| 292 | +- :py:func:`detach`. |
| 293 | + |
| 294 | +Performance considerations |
| 295 | +-------------------------- |
| 296 | + |
| 297 | +- **CUDA/OptiX** backend: |
| 298 | + |
| 299 | + - :py:func:`nn.matvec() <drjit.nn.matvec>` currently requires 16-bit |
| 300 | + floating point arguments. FP8 formats may be added in the future. |
| 301 | + |
| 302 | + - Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column |
| 303 | + counts are not a multiples of 8 or 16 will be zero-padded internally. There |
| 304 | + is no performance benefit in working with such intermediate sizes. |
| 305 | + |
| 306 | +- **LLVM** backend: |
| 307 | + |
| 308 | + - There is no difference between row-major and training/inference-optimal |
| 309 | + layouts on the CPU. However, using :py:func:`nn.pack() |
| 310 | + <drjit.nn.pack>` is still recommended, since packing multiple arrays |
| 311 | + into a shared buffer has a small performance benefit. |
| 312 | + |
| 313 | + - On Intel-compatible processors, using half precision cooperative vectors is |
| 314 | + not recommended. FP16 matrix multiplication requires ``AVX512FP16``, an |
| 315 | + extension not yet available on consumer CPUs as of 2025. Without this |
| 316 | + extension, FP16 computation involves many costly FP16 ↔ FP32 roundtrips. |
0 commit comments