Skip to content

Commit 5e28fe7

Browse files
Added option to re-traverse the input when new opaque candidates have been discovered
1 parent 3b9536c commit 5e28fe7

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

src/python/freeze.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
275275
// corresponding bit in the mask, indicating that this literal should be
276276
// made opaque next time.
277277
uint32_t opaque_cunter = 0;
278+
bool new_opaques = false;
278279
for (uint32_t i = 0; i < this->layout.size(); i++) {
279280
Layout &layout = this->layout[i];
280281
Layout &prev_layout = prev.layout[i];
@@ -283,6 +284,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
283284
prev_layout.flags & (uint32_t) LayoutFlag::Literal &&
284285
(layout.literal != prev_layout.literal)) {
285286
opaque_mask[i] = true;
287+
new_opaques = true;
286288
}
287289
if (opaque_mask[i])
288290
opaque_cunter++;
@@ -292,7 +294,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
292294
"compare_opaque(): %u variables will be made opaque",
293295
opaque_cunter);
294296

295-
return true;
297+
return new_opaques;
296298
}
297299

298300
void FlatVariables::schedule_jit_variables(bool schedule_force,
@@ -1482,9 +1484,6 @@ nb::object FunctionRecording::record(nb::callable func,
14821484
out_variables.traverse(output, ctx);
14831485
out_variables.schedule_jit_variables(false, nullptr);
14841486

1485-
// if (!frozen_func->prev_key)
1486-
// frozen_func->in_opaque_mask.resize(out_variables.layout.size(), false);
1487-
14881487
out_variables.traverse_with_registry(input, ctx);
14891488
out_variables.schedule_jit_variables(false, nullptr);
14901489

@@ -1643,7 +1642,9 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
16431642
auto in_variables =
16441643
std::make_shared<FlatVariables>(FlatVariables(in_heuristics));
16451644
// Evaluate and traverse input variables (args and kwargs)
1646-
{
1645+
// Repeat this a max of 2 times if the number of variables that should
1646+
// be made opaque changed.
1647+
for (uint32_t i = 0; i < 2; i++) {
16471648
// Enter Resume scope, so we can track gradients
16481649
ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0,
16491650
true);
@@ -1683,8 +1684,21 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
16831684
}
16841685

16851686
in_variables->record_jit_variables();
1687+
bool new_opaques = false;
16861688
if (prev_key && auto_opaque)
1687-
in_variables->fill_opaque_mask(*prev_key, opaque_mask);
1689+
new_opaques =
1690+
in_variables->fill_opaque_mask(*prev_key, opaque_mask);
1691+
1692+
if (new_opaques) {
1693+
// If new variables have been discovered that should be made
1694+
// opaque, we repeat traversal of the input to make them opaque.
1695+
// This reduces the number of variants that are saved by one.
1696+
in_variables->release();
1697+
in_variables = std::make_shared<FlatVariables>(
1698+
FlatVariables(in_heuristics));
1699+
} else {
1700+
break;
1701+
}
16881702
}
16891703

16901704
in_heuristics = in_heuristics.max(in_variables->heuristic());

src/python/freeze.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ struct FlatVariables {
245245
* Generates a mask of variables that should be made opaque in the next
246246
* iteration. This should only be called if \c compatible_auto_opaque
247247
* returns true for the corresponding \c FlatVariables pair.
248+
*
249+
* Returns true if new variables have been discovered that should be made
250+
* opaque, otherwise returns false.
248251
*/
249252
bool fill_opaque_mask(FlatVariables &prev, std::vector<bool> &opaque_mask);
250253

tests/test_freeze.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,29 @@ def func(x: t):
28512851

28522852
assert dr.allclose(ref, res)
28532853

2854+
assert frozen.n_recordings < n
2855+
2856+
2857+
@pytest.test_arrays("float32, jit, shape=(*)")
2858+
def test73_auto_opaque_retraverse(t):
2859+
2860+
def func(x: t):
2861+
return x + 1
2862+
2863+
frozen = dr.freeze(func, auto_opaque=True)
2864+
2865+
n = 3
2866+
for i in range(n):
2867+
x = t(i)
2868+
2869+
res = frozen(x)
2870+
ref = func(x)
2871+
2872+
assert dr.allclose(ref, res)
2873+
2874+
assert frozen.n_recordings == 2
2875+
2876+
28542877

28552878
# @pytest.test_arrays("float32, jit, diff, shape=(*)")
28562879
# def test42_raise(t):

0 commit comments

Comments
 (0)