-
Notifications
You must be signed in to change notification settings - Fork 48
Cooperative Vector API #384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
7c65d4b
to
cd67909
Compare
89499a0
to
0247195
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review part 1 (three big files left to review)
a, b = t(1), t(2) | ||
dr.enable_grad(a, b) | ||
z = nn.CoopVec(a, b) # pack | ||
assert dr.grad_enabled(z) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also test that dr.enable_grad(z)
raises as expected
dr.schedule(x.grad, y.grad) | ||
assert x.grad == 4 | ||
assert y.grad == 5 | ||
assert dr.grad_enabled(z) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also test dr.detach(z)
z + 3 | ||
) | ||
b = nn.cast(a, dr.float32_array_t(t)) | ||
c = nn.cast(b, dr.float16_array_t(t)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test grad enabled / disabled and grad propagation through casts?
Ideally, gradients would just be converted to the new precision. It's an important use-case to have most of a differentiable pipeline in fp32 and locally convert to fp16 for the MLP.
One thing that will come up when we add the hash grid encoding, but good to keep in mind in general: atomic addition of |
2d46aae
to
f80af5a
Compare
Cooperative vectors enable efficient compilation and evaluation of expressions involving matrix multiplication. They cater to a specific use case, where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways. On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.). For further details on this new API and now to use it, refer to the documentation in ``docs/coop_vec.rst``.
This commit improves handling of evaluated loops with grad-enabled state variables. Previously, the AD variable ID of each differentiable state variable changed in every iteration, even if the loop did not touch that variable. This is an implementation detail of the loop evaluation code, that should, however, not leak into user code. This commit fixes this behavior.
This commit fixes bugs in the compilation of reverse-mode derivatives of simple loops (i.e, loops with max_iterations==-1) and updates the test suite to cover problematic cases.
This commit fixes bugs and adds tests to ensure that matrix multiplication can be correctly differentiated in reverse-mode when it occurs inside a "simple" loop (i.e., a loop with max_iterations==-1).
Dr.Jit-Core always generates the f16x2 assembly operation, even when only scattering a single Right now, packet atomics are ignored by the CUDA backend. I think that Blackwell is the first consumer architecture that really supports these besides the f16x2 special case. In any case, such changes are out of scope for this already very big PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Went over the remaining files! Not many more comments this time.
I didn't really understand the changes in loop.cpp
, I trust that they make sense :)
Every Dr.Jit array or tensor will be replaced by a | ||
:py:class:`drjit.nn.MatrixView`, which is a thin pointer into a shared buffer | ||
annotated with layout and type metadata. The function can generate optimal | ||
memory layouts for either *inference* (the default) and *training*. You must |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
memory layouts for either *inference* (the default) and *training*. You must | |
memory layouts for either *inference* (the default) or *training*. You must |
|
||
Following this step, ``A`` and ``b`` have been merged into ``buffer``, and | ||
``A_view`` and ``b_view`` encode the offset and layout within this larger | ||
buffer. Matrix views *cannot* be used in arithmetic expressions and are best |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a paragraph showing how to perform e.g. L2 regularization on the weights and / or weights clipping, given this limitation?
buffer, A_view, b_view = nn.pack(A, b) | ||
dr.enable_grad(buffer) | ||
|
||
# Pack grad-enabled variables into a cooperative vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, at some point we had discussed that users should not try to enable gradients directly on the packed CoopVec()
. The fact that enable_grad()
is supported (noted below) makes this a bit ambiguous, so maybe it's worth saying explicitly:
# Pack grad-enabled variables into a cooperative vector | |
# Pack grad-enabled variables into a cooperative vector. | |
# Note that gradients were enabled on the components, it is not recommended | |
# to enable grads on the cooperative vector object after creation. |
- Tensor cores work with 8x8 and 16x16 blocks. Matrices, whose row or column | ||
counts are not a multiples of 8 or 16 will be zero-padded internally. There | ||
is no performance benefit in working with such intermediate sizes. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could repeat the point here about unpacking & repacking coopvecs:
Unpacking cooperative vectors may degrade performance. It is best to keep them in their opaque layout whenever possible.
return result; | ||
} | ||
|
||
/// Unpack a cooperative vecotr into a Dr.Jit array type like CoopVecXf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Unpack a cooperative vecotr into a Dr.Jit array type like CoopVecXf | |
/// Unpack a cooperative vector into a Dr.Jit array type like CoopVecXf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it really expand to CoopVecXf
? It looks like this object is already a CoopVecXf
, but maybe I'm missing something.
@@ -3605,7 +3678,7 @@ const char *ad_var_graphviz() { | |||
" l4 [style=filled fillcolor=yellowgreen label=\"Gradient present\"];\n" | |||
" l3 [style=filled fillcolor=salmon label=\"Input\"];\n" | |||
" l2 [style=filled fillcolor=lightblue2 label=\"Output\"];\n" | |||
" l1 [style=filled fillcolor=wheat label=\"Labeled\"];\n" | |||
" l0 [style=filled fillcolor=wheat label=\"Labeled\"];\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l1
is not correct? I thought the naming could be arbitrary here.
JitVar scale = scalar_coop_vec(i0, -2.8853900817779268), // -2/log(2) | ||
c0 = scalar_coop_vec(i0, 3.98046875), | ||
c1 = scalar_coop_vec(i0, -7.4140625), | ||
c2 = scalar_coop_vec(i0, 8.2421875), | ||
c3 = scalar_coop_vec(i0, -5.1640625), | ||
c4 = scalar_coop_vec(i0, 1.35546875); | ||
|
||
JitVar x0 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, (uint32_t) i0, scale.index())), | ||
x1 = JitVar::steal(jit_coop_vec_unary_op(JitOp::Exp2, x0.index())), | ||
y0 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), c4.index(), c3.index())), | ||
y1 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y0.index(), c2.index())), | ||
y2 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y1.index(), c1.index())), | ||
y3 = JitVar::steal(jit_coop_vec_ternary_op(JitOp::Fma, x1.index(), y2.index(), c0.index())), | ||
y4 = JitVar::steal(jit_coop_vec_binary_op(JitOp::Mul, x1.index(), y3.index())); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fairly heavy, I wonder what's the intended way to compute the adjoint of a tanh
activation within the official OptiX CoopVec API.
Maybe I'm missing something, but if we already have the primal value tanh(x)
as a CoopVec, wouldn't the adjoint 1 - tanh(x) ** 2
be cheaper?
Arg(i2, 1.0, coop{})); | ||
|
||
default: | ||
ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented."); | |
ad_raise("ad_coop_vec_ternary_op(): differentiable version not implemented for op ... ."); |
return result.release(); | ||
|
||
default: | ||
ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented."); | |
ad_raise("ad_coop_vec_binary_op(): differentiable version not implemented for op ... ."); |
def _alloc(self, dtype: Type[drjit.ArrayBase], size: int = -1, /) -> Tuple[Module, int]: | ||
result = [] | ||
for l in self.layers: | ||
l_new, size = l._alloc(dtype, size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought there was a Python convention that arguments with a default value should be passed as kwarg, but it may not be as widespread as I thought.
This feature adds cooperative vector support to Dr.Jit. They enable efficient compilation and evaluation of expressions involving matrix multiplication and cater to situations where each execution thread performs a sequence of independent multiplications by reasonably small matrices (e.g., 64x64). This enables the fully fused evaluation of small multilayer perceptrons within a larger program. That said, the feature isn't specific to MLPs and could also be used in other ways.
On NVIDIA GPUs (Turing or newer), cooperative vectors map to the OptiX cooperative vector API leveraging the builtin tensor core for acceleration. On the CPU (LLVM) backend, Dr.Jit compiles cooperative vector operations using available instruction set extensions (AVX512, NEON, etc.).
For further details on this new API and now to use it, refer to the documentation: