diff --git a/parsl/executors/high_throughput/interchange.py b/parsl/executors/high_throughput/interchange.py index 12d3e07f31..61fa29f0f7 100644 --- a/parsl/executors/high_throughput/interchange.py +++ b/parsl/executors/high_throughput/interchange.py @@ -5,13 +5,13 @@ import os import pickle import platform -import queue import sys import threading import time from typing import Any, Dict, List, NoReturn, Optional, Sequence, Set, Tuple, cast import zmq +from sortedcontainers import SortedList from parsl import curvezmq from parsl.addresses import tcp_url @@ -131,7 +131,7 @@ def __init__(self, self.hub_address = hub_address self.hub_zmq_port = hub_zmq_port - self.pending_task_queue: queue.Queue[Any] = queue.Queue(maxsize=10 ** 6) + self.pending_task_queue: SortedList[Any] = SortedList(key=lambda msg: -msg['resource_spec']['priority']) self.count = 0 self.worker_ports = worker_ports @@ -192,12 +192,9 @@ def get_tasks(self, count: int) -> Sequence[dict]: """ tasks = [] for _ in range(0, count): - try: - x = self.pending_task_queue.get(block=False) - except queue.Empty: - break - else: - tasks.append(x) + if len(self.pending_task_queue) > 0: + x = self.pending_task_queue.pop(-1) + tasks.append(x) return tasks @@ -215,11 +212,14 @@ def task_puller(self) -> NoReturn: msg = self.task_incoming.recv_pyobj() except zmq.Again: # We just timed out while attempting to receive - logger.debug("zmq.Again with {} tasks in internal queue".format(self.pending_task_queue.qsize())) + logger.debug("zmq.Again with {} tasks in internal queue".format(len(self.pending_task_queue))) continue + resource_spec = msg.get('resource_spec', {}) + resource_spec.setdefault("priority", float('inf')) + msg['resource_spec'] = resource_spec logger.debug("putting message onto pending_task_queue") - self.pending_task_queue.put(msg) + self.pending_task_queue.add(msg) task_counter += 1 logger.debug(f"Fetched {task_counter} tasks so far") @@ -476,10 +476,10 @@ def process_tasks_to_send(self, interesting_managers: Set[bytes]) -> None: len(self._ready_managers) ) - if interesting_managers and not self.pending_task_queue.empty(): + if interesting_managers and (len(self.pending_task_queue) != 0): shuffled_managers = self.manager_selector.sort_managers(self._ready_managers, interesting_managers) - while shuffled_managers and not self.pending_task_queue.empty(): # cf. the if statement above... + while shuffled_managers and (len(self.pending_task_queue) != 0): # cf. the if statement above... manager_id = shuffled_managers.pop() m = self._ready_managers[manager_id] tasks_inflight = len(m['tasks']) diff --git a/parsl/tests/test_htex/test_priority_queue.py b/parsl/tests/test_htex/test_priority_queue.py new file mode 100644 index 0000000000..f36a64ae4b --- /dev/null +++ b/parsl/tests/test_htex/test_priority_queue.py @@ -0,0 +1,53 @@ +import pytest + +import parsl +from parsl.app.app import bash_app, python_app +from parsl.config import Config +from parsl.executors import HighThroughputExecutor +from parsl.executors.high_throughput.manager_selector import ( + ManagerSelector, + RandomManagerSelector, +) +from parsl.launchers import WrappedLauncher +from parsl.providers import LocalProvider +from parsl.usage_tracking.levels import LEVEL_1 + + +@parsl.python_app +def fake_task(parsl_resource_specification={'priority': 1}): + import time + return time.time() + + +@pytest.mark.local +def test_priority_queue(): + p = LocalProvider( + init_blocks=0, + max_blocks=0, + min_blocks=0, + ) + + htex = HighThroughputExecutor( + label="htex_local", + max_workers_per_node=1, + manager_selector=RandomManagerSelector(), + provider=p, + ) + + config = Config( + executors=[htex], + strategy="htex_auto_scale", + usage_tracking=LEVEL_1, + ) + + with parsl.load(config): + futures = {} + for priority in range(10, 0, -1): + spec = {'priority': priority} + futures[priority] = fake_task(parsl_resource_specification=spec) + + p.max_blocks = 1 + results = {priority: future.result() for priority, future in futures.items()} + sorted_results = dict(sorted(results.items(), key=lambda item: item[1])) + sorted_priorities = list(sorted_results.keys()) + assert sorted_priorities == sorted(sorted_priorities) diff --git a/test-requirements.txt b/test-requirements.txt index 82ec5172c2..e043171e8f 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -16,6 +16,7 @@ mpi4py # (where it's specified in setup.py) sqlalchemy>=1.4,<2 sqlalchemy2-stubs +sortedcontainers-stubs Sphinx==4.5.0 twine