@@ -101,6 +101,65 @@ async def test_auto_start(c, s, a, b):
101
101
assert len (s .tasks ["x" ].who_has ) == 1
102
102
103
103
104
+ @gen_cluster (client = True , config = demo_config ("drop" , key = "x" ))
105
+ async def test_add_policy (c , s , a , b ):
106
+ p2 = DemoPolicy (action = "drop" , key = "y" , n = 10 , candidates = None )
107
+ p3 = DemoPolicy (action = "drop" , key = "z" , n = 10 , candidates = None )
108
+
109
+ # policies parameter can be:
110
+ # - None: get from config
111
+ # - explicit set, which can be empty
112
+ m1 = s .extensions ["amm" ]
113
+ m2 = ActiveMemoryManagerExtension (s , {p2 }, register = False , start = False )
114
+ m3 = ActiveMemoryManagerExtension (s , set (), register = False , start = False )
115
+
116
+ assert len (m1 .policies ) == 1
117
+ assert len (m2 .policies ) == 1
118
+ assert len (m3 .policies ) == 0
119
+ m3 .add_policy (p3 )
120
+ assert len (m3 .policies ) == 1
121
+
122
+ futures = await c .scatter ({"x" : 1 , "y" : 2 , "z" : 3 }, broadcast = True )
123
+ m1 .run_once ()
124
+ while len (s .tasks ["x" ].who_has ) == 2 :
125
+ await asyncio .sleep (0.01 )
126
+
127
+ m2 .run_once ()
128
+ while len (s .tasks ["y" ].who_has ) == 2 :
129
+ await asyncio .sleep (0.01 )
130
+
131
+ m3 .run_once ()
132
+ while len (s .tasks ["z" ].who_has ) == 2 :
133
+ await asyncio .sleep (0.01 )
134
+
135
+
136
+ @gen_cluster (client = True , config = demo_config ("drop" , key = "x" , start = False ))
137
+ async def test_multi_start (c , s , a , b ):
138
+ """Multiple AMMs can be started in parallel"""
139
+ p2 = DemoPolicy (action = "drop" , key = "y" , n = 10 , candidates = None )
140
+ p3 = DemoPolicy (action = "drop" , key = "z" , n = 10 , candidates = None )
141
+
142
+ # policies parameter can be:
143
+ # - None: get from config
144
+ # - explicit set, which can be empty
145
+ m1 = s .extensions ["amm" ]
146
+ m2 = ActiveMemoryManagerExtension (s , {p2 }, register = False , start = True , interval = 0.1 )
147
+ m3 = ActiveMemoryManagerExtension (s , {p3 }, register = False , start = True , interval = 0.1 )
148
+
149
+ assert not m1 .started
150
+ assert m2 .started
151
+ assert m3 .started
152
+
153
+ futures = await c .scatter ({"x" : 1 , "y" : 2 , "z" : 3 }, broadcast = True )
154
+
155
+ # The AMMs should run within 0.1s of the broadcast.
156
+ # Add generous extra padding to prevent flakiness.
157
+ await asyncio .sleep (0.5 )
158
+ assert len (s .tasks ["x" ].who_has ) == 2
159
+ assert len (s .tasks ["y" ].who_has ) == 1
160
+ assert len (s .tasks ["z" ].who_has ) == 1
161
+
162
+
104
163
@gen_cluster (client = True , config = NO_AMM_START )
105
164
async def test_not_registered (c , s , a , b ):
106
165
futures = await c .scatter ({"x" : 1 }, broadcast = True )
0 commit comments