Skip to content

Commit 99838f8

Browse files
Fixed texture perfromance regression by handling uninitialized variables
1 parent c3073cc commit 99838f8

File tree

4 files changed

+54
-18
lines changed

4 files changed

+54
-18
lines changed

src/python/freeze.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ void FlatVariables::schedule_jit_variables(bool schedule_force,
265265
// We have to force scheduling of undefined variables, in order to
266266
// handle variables initialized with ``empty``.
267267
if (schedule_force ||
268-
(opaque_mask && (*opaque_mask)[i - layout_index]) ||
269-
jit_var_state(index) == VarState::Undefined) {
268+
(opaque_mask && (*opaque_mask)[i - layout_index]) /* ||
269+
jit_var_state(index) == VarState::Undefined */) {
270270
// Returns owning reference
271271
index = jit_var_schedule_force(index, &rv);
272272
} else {
@@ -286,16 +286,22 @@ void FlatVariables::schedule_jit_variables(bool schedule_force,
286286
}
287287

288288
if (info.state == VarState::Literal) {
289-
// Special case, where the variable is a literal. This should not
290-
// occur, as all literals are made opaque in beforehand, however it
291-
// is nice to have a fallback.
289+
// Special case, where the variable is a literal.
292290
layout.literal = info.literal;
293-
// Store size in index variable, as this is not used for literals
291+
// Store size in index variable, as this is not used for literals.
294292
layout.index = info.size;
295293
layout.vt = (uint32_t) info.type;
296294
layout.literal_index = index;
297295

298296
layout.flags |= (uint32_t) LayoutFlag::Literal;
297+
} else if (info.state == VarState::Undefined) {
298+
// Special case, where the variable is a literal.
299+
// Store size in index variable, as this is not used for literals.
300+
layout.index = info.size;
301+
layout.vt = (uint32_t) info.type;
302+
layout.literal_index = index;
303+
304+
layout.flags |= (uint32_t) LayoutFlag::Undefined;
299305
} else {
300306
layout.index = this->add_jit_index(index);
301307
layout.vt = (uint32_t) info.type;
@@ -408,7 +414,8 @@ uint32_t FlatVariables::construct_jit_index(uint32_t prev_index) {
408414

409415
uint32_t index;
410416
VarType vt;
411-
if (layout.flags & (uint32_t) LayoutFlag::Literal) {
417+
if ((layout.flags & (uint32_t) LayoutFlag::Literal) ||
418+
(layout.flags & (uint32_t) LayoutFlag::Undefined)) {
412419
index = layout.literal_index;
413420
jit_var_inc_ref(index);
414421
vt = (VarType) layout.vt;
@@ -1255,7 +1262,9 @@ FlatVariables::~FlatVariables() {
12551262
state_lock_guard guard;
12561263
for (uint32_t i = 0; i < layout.size(); ++i) {
12571264
Layout &l = layout[i];
1258-
if (l.flags & (uint32_t) LayoutFlag::Literal && l.literal_index) {
1265+
if (((l.flags & (uint32_t) LayoutFlag::Literal) ||
1266+
(l.flags & (uint32_t) LayoutFlag::Undefined)) &&
1267+
l.literal_index) {
12591268
jit_var_dec_ref(l.literal_index);
12601269
}
12611270
}
@@ -1279,12 +1288,12 @@ bool log_diff_variable(LogLevel level, const FlatVariables &curr,
12791288
const VarLayout &curr_l = curr.var_layout[slot];
12801289
const VarLayout &prev_l = prev.var_layout[slot];
12811290

1282-
if(curr_l.vt != prev_l.vt){
1291+
if (curr_l.vt != prev_l.vt) {
12831292
jit_log(level, "%s: The variable type changed from %u to %u.",
12841293
path.c_str(), prev_l.vt, curr_l.vt);
12851294
return false;
12861295
}
1287-
if(curr_l.size_index != prev_l.size_index){
1296+
if (curr_l.size_index != prev_l.size_index) {
12881297
jit_log(level,
12891298
"%s: The size equivalence class of the variable changed from "
12901299
"%u to %u.",
@@ -1321,7 +1330,8 @@ bool log_diff(LogLevel level, const FlatVariables &curr,
13211330
}
13221331

13231332
if (curr_l.flags & (uint32_t) LayoutFlag::JitIndex &&
1324-
!(curr_l.flags & (uint32_t) LayoutFlag::Literal)) {
1333+
!(curr_l.flags & (uint32_t) LayoutFlag::Literal) &&
1334+
!(curr_l.flags & (uint32_t) LayoutFlag::Undefined)) {
13251335
uint32_t slot = curr_l.index;
13261336
if (!log_diff_variable(level, curr, prev, path, slot))
13271337
return false;

src/python/freeze.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,17 @@ enum class LayoutFlag : uint32_t {
3131
/// Whether this variable is unaligned in memory
3232
Unaligned = (1 << 1),
3333
/// Whether this layout represents a literal variable
34-
Literal = (1 << 2),
34+
Literal = (1 << 2),
35+
/// Whether this layout represents an undefined variable (they behave
36+
/// similarly to literals)
37+
Undefined = (1 << 3),
3538
/// Whether this variable has gradients enabled
36-
GradEnabled = (1 << 3),
39+
GradEnabled = (1 << 4),
3740
/// Did this variable have gradient edges attached when recording, that
3841
/// where postponed by the ``isolate_grad`` function?
39-
Postponed = (1 << 4),
42+
Postponed = (1 << 5),
4043
/// Does this node represent a JIT Index?
41-
JitIndex = (1 << 5),
44+
JitIndex = (1 << 6),
4245
};
4346

4447
/// Stores information about python objects, such as their type, their number of
@@ -61,7 +64,7 @@ struct Layout {
6164
uint32_t index = 0;
6265

6366
/// Flags, storing information about variables and literals.
64-
uint32_t flags : 6; // LayoutFlag
67+
uint32_t flags : 8; // LayoutFlag
6568

6669
/// Optional drjit type of the variable
6770
uint32_t vt: 4; // VarType

tests/test_freeze.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3015,7 +3015,6 @@ def init_optimizer():
30153015
@pytest.test_arrays("float32, jit, shape=(*)")
30163016
@pytest.mark.parametrize("auto_opaque", [False, True])
30173017
def test78_hash_id_fallback(t, auto_opaque):
3018-
30193018
"""
30203019
Test the hash to id fallback for object hashing if the object is not
30213020
traversible nor hashable.
@@ -3047,3 +3046,27 @@ def func(x, test):
30473046

30483047
assert frozen.n_recordings == 3
30493048

3049+
@pytest.test_arrays("float32, jit, shape=(*)")
3050+
@pytest.mark.parametrize("auto_opaque", [False, True])
3051+
def test79_empty(t, auto_opaque):
3052+
3053+
n = 5
3054+
3055+
mod = sys.modules[t.__module__]
3056+
3057+
def func(x, i, v):
3058+
dr.scatter(x, v, i)
3059+
3060+
frozen = dr.freeze(func, auto_opaque = auto_opaque)
3061+
3062+
for i in range(n):
3063+
i = mod.UInt32(i)
3064+
3065+
res = dr.empty(t, n)
3066+
frozen(res, i, 1)
3067+
3068+
ref = dr.empty(t, n)
3069+
func(ref, i, 1)
3070+
3071+
assert res[i] == ref[i]
3072+

0 commit comments

Comments
 (0)