Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions src/python/freeze.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
(layout_.flags & (uint32_t) LayoutFlag::Literal) &&
(prev_layout.flags & (uint32_t) LayoutFlag::Literal) &&
(layout_.literal != prev_layout.literal ||
layout_.literal_size != prev_layout.literal_size);
this->sizes[layout_.literal_size] != prev.sizes[prev_layout.literal_size]);

opaque_mask[i] |= requires_opaque;
new_opaques |= requires_opaque;
Expand Down Expand Up @@ -420,16 +420,16 @@ void FlatVariables::schedule_jit_variables(
if (info.state == VarState::Literal) {
// Special case, where the variable is a literal.
layout_.literal = info.literal;
// Store size in index variable, as this is not used for literals.
layout_.literal_size = (uint32_t) info.size;
// Store size index, as this is not used for literals.
layout_.literal_size = this->add_size((uint32_t) info.size);
layout_.vt = (uint32_t) info.type;
layout_.literal_index = index;

layout_.flags |= (uint32_t) LayoutFlag::Literal;
} else if (info.state == VarState::Undefined) {
// Special case, where the variable is a literal.
// Store size in index variable, as this is not used for literals.
layout_.literal_size = (uint32_t) info.size;
// Store size index, as this is not used for literals.
layout_.literal_size = this->add_size((uint32_t) info.size);
layout_.vt = (uint32_t) info.type;
layout_.literal_index = index;

Expand Down Expand Up @@ -552,6 +552,13 @@ uint32_t FlatVariables::construct_jit_index(uint32_t prev_index) {
index = layout_.literal_index;
jit_var_inc_ref(index);
vt = (VarType) layout_.vt;

uint32_t target_size = this->sizes[layout_.literal_size];
if (jit_var_size(index) != target_size) {
uint32_t new_index = jit_var_resize(index, target_size);
jit_var_dec_ref(index);
index = new_index;
}
} else {
VarLayout &var_layout_ = this->var_layout[layout_.index];
index = this->variables[layout_.index];
Expand Down Expand Up @@ -1903,6 +1910,11 @@ nb::object FunctionRecording::replay(nb::callable func,
jit_freeze_replay(recording, in_variables.variables.data(),
out_variables.variables.data());
}
// Update the size equivalence classes of the output variables
for (uint32_t i = 0; i < out_variables.variables.size(); i++) {
uint32_t size_index = out_variables.var_layout[i].size_index;
out_variables.sizes[size_index] = (uint32_t) jit_var_size(out_variables.variables[i]);
}
}
jit_log(LogLevel::Info, "Replaying done:");

Expand Down
24 changes: 24 additions & 0 deletions tests/test_freeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -3922,3 +3922,27 @@ def func(A, B):
assert tuple(res.shape) == tuple(ref.shape)
assert dr.allclose(res, ref, atol=1e-3)

@pytest.mark.parametrize("auto_opaque", [False, True])
@pytest.test_arrays("float32, jit, shape=(*)")
def test106_literal_resizing(t, auto_opaque):
"""
Tests that literal outputs in a frozen function are correctly resized on replay
when input sizes change.
"""
x = t(1.5)

@dr.freeze(auto_opaque=auto_opaque)
def func(x, y):
return x * y, 0.0 * y

y0 = t(10.0, 20.0)
res0 = func(x, y0)
ref0 = (x * y0, 0.0 * y0)
assert dr.allclose(res0[0], ref0[0])
assert dr.allclose(res0[1], ref0[1])

y1 = t(10.0, 20.0, 30.0)
res1 = func(x, y1)
ref1 = (x * y1, 0.0 * y1)
assert dr.allclose(res1[0], ref1[0])
assert dr.allclose(res1[1], ref1[1])