diff --git a/include/drjit-core/jit.h b/include/drjit-core/jit.h
index 1aef7dd31..80beed7e4 100644
--- a/include/drjit-core/jit.h
+++ b/include/drjit-core/jit.h
@@ -1285,7 +1285,7 @@ extern JIT_EXPORT void jit_prefix_pop(JIT_ENUM JitBackend backend);
* The default set of flags is:
*
* ConstProp | ValueNumbering | LoopRecord | LoopOptimize |
- * VCallRecord | VCallOptimize | ADOptimize
+ * VCallRecord | VCallDeduplicate | VCallOptimize | ADOptimize
*/
#if defined(__cplusplus)
enum class JitFlag : uint32_t {
@@ -1304,38 +1304,51 @@ enum class JitFlag : uint32_t {
/// Record virtual function calls instead of splitting them into many small kernel launches
VCallRecord = 16,
+ /**
+ * \brief Use branches instead of direct callables (in OptiX) or indirect
+ * function calls (in CUDA) for virtual function calls. The default
+ * branching strategy is a linear search among all targets.
+ */
+ VCallBranch = 32,
+
+ /// Use a jump table to reach appropriate target when `VCallBranch` is enabled
+ VCallBranchJumpTable = 64,
+
+ /// Perform a binary search to find the appropriate target when `VCallBranch` is enabled
+ VCallBranchBinarySearch = 128,
+
/// De-duplicate virtual function calls that produce the same code
- VCallDeduplicate = 32,
+ VCallDeduplicate = 256,
/// Enable constant propagation and elide unnecessary function arguments
- VCallOptimize = 64,
+ VCallOptimize = 512,
/**
* \brief Inline calls if there is only a single instance? (off by default,
* inlining can make kernels so large that they actually run slower in
* CUDA/OptiX).
*/
- VCallInline = 128,
+ VCallInline = 1024,
/// Force execution through OptiX even if a kernel doesn't use ray tracing
- ForceOptiX = 256,
+ ForceOptiX = 2048,
/// Temporarily postpone evaluation of statements with side effects
- Recording = 512,
+ Recording = 4096,
/// Print the intermediate representation of generated programs
- PrintIR = 1024,
+ PrintIR = 8192,
/// Enable writing of the kernel history
- KernelHistory = 2048,
+ KernelHistory = 16384,
/* Force synchronization after every kernel launch. This is useful to
isolate crashes to a specific kernel, and to benchmark kernel runtime
along with the KernelHistory feature. */
- LaunchBlocking = 4096,
+ LaunchBlocking = 32768,
/// Exploit literal constants during AD (used in the Dr.Jit parent project)
- ADOptimize = 8192,
+ ADOptimize = 65536,
/// Default flags
Default = (uint32_t) ConstProp | (uint32_t) ValueNumbering |
@@ -1345,20 +1358,23 @@ enum class JitFlag : uint32_t {
};
#else
enum JitFlag {
- JitFlagConstProp = 1,
- JitFlagValueNumbering = 2,
- JitFlagLoopRecord = 4,
- JitFlagLoopOptimize = 8,
- JitFlagVCallRecord = 16,
- JitFlagVCallDeduplicate = 32,
- JitFlagVCallOptimize = 64,
- JitFlagVCallInline = 128,
- JitFlagForceOptiX = 256,
- JitFlagRecording = 512,
- JitFlagPrintIR = 1024,
- JitFlagKernelHistory = 2048,
- JitFlagLaunchBlocking = 4096,
- JitFlagADOptimize = 8192
+ JitFlagConstProp = 1,
+ JitFlagValueNumbering = 2,
+ JitFlagLoopRecord = 4,
+ JitFlagLoopOptimize = 8,
+ JitFlagVCallRecord = 16,
+ JitFlagVCallBranch = 32,
+ JitFlagVCallBranchJumpTable = 64,
+ JitFlagVCallBranchBinarySearch = 128,
+ JitFlagVCallDeduplicate = 256,
+ JitFlagVCallOptimize = 512,
+ JitFlagVCallInline = 1024,
+ JitFlagForceOptiX = 2048,
+ JitFlagRecording = 4096,
+ JitFlagPrintIR = 8192,
+ JitFlagKernelHistory = 16384,
+ JitFlagLaunchBlocking = 32768,
+ JitFlagADOptimize = 65536,
};
#endif
diff --git a/src/eval_cuda.cpp b/src/eval_cuda.cpp
index 9d4741bc9..d4af5553b 100644
--- a/src/eval_cuda.cpp
+++ b/src/eval_cuda.cpp
@@ -203,30 +203,48 @@ void jitc_assemble_cuda(ThreadState *ts, ScheduledGroup group,
it.second.callable_index = ctr++;
}
- if (callable_count > 0 && !uses_optix) {
+ if (callable_count > 0) {
size_t insertion_point =
(char *) strstr(buffer.get(), ".address_size 64\n\n") -
buffer.get() + 18,
insertion_start = buffer.size();
- buffer.fmt(".extern .global .u64 callables[%u];\n\n",
- callable_count_unique);
+ if (jit_flag(JitFlag::VCallBranch)) {
+ // Copy signatures to very beginning
+ for (const auto &it : globals_map) {
+ if (!it.first.callable)
+ continue;
- jitc_insert_code_at(insertion_point, insertion_start);
+ const char* func_definition = globals.get() + it.second.start;
+ const char* signature_begin = strstr(func_definition, ".func");
+ const char* signature_end = strstr(func_definition, "{");
- buffer.fmt("\n.visible .global .align 8 .u64 callables[%u] = {\n",
- callable_count_unique);
- for (auto const &it : globals_map) {
- if (!it.first.callable)
- continue;
+ buffer.put(".visible ");
+ buffer.put(signature_begin,
+ signature_end - 1 - signature_begin);
+ buffer.put(";\n");
+ }
+ buffer.fmt("\n");
+ jitc_insert_code_at(insertion_point, insertion_start);
+ } else if (!uses_optix) {
+ buffer.fmt(".extern .global .u64 callables[%u];\n\n",
+ callable_count_unique);
+ jitc_insert_code_at(insertion_point, insertion_start);
+
+ buffer.fmt("\n.visible .global .align 8 .u64 callables[%u] = {\n",
+ callable_count_unique);
+ for (auto const &it : globals_map) {
+ if (!it.first.callable)
+ continue;
+
+ buffer.fmt(" func_%016llx%016llx%s\n",
+ (unsigned long long) it.first.hash.high64,
+ (unsigned long long) it.first.hash.low64,
+ it.second.callable_index + 1 < callable_count_unique ? "," : "");
+ }
- buffer.fmt(" func_%016llx%016llx%s\n",
- (unsigned long long) it.first.hash.high64,
- (unsigned long long) it.first.hash.low64,
- it.second.callable_index + 1 < callable_count_unique ? "," : "");
+ buffer.put("};\n\n");
}
-
- buffer.put("};\n\n");
}
jitc_vcall_upload(ts);
@@ -245,8 +263,9 @@ void jitc_assemble_cuda_func(const char *name, uint32_t inst_id,
buffer.put(".visible .func");
if (out_size) buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size);
+ bool uses_direct_callables = uses_optix && !(jit_flag(JitFlag::VCallBranch));
buffer.fmt(" %s^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^(",
- uses_optix ? "__direct_callable__" : "func_");
+ uses_direct_callables ? "__direct_callable__" : "func_");
if (use_self) {
buffer.put(".reg .u32 self");
diff --git a/src/hash.h b/src/hash.h
index a4fbaf1af..e31631004 100644
--- a/src/hash.h
+++ b/src/hash.h
@@ -71,6 +71,18 @@ struct XXH128Cmp {
}
};
+struct XXH128Eq {
+ bool operator()(const XXH128_hash_t &lhs, const XXH128_hash_t &rhs) const {
+ return lhs.high64 == rhs.high64 && lhs.low64 == rhs.low64;
+ }
+};
+
+struct XXH128Hash {
+ size_t operator()(const XXH128_hash_t &hash) const {
+ return hash.low64 ^ hash.high64;
+ }
+};
+
inline void hash_combine(size_t& seed, size_t value) {
/// From CityHash (https://github.com/google/cityhash)
const size_t mult = 0x9ddfea08eb382d69ull;
diff --git a/src/optix_api.cpp b/src/optix_api.cpp
index 830ee6ec8..c4e90c542 100644
--- a/src/optix_api.cpp
+++ b/src/optix_api.cpp
@@ -662,19 +662,21 @@ bool jitc_optix_compile(ThreadState *ts, const char *buf, size_t buf_size,
pgd[0].raygen.module = kernel.optix.mod;
pgd[0].raygen.entryFunctionName = strdup(kern_name);
- for (auto const &it : globals_map) {
- if (!it.first.callable)
- continue;
-
- char *name = (char *) malloc_check(52);
- snprintf(name, 52, "__direct_callable__%016llx%016llx",
- (unsigned long long) it.first.hash.high64,
- (unsigned long long) it.first.hash.low64);
-
- uint32_t index = 1 + it.second.callable_index;
- pgd[index].kind = OPTIX_PROGRAM_GROUP_KIND_CALLABLES;
- pgd[index].callables.moduleDC = kernel.optix.mod;
- pgd[index].callables.entryFunctionNameDC = name;
+ if (!jit_flag(JitFlag::VCallBranch)) {
+ for (auto const &it : globals_map) {
+ if (!it.first.callable)
+ continue;
+
+ char *name = (char *) malloc_check(52);
+ snprintf(name, 52, "__direct_callable__%016llx%016llx",
+ (unsigned long long) it.first.hash.high64,
+ (unsigned long long) it.first.hash.low64);
+
+ uint32_t index = 1 + it.second.callable_index;
+ pgd[index].kind = OPTIX_PROGRAM_GROUP_KIND_CALLABLES;
+ pgd[index].callables.moduleDC = kernel.optix.mod;
+ pgd[index].callables.entryFunctionNameDC = name;
+ }
}
kernel.optix.pg = new OptixProgramGroup[n_programs];
diff --git a/src/vcall.cpp b/src/vcall.cpp
index dfeb946f6..674488a43 100644
--- a/src/vcall.cpp
+++ b/src/vcall.cpp
@@ -69,6 +69,8 @@ struct VCall {
/// Does this vcall need self as argument
bool use_self = false;
+ CallablesSet callables_set;
+
~VCall() {
for (uint32_t index : out_nested)
jitc_var_dec_ref(index);
@@ -773,7 +775,7 @@ static void jitc_var_vcall_assemble(VCall *vcall,
ThreadState *ts = thread_state(vcall->backend);
- CallablesSet callables_set;
+ vcall->callables_set.clear();
for (uint32_t i = 0; i < vcall->n_inst; ++i) {
XXH128_hash_t hash = jitc_assemble_func(
ts, vcall->name, i, in_size, in_align, out_size, out_align,
@@ -783,7 +785,7 @@ static void jitc_var_vcall_assemble(VCall *vcall,
vcall->side_effects.data() + vcall->checkpoints[i],
vcall->use_self);
vcall->inst_hash[i] = hash;
- callables_set.insert(hash);
+ vcall->callables_set.insert(hash);
}
size_t se_count = vcall->side_effects.size();
@@ -819,13 +821,83 @@ static void jitc_var_vcall_assemble(VCall *vcall,
InfoSym,
"jit_var_vcall_assemble(): indirect call (\"%s\") to %zu/%u instances, "
"passing %u/%u inputs (%u/%u bytes), %u/%u outputs (%u/%u bytes), %zu side effects",
- vcall->name, callables_set.size(), vcall->n_inst, n_in_active,
+ vcall->name, vcall->callables_set.size(), vcall->n_inst, n_in_active,
vcall->in_count_initial, in_size, vcall->in_size_initial, n_out_active,
n_out, out_size, vcall->out_size_initial, se_count);
vcalls_assembled.push_back(vcall);
}
+static void jitc_var_vcall_branch_strategy_assemble_cuda(const VCall *vcall,
+ uint32_t vcall_reg) {
+ bool jump_table = jit_flag(JitFlag::VCallBranchJumpTable);
+ bool binary_search = jit_flag(JitFlag::VCallBranchBinarySearch);
+
+ if (jump_table == binary_search) {
+ // Linear search
+ if (jump_table)
+ jitc_log(Warn, "jitc_var_vcall_assemble_cuda(): both "
+ "JitFlag::VCallBranchJumpTable and "
+ "JitFlag::VCallBranchBinarySearch are enabled, "
+ "defaulting back to linear search!");
+
+ for (size_t i = 0; i < vcall->callables_set.size(); ++i) {
+ buffer.fmt(" setp.eq.u32 %%p3, %%r3, %u;\n", (uint32_t) i);
+ buffer.fmt(" @%%p3 bra l_%u_%u;\n", vcall_reg, (uint32_t) i);
+ }
+ } else if (binary_search) {
+ uint32_t size = vcall->callables_set.size();
+
+ uint32_t max_depth = log2i_ceil(size);
+ for (uint32_t depth = 0; depth < max_depth; ++depth) {
+ for (uint32_t i = 0; i < (1 << depth); ++i) {
+ uint32_t range_start = i << (max_depth - depth);
+ if (size <= range_start)
+ break;
+
+ uint32_t offset = 1 << (max_depth - depth - 1);
+ uint32_t spacing = offset * 2;
+ uint32_t mid = offset + i * spacing;
+
+ uint32_t next_offset = offset >> 1;
+ uint32_t next_spacing = spacing >> 1;
+ uint32_t left = next_offset + (i * 2) * next_spacing;
+ uint32_t right = next_offset + ((i * 2) + 1) * next_spacing;
+
+ if (depth != 0)
+ buffer.fmt(" l_%u_%u_%u:\n", vcall_reg, depth, mid);
+
+ if (mid < size) {
+ buffer.fmt(" setp.lt.u32 %%p3, %%r3, %u;\n", mid);
+ if (depth + 1 < max_depth) {
+ buffer.fmt(" @%%p3 bra l_%u_%u_%u;\n", vcall_reg, depth + 1, left);
+ buffer.fmt(" bra.uni l_%u_%u_%u;\n", vcall_reg, depth + 1, right);
+ } else {
+ buffer.fmt(" @%%p3 bra.uni l_%u_%u;\n", vcall_reg, left);
+ buffer.fmt(" bra.uni l_%u_%u;\n", vcall_reg, right);
+ }
+ } else {
+ if (depth + 1 < max_depth)
+ buffer.fmt(" bra l_%u_%u_%u;\n", vcall_reg, depth + 1, left);
+ else
+ buffer.fmt(" bra.uni l_%u_%u;\n", vcall_reg, left);
+ }
+ }
+ }
+ } else {
+ // Jump table
+ buffer.put(" ts: .branchtargets ");
+ for (size_t i = 0; i < vcall->callables_set.size(); ++i) {
+ if (i != 0)
+ buffer.put(", ");
+ buffer.fmt("l_%u_%u", vcall_reg, (uint32_t) i);
+ }
+ buffer.put(";\n brx.idx %r3, ts;\n");
+ }
+
+ buffer.put("\n");
+}
+
/// Virtual function call code generation -- CUDA/PTX-specific bits
static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg,
uint32_t self_reg, uint32_t mask_reg,
@@ -834,6 +906,8 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg,
uint32_t in_align, uint32_t out_size,
uint32_t out_align) {
+ bool branch_vcall = jit_flag(JitFlag::VCallBranch);
+
// =====================================================
// 1. Conditional branch
// =====================================================
@@ -857,10 +931,12 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg,
// 3. Turn callable ID into a function pointer
// =====================================================
- if (!uses_optix)
- buffer.fmt(" ld.global.u64 %%rd2, callables[%%r3];\n");
- else
- buffer.put(" call (%rd2), _optix_call_direct_callable, (%r3);\n");
+ if (!branch_vcall) {
+ if (!uses_optix)
+ buffer.fmt(" ld.global.u64 %%rd2, callables[%%r3];\n");
+ else
+ buffer.put(" call (%rd2), _optix_call_direct_callable, (%r3);\n");
+ }
// =====================================================
// 4. Obtain pointer to supplemental call data
@@ -871,7 +947,6 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg,
" add.u64 %%rd3, %%rd3, %%rd%u;\n",
data_reg);
}
-
// %rd2: function pointer (if applicable)
// %rd3: call data pointer with offset
@@ -895,131 +970,161 @@ static void jitc_var_vcall_assemble_cuda(VCall *vcall, uint32_t vcall_reg,
v2->reg_index, v2->reg_index);
}
- buffer.put(" {\n");
+ // Switch statement: branch to call (multiple strategies)
+ if (branch_vcall)
+ jitc_var_vcall_branch_strategy_assemble_cuda(vcall, vcall_reg);
+
+ uint32_t callable_id = 0;
+ for (const XXH128_hash_t &callable_hash : vcall->callables_set) {
+ if (!branch_vcall) {
+ // Generate call prototype
+ buffer.put(" {\n");
+ buffer.put(" proto: .callprototype");
+ if (out_size)
+ buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size);
+ buffer.put(" _(");
+ if (vcall->use_self) {
+ buffer.put(".reg .u32 self");
+ if (data_reg || in_size)
+ buffer.put(", ");
+ }
+ if (data_reg) {
+ buffer.put(".reg .u64 data");
+ if (in_size)
+ buffer.put(", ");
+ }
+ if (in_size)
+ buffer.fmt(".param .align %u .b8 params[%u]", in_align, in_size);
+ buffer.put(");\n");
+ } else {
+ buffer.fmt(" l_%u_%u:\n", vcall_reg, callable_id);
+ buffer.put(" {\n");
+ }
- // Call prototype
- buffer.put(" proto: .callprototype");
- if (out_size)
- buffer.fmt(" (.param .align %u .b8 result[%u])", out_align, out_size);
- buffer.put(" _(");
- if (vcall->use_self) {
- buffer.put(".reg .u32 self");
- if (data_reg || in_size)
- buffer.put(", ");
- }
- if (data_reg) {
- buffer.put(".reg .u64 data");
+ // Input/output parameter arrays
+ if (out_size)
+ buffer.fmt(" .param .align %u .b8 out[%u];\n", out_align, out_size);
if (in_size)
- buffer.put(", ");
- }
- if (in_size)
- buffer.fmt(".param .align %u .b8 params[%u]", in_align, in_size);
- buffer.put(");\n");
-
- // Input/output parameter arrays
- if (out_size)
- buffer.fmt(" .param .align %u .b8 out[%u];\n", out_align, out_size);
- if (in_size)
- buffer.fmt(" .param .align %u .b8 in[%u];\n", in_align, in_size);
-
- // =====================================================
- // 5.1. Pass the input arguments
- // =====================================================
-
- uint32_t offset = 0;
- for (uint32_t in : vcall->in) {
- auto it = state.variables.find(in);
- if (it == state.variables.end())
- continue;
- const Variable *v2 = &it->second;
- uint32_t size = type_size[v2->type];
+ buffer.fmt(" .param .align %u .b8 in[%u];\n", in_align, in_size);
- const char *tname = type_name_ptx[v2->type],
- *prefix = type_prefix[v2->type];
+ // =====================================================
+ // 5.1. Pass the input arguments
+ // =====================================================
- // Special handling for predicates (pass via u8)
- if ((VarType) v2->type == VarType::Bool) {
- tname = "u8";
- prefix = "%w";
- }
+ uint32_t offset = 0;
+ for (uint32_t in : vcall->in) {
+ auto it = state.variables.find(in);
+ if (it == state.variables.end())
+ continue;
+ const Variable *v2 = &it->second;
+ uint32_t size = type_size[v2->type];
- buffer.fmt(" st.param.%s [in+%u], %s%u;\n", tname, offset,
- prefix, v2->reg_index);
+ const char *tname = type_name_ptx[v2->type],
+ *prefix = type_prefix[v2->type];
- offset += size;
- }
-
- if (vcall->use_self) {
- buffer.fmt(" call %s%%rd2, (%%r%u%s%s), proto;\n",
- out_size ? "(out), " : "", self_reg,
- data_reg ? ", %rd3" : "",
- in_size ? ", in" : "");
- } else {
- buffer.fmt(" call %s%%rd2, (%s%s%s), proto;\n",
- out_size ? "(out), " : "", data_reg ? "%rd3" : "",
- data_reg && in_size ? ", " : "", in_size ? "in" : "");
- }
+ // Special handling for predicates (pass via u8)
+ if ((VarType) v2->type == VarType::Bool) {
+ tname = "u8";
+ prefix = "%w";
+ }
- // =====================================================
- // 5.2. Read back the output arguments
- // =====================================================
+ buffer.fmt(" st.param.%s [in+%u], %s%u;\n", tname, offset,
+ prefix, v2->reg_index);
- offset = 0;
- for (uint32_t i = 0; i < n_out; ++i) {
- uint32_t index = vcall->out_nested[i],
- index_2 = vcall->out[i];
- auto it = state.variables.find(index);
- if (it == state.variables.end())
- continue;
- uint32_t size = type_size[it->second.type],
- load_offset = offset;
- offset += size;
+ offset += size;
+ }
- // Skip if outer access expired
- auto it2 = state.variables.find(index_2);
- if (it2 == state.variables.end())
- continue;
+ // =====================================================
+ // 5.2. Setup the function call
+ // =====================================================
+
+ auto assemble_call = [&](const char* target) {
+ buffer.put(" ");
+ if (vcall->use_self) {
+ buffer.fmt("call %s%s, (%%r%u%s%s)%s;\n",
+ out_size ? "(out), " : "",
+ target,
+ self_reg,
+ data_reg ? ", %rd3" : "",
+ in_size ? ", in" : "",
+ branch_vcall ? "" : ", proto");
+ } else {
+ buffer.fmt("call %s%s, (%s%s%s)%s;\n",
+ out_size ? "(out), " : "",
+ target,
+ data_reg ? "%rd3" : "",
+ data_reg && in_size ? ", " : "",
+ in_size ? "in" : "",
+ branch_vcall ? "" : ", proto");
+ }
+ };
+
+ // =====================================================
+ // 5.3. Call the function and read the output arguments
+ // =====================================================
+
+ auto read_output_arguments = [&]() {
+ offset = 0;
+ for (uint32_t i = 0; i < n_out; ++i) {
+ uint32_t index = vcall->out_nested[i],
+ index_2 = vcall->out[i];
+ auto it = state.variables.find(index);
+ if (it == state.variables.end())
+ continue;
+ uint32_t size = type_size[it->second.type],
+ load_offset = offset;
+ offset += size;
+
+ // Skip if outer access expired
+ auto it2 = state.variables.find(index_2);
+ if (it2 == state.variables.end())
+ continue;
+
+ const Variable *v2 = &it2.value();
+ if (v2->reg_index == 0 || v2->param_type == ParamType::Input)
+ continue;
+
+ const char *tname = type_name_ptx[v2->type],
+ *prefix = type_prefix[v2->type];
+
+ // Special handling for predicates (pass via u8)
+ if ((VarType) v2->type == VarType::Bool) {
+ tname = "u8";
+ prefix = "%w";
+ }
- const Variable *v2 = &it2.value();
- if (v2->reg_index == 0 || v2->param_type == ParamType::Input)
- continue;
+ buffer.fmt(" ld.param.%s %s%u, [out+%u];\n",
+ tname, prefix, v2->reg_index, load_offset);
- const char *tname = type_name_ptx[v2->type],
- *prefix = type_prefix[v2->type];
+ // Special handling for predicates
+ if ((VarType) v2->type == VarType::Bool)
+ buffer.fmt(" setp.ne.u16 %%p%u, %%w%u, 0;\n",
+ v2->reg_index, v2->reg_index);
+ }
+ };
- // Special handling for predicates (pass via u8)
- if ((VarType) v2->type == VarType::Bool) {
- tname = "u8";
- prefix = "%w";
+ if (!branch_vcall) {
+ const char* target = "%rd2";
+ assemble_call(target);
+ read_output_arguments();
+ } else {
+ char target[38];
+ snprintf(target, sizeof(target), "func_%016llx%016llx",
+ (unsigned long long) callable_hash.high64,
+ (unsigned long long) callable_hash.low64);
+ assemble_call(target);
+ read_output_arguments();
}
- buffer.fmt(" ld.param.%s %s%u, [out+%u];\n",
- tname, prefix, v2->reg_index, load_offset);
- }
-
- buffer.put(" }\n\n");
+ buffer.put(" }\n");
+ buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg);
- // =====================================================
- // 6. Special handling for predicates return value(s)
- // =====================================================
-
- for (uint32_t out : vcall->out) {
- auto it = state.variables.find(out);
- if (it == state.variables.end())
- continue;
- const Variable *v2 = &it->second;
- if ((VarType) v2->type != VarType::Bool)
- continue;
- if (v2->reg_index == 0 || v2->param_type == ParamType::Input)
- continue;
+ callable_id++;
- // Special handling for predicates
- buffer.fmt(" setp.ne.u16 %%p%u, %%w%u, 0;\n",
- v2->reg_index, v2->reg_index);
+ if (!branch_vcall)
+ break;
}
-
- buffer.fmt(" bra.uni l_done_%u;\n", vcall_reg);
buffer.put(" }\n");
// =====================================================
@@ -1370,15 +1475,34 @@ void jitc_vcall_upload(ThreadState *ts) {
uint64_t *data = (uint64_t *) jitc_malloc(at, vcall->offset_size);
memset(data, 0, vcall->offset_size);
- for (uint32_t i = 0; i < vcall->n_inst; ++i) {
- auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true));
- if (it == globals_map.end())
- jitc_fail("jitc_vcall_upload(): could not find callable!");
-
- // high part: instance data offset, low part: callable index
- data[vcall->inst_id[i]] =
- (((uint64_t) vcall->data_offset[i]) << 32) |
- it->second.callable_index;
+ if (ts->backend == JitBackend::CUDA && jit_flag(JitFlag::VCallBranch)) {
+ tsl::robin_map callable_indices;
+ uint32_t index = 0;
+ for (const XXH128_hash_t &callable : vcall->callables_set)
+ callable_indices[callable] = index++;
+
+ for (uint32_t i = 0; i < vcall->n_inst; ++i) {
+ auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true));
+ if (it == globals_map.end())
+ jitc_fail("jitc_vcall_upload(): could not find callable!");
+
+ // high part: instance data offset, low part: callable index
+ data[vcall->inst_id[i]] =
+ (((uint64_t) vcall->data_offset[i]) << 32) |
+ callable_indices[vcall->inst_hash[i]];
+ }
+ }
+ else {
+ for (uint32_t i = 0; i < vcall->n_inst; ++i) {
+ auto it = globals_map.find(GlobalKey(vcall->inst_hash[i], true));
+ if (it == globals_map.end())
+ jitc_fail("jitc_vcall_upload(): could not find callable!");
+
+ // high part: instance data offset, low part: callable index
+ data[vcall->inst_id[i]] =
+ (((uint64_t) vcall->data_offset[i]) << 32) |
+ it->second.callable_index;
+ }
}
jitc_memcpy_async(ts->backend, vcall->offset, data, vcall->offset_size);
diff --git a/tests/vcall.cpp b/tests/vcall.cpp
index f7d2f69fa..2d9383bff 100644
--- a/tests/vcall.cpp
+++ b/tests/vcall.cpp
@@ -239,10 +239,15 @@ TEST_BOTH(01_recorded_vcall) {
BasePtr self = arange(10) % 3;
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
- Float y = vcall(
- "Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x);
- jit_assert(strcmp(y.str(), "[0, 22, 204, 0, 28, 210, 0, 34, 216, 0]") == 0);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
+ Float y = vcall(
+ "Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x);
+ jit_assert(strcmp(y.str(), "[0, 22, 204, 0, 28, 210, 0, 34, 216, 0]") == 0);
+ }
}
jit_registry_remove(Backend, &a1);
@@ -290,35 +295,40 @@ TEST_BOTH(02_calling_conventions) {
(void) i1; (void) i2; (void) i3;
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
- using BasePtr = Array;
- BasePtr self = arange(10) % 3;
+ using BasePtr = Array;
+ BasePtr self = arange(12) % 4;
- Mask p0(false);
- Float p1(12);
- Double p2(34);
- Float p3(56);
- Mask p4(true);
+ Mask p0(false);
+ Float p1(12);
+ Double p2(34);
+ Float p3(56);
+ Mask p4(true);
- auto result = vcall(
- "Base",
- [](Base *self2, Mask p0, Float p1, Double p2, Float p3, Mask p4) {
- return self2->f(p0, p1, p2, p3, p4);
- },
- self, p0, p1, p2, p3, p4);
-
- jit_var_schedule(result.template get<0>().index());
- jit_var_schedule(result.template get<1>().index());
- jit_var_schedule(result.template get<2>().index());
- jit_var_schedule(result.template get<3>().index());
- jit_var_schedule(result.template get<4>().index());
-
- jit_assert(strcmp(result.template get<0>().str(), "[0, 0, 1, 0, 0, 1, 0, 0, 1, 0]") == 0);
- jit_assert(strcmp(result.template get<1>().str(), "[0, 12, 13, 0, 12, 13, 0, 12, 13, 0]") == 0);
- jit_assert(strcmp(result.template get<2>().str(), "[0, 34, 36, 0, 34, 36, 0, 34, 36, 0]") == 0);
- jit_assert(strcmp(result.template get<3>().str(), "[0, 56, 59, 0, 56, 59, 0, 56, 59, 0]") == 0);
- jit_assert(strcmp(result.template get<4>().str(), "[0, 1, 0, 0, 1, 0, 0, 1, 0, 0]") == 0);
+ auto result = vcall(
+ "Base",
+ [](Base *self2, Mask p0, Float p1, Double p2, Float p3, Mask p4) {
+ return self2->f(p0, p1, p2, p3, p4);
+ },
+ self, p0, p1, p2, p3, p4);
+
+ jit_var_schedule(result.template get<0>().index());
+ jit_var_schedule(result.template get<1>().index());
+ jit_var_schedule(result.template get<2>().index());
+ jit_var_schedule(result.template get<3>().index());
+ jit_var_schedule(result.template get<4>().index());
+
+ jit_assert(strcmp(result.template get<0>().str(), "[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0]") == 0);
+ jit_assert(strcmp(result.template get<1>().str(), "[0, 12, 13, 0, 0, 12, 13, 0, 0, 12, 13, 0]") == 0);
+ jit_assert(strcmp(result.template get<2>().str(), "[0, 34, 36, 0, 0, 34, 36, 0, 0, 34, 36, 0]") == 0);
+ jit_assert(strcmp(result.template get<3>().str(), "[0, 56, 59, 0, 0, 56, 59, 0, 0, 56, 59, 0]") == 0);
+ jit_assert(strcmp(result.template get<4>().str(), "[0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]") == 0);
+ }
}
jit_registry_remove(Backend, &b1);
@@ -362,32 +372,36 @@ TEST_BOTH(03_optimize_away_outputs) {
BasePtr self = arange(10) % 4;
for (uint32_t i = 0; i < 2; ++i) {
- i = 1;
- jit_set_flag(JitFlag::VCallOptimize, i);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
- jit_assert(jit_var_ref(p3.index()) == 1);
+ jit_assert(jit_var_ref(p3.index()) == 1);
- auto result = vcall(
- "Base",
- [](Base *self2, Float p1, Float p2, Float p3) {
- return self2->f(p1, p2, p3);
- },
- self, p1, p2, p3);
+ auto result = vcall(
+ "Base",
+ [](Base *self2, Float p1, Float p2, Float p3) {
+ return self2->f(p1, p2, p3);
+ },
+ self, p1, p2, p3);
- jit_assert(jit_var_ref(p1.index()) == 3);
- jit_assert(jit_var_ref(p2.index()) == 3);
+ jit_assert(jit_var_ref(p1.index()) == 3);
+ jit_assert(jit_var_ref(p2.index()) == 3);
- // Irrelevant input optimized away
- jit_assert(jit_var_ref(p3.index()) == 2 - i);
+ // Irrelevant input optimized away
+ jit_assert(jit_var_ref(p3.index()) == 2 - i);
- result.template get<0>() = Float(0);
+ result.template get<0>() = Float(0);
- jit_assert(jit_var_ref(p1.index()) == 3);
- jit_assert(jit_var_ref(p2.index()) == 3 - 2*i);
- jit_assert(jit_var_ref(p3.index()) == 2 - i);
+ jit_assert(jit_var_ref(p1.index()) == 3);
+ jit_assert(jit_var_ref(p2.index()) == 3 - 2*i);
+ jit_assert(jit_var_ref(p3.index()) == 2 - i);
- jit_assert(strcmp(jit_var_str(result.template get<1>().index()),
- "[0, 13, 13, 14, 0, 13, 13, 14, 0, 13]") == 0);
+ jit_assert(strcmp(jit_var_str(result.template get<1>().index()),
+ "[0, 13, 13, 14, 0, 13, 13, 14, 0, 13]") == 0);
+ }
}
jit_registry_remove(Backend, &c1);
@@ -424,47 +438,52 @@ TEST_BOTH(04_devirtualize) {
for (uint32_t k = 0; k < 2; ++k) {
for (uint32_t i = 0; i < 2; ++i) {
- Float p1, p2;
- if (k == 0) {
- p1 = 12;
- p2 = 34;
- } else {
- p1 = dr::opaque(12);
- p2 = dr::opaque(34);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ Float p1, p2;
+ if (k == 0) {
+ p1 = 12;
+ p2 = 34;
+ } else {
+ p1 = dr::opaque(12);
+ p2 = dr::opaque(34);
+ }
+
+ jit_set_flag(JitFlag::VCallOptimize, i);
+ uint32_t scope = jit_scope(Backend);
+
+ auto result = vcall(
+ "Base",
+ [](Base *self2, Float p1, Float p2) {
+ return self2->f(p1, p2);
+ },
+ self, p1, p2);
+
+ jit_set_scope(Backend, scope + 1);
+
+ Float p2_wrap = Float::steal(jit_var_wrap_vcall(p2.index()));
+
+ Mask mask = neq(self, nullptr),
+ mask_combined = Mask::steal(jit_var_mask_apply(mask.index(), 10));
+
+ Float alt = (p2_wrap + 2) & mask_combined;
+
+ jit_set_scope(Backend, scope + 2);
+
+ jit_assert((result.template get<0>().index() == alt.index()) == (i == 1));
+ jit_assert(jit_var_is_literal(result.template get<2>().index()) == (i == 1));
+
+ jit_var_schedule(result.template get<0>().index());
+ jit_var_schedule(result.template get<1>().index());
+
+ jit_assert(
+ strcmp(jit_var_str(result.template get<0>().index()),
+ "[0, 36, 36, 0, 36, 36, 0, 36, 36, 0]") == 0);
+ jit_assert(strcmp(jit_var_str(result.template get<1>().index()),
+ "[0, 13, 14, 0, 13, 14, 0, 13, 14, 0]") == 0);
}
-
- jit_set_flag(JitFlag::VCallOptimize, i);
- uint32_t scope = jit_scope(Backend);
-
- auto result = vcall(
- "Base",
- [](Base *self2, Float p1, Float p2) {
- return self2->f(p1, p2);
- },
- self, p1, p2);
-
- jit_set_scope(Backend, scope + 1);
-
- Float p2_wrap = Float::steal(jit_var_wrap_vcall(p2.index()));
-
- Mask mask = neq(self, nullptr),
- mask_combined = Mask::steal(jit_var_mask_apply(mask.index(), 10));
-
- Float alt = (p2_wrap + 2) & mask_combined;
-
- jit_set_scope(Backend, scope + 2);
-
- jit_assert((result.template get<0>().index() == alt.index()) == (i == 1));
- jit_assert(jit_var_is_literal(result.template get<2>().index()) == (i == 1));
-
- jit_var_schedule(result.template get<0>().index());
- jit_var_schedule(result.template get<1>().index());
-
- jit_assert(
- strcmp(jit_var_str(result.template get<0>().index()),
- "[0, 36, 36, 0, 36, 36, 0, 36, 36, 0]") == 0);
- jit_assert(strcmp(jit_var_str(result.template get<1>().index()),
- "[0, 13, 14, 0, 13, 14, 0, 13, 14, 0]") == 0);
}
}
jit_registry_remove(Backend, &d1);
@@ -509,11 +528,16 @@ TEST_BOTH(05_extra_data) {
}
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
- Float result = vcall(
- "Base", [](Base *self2, Float x) { return self2->f(x); }, self,
- x);
- jit_assert(strcmp(result.str(), "[0, 9, 13, 0, 21, 28, 0, 33, 43, 0]") == 0);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
+ Float result = vcall(
+ "Base", [](Base *self2, Float x) { return self2->f(x); }, self,
+ x);
+ jit_assert(strcmp(result.str(), "[0, 9, 13, 0, 21, 28, 0, 33, 43, 0]") == 0);
+ }
}
}
jit_registry_remove(Backend, &e1);
@@ -550,20 +574,25 @@ TEST_BOTH(06_side_effects) {
BasePtr self = arange(11) % 3;
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
- F1 f1; F2 f2;
- uint32_t i1 = jit_registry_put(Backend, "Base", &f1);
- uint32_t i2 = jit_registry_put(Backend, "Base", &f2);
- jit_assert(i1 == 1 && i2 == 2);
+ F1 f1; F2 f2;
+ uint32_t i1 = jit_registry_put(Backend, "Base", &f1);
+ uint32_t i2 = jit_registry_put(Backend, "Base", &f2);
+ jit_assert(i1 == 1 && i2 == 2);
- vcall("Base", [](Base *self2) { self2->go(); }, self);
- jit_assert(strcmp(f1.buffer.str(), "[0, 4, 0, 8, 0]") == 0);
- jit_assert(strcmp(f2.buffer.str(), "[0, 1, 5, 3]") == 0);
+ vcall("Base", [](Base *self2) { self2->go(); }, self);
+ jit_assert(strcmp(f1.buffer.str(), "[0, 4, 0, 8, 0]") == 0);
+ jit_assert(strcmp(f2.buffer.str(), "[0, 1, 5, 3]") == 0);
- jit_registry_remove(Backend, &f1);
- jit_registry_remove(Backend, &f2);
- jit_registry_trim();
+ jit_registry_remove(Backend, &f1);
+ jit_registry_remove(Backend, &f2);
+ jit_registry_trim();
+ }
}
}
@@ -595,26 +624,31 @@ TEST_BOTH(07_side_effects_only_once) {
BasePtr self = arange(11) % 3;
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
-
- G1 g1; G2 g2;
- uint32_t i1 = jit_registry_put(Backend, "Base", &g1);
- uint32_t i2 = jit_registry_put(Backend, "Base", &g2);
- jit_assert(i1 == 1 && i2 == 2);
-
- auto result = vcall("Base", [](Base *self2) { return self2->f(); }, self);
- Float f1 = result.template get<0>();
- Float f2 = result.template get<1>();
- jit_assert(strcmp(f1.str(), "[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]") == 0);
- jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0);
- jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0);
- jit_assert(strcmp(f2.str(), "[0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]") == 0);
- jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0);
- jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0);
-
- jit_registry_remove(Backend, &g1);
- jit_registry_remove(Backend, &g2);
- jit_registry_trim();
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
+
+ G1 g1; G2 g2;
+ uint32_t i1 = jit_registry_put(Backend, "Base", &g1);
+ uint32_t i2 = jit_registry_put(Backend, "Base", &g2);
+ jit_assert(i1 == 1 && i2 == 2);
+
+ auto result = vcall("Base", [](Base *self2) { return self2->f(); }, self);
+ Float f1 = result.template get<0>();
+ Float f2 = result.template get<1>();
+ jit_assert(strcmp(f1.str(), "[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1]") == 0);
+ jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0);
+ jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0);
+ jit_assert(strcmp(f2.str(), "[0, 2, 1, 0, 2, 1, 0, 2, 1, 0, 2]") == 0);
+ jit_assert(strcmp(g1.buffer.str(), "[0, 4, 0, 0, 0]") == 0);
+ jit_assert(strcmp(g2.buffer.str(), "[0, 0, 3, 0, 0]") == 0);
+
+ jit_registry_remove(Backend, &g1);
+ jit_registry_remove(Backend, &g2);
+ jit_registry_trim();
+ }
}
}
@@ -652,12 +686,17 @@ TEST_BOTH(08_multiple_calls) {
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
- Float y = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x);
- Float z = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, y);
+ Float y = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, x);
+ Float z = vcall("Base", [](Base *self2, Float x2) { return self2->f(x2); }, self, y);
- jit_assert(strcmp(z.str(), "[0, 12, 14, 0, 12, 14, 0, 12, 14, 0]") == 0);
+ jit_assert(strcmp(z.str(), "[0, 12, 14, 0, 12, 14, 0, 12, 14, 0]") == 0);
+ }
}
jit_registry_remove(Backend, &h1);
@@ -710,26 +749,31 @@ TEST_BOTH(09_big) {
self2 = select(self2 <= n2, self2, 0);
for (uint32_t i = 0; i < 2; ++i) {
- jit_set_flag(JitFlag::VCallOptimize, i);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
+ jit_set_flag(JitFlag::VCallOptimize, i);
- Float x = vcall("Base1", [](Base1 *self_) { return self_->f(); }, Base1Ptr(self1));
- Float y = vcall("Base2", [](Base2 *self_) { return self_->f(); }, Base2Ptr(self2));
+ Float x = vcall("Base1", [](Base1 *self_) { return self_->f(); }, Base1Ptr(self1));
+ Float y = vcall("Base2", [](Base2 *self_) { return self_->f(); }, Base2Ptr(self2));
- jit_var_schedule(x.index());
- jit_var_schedule(y.index());
+ jit_var_schedule(x.index());
+ jit_var_schedule(y.index());
- jit_assert(x.read(0) == 0);
- jit_assert(y.read(0) == 0);
+ jit_assert(x.read(0) == 0);
+ jit_assert(y.read(0) == 0);
- for (uint32_t j = 1; j <= n1; ++j)
- jit_assert(x.read(j) == j - 1);
- for (uint32_t j = 1; j <= n2; ++j)
- jit_assert(y.read(j) == 100 + j - 1);
+ for (uint32_t j = 1; j <= n1; ++j)
+ jit_assert(x.read(j) == j - 1);
+ for (uint32_t j = 1; j <= n2; ++j)
+ jit_assert(y.read(j) == 100 + j - 1);
- for (uint32_t j = n1 + 1; j < n; ++j)
- jit_assert(x.read(j + 1) == 0);
- for (uint32_t j = n2 + 1; j < n; ++j)
- jit_assert(y.read(j + 1) == 0);
+ for (uint32_t j = n1 + 1; j < n; ++j)
+ jit_assert(x.read(j + 1) == 0);
+ for (uint32_t j = n2 + 1; j < n; ++j)
+ jit_assert(y.read(j + 1) == 0);
+ }
}
for (int i = 0; i < n1; ++i)
@@ -754,12 +798,18 @@ TEST_BOTH(09_self) {
uint32_t i2_id = jit_registry_put(Backend, "Base", &i2);
UInt32 self(i1_id, i2_id);
- UInt32 y = vcall(
- "Base",
- [](Base *self_) { return self_->f(); },
- BasePtr(self));
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
- jit_assert(strcmp(y.str(), "[1, 2]") == 0);
+ UInt32 y = vcall(
+ "Base",
+ [](Base *self_) { return self_->f(); },
+ BasePtr(self));
+
+ jit_assert(strcmp(y.str(), "[1, 2]") == 0);
+ }
jit_registry_remove(Backend, &i1);
jit_registry_remove(Backend, &i2);
@@ -796,14 +846,20 @@ TEST_BOTH(10_recursion) {
UInt32 self2(i21_id, i22_id);
Float x(3.f, 5.f);
- Float y = vcall(
- "Base2",
- [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) {
- return self_->g(ptr_, x_);
- },
- Base2Ptr(self2), Base1Ptr(self1), x);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
- jit_assert(strcmp(y.str(), "[7, 16]") == 0);
+ Float y = vcall(
+ "Base2",
+ [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) {
+ return self_->g(ptr_, x_);
+ },
+ Base2Ptr(self2), Base1Ptr(self1), x);
+
+ jit_assert(strcmp(y.str(), "[7, 16]") == 0);
+ }
jit_registry_remove(Backend, &i11);
jit_registry_remove(Backend, &i12);
@@ -842,14 +898,20 @@ TEST_BOTH(11_recursion_with_local) {
UInt32 self2(i21_id, i22_id);
Float x(3.f, 5.f);
- Float y = vcall(
- "Base2",
- [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) {
- return self_->g(ptr_, x_);
- },
- Base2Ptr(self2), Base1Ptr(self1), x);
+ for (uint32_t j = 0; j < 4; ++j) {
+ jit_set_flag(JitFlag::VCallBranch, j > 0);
+ jit_set_flag(JitFlag::VCallBranchBinarySearch, j == 2);
+ jit_set_flag(JitFlag::VCallBranchJumpTable, j == 3);
- jit_assert(strcmp(y.str(), "[7, 16]") == 0);
+ Float y = vcall(
+ "Base2",
+ [](Base2 *self_, const Base1Ptr &ptr_, const Float &x_) {
+ return self_->g(ptr_, x_);
+ },
+ Base2Ptr(self2), Base1Ptr(self1), x);
+
+ jit_assert(strcmp(y.str(), "[7, 16]") == 0);
+ }
jit_registry_remove(Backend, &i11);
jit_registry_remove(Backend, &i12);