6
6
import subprocess
7
7
from dataclasses import dataclass , field
8
8
from enum import Enum
9
- from typing import Dict , List , Optional
9
+ from typing import Dict , List , Optional , Tuple
10
10
11
11
from parsl .multiprocessing import SpawnContext
12
12
from parsl .serialize import pack_res_spec_apply_message , unpack_res_spec_apply_message
@@ -76,6 +76,8 @@ class PrioritizedTask:
76
76
# This dataclass limits comparison to the priority field
77
77
priority : int
78
78
task : Dict = field (compare = False )
79
+ unpacked_task : Tuple = field (compare = False )
80
+ nodes_needed : int = field (compare = False )
79
81
80
82
81
83
class TaskScheduler :
@@ -165,29 +167,41 @@ def _return_nodes(self, nodes: List[str]) -> None:
165
167
with self ._free_node_counter .get_lock ():
166
168
self ._free_node_counter .value += len (nodes ) # type: ignore[attr-defined]
167
169
168
- def put_task (self , task_package : dict ):
169
- """Schedule task if resources are available otherwise backlog the task"""
170
- user_ns = locals ()
171
- user_ns .update ({"__builtins__" : __builtins__ })
172
- _f , _args , _kwargs , resource_spec = unpack_res_spec_apply_message (task_package ["buffer" ])
173
-
174
- nodes_needed = resource_spec .get ("num_nodes" )
175
- tid = task_package ["task_id" ]
170
+ def schedule_task (self , p_task : PrioritizedTask ):
171
+ """Schedule a prioritized task if resources are available, and push to backlog
172
+ otherwise."""
173
+ nodes_needed = p_task .nodes_needed
174
+ tid = p_task .task ["task_id" ]
176
175
if nodes_needed :
177
176
try :
178
177
allocated_nodes = self ._get_nodes (nodes_needed )
179
178
except MPINodesUnavailable :
180
179
logger .info (f"Not enough resources, placing task { tid } into backlog" )
181
- self ._backlog_queue .put (PrioritizedTask ( nodes_needed , task_package ) )
180
+ self ._backlog_queue .put (p_task )
182
181
return
183
182
else :
183
+ f , args , kwargs , resource_spec = p_task .unpacked_task
184
184
resource_spec ["MPI_NODELIST" ] = "," .join (allocated_nodes )
185
185
self ._map_tasks_to_nodes [tid ] = allocated_nodes
186
- buffer = pack_res_spec_apply_message (_f , _args , _kwargs , resource_spec )
187
- task_package ["buffer" ] = buffer
188
- task_package ["resource_spec" ] = resource_spec
186
+ buffer = pack_res_spec_apply_message (f , args , kwargs , resource_spec )
187
+ p_task .task ["buffer" ] = buffer
188
+ p_task .task ["resource_spec" ] = resource_spec
189
+
190
+ self .pending_task_q .put (p_task .task )
191
+
192
+ def put_task (self , task_package : dict ):
193
+ """Schedule task if resources are available otherwise backlog the task"""
194
+ user_ns = locals ()
195
+ user_ns .update ({"__builtins__" : __builtins__ })
196
+ _f , _args , _kwargs , resource_spec = unpack_res_spec_apply_message (task_package ["buffer" ])
197
+
198
+ nodes_needed = resource_spec .get ("num_nodes" )
199
+ prioritized_task = PrioritizedTask (priority = nodes_needed ,
200
+ task = task_package ,
201
+ unpacked_task = (_f , _args , _kwargs , resource_spec ),
202
+ nodes_needed = nodes_needed )
189
203
190
- self .pending_task_q . put ( task_package )
204
+ self .schedule_task ( prioritized_task )
191
205
192
206
def _schedule_backlog_tasks (self ):
193
207
"""Attempt to schedule backlogged tasks"""
@@ -198,12 +212,12 @@ def _schedule_backlog_tasks(self):
198
212
while True :
199
213
try :
200
214
prioritized_task = self ._backlog_queue .get (block = False )
201
- backlogged_tasks .append (prioritized_task . task )
215
+ backlogged_tasks .append (prioritized_task )
202
216
except queue .Empty :
203
217
break
204
218
205
219
for backlogged_task in backlogged_tasks :
206
- self .put_task (backlogged_task )
220
+ self .schedule_task (backlogged_task )
207
221
208
222
def get_result (self , block : bool = True , timeout : Optional [float ] = None ):
209
223
"""Return result and relinquish provisioned nodes"""
0 commit comments