diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3708ef3fc06..fe1949e3cf0 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -818,6 +818,19 @@ class TaskGroup: "last_worker_tasks_left", ] + name: str + prefix: TaskPrefix | None + states: dict[str, int] + dependencies: set[TaskGroup] + nbytes_total: int + duration: float + types: set[str] + start: float + stop: float + all_durations: defaultdict[str, float] + last_worker: WorkerState | None + last_worker_tasks_left: int + def __init__(self, name: str): self.name = name self.prefix: TaskPrefix | None = None @@ -830,7 +843,7 @@ def __init__(self, name: str): self.start: float = 0.0 self.stop: float = 0.0 self.all_durations: defaultdict[str, float] = defaultdict(float) - self.last_worker = None # type: ignore + self.last_worker = None self.last_worker_tasks_left = 0 def add_duration(self, action: str, start: float, stop: float): @@ -1865,7 +1878,7 @@ def transition_no_worker_memory( pdb.set_trace() raise - def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None + def decide_worker(self, ts: TaskState) -> WorkerState | None: """ Decide on a worker for task *ts*. Return a WorkerState. @@ -1879,9 +1892,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None in a round-robin fashion. """ if not self.workers: - return None # type: ignore + return None - ws: WorkerState tg: TaskGroup = ts.group valid_workers: set = self.valid_workers(ts) @@ -1892,15 +1904,15 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None ): self.unrunnable.add(ts) ts.state = "no-worker" - return None # type: ignore + return None - # Group is larger than cluster with few dependencies? - # Minimize future data transfers. + # Group fills the cluster and dependencies are much smaller than cluster? Minimize future data transfers. + ndeps_cutoff: int = min(5, len(self.workers)) if ( valid_workers is None - and len(tg) > self.total_nthreads * 2 - and len(tg.dependencies) < 5 - and sum(map(len, tg.dependencies)) < 5 + and len(tg) >= self.total_nthreads + and len(tg.dependencies) < ndeps_cutoff + and sum(map(len, tg.dependencies)) < ndeps_cutoff ): ws = tg.last_worker @@ -1955,7 +1967,8 @@ def decide_worker(self, ts: TaskState) -> WorkerState: # -> WorkerState | None type(ws), ws, ) - assert ws.address in self.workers + if ws: + assert ws.address in self.workers return ws @@ -7478,8 +7491,11 @@ def _task_to_client_msgs(state: SchedulerState, ts: TaskState) -> dict: def decide_worker( - ts: TaskState, all_workers, valid_workers: set | None, objective -) -> WorkerState: # -> WorkerState | None + ts: TaskState, + all_workers: Collection[WorkerState], + valid_workers: set[WorkerState] | None, + objective, +) -> WorkerState | None: """ Decide which worker should take task *ts*. @@ -7495,16 +7511,34 @@ def decide_worker( of bytes sent between workers. This is determined by calling the *objective* function. """ - ws: WorkerState = None # type: ignore + ws: WorkerState | None = None wws: WorkerState dts: TaskState deps: set = ts.dependencies candidates: set + n_workers: int = len(valid_workers if valid_workers is not None else all_workers) assert all([dts.who_has for dts in deps]) if ts.actor: candidates = set(all_workers) else: - candidates = {wws for dts in deps for wws in dts.who_has} + candidates = { + wws + for dts in deps + # Ignore dependencies that will need to be, or already are, copied to all workers + if len(dts.who_has) < n_workers + and not ( + len(dts.dependents) >= n_workers + and len(dts.group) < n_workers // 2 + # Really want something like: + # map(len, dts.group.dependents) >= nthreads and len(dts.group) < n_workers // 2 + # Or at least + # len(dts.dependents) * len(dts.group) >= nthreads and len(dts.group) < n_workers // 2 + # But `nthreads` is O(k) to calculate if given `valid_workers`. + # and the `map(len, dts.group.dependents)` could be extremely expensive since we can't put + # much of an upper bound on it. + ) + for wws in dts.who_has + } if valid_workers is None: if not candidates: candidates = set(all_workers) diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 6401a873e60..43d5ef7f016 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -136,6 +136,9 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c): ], ) def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads): + if ndeps >= len(nthreads): + pytest.skip() + @gen_cluster( client=True, nthreads=nthreads, @@ -239,6 +242,153 @@ def random(**kwargs): test_decide_worker_coschedule_order_neighbors_() +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) +async def test_decide_worker_common_dep_ignored(client, s, *workers): + r""" + When we have basic linear chains, but all the downstream tasks also share a common dependency, ignore that dependency. + + i j k l m n o p + \__\__\__\___/__/__/__/ + | | | | | | | | | + | | | | X | | | | + a b c d e f g h + + ^ Ignore the location of X when picking a worker for i..p. + It will end up being copied to all workers anyway. + + If a dependency will end up on every worker regardless, because many things depend on it, + we should ignore it when selecting our candidate workers. Otherwise, we'll end up considering + every worker as a candidate, which is 1) slow and 2) often leads to poor choices. + """ + roots = [ + delayed(slowinc)(1, 0.1 / (i + 1), dask_key_name=f"root-{i}") for i in range(16) + ] + # This shared dependency will get copied to all workers, eventually making all workers valid candidates for each dep + everywhere = delayed(None, name="everywhere") + deps = [ + delayed(lambda x, y: None)(r, everywhere, dask_key_name=f"dep-{i}") + for i, r in enumerate(roots) + ] + + rs, ds = dask.persist(roots, deps) + await wait(ds) + + keys = { + worker.name: dict( + root_keys=sorted( + [int(k.split("-")[1]) for k in worker.data if k.startswith("root")] + ), + deps_of_root=sorted( + [int(k.split("-")[1]) for k in worker.data if k.startswith("dep")] + ), + ) + for worker in workers + } + + for k in keys.values(): + assert k["root_keys"] == k["deps_of_root"] + + for worker in workers: + log = worker.incoming_transfer_log + if log: + assert len(log) == 1 + assert list(log[0]["keys"]) == ["everywhere"] + + +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 4) +async def test_decide_worker_large_subtrees_colocated(client, s, *workers): + r""" + Ensure that the above "ignore common dependencies" logic doesn't affect wide (but isolated) subtrees. + + ........ ........ ........ ........ + \\\\//// \\\\//// \\\\//// \\\\//// + a b c d + + Each one of a, b, etc. has more dependents than there are workers. But just because a has + lots of dependents doesn't necessarily mean it will end up copied to every worker. + Because a also has a few siblings, a's dependents shouldn't spread out over the whole cluster. + """ + roots = [delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers))] + deps = [ + delayed(inc)(r, dask_key_name=f"dep-{i}-{j}") + for i, r in enumerate(roots) + for j in range(len(workers) * 2) + ] + + rs, ds = dask.persist(roots, deps) + await wait(ds) + + keys = { + worker.name: dict( + root_keys={ + int(k.split("-")[1]) for k in worker.data if k.startswith("root") + }, + deps_of_root={ + int(k.split("-")[1]) for k in worker.data if k.startswith("dep") + }, + ) + for worker in workers + } + + for k in keys.values(): + assert k["root_keys"] == k["deps_of_root"] + assert len(k["root_keys"]) == len(roots) / len(workers) + + for worker in workers: + assert not worker.incoming_transfer_log + + +@gen_cluster( + client=True, + nthreads=[("127.0.0.1", 1)] * 4, + config={"distributed.scheduler.work-stealing": False}, +) +async def test_decide_worker_large_multiroot_subtrees_colocated(client, s, *workers): + r""" + Same as the above test, but also check isolated trees with multiple roots. + + ........ ........ ........ ........ + \\\\//// \\\\//// \\\\//// \\\\//// + a b c d e f g h + """ + roots = [ + delayed(inc)(i, dask_key_name=f"root-{i}") for i in range(len(workers) * 2) + ] + deps = [ + delayed(lambda x, y: None)( + r, roots[i * 2 + 1], dask_key_name=f"dep-{i * 2}-{j}" + ) + for i, r in enumerate(roots[::2]) + for j in range(len(workers) * 2) + ] + + rs, ds = dask.persist(roots, deps) + await wait(ds) + + keys = { + worker.name: dict( + root_keys={ + int(k.split("-")[1]) for k in worker.data if k.startswith("root") + }, + deps_of_root=set().union( + *( + (int(k.split("-")[1]), int(k.split("-")[1]) + 1) + for k in worker.data + if k.startswith("dep") + ) + ), + ) + for worker in workers + } + + for k in keys.values(): + assert k["root_keys"] == k["deps_of_root"] + assert len(k["root_keys"]) == len(roots) / len(workers) + + for worker in workers: + assert not worker.incoming_transfer_log + + @gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3) async def test_move_data_over_break_restrictions(client, s, a, b, c): [x] = await client.scatter([1], workers=b.address)