Skip to content

Commit 282da88

Browse files
lnuicnjroussel
authored andcommitted
add ad support and fix bug for c++ repeat/tile operations
1 parent b07811c commit 282da88

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

include/drjit/autodiff.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,11 @@ struct DRJIT_TRIVIAL_ABI DiffArray
408408
return steal(jit_var_tile(m_index, count));
409409
}
410410

411-
DiffArray repeat_(size_t count) const {
411+
DiffArray repeat_(size_t count, size_t max_size = 0) const {
412412
if constexpr (IsFloat)
413-
return steal(ad_var_repeat(m_index, count));
413+
return steal(ad_var_repeat(m_index, count, max_size));
414414
else
415-
return steal(jit_var_repeat(m_index, count));
415+
return steal(jit_var_repeat(m_index, count, max_size));
416416
}
417417

418418
DiffArray dot_(const DiffArray &a) const { return sum(*this * a); }

include/drjit/extra.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ extern DRJIT_EXTRA_EXPORT uint64_t ad_var_block_reduce(ReduceOp op,
168168
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_tile(uint64_t index, uint32_t count);
169169

170170
/// Repeat values of an array into larger blocks
171-
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_repeat(uint64_t index, uint32_t count);
171+
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_repeat(uint64_t index, uint32_t count,
172+
size_t max_size = 0);
172173

173174
/// Perform a differentiable gather operation. See jit_var_gather for signature.
174175
extern DRJIT_EXTRA_EXPORT uint64_t ad_var_gather(uint64_t source,

src/extra/autodiff.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3031,6 +3031,54 @@ static Index ad_var_memop_remap(Index index, bool input) {
30313031

30323032
// ==========================================================================
30333033

3034+
uint64_t ad_var_tile(Index index, uint32_t count) {
3035+
JitVar result = JitVar::steal(jit_var_tile(jit_index(index), count));
3036+
3037+
if (likely(is_detached(index)))
3038+
return result.release();
3039+
else {
3040+
VarInfo info = jit_set_backend(jit_index(index));
3041+
3042+
JitVar input_size_var = JitVar::steal(jit_var_literal(info.backend, VarType::UInt32, &info.size, 1, 0));
3043+
JitVar offset = JitVar::steal(jit_var_counter(info.backend, result.size()));
3044+
offset = JitVar::steal(jit_var_mod(offset.index(), input_size_var.index()));
3045+
3046+
uint64_t one_u64 = 1;
3047+
JitVar mask = JitVar::steal(jit_var_literal(info.backend, VarType::Bool, &one_u64, 1, 0));
3048+
3049+
return ad_var_new("tile", std::move(result),
3050+
SpecialArg(index, new Gather(
3051+
GenericArray<uint32_t>::borrow(offset.index()),
3052+
JitMask::borrow(mask.index()),
3053+
ReduceMode::Auto)));
3054+
}
3055+
}
3056+
3057+
uint64_t ad_var_repeat(Index index, uint32_t count, size_t max_size) {
3058+
JitVar result = JitVar::steal(jit_var_repeat(jit_index(index), count, max_size));
3059+
3060+
if (likely(is_detached(index)))
3061+
return result.release();
3062+
else {
3063+
VarInfo info = jit_set_backend(jit_index(index));
3064+
3065+
JitVar offset = JitVar::steal(jit_var_counter(info.backend, result.size()));
3066+
JitVar divisor = JitVar::steal(jit_var_literal(info.backend, VarType::UInt32, &count, 1, 0));
3067+
offset = JitVar::steal(jit_var_div(offset.index(), divisor.index()));
3068+
3069+
uint64_t one_u64 = 1;
3070+
JitVar mask = JitVar::steal(jit_var_literal(info.backend, VarType::Bool, &one_u64, 1, 0));
3071+
3072+
return ad_var_new("repeat", std::move(result),
3073+
SpecialArg(index, new Gather(
3074+
GenericArray<uint32_t>::borrow(offset.index()),
3075+
JitMask::borrow(mask.index()),
3076+
ReduceMode::Auto)));
3077+
}
3078+
}
3079+
3080+
// ==========================================================================
3081+
30343082
uint64_t ad_var_gather(Index source, JitIndex offset, JitIndex mask, ReduceMode mode) {
30353083
JitVar result = JitVar::steal(jit_var_gather(jit_index(source), offset, mask));
30363084

0 commit comments

Comments
 (0)