|
48 | 48 | buffer.rewind_to(tmpoff); \
|
49 | 49 | } while (0);
|
50 | 50 |
|
51 |
| -// Forward declaration |
| 51 | +// Forward declarations |
52 | 52 | static void jitc_cuda_render_stmt(uint32_t index, const Variable *v);
|
53 | 53 | static void jitc_cuda_render_var(uint32_t index, const Variable *v);
|
54 | 54 | static void jitc_cuda_render_scatter(const Variable *v, const Variable *ptr,
|
55 | 55 | const Variable *value, const Variable *index,
|
56 | 56 | const Variable *mask);
|
57 | 57 | static void jitc_cuda_render_printf(uint32_t index, const Variable *v,
|
58 | 58 | const Variable *mask);
|
| 59 | +static void jitc_cuda_render_trace(uint32_t index, const Variable *v, |
| 60 | + const Variable *valid, |
| 61 | + const Variable *pipeline, |
| 62 | + const Variable *sbt); |
59 | 63 |
|
60 | 64 | void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group,
|
61 | 65 | uint32_t n_regs, uint32_t n_params) {
|
@@ -405,6 +409,9 @@ static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
|
405 | 409 | fmt(" mov.$b $v, $l;\n", v, v, v);
|
406 | 410 | break;
|
407 | 411 |
|
| 412 | + case VarKind::Nop: |
| 413 | + break; |
| 414 | + |
408 | 415 | case VarKind::Neg:
|
409 | 416 | if (jitc_is_uint(v))
|
410 | 417 | fmt(" neg.s$u $v, $v;\n", type_size[v->type]*8, v, a0);
|
@@ -741,12 +748,97 @@ static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
|
741 | 748 | fmt(" mov.$t $v, $v_v4.$c;\n", v, v, a0, "xyzw"[v->literal]);
|
742 | 749 | break;
|
743 | 750 |
|
| 751 | + case VarKind::TraceRay: |
| 752 | + jitc_cuda_render_trace(index, v, a0, a1, a2); |
| 753 | + break; |
| 754 | + |
| 755 | + case VarKind::TraceExtract: |
| 756 | + fmt(" mov.u32 $v, %u$u_scratch_$u;\n", v, a0->reg_index, (uint32_t) v->literal); |
| 757 | + break; |
| 758 | + |
744 | 759 | default:
|
745 | 760 | jitc_fail("jitc_cuda_render_var(): unhandled variable kind \"%s\"!",
|
746 | 761 | var_kind_name[(uint32_t) v->kind]);
|
747 | 762 | }
|
748 | 763 | }
|
749 | 764 |
|
| 765 | +static void jitc_cuda_render_trace(uint32_t index, const Variable *v, |
| 766 | + const Variable *valid, |
| 767 | + const Variable *pipeline, |
| 768 | + const Variable *sbt) { |
| 769 | + ThreadState *ts = thread_state(JitBackend::CUDA); |
| 770 | + OptixPipelineData *pipeline_p = (OptixPipelineData *) pipeline->literal; |
| 771 | + OptixShaderBindingTable *sbt_p = (OptixShaderBindingTable*) sbt->literal; |
| 772 | + bool problem = false; |
| 773 | + |
| 774 | + if (ts->optix_pipeline == state.optix_default_pipeline) { |
| 775 | + ts->optix_pipeline = pipeline_p; |
| 776 | + } else if (ts->optix_pipeline != pipeline_p) { |
| 777 | + jitc_log( |
| 778 | + Warn, |
| 779 | + "jit_eval(): more than one OptiX pipeline was used within a single " |
| 780 | + "kernel, which is not supported. Please split your kernel into " |
| 781 | + "smaller parts (e.g. using `dr::eval()`). Disabling the ray " |
| 782 | + "tracing operation to avoid potential undefined behavior."); |
| 783 | + problem = true; |
| 784 | + } |
| 785 | + |
| 786 | + if (ts->optix_sbt == state.optix_default_sbt) { |
| 787 | + ts->optix_sbt = sbt_p; |
| 788 | + } else if (ts->optix_sbt != sbt_p) { |
| 789 | + jitc_log( |
| 790 | + Warn, |
| 791 | + "jit_eval(): more than one OptiX shader binding table was used " |
| 792 | + "within a single kernel, which is not supported. Please split your " |
| 793 | + "kernel into smaller parts (e.g. using `dr::eval()`). Disabling " |
| 794 | + "the ray tracing operation to avoid potential undefined behavior."); |
| 795 | + problem = true; |
| 796 | + } |
| 797 | + |
| 798 | + Extra &extra = state.extra[index]; |
| 799 | + uint32_t payload_count = extra.n_dep - 15, |
| 800 | + reg = v->reg_index; |
| 801 | + |
| 802 | + fmt(" .reg.u32 %u$u_scratch_<32>;\n", reg); |
| 803 | + |
| 804 | + if (problem) { |
| 805 | + for (int i = 0; i < 32; ++i) |
| 806 | + fmt(" mov.b32 %u$u_scratch_$u, 0;\n", reg, i); |
| 807 | + return; |
| 808 | + } |
| 809 | + |
| 810 | + |
| 811 | + bool masked = !valid->is_literal() || valid->literal != 1; |
| 812 | + if (masked) |
| 813 | + fmt(" @!$v bra l_masked_$u;\n", valid, reg); |
| 814 | + |
| 815 | + fmt(" .reg.u32 %u$u_payload_type, %u$u_payload_count;\n" |
| 816 | + " mov.u32 %u$u_payload_type, 0;\n" |
| 817 | + " mov.u32 %u$u_payload_count, $u;\n", |
| 818 | + reg, reg, reg, reg, payload_count); |
| 819 | + |
| 820 | + put(" call ("); |
| 821 | + for (uint32_t i = 0; i < 32; ++i) |
| 822 | + fmt("%u$u_scratch_$u$s", reg, i, i + 1 < 32 ? ", " : ""); |
| 823 | + put("), _optix_trace_typed_32, ("); |
| 824 | + |
| 825 | + fmt("%u$u_payload_type, ", reg); |
| 826 | + for (uint32_t i = 0; i < 15; ++i) |
| 827 | + fmt("$v, ", jitc_var(extra.dep[i])); |
| 828 | + |
| 829 | + fmt("%u$u_payload_count, ", reg); |
| 830 | + for (uint32_t i = 15; i < extra.n_dep; ++i) |
| 831 | + fmt("$v$s", jitc_var(extra.dep[i]), (i - 15 < 32) ? ", " : ""); |
| 832 | + |
| 833 | + for (uint32_t i = payload_count; i < 32; ++i) |
| 834 | + fmt("%u$u_scratch_$u$s", reg, i, (i + 1 < 32) ? ", " : ""); |
| 835 | + |
| 836 | + put(");\n"); |
| 837 | + |
| 838 | + if (masked) |
| 839 | + fmt("\nl_masked_$u:\n", reg); |
| 840 | +} |
| 841 | + |
750 | 842 | static void jitc_cuda_render_scatter(const Variable *v,
|
751 | 843 | const Variable *ptr,
|
752 | 844 | const Variable *value,
|
|
0 commit comments