Skip to content

Commit 2067ed1

Browse files
Memory layout optimization
1 parent dc5e058 commit 2067ed1

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

src/python/freeze.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
#include <sstream>
2424
#include <vector>
2525

26+
#define likely(x) DRJIT_LIKELY(x)
27+
#define unlikely(x) DRJIT_UNLIKELY(x)
28+
2629
/**
2730
* \brief Helper struct to profile and log frozen functions.
2831
*/
@@ -229,10 +232,10 @@ bool FlatVariables::fill_opaque_mask(FlatVariables &prev,
229232
Layout &layout = this->layout[i];
230233
Layout &prev_layout = prev.layout[i];
231234

232-
if (layout.flags & (uint32_t) LayoutFlag::Literal &&
233-
prev_layout.flags & (uint32_t) LayoutFlag::Literal &&
234-
(layout.literal != prev_layout.literal ||
235-
layout.index != prev_layout.index)) {
235+
if (unlikely(layout.flags & (uint32_t) LayoutFlag::Literal &&
236+
prev_layout.flags & (uint32_t) LayoutFlag::Literal &&
237+
(layout.literal != prev_layout.literal ||
238+
layout.index != prev_layout.index))) {
236239
opaque_mask[i] = true;
237240
new_opaques = true;
238241
}
@@ -288,14 +291,14 @@ void FlatVariables::schedule_jit_variables(bool schedule_force,
288291
// is nice to have a fallback.
289292
layout.literal = info.literal;
290293
// Store size in index variable, as this is not used for literals
291-
layout.index = info.size;
292-
layout.vt = info.type;
294+
layout.index = info.size;
295+
layout.vt = (uint32_t) info.type;
293296
layout.literal_index = index;
294297

295298
layout.flags |= (uint32_t) LayoutFlag::Literal;
296299
} else {
297300
layout.index = this->add_jit_index(index);
298-
layout.vt = info.type;
301+
layout.vt = (uint32_t) info.type;
299302
jit_var_dec_ref(index);
300303
}
301304
}
@@ -389,7 +392,7 @@ void FlatVariables::traverse_jit_index(uint32_t index, TraverseContext &ctx,
389392

390393
layout.flags |= (uint32_t) LayoutFlag::JitIndex;
391394
layout.index = index;
392-
layout.vt = jit_var_type(index);
395+
layout.vt = (uint32_t) jit_var_type(index);
393396
}
394397

395398
/**
@@ -408,7 +411,7 @@ uint32_t FlatVariables::construct_jit_index(uint32_t prev_index) {
408411
if (layout.flags & (uint32_t) LayoutFlag::Literal) {
409412
index = layout.literal_index;
410413
jit_var_inc_ref(index);
411-
vt = layout.vt;
414+
vt = (VarType) layout.vt;
412415
} else {
413416
VarLayout &var_layout = this->var_layout[layout.index];
414417
index = this->variables[layout.index];
@@ -1706,6 +1709,7 @@ nb::object FunctionRecording::replay(nb::callable func,
17061709
}
17071710

17081711
nb::object FrozenFunction::operator()(nb::args args, nb::kwargs kwargs) {
1712+
ProfilerPhase profiler("frozen function");
17091713
nb::object result;
17101714
{
17111715
// Enter Isolate grad scope, so that gradients are not propagated

src/python/freeze.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,27 @@ enum class LayoutFlag : uint32_t {
4545
/// sub-elements or their field keys. This can be used to reconstruct a PyTree
4646
/// from a flattened variable array.
4747
struct Layout {
48-
/// Number of members in this container.
49-
/// Can be used to traverse the layout without knowing the type.
50-
uint32_t num = 0;
48+
49+
/// The literal data
50+
uint64_t literal = 0;
51+
5152
/// Optional field identifiers of the container
5253
/// for example: keys in dictionary
5354
drjit::vector<nb::object> fields;
55+
56+
/// Number of members in this container.
57+
/// Can be used to traverse the layout without knowing the type.
58+
uint32_t num = 0;
5459
/// The index in the flat_variables array of this variable.
5560
/// This can be used to determine aliasing.
5661
uint32_t index = 0;
5762

5863
/// Flags, storing information about variables and literals.
59-
uint32_t flags = 0;
64+
uint32_t flags : 6; // LayoutFlag
6065

61-
/// The literal data
62-
uint64_t literal = 0;
6366
/// Optional drjit type of the variable
64-
VarType vt = VarType::Void;
67+
uint32_t vt: 4; // VarType
68+
6569
/// Variable index of literal. Instead of constructing a literal every time,
6670
/// we keep a reference to it.
6771
uint32_t literal_index = 0;
@@ -77,7 +81,9 @@ struct Layout {
7781
bool operator==(const Layout &rhs) const;
7882
bool operator!=(const Layout &rhs) const { return !(*this == rhs); }
7983

80-
Layout() = default;
84+
Layout()
85+
: literal(0), fields(), num(0), index(0), flags(0), vt(0),
86+
literal_index(0), py_object(), type() {};
8187

8288
Layout(const Layout &) = delete;
8389
Layout &operator=(const Layout &) = delete;

0 commit comments

Comments
 (0)