@@ -30,9 +30,9 @@ struct ProfilerPhase {
30
30
std::string m_message;
31
31
ProfilerPhase (const char *message) : m_message(message) {
32
32
jit_log (LogLevel::Debug, " profiler start: %s" , message);
33
- #if defined(DRJIT_ENABLE_NVTX)
33
+ // #if defined(DRJIT_ENABLE_NVTX)
34
34
jit_profile_range_push (message);
35
- #endif
35
+ // #endif
36
36
}
37
37
38
38
ProfilerPhase (const drjit::TraversableBase *traversable) {
@@ -47,9 +47,9 @@ struct ProfilerPhase {
47
47
}
48
48
49
49
~ProfilerPhase () {
50
- #if defined(DRJIT_ENABLE_NVTX)
50
+ // #if defined(DRJIT_ENABLE_NVTX)
51
51
jit_profile_range_pop ();
52
- #endif
52
+ // #endif
53
53
jit_log (LogLevel::Debug, " profiler end: %s" , m_message.c_str ());
54
54
}
55
55
};
@@ -208,53 +208,13 @@ uint32_t FlatVariables::add_jit_index(uint32_t index) {
208
208
* consists of changes in literal values.
209
209
*/
210
210
bool compatible_auto_opaque (FlatVariables &cur, FlatVariables &prev){
211
+ // NOTE: We only test the size of the layout, as a full test is somewhat
212
+ // expensive, and the worst case is that we make too many variables opaque,
213
+ // which does not impact correctness. If this causes problems, more
214
+ // extensive tests might have to be reintroduced.
211
215
if (cur.layout .size () != prev.layout .size ()) {
212
216
return false ;
213
217
}
214
- for (uint32_t i = 0 ; i < cur.layout .size (); i++) {
215
- Layout &cur_layout = cur.layout [i];
216
- Layout &prev_layout = prev.layout [i];
217
-
218
- if (((bool ) cur_layout.type != (bool ) prev_layout.type ) ||
219
- !(cur_layout.type .equal (prev_layout.type ))){
220
- // jit_log(LogLevel::Warn, "type");
221
- return false ;
222
- }
223
-
224
- if (cur_layout.num != prev_layout.num ){
225
- // jit_log(LogLevel::Warn, "num");
226
- return false ;
227
- }
228
-
229
- if (cur_layout.fields .size () != prev_layout.fields .size ()){
230
- // jit_log(LogLevel::Warn, "fields.size()");
231
- return false ;
232
- }
233
-
234
- for (uint32_t i = 0 ; i < cur_layout.fields .size (); ++i) {
235
- if (!(cur_layout.fields [i].equal (prev_layout.fields [i]))){
236
- // jit_log(LogLevel::Warn, "fields[%u]", i);
237
- return false ;
238
- }
239
- }
240
-
241
- // if (cur_layout.index != prev_layout.index)
242
- // return false;
243
-
244
- // if (cur_layout.flags != prev_layout.flags)
245
- // return false;
246
-
247
- if (cur_layout.vt != prev_layout.vt ){
248
- // jit_log(LogLevel::Warn, "vt");
249
- return false ;
250
- }
251
-
252
- if (((bool ) cur_layout.py_object != (bool ) prev_layout.py_object ) ||
253
- !cur_layout.py_object .equal (prev_layout.py_object )){
254
- // jit_log(LogLevel::Warn, "py_object");
255
- return false ;
256
- }
257
- }
258
218
return true ;
259
219
}
260
220
@@ -1477,12 +1437,12 @@ bool log_diff(LogLevel level, const FlatVariables &curr,
1477
1437
1478
1438
return true ;
1479
1439
}
1480
- inline void hash_combine (size_t &seed, size_t value) {
1440
+ inline void hash_combine (uint64_t &seed, uint64_t value) {
1481
1441
// / From CityHash (https://github.com/google/cityhash)
1482
- const size_t mult = 0x9ddfea08eb382d69ull ;
1483
- size_t a = (value ^ seed) * mult;
1442
+ const uint64_t mult = 0x9ddfea08eb382d69ull ;
1443
+ uint64_t a = (value ^ seed) * mult;
1484
1444
a ^= (a >> 47 );
1485
- size_t b = (seed ^ a) * mult;
1445
+ uint64_t b = (seed ^ a) * mult;
1486
1446
b ^= (b >> 47 );
1487
1447
seed = b * mult;
1488
1448
}
@@ -1491,46 +1451,39 @@ size_t
1491
1451
FlatVariablesHasher::operator ()(const std::shared_ptr<FlatVariables> &key) const {
1492
1452
ProfilerPhase profiler (" hash" );
1493
1453
// Hash the layout
1494
- // NOTE: string hashing seems to be less efficient
1495
- size_t hash = key->layout .size ();
1496
- for (const Layout &layout : key->layout ) {
1497
- hash_combine (hash, layout.num );
1498
- hash_combine (hash, layout.fields .size ());
1499
- hash_combine (hash, (size_t ) layout.flags );
1500
- hash_combine (hash, (size_t ) layout.index );
1501
- hash_combine (hash, (size_t ) layout.literal );
1502
- hash_combine (hash, (size_t ) layout.vt );
1503
- if (layout.type )
1504
- hash_combine (hash, nb::hash (layout.type ));
1505
- if (layout.py_object ){
1506
- PyObject *ptr = layout.py_object .ptr ();
1507
- uint32_t object_hash;
1508
- Py_hash_t rv = PyObject_Hash (ptr);
1509
-
1510
- // Try to hash the object, and otherwise fallback to ``id()``
1511
- if (rv == -1 && PyErr_Occurred ()) {
1512
- PyErr_Clear ();
1513
- object_hash = (uintptr_t ) ptr;
1514
- } else {
1515
- object_hash = rv;
1516
- }
1517
1454
1518
- hash_combine (hash, object_hash);
1519
- }
1520
- for (auto &field : layout.fields ) {
1521
- hash_combine (hash, nb::hash (field));
1522
- }
1455
+ // TODO: Maybe we can use xxh by first collecting in vector<uint64_t>?
1456
+
1457
+ uint64_t hash = (uint64_t ) (key->layout .size () << 32 ) |
1458
+ (uint64_t ) (key->var_layout .size () << 2 );
1459
+
1460
+ for (const Layout &layout : key->layout ) {
1461
+ // if layout.fields is not 0 then layout.num == layout.fields.size()
1462
+ // therefore we can omit layout.fields.size().
1463
+ hash_combine (hash,
1464
+ ((uint64_t ) layout.num << 32 ) | ((uint64_t ) layout.index ));
1465
+ hash_combine (hash, layout.literal );
1466
+ hash_combine (hash,
1467
+ ((uint64_t ) layout.flags << 32 ) | ((uint64_t ) layout.vt ));
1468
+ hash_combine (
1469
+ hash,
1470
+ ((uint64_t ) (uint32_t ) (uintptr_t ) layout.type .ptr () << 32 ) |
1471
+ ((uint64_t ) (uint32_t ) (uintptr_t ) layout.py_object .ptr ()));
1472
+ hash_combine (hash, (uint64_t ) layout.vt );
1473
+ for (auto &field : layout.fields )
1474
+ hash_combine (hash, (uintptr_t )field.ptr ());
1523
1475
}
1524
1476
1525
1477
for (const VarLayout &layout : key->var_layout ) {
1526
- hash_combine (hash, (size_t ) layout.vt );
1527
- hash_combine (hash, (size_t ) layout.vs );
1528
- hash_combine (hash, (size_t ) layout.flags );
1529
- hash_combine (hash, (size_t ) layout.size_index );
1478
+ // layout.vt: 4
1479
+ // layout.vs: 4
1480
+ // layout.flags: 6
1481
+ hash_combine (hash, ((uint64_t ) layout.size_index << 32 ) |
1482
+ ((uint64_t ) layout.flags << 8 ) |
1483
+ ((uint64_t ) layout.vs << 4 ) |
1484
+ ((uint64_t ) layout.vt ));
1530
1485
}
1531
1486
1532
- hash_combine (hash, (size_t ) key->flags );
1533
-
1534
1487
return hash;
1535
1488
}
1536
1489
0 commit comments