4
4
import pickle
5
5
import queue
6
6
import subprocess
7
+ from dataclasses import dataclass , field
7
8
from enum import Enum
8
9
from typing import Dict , List , Optional
9
10
@@ -69,6 +70,14 @@ def __str__(self):
69
70
return f"MPINodesUnavailable(requested={ self .requested } available={ self .available } )"
70
71
71
72
73
+ @dataclass (order = True )
74
+ class PrioritizedTask :
75
+ # Comparing dict will fail since they are unhashable
76
+ # This dataclass limits comparison to the priority field
77
+ priority : int
78
+ task : Dict = field (compare = False )
79
+
80
+
72
81
class TaskScheduler :
73
82
"""Default TaskScheduler that does no taskscheduling
74
83
@@ -111,7 +120,7 @@ def __init__(
111
120
super ().__init__ (pending_task_q , pending_result_q )
112
121
self .scheduler = identify_scheduler ()
113
122
# PriorityQueue is threadsafe
114
- self ._backlog_queue : queue .PriorityQueue = queue .PriorityQueue ()
123
+ self ._backlog_queue : queue .PriorityQueue [ PrioritizedTask ] = queue .PriorityQueue ()
115
124
self ._map_tasks_to_nodes : Dict [str , List [str ]] = {}
116
125
self .available_nodes = get_nodes_in_batchjob (self .scheduler )
117
126
self ._free_node_counter = SpawnContext .Value ("i" , len (self .available_nodes ))
@@ -169,7 +178,8 @@ def put_task(self, task_package: dict):
169
178
allocated_nodes = self ._get_nodes (nodes_needed )
170
179
except MPINodesUnavailable :
171
180
logger .info (f"Not enough resources, placing task { tid } into backlog" )
172
- self ._backlog_queue .put ((nodes_needed , task_package ))
181
+ # Negate the priority element so that larger tasks are prioritized
182
+ self ._backlog_queue .put (PrioritizedTask (- 1 * nodes_needed , task_package ))
173
183
return
174
184
else :
175
185
resource_spec ["MPI_NODELIST" ] = "," .join (allocated_nodes )
@@ -182,14 +192,16 @@ def put_task(self, task_package: dict):
182
192
183
193
def _schedule_backlog_tasks (self ):
184
194
"""Attempt to schedule backlogged tasks"""
185
- try :
186
- _nodes_requested , task_package = self ._backlog_queue .get (block = False )
187
- self .put_task (task_package )
188
- except queue .Empty :
189
- return
190
- else :
191
- # Keep attempting to schedule tasks till we are out of resources
192
- self ._schedule_backlog_tasks ()
195
+
196
+ # Separate fetching tasks from the _backlog_queue and scheduling them
197
+ # since tasks that failed to schedule will be pushed to the _backlog_queue
198
+ backlogged_tasks = []
199
+ while not self ._backlog_queue .empty ():
200
+ prioritized_task = self ._backlog_queue .get (block = False )
201
+ backlogged_tasks .append (prioritized_task .task )
202
+
203
+ for backlogged_task in backlogged_tasks :
204
+ self .put_task (backlogged_task )
193
205
194
206
def get_result (self , block : bool = True , timeout : Optional [float ] = None ):
195
207
"""Return result and relinquish provisioned nodes"""
0 commit comments