@@ -275,6 +275,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
275
275
// corresponding bit in the mask, indicating that this literal should be
276
276
// made opaque next time.
277
277
uint32_t opaque_cunter = 0 ;
278
+ bool new_opaques = false ;
278
279
for (uint32_t i = 0 ; i < this ->layout .size (); i++) {
279
280
Layout &layout = this ->layout [i];
280
281
Layout &prev_layout = prev.layout [i];
@@ -283,6 +284,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
283
284
prev_layout.flags & (uint32_t ) LayoutFlag::Literal &&
284
285
(layout.literal != prev_layout.literal )) {
285
286
opaque_mask[i] = true ;
287
+ new_opaques = true ;
286
288
}
287
289
if (opaque_mask[i])
288
290
opaque_cunter++;
@@ -292,7 +294,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
292
294
" compare_opaque(): %u variables will be made opaque" ,
293
295
opaque_cunter);
294
296
295
- return true ;
297
+ return new_opaques ;
296
298
}
297
299
298
300
void FlatVariables::schedule_jit_variables (bool schedule_force,
@@ -1482,9 +1484,6 @@ nb::object FunctionRecording::record(nb::callable func,
1482
1484
out_variables.traverse (output, ctx);
1483
1485
out_variables.schedule_jit_variables (false , nullptr );
1484
1486
1485
- // if (!frozen_func->prev_key)
1486
- // frozen_func->in_opaque_mask.resize(out_variables.layout.size(), false);
1487
-
1488
1487
out_variables.traverse_with_registry (input, ctx);
1489
1488
out_variables.schedule_jit_variables (false , nullptr );
1490
1489
@@ -1643,7 +1642,9 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1643
1642
auto in_variables =
1644
1643
std::make_shared<FlatVariables>(FlatVariables (in_heuristics));
1645
1644
// Evaluate and traverse input variables (args and kwargs)
1646
- {
1645
+ // Repeat this a max of 2 times if the number of variables that should
1646
+ // be made opaque changed.
1647
+ for (uint32_t i = 0 ; i < 2 ; i++) {
1647
1648
// Enter Resume scope, so we can track gradients
1648
1649
ADScopeContext ad_scope (drjit::ADScope::Resume, 0 , nullptr , 0 ,
1649
1650
true );
@@ -1683,8 +1684,21 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1683
1684
}
1684
1685
1685
1686
in_variables->record_jit_variables ();
1687
+ bool new_opaques = false ;
1686
1688
if (prev_key && auto_opaque)
1687
- in_variables->fill_opaque_mask (*prev_key, opaque_mask);
1689
+ new_opaques =
1690
+ in_variables->fill_opaque_mask (*prev_key, opaque_mask);
1691
+
1692
+ if (new_opaques) {
1693
+ // If new variables have been discovered that should be made
1694
+ // opaque, we repeat traversal of the input to make them opaque.
1695
+ // This reduces the number of variants that are saved by one.
1696
+ in_variables->release ();
1697
+ in_variables = std::make_shared<FlatVariables>(
1698
+ FlatVariables (in_heuristics));
1699
+ } else {
1700
+ break ;
1701
+ }
1688
1702
}
1689
1703
1690
1704
in_heuristics = in_heuristics.max (in_variables->heuristic ());
0 commit comments