55from collections import defaultdict , deque
66from collections .abc import Iterable
77from datetime import timedelta
8- from typing import TYPE_CHECKING , cast
8+ from typing import TYPE_CHECKING , Literal , TypedDict , cast
99
1010import tlz as toolz
1111from tornado .ioloop import IOLoop
12+ from typing_extensions import NotRequired
1213
1314import dask .config
1415from dask .utils import parse_timedelta
2324logger = logging .getLogger (__name__ )
2425
2526
27+ RecommendationStatus = Literal ["up" , "down" , "same" ]
28+
29+
30+ class Recommendation (TypedDict ):
31+ status : RecommendationStatus
32+ workers : NotRequired [set [WorkerState ]]
33+ n : NotRequired [int ]
34+
35+
2636class AdaptiveCore :
2737 """
2838 The core logic for adaptive deployments, with none of the cluster details
@@ -169,13 +179,13 @@ async def safe_target(self) -> int:
169179
170180 return n
171181
172- async def scale_down (self , n : int ) -> None :
182+ async def scale_down (self , workers : Iterable ) -> None :
173183 raise NotImplementedError ()
174184
175- async def scale_up (self , workers : Iterable ) -> None :
185+ async def scale_up (self , n : int ) -> None :
176186 raise NotImplementedError ()
177187
178- async def recommendations (self , target : int ) -> dict :
188+ async def recommendations (self , target : int ) -> Recommendation :
179189 """
180190 Make scale up/down recommendations based on current state and target
181191 """
@@ -185,11 +195,11 @@ async def recommendations(self, target: int) -> dict:
185195
186196 if target == len (plan ):
187197 self .close_counts .clear ()
188- return { " status" : " same"}
198+ return Recommendation ( status = " same")
189199
190200 if target > len (plan ):
191201 self .close_counts .clear ()
192- return { " status" : " up" , "n" : target }
202+ return Recommendation ( status = " up" , n = target )
193203
194204 # target < len(plan)
195205 not_yet_arrived = requested - observed
@@ -212,9 +222,9 @@ async def recommendations(self, target: int) -> dict:
212222 del self .close_counts [k ]
213223
214224 if firmly_close :
215- return { " status" : " down" , " workers" : list ( firmly_close )}
225+ return Recommendation ( status = " down" , workers = firmly_close )
216226 else :
217- return { " status" : " same"}
227+ return Recommendation ( status = " same")
218228
219229 async def adapt (self ) -> None :
220230 """
@@ -229,18 +239,16 @@ async def adapt(self) -> None:
229239
230240 try :
231241 target = await self .safe_target ()
232- recommendations = await self .recommendations (target )
233-
234- if recommendations ["status" ] != "same" :
235- self .log .append ((time (), dict (recommendations )))
242+ recommendation = await self .recommendations (target )
236243
237- status = recommendations .pop ("status" )
238- if status == "same" :
244+ if recommendation ["status" ] == "same" :
239245 return
240- if status == "up" :
241- await self .scale_up (** recommendations )
242- if status == "down" :
243- await self .scale_down (** recommendations )
246+ else :
247+ self .log .append ((time (), cast (dict , recommendation )))
248+ if recommendation ["status" ] == "up" :
249+ await self .scale_up (recommendation ["n" ])
250+ elif recommendation ["status" ] == "down" :
251+ await self .scale_down (recommendation ["workers" ])
244252 except OSError :
245253 if status != "down" :
246254 logger .error ("Adaptive stopping due to error" , exc_info = True )
0 commit comments