Skip to content

Commit 05677bb

Browse files
authored
Run multiple AMMs in parallel (#5315)
Propaedeutic to RetireWorker AMM policy
1 parent 54760d8 commit 05677bb

File tree

2 files changed

+76
-8
lines changed

2 files changed

+76
-8
lines changed

distributed/active_memory_manager.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -52,23 +52,20 @@ def __init__(
5252
interval: Optional[float] = None,
5353
):
5454
self.scheduler = scheduler
55+
self.policies = set()
5556

5657
if policies is None:
58+
# Initialize policies from config
5759
policies = set()
5860
for kwargs in dask.config.get(
5961
"distributed.scheduler.active-memory-manager.policies"
6062
):
6163
kwargs = kwargs.copy()
6264
cls = import_term(kwargs.pop("class"))
63-
if not issubclass(cls, ActiveMemoryManagerPolicy):
64-
raise TypeError(
65-
f"{cls}: Expected ActiveMemoryManagerPolicy; got {type(cls)}"
66-
)
6765
policies.add(cls(**kwargs))
6866

6967
for policy in policies:
70-
policy.manager = self
71-
self.policies = policies
68+
self.add_policy(policy)
7269

7370
if register:
7471
scheduler.extensions["amm"] = self
@@ -92,16 +89,28 @@ def __init__(
9289

9390
def start(self, comm=None) -> None:
9491
"""Start executing every ``self.interval`` seconds until scheduler shutdown"""
92+
if self.started:
93+
return
9594
pc = PeriodicCallback(self.run_once, self.interval * 1000.0)
96-
self.scheduler.periodic_callbacks["amm"] = pc
95+
self.scheduler.periodic_callbacks[f"amm-{id(self)}"] = pc
9796
pc.start()
9897

9998
def stop(self, comm=None) -> None:
10099
"""Stop periodic execution"""
101-
pc = self.scheduler.periodic_callbacks.pop("amm", None)
100+
pc = self.scheduler.periodic_callbacks.pop(f"amm-{id(self)}", None)
102101
if pc:
103102
pc.stop()
104103

104+
@property
105+
def started(self) -> bool:
106+
return f"amm-{id(self)}" in self.scheduler.periodic_callbacks
107+
108+
def add_policy(self, policy: ActiveMemoryManagerPolicy) -> None:
109+
if not isinstance(policy, ActiveMemoryManagerPolicy):
110+
raise TypeError(f"Expected ActiveMemoryManagerPolicy; got {policy!r}")
111+
self.policies.add(policy)
112+
policy.manager = self
113+
105114
def run_once(self, comm=None) -> None:
106115
"""Run all policies once and asynchronously (fire and forget) enact their
107116
recommendations to replicate/drop keys

distributed/tests/test_active_memory_manager.py

+59
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,65 @@ async def test_auto_start(c, s, a, b):
101101
assert len(s.tasks["x"].who_has) == 1
102102

103103

104+
@gen_cluster(client=True, config=demo_config("drop", key="x"))
105+
async def test_add_policy(c, s, a, b):
106+
p2 = DemoPolicy(action="drop", key="y", n=10, candidates=None)
107+
p3 = DemoPolicy(action="drop", key="z", n=10, candidates=None)
108+
109+
# policies parameter can be:
110+
# - None: get from config
111+
# - explicit set, which can be empty
112+
m1 = s.extensions["amm"]
113+
m2 = ActiveMemoryManagerExtension(s, {p2}, register=False, start=False)
114+
m3 = ActiveMemoryManagerExtension(s, set(), register=False, start=False)
115+
116+
assert len(m1.policies) == 1
117+
assert len(m2.policies) == 1
118+
assert len(m3.policies) == 0
119+
m3.add_policy(p3)
120+
assert len(m3.policies) == 1
121+
122+
futures = await c.scatter({"x": 1, "y": 2, "z": 3}, broadcast=True)
123+
m1.run_once()
124+
while len(s.tasks["x"].who_has) == 2:
125+
await asyncio.sleep(0.01)
126+
127+
m2.run_once()
128+
while len(s.tasks["y"].who_has) == 2:
129+
await asyncio.sleep(0.01)
130+
131+
m3.run_once()
132+
while len(s.tasks["z"].who_has) == 2:
133+
await asyncio.sleep(0.01)
134+
135+
136+
@gen_cluster(client=True, config=demo_config("drop", key="x", start=False))
137+
async def test_multi_start(c, s, a, b):
138+
"""Multiple AMMs can be started in parallel"""
139+
p2 = DemoPolicy(action="drop", key="y", n=10, candidates=None)
140+
p3 = DemoPolicy(action="drop", key="z", n=10, candidates=None)
141+
142+
# policies parameter can be:
143+
# - None: get from config
144+
# - explicit set, which can be empty
145+
m1 = s.extensions["amm"]
146+
m2 = ActiveMemoryManagerExtension(s, {p2}, register=False, start=True, interval=0.1)
147+
m3 = ActiveMemoryManagerExtension(s, {p3}, register=False, start=True, interval=0.1)
148+
149+
assert not m1.started
150+
assert m2.started
151+
assert m3.started
152+
153+
futures = await c.scatter({"x": 1, "y": 2, "z": 3}, broadcast=True)
154+
155+
# The AMMs should run within 0.1s of the broadcast.
156+
# Add generous extra padding to prevent flakiness.
157+
await asyncio.sleep(0.5)
158+
assert len(s.tasks["x"].who_has) == 2
159+
assert len(s.tasks["y"].who_has) == 1
160+
assert len(s.tasks["z"].who_has) == 1
161+
162+
104163
@gen_cluster(client=True, config=NO_AMM_START)
105164
async def test_not_registered(c, s, a, b):
106165
futures = await c.scatter({"x": 1}, broadcast=True)

0 commit comments

Comments
 (0)