@@ -292,6 +292,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
292
292
// corresponding bit in the mask, indicating that this literal should be
293
293
// made opaque next time.
294
294
uint32_t opaque_cunter = 0 ;
295
+ bool new_opaques = false ;
295
296
for (uint32_t i = 0 ; i < this ->layout .size (); i++) {
296
297
Layout &layout = this ->layout [i];
297
298
Layout &prev_layout = prev.layout [i];
@@ -300,6 +301,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
300
301
prev_layout.flags & (uint32_t ) LayoutFlag::Literal &&
301
302
(layout.literal != prev_layout.literal )) {
302
303
opaque_mask[i] = true ;
304
+ new_opaques = true ;
303
305
}
304
306
if (opaque_mask[i])
305
307
opaque_cunter++;
@@ -309,7 +311,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
309
311
" compare_opaque(): %u variables will be made opaque" ,
310
312
opaque_cunter);
311
313
312
- return true ;
314
+ return new_opaques ;
313
315
}
314
316
315
317
void FlatVariables::schedule_jit_variables (bool schedule_force,
@@ -1478,9 +1480,6 @@ nb::object FunctionRecording::record(nb::callable func,
1478
1480
out_variables.traverse (output, ctx);
1479
1481
out_variables.schedule_jit_variables (false , nullptr );
1480
1482
1481
- // if (!frozen_func->prev_key)
1482
- // frozen_func->in_opaque_mask.resize(out_variables.layout.size(), false);
1483
-
1484
1483
out_variables.traverse_with_registry (input, ctx);
1485
1484
out_variables.schedule_jit_variables (false , nullptr );
1486
1485
@@ -1639,7 +1638,9 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1639
1638
auto in_variables =
1640
1639
std::make_shared<FlatVariables>(FlatVariables (in_heuristics));
1641
1640
// Evaluate and traverse input variables (args and kwargs)
1642
- {
1641
+ // Repeat this a max of 2 times if the number of variables that should
1642
+ // be made opaque changed.
1643
+ for (uint32_t i = 0 ; i < 2 ; i++) {
1643
1644
// Enter Resume scope, so we can track gradients
1644
1645
ADScopeContext ad_scope (drjit::ADScope::Resume, 0 , nullptr , 0 ,
1645
1646
true );
@@ -1679,8 +1680,21 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1679
1680
}
1680
1681
1681
1682
in_variables->record_jit_variables ();
1683
+ bool new_opaques = false ;
1682
1684
if (prev_key && auto_opaque)
1683
- in_variables->fill_opaque_mask (*prev_key, opaque_mask);
1685
+ new_opaques =
1686
+ in_variables->fill_opaque_mask (*prev_key, opaque_mask);
1687
+
1688
+ if (new_opaques) {
1689
+ // If new variables have been discovered that should be made
1690
+ // opaque, we repeat traversal of the input to make them opaque.
1691
+ // This reduces the number of variants that are saved by one.
1692
+ in_variables->release ();
1693
+ in_variables = std::make_shared<FlatVariables>(
1694
+ FlatVariables (in_heuristics));
1695
+ } else {
1696
+ break ;
1697
+ }
1684
1698
}
1685
1699
1686
1700
in_heuristics = in_heuristics.max (in_variables->heuristic ());
0 commit comments