Skip to content

Commit 58a5a3c

Browse files
authored
Overhaul transitions for the resumed state (#6699)
1 parent 864b59c commit 58a5a3c

File tree

3 files changed

+482
-429
lines changed

3 files changed

+482
-429
lines changed

distributed/tests/test_cancelled_state.py

Lines changed: 119 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,22 @@
1616
gen_cluster,
1717
inc,
1818
lock_inc,
19-
slowinc,
2019
wait_for_state,
2120
wait_for_stimulus,
2221
)
2322
from distributed.worker_state_machine import (
23+
AddKeysMsg,
2424
ComputeTaskEvent,
2525
Execute,
2626
ExecuteFailureEvent,
2727
ExecuteSuccessEvent,
2828
FreeKeysEvent,
2929
GatherDep,
30+
GatherDepFailureEvent,
3031
GatherDepNetworkFailureEvent,
3132
GatherDepSuccessEvent,
3233
TaskFinishedMsg,
34+
UpdateDataEvent,
3335
)
3436

3537

@@ -231,53 +233,30 @@ async def wait_and_raise(*args, **kwargs):
231233
w.state.story(f1.key),
232234
[
233235
(f1.key, "executing", "released", "cancelled", {}),
234-
(
235-
f1.key,
236-
"cancelled",
237-
"error",
238-
"error",
239-
{f2.key: "executing", f1.key: "released"},
240-
),
241-
(f1.key, "error", "released", "released", {f1.key: "forgotten"}),
236+
(f1.key, "cancelled", "error", "released", {f1.key: "forgotten"}),
242237
(f1.key, "released", "forgotten", "forgotten", {}),
243238
],
244239
)
245240

246241

247-
@gen_cluster(client=True, nthreads=[("", 1)])
248-
async def test_flight_cancelled_error(c, s, b):
249-
"""One worker with one thread. We provoke an flight->cancelled transition
250-
and let the task err."""
251-
lock = asyncio.Lock()
252-
await lock.acquire()
242+
def test_flight_cancelled_error(ws):
243+
"""Test flight -> cancelled -> error transition loop.
244+
This can be caused by an issue while (un)pickling or a bug in the network stack.
253245
254-
class BrokenWorker(Worker):
255-
block_get_data = True
256-
257-
async def get_data(self, comm, *args, **kwargs):
258-
if self.block_get_data:
259-
async with lock:
260-
comm.abort()
261-
return await super().get_data(comm, *args, **kwargs)
262-
263-
async with BrokenWorker(s.address) as a:
264-
await c.wait_for_workers(2)
265-
fut1 = c.submit(inc, 1, workers=[a.address], allow_other_workers=True)
266-
fut2 = c.submit(inc, fut1, workers=[b.address])
267-
await wait_for_state(fut1.key, "flight", b)
268-
fut2.release()
269-
fut1.release()
270-
await wait_for_state(fut1.key, "cancelled", b)
271-
lock.release()
272-
# At this point we do not fetch the result of the future since the
273-
# future itself would raise a cancelled exception. At this point we're
274-
# concerned about the worker. The task should transition over error to
275-
# be eventually forgotten since we no longer hold a ref.
276-
while fut1.key in b.state.tasks:
277-
await asyncio.sleep(0.01)
278-
a.block_get_data = False
279-
# Everything should still be executing as usual after this
280-
assert await c.submit(sum, c.map(inc, range(10))) == sum(map(inc, range(10)))
246+
See https://github.com/dask/distributed/issues/6877
247+
"""
248+
ws2 = "127.0.0.1:2"
249+
instructions = ws.handle_stimulus(
250+
ComputeTaskEvent.dummy("y", who_has={"x": [ws2]}, stimulus_id="s1"),
251+
FreeKeysEvent(keys=["y", "x"], stimulus_id="s2"),
252+
GatherDepFailureEvent.from_exception(
253+
Exception(), worker=ws2, total_nbytes=1, stimulus_id="s3"
254+
),
255+
)
256+
assert instructions == [
257+
GatherDep(worker=ws2, to_gather={"x"}, total_nbytes=1, stimulus_id="s1")
258+
]
259+
assert not ws.tasks
281260

282261

283262
@gen_cluster(client=True, nthreads=[("", 1)])
@@ -332,6 +311,7 @@ def block_execution(lock):
332311
(fut1.key, "resumed", "released", "cancelled", {}),
333312
# After gather_dep receives the data, the task is forgotten
334313
(fut1.key, "cancelled", "memory", "released", {fut1.key: "forgotten"}),
314+
(fut1.key, "released", "forgotten", "forgotten", {}),
335315
],
336316
)
337317

@@ -369,7 +349,8 @@ def block_execution(event, lock):
369349
b.state.story(fut1.key),
370350
[
371351
(fut1.key, "executing", "released", "cancelled", {}),
372-
(fut1.key, "cancelled", "error", "error", {fut1.key: "released"}),
352+
(fut1.key, "cancelled", "error", "released", {fut1.key: "forgotten"}),
353+
(fut1.key, "released", "forgotten", "forgotten", {}),
373354
],
374355
)
375356

@@ -480,14 +461,18 @@ async def test_resumed_cancelled_handle_compute(
480461
lock_compute = Lock()
481462
await lock_compute.acquire()
482463
enter_compute = Event()
464+
exit_compute = Event()
483465

484-
def block(x, lock, enter_event):
466+
def block(x, lock, enter_event, exit_event):
485467
enter_event.set()
486-
with lock:
487-
if raise_error:
488-
raise RuntimeError("test error")
489-
else:
490-
return x + 1
468+
try:
469+
with lock:
470+
if raise_error:
471+
raise RuntimeError("test error")
472+
else:
473+
return x + 1
474+
finally:
475+
exit_event.set()
491476

492477
f1 = c.submit(inc, 1, key="f1", workers=[a.address])
493478
f2 = c.submit(inc, f1, key="f2", workers=[a.address])
@@ -496,6 +481,7 @@ def block(x, lock, enter_event):
496481
f2,
497482
lock=lock_compute,
498483
enter_event=enter_compute,
484+
exit_event=exit_compute,
499485
key="f3",
500486
workers=[b.address],
501487
)
@@ -523,17 +509,20 @@ async def release_all_futures():
523509
await wait_for_state(f3.key, "resumed", b)
524510
await release_all_futures()
525511

512+
if not wait_for_processing:
513+
await lock_compute.release()
514+
await exit_compute.wait()
515+
526516
f1 = c.submit(inc, 1, key="f1", workers=[a.address])
527517
f2 = c.submit(inc, f1, key="f2", workers=[a.address])
528518
f3 = c.submit(inc, f2, key="f3", workers=[b.address])
529519
f4 = c.submit(sum, [f1, f3], key="f4", workers=[b.address])
530520

531521
if wait_for_processing:
532522
await wait_for_state(f3.key, "processing", s)
523+
await lock_compute.release()
533524

534-
await lock_compute.release()
535-
536-
if not raise_error:
525+
if not wait_for_processing and not raise_error:
537526
assert await f4 == 4 + 2
538527

539528
assert_story(
@@ -546,19 +535,55 @@ async def release_all_futures():
546535
],
547536
)
548537

549-
else:
538+
elif not wait_for_processing and raise_error:
539+
assert await f4 == 4 + 2
540+
541+
assert_story(
542+
b.state.story(f3.key),
543+
expect=[
544+
(f3.key, "ready", "executing", "executing", {}),
545+
(f3.key, "executing", "released", "cancelled", {}),
546+
(f3.key, "cancelled", "fetch", "resumed", {}),
547+
(f3.key, "resumed", "error", "released", {f3.key: "fetch"}),
548+
(f3.key, "fetch", "flight", "flight", {}),
549+
(f3.key, "flight", "missing", "missing", {}),
550+
(f3.key, "missing", "waiting", "waiting", {f2.key: "fetch"}),
551+
(f3.key, "waiting", "ready", "ready", {f3.key: "executing"}),
552+
(f3.key, "ready", "executing", "executing", {}),
553+
(f3.key, "executing", "memory", "memory", {}),
554+
],
555+
)
556+
557+
elif wait_for_processing and not raise_error:
558+
assert await f4 == 4 + 2
559+
560+
assert_story(
561+
b.state.story(f3.key),
562+
expect=[
563+
(f3.key, "ready", "executing", "executing", {}),
564+
(f3.key, "executing", "released", "cancelled", {}),
565+
(f3.key, "cancelled", "fetch", "resumed", {}),
566+
(f3.key, "resumed", "waiting", "executing", {}),
567+
(f3.key, "executing", "memory", "memory", {}),
568+
],
569+
)
570+
571+
elif wait_for_processing and raise_error:
550572
with pytest.raises(RuntimeError, match="test error"):
551573
await f3
552574

553575
assert_story(
554576
b.state.story(f3.key),
555-
expect=[
577+
[
556578
(f3.key, "ready", "executing", "executing", {}),
557579
(f3.key, "executing", "released", "cancelled", {}),
558580
(f3.key, "cancelled", "fetch", "resumed", {}),
559-
(f3.key, "resumed", "error", "error", {}),
581+
(f3.key, "resumed", "waiting", "executing", {}),
582+
(f3.key, "executing", "error", "error", {}),
560583
],
561584
)
585+
else:
586+
assert False, "unreachable"
562587

563588

564589
@pytest.mark.parametrize("intermediate_state", ["resumed", "cancelled"])
@@ -570,13 +595,9 @@ async def test_deadlock_cancelled_after_inflight_before_gather_from_worker(
570595
"""If a task was transitioned to in-flight, the gather_dep coroutine was scheduled
571596
but a cancel request came in before gather_data_from_worker was issued. This might
572597
corrupt the state machine if the cancelled key is not properly handled.
573-
574-
See also
575-
--------
576-
test_workerstate_deadlock_cancelled_after_inflight_before_gather_from_worker
577598
"""
578-
fut1 = c.submit(slowinc, 1, workers=[a.address], key="f1")
579-
fut1B = c.submit(slowinc, 2, workers=[x.address], key="f1B")
599+
fut1 = c.submit(inc, 1, workers=[a.address], key="f1")
600+
fut1B = c.submit(inc, 2, workers=[x.address], key="f1B")
580601
fut2 = c.submit(sum, [fut1, fut1B], workers=[x.address], key="f2")
581602
await fut2
582603

@@ -661,14 +682,13 @@ def test_workerstate_executing_skips_fetch_on_success(ws_with_running_task):
661682
ExecuteSuccessEvent.dummy("x", 123, stimulus_id="s3"),
662683
)
663684
assert instructions == [
664-
TaskFinishedMsg.match(key="x", stimulus_id="s3"),
685+
AddKeysMsg(keys=["x"], stimulus_id="s3"),
665686
Execute(key="y", stimulus_id="s3"),
666687
]
667688
assert ws.tasks["x"].state == "memory"
668689
assert ws.data["x"] == 123
669690

670691

671-
@pytest.mark.xfail(reason="distributed#6689")
672692
def test_workerstate_executing_failure_to_fetch(ws_with_running_task):
673693
"""Test state loops:
674694
@@ -887,3 +907,39 @@ async def resume():
887907

888908
# Test that x does not get stuck.
889909
assert await fut == expect
910+
911+
912+
@pytest.mark.parametrize("release_dep", [False, True])
913+
@pytest.mark.parametrize("done_ev_cls", [ExecuteSuccessEvent, ExecuteFailureEvent])
914+
def test_cancel_with_dependencies_in_memory(ws, release_dep, done_ev_cls):
915+
"""Cancel an executing task y with an in-memory dependency x; then simulate that x
916+
did not have any further dependents, so cancel x as well.
917+
918+
Test that x immediately transitions to released state and is forgotten as soon as
919+
y finishes computing.
920+
921+
Read: https://github.com/dask/distributed/issues/6893"""
922+
ws.handle_stimulus(
923+
UpdateDataEvent(data={"x": 1}, report=False, stimulus_id="s1"),
924+
ComputeTaskEvent.dummy("y", who_has={"x": [ws.address]}, stimulus_id="s2"),
925+
)
926+
assert ws.tasks["x"].state == "memory"
927+
assert ws.tasks["y"].state == "executing"
928+
929+
ws.handle_stimulus(FreeKeysEvent(keys=["y"], stimulus_id="s3"))
930+
assert ws.tasks["x"].state == "memory"
931+
assert ws.tasks["y"].state == "cancelled"
932+
933+
if release_dep:
934+
# This will happen iff x has no dependents or waiters on the scheduler
935+
ws.handle_stimulus(FreeKeysEvent(keys=["x"], stimulus_id="s4"))
936+
assert ws.tasks["x"].state == "released"
937+
assert ws.tasks["y"].state == "cancelled"
938+
939+
ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5"))
940+
assert "y" not in ws.tasks
941+
assert "x" not in ws.tasks
942+
else:
943+
ws.handle_stimulus(done_ev_cls.dummy("y", stimulus_id="s5"))
944+
assert "y" not in ws.tasks
945+
assert ws.tasks["x"].state == "memory"

distributed/tests/test_worker_state_machine.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,16 +1171,7 @@ def test_task_with_dependencies_acquires_resources(ws):
11711171

11721172
@pytest.mark.parametrize(
11731173
"done_ev_cls,done_status",
1174-
[
1175-
(ExecuteSuccessEvent, "memory"),
1176-
pytest.param(
1177-
ExecuteFailureEvent,
1178-
"flight",
1179-
marks=pytest.mark.xfail(
1180-
reason="distributed#6682,distributed#6689,distributed#6693"
1181-
),
1182-
),
1183-
],
1174+
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "flight")],
11841175
)
11851176
def test_resumed_task_releases_resources(
11861177
ws_with_running_task, done_ev_cls, done_status
@@ -1247,14 +1238,7 @@ def test_done_task_not_in_all_running_tasks(
12471238

12481239
@pytest.mark.parametrize(
12491240
"done_ev_cls,done_status",
1250-
[
1251-
(ExecuteSuccessEvent, "memory"),
1252-
pytest.param(
1253-
ExecuteFailureEvent,
1254-
"flight",
1255-
marks=pytest.mark.xfail(reason="distributed#6689"),
1256-
),
1257-
],
1241+
[(ExecuteSuccessEvent, "memory"), (ExecuteFailureEvent, "flight")],
12581242
)
12591243
def test_done_resumed_task_not_in_all_running_tasks(
12601244
ws_with_running_task, done_ev_cls, done_status

0 commit comments

Comments
 (0)