Skip to content

Commit c8ed3f6

Browse files
committed
Improve work stealing for scaling situations
1 parent d9bc3c6 commit c8ed3f6

File tree

1 file changed

+38
-10
lines changed

1 file changed

+38
-10
lines changed

distributed/stealing.py

+38-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
2+
import random
23
from collections import defaultdict, deque
4+
from functools import partial
35
from math import log2
46
from time import time
57

@@ -113,27 +115,21 @@ def steal_time_ratio(self, ts):
113115
For example a result of zero implies a task without dependencies.
114116
level: The location within a stealable list to place this value
115117
"""
116-
if not ts.dependencies: # no dependencies fast path
117-
return 0, 0
118118

119119
split = ts.prefix.name
120120
if split in fast_tasks:
121121
return None, None
122122

123123
ws = ts.processing_on
124124
compute_time = ws.processing[ts]
125-
if compute_time < 0.005: # 5ms, just give up
126-
return None, None
127125

128126
nbytes = ts.get_nbytes_deps()
129127
transfer_time = nbytes / self.scheduler.bandwidth + LATENCY
130128
cost_multiplier = transfer_time / compute_time
131-
if cost_multiplier > 100:
132-
return None, None
133129

134130
level = int(round(log2(cost_multiplier) + 6))
135-
if level < 1:
136-
level = 1
131+
132+
level = min(len(self.cost_multipliers) - 1, level)
137133

138134
return cost_multiplier, level
139135

@@ -344,7 +340,10 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
344340
thieves = idle
345341
if not thieves:
346342
break
347-
thief = thieves[i % len(thieves)]
343+
344+
thief = self._maybe_pick_thief(ts, thieves)
345+
if not thief:
346+
thief = thieves[i % len(thieves)]
348347

349348
duration = sat.processing.get(ts)
350349
if duration is None:
@@ -380,7 +379,10 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
380379
thieves = idle
381380
if not thieves:
382381
continue
383-
thief = thieves[i % len(thieves)]
382+
thief = self._maybe_pick_thief(ts, thieves)
383+
if not thief:
384+
thief = thieves[i % len(thieves)]
385+
384386
duration = sat.processing[ts]
385387

386388
maybe_move_task(
@@ -394,6 +396,32 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
394396
if s.digests:
395397
s.digests["steal-duration"].add(stop - start)
396398

399+
def _maybe_pick_thief(self, ts, thieves):
400+
"""Try to be smart about picking a thief given some options. We're
401+
trying to pick a thief which has dependencies for a given task, if
402+
possible and will pick the one which works best for us given the
403+
Scheduler.worker_objective
404+
405+
If no idle worker with dependencies is found, this returns None.
406+
"""
407+
if ts._dependencies:
408+
who_has = set()
409+
for dep in ts._dependencies:
410+
who_has.update(dep.who_has)
411+
412+
thieves_with_data = who_has & set(thieves)
413+
414+
# If there are potential thieves with dependencies we
415+
# should prefer them and pick the one which works best.
416+
# Otherwise just random/round robin
417+
if thieves_with_data:
418+
if len(thieves_with_data) > 10:
419+
thieves_with_data = random.sample(thieves_with_data, 10)
420+
return min(
421+
thieves_with_data,
422+
key=partial(self.scheduler.worker_objective, ts),
423+
)
424+
397425
def restart(self, scheduler):
398426
for stealable in self.stealable.values():
399427
for s in stealable:

0 commit comments

Comments
 (0)