Skip to content

Shader Execution Reordering #395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ Miscellaneous operations
.. autofunction:: binary_search
.. autofunction:: make_opaque
.. autofunction:: copy
.. autofunction:: reorder_threads

Just-in-time compilation
------------------------
Expand Down
5 changes: 5 additions & 0 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ nanobind_add_module(
tracker.h tracker.cpp
local.h local.cpp
resample.h resample.cpp
reorder.h reorder.cpp

# Backends
scalar.h scalar.cpp
Expand Down Expand Up @@ -141,6 +142,10 @@ if (DRJIT_ENABLE_CUDA)
target_compile_definitions(drjit-python PRIVATE -DDRJIT_ENABLE_CUDA)
endif()

if (DRJIT_ENABLE_OPTIX)
target_compile_definitions(drjit-python PRIVATE -DDRJIT_ENABLE_OPTIX)
endif()

# Disable leak warnings by default in PyPI release builds
if (SKBUILD)
target_compile_definitions(drjit-python PRIVATE -DDRJIT_DISABLE_LEAK_WARNINGS)
Expand Down
57 changes: 57 additions & 0 deletions src/python/docstr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6040,6 +6040,13 @@
Note that this information can also be queried in a more fine-grained
manner (per variable) using the :py:attr:`drjit.ArrayBase.state` field.

.. topic:: JitFlag_ShaderExecutionReordering

Enable OptiX's SER feature in ray tracing functions and in
:py:func:`reorder_threads`. This flag only applies to the CUDA backend.

This flag is *enabled* by default.

.. topic:: JitFlag_Default

The default set of optimization flags consisting of
Expand All @@ -6056,6 +6063,7 @@
- :py:attr:`drjit.JitFlag.ReuseIndices`, and
- :py:attr:`drjit.JitFlag.ScatterReduceLocal`.
- :py:attr:`drjit.JitFlag.PacketOps`.
- :py:attr:`drjit.JitFlag.ShaderExecutionReordering`.

.. topic:: JitFlag_LoopRecord

Expand Down Expand Up @@ -8104,3 +8112,52 @@
.. topic:: leak_warnings

Query whether leak warnings are enabled. See :py:func:`drjit.detail.set_leak_warnings()`.

.. topic:: reorder_threads

Trigger a call to the Shader Execution Reordering (SER) feature of the GPU.

This function performs a hardware-assisted shuffle of the GPU threads to
improve the kernel occupancy by reducing warp-level divergence in certain
workloads. In order to perorm this shuffle, it requires a sorting key to
indicate which threads should be grouped into coherent warps.

An extra ``value`` argument must be passed to the function. This argument
will be returned as is but internally Dr.Jit will add some tracking to it to
guarantee that, on any subsequent use of it, a reordering operation will be
inserted in the kernel.

Example usage:

.. code-block:: python

arg = dr.cuda.Array3f(...)
key = dr.cuda.UInt32(...) % 4

# Reorder threads before `dr.switch()` to reduce divergence
# Only do it if `arg` is used
arg = dr.reorder_threads(key, 2, arg)

callables = [...]
callable_idx = dr.cuda.UInt32(...)
out = dr.switch(callable_idx, callables, arg)

When :py:attr:`drjit.JitFlag.ShaderExecutionReordering` is **not** set, or
when using the LLVM backend, this operation is a no-op.

Args:
key (drjit.ArrayBase): A 1D unsigned integer 32-bit array that serves as
a sorting key for the shuffle operation. Only the lower ``num_bits``
are used.

num_bits (int): Number of bits from the key to use (starting from the
least signifcant bit). It is recommended to use as few as possible. At
most, 16 bits can be used.

value (object): An arbitrary Dr.Jit array, tensor, or :ref:`PyTree <pytrees>`.
This argument is returned without modification. The reordering will
only happen if the returned version of this arugment is used.

Returns:
object: The updated ``value`` variable that will trigger the reordering
if used.
3 changes: 3 additions & 0 deletions src/python/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "tracker.h"
#include "local.h"
#include "resample.h"
#include "reorder.h"

static int active_backend = -1;

Expand Down Expand Up @@ -106,6 +107,7 @@ NB_MODULE(_drjit_ext, m_) {
.value("ScatterReduceLocal", JitFlag::ScatterReduceLocal, doc_JitFlag_ScatterReduceLocal)
.value("SymbolicConditionals", JitFlag::SymbolicConditionals, doc_JitFlag_SymbolicConditionals)
.value("SymbolicScope", JitFlag::SymbolicScope, doc_JitFlag_SymbolicScope)
.value("ShaderExecutionReordering", JitFlag::ShaderExecutionReordering, doc_JitFlag_ShaderExecutionReordering)
.value("Default", JitFlag::Default, doc_JitFlag_Default)

// Deprecated aliases
Expand Down Expand Up @@ -250,6 +252,7 @@ NB_MODULE(_drjit_ext, m_) {
export_tracker(detail);
export_local(m);
export_resample(m);
export_reorder(m);

export_scalar(scalar);

Expand Down
60 changes: 60 additions & 0 deletions src/python/reorder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
reorder.cpp -- Bindings for drjit.reorder_threads()

Dr.Jit: A Just-In-Time-Compiler for Differentiable Rendering
Copyright 2023, Realistic Graphics Lab, EPFL.

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE.txt file.
*/

#include "reorder.h"
#include "detail.h"
#include <drjit/autodiff.h>
#include <drjit-core/optix.h>

nb::object reorder_threads(nb::handle_t<dr::ArrayBase> key, int num_bits,
nb::handle value) {
const ArraySupplement &s_key = supp(key.type());
if (s_key.ndim != 1 || s_key.type != (uint8_t) VarType::UInt32 ||
s_key.backend == (uint8_t) JitBackend::None)
nb::raise("drjit.reorder_threads(): 'key' must be a JIT-compiled 32 "
"bit unsigned integer array (e.g., 'drjit.cuda.UInt32' or "
"'drjit.llvm.ad.UInt32')");

dr::vector<uint64_t> value_indices;
::collect_indices(value, value_indices);
if (value_indices.size() == 0)
nb::raise("drjit.reorder_threads(): 'value' must be a valid PyTree "
"containing at least one JIT-compiled type");

#if defined(DRJIT_ENABLE_OPTIX)
uint32_t n_values = (uint32_t) value_indices.size();

// Extract JIT indices
dr::vector<uint32_t> jit_indices(n_values);
for (size_t i = 0; i < n_values; ++i)
jit_indices[i] = (uint32_t) value_indices[i];

// Create updated values with reordering
dr::detail::index32_vector out_indices(n_values);
jit_optix_reorder(s_key.index(inst_ptr(key)), num_bits, n_values,
jit_indices.data(), out_indices.data());

// Re-combine with AD indices
dr::vector<uint64_t> new_value_indices(n_values);
for (size_t i = 0; i < n_values; ++i) {
uint32_t ad_index = value_indices[i] >> 32;
new_value_indices[i] = (((uint64_t) ad_index) << 32 | ((uint64_t) out_indices[i]));
}

return ::update_indices(value, new_value_indices);
#endif

return nb::borrow(value);
}

void export_reorder(nb::module_ &m) {
m.def("reorder_threads", &reorder_threads, "key"_a, "num_bits"_a, "value"_a,
doc_reorder_threads);
}
16 changes: 16 additions & 0 deletions src/python/reorder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
reorder.h -- Bindings for drjit.reorder_threads()

Dr.Jit: A Just-In-Time-Compiler for Differentiable Rendering
Copyright 2023, Realistic Graphics Lab, EPFL.

All rights reserved. Use of this source code is governed by a
BSD-style license that can be found in the LICENSE.txt file.
*/

#pragma once

#include "common.h"

extern nb::object reorder_threads(nb::handle_t<dr::ArrayBase>, int, nb::handle);
extern void export_reorder(nb::module_ &);
55 changes: 55 additions & 0 deletions tests/test_reorder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import drjit as dr
import pytest

# Just an existence test
@pytest.test_arrays('float32, jit, shape=(*)')
def test01_reorder_switch(t):
UInt32 = dr.uint32_array_t(t)
N = 4

idx = dr.arange(UInt32, N) % 2
arg = dr.arange(t, N)
dr.make_opaque(arg)

def cheap_func(arg):
return arg

def expensive_func(arg):
return arg * 2

result = dr.switch(idx, [cheap_func, expensive_func], arg)
dr.allclose(result, [0, 2, 2, 6])


# Test that reorder is valid inside loops
@pytest.test_arrays('float32, jit, shape=(*)')
@pytest.mark.parametrize('mode', ['evaluated', 'symbolic'])
@dr.syntax(recursive=True)
def test01_reorder_loop(t, mode):
UInt32 = dr.uint32_array_t(t)
N = 4

idx = dr.arange(UInt32, N)
arg = dr.arange(t, N)
dr.make_opaque(arg)

def f(arg):
i = UInt32(0)
while dr.hint(i < 32, mode=mode):
j = UInt32(0)

# Arbitrary aritmetic
while dr.hint(j < 10, mode=mode):
arg = arg + j
j += 1

i = dr.reorder_threads(idx % 32, 2, i)
i += 1

# Early exit one thread per iteraion
i = dr.select(idx % 32 < i, 100, i)

return arg

result = f(arg)
dr.allclose(result, [45, 91, 137, 183])