Skip to content

Commit d758a00

Browse files
authored
Fix race condition involving wrapper lookup (#865)
There's a race condition between wrapper lookup and wrapper deallocation where a Python wrapper may be returned that's in the process of being deallocated. This commit fixes the issue (see #864 for further details).
1 parent e134487 commit d758a00

File tree

3 files changed

+81
-5
lines changed

3 files changed

+81
-5
lines changed

src/nb_type.cpp

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,60 @@ static PyObject **nb_weaklist_ptr(PyObject *self) {
4040
return weaklistoffset ? (PyObject **) ((uint8_t *) self + weaklistoffset) : nullptr;
4141
}
4242

43+
static void nb_enable_try_inc_ref(PyObject *obj) noexcept {
44+
#if 0 && defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030E00A5
45+
PyUnstable_EnableTryIncRef(obj);
46+
#elif defined(Py_GIL_DISABLED)
47+
// Since this is called during object construction, we know that we have
48+
// the only reference to the object and can use a non-atomic write.
49+
assert(obj->ob_ref_shared == 0);
50+
obj->ob_ref_shared = _Py_REF_MAYBE_WEAKREF;
51+
#endif
52+
}
53+
54+
static bool nb_try_inc_ref(PyObject *obj) noexcept {
55+
#if 0 && defined(Py_GIL_DISABLED) && PY_VERSION_HEX >= 0x030E00A5
56+
return PyUnstable_TryIncRef(obj);
57+
#elif defined(Py_GIL_DISABLED)
58+
// See https://github.com/python/cpython/blob/d05140f9f77d7dfc753dd1e5ac3a5962aaa03eff/Include/internal/pycore_object.h#L761
59+
uint32_t local = _Py_atomic_load_uint32_relaxed(&obj->ob_ref_local);
60+
local += 1;
61+
if (local == 0) {
62+
// immortal
63+
return true;
64+
}
65+
if (_Py_IsOwnedByCurrentThread(obj)) {
66+
_Py_atomic_store_uint32_relaxed(&obj->ob_ref_local, local);
67+
#ifdef Py_REF_DEBUG
68+
_Py_INCREF_IncRefTotal();
69+
#endif
70+
return true;
71+
}
72+
Py_ssize_t shared = _Py_atomic_load_ssize_relaxed(&obj->ob_ref_shared);
73+
for (;;) {
74+
// If the shared refcount is zero and the object is either merged
75+
// or may not have weak references, then we cannot incref it.
76+
if (shared == 0 || shared == _Py_REF_MERGED) {
77+
return false;
78+
}
79+
80+
if (_Py_atomic_compare_exchange_ssize(
81+
&obj->ob_ref_shared, &shared, shared + (1 << _Py_REF_SHARED_SHIFT))) {
82+
#ifdef Py_REF_DEBUG
83+
_Py_INCREF_IncRefTotal();
84+
#endif
85+
return true;
86+
}
87+
}
88+
#else
89+
if (Py_REFCNT(obj) > 0) {
90+
Py_INCREF(obj);
91+
return true;
92+
}
93+
return false;
94+
#endif
95+
}
96+
4397
static PyGetSetDef inst_getset[] = {
4498
{ "__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict, nullptr, nullptr },
4599
{ nullptr, nullptr, nullptr, nullptr, nullptr }
@@ -98,6 +152,7 @@ PyObject *inst_new_int(PyTypeObject *tp, PyObject * /* args */,
98152
self->clear_keep_alive = 0;
99153
self->intrusive = intrusive;
100154
self->unused = 0;
155+
nb_enable_try_inc_ref((PyObject *)self);
101156

102157
// Update hash table that maps from C++ to Python instance
103158
nb_shard &shard = internals->shard((void *) payload);
@@ -163,6 +218,7 @@ PyObject *inst_new_ext(PyTypeObject *tp, void *value) {
163218
self->clear_keep_alive = 0;
164219
self->intrusive = intrusive;
165220
self->unused = 0;
221+
nb_enable_try_inc_ref((PyObject *)self);
166222

167223
nb_shard &shard = internals->shard(value);
168224
lock_shard guard(shard);
@@ -1766,16 +1822,18 @@ PyObject *nb_type_put(const std::type_info *cpp_type,
17661822
PyTypeObject *tp = Py_TYPE(seq.inst);
17671823

17681824
if (nb_type_data(tp)->type == cpp_type) {
1769-
Py_INCREF(seq.inst);
1770-
return seq.inst;
1825+
if (nb_try_inc_ref(seq.inst)) {
1826+
return seq.inst;
1827+
}
17711828
}
17721829

17731830
if (!lookup_type())
17741831
return nullptr;
17751832

17761833
if (PyType_IsSubtype(tp, td->type_py)) {
1777-
Py_INCREF(seq.inst);
1778-
return seq.inst;
1834+
if (nb_try_inc_ref(seq.inst)) {
1835+
return seq.inst;
1836+
}
17791837
}
17801838

17811839
if (seq.next == nullptr)

tests/test_thread.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ struct Counter {
1212
}
1313
};
1414

15+
struct GlobalData {} global_data;
16+
1517
nb::ft_mutex mutex;
1618

1719
NB_MODULE(test_thread_ext, m) {
@@ -34,4 +36,7 @@ NB_MODULE(test_thread_ext, m) {
3436
nb::ft_lock_guard guard(mutex);
3537
c.inc();
3638
}, "counter");
39+
40+
nb::class_<GlobalData>(m, "GlobalData")
41+
.def_static("get", [] { return &global_data; }, nb::rv_policy::reference);
3742
}

tests/test_thread.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import test_thread_ext as t
2-
from test_thread_ext import Counter
2+
from test_thread_ext import Counter, GlobalData
33
from common import parallelize
44

55
def test01_object_creation(n_threads=8):
@@ -75,3 +75,16 @@ def f():
7575

7676
parallelize(f, n_threads=n_threads)
7777
assert c.value == n * n_threads
78+
79+
80+
def test_06_global_wrapper(n_threads=8):
81+
# Check wrapper lookup racing with wrapper deallocation
82+
n = 10000
83+
def f():
84+
for i in range(n):
85+
GlobalData.get()
86+
GlobalData.get()
87+
GlobalData.get()
88+
GlobalData.get()
89+
90+
parallelize(f, n_threads=n_threads)

0 commit comments

Comments
 (0)