Skip to content

Commit ffaf248

Browse files
committed
Converted ray tracing operation to abstract IR (CUDA part)
1 parent a6ccd72 commit ffaf248

File tree

7 files changed

+165
-222
lines changed

7 files changed

+165
-222
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ mark_as_advanced(NANOTHREAD_ENABLE_TESTS)
7575

7676
if (DRJIT_ENABLE_OPTIX)
7777
set(DRJIT_OPTIX_FILES
78+
src/optix.h
7879
src/optix_api.h
7980
src/optix_api.cpp
80-
src/optix_core.h
8181
src/optix_core.cpp)
8282
endif()
8383

src/cuda_eval.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,18 @@
4848
buffer.rewind_to(tmpoff); \
4949
} while (0);
5050

51-
// Forward declaration
51+
// Forward declarations
5252
static void jitc_cuda_render_stmt(uint32_t index, const Variable *v);
5353
static void jitc_cuda_render_var(uint32_t index, const Variable *v);
5454
static void jitc_cuda_render_scatter(const Variable *v, const Variable *ptr,
5555
const Variable *value, const Variable *index,
5656
const Variable *mask);
5757
static void jitc_cuda_render_printf(uint32_t index, const Variable *v,
5858
const Variable *mask);
59+
static void jitc_cuda_render_trace(uint32_t index, const Variable *v,
60+
const Variable *valid,
61+
const Variable *pipeline,
62+
const Variable *sbt);
5963

6064
void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group,
6165
uint32_t n_regs, uint32_t n_params) {
@@ -405,6 +409,9 @@ static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
405409
fmt(" mov.$b $v, $l;\n", v, v, v);
406410
break;
407411

412+
case VarKind::Nop:
413+
break;
414+
408415
case VarKind::Neg:
409416
if (jitc_is_uint(v))
410417
fmt(" neg.s$u $v, $v;\n", type_size[v->type]*8, v, a0);
@@ -741,12 +748,97 @@ static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
741748
fmt(" mov.$t $v, $v_v4.$c;\n", v, v, a0, "xyzw"[v->literal]);
742749
break;
743750

751+
case VarKind::TraceRay:
752+
jitc_cuda_render_trace(index, v, a0, a1, a2);
753+
break;
754+
755+
case VarKind::TraceExtract:
756+
fmt(" mov.u32 $v, %u$u_scratch_$u;\n", v, a0->reg_index, (uint32_t) v->literal);
757+
break;
758+
744759
default:
745760
jitc_fail("jitc_cuda_render_var(): unhandled variable kind \"%s\"!",
746761
var_kind_name[(uint32_t) v->kind]);
747762
}
748763
}
749764

765+
static void jitc_cuda_render_trace(uint32_t index, const Variable *v,
766+
const Variable *valid,
767+
const Variable *pipeline,
768+
const Variable *sbt) {
769+
ThreadState *ts = thread_state(JitBackend::CUDA);
770+
OptixPipelineData *pipeline_p = (OptixPipelineData *) pipeline->literal;
771+
OptixShaderBindingTable *sbt_p = (OptixShaderBindingTable*) sbt->literal;
772+
bool problem = false;
773+
774+
if (ts->optix_pipeline == state.optix_default_pipeline) {
775+
ts->optix_pipeline = pipeline_p;
776+
} else if (ts->optix_pipeline != pipeline_p) {
777+
jitc_log(
778+
Warn,
779+
"jit_eval(): more than one OptiX pipeline was used within a single "
780+
"kernel, which is not supported. Please split your kernel into "
781+
"smaller parts (e.g. using `dr::eval()`). Disabling the ray "
782+
"tracing operation to avoid potential undefined behavior.");
783+
problem = true;
784+
}
785+
786+
if (ts->optix_sbt == state.optix_default_sbt) {
787+
ts->optix_sbt = sbt_p;
788+
} else if (ts->optix_sbt != sbt_p) {
789+
jitc_log(
790+
Warn,
791+
"jit_eval(): more than one OptiX shader binding table was used "
792+
"within a single kernel, which is not supported. Please split your "
793+
"kernel into smaller parts (e.g. using `dr::eval()`). Disabling "
794+
"the ray tracing operation to avoid potential undefined behavior.");
795+
problem = true;
796+
}
797+
798+
Extra &extra = state.extra[index];
799+
uint32_t payload_count = extra.n_dep - 15,
800+
reg = v->reg_index;
801+
802+
fmt(" .reg.u32 %u$u_scratch_<32>;\n", reg);
803+
804+
if (problem) {
805+
for (int i = 0; i < 32; ++i)
806+
fmt(" mov.b32 %u$u_scratch_$u, 0;\n", reg, i);
807+
return;
808+
}
809+
810+
811+
bool masked = !valid->is_literal() || valid->literal != 1;
812+
if (masked)
813+
fmt(" @!$v bra l_masked_$u;\n", valid, reg);
814+
815+
fmt(" .reg.u32 %u$u_payload_type, %u$u_payload_count;\n"
816+
" mov.u32 %u$u_payload_type, 0;\n"
817+
" mov.u32 %u$u_payload_count, $u;\n",
818+
reg, reg, reg, reg, payload_count);
819+
820+
put(" call (");
821+
for (uint32_t i = 0; i < 32; ++i)
822+
fmt("%u$u_scratch_$u$s", reg, i, i + 1 < 32 ? ", " : "");
823+
put("), _optix_trace_typed_32, (");
824+
825+
fmt("%u$u_payload_type, ", reg);
826+
for (uint32_t i = 0; i < 15; ++i)
827+
fmt("$v, ", jitc_var(extra.dep[i]));
828+
829+
fmt("%u$u_payload_count, ", reg);
830+
for (uint32_t i = 15; i < extra.n_dep; ++i)
831+
fmt("$v$s", jitc_var(extra.dep[i]), (i - 15 < 32) ? ", " : "");
832+
833+
for (uint32_t i = payload_count; i < 32; ++i)
834+
fmt("%u$u_scratch_$u$s", reg, i, (i + 1 < 32) ? ", " : "");
835+
836+
put(");\n");
837+
838+
if (masked)
839+
fmt("\nl_masked_$u:\n", reg);
840+
}
841+
750842
static void jitc_cuda_render_scatter(const Variable *v,
751843
const Variable *ptr,
752844
const Variable *value,

src/internal.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,22 @@
2929
#define DRJIT_PTR "<0x%" PRIxPTR ">"
3030

3131
enum VarKind : uint32_t {
32-
/// Invalid node (default)
32+
// Invalid node (default)
3333
Invalid,
3434

35-
/// An evaluated node representing data
35+
// An evaluated node representing data
3636
Data,
3737

38-
/// Legacy string-based IR statement
38+
// Legacy string-based IR statement
3939
Stmt,
4040

41-
/// A literal constant
42-
/// (note: this must be the last enumeration entry before the regular nodes start)
41+
// A literal constant
42+
// (note: this must be the last enumeration entry before the regular nodes start)
4343
Literal,
4444

45+
/// A no-op (generates no code)
46+
Nop,
47+
4548
// Common unary operations
4649
Neg, Not, Sqrt, Abs,
4750

@@ -102,9 +105,15 @@ enum VarKind : uint32_t {
102105
// Load all texels used for bilinear interpolation (CUDA)
103106
TexFetchBilerp,
104107

105-
// Extract a component from a preceding texture lookup (CUDA)
108+
// Extract a component from a prior texture lookup (CUDA)
106109
TexExtract,
107110

111+
// Perform a ray tracing call
112+
TraceRay,
113+
114+
// Extract a result from a prior ray tracing call
115+
TraceExtract,
116+
108117
// Denotes the number of different node types
109118
Count
110119
};

src/optix.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,5 @@ extern void jitc_optix_free(const Kernel &kernel);
5858
extern void jitc_optix_launch(ThreadState *ts, const Kernel &kernel,
5959
uint32_t size, const void *args, uint32_t n_args);
6060

61-
/// Mark a variable as an expression requiring compilation via OptiX
62-
extern void jitc_optix_mark(uint32_t index);
63-
6461
/// Optional: set the desired launch size
6562
extern void jitc_optix_set_launch_size(uint32_t width, uint32_t height, uint32_t samples);

0 commit comments

Comments
 (0)