@@ -3808,7 +3808,7 @@ class CoopVecPack : public dr::detail::CustomOpBase {
3808
3808
3809
3809
void forward () override {
3810
3810
std::lock_guard<Lock> guard (state.lock );
3811
- uint32_t size = (uint32_t ) m_input_indices .size ();
3811
+ uint32_t size = (uint32_t ) m_inputs .size ();
3812
3812
JitIndex *tmp = (JitIndex *) alloca (sizeof (JitIndex) * size);
3813
3813
size_t n_valid = 0 ;
3814
3814
@@ -3835,7 +3835,7 @@ class CoopVecPack : public dr::detail::CustomOpBase {
3835
3835
3836
3836
void backward () override {
3837
3837
std::lock_guard<Lock> guard (state.lock );
3838
- uint32_t n = (uint32_t ) m_input_indices .size ();
3838
+ uint32_t n = (uint32_t ) m_inputs .size ();
3839
3839
3840
3840
Variable *v = state[m_output_indices[0 ]];
3841
3841
if (!v->grad .valid ())
@@ -3844,9 +3844,13 @@ class CoopVecPack : public dr::detail::CustomOpBase {
3844
3844
JitIndex *tmp = (JitIndex *) alloca (sizeof (JitIndex) * n);
3845
3845
jit_coop_vec_unpack (v->grad .index (), n, tmp);
3846
3846
3847
- for (size_t i = 0 ; i < m_input_indices.size (); ++i) {
3848
- Variable *v2 = state[m_inputs[i]];
3849
- v2->accum (JitVar::steal (tmp[i]), v2->size );
3847
+ for (size_t i = 0 ; i < m_inputs.size (); ++i) {
3848
+ uint32_t index = m_inputs[i];
3849
+ JitVar var = JitVar::steal (tmp[i]);
3850
+ if (!index )
3851
+ continue ;
3852
+ Variable *v2 = state[index ];
3853
+ v2->accum (var, v2->size );
3850
3854
}
3851
3855
}
3852
3856
0 commit comments