Skip to content

Commit dea9ef2

Browse files
Add a lock to distributed.profile for better concurrency control (#6421)
Adds a Lock to distributed.profile to enable better concurrency control. In particular, it allows running garbage collection without a profiling thread holding references to objects, which is necessary for #6250.
1 parent 0a77946 commit dea9ef2

File tree

10 files changed

+94
-66
lines changed

10 files changed

+94
-66
lines changed

distributed/profile.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
from distributed.metrics import time
4545
from distributed.utils import color_of
4646

47+
#: This lock can be acquired to ensure that no instance of watch() is concurrently holding references to frames
48+
lock = threading.Lock()
49+
4750

4851
def identifier(frame: FrameType | None) -> str:
4952
"""A string identifier from a frame
@@ -314,18 +317,6 @@ def traverse(state, start, stop, height):
314317
}
315318

316319

317-
_watch_running: set[int] = set()
318-
319-
320-
def wait_profiler() -> None:
321-
"""Wait until a moment when no instances of watch() are sampling the frames.
322-
You must call this function whenever you would otherwise expect an object to be
323-
immediately released after it's descoped.
324-
"""
325-
while _watch_running:
326-
sleep(0.0001)
327-
328-
329320
def _watch(
330321
thread_id: int,
331322
log: deque[tuple[float, dict[str, Any]]], # [(timestamp, output of create()), ...]
@@ -337,24 +328,20 @@ def _watch(
337328

338329
recent = create()
339330
last = time()
340-
watch_id = threading.get_ident()
341331

342332
while not stop():
343-
_watch_running.add(watch_id)
344-
try:
345-
if time() > last + cycle:
333+
if time() > last + cycle:
334+
recent = create()
335+
with lock:
346336
log.append((time(), recent))
347-
recent = create()
348337
last = time()
349-
try:
350-
frame = sys._current_frames()[thread_id]
351-
except KeyError:
352-
return
353-
354-
process(frame, None, recent, omit=omit)
355-
del frame
356-
finally:
357-
_watch_running.remove(watch_id)
338+
try:
339+
frame = sys._current_frames()[thread_id]
340+
except KeyError:
341+
return
342+
343+
process(frame, None, recent, omit=omit)
344+
del frame
358345
sleep(interval)
359346

360347

distributed/protocol/tests/test_pickle.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from distributed.profile import wait_profiler
8+
from distributed import profile
99
from distributed.protocol import deserialize, serialize
1010
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads
1111

@@ -181,7 +181,7 @@ def funcs():
181181
assert func3(1) == func(1)
182182

183183
del func, func2, func3
184-
wait_profiler()
185-
assert wr() is None
186-
assert wr2() is None
187-
assert wr3() is None
184+
with profile.lock:
185+
assert wr() is None
186+
assert wr2() is None
187+
assert wr3() is None

distributed/tests/test_client.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@
6969
from distributed.compatibility import LINUX, WINDOWS
7070
from distributed.core import Server, Status
7171
from distributed.metrics import time
72-
from distributed.profile import wait_profiler
7372
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
7473
from distributed.sizeof import sizeof
7574
from distributed.utils import is_valid_xml, mp_context, sync, tmp_text
@@ -678,8 +677,8 @@ def test_no_future_references(c):
678677
futures = c.map(inc, range(10))
679678
ws.update(futures)
680679
del futures
681-
wait_profiler()
682-
assert not list(ws)
680+
with profile.lock:
681+
assert not list(ws)
683682

684683

685684
def test_get_sync_optimize_graph_passes_through(c):
@@ -811,9 +810,9 @@ async def test_recompute_released_key(c, s, a, b):
811810
result1 = await x
812811
xkey = x.key
813812
del x
814-
wait_profiler()
815-
await asyncio.sleep(0)
816-
assert c.refcount[xkey] == 0
813+
with profile.lock:
814+
await asyncio.sleep(0)
815+
assert c.refcount[xkey] == 0
817816

818817
# 1 second batching needs a second action to trigger
819818
while xkey in s.tasks and s.tasks[xkey].who_has or xkey in a.data or xkey in b.data:
@@ -3483,10 +3482,9 @@ async def test_Client_clears_references_after_restart(c, s, a, b):
34833482

34843483
key = x.key
34853484
del x
3486-
wait_profiler()
3487-
await asyncio.sleep(0)
3488-
3489-
assert key not in c.refcount
3485+
with profile.lock:
3486+
await asyncio.sleep(0)
3487+
assert key not in c.refcount
34903488

34913489

34923490
@gen_cluster(Worker=Nanny, client=True)

distributed/tests/test_diskutils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
import dask
1414

15+
from distributed import profile
1516
from distributed.compatibility import WINDOWS
1617
from distributed.diskutils import WorkSpace
1718
from distributed.metrics import time
18-
from distributed.profile import wait_profiler
1919
from distributed.utils import mp_context
2020
from distributed.utils_test import captured_logger
2121

@@ -53,8 +53,8 @@ def test_workdir_simple(tmpdir):
5353
a.release()
5454
assert_contents(["bb", "bb.dirlock"])
5555
del b
56-
wait_profiler()
57-
gc.collect()
56+
with profile.lock:
57+
gc.collect()
5858
assert_contents([])
5959

6060
# Generated temporary name with a prefix
@@ -89,12 +89,12 @@ def test_two_workspaces_in_same_directory(tmpdir):
8989

9090
del ws
9191
del b
92-
wait_profiler()
93-
gc.collect()
92+
with profile.lock:
93+
gc.collect()
9494
assert_contents(["aa", "aa.dirlock"], trials=5)
9595
del a
96-
wait_profiler()
97-
gc.collect()
96+
with profile.lock:
97+
gc.collect()
9898
assert_contents([], trials=5)
9999

100100

@@ -188,8 +188,8 @@ def test_locking_disabled(tmpdir):
188188
a.release()
189189
assert_contents(["bb"])
190190
del b
191-
wait_profiler()
192-
gc.collect()
191+
with profile.lock:
192+
gc.collect()
193193
assert_contents([])
194194

195195
lock_file.assert_not_called()

distributed/tests/test_failed_workers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010

1111
from dask import delayed
1212

13-
from distributed import Client, Nanny, wait
13+
from distributed import Client, Nanny, profile, wait
1414
from distributed.comm import CommClosedError
1515
from distributed.compatibility import MACOS
1616
from distributed.metrics import time
17-
from distributed.profile import wait_profiler
1817
from distributed.utils import CancelledError, sync
1918
from distributed.utils_test import (
2019
captured_logger,
@@ -262,7 +261,10 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b):
262261
await c.restart()
263262
y = c.submit(inc, 1)
264263
del x
265-
wait_profiler()
264+
265+
# Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected
266+
with profile.lock:
267+
pass
266268
await asyncio.sleep(0.1)
267269
await y
268270

distributed/tests/test_nanny.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
import dask
2020
from dask.utils import tmpfile
2121

22-
from distributed import Nanny, Scheduler, Worker, rpc, wait, worker
22+
from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker
2323
from distributed.compatibility import LINUX, WINDOWS
2424
from distributed.core import CommClosedError, Status
2525
from distributed.diagnostics import SchedulerPlugin
2626
from distributed.metrics import time
27-
from distributed.profile import wait_profiler
2827
from distributed.protocol.pickle import dumps
2928
from distributed.utils import TimeoutError, parse_ports
3029
from distributed.utils_test import (
@@ -170,8 +169,8 @@ async def test_num_fds(s):
170169
# Warm up
171170
async with Nanny(s.address):
172171
pass
173-
wait_profiler()
174-
gc.collect()
172+
with profile.lock:
173+
gc.collect()
175174

176175
before = proc.num_fds()
177176

distributed/tests/test_profile.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
info_frame,
1919
ll_get_stack,
2020
llprocess,
21+
lock,
2122
merge,
2223
plot_data,
2324
process,
@@ -207,6 +208,45 @@ def stop():
207208
sleep(0.01)
208209

209210

211+
def test_watch_requires_lock_to_run():
212+
start = time()
213+
214+
def stop_lock():
215+
return time() > start + 0.600
216+
217+
def stop_profile():
218+
return time() > start + 0.500
219+
220+
def hold_lock(stop):
221+
with lock:
222+
while not stop():
223+
sleep(0.1)
224+
225+
start_threads = threading.active_count()
226+
227+
# Hog the lock over the entire duration of watch
228+
thread = threading.Thread(
229+
target=hold_lock, name="Hold Lock", kwargs={"stop": stop_lock}
230+
)
231+
thread.daemon = True
232+
thread.start()
233+
234+
log = watch(interval="10ms", cycle="50ms", stop=stop_profile)
235+
236+
start = time() # wait until thread starts up
237+
while threading.active_count() < start_threads + 2:
238+
assert time() < start + 2
239+
sleep(0.01)
240+
241+
sleep(0.5)
242+
assert len(log) == 0
243+
244+
start = time()
245+
while threading.active_count() > start_threads:
246+
assert time() < start + 2
247+
sleep(0.01)
248+
249+
210250
@dataclasses.dataclass(frozen=True)
211251
class FakeCode:
212252
co_filename: str

distributed/tests/test_spill.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from dask.sizeof import sizeof
1010

11+
from distributed import profile
1112
from distributed.compatibility import WINDOWS
12-
from distributed.profile import wait_profiler
1313
from distributed.protocol import serialize_bytelist
1414
from distributed.spill import SpillBuffer, has_zict_210, has_zict_220
1515
from distributed.utils_test import captured_logger
@@ -338,7 +338,10 @@ def test_weakref_cache(tmpdir, cls, expect_cached, size):
338338
# the same id as a deleted one
339339
id_x = x.id
340340
del x
341-
wait_profiler()
341+
342+
# Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected
343+
with profile.lock:
344+
pass
342345

343346
if size < 100:
344347
buf["y"]

distributed/tests/test_steal.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212

1313
import dask
1414

15-
from distributed import Event, Lock, Nanny, Worker, wait, worker_client
15+
from distributed import Event, Lock, Nanny, Worker, profile, wait, worker_client
1616
from distributed.compatibility import LINUX
1717
from distributed.config import config
1818
from distributed.core import Status
1919
from distributed.metrics import time
20-
from distributed.profile import wait_profiler
2120
from distributed.scheduler import key_split
2221
from distributed.system import MEMORY_LIMIT
2322
from distributed.utils_test import (
@@ -948,8 +947,8 @@ class Foo:
948947

949948
assert not s.tasks
950949

951-
wait_profiler()
952-
assert not list(ws)
950+
with profile.lock:
951+
assert not list(ws)
953952

954953

955954
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)

distributed/tests/test_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
default_client,
3535
get_client,
3636
get_worker,
37+
profile,
3738
wait,
3839
)
3940
from distributed.comm.registry import backends
@@ -42,7 +43,6 @@
4243
from distributed.diagnostics import nvml
4344
from distributed.diagnostics.plugin import PipInstall
4445
from distributed.metrics import time
45-
from distributed.profile import wait_profiler
4646
from distributed.protocol import pickle
4747
from distributed.scheduler import Scheduler
4848
from distributed.utils_test import (
@@ -1851,8 +1851,8 @@ class C:
18511851
del f
18521852
while "f" in a.data:
18531853
await asyncio.sleep(0.01)
1854-
wait_profiler()
1855-
assert ref() is None
1854+
with profile.lock:
1855+
assert ref() is None
18561856

18571857
story = a.stimulus_story("f", "f2")
18581858
assert {ev.key for ev in story} == {"f", "f2"}

0 commit comments

Comments
 (0)