Skip to content

Commit dc9586d

Browse files
committed
Improve work stealing for scaling situations
1 parent ec9b569 commit dc9586d

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

distributed/scheduler.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -2340,8 +2340,9 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
23402340
if ts._dependencies or valid_workers is not None:
23412341
ws = decide_worker(
23422342
ts,
2343-
self._workers_dv.values(),
2343+
set(self._workers_dv.values()),
23442344
valid_workers,
2345+
set(self._idle_dv),
23452346
partial(self.worker_objective, ts),
23462347
)
23472348
else:
@@ -7471,7 +7472,11 @@ def _reevaluate_occupancy_worker(state: SchedulerState, ws: WorkerState):
74717472
@cfunc
74727473
@exceptval(check=False)
74737474
def decide_worker(
7474-
ts: TaskState, all_workers, valid_workers: set, objective
7475+
ts: TaskState,
7476+
all_workers: set,
7477+
valid_workers: Optional[set],
7478+
idle: set,
7479+
objective,
74757480
) -> WorkerState:
74767481
"""
74777482
Decide which worker should take task *ts*.
@@ -7495,19 +7500,19 @@ def decide_worker(
74957500
candidates: set
74967501
assert all([dts._who_has for dts in deps])
74977502
if ts._actor:
7498-
candidates = set(all_workers)
7503+
candidates = all_workers
74997504
else:
75007505
candidates = {wws for dts in deps for wws in dts._who_has}
75017506
if valid_workers is None:
75027507
if not candidates:
7503-
candidates = set(all_workers)
7508+
candidates = all_workers
75047509
else:
75057510
candidates &= valid_workers
75067511
if not candidates:
75077512
candidates = valid_workers
75087513
if not candidates:
75097514
if ts._loose_restrictions:
7510-
ws = decide_worker(ts, all_workers, None, objective)
7515+
ws = decide_worker(ts, all_workers, None, idle, objective)
75117516
return ws
75127517

75137518
ncandidates: Py_ssize_t = len(candidates)

distributed/stealing.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
from collections import defaultdict, deque
3+
from functools import partial
34
from math import log2
45
from time import time
56

@@ -113,27 +114,21 @@ def steal_time_ratio(self, ts):
113114
For example a result of zero implies a task without dependencies.
114115
level: The location within a stealable list to place this value
115116
"""
116-
if not ts.dependencies: # no dependencies fast path
117-
return 0, 0
118117

119118
split = ts.prefix.name
120119
if split in fast_tasks:
121120
return None, None
122121

123122
ws = ts.processing_on
124123
compute_time = ws.processing[ts]
125-
if compute_time < 0.005: # 5ms, just give up
126-
return None, None
127124

128125
nbytes = ts.get_nbytes_deps()
129126
transfer_time = nbytes / self.scheduler.bandwidth + LATENCY
130127
cost_multiplier = transfer_time / compute_time
131-
if cost_multiplier > 100:
132-
return None, None
133128

134129
level = int(round(log2(cost_multiplier) + 6))
135-
if level < 1:
136-
level = 1
130+
131+
level = min(len(self.cost_multipliers) - 1, level)
137132

138133
return cost_multiplier, level
139134

@@ -344,7 +339,10 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
344339
thieves = idle
345340
if not thieves:
346341
break
347-
thief = thieves[i % len(thieves)]
342+
343+
thief = min(
344+
thieves, key=partial(self.scheduler.worker_objective, ts)
345+
)
348346

349347
duration = sat.processing.get(ts)
350348
if duration is None:
@@ -380,7 +378,9 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
380378
thieves = idle
381379
if not thieves:
382380
continue
383-
thief = thieves[i % len(thieves)]
381+
thief = min(
382+
thieves, key=partial(self.scheduler.worker_objective, ts)
383+
)
384384
duration = sat.processing[ts]
385385

386386
maybe_move_task(

0 commit comments

Comments
 (0)