Skip to content

Commit 7b42ffa

Browse files
Added option to re-traverse the input when new opaque candidates have been discovered
1 parent 5425b17 commit 7b42ffa

File tree

3 files changed

+46
-6
lines changed

3 files changed

+46
-6
lines changed

src/python/freeze.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
292292
// corresponding bit in the mask, indicating that this literal should be
293293
// made opaque next time.
294294
uint32_t opaque_cunter = 0;
295+
bool new_opaques = false;
295296
for (uint32_t i = 0; i < this->layout.size(); i++) {
296297
Layout &layout = this->layout[i];
297298
Layout &prev_layout = prev.layout[i];
@@ -300,6 +301,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
300301
prev_layout.flags & (uint32_t) LayoutFlag::Literal &&
301302
(layout.literal != prev_layout.literal)) {
302303
opaque_mask[i] = true;
304+
new_opaques = true;
303305
}
304306
if (opaque_mask[i])
305307
opaque_cunter++;
@@ -309,7 +311,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
309311
"compare_opaque(): %u variables will be made opaque",
310312
opaque_cunter);
311313

312-
return true;
314+
return new_opaques;
313315
}
314316

315317
void FlatVariables::schedule_jit_variables(bool schedule_force,
@@ -1478,9 +1480,6 @@ nb::object FunctionRecording::record(nb::callable func,
14781480
out_variables.traverse(output, ctx);
14791481
out_variables.schedule_jit_variables(false, nullptr);
14801482

1481-
// if (!frozen_func->prev_key)
1482-
// frozen_func->in_opaque_mask.resize(out_variables.layout.size(), false);
1483-
14841483
out_variables.traverse_with_registry(input, ctx);
14851484
out_variables.schedule_jit_variables(false, nullptr);
14861485

@@ -1639,7 +1638,9 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
16391638
auto in_variables =
16401639
std::make_shared<FlatVariables>(FlatVariables(in_heuristics));
16411640
// Evaluate and traverse input variables (args and kwargs)
1642-
{
1641+
// Repeat this a max of 2 times if the number of variables that should
1642+
// be made opaque changed.
1643+
for (uint32_t i = 0; i < 2; i++) {
16431644
// Enter Resume scope, so we can track gradients
16441645
ADScopeContext ad_scope(drjit::ADScope::Resume, 0, nullptr, 0,
16451646
true);
@@ -1679,8 +1680,21 @@ nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
16791680
}
16801681

16811682
in_variables->record_jit_variables();
1683+
bool new_opaques = false;
16821684
if (prev_key && auto_opaque)
1683-
in_variables->fill_opaque_mask(*prev_key, opaque_mask);
1685+
new_opaques =
1686+
in_variables->fill_opaque_mask(*prev_key, opaque_mask);
1687+
1688+
if (new_opaques) {
1689+
// If new variables have been discovered that should be made
1690+
// opaque, we repeat traversal of the input to make them opaque.
1691+
// This reduces the number of variants that are saved by one.
1692+
in_variables->release();
1693+
in_variables = std::make_shared<FlatVariables>(
1694+
FlatVariables(in_heuristics));
1695+
} else {
1696+
break;
1697+
}
16841698
}
16851699

16861700
in_heuristics = in_heuristics.max(in_variables->heuristic());

src/python/freeze.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ struct FlatVariables {
245245
* Generates a mask of variables that should be made opaque in the next
246246
* iteration. This should only be called if \c compatible_auto_opaque
247247
* returns true for the corresponding \c FlatVariables pair.
248+
*
249+
* Returns true if new variables have been discovered that should be made
250+
* opaque, otherwise returns false.
248251
*/
249252
bool fill_opaque_mask(FlatVariables &prev, std::vector<bool> &opaque_mask);
250253

tests/test_freeze.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,29 @@ def func(x: t):
28512851

28522852
assert dr.allclose(ref, res)
28532853

2854+
assert frozen.n_recordings < n
2855+
2856+
2857+
@pytest.test_arrays("float32, jit, shape=(*)")
2858+
def test73_auto_opaque_retraverse(t):
2859+
2860+
def func(x: t):
2861+
return x + 1
2862+
2863+
frozen = dr.freeze(func, auto_opaque=True)
2864+
2865+
n = 3
2866+
for i in range(n):
2867+
x = t(i)
2868+
2869+
res = frozen(x)
2870+
ref = func(x)
2871+
2872+
assert dr.allclose(ref, res)
2873+
2874+
assert frozen.n_recordings == 2
2875+
2876+
28542877

28552878
# @pytest.test_arrays("float32, jit, diff, shape=(*)")
28562879
# def test42_raise(t):

0 commit comments

Comments
 (0)