Skip to content

Commit 6a4b1a1

Browse files
committed
New dataclass for MPITaskScheduler and fixes to the backlog scheduling logic
* `schedule_backlog_tasks` is now updated to fetch all tasks in the backlog_queue and then attempt to schedule them avoiding the infinite loop. * A new `PrioritizedTask` dataclass is added that disable comparison on the task: dict element. * The priority is set num_nodes * -1 to ensure that larger jobs get prioritized.
1 parent f399687 commit 6a4b1a1

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

parsl/executors/high_throughput/mpi_resource_management.py

+22-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55
import queue
66
import subprocess
7+
from dataclasses import dataclass, field
78
from enum import Enum
89
from typing import Dict, List, Optional
910

@@ -69,6 +70,14 @@ def __str__(self):
6970
return f"MPINodesUnavailable(requested={self.requested} available={self.available})"
7071

7172

73+
@dataclass(order=True)
74+
class PrioritizedTask:
75+
# Comparing dict will fail since they are unhashable
76+
# This dataclass limits comparison to the priority field
77+
priority: int
78+
task: Dict = field(compare=False)
79+
80+
7281
class TaskScheduler:
7382
"""Default TaskScheduler that does no taskscheduling
7483
@@ -111,7 +120,7 @@ def __init__(
111120
super().__init__(pending_task_q, pending_result_q)
112121
self.scheduler = identify_scheduler()
113122
# PriorityQueue is threadsafe
114-
self._backlog_queue: queue.PriorityQueue = queue.PriorityQueue()
123+
self._backlog_queue: queue.PriorityQueue[PrioritizedTask] = queue.PriorityQueue()
115124
self._map_tasks_to_nodes: Dict[str, List[str]] = {}
116125
self.available_nodes = get_nodes_in_batchjob(self.scheduler)
117126
self._free_node_counter = SpawnContext.Value("i", len(self.available_nodes))
@@ -169,7 +178,8 @@ def put_task(self, task_package: dict):
169178
allocated_nodes = self._get_nodes(nodes_needed)
170179
except MPINodesUnavailable:
171180
logger.info(f"Not enough resources, placing task {tid} into backlog")
172-
self._backlog_queue.put((nodes_needed, task_package))
181+
# Negate the priority element so that larger tasks are prioritized
182+
self._backlog_queue.put(PrioritizedTask(-1 * nodes_needed, task_package))
173183
return
174184
else:
175185
resource_spec["MPI_NODELIST"] = ",".join(allocated_nodes)
@@ -182,14 +192,16 @@ def put_task(self, task_package: dict):
182192

183193
def _schedule_backlog_tasks(self):
184194
"""Attempt to schedule backlogged tasks"""
185-
try:
186-
_nodes_requested, task_package = self._backlog_queue.get(block=False)
187-
self.put_task(task_package)
188-
except queue.Empty:
189-
return
190-
else:
191-
# Keep attempting to schedule tasks till we are out of resources
192-
self._schedule_backlog_tasks()
195+
196+
# Separate fetching tasks from the _backlog_queue and scheduling them
197+
# since tasks that failed to schedule will be pushed to the _backlog_queue
198+
backlogged_tasks = []
199+
while not self._backlog_queue.empty():
200+
prioritized_task = self._backlog_queue.get(block=False)
201+
backlogged_tasks.append(prioritized_task.task)
202+
203+
for backlogged_task in backlogged_tasks:
204+
self.put_task(backlogged_task)
193205

194206
def get_result(self, block: bool = True, timeout: Optional[float] = None):
195207
"""Return result and relinquish provisioned nodes"""

0 commit comments

Comments
 (0)