diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index 1af96a9..065124d 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -1,4 +1,5 @@ import logging +import threading from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union import psycopg @@ -97,8 +98,8 @@ def get_publish_channel(self, channel: Optional[str] = None) -> str: # --- asyncio connection methods --- async def aget_connection(self) -> psycopg.AsyncConnection: - "Return existing connection or create a new one" - if not self._async_connection: + # Check if the cached async connection is either None or closed. + if not self._async_connection or getattr(self._async_connection, "closed", 0) != 0: if self._async_connection_factory: factory = resolve_callable(self._async_connection_factory) if not factory: @@ -109,7 +110,7 @@ async def aget_connection(self) -> psycopg.AsyncConnection: else: raise RuntimeError('Could not construct async connection for lack of config or factory') self._async_connection = connection - return connection # slightly weird due to MyPY + assert self._async_connection is not None return self._async_connection def get_listen_query(self, channel: str) -> psycopg.sql.Composed: @@ -178,7 +179,8 @@ async def aclose(self) -> None: # --- synchronous connection methods --- def get_connection(self) -> psycopg.Connection: - if not self._sync_connection: + # Check if the cached connection is either None or closed. + if not self._sync_connection or getattr(self._sync_connection, "closed", 0) != 0: if self._sync_connection_factory: factory = resolve_callable(self._sync_connection_factory) if not factory: @@ -189,7 +191,7 @@ def get_connection(self) -> psycopg.Connection: else: raise RuntimeError('Could not construct connection for lack of config or factory') self._sync_connection = connection - return connection + assert self._sync_connection is not None return self._sync_connection def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]: @@ -234,6 +236,7 @@ class ConnectionSaver: def __init__(self) -> None: self._connection: Optional[psycopg.Connection] = None self._async_connection: Optional[psycopg.AsyncConnection] = None + self._lock = threading.Lock() connection_save = ConnectionSaver() @@ -245,10 +248,14 @@ def connection_saver(**config) -> psycopg.Connection: # type: ignore[no-untyped Philosophically, this is used by an application that uses an ORM, or otherwise has its own connection management logic. Dispatcher does not manage connections, so this a simulation of that. + + Uses a thread lock to ensure thread safety. """ - if connection_save._connection is None: - connection_save._connection = create_connection(**config) - return connection_save._connection + with connection_save._lock: + # Check if we need to create a new connection because it's either None or closed. + if connection_save._connection is None or getattr(connection_save._connection, 'closed', False): + connection_save._connection = create_connection(**config) + return connection_save._connection async def async_connection_saver(**config) -> psycopg.AsyncConnection: # type: ignore[no-untyped-def] @@ -257,7 +264,10 @@ async def async_connection_saver(**config) -> psycopg.AsyncConnection: # type: Philosophically, this is used by an application that uses an ORM, or otherwise has its own connection management logic. Dispatcher does not manage connections, so this a simulation of that. + + Uses a thread lock to ensure thread safety. """ - if connection_save._async_connection is None: - connection_save._async_connection = await acreate_connection(**config) - return connection_save._async_connection + with connection_save._lock: + if connection_save._async_connection is None or getattr(connection_save._async_connection, 'closed', False): + connection_save._async_connection = await acreate_connection(**config) + return connection_save._async_connection diff --git a/dispatcher/control.py b/dispatcher/control.py index aaf0b45..15f15b4 100644 --- a/dispatcher/control.py +++ b/dispatcher/control.py @@ -24,12 +24,17 @@ async def connected_callback(self) -> None: await self.broker.apublish_message(self.queuename, self.send_message) async def listen_for_replies(self) -> None: - """Listen to the reply channel until we get the expected number of messages + """Listen to the reply channel until we get the expected number of messages. - This gets ran in a task, and timing out will be accomplished by the main code + This gets ran in an async task, and timing out will be accomplished by the main code """ async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback): - self.received_replies.append(payload) + try: + # If payload is a string, parse it to a dict; otherwise assume it's valid. + message = json.loads(payload) if isinstance(payload, str) else payload + self.received_replies.append(message) + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON on channel '{channel}': {payload[:100]}... (Error: {e})") if len(self.received_replies) >= self.expected_replies: return @@ -54,7 +59,7 @@ def parse_replies(received_replies: list[Union[str, dict]]) -> list[dict]: ret.append(json.loads(payload)) return ret - def get_send_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str: + def create_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str: to_send: dict[str, Union[dict, str]] = {'control': command} if reply_to: to_send['reply_to'] = reply_to @@ -65,7 +70,7 @@ def get_send_message(self, command: str, reply_to: Optional[str] = None, send_da async def acontrol_with_reply(self, command: str, expected_replies: int = 1, timeout: int = 1, data: Optional[dict] = None) -> list[dict]: reply_queue = Control.generate_reply_queue_name() broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue]) - send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data) + send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data) control_callbacks = BrokerCallbacks(broker=broker, queuename=self.queuename, send_message=send_message, expected_replies=expected_replies) @@ -77,19 +82,24 @@ async def acontrol_with_reply(self, command: str, expected_replies: int = 1, tim except asyncio.TimeoutError: logger.warning(f'Did not receive {expected_replies} reply in {timeout} seconds, only {len(control_callbacks.received_replies)}') listen_task.cancel() + finally: + await broker.aclose() return self.parse_replies(control_callbacks.received_replies) async def acontrol(self, command: str, data: Optional[dict] = None) -> None: broker = get_broker(self.broker_name, self.broker_config, channels=[]) - send_message = self.get_send_message(command=command, send_data=data) - await broker.apublish_message(message=send_message) + send_message = self.create_message(command=command, send_data=data) + try: + await broker.apublish_message(message=send_message) + finally: + await broker.aclose() def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]: - logger.info('control-and-reply {} to {}'.format(command, self.queuename)) + logger.info(f'control-and-reply {command} to {self.queuename}') start = time.time() reply_queue = Control.generate_reply_queue_name() - send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data) + send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data) broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue]) @@ -97,15 +107,20 @@ def connected_callback() -> None: broker.publish_message(channel=self.queuename, message=send_message) replies = [] - for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout): - reply_data = json.loads(payload) - replies.append(reply_data) - - logger.info(f'control-and-reply message returned in {time.time() - start} seconds') - return replies + try: + for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout): + reply_data = json.loads(payload) + replies.append(reply_data) + logger.info(f'control-and-reply message returned in {time.time() - start} seconds') + return replies + finally: + broker.close() def control(self, command: str, data: Optional[dict] = None) -> None: - "Send message in fire-and-forget mode, as synchronous code. Only for no-reply control." + """Send a fire-and-forget control message synchronously.""" broker = get_broker(self.broker_name, self.broker_config) - send_message = self.get_send_message(command=command, send_data=data) - broker.publish_message(channel=self.queuename, message=send_message) + send_message = self.create_message(command=command, send_data=data) + try: + broker.publish_message(channel=self.queuename, message=send_message) + finally: + broker.close() diff --git a/dispatcher/protocols.py b/dispatcher/protocols.py index 2c2ddc2..4b5dd1a 100644 --- a/dispatcher/protocols.py +++ b/dispatcher/protocols.py @@ -3,6 +3,13 @@ class Broker(Protocol): + """ + Describes a messaging broker interface. + + This interface abstracts functionality for sending and receiving messages, + both asynchronously and synchronously, and for managing connection lifecycles. + """ + async def aprocess_notify( self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None ) -> AsyncGenerator[tuple[str, str], None]: @@ -35,10 +42,23 @@ def close(self): class ProducerEvents(Protocol): + """ + Describes an events container for producers. + + Typically provides a signal (like a ready event) to indicate producer readiness. + """ + ready_event: asyncio.Event class Producer(Protocol): + """ + Describes a task producer interface. + + This interface encapsulates behavior for starting task production, + managing its lifecycle, and tracking asynchronous operations. + """ + events: ProducerEvents async def start_producing(self, dispatcher: 'DispatcherMain') -> None: @@ -55,6 +75,13 @@ def all_tasks(self) -> Iterable[asyncio.Task]: class PoolWorker(Protocol): + """ + Describes an individual worker in a task pool. + + It covers the properties and behaviors needed to track a worker’s execution state + and control its task processing lifecycle. + """ + current_task: Optional[dict] worker_id: int @@ -70,18 +97,37 @@ def cancel(self) -> None: ... class Queuer(Protocol): + """ + Describes an interface for managing pending tasks. + + It provides a way to iterate over and modify tasks awaiting assignment. + """ + def __iter__(self) -> Iterator[dict]: ... def remove_task(self, message: dict) -> None: ... class Blocker(Protocol): + """ + Describes an interface for handling tasks that are temporarily deferred. + + It offers a mechanism to view and manage tasks that cannot run immediately. + """ + def __iter__(self) -> Iterator[dict]: ... def remove_task(self, message: dict) -> None: ... class WorkerData(Protocol): + """ + Describes an interface for managing a collection of workers. + + It abstracts how worker instances are iterated over and retrieved, + and it provides a lock for safe concurrent updates. + """ + management_lock: asyncio.Lock def __iter__(self) -> Iterator[PoolWorker]: ... @@ -90,6 +136,13 @@ def get_by_id(self, worker_id: int) -> PoolWorker: ... class WorkerPool(Protocol): + """ + Describes an interface for a pool managing task workers. + + It includes core functionality for starting the pool, dispatching tasks, + and shutting down the pool in a controlled manner. + """ + workers: WorkerData queuer: Queuer blocker: Blocker @@ -106,6 +159,14 @@ async def shutdown(self) -> None: ... class DispatcherMain(Protocol): + """ + Describes the primary dispatcher interface. + + This interface defines the contract for the overall task dispatching service, + including coordinating task processing, managing the worker pool, and + handling delayed or control messages. + """ + pool: WorkerPool delayed_messages: set diff --git a/dispatcher/service/next_wakeup_runner.py b/dispatcher/service/next_wakeup_runner.py index 67ea56d..e45ab01 100644 --- a/dispatcher/service/next_wakeup_runner.py +++ b/dispatcher/service/next_wakeup_runner.py @@ -50,9 +50,12 @@ def __init__(self, wakeup_objects: Iterable[HasWakeup], process_object: Callable self.name = name async def process_wakeups(self, current_time: float, do_processing: bool = True) -> Optional[float]: - """Runs process_object for objects past for which we have passed the wakeup time + """Runs process_object for objects whose wakeup time has passed. - Returns the time of the soonest wakeup that has not been processed here + Returns the soonest upcoming wakeup time among the objects that have not been processed. + + If do_processing is True, process_object is called for objects with wakeup times below current_time. + Errors from process_object are logged and propagated. Arguments: - current_time - output of time.monotonic() passed from caller to keep this deterministic @@ -63,7 +66,11 @@ async def process_wakeups(self, current_time: float, do_processing: bool = True) for obj in list(self.wakeup_objects): if obj_wakeup := obj.next_wakeup(): if do_processing and (obj_wakeup < current_time): - await self.process_object(obj) + try: + await self.process_object(obj) + except Exception as e: + logger.error(f"Error processing wakeup for object {obj}: {e}", exc_info=True) + raise # refresh wakeup, which should be nullified or pushed back by process_object obj_wakeup = obj.next_wakeup() if obj_wakeup is None: diff --git a/dispatcher/service/pool.py b/dispatcher/service/pool.py index 0e2510f..a1ac89c 100644 --- a/dispatcher/service/pool.py +++ b/dispatcher/service/pool.py @@ -317,14 +317,22 @@ async def manage_new_workers(self, forking_lock: asyncio.Lock) -> None: async def manage_old_workers(self) -> None: """Clear internal memory of workers whose process has exited, and assures processes are gone + This method takes a snapshot of the current workers under lock, + processes them outside the lock (including awaiting worker stops), + and then re-acquires the lock to remove workers marked for deletion. + happy path: The scale_workers method notifies a worker they need to exit The read_results_task will mark the worker status to exited This method will see the updated status, join the process, and remove it from self.workers """ + # Phase 1: Get a consistent snapshot of workers. + async with self.workers.management_lock: + current_workers = list(self.workers) + remove_ids = [] - for worker in self.workers: - # Check for workers that died unexpectedly + for worker in current_workers: + # Check if the worker has died unexpectedly. if worker.status not in ['retired', 'error', 'exited', 'initialized', 'spawned'] and not worker.process.is_alive(): logger.error(f'Worker {worker.worker_id} pid={worker.process.pid} has died unexpectedly, status was {worker.status}') @@ -332,8 +340,7 @@ async def manage_old_workers(self) -> None: uuid = worker.current_task.get('uuid', '') logger.error(f'Task (uuid={uuid}) was running on worker {worker.worker_id} but the worker died unexpectedly') self.canceled_count += 1 - worker.is_active_cancel = False # Ensure it's not processed by timeout runner - + worker.is_active_cancel = False # Prevent further processing. worker.status = 'error' worker.retired_at = time.monotonic() @@ -345,9 +352,9 @@ async def manage_old_workers(self) -> None: elif worker.status in ['retired', 'error'] and worker.retired_at and (time.monotonic() - worker.retired_at) > self.worker_removal_wait: remove_ids.append(worker.worker_id) - # Remove workers from memory, done as separate loop due to locking concerns - for worker_id in remove_ids: - async with self.workers.management_lock: + # Phase 2: Remove workers from the collection under lock. + async with self.workers.management_lock: + for worker_id in remove_ids: if worker_id in self.workers: logger.debug(f'Fully removing worker id={worker_id}') self.workers.remove_by_id(worker_id) @@ -453,14 +460,16 @@ async def post_task_start(self, message: dict) -> None: async def dispatch_task(self, message: dict) -> None: uuid = message.get("uuid", "") - async with self.workers.management_lock: - if unblocked_task := self.blocker.process_task(message): - if worker := self.queuer.get_worker_or_process_task(unblocked_task): - logger.debug(f"Dispatching task (uuid={uuid}) to worker (id={worker.worker_id})") + unblocked_task = self.blocker.process_task(message) + if unblocked_task: + worker = self.queuer.get_worker_or_process_task(unblocked_task) + if worker: + logger.debug(f"Dispatching task (uuid={uuid}) to worker (id={worker.worker_id})") + async with self.workers.management_lock: await worker.start_task(unblocked_task) await self.post_task_start(unblocked_task) - else: - self.events.management_event.set() # kick manager task to start auto-scale up + else: + self.events.management_event.set() # kick manager task to start auto-scale up if needed async def drain_queue(self) -> None: async with self.workers.management_lock: diff --git a/dispatcher/service/process.py b/dispatcher/service/process.py index 4064575..2ebbc84 100644 --- a/dispatcher/service/process.py +++ b/dispatcher/service/process.py @@ -2,7 +2,7 @@ import multiprocessing from multiprocessing.context import BaseContext from types import ModuleType -from typing import Callable, Iterable, Optional, Union +from typing import Any, Callable, Iterable, Optional, Union from ..config import LazySettings from ..config import settings as global_settings @@ -51,6 +51,23 @@ def kill(self) -> None: def terminate(self) -> None: self._process.terminate() + def __enter__(self) -> "ProcessProxy": + """Enter the runtime context and return this ProcessProxy.""" + return self + + def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]) -> Optional[bool]: + """Ensure the process is terminated and joined when exiting the context. + + If the process is still alive, it will be terminated (or killed if necessary) and then joined. + """ + if self.is_alive(): + try: + self.terminate() + except Exception: + self.kill() + self.join() + return None + class ProcessManager: mp_context = 'fork' diff --git a/tests/unit/test_broker_callbacks.py b/tests/unit/test_broker_callbacks.py new file mode 100644 index 0000000..bf886e5 --- /dev/null +++ b/tests/unit/test_broker_callbacks.py @@ -0,0 +1,51 @@ +import json +import logging +import pytest + +from dispatcher.control import BrokerCallbacks +from dispatcher.protocols import Broker + + +# Dummy broker that yields first an invalid JSON message and then a valid one. +class DummyBroker(Broker): + async def aprocess_notify(self, connected_callback=None): + if connected_callback: + await connected_callback() + # First yield an invalid JSON string, then a valid one. + yield ("reply_channel", "invalid json") + yield ("reply_channel", json.dumps({"result": "ok"})) + + async def apublish_message(self, channel, message): + # No-op for testing. + return + + async def aclose(self): + return + + def process_notify(self, connected_callback=None, timeout: float = 5.0, max_messages: int = 1): + # Not used in this test. + yield ("reply_channel", "") + + def publish_message(self, channel=None, message=None): + return + + def close(self): + return + + +@pytest.mark.asyncio +async def test_listen_for_replies_with_invalid_json(caplog): + caplog.set_level(logging.WARNING) + dummy_broker = DummyBroker() + callbacks = BrokerCallbacks( + queuename="reply_channel", + broker=dummy_broker, + send_message="{}", + expected_replies=1 + ) + await callbacks.listen_for_replies() + # The invalid JSON should be ignored and only the valid message appended. + assert len(callbacks.received_replies) == 1 + assert callbacks.received_replies[0] == {"result": "ok"} + # Verify that a warning was logged for the malformed message. + assert any("Invalid JSON" in record.message for record in caplog.records) diff --git a/tests/unit/test_connection_saver.py b/tests/unit/test_connection_saver.py new file mode 100644 index 0000000..32d3f65 --- /dev/null +++ b/tests/unit/test_connection_saver.py @@ -0,0 +1,70 @@ +import threading +import asyncio +import pytest + +from dispatcher.brokers.pg_notify import connection_saver, async_connection_saver, connection_save + +# Define a dummy connection object that supports both sync and async close methods. +class DummyConnection: + def __init__(self): + self.closed = False + def close(self): + self.closed = True + async def aclose(self): + self.close() + +connection_create_count = 0 + +def dummy_create_connection(**config): + global connection_create_count + connection_create_count += 1 + return DummyConnection() + +@pytest.fixture(autouse=True) +def reset_sync(monkeypatch): + global connection_create_count + connection_create_count = 0 + monkeypatch.setattr("dispatcher.brokers.pg_notify.create_connection", dummy_create_connection) + connection_save._connection = None + +def test_connection_saver_thread_safety(): + results = [] + def worker(): + res = connection_saver(foo="bar") + results.append(res) + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + # Ensure all threads got the same connection object. + assert all(r is results[0] for r in results) + # Ensure only one connection was created. + assert connection_create_count == 1 + # Check that the connection supports close() properly. + results[0].close() + assert results[0].closed is True + +@pytest.mark.asyncio +async def test_async_connection_saver_thread_safety(monkeypatch): + global connection_create_count + connection_create_count = 0 + + async def dummy_acreate_connection(**config): + global connection_create_count + connection_create_count += 1 + return DummyConnection() + + monkeypatch.setattr("dispatcher.brokers.pg_notify.acreate_connection", dummy_acreate_connection) + connection_save._async_connection = None + + async def worker(): + return await async_connection_saver(foo="bar") + results = await asyncio.gather(*[worker() for _ in range(10)]) + # Ensure all tasks returned the same connection object. + assert all(r is results[0] for r in results) + # Ensure only one async connection was created. + assert connection_create_count == 1 + await results[0].aclose() + assert results[0].closed is True diff --git a/tests/unit/test_control_cleanup.py b/tests/unit/test_control_cleanup.py new file mode 100644 index 0000000..64f55a5 --- /dev/null +++ b/tests/unit/test_control_cleanup.py @@ -0,0 +1,97 @@ +import asyncio +import json +import pytest + +from dispatcher.control import Control, BrokerCallbacks +from dispatcher.protocols import Broker + +# Dummy broker implementation for testing cleanup. +class DummyBroker(Broker): + def __init__(self): + self.aclose_called = False + self.close_called = False + self.sent_message = None + + async def aprocess_notify(self, connected_callback=None): + if connected_callback: + await connected_callback() + # Yield one valid reply message. + yield ("dummy_channel", json.dumps({"result": "ok"})) + + async def apublish_message(self, channel=None, message=""): + self.sent_message = message + + async def aclose(self): + self.aclose_called = True + + def process_notify(self, connected_callback=None, timeout: float = 5.0, max_messages: int = 1): + if connected_callback: + connected_callback() + # Yield one valid reply message. + yield ("dummy_channel", json.dumps({"result": "ok"})) + + def publish_message(self, channel=None, message=""): + self.sent_message = message + + def close(self): + self.close_called = True + +# Test for async control with reply cleanup +@pytest.mark.asyncio +async def test_acontrol_with_reply_resource_cleanup(monkeypatch): + dummy_broker = DummyBroker() + + def dummy_get_broker(broker_name, broker_config, channels=None): + return dummy_broker + + monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker) + + control = Control(broker_name="dummy", broker_config={}) + result = await control.acontrol_with_reply( + command="test_command", expected_replies=1, timeout=2, data={"key": "value"} + ) + assert result == [{"result": "ok"}] + assert dummy_broker.aclose_called is True + +# Test for async control (fire-and-forget) cleanup +@pytest.mark.asyncio +async def test_acontrol_resource_cleanup(monkeypatch): + dummy_broker = DummyBroker() + + def dummy_get_broker(broker_name, broker_config, channels=None): + return dummy_broker + + monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker) + + control = Control(broker_name="dummy", broker_config={}) + await control.acontrol(command="test_command", data={"foo": "bar"}) + # In acontrol, broker.aclose() should be called. + assert dummy_broker.aclose_called is True + +# Test for synchronous control_with_reply cleanup +def test_control_with_reply_resource_cleanup(monkeypatch): + dummy_broker = DummyBroker() + + def dummy_get_broker(broker_name, broker_config, channels=None): + return dummy_broker + + monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker) + + control = Control(broker_name="dummy", broker_config={}, queue="test_queue") + result = control.control_with_reply(command="test_command", expected_replies=1, timeout=2, data={"foo": "bar"}) + assert result == [{"result": "ok"}] + # For sync methods, broker.close() should be called. + assert dummy_broker.close_called is True + +# Test for synchronous control (fire-and-forget) cleanup +def test_control_resource_cleanup(monkeypatch): + dummy_broker = DummyBroker() + + def dummy_get_broker(broker_name, broker_config, channels=None): + return dummy_broker + + monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker) + + control = Control(broker_name="dummy", broker_config={}, queue="test_queue") + control.control(command="test_command", data={"foo": "bar"}) + assert dummy_broker.close_called is True diff --git a/tests/unit/test_next_wakeup_runner_errors.py b/tests/unit/test_next_wakeup_runner_errors.py new file mode 100644 index 0000000..2ea5f6d --- /dev/null +++ b/tests/unit/test_next_wakeup_runner_errors.py @@ -0,0 +1,46 @@ +import time +import pytest +from dispatcher.service.next_wakeup_runner import NextWakeupRunner, HasWakeup + + +# Dummy object that implements HasWakeup. +class DummySchedule(HasWakeup): + def __init__(self, wakeup_time: float): + self._wakeup_time = wakeup_time + + def next_wakeup(self) -> float: + return self._wakeup_time + +# Dummy process_object that simulates successful processing by pushing wakeup time forward. +async def dummy_process_object(schedule: DummySchedule) -> None: + # Simulate processing by adding 10 seconds. + schedule._wakeup_time += 10 + +# Dummy process_object that raises an exception. +async def failing_process_object(schedule: DummySchedule) -> None: + raise ValueError("Processing error") + +@pytest.mark.asyncio +async def test_process_wakeups_normal(): + # Set up a dummy schedule with a wakeup time in the past. + past_time = time.monotonic() - 5 + schedule = DummySchedule(past_time) + # Use dummy_process_object that adds 10 seconds. + runner = NextWakeupRunner([schedule], dummy_process_object) + current_time = time.monotonic() + next_wakeup = await runner.process_wakeups(current_time) + # The wakeup time should now be 10 seconds later than the original past time. + assert next_wakeup == schedule._wakeup_time + # Also, since the schedule was processed, it should not return None. + assert next_wakeup is not None + +@pytest.mark.asyncio +async def test_process_wakeups_error_propagation(): + # Set up a dummy schedule with a wakeup time in the past. + past_time = time.monotonic() - 5 + schedule = DummySchedule(past_time) + # Use failing_process_object that raises an exception. + runner = NextWakeupRunner([schedule], failing_process_object) + current_time = time.monotonic() + with pytest.raises(ValueError, match="Processing error"): + await runner.process_wakeups(current_time)