Skip to content

Commit 1cb66e8

Browse files
committed
Added new jit_var_scatter_inc operation for stream compaction
This commit adds a new and relatively advanced Dr.Jit operation named ``jit_var_scatter_inc`` that atomically increments a value within a ``uint32``-typed Dr.Jit array. It works just like the standard ``jit_var_scatter`` operation for 32-bit unsigned integer operands, but with a fixed ``value=1`` parameter and ``reduce_op=ReduceOp::Add``. The main difference is that this variant additionally returns the *old* value of the target array prior to the atomic update in contrast to the more general scatter-reduction, which just returns ``None``. The operation also supports masking---the return value in the unmasked case is undefined. This operation is a building block for stream compaction: threads can scatter-increment a global counter to request a spot in an array and then write their result there. The recipe for this is look as follows: ```python ctr = UInt32(0) # Counter array mask = drjit.ones(Bool, len(data_1)) # .. or a more complex condition my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.scatter( target=data_compact_2, value=data_2, index=my_index, mask=active ) ``` When following this approach, be sure to provide the same mask value to the ``dr.scatter_inc()`` and subsequent ``dr.scatter()`` operations. ``dr.scatter_inc()`` exhibits the following unusual behavior compared to normal Dr.Jit operations: the return value references the instantaneous state during a potentially large sequence of atomic operations. This instantaneous state is not reproducible in later kernel evaluations, and Dr.Jit will refuse to do so when the computed index is reused: ```python my_index = dr.scatter_inc(target=ctr, index=UInt32(0), mask=active) dr.scatter( target=data_compact_1, value=data_1, index=my_index, mask=active ) dr.eval(data_compact_1) # Run Kernel #1 dr.scatter( target=data_compact_2, value=data_2, index=my_index, # <-- oops, reusing my_index in another kernel. mask=active # This raises an exception. ) ``` To get the above code to work, you will need to evaluate ``my_index`` at the same time to materialize it into a stored (and therefore trivially reproducible) representation. For this, ensure that the size of the ``active`` mask matches ``len(data_*)`` and that it is not the trivial ``True`` default mask (otherwise, the evaluated ``my_index`` will be scalar). ```python dr.eval(data_compact_1, my_index) ``` Such multi-stage evaluation is potentially inefficient and may defeat the purpose of performing stream compaction in the first place. In general, prefer keeping all scatter operations involving the computed index in the same kernel, and then this issue does not arise.
1 parent a8f95ab commit 1cb66e8

16 files changed

+373
-54
lines changed

include/drjit-core/array.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ Array empty(size_t size) {
368368
}
369369

370370
template <typename Array>
371-
Array zero(size_t size = 1) {
371+
Array zeros(size_t size = 1) {
372372
typename Array::Value value = 0;
373373
return Array::steal(
374374
jit_var_literal(Array::Backend, Array::Type, &value, size));
@@ -409,6 +409,11 @@ void scatter_reduce(ReduceOp op, Array &target, const Array &value,
409409
index.index(), mask.index(), op));
410410
}
411411

412+
template <typename Array>
413+
Array scatter_inc(Array &target, const Array index, const JitArray<Array::Backend, bool> &mask = true) {
414+
return Array::steal(jit_var_scatter_inc(target.index_ptr(), index.index(), mask.index()));
415+
}
416+
412417
template <typename Array, typename Index>
413418
void scatter_reduce_kahan(Array &target_1, Array &target_2, const Array &value,
414419
const JitArray<Array::Backend, Index> &index,

include/drjit-core/jit.h

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -929,9 +929,11 @@ extern JIT_EXPORT uint32_t jit_var_scatter(uint32_t target, uint32_t value,
929929
/**
930930
* \brief Schedule a Kahan-compensated floating point atomic scatter-write
931931
*
932-
* This operation is just like `jit_var(scatter, ..., ReduceOp::Add)`. The
933-
* difference is that it simultaneously adds to two different target buffers
934-
* using the Kahan summation algorithm.
932+
* This operation is just like ``jit_var_scatter`` invoked with a floating
933+
* point operands and reduce_op=ReduceOp::Add.
934+
*
935+
* The difference is that it simultaneously adds to
936+
* two different target buffers using the Kahan summation algorithm.
935937
*
936938
* The implementation may overwrite the 'target_1' / 'target_2' pointers
937939
* if a copy needs to be made (for example, if another variable elsewhere
@@ -943,6 +945,23 @@ extern JIT_EXPORT void jit_var_scatter_reduce_kahan(uint32_t *target_1,
943945
uint32_t index,
944946
uint32_t mask);
945947

948+
/**
949+
* \brief Atomically increment a counter and return the old value
950+
*
951+
* This operation is just like ``jit_var_scatter`` invoked with 32-bit unsigned
952+
* integer operands, the value ``1``, and reduce_op=ReduceOp::Add.
953+
*
954+
* The main difference is that this variant returns the *old* value before the
955+
* atomic write (in contrast to the more general scatter reduction, where doing
956+
* so would be rather complicated).
957+
*
958+
* This operation is a building block for stream compaction: threads can
959+
* scatter-increment a global counter to request a spot in an array.
960+
*/
961+
extern JIT_EXPORT uint32_t jit_var_scatter_inc(uint32_t *target,
962+
uint32_t index,
963+
uint32_t mask);
964+
946965
/**
947966
* \brief Create an identical copy of the given variable
948967
*

src/api.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,12 @@ void jit_var_scatter_reduce_kahan(uint32_t *target_1, uint32_t *target_2,
407407
jitc_var_scatter_reduce_kahan(target_1, target_2, value, index, mask);
408408
}
409409

410+
uint32_t jit_var_scatter_inc(uint32_t *target, uint32_t index, uint32_t mask) {
411+
lock_guard guard(state.lock);
412+
return jitc_var_scatter_inc(target, index, mask);
413+
}
414+
415+
410416
uint32_t jit_var_pointer(JitBackend backend, const void *value,
411417
uint32_t dep, int write) {
412418
lock_guard guard(state.lock);

src/cuda_eval.cpp

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,12 @@
5050

5151
// Forward declarations
5252
static void jitc_cuda_render_stmt(uint32_t index, const Variable *v);
53-
static void jitc_cuda_render_var(uint32_t index, const Variable *v);
53+
static void jitc_cuda_render_var(uint32_t index, Variable *v);
5454
static void jitc_cuda_render_scatter(const Variable *v, const Variable *ptr,
5555
const Variable *value, const Variable *index,
5656
const Variable *mask);
57+
static void jitc_cuda_render_scatter_inc(Variable *v, const Variable *ptr,
58+
const Variable *index, const Variable *mask);
5759
static void jitc_cuda_render_scatter_kahan(const Variable *v, uint32_t index);
5860
static void jitc_cuda_render_printf(uint32_t index, const Variable *v,
5961
const Variable *mask);
@@ -157,7 +159,7 @@ void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group,
157159

158160
for (uint32_t gi = group.start; gi != group.end; ++gi) {
159161
uint32_t index = schedule[gi].index;
160-
const Variable *v = jitc_var(index);
162+
Variable *v = jitc_var(index);
161163
const uint32_t vti = v->type,
162164
size = v->size;
163165
const VarType vt = (VarType) vti;
@@ -311,7 +313,7 @@ void jitc_cuda_assemble_func(const char *name, uint32_t inst_id,
311313
name, n_regs, n_regs, n_regs, n_regs, n_regs, n_regs, n_regs);
312314

313315
for (ScheduledVariable &sv : schedule) {
314-
const Variable *v = jitc_var(sv.index);
316+
Variable *v = jitc_var(sv.index);
315317
const uint32_t vti = v->type;
316318
const VarType vt = (VarType) vti;
317319

@@ -401,7 +403,7 @@ static const char *reduce_op_name[(int) ReduceOp::Count] = {
401403
"", "add", "mul", "min", "max", "and", "or"
402404
};
403405

404-
static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
406+
static void jitc_cuda_render_var(uint32_t index, Variable *v) {
405407
const char *stmt = nullptr;
406408
Variable *a0 = v->dep[0] ? jitc_var(v->dep[0]) : nullptr,
407409
*a1 = v->dep[1] ? jitc_var(v->dep[1]) : nullptr,
@@ -720,6 +722,10 @@ static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
720722
jitc_cuda_render_scatter(v, a0, a1, a2, a3);
721723
break;
722724

725+
case VarKind::ScatterInc:
726+
jitc_cuda_render_scatter_inc(v, a0, a1, a2);
727+
break;
728+
723729
case VarKind::ScatterKahan:
724730
jitc_cuda_render_scatter_kahan(v, index);
725731
break;
@@ -808,7 +814,7 @@ static void jitc_cuda_render_scatter(const Variable *v,
808814
(jitc_flags() & (uint32_t) JitFlag::AtomicReduceLocal)) {
809815
fmt(" {\n"
810816
" .func reduce_$s_$t(.param .u64 ptr, .param .$t value);\n"
811-
" call reduce_$s_$t, (%rd3, $v);\n"
817+
" call.uni reduce_$s_$t, (%rd3, $v);\n"
812818
" }\n",
813819
op, value, value, op, value, value);
814820

@@ -901,6 +907,61 @@ static void jitc_cuda_render_scatter(const Variable *v,
901907
fmt("\nl_$u_done:\n", v->reg_index);
902908
}
903909

910+
static void jitc_cuda_render_scatter_inc(Variable *v,
911+
const Variable *ptr,
912+
const Variable *index,
913+
const Variable *mask) {
914+
bool index_zero = index->is_literal() && index->literal == 0;
915+
bool unmasked = mask->is_literal() && mask->literal == 1;
916+
917+
fmt_intrinsic(
918+
".func (.param .u32 rv) reduce_inc_u32 (.param .u64 ptr) {\n"
919+
" .reg .pred %p<2>;\n"
920+
" .reg .b32 %r<11>;\n"
921+
" .reg .b64 %rd<2>;\n"
922+
"\n"
923+
" ld.param.u64 %rd1, [ptr];\n"
924+
" activemask.b32 %r2;\n"
925+
" mov.u32 %r3, %lanemask_lt;\n"
926+
" and.b32 %r3, %r3, %r2;\n"
927+
" setp.ne.u32 %p1, %r3, 0;\n"
928+
" @%p1 bra L2;\n"
929+
"\n"
930+
"L1:\n"
931+
" popc.b32 %r4, %r2;\n"
932+
" atom.global.add.u32 %r5, [%rd1], %r4;\n"
933+
"\n"
934+
"L2:\n"
935+
" popc.b32 %r6, %r3;\n"
936+
" brev.b32 %r7, %r2;\n"
937+
" bfind.shiftamt.u32 %r8, %r7;\n"
938+
" shfl.sync.idx.b32 %r9, %r5, %r8, 31, %r2;\n"
939+
" add.u32 %r10, %r6, %r9;\n"
940+
" st.param.u32 [rv], %r10;\n"
941+
" ret;\n"
942+
"}\n");
943+
944+
if (!unmasked)
945+
fmt(" @!$v bra l_$u_done;\n", mask, v->reg_index);
946+
947+
if (index_zero) {
948+
fmt(" mov.u64 %rd3, $v;\n", ptr);
949+
} else {
950+
fmt(" mad.wide.$t %rd3, $v, 4, $v;\n",
951+
index, index, ptr);
952+
}
953+
954+
fmt(" {\n"
955+
" .func (.param .u32 rv) reduce_inc_u32 (.param .u64 ptr);\n"
956+
" call.uni ($v), reduce_inc_u32, (%rd3);\n"
957+
" }\n", v);
958+
959+
if (!unmasked)
960+
fmt("\nl_$u_done:\n", v->reg_index);
961+
962+
v->consumed = 1;
963+
}
964+
904965
static void jitc_cuda_render_scatter_kahan(const Variable *v, uint32_t v_index) {
905966
const Extra &extra = state.extra[v_index];
906967

src/eval.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ void jitc_eval(ThreadState *ts) {
709709
v->kind = (uint32_t) VarKind::Data;
710710
v->data = sv.data;
711711
v->output_flag = false;
712+
v->consumed = false;
712713
}
713714

714715
if (unlikely(v->extra)) {

src/internal.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ enum VarKind : uint32_t {
8888
Cast, Bitcast,
8989

9090
// Memory-related operations
91-
Gather, Scatter, ScatterKahan,
91+
Gather, Scatter, ScatterInc, ScatterKahan,
9292

9393
// Specialized nodes for vcalls
9494
VCallMask, VCallSelf,
@@ -231,8 +231,11 @@ struct Variable {
231231
/// Is this variable marked as an output?
232232
uint32_t output_flag : 1;
233233

234+
/// Consumed bit for operations that should only be executed once
235+
uint32_t consumed : 1;
236+
234237
/// Unused for now
235-
uint32_t unused_2 : 6;
238+
uint32_t unused_2 : 5;
236239

237240
/// Offset of the argument in the list of kernel parameters
238241
uint32_t param_offset;

src/io.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ void jitc_lz4_init() {
6363
jitc_lz4_dict_ready = true;
6464
}
6565

66-
/* Computes padding to align cache file content to a multiple of sizeof(void*).
67-
This prevents undefiend behavior due to misaligned memory reads/writes. */
66+
/* Computes padding to align cache file content to a multiple of sizeof(void*).
67+
This prevents undefiend behavior due to misaligned memory reads/writes. */
6868
static uint32_t compute_padding(const CacheFileHeader &header) {
6969
uint32_t padding_size = (header.source_size + header.kernel_size) % sizeof(void *);
7070
if (padding_size)
@@ -353,7 +353,7 @@ bool jitc_kernel_write(const char *source, uint32_t source_size,
353353
header.reloc_size = kernel.llvm.n_reloc * sizeof(void *);
354354

355355
uint32_t padding_size = compute_padding(header);
356-
uint32_t in_size = header.source_size + header.kernel_size
356+
uint32_t in_size = header.source_size + header.kernel_size
357357
+ padding_size + header.reloc_size,
358358
out_size = LZ4_compressBound(in_size);
359359

@@ -365,7 +365,7 @@ bool jitc_kernel_write(const char *source, uint32_t source_size,
365365
memset(temp_in + header.source_size + header.kernel_size, 0, padding_size);
366366

367367
if (backend == JitBackend::LLVM) {
368-
uintptr_t *reloc_out = (uintptr_t *) (temp_in + header.source_size +
368+
uintptr_t *reloc_out = (uintptr_t *) (temp_in + header.source_size +
369369
header.kernel_size + padding_size);
370370
for (uint32_t i = 0; i < kernel.llvm.n_reloc; ++i)
371371
reloc_out[i] = (uintptr_t) kernel.llvm.reloc[i] - (uintptr_t) kernel.data;

src/llvm_eval.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ static void jitc_llvm_render_scatter(const Variable *v, const Variable *ptr,
8686
const Variable *value, const Variable *index,
8787
const Variable *mask);
8888
static void jitc_llvm_render_scatter_kahan(const Variable *v, uint32_t index);
89+
static void jitc_llvm_render_scatter_inc(Variable *v,
90+
const Variable *ptr,
91+
const Variable *index,
92+
const Variable *mask);
8993
static void jitc_llvm_render_printf(uint32_t index, const Variable *v,
9094
const Variable *mask, const Variable *target);
9195
static void jitc_llvm_render_trace(uint32_t index, const Variable *v,
@@ -777,6 +781,10 @@ static void jitc_llvm_render_var(uint32_t index, Variable *v) {
777781
jitc_llvm_render_scatter(v, a0, a1, a2, a3);
778782
break;
779783

784+
case VarKind::ScatterInc:
785+
jitc_llvm_render_scatter_inc(v, a0, a1, a2);
786+
break;
787+
780788
case VarKind::ScatterKahan:
781789
jitc_llvm_render_scatter_kahan(v, index);
782790
break;
@@ -948,6 +956,54 @@ static void jitc_llvm_render_scatter(const Variable *v,
948956
}
949957
}
950958

959+
static void jitc_llvm_render_scatter_inc(Variable *v,
960+
const Variable *ptr,
961+
const Variable *index,
962+
const Variable *mask) {
963+
fmt( " $v_1 = extractelement $V, i32 0\n"
964+
"{ $v_2 = bitcast i8* $v to i32*\n"
965+
" $v_3 = getelementptr i32, i32* $v_2, i32 $v_1\n|"
966+
" $v_3 = getelementptr i32, ptr $v, i32 $v_1\n}"
967+
" $v = call $T @reduce_inc_u32({$t*} $v_3, $V)\n",
968+
v, index,
969+
v, ptr,
970+
v, v, v,
971+
v, ptr, v,
972+
v, v, v, v, mask);
973+
974+
fmt_intrinsic(
975+
"define internal <$w x i32> @reduce_inc_u32({i32*} %ptr, <$w x i1> %active) #0 ${\n"
976+
"L0:\n"
977+
" br label %L1\n\n"
978+
"L1:\n"
979+
" %index = phi i32 [ 0, %L0 ], [ %index_next, %L1 ]\n"
980+
" %sum = phi i32 [ 0, %L0 ], [ %sum_next, %L1 ]\n"
981+
" %sum_vec = phi <$w x i32> [ undef, %L0 ], [ %sum_vec_next, %L1 ]\n"
982+
" %active_i = extractelement <$w x i1> %active, i32 %index\n"
983+
" %active_u = zext i1 %active_i to i32\n"
984+
" %sum_next = add nuw i32 %sum, %active_u\n"
985+
" %sum_vec_next = insertelement <$w x i32> %sum_vec, i32 %sum, i32 %index\n"
986+
" %index_next = add nuw nsw i32 %index, 1\n"
987+
" %cond_1 = icmp eq i32 %index_next, $w\n"
988+
" br i1 %cond_1, label %L2, label %L1\n\n"
989+
"L2:\n"
990+
" %cond_2 = icmp eq i32 %sum_next, 0\n"
991+
" br i1 %cond_2, label %L4, label %L3\n\n"
992+
"L3:\n"
993+
" %old_1 = atomicrmw add {i32*} %ptr, i32 %sum_next monotonic\n"
994+
" %old_2 = insertelement <$w x i32> undef, i32 %old_1, i32 0\n"
995+
" %old_3 = shufflevector <$w x i32> %old_2, <$w x i32> undef, <$w x i32> $z\n"
996+
" %sum_vec_final = add <$w x i32> %sum_vec_next, %old_3\n"
997+
" br label %L4;\n\n"
998+
"L4:\n"
999+
" %sum_vec_combined = phi <$w x i32> [ %sum_vec_next, %L2 ], [ %sum_vec_final, %L3 ]\n"
1000+
" ret <$w x i32> %sum_vec_combined\n"
1001+
"$}"
1002+
);
1003+
1004+
v->consumed = 1;
1005+
}
1006+
9511007
static void jitc_llvm_render_scatter_kahan(const Variable *v, uint32_t v_index) {
9521008
const Extra &extra = state.extra[v_index];
9531009
const Variable *ptr_1 = jitc_var(extra.dep[0]),
@@ -1258,7 +1314,7 @@ void jitc_llvm_ray_trace(uint32_t func, uint32_t scene, int shadow_ray,
12581314
jitc_var_inc_ref(id);
12591315
}
12601316

1261-
for (uint32_t i = 0; i < (shadow_ray ? 1 : 6); ++i)
1317+
for (int i = 0; i < (shadow_ray ? 1 : 6); ++i)
12621318
out[i] = jitc_var_new_node_1(JitBackend::LLVM, VarKind::Extract,
12631319
i < 3 ? float_type : VarType::UInt32, size,
12641320
placeholder, index, jitc_var(index),
@@ -1467,7 +1523,7 @@ static void jitc_llvm_render_trace(uint32_t index, const Variable *v,
14671523

14681524
offset = (8 * float_size + 4) * width;
14691525

1470-
for (uint32_t i = 0; i < (shadow_ray ? 1 : 6); ++i) {
1526+
for (int i = 0; i < (shadow_ray ? 1 : 6); ++i) {
14711527
VarType vt = (i < 3) ? float_type : VarType::UInt32;
14721528
const char *tname = type_name_llvm[(int) vt];
14731529
uint32_t tsize = type_size[(int) vt];

0 commit comments

Comments
 (0)