Skip to content

Commit d074c9e

Browse files
authored
Fix regression in test_weakref_cache (#6033)
1 parent 971a96d commit d074c9e

File tree

13 files changed

+65
-73
lines changed

13 files changed

+65
-73
lines changed

distributed/deploy/tests/test_adaptive.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import gc
32
import math
43
from time import sleep
54

@@ -151,7 +150,6 @@ async def test_min_max():
151150
assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log)
152151

153152
del futures
154-
gc.collect()
155153

156154
start = time()
157155
while len(cluster.scheduler.workers) != 1:

distributed/deploy/tests/test_local.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import asyncio
2-
import gc
32
import subprocess
43
import sys
54
import unittest
6-
import weakref
75
from threading import Lock
86
from time import sleep
97
from urllib.parse import urlparse
@@ -644,18 +642,9 @@ def test_adapt(loop):
644642
cluster.adapt(minimum=0, maximum=2, interval="10ms")
645643
assert cluster._adaptive.minimum == 0
646644
assert cluster._adaptive.maximum == 2
647-
ref = weakref.ref(cluster._adaptive)
648645

649646
cluster.adapt(minimum=1, maximum=2, interval="10ms")
650647
assert cluster._adaptive.minimum == 1
651-
gc.collect()
652-
653-
# the old Adaptive class sticks around, not sure why
654-
# start = time()
655-
# while ref():
656-
# sleep(0.01)
657-
# gc.collect()
658-
# assert time() < start + 5
659648

660649
start = time()
661650
while len(cluster.scheduler.workers) != 1:

distributed/diagnostics/tests/test_progress.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,6 @@ async def test_AllProgress(c, s, a, b):
122122

123123
keys = {x.key, y.key, z.key}
124124
del x, y, z
125-
import gc
126-
127-
gc.collect()
128125

129126
while any(k in s.who_has for k in keys):
130127
await asyncio.sleep(0.01)
@@ -141,9 +138,6 @@ async def test_AllProgress(c, s, a, b):
141138

142139
tkey = t.key
143140
del xx, yy, zz, t
144-
import gc
145-
146-
gc.collect()
147141

148142
while tkey in s.tasks:
149143
await asyncio.sleep(0.01)
@@ -157,9 +151,6 @@ def f(x):
157151

158152
for i in range(4):
159153
future = c.submit(f, i)
160-
import gc
161-
162-
gc.collect()
163154

164155
await asyncio.sleep(1)
165156

distributed/profile.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,24 +275,42 @@ def traverse(state, start, stop, height):
275275
}
276276

277277

278+
_watch_running: set[int] = set()
279+
280+
281+
def wait_profiler() -> None:
282+
"""Wait until a moment when no instances of watch() are sampling the frames.
283+
You must call this function whenever you would otherwise expect an object to be
284+
immediately released after it's descoped.
285+
"""
286+
while _watch_running:
287+
sleep(0.0001)
288+
289+
278290
def _watch(thread_id, log, interval="20ms", cycle="2s", omit=None, stop=lambda: False):
279291
interval = parse_timedelta(interval)
280292
cycle = parse_timedelta(cycle)
281293

282294
recent = create()
283295
last = time()
296+
watch_id = threading.get_ident()
284297

285298
while not stop():
286-
if time() > last + cycle:
287-
log.append((time(), recent))
288-
recent = create()
289-
last = time()
299+
_watch_running.add(watch_id)
290300
try:
291-
frame = sys._current_frames()[thread_id]
292-
except KeyError:
293-
return
294-
295-
process(frame, None, recent, omit=omit)
301+
if time() > last + cycle:
302+
log.append((time(), recent))
303+
recent = create()
304+
last = time()
305+
try:
306+
frame = sys._current_frames()[thread_id]
307+
except KeyError:
308+
return
309+
310+
process(frame, None, recent, omit=omit)
311+
del frame
312+
finally:
313+
_watch_running.remove(watch_id)
296314
sleep(interval)
297315

298316

distributed/protocol/tests/test_pickle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import gc
21
import pickle
32
import weakref
43
from functools import partial
54
from operator import add
65

76
import pytest
87

8+
from distributed.profile import wait_profiler
99
from distributed.protocol import deserialize, serialize
1010
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads
1111

@@ -181,7 +181,7 @@ def funcs():
181181
assert func3(1) == func(1)
182182

183183
del func, func2, func3
184-
gc.collect()
184+
wait_profiler()
185185
assert wr() is None
186186
assert wr2() is None
187187
assert wr3() is None

distributed/tests/test_asyncprocess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@ async def test_simple():
107107
assert dt <= 0.6
108108

109109
del proc
110-
gc.collect()
110+
111111
start = time()
112112
while wr1() is not None and time() < start + 1:
113113
# Perhaps the GIL switched before _watch_process() exit,
114114
# help it a little
115115
sleep(0.001)
116116
gc.collect()
117+
117118
if wr1() is not None:
118119
# Help diagnosing
119120
from types import FrameType

distributed/tests/test_client.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from distributed.core import Server, Status
7070
from distributed.metrics import time
7171
from distributed.objects import HasWhat, WhoHas
72+
from distributed.profile import wait_profiler
7273
from distributed.scheduler import (
7374
COMPILED,
7475
CollectTaskMetaDataPlugin,
@@ -286,7 +287,6 @@ async def test_compute_retries_annotations(c, s, a, b):
286287
y = delayed(varying(yargs))()
287288

288289
x, y = c.compute([x, y], optimize_graph=False)
289-
gc.collect()
290290

291291
assert await x == 30
292292
with pytest.raises(ZeroDivisionError, match="five"):
@@ -676,19 +676,15 @@ def test_get_sync(c):
676676

677677

678678
def test_no_future_references(c):
679-
from weakref import WeakSet
680-
681-
ws = WeakSet()
679+
"""Test that there are neither global references to Future objects nor circular
680+
references that need to be collected by gc
681+
"""
682+
ws = weakref.WeakSet()
682683
futures = c.map(inc, range(10))
683684
ws.update(futures)
684685
del futures
685-
import gc
686-
687-
gc.collect()
688-
start = time()
689-
while list(ws):
690-
sleep(0.01)
691-
assert time() < start + 30
686+
wait_profiler()
687+
assert not list(ws)
692688

693689

694690
def test_get_sync_optimize_graph_passes_through(c):
@@ -820,9 +816,7 @@ async def test_recompute_released_key(c, s, a, b):
820816
result1 = await x
821817
xkey = x.key
822818
del x
823-
import gc
824-
825-
gc.collect()
819+
wait_profiler()
826820
await asyncio.sleep(0)
827821
assert c.refcount[xkey] == 0
828822

@@ -1231,10 +1225,6 @@ async def test_scatter_hash_2(c, s, a, b):
12311225
@gen_cluster(client=True)
12321226
async def test_get_releases_data(c, s, a, b):
12331227
await c.gather(c.get({"x": (inc, 1)}, ["x"], sync=False))
1234-
import gc
1235-
1236-
gc.collect()
1237-
12381228
while c.refcount["x"]:
12391229
await asyncio.sleep(0.01)
12401230

@@ -3569,9 +3559,7 @@ async def test_Client_clears_references_after_restart(c, s, a, b):
35693559

35703560
key = x.key
35713561
del x
3572-
import gc
3573-
3574-
gc.collect()
3562+
wait_profiler()
35753563
await asyncio.sleep(0)
35763564

35773565
assert key not in c.refcount

distributed/tests/test_diskutils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from distributed.compatibility import WINDOWS
1616
from distributed.diskutils import WorkSpace
1717
from distributed.metrics import time
18+
from distributed.profile import wait_profiler
1819
from distributed.utils import mp_context
1920
from distributed.utils_test import captured_logger
2021

@@ -52,6 +53,7 @@ def test_workdir_simple(tmpdir):
5253
a.release()
5354
assert_contents(["bb", "bb.dirlock"])
5455
del b
56+
wait_profiler()
5557
gc.collect()
5658
assert_contents([])
5759

@@ -87,9 +89,11 @@ def test_two_workspaces_in_same_directory(tmpdir):
8789

8890
del ws
8991
del b
92+
wait_profiler()
9093
gc.collect()
9194
assert_contents(["aa", "aa.dirlock"], trials=5)
9295
del a
96+
wait_profiler()
9397
gc.collect()
9498
assert_contents([], trials=5)
9599

@@ -184,6 +188,7 @@ def test_locking_disabled(tmpdir):
184188
a.release()
185189
assert_contents(["bb"])
186190
del b
191+
wait_profiler()
187192
gc.collect()
188193
assert_contents([])
189194

distributed/tests/test_failed_workers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from distributed.comm import CommClosedError
1515
from distributed.compatibility import MACOS
1616
from distributed.metrics import time
17+
from distributed.profile import wait_profiler
1718
from distributed.scheduler import COMPILED
1819
from distributed.utils import CancelledError, sync
1920
from distributed.utils_test import (
@@ -273,9 +274,7 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b):
273274
await c.restart()
274275
y = c.submit(inc, 1)
275276
del x
276-
import gc
277-
278-
gc.collect()
277+
wait_profiler()
279278
await asyncio.sleep(0.1)
280279
await y
281280

distributed/tests/test_nanny.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from distributed.core import CommClosedError, Status
2525
from distributed.diagnostics import SchedulerPlugin
2626
from distributed.metrics import time
27+
from distributed.profile import wait_profiler
2728
from distributed.protocol.pickle import dumps
2829
from distributed.utils import TimeoutError, parse_ports
2930
from distributed.utils_test import captured_logger, gen_cluster, gen_test
@@ -208,6 +209,7 @@ async def test_num_fds(s):
208209
w = await Nanny(s.address)
209210
await w.close()
210211
del w
212+
wait_profiler()
211213
gc.collect()
212214

213215
before = proc.num_fds()

distributed/tests/test_spill.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from dask.sizeof import sizeof
1010

1111
from distributed.compatibility import WINDOWS
12+
from distributed.profile import wait_profiler
1213
from distributed.protocol import serialize_bytelist
1314
from distributed.spill import SpillBuffer, has_zict_210, has_zict_220
1415
from distributed.utils_test import captured_logger
@@ -337,6 +338,7 @@ def test_weakref_cache(tmpdir, cls, expect_cached, size):
337338
# the same id as a deleted one
338339
id_x = x.id
339340
del x
341+
wait_profiler()
340342

341343
if size < 100:
342344
buf["y"]

distributed/tests/test_steal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import contextlib
3-
import gc
43
import itertools
54
import logging
65
import random
@@ -18,6 +17,7 @@
1817
from distributed.config import config
1918
from distributed.core import Status
2019
from distributed.metrics import time
20+
from distributed.profile import wait_profiler
2121
from distributed.scheduler import key_split
2222
from distributed.system import MEMORY_LIMIT
2323
from distributed.utils_test import (
@@ -946,7 +946,7 @@ class Foo:
946946
assert not s.who_has
947947
assert not any(s.has_what.values())
948948

949-
gc.collect()
949+
wait_profiler()
950950
assert not list(ws)
951951

952952

distributed/tests/test_utils.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -460,22 +460,21 @@ async def test_loop_runner_gen():
460460

461461

462462
@gen_test()
463-
async def test_all_exceptions_logging():
464-
async def throws():
465-
raise Exception("foo1234")
466-
467-
with captured_logger("") as sio:
468-
try:
469-
await All([throws() for _ in range(5)], quiet_exceptions=Exception)
470-
except Exception:
471-
pass
463+
async def test_all_quiet_exceptions():
464+
class CustomError(Exception):
465+
pass
472466

473-
import gc
467+
async def throws(msg):
468+
raise CustomError(msg)
474469

475-
gc.collect()
476-
await asyncio.sleep(0.1)
470+
with captured_logger("") as sio:
471+
with pytest.raises(CustomError):
472+
await All([throws("foo") for _ in range(5)])
473+
with pytest.raises(CustomError):
474+
await All([throws("bar") for _ in range(5)], quiet_exceptions=CustomError)
477475

478-
assert "foo1234" not in sio.getvalue()
476+
assert "bar" not in sio.getvalue()
477+
assert "foo" in sio.getvalue()
479478

480479

481480
def test_warn_on_duration():

0 commit comments

Comments
 (0)