Skip to content

Commit 91201ef

Browse files
committed
Prefix scan for various data types with inclusive/exclusive option
This commit improves the existing ``jit_scan()`` function with support for various data types: - int32/uint32 - uint64 - float - double The user scan now also specify whether the scan should be inclusive or exclusive. Finally, the commit adds comments to facilitate future modifications of this code.
1 parent 8ecaaf2 commit 91201ef

18 files changed

+10885
-3697
lines changed

include/drjit-core/array.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ Array empty(size_t size) {
364364
: AllocType::HostAsync,
365365
byte_size);
366366
return Array::steal(
367-
jit_var_map_mem(Array::Backend, Array::Type, ptr, size, 1));
367+
jit_var_mem_map(Array::Backend, Array::Type, ptr, size, 1));
368368
}
369369

370370
template <typename Array>

include/drjit-core/jit.h

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,30 +1592,42 @@ extern JIT_EXPORT void jit_memcpy_async(JIT_ENUM JitBackend backend, void *dst,
15921592
*/
15931593
extern JIT_EXPORT void jit_reduce(JIT_ENUM JitBackend backend, JIT_ENUM VarType type,
15941594
JIT_ENUM ReduceOp rtype,
1595-
const void *ptr, uint32_t size, void *out);
1595+
const void *in, uint32_t size, void *out);
15961596

1597-
/**
1598-
* \brief Perform an exclusive scan / prefix sum over an unsigned 32 bit integer
1599-
* array
1597+
/** \brief Compute n prefix sum over the given input array
1598+
*
1599+
* Both exclusive and inclusive variants are supported. If desired, the scan
1600+
* can be performed in-place (i.e., <tt>out == in</tt>). The operation runs
1601+
* asynchronously.
16001602
*
1601-
* If desired, the scan can be performed in-place (i.e. <tt>in == out</tt>).
1602-
* Note that the CUDA implementation will round up \c size to the maximum of
1603-
* the following three values for performance reasons:
1603+
* The operation is currenly implemented for the following numeric types:
1604+
* ``VarType::Int32``, ``VarType::UInt32``, ``VarType::UInt64``,
1605+
* ``VarType::Float32``, and ``VarType::Float64``.
16041606
*
1605-
* - the value 4,
1607+
* Note that the CUDA implementation may round \c size to the maximum of the
1608+
* following three values for performance and implementation-related reasons
1609+
* (the prefix sum uses a tree-based parallelization scheme).
1610+
*
1611+
* - the value 4
16061612
* - the next highest power of two (when size <= 4096),
16071613
* - the next highest multiple of 2K (when size > 4096),
16081614
*
16091615
* For this reason, the the supplied memory regions must be sufficiently large
1610-
* to avoid both out-of-bounds reads and writes. This is not an issue for
1611-
* memory obtained using \ref jit_malloc(), which internally rounds
1612-
* allocations to the next largest power of two and enforces a 64 byte minimum
1613-
* allocation size.
1616+
* to avoid out-of-bounds reads and writes. This is not an issue for memory
1617+
* obtained using \ref jit_malloc(), which internally rounds allocations to the
1618+
* next largest power of two and enforces a 64 byte minimum allocation size.
16141619
*
1615-
* Runs asynchronously.
1620+
* The CUDA backend implementation for *large* numeric types (double precision
1621+
* floats, 64 bit integers) has the following technical limitation: when
1622+
* reducing 64-bit integers, their values must be smaller than 2**62. When
1623+
* reducing double precision arrays, the two least significant mantissa bits
1624+
* are clamped to zero when forwarding the prefix from one 512-wide block to
1625+
* the next (at a very minor loss in accuracy). See the implementation for
1626+
* details on this.
16161627
*/
1617-
extern JIT_EXPORT void jit_scan_u32(JIT_ENUM JitBackend backend, const uint32_t *in,
1618-
uint32_t size, uint32_t *out);
1628+
extern JIT_EXPORT void jit_prefix_sum(JIT_ENUM JitBackend backend,
1629+
JIT_ENUM VarType type, int exclusive,
1630+
const void *in, uint32_t size, void *out);
16191631

16201632
/**
16211633
* \brief Compress a mask into a list of nonzero indices
@@ -1625,7 +1637,7 @@ extern JIT_EXPORT void jit_scan_u32(JIT_ENUM JitBackend backend, const uint32_t
16251637
* indices of nonzero entries to \c out (in increasing order), and it
16261638
* furthermore returns the total number of nonzero mask entries.
16271639
*
1628-
* The internals resemble \ref jit_scan_u32(), and the CUDA implementation may
1640+
* The internals resemble \ref jit_prefix_sum_u32(), and the CUDA implementation may
16291641
* similarly access regions beyond the end of \c in and \c out.
16301642
*
16311643
* This function internally performs a synchronization step.

resources/Makefile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
COMPUTE_CAPABILITY=compute_70
22
CUDA_VER=10.2
3-
NVCC=/usr/local/cuda-$(CUDA_VER)/bin/nvcc -m64 --ptx --expt-relaxed-constexpr
3+
NVCC=/usr/local/cuda-$(CUDA_VER)/bin/nvcc -m64 --ptx --expt-relaxed-constexpr -std=c++14
44

55
all: kernels.h
66

7-
kernels_50.ptx: reduce.cuh scan.cuh compress.cuh mkperm.cuh misc.cuh kernels.cu
7+
kernels_50.ptx: reduce.cuh prefix_sum.cuh compress.cuh mkperm.cuh misc.cuh kernels.cu
88
$(NVCC) --Wno-deprecated-gpu-targets -gencode arch=compute_50,code=compute_50 kernels.cu -o kernels_50.ptx
99

10-
kernels_70.ptx: reduce.cuh scan.cuh compress.cuh mkperm.cuh misc.cuh kernels.cu
10+
kernels_70.ptx: reduce.cuh prefix_sum.cuh compress.cuh mkperm.cuh misc.cuh kernels.cu
1111
$(NVCC) -Wno-deprecated-gpu-targets -gencode arch=compute_70,code=compute_70 kernels.cu -o kernels_70.ptx
1212

1313
kernels.dict:

resources/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#include <limits>
66

77
#define KERNEL extern "C" __global__
8+
#define DEVICE __device__
9+
#define FINLINE __forceinline__
10+
#define WARP_SIZE 32
11+
#define FULL_MASK 0xffffffff
812

913
template <typename T> struct SharedMemory {
1014
__device__ inline static T *get() {

resources/compress.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@
1010

1111
#include "common.h"
1212

13+
DEVICE FINLINE void store_cg(uint64_t *ptr, uint64_t val) {
14+
asm volatile("st.cg.u64 [%0], %1;" : : "l"(ptr), "l"(val));
15+
}
16+
17+
DEVICE FINLINE uint64_t load_cg(uint64_t *ptr) {
18+
uint64_t retval;
19+
asm volatile("ld.cg.u64 %0, [%1];" : "=l"(retval) : "l"(ptr));
20+
return retval;
21+
}
22+
1323
KERNEL void compress_small(const uint8_t *in, uint32_t *out, uint32_t size, uint32_t *count_out) {
1424
uint32_t *shared = SharedMemory<uint32_t>::get();
1525

0 commit comments

Comments
 (0)