Skip to content

Fix regression in test_weakref_cache #6033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions distributed/deploy/tests/test_adaptive.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import gc
import math
from time import sleep

Expand Down Expand Up @@ -151,7 +150,6 @@ async def test_min_max():
assert len(adapt.log) == 2 and all(d["status"] == "up" for _, d in adapt.log)

del futures
gc.collect()

start = time()
while len(cluster.scheduler.workers) != 1:
Expand Down
11 changes: 0 additions & 11 deletions distributed/deploy/tests/test_local.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import gc
import subprocess
import sys
import unittest
import weakref
from threading import Lock
from time import sleep
from urllib.parse import urlparse
Expand Down Expand Up @@ -644,18 +642,9 @@ def test_adapt(loop):
cluster.adapt(minimum=0, maximum=2, interval="10ms")
assert cluster._adaptive.minimum == 0
assert cluster._adaptive.maximum == 2
ref = weakref.ref(cluster._adaptive)

cluster.adapt(minimum=1, maximum=2, interval="10ms")
assert cluster._adaptive.minimum == 1
gc.collect()

# the old Adaptive class sticks around, not sure why
# start = time()
# while ref():
# sleep(0.01)
# gc.collect()
# assert time() < start + 5

start = time()
while len(cluster.scheduler.workers) != 1:
Expand Down
9 changes: 0 additions & 9 deletions distributed/diagnostics/tests/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ async def test_AllProgress(c, s, a, b):

keys = {x.key, y.key, z.key}
del x, y, z
import gc

gc.collect()

while any(k in s.who_has for k in keys):
await asyncio.sleep(0.01)
Expand All @@ -141,9 +138,6 @@ async def test_AllProgress(c, s, a, b):

tkey = t.key
del xx, yy, zz, t
import gc

gc.collect()

while tkey in s.tasks:
await asyncio.sleep(0.01)
Expand All @@ -157,9 +151,6 @@ def f(x):

for i in range(4):
future = c.submit(f, i)
import gc

gc.collect()

await asyncio.sleep(1)

Expand Down
36 changes: 27 additions & 9 deletions distributed/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,24 +275,42 @@ def traverse(state, start, stop, height):
}


_watch_running: set[int] = set()


def wait_profiler() -> None:
"""Wait until a moment when no instances of watch() are sampling the frames.
You must call this function whenever you would otherwise expect an object to be
immediately released after it's descoped.
"""
while _watch_running:
sleep(0.0001)
Comment on lines +286 to +287
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure how I feel about this. It is not thread safe so while this function is returning another profile thread might again sample the frames. At the same time, if a profile thread is actually running, this thing will burn CPU hard.
For this specific use case, that's likely not a problem but I'm wondering if a sleep(0.1) would not have an equal power without us putting this kind of instrumentation in.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, this is OK only as long as its usage is limited to unit tests



def _watch(thread_id, log, interval="20ms", cycle="2s", omit=None, stop=lambda: False):
interval = parse_timedelta(interval)
cycle = parse_timedelta(cycle)

recent = create()
last = time()
watch_id = threading.get_ident()

while not stop():
if time() > last + cycle:
log.append((time(), recent))
recent = create()
last = time()
_watch_running.add(watch_id)
try:
frame = sys._current_frames()[thread_id]
except KeyError:
return

process(frame, None, recent, omit=omit)
if time() > last + cycle:
log.append((time(), recent))
recent = create()
last = time()
try:
frame = sys._current_frames()[thread_id]
except KeyError:
return

process(frame, None, recent, omit=omit)
del frame
finally:
_watch_running.remove(watch_id)
sleep(interval)


Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import gc
import pickle
import weakref
from functools import partial
from operator import add

import pytest

from distributed.profile import wait_profiler
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads

Expand Down Expand Up @@ -181,7 +181,7 @@ def funcs():
assert func3(1) == func(1)

del func, func2, func3
gc.collect()
wait_profiler()
assert wr() is None
assert wr2() is None
assert wr3() is None
3 changes: 2 additions & 1 deletion distributed/tests/test_asyncprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,14 @@ async def test_simple():
assert dt <= 0.6

del proc
gc.collect()

start = time()
while wr1() is not None and time() < start + 1:
# Perhaps the GIL switched before _watch_process() exit,
# help it a little
sleep(0.001)
gc.collect()

if wr1() is not None:
# Help diagnosing
from types import FrameType
Expand Down
30 changes: 9 additions & 21 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from distributed.core import Server, Status
from distributed.metrics import time
from distributed.objects import HasWhat, WhoHas
from distributed.profile import wait_profiler
from distributed.scheduler import (
COMPILED,
CollectTaskMetaDataPlugin,
Expand Down Expand Up @@ -286,7 +287,6 @@ async def test_compute_retries_annotations(c, s, a, b):
y = delayed(varying(yargs))()

x, y = c.compute([x, y], optimize_graph=False)
gc.collect()

assert await x == 30
with pytest.raises(ZeroDivisionError, match="five"):
Expand Down Expand Up @@ -676,19 +676,15 @@ def test_get_sync(c):


def test_no_future_references(c):
from weakref import WeakSet

ws = WeakSet()
"""Test that there are neither global references to Future objects nor circular
references that need to be collected by gc
"""
ws = weakref.WeakSet()
futures = c.map(inc, range(10))
ws.update(futures)
del futures
import gc

gc.collect()
start = time()
while list(ws):
sleep(0.01)
assert time() < start + 30
wait_profiler()
assert not list(ws)


def test_get_sync_optimize_graph_passes_through(c):
Expand Down Expand Up @@ -820,9 +816,7 @@ async def test_recompute_released_key(c, s, a, b):
result1 = await x
xkey = x.key
del x
import gc

gc.collect()
wait_profiler()
await asyncio.sleep(0)
assert c.refcount[xkey] == 0

Expand Down Expand Up @@ -1231,10 +1225,6 @@ async def test_scatter_hash_2(c, s, a, b):
@gen_cluster(client=True)
async def test_get_releases_data(c, s, a, b):
await c.gather(c.get({"x": (inc, 1)}, ["x"], sync=False))
import gc

gc.collect()

while c.refcount["x"]:
await asyncio.sleep(0.01)

Expand Down Expand Up @@ -3569,9 +3559,7 @@ async def test_Client_clears_references_after_restart(c, s, a, b):

key = x.key
del x
import gc

gc.collect()
wait_profiler()
await asyncio.sleep(0)

assert key not in c.refcount
Expand Down
5 changes: 5 additions & 0 deletions distributed/tests/test_diskutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from distributed.compatibility import WINDOWS
from distributed.diskutils import WorkSpace
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.utils import mp_context
from distributed.utils_test import captured_logger

Expand Down Expand Up @@ -52,6 +53,7 @@ def test_workdir_simple(tmpdir):
a.release()
assert_contents(["bb", "bb.dirlock"])
del b
wait_profiler()
gc.collect()
assert_contents([])

Expand Down Expand Up @@ -87,9 +89,11 @@ def test_two_workspaces_in_same_directory(tmpdir):

del ws
del b
wait_profiler()
gc.collect()
assert_contents(["aa", "aa.dirlock"], trials=5)
del a
wait_profiler()
gc.collect()
assert_contents([], trials=5)

Expand Down Expand Up @@ -184,6 +188,7 @@ def test_locking_disabled(tmpdir):
a.release()
assert_contents(["bb"])
del b
wait_profiler()
gc.collect()
assert_contents([])

Expand Down
5 changes: 2 additions & 3 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from distributed.comm import CommClosedError
from distributed.compatibility import MACOS
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.scheduler import COMPILED
from distributed.utils import CancelledError, sync
from distributed.utils_test import (
Expand Down Expand Up @@ -273,9 +274,7 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b):
await c.restart()
y = c.submit(inc, 1)
del x
import gc

gc.collect()
wait_profiler()
await asyncio.sleep(0.1)
await y

Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from distributed.core import CommClosedError, Status
from distributed.diagnostics import SchedulerPlugin
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.protocol.pickle import dumps
from distributed.utils import TimeoutError, parse_ports
from distributed.utils_test import captured_logger, gen_cluster, gen_test
Expand Down Expand Up @@ -208,6 +209,7 @@ async def test_num_fds(s):
w = await Nanny(s.address)
await w.close()
del w
wait_profiler()
gc.collect()

before = proc.num_fds()
Expand Down
2 changes: 2 additions & 0 deletions distributed/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dask.sizeof import sizeof

from distributed.compatibility import WINDOWS
from distributed.profile import wait_profiler
from distributed.protocol import serialize_bytelist
from distributed.spill import SpillBuffer, has_zict_210, has_zict_220
from distributed.utils_test import captured_logger
Expand Down Expand Up @@ -337,6 +338,7 @@ def test_weakref_cache(tmpdir, cls, expect_cached, size):
# the same id as a deleted one
id_x = x.id
del x
wait_profiler()

if size < 100:
buf["y"]
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import gc
import itertools
import logging
import random
Expand All @@ -18,6 +17,7 @@
from distributed.config import config
from distributed.core import Status
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.scheduler import key_split
from distributed.system import MEMORY_LIMIT
from distributed.utils_test import (
Expand Down Expand Up @@ -946,7 +946,7 @@ class Foo:
assert not s.who_has
assert not any(s.has_what.values())

gc.collect()
wait_profiler()
assert not list(ws)


Expand Down
25 changes: 12 additions & 13 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,21 @@ async def test_loop_runner_gen():


@gen_test()
async def test_all_exceptions_logging():
async def throws():
raise Exception("foo1234")

with captured_logger("") as sio:
try:
await All([throws() for _ in range(5)], quiet_exceptions=Exception)
except Exception:
pass
async def test_all_quiet_exceptions():
class CustomError(Exception):
pass

import gc
async def throws(msg):
raise CustomError(msg)

gc.collect()
await asyncio.sleep(0.1)
with captured_logger("") as sio:
with pytest.raises(CustomError):
await All([throws("foo") for _ in range(5)])
with pytest.raises(CustomError):
await All([throws("bar") for _ in range(5)], quiet_exceptions=CustomError)

assert "foo1234" not in sio.getvalue()
assert "bar" not in sio.getvalue()
assert "foo" in sio.getvalue()


def test_warn_on_duration():
Expand Down