1
1
import logging
2
+ import random
2
3
from collections import defaultdict , deque
4
+ from functools import partial
3
5
from math import log2
4
6
from time import time
5
7
@@ -113,27 +115,21 @@ def steal_time_ratio(self, ts):
113
115
For example a result of zero implies a task without dependencies.
114
116
level: The location within a stealable list to place this value
115
117
"""
116
- if not ts .dependencies : # no dependencies fast path
117
- return 0 , 0
118
118
119
119
split = ts .prefix .name
120
120
if split in fast_tasks :
121
121
return None , None
122
122
123
123
ws = ts .processing_on
124
124
compute_time = ws .processing [ts ]
125
- if compute_time < 0.005 : # 5ms, just give up
126
- return None , None
127
125
128
126
nbytes = ts .get_nbytes_deps ()
129
127
transfer_time = nbytes / self .scheduler .bandwidth + LATENCY
130
128
cost_multiplier = transfer_time / compute_time
131
- if cost_multiplier > 100 :
132
- return None , None
133
129
134
130
level = int (round (log2 (cost_multiplier ) + 6 ))
135
- if level < 1 :
136
- level = 1
131
+
132
+ level = min ( len ( self . cost_multipliers ) - 1 , level )
137
133
138
134
return cost_multiplier , level
139
135
@@ -344,7 +340,10 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
344
340
thieves = idle
345
341
if not thieves :
346
342
break
347
- thief = thieves [i % len (thieves )]
343
+
344
+ thief = self ._maybe_pick_thief (ts , thieves )
345
+ if not thief :
346
+ thief = thieves [i % len (thieves )]
348
347
349
348
duration = sat .processing .get (ts )
350
349
if duration is None :
@@ -380,7 +379,10 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
380
379
thieves = idle
381
380
if not thieves :
382
381
continue
383
- thief = thieves [i % len (thieves )]
382
+ thief = self ._maybe_pick_thief (ts , thieves )
383
+ if not thief :
384
+ thief = thieves [i % len (thieves )]
385
+
384
386
duration = sat .processing [ts ]
385
387
386
388
maybe_move_task (
@@ -394,6 +396,32 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
394
396
if s .digests :
395
397
s .digests ["steal-duration" ].add (stop - start )
396
398
399
+ def _maybe_pick_thief (self , ts , thieves ):
400
+ """Try to be smart about picking a thief given some options. We're
401
+ trying to pick a thief which has dependencies for a given task, if
402
+ possible and will pick the one which works best for us given the
403
+ Scheduler.worker_objective
404
+
405
+ If no idle worker with dependencies is found, this returns None.
406
+ """
407
+ if ts ._dependencies :
408
+ who_has = set ()
409
+ for dep in ts ._dependencies :
410
+ who_has .update (dep .who_has )
411
+
412
+ thieves_with_data = who_has & set (thieves )
413
+
414
+ # If there are potential thieves with dependencies we
415
+ # should prefer them and pick the one which works best.
416
+ # Otherwise just random/round robin
417
+ if thieves_with_data :
418
+ if len (thieves_with_data ) > 10 :
419
+ thieves_with_data = random .sample (thieves_with_data , 10 )
420
+ return min (
421
+ thieves_with_data ,
422
+ key = partial (self .scheduler .worker_objective , ts ),
423
+ )
424
+
397
425
def restart (self , scheduler ):
398
426
for stealable in self .stealable .values ():
399
427
for s in stealable :
0 commit comments