Skip to content

Cooperative Vector API #384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open

Cooperative Vector API #384

wants to merge 9 commits into from

Conversation

wjakob
Copy link
Member

@wjakob wjakob commented Apr 15, 2025

This feature adds cooperative vector support to Dr.Jit. They enable efficient compilation and evaluation of expressions involving matrix multiplication and cater to situations where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways.

On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.).

For further details on this new API and now to use it, refer to the documentation:

@wjakob wjakob force-pushed the coopvec branch 2 times, most recently from 7c65d4b to cd67909 Compare April 15, 2025 10:18
@wjakob wjakob requested a review from merlinND April 15, 2025 12:41
@wjakob wjakob force-pushed the coopvec branch 2 times, most recently from 89499a0 to 0247195 Compare April 16, 2025 05:07
Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review part 1 (three big files left to review)

a, b = t(1), t(2)
dr.enable_grad(a, b)
z = nn.CoopVec(a, b) # pack
assert dr.grad_enabled(z)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also test that dr.enable_grad(z) raises as expected

dr.schedule(x.grad, y.grad)
assert x.grad == 4
assert y.grad == 5
assert dr.grad_enabled(z)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can also test dr.detach(z)

z + 3
)
b = nn.cast(a, dr.float32_array_t(t))
c = nn.cast(b, dr.float16_array_t(t))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test grad enabled / disabled and grad propagation through casts?
Ideally, gradients would just be converted to the new precision. It's an important use-case to have most of a differentiable pipeline in fp32 and locally convert to fp16 for the MLP.

@merlinND
Copy link
Member

merlinND commented Apr 17, 2025

One thing that will come up when we add the hash grid encoding, but good to keep in mind in general: atomic addition of f16 values is much slower than f16x2. I think it should be fairly easy to add a special case for f16 in the scatter_packet implementation?
(Not really for this PR, just to keep in mind for later)

@wjakob wjakob force-pushed the coopvec branch 2 times, most recently from 2d46aae to f80af5a Compare April 21, 2025 14:44
Wenzel Jakob and others added 7 commits April 22, 2025 07:55
Cooperative vectors enable efficient compilation and evaluation of
expressions involving matrix multiplication. They cater to a specific
use case, where each execution thread performs a sequence of independent
multiplications by reasonably small matrices (e.g., 64x64). This enables
the fully fused evaluation of small multilayer perceptrons within a
larger program. That said, the feature isn't specific to MLPs and could
also be used in other ways.

On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX
cooperative vector API leveraging the builtin tensor core for
acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative
vector operations using available instruction set extensions (AVX512,
NEON, etc.).

For further details on this new API and now to use it, refer to the
documentation in ``docs/coop_vec.rst``.
This commit improves handling of evaluated loops with grad-enabled state
variables. Previously, the AD variable ID of each differentiable state
variable changed in every iteration, even if the loop did not touch that
variable. This is an implementation detail of the loop evaluation code,
that should, however, not leak into user code. This commit fixes this
behavior.
This commit fixes bugs in the compilation of reverse-mode derivatives of
simple loops (i.e, loops with max_iterations==-1) and updates the test
suite to cover problematic cases.
This commit fixes bugs and adds tests to ensure that matrix
multiplication can be correctly differentiated in reverse-mode when it
occurs inside a "simple" loop (i.e., a loop with max_iterations==-1).
@wjakob
Copy link
Member Author

wjakob commented Apr 22, 2025

One thing that will come up when we add the hash grid encoding, but good to keep in mind in general: atomic addition of f16 values is much slower than f16x2. I think it should be fairly easy to add a special case for f16 in the scatter_packet implementation?
(Not really for this PR, just to keep in mind for later)

Dr.Jit-Core always generates the f16x2 assembly operation, even when only scattering a single f16 value. In the case of your hash grid, would it be possible to make use of the f16x2 format to scatter two values at once?

Right now, packet atomics are ignored by the CUDA backend. I think that Blackwell is the first consumer architecture that really supports these besides the f16x2 special case. In any case, such changes are out of scope for this already very big PR.

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went over the remaining files! Not many more comments this time.

I didn't really understand the changes in loop.cpp, I trust that they make sense :)

Every Dr.Jit array or tensor will be replaced by a
:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer
annotated with layout and type metadata. The function can generate optimal
memory layouts for either *inference* (the default) and *training*. You must
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
memory layouts for either *inference* (the default) and *training*. You must
memory layouts for either *inference* (the default) or *training*. You must


Following this step, ``A`` and ``b`` have been merged into ``buffer``, and
``A_view`` and ``b_view`` encode the offset and layout within this larger
buffer. Matrix views *cannot* be used in arithmetic expressions and are best
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a paragraph showing how to perform e.g. L2 regularization on the weights and / or weights clipping, given this limitation?

buffer, A_view, b_view = nn.pack(A, b)
dr.enable_grad(buffer)

# Pack grad-enabled variables into a cooperative vector
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly, at some point we had discussed that users should not try to enable gradients directly on the packed CoopVec(). The fact that enable_grad() is supported (noted below) makes this a bit ambiguous, so maybe it's worth saying explicitly:

Suggested change
# Pack grad-enabled variables into a cooperative vector
# Pack grad-enabled variables into a cooperative vector.
# Note that gradients were enabled on the components, it is not recommended
# to enable grads on the cooperative vector object after creation.

- Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column
counts are not a multiples of 8 or 16 will be zero-padded internally. There
is no performance benefit in working with such intermediate sizes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could repeat the point here about unpacking & repacking coopvecs:

Unpacking cooperative vectors may degrade performance. It is best to keep them in their opaque layout whenever possible.

return result;
}

/// Unpack a cooperative vecotr into a Dr.Jit array type like CoopVecXf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Unpack a cooperative vecotr into a Dr.Jit array type like CoopVecXf
/// Unpack a cooperative vector into a Dr.Jit array type like CoopVecXf

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it really expand to CoopVecXf? It looks like this object is already a CoopVecXf, but maybe I'm missing something.

@@ -3605,7 +3678,7 @@ const char *ad_var_graphviz() {
" l4 [style=filled fillcolor=yellowgreen label=\"Gradient present\"];\n"
" l3 [style=filled fillcolor=salmon label=\"Input\"];\n"
" l2 [style=filled fillcolor=lightblue2 label=\"Output\"];\n"
" l1 [style=filled fillcolor=wheat label=\"Labeled\"];\n"
" l0 [style=filled fillcolor=wheat label=\"Labeled\"];\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l1 is not correct? I thought the naming could be arbitrary here.

Comment on lines +4039 to +4052
JitVar scale = scalar_coop_vec(i0, -2.8853900817779268), // -2/log(2)
c0 = scalar_coop_vec(i0, 3.98046875),
c1 = scalar_coop_vec(i0, -7.4140625),
c2 = scalar_coop_vec(i0, 8.2421875),
c3 = scalar_coop_vec(i0, -5.1640625),
c4 = scalar_coop_vec(i0, 1.35546875);

JitVar x0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, (uint32_t) i0, scale.index())),
x1 = JitVar::steal(jit_coop_vec_unary_op(JitOp::Exp2, x0.index())),
y0 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), c4.index(), c3.index())),
y1 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y0.index(), c2.index())),
y2 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y1.index(), c1.index())),
y3 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y2.index(), c0.index())),
y4 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, x1.index(), y3.index()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fairly heavy, I wonder what's the intended way to compute the adjoint of a tanh activation within the official OptiX CoopVec API.

Maybe I'm missing something, but if we already have the primal value tanh(x) as a CoopVec, wouldn't the adjoint 1 - tanh(x) ** 2 be cheaper?

Arg(i2, 1.0, coop{}));

default:
ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented.");
ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented for op ... .");

return result.release();

default:
ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented.");
ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented for op ... .");

def _alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1, /) -> Tuple[Module, int]:
result = []
for l in self.layers:
l_new, size = l._alloc(dtype, size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought there was a Python convention that arguments with a default value should be passed as kwarg, but it may not be as widespread as I thought.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants