Skip to content

Commit 5e8be1c

Browse files
Improved hashing performance and simplified compatible_auto_opaque
1 parent 2f3a144 commit 5e8be1c

File tree

2 files changed

+41
-86
lines changed

2 files changed

+41
-86
lines changed

src/python/freeze.cpp

Lines changed: 39 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ struct ProfilerPhase {
3030
std::string m_message;
3131
ProfilerPhase(const char *message) : m_message(message) {
3232
jit_log(LogLevel::Debug, "profiler start: %s", message);
33-
#if defined(DRJIT_ENABLE_NVTX)
33+
// #if defined(DRJIT_ENABLE_NVTX)
3434
jit_profile_range_push(message);
35-
#endif
35+
// #endif
3636
}
3737

3838
ProfilerPhase(const drjit::TraversableBase *traversable) {
@@ -47,9 +47,9 @@ struct ProfilerPhase {
4747
}
4848

4949
~ProfilerPhase() {
50-
#if defined(DRJIT_ENABLE_NVTX)
50+
// #if defined(DRJIT_ENABLE_NVTX)
5151
jit_profile_range_pop();
52-
#endif
52+
// #endif
5353
jit_log(LogLevel::Debug, "profiler end: %s", m_message.c_str());
5454
}
5555
};
@@ -208,53 +208,13 @@ uint32_t FlatVariables::add_jit_index(uint32_t index) {
208208
* consists of changes in literal values.
209209
*/
210210
bool compatible_auto_opaque(FlatVariables &cur, FlatVariables &prev){
211+
// NOTE: We only test the size of the layout, as a full test is somewhat
212+
// expensive, and the worst case is that we make too many variables opaque,
213+
// which does not impact correctness. If this causes problems, more
214+
// extensive tests might have to be reintroduced.
211215
if (cur.layout.size() != prev.layout.size()) {
212216
return false;
213217
}
214-
for (uint32_t i = 0; i < cur.layout.size(); i++) {
215-
Layout &cur_layout = cur.layout[i];
216-
Layout &prev_layout = prev.layout[i];
217-
218-
if (((bool) cur_layout.type != (bool) prev_layout.type) ||
219-
!(cur_layout.type.equal(prev_layout.type))){
220-
// jit_log(LogLevel::Warn, "type");
221-
return false;
222-
}
223-
224-
if (cur_layout.num != prev_layout.num){
225-
// jit_log(LogLevel::Warn, "num");
226-
return false;
227-
}
228-
229-
if (cur_layout.fields.size() != prev_layout.fields.size()){
230-
// jit_log(LogLevel::Warn, "fields.size()");
231-
return false;
232-
}
233-
234-
for (uint32_t i = 0; i < cur_layout.fields.size(); ++i) {
235-
if (!(cur_layout.fields[i].equal(prev_layout.fields[i]))){
236-
// jit_log(LogLevel::Warn, "fields[%u]", i);
237-
return false;
238-
}
239-
}
240-
241-
// if (cur_layout.index != prev_layout.index)
242-
// return false;
243-
244-
// if (cur_layout.flags != prev_layout.flags)
245-
// return false;
246-
247-
if (cur_layout.vt != prev_layout.vt){
248-
// jit_log(LogLevel::Warn, "vt");
249-
return false;
250-
}
251-
252-
if (((bool) cur_layout.py_object != (bool) prev_layout.py_object) ||
253-
!cur_layout.py_object.equal(prev_layout.py_object)){
254-
// jit_log(LogLevel::Warn, "py_object");
255-
return false;
256-
}
257-
}
258218
return true;
259219
}
260220

@@ -1477,12 +1437,12 @@ bool log_diff(LogLevel level, const FlatVariables &curr,
14771437

14781438
return true;
14791439
}
1480-
inline void hash_combine(size_t &seed, size_t value) {
1440+
inline void hash_combine(uint64_t &seed, uint64_t value) {
14811441
/// From CityHash (https://github.com/google/cityhash)
1482-
const size_t mult = 0x9ddfea08eb382d69ull;
1483-
size_t a = (value ^ seed) * mult;
1442+
const uint64_t mult = 0x9ddfea08eb382d69ull;
1443+
uint64_t a = (value ^ seed) * mult;
14841444
a ^= (a >> 47);
1485-
size_t b = (seed ^ a) * mult;
1445+
uint64_t b = (seed ^ a) * mult;
14861446
b ^= (b >> 47);
14871447
seed = b * mult;
14881448
}
@@ -1491,46 +1451,39 @@ size_t
14911451
FlatVariablesHasher::operator()(const std::shared_ptr<FlatVariables> &key) const {
14921452
ProfilerPhase profiler("hash");
14931453
// Hash the layout
1494-
// NOTE: string hashing seems to be less efficient
1495-
size_t hash = key->layout.size();
1496-
for (const Layout &layout : key->layout) {
1497-
hash_combine(hash, layout.num);
1498-
hash_combine(hash, layout.fields.size());
1499-
hash_combine(hash, (size_t) layout.flags);
1500-
hash_combine(hash, (size_t) layout.index);
1501-
hash_combine(hash, (size_t) layout.literal);
1502-
hash_combine(hash, (size_t) layout.vt);
1503-
if (layout.type)
1504-
hash_combine(hash, nb::hash(layout.type));
1505-
if (layout.py_object){
1506-
PyObject *ptr = layout.py_object.ptr();
1507-
uint32_t object_hash;
1508-
Py_hash_t rv = PyObject_Hash(ptr);
1509-
1510-
// Try to hash the object, and otherwise fallback to ``id()``
1511-
if (rv == -1 && PyErr_Occurred()) {
1512-
PyErr_Clear();
1513-
object_hash = (uintptr_t) ptr;
1514-
} else {
1515-
object_hash = rv;
1516-
}
15171454

1518-
hash_combine(hash, object_hash);
1519-
}
1520-
for (auto &field : layout.fields) {
1521-
hash_combine(hash, nb::hash(field));
1522-
}
1455+
// TODO: Maybe we can use xxh by first collecting in vector<uint64_t>?
1456+
1457+
uint64_t hash = (uint64_t) (key->layout.size() << 32) |
1458+
(uint64_t) (key->var_layout.size() << 2);
1459+
1460+
for (const Layout &layout : key->layout) {
1461+
// if layout.fields is not 0 then layout.num == layout.fields.size()
1462+
// therefore we can omit layout.fields.size().
1463+
hash_combine(hash,
1464+
((uint64_t) layout.num << 32) | ((uint64_t) layout.index));
1465+
hash_combine(hash, layout.literal);
1466+
hash_combine(hash,
1467+
((uint64_t) layout.flags << 32) | ((uint64_t) layout.vt));
1468+
hash_combine(
1469+
hash,
1470+
((uint64_t) (uint32_t) (uintptr_t) layout.type.ptr() << 32) |
1471+
((uint64_t) (uint32_t) (uintptr_t) layout.py_object.ptr()));
1472+
hash_combine(hash, (uint64_t) layout.vt);
1473+
for (auto &field : layout.fields)
1474+
hash_combine(hash, (uintptr_t)field.ptr());
15231475
}
15241476

15251477
for (const VarLayout &layout : key->var_layout) {
1526-
hash_combine(hash, (size_t) layout.vt);
1527-
hash_combine(hash, (size_t) layout.vs);
1528-
hash_combine(hash, (size_t) layout.flags);
1529-
hash_combine(hash, (size_t) layout.size_index);
1478+
// layout.vt: 4
1479+
// layout.vs: 4
1480+
// layout.flags: 6
1481+
hash_combine(hash, ((uint64_t) layout.size_index << 32) |
1482+
((uint64_t) layout.flags << 8) |
1483+
((uint64_t) layout.vs << 4) |
1484+
((uint64_t) layout.vt));
15301485
}
15311486

1532-
hash_combine(hash, (size_t) key->flags);
1533-
15341487
return hash;
15351488
}
15361489

src/python/freeze.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ struct Layout {
6262
uint64_t literal = 0;
6363
/// Optional drjit type of the variable
6464
VarType vt = VarType::Void;
65+
/// Variable index of literal. Instead of constructing a literal every time,
66+
/// we keep a reference to it.
6567
uint32_t literal_index = 0;
6668

6769
/// If a non drjit type is passed as function arguments or result, we simply

0 commit comments

Comments
 (0)