Skip to content

Commit 2d46aae

Browse files
committed
fix ad_coop_vec_pack() derivative
1 parent 58d45e5 commit 2d46aae

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/extra/autodiff.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3808,7 +3808,7 @@ class CoopVecPack : public dr::detail::CustomOpBase {
38083808

38093809
void forward() override {
38103810
std::lock_guard<Lock> guard(state.lock);
3811-
uint32_t size = (uint32_t) m_input_indices.size();
3811+
uint32_t size = (uint32_t) m_inputs.size();
38123812
JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * size);
38133813
size_t n_valid = 0;
38143814

@@ -3835,7 +3835,7 @@ class CoopVecPack : public dr::detail::CustomOpBase {
38353835

38363836
void backward() override {
38373837
std::lock_guard<Lock> guard(state.lock);
3838-
uint32_t n = (uint32_t) m_input_indices.size();
3838+
uint32_t n = (uint32_t) m_inputs.size();
38393839

38403840
Variable *v = state[m_output_indices[0]];
38413841
if (!v->grad.valid())
@@ -3844,9 +3844,13 @@ class CoopVecPack : public dr::detail::CustomOpBase {
38443844
JitIndex *tmp = (JitIndex *) alloca(sizeof(JitIndex) * n);
38453845
jit_coop_vec_unpack(v->grad.index(), n, tmp);
38463846

3847-
for (size_t i = 0; i < m_input_indices.size(); ++i) {
3848-
Variable *v2 = state[m_inputs[i]];
3849-
v2->accum(JitVar::steal(tmp[i]), v2->size);
3847+
for (size_t i = 0; i < m_inputs.size(); ++i) {
3848+
uint32_t index = m_inputs[i];
3849+
JitVar var = JitVar::steal(tmp[i]);
3850+
if (!index)
3851+
continue;
3852+
Variable *v2 = state[index];
3853+
v2->accum(var, v2->size);
38503854
}
38513855
}
38523856

0 commit comments

Comments
 (0)