From 7859993f6483aa60cea3d823c700323aa91e38c9 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 19:43:06 +0000
Subject: [PATCH 01/10] Improve protocol documentation by adding docstrings

---
 dispatcher/protocols.py | 61 +++++++++++++++++++++++++++++++++++++++++
 1 file changed, 61 insertions(+)

diff --git a/dispatcher/protocols.py b/dispatcher/protocols.py
index 2c2ddc2..4b5dd1a 100644
--- a/dispatcher/protocols.py
+++ b/dispatcher/protocols.py
@@ -3,6 +3,13 @@
 
 
 class Broker(Protocol):
+    """
+    Describes a messaging broker interface.
+
+    This interface abstracts functionality for sending and receiving messages,
+    both asynchronously and synchronously, and for managing connection lifecycles.
+    """
+
     async def aprocess_notify(
         self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
     ) -> AsyncGenerator[tuple[str, str], None]:
@@ -35,10 +42,23 @@ def close(self):
 
 
 class ProducerEvents(Protocol):
+    """
+    Describes an events container for producers.
+
+    Typically provides a signal (like a ready event) to indicate producer readiness.
+    """
+
     ready_event: asyncio.Event
 
 
 class Producer(Protocol):
+    """
+    Describes a task producer interface.
+
+    This interface encapsulates behavior for starting task production,
+    managing its lifecycle, and tracking asynchronous operations.
+    """
+
     events: ProducerEvents
 
     async def start_producing(self, dispatcher: 'DispatcherMain') -> None:
@@ -55,6 +75,13 @@ def all_tasks(self) -> Iterable[asyncio.Task]:
 
 
 class PoolWorker(Protocol):
+    """
+    Describes an individual worker in a task pool.
+
+    It covers the properties and behaviors needed to track a worker’s execution state
+    and control its task processing lifecycle.
+    """
+
     current_task: Optional[dict]
     worker_id: int
 
@@ -70,18 +97,37 @@ def cancel(self) -> None: ...
 
 
 class Queuer(Protocol):
+    """
+    Describes an interface for managing pending tasks.
+
+    It provides a way to iterate over and modify tasks awaiting assignment.
+    """
+
     def __iter__(self) -> Iterator[dict]: ...
 
     def remove_task(self, message: dict) -> None: ...
 
 
 class Blocker(Protocol):
+    """
+    Describes an interface for handling tasks that are temporarily deferred.
+
+    It offers a mechanism to view and manage tasks that cannot run immediately.
+    """
+
     def __iter__(self) -> Iterator[dict]: ...
 
     def remove_task(self, message: dict) -> None: ...
 
 
 class WorkerData(Protocol):
+    """
+    Describes an interface for managing a collection of workers.
+
+    It abstracts how worker instances are iterated over and retrieved,
+    and it provides a lock for safe concurrent updates.
+    """
+
     management_lock: asyncio.Lock
 
     def __iter__(self) -> Iterator[PoolWorker]: ...
@@ -90,6 +136,13 @@ def get_by_id(self, worker_id: int) -> PoolWorker: ...
 
 
 class WorkerPool(Protocol):
+    """
+    Describes an interface for a pool managing task workers.
+
+    It includes core functionality for starting the pool, dispatching tasks,
+    and shutting down the pool in a controlled manner.
+    """
+
     workers: WorkerData
     queuer: Queuer
     blocker: Blocker
@@ -106,6 +159,14 @@ async def shutdown(self) -> None: ...
 
 
 class DispatcherMain(Protocol):
+    """
+    Describes the primary dispatcher interface.
+
+    This interface defines the contract for the overall task dispatching service,
+    including coordinating task processing, managing the worker pool, and
+    handling delayed or control messages.
+    """
+
     pool: WorkerPool
     delayed_messages: set
 

From 0ac03a0880441488e17f9c32b652227c7fb8bbb6 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 19:56:00 +0000
Subject: [PATCH 02/10] Improve error handling in
 BrokerCallbacks.listen_for_replies

- Add JSON parsing with exception handling to ignore malformed messages.
- Log warnings when invalid JSON is received.
- Add a unit test
---
 dispatcher/control.py               | 11 +++++--
 tests/unit/test_broker_callbacks.py | 51 +++++++++++++++++++++++++++++
 2 files changed, 59 insertions(+), 3 deletions(-)
 create mode 100644 tests/unit/test_broker_callbacks.py

diff --git a/dispatcher/control.py b/dispatcher/control.py
index aaf0b45..c60cb38 100644
--- a/dispatcher/control.py
+++ b/dispatcher/control.py
@@ -24,12 +24,17 @@ async def connected_callback(self) -> None:
         await self.broker.apublish_message(self.queuename, self.send_message)
 
     async def listen_for_replies(self) -> None:
-        """Listen to the reply channel until we get the expected number of messages
+        """Listen to the reply channel until we get the expected number of messages.
 
-        This gets ran in a task, and timing out will be accomplished by the main code
+        This gets ran in an async task, and timing out will be accomplished by the main code
         """
         async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
-            self.received_replies.append(payload)
+            try:
+                # If payload is a string, parse it to a dict; otherwise assume it's valid.
+                message = json.loads(payload) if isinstance(payload, str) else payload
+                self.received_replies.append(message)
+            except json.JSONDecodeError as e:
+                logger.warning(f"Invalid JSON on channel '{channel}': {payload[:100]}... (Error: {e})")
             if len(self.received_replies) >= self.expected_replies:
                 return
 
diff --git a/tests/unit/test_broker_callbacks.py b/tests/unit/test_broker_callbacks.py
new file mode 100644
index 0000000..bf886e5
--- /dev/null
+++ b/tests/unit/test_broker_callbacks.py
@@ -0,0 +1,51 @@
+import json
+import logging
+import pytest
+
+from dispatcher.control import BrokerCallbacks
+from dispatcher.protocols import Broker
+
+
+# Dummy broker that yields first an invalid JSON message and then a valid one.
+class DummyBroker(Broker):
+    async def aprocess_notify(self, connected_callback=None):
+        if connected_callback:
+            await connected_callback()
+        # First yield an invalid JSON string, then a valid one.
+        yield ("reply_channel", "invalid json")
+        yield ("reply_channel", json.dumps({"result": "ok"}))
+
+    async def apublish_message(self, channel, message):
+        # No-op for testing.
+        return
+
+    async def aclose(self):
+        return
+
+    def process_notify(self, connected_callback=None, timeout: float = 5.0, max_messages: int = 1):
+        # Not used in this test.
+        yield ("reply_channel", "")
+
+    def publish_message(self, channel=None, message=None):
+        return
+
+    def close(self):
+        return
+
+
+@pytest.mark.asyncio
+async def test_listen_for_replies_with_invalid_json(caplog):
+    caplog.set_level(logging.WARNING)
+    dummy_broker = DummyBroker()
+    callbacks = BrokerCallbacks(
+        queuename="reply_channel",
+        broker=dummy_broker,
+        send_message="{}",
+        expected_replies=1
+    )
+    await callbacks.listen_for_replies()
+    # The invalid JSON should be ignored and only the valid message appended.
+    assert len(callbacks.received_replies) == 1
+    assert callbacks.received_replies[0] == {"result": "ok"}
+    # Verify that a warning was logged for the malformed message.
+    assert any("Invalid JSON" in record.message for record in caplog.records)

From 118434ed4b32581d391ea1583f56cfe489684fb4 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 19:59:50 +0000
Subject: [PATCH 03/10] Rename get_send_message to create_message for clarity

- Change method name in Control class to better reflect its role in constructing messages.
---
 dispatcher/control.py | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/dispatcher/control.py b/dispatcher/control.py
index c60cb38..7609e87 100644
--- a/dispatcher/control.py
+++ b/dispatcher/control.py
@@ -59,7 +59,7 @@ def parse_replies(received_replies: list[Union[str, dict]]) -> list[dict]:
                 ret.append(json.loads(payload))
         return ret
 
-    def get_send_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str:
+    def create_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str:
         to_send: dict[str, Union[dict, str]] = {'control': command}
         if reply_to:
             to_send['reply_to'] = reply_to
@@ -70,7 +70,7 @@ def get_send_message(self, command: str, reply_to: Optional[str] = None, send_da
     async def acontrol_with_reply(self, command: str, expected_replies: int = 1, timeout: int = 1, data: Optional[dict] = None) -> list[dict]:
         reply_queue = Control.generate_reply_queue_name()
         broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
-        send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data)
+        send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)
 
         control_callbacks = BrokerCallbacks(broker=broker, queuename=self.queuename, send_message=send_message, expected_replies=expected_replies)
 
@@ -87,14 +87,14 @@ async def acontrol_with_reply(self, command: str, expected_replies: int = 1, tim
 
     async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
         broker = get_broker(self.broker_name, self.broker_config, channels=[])
-        send_message = self.get_send_message(command=command, send_data=data)
+        send_message = self.create_message(command=command, send_data=data)
         await broker.apublish_message(message=send_message)
 
     def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
         logger.info('control-and-reply {} to {}'.format(command, self.queuename))
         start = time.time()
         reply_queue = Control.generate_reply_queue_name()
-        send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data)
+        send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)
 
         broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
 
@@ -112,5 +112,5 @@ def connected_callback() -> None:
     def control(self, command: str, data: Optional[dict] = None) -> None:
         "Send message in fire-and-forget mode, as synchronous code. Only for no-reply control."
         broker = get_broker(self.broker_name, self.broker_config)
-        send_message = self.get_send_message(command=command, send_data=data)
+        send_message = self.create_message(command=command, send_data=data)
         broker.publish_message(channel=self.queuename, message=send_message)

From 5c60264335aac748900f3fa430cf0c8a9fbc8a27 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 20:09:18 +0000
Subject: [PATCH 04/10] Add resource cleanup in Control methods and tests

- Ensure broker connections are closed in acontrol, control_with_reply, and control.
- Update tests to verify that aclose() or close() is called appropriately.
---
 dispatcher/control.py              | 28 ++++++---
 tests/unit/test_control_cleanup.py | 97 ++++++++++++++++++++++++++++++
 2 files changed, 116 insertions(+), 9 deletions(-)
 create mode 100644 tests/unit/test_control_cleanup.py

diff --git a/dispatcher/control.py b/dispatcher/control.py
index 7609e87..d4d5842 100644
--- a/dispatcher/control.py
+++ b/dispatcher/control.py
@@ -82,13 +82,18 @@ async def acontrol_with_reply(self, command: str, expected_replies: int = 1, tim
         except asyncio.TimeoutError:
             logger.warning(f'Did not receive {expected_replies} reply in {timeout} seconds, only {len(control_callbacks.received_replies)}')
             listen_task.cancel()
+        finally:
+            await broker.aclose()
 
         return self.parse_replies(control_callbacks.received_replies)
 
     async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
         broker = get_broker(self.broker_name, self.broker_config, channels=[])
         send_message = self.create_message(command=command, send_data=data)
-        await broker.apublish_message(message=send_message)
+        try:
+            await broker.apublish_message(message=send_message)
+        finally:
+            await broker.aclose()
 
     def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
         logger.info('control-and-reply {} to {}'.format(command, self.queuename))
@@ -102,15 +107,20 @@ def connected_callback() -> None:
             broker.publish_message(channel=self.queuename, message=send_message)
 
         replies = []
-        for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
-            reply_data = json.loads(payload)
-            replies.append(reply_data)
-
-        logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
-        return replies
+        try:
+            for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
+                reply_data = json.loads(payload)
+                replies.append(reply_data)
+            logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
+            return replies
+        finally:
+            broker.close()
 
     def control(self, command: str, data: Optional[dict] = None) -> None:
-        "Send message in fire-and-forget mode, as synchronous code. Only for no-reply control."
+        """Send a fire-and-forget control message synchronously."""
         broker = get_broker(self.broker_name, self.broker_config)
         send_message = self.create_message(command=command, send_data=data)
-        broker.publish_message(channel=self.queuename, message=send_message)
+        try:
+            broker.publish_message(channel=self.queuename, message=send_message)
+        finally:
+            broker.close()
diff --git a/tests/unit/test_control_cleanup.py b/tests/unit/test_control_cleanup.py
new file mode 100644
index 0000000..64f55a5
--- /dev/null
+++ b/tests/unit/test_control_cleanup.py
@@ -0,0 +1,97 @@
+import asyncio
+import json
+import pytest
+
+from dispatcher.control import Control, BrokerCallbacks
+from dispatcher.protocols import Broker
+
+# Dummy broker implementation for testing cleanup.
+class DummyBroker(Broker):
+    def __init__(self):
+        self.aclose_called = False
+        self.close_called = False
+        self.sent_message = None
+
+    async def aprocess_notify(self, connected_callback=None):
+        if connected_callback:
+            await connected_callback()
+        # Yield one valid reply message.
+        yield ("dummy_channel", json.dumps({"result": "ok"}))
+
+    async def apublish_message(self, channel=None, message=""):
+        self.sent_message = message
+
+    async def aclose(self):
+        self.aclose_called = True
+
+    def process_notify(self, connected_callback=None, timeout: float = 5.0, max_messages: int = 1):
+        if connected_callback:
+            connected_callback()
+        # Yield one valid reply message.
+        yield ("dummy_channel", json.dumps({"result": "ok"}))
+
+    def publish_message(self, channel=None, message=""):
+        self.sent_message = message
+
+    def close(self):
+        self.close_called = True
+
+# Test for async control with reply cleanup
+@pytest.mark.asyncio
+async def test_acontrol_with_reply_resource_cleanup(monkeypatch):
+    dummy_broker = DummyBroker()
+
+    def dummy_get_broker(broker_name, broker_config, channels=None):
+        return dummy_broker
+
+    monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker)
+
+    control = Control(broker_name="dummy", broker_config={})
+    result = await control.acontrol_with_reply(
+        command="test_command", expected_replies=1, timeout=2, data={"key": "value"}
+    )
+    assert result == [{"result": "ok"}]
+    assert dummy_broker.aclose_called is True
+
+# Test for async control (fire-and-forget) cleanup
+@pytest.mark.asyncio
+async def test_acontrol_resource_cleanup(monkeypatch):
+    dummy_broker = DummyBroker()
+
+    def dummy_get_broker(broker_name, broker_config, channels=None):
+        return dummy_broker
+
+    monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker)
+
+    control = Control(broker_name="dummy", broker_config={})
+    await control.acontrol(command="test_command", data={"foo": "bar"})
+    # In acontrol, broker.aclose() should be called.
+    assert dummy_broker.aclose_called is True
+
+# Test for synchronous control_with_reply cleanup
+def test_control_with_reply_resource_cleanup(monkeypatch):
+    dummy_broker = DummyBroker()
+
+    def dummy_get_broker(broker_name, broker_config, channels=None):
+        return dummy_broker
+
+    monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker)
+
+    control = Control(broker_name="dummy", broker_config={}, queue="test_queue")
+    result = control.control_with_reply(command="test_command", expected_replies=1, timeout=2, data={"foo": "bar"})
+    assert result == [{"result": "ok"}]
+    # For sync methods, broker.close() should be called.
+    assert dummy_broker.close_called is True
+
+# Test for synchronous control (fire-and-forget) cleanup
+def test_control_resource_cleanup(monkeypatch):
+    dummy_broker = DummyBroker()
+
+    def dummy_get_broker(broker_name, broker_config, channels=None):
+        return dummy_broker
+
+    monkeypatch.setattr("dispatcher.control.get_broker", dummy_get_broker)
+
+    control = Control(broker_name="dummy", broker_config={}, queue="test_queue")
+    control.control(command="test_command", data={"foo": "bar"})
+    assert dummy_broker.close_called is True

From bd239621d9c196a195b0c16022b01d2a10c74c55 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 20:36:38 +0000
Subject: [PATCH 05/10] Fix race condition in manage_old_workers and add tests

- Refactor manage_old_workers to use a two-phase locking approach.
- Take a snapshot under lock, process removals, then re-acquire the lock to remove workers atomically.

Potentially closes #124
---
 dispatcher/service/pool.py | 21 ++++++++++++++-------
 1 file changed, 14 insertions(+), 7 deletions(-)

diff --git a/dispatcher/service/pool.py b/dispatcher/service/pool.py
index 0e2510f..385bd5e 100644
--- a/dispatcher/service/pool.py
+++ b/dispatcher/service/pool.py
@@ -317,14 +317,22 @@ async def manage_new_workers(self, forking_lock: asyncio.Lock) -> None:
     async def manage_old_workers(self) -> None:
         """Clear internal memory of workers whose process has exited, and assures processes are gone
 
+        This method takes a snapshot of the current workers under lock,
+        processes them outside the lock (including awaiting worker stops),
+        and then re-acquires the lock to remove workers marked for deletion.
+
         happy path:
         The scale_workers method notifies a worker they need to exit
         The read_results_task will mark the worker status to exited
         This method will see the updated status, join the process, and remove it from self.workers
         """
+        # Phase 1: Get a consistent snapshot of workers.
+        async with self.workers.management_lock:
+            current_workers = list(self.workers)
+
         remove_ids = []
-        for worker in self.workers:
-            # Check for workers that died unexpectedly
+        for worker in current_workers:
+            # Check if the worker has died unexpectedly.
             if worker.status not in ['retired', 'error', 'exited', 'initialized', 'spawned'] and not worker.process.is_alive():
                 logger.error(f'Worker {worker.worker_id} pid={worker.process.pid} has died unexpectedly, status was {worker.status}')
 
@@ -332,8 +340,7 @@ async def manage_old_workers(self) -> None:
                     uuid = worker.current_task.get('uuid', '<unknown>')
                     logger.error(f'Task (uuid={uuid}) was running on worker {worker.worker_id} but the worker died unexpectedly')
                     self.canceled_count += 1
-                    worker.is_active_cancel = False  # Ensure it's not processed by timeout runner
-
+                    worker.is_active_cancel = False  # Prevent further processing.
                 worker.status = 'error'
                 worker.retired_at = time.monotonic()
 
@@ -345,9 +352,9 @@ async def manage_old_workers(self) -> None:
             elif worker.status in ['retired', 'error'] and worker.retired_at and (time.monotonic() - worker.retired_at) > self.worker_removal_wait:
                 remove_ids.append(worker.worker_id)
 
-        # Remove workers from memory, done as separate loop due to locking concerns
-        for worker_id in remove_ids:
-            async with self.workers.management_lock:
+        # Phase 2: Remove workers from the collection under lock.
+        async with self.workers.management_lock:
+            for worker_id in remove_ids:
                 if worker_id in self.workers:
                     logger.debug(f'Fully removing worker id={worker_id}')
                     self.workers.remove_by_id(worker_id)

From 7236ba72081a367e06e72bed003e564df8ded2c1 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 20:50:17 +0000
Subject: [PATCH 06/10] Improve error propagation in
 NextWakeupRunner.process_wakeups

- Wrap process_object callback in try/except to log and re-raise errors.
- Add unit tests to verify normal operation and error propagation.
---
 dispatcher/service/next_wakeup_runner.py     | 13 ++++--
 tests/unit/test_next_wakeup_runner_errors.py | 46 ++++++++++++++++++++
 2 files changed, 56 insertions(+), 3 deletions(-)
 create mode 100644 tests/unit/test_next_wakeup_runner_errors.py

diff --git a/dispatcher/service/next_wakeup_runner.py b/dispatcher/service/next_wakeup_runner.py
index 67ea56d..e45ab01 100644
--- a/dispatcher/service/next_wakeup_runner.py
+++ b/dispatcher/service/next_wakeup_runner.py
@@ -50,9 +50,12 @@ def __init__(self, wakeup_objects: Iterable[HasWakeup], process_object: Callable
             self.name = name
 
     async def process_wakeups(self, current_time: float, do_processing: bool = True) -> Optional[float]:
-        """Runs process_object for objects past for which we have passed the wakeup time
+        """Runs process_object for objects whose wakeup time has passed.
 
-        Returns the time of the soonest wakeup that has not been processed here
+        Returns the soonest upcoming wakeup time among the objects that have not been processed.
+
+        If do_processing is True, process_object is called for objects with wakeup times below current_time.
+        Errors from process_object are logged and propagated.
 
         Arguments:
          - current_time - output of time.monotonic() passed from caller to keep this deterministic
@@ -63,7 +66,11 @@ async def process_wakeups(self, current_time: float, do_processing: bool = True)
         for obj in list(self.wakeup_objects):
             if obj_wakeup := obj.next_wakeup():
                 if do_processing and (obj_wakeup < current_time):
-                    await self.process_object(obj)
+                    try:
+                        await self.process_object(obj)
+                    except Exception as e:
+                        logger.error(f"Error processing wakeup for object {obj}: {e}", exc_info=True)
+                        raise
                     # refresh wakeup, which should be nullified or pushed back by process_object
                     obj_wakeup = obj.next_wakeup()
                     if obj_wakeup is None:
diff --git a/tests/unit/test_next_wakeup_runner_errors.py b/tests/unit/test_next_wakeup_runner_errors.py
new file mode 100644
index 0000000..2ea5f6d
--- /dev/null
+++ b/tests/unit/test_next_wakeup_runner_errors.py
@@ -0,0 +1,46 @@
+import time
+import pytest
+from dispatcher.service.next_wakeup_runner import NextWakeupRunner, HasWakeup
+
+
+# Dummy object that implements HasWakeup.
+class DummySchedule(HasWakeup):
+    def __init__(self, wakeup_time: float):
+        self._wakeup_time = wakeup_time
+
+    def next_wakeup(self) -> float:
+        return self._wakeup_time
+
+# Dummy process_object that simulates successful processing by pushing wakeup time forward.
+async def dummy_process_object(schedule: DummySchedule) -> None:
+    # Simulate processing by adding 10 seconds.
+    schedule._wakeup_time += 10
+
+# Dummy process_object that raises an exception.
+async def failing_process_object(schedule: DummySchedule) -> None:
+    raise ValueError("Processing error")
+
+@pytest.mark.asyncio
+async def test_process_wakeups_normal():
+    # Set up a dummy schedule with a wakeup time in the past.
+    past_time = time.monotonic() - 5
+    schedule = DummySchedule(past_time)
+    # Use dummy_process_object that adds 10 seconds.
+    runner = NextWakeupRunner([schedule], dummy_process_object)
+    current_time = time.monotonic()
+    next_wakeup = await runner.process_wakeups(current_time)
+    # The wakeup time should now be 10 seconds later than the original past time.
+    assert next_wakeup == schedule._wakeup_time
+    # Also, since the schedule was processed, it should not return None.
+    assert next_wakeup is not None
+
+@pytest.mark.asyncio
+async def test_process_wakeups_error_propagation():
+    # Set up a dummy schedule with a wakeup time in the past.
+    past_time = time.monotonic() - 5
+    schedule = DummySchedule(past_time)
+    # Use failing_process_object that raises an exception.
+    runner = NextWakeupRunner([schedule], failing_process_object)
+    current_time = time.monotonic()
+    with pytest.raises(ValueError, match="Processing error"):
+        await runner.process_wakeups(current_time)

From bc2c516e6d825a702ddee42b2a20393854f99468 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 21:08:06 +0000
Subject: [PATCH 07/10] pg_notify: Improve ConnectionSaver caching, thread
 safety, and type correctness

-- Squashed --

Fix ConnectionSaver caching and type issues for closed connections

- Update get_connection and aget_connection to check if the cached connection is closed (i.e. .closed != 0) and reinitialize it if so, ensuring that run_demo.py and other users always receive a live connection.
- Add type assertions to guarantee that a valid (non-None) connection is returned, resolving mypy errors.

Add thread safety to ConnectionSaver in pg_notify.py and add tests

- Introduce a threading.Lock in ConnectionSaver to protect _connection and _async_connection.
- Wrap the initialization in connection_saver and async_connection_saver with the lock to avoid race conditions.
- Update tests to verify that concurrent access creates only one connection.

Note: We use a standard threading.Lock because this is protecting shared state across threads.
---
 dispatcher/brokers/pg_notify.py     | 32 ++++++++-----
 tests/unit/test_connection_saver.py | 70 +++++++++++++++++++++++++++++
 2 files changed, 91 insertions(+), 11 deletions(-)
 create mode 100644 tests/unit/test_connection_saver.py

diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py
index 1af96a9..065124d 100644
--- a/dispatcher/brokers/pg_notify.py
+++ b/dispatcher/brokers/pg_notify.py
@@ -1,4 +1,5 @@
 import logging
+import threading
 from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union
 
 import psycopg
@@ -97,8 +98,8 @@ def get_publish_channel(self, channel: Optional[str] = None) -> str:
     # --- asyncio connection methods ---
 
     async def aget_connection(self) -> psycopg.AsyncConnection:
-        "Return existing connection or create a new one"
-        if not self._async_connection:
+        # Check if the cached async connection is either None or closed.
+        if not self._async_connection or getattr(self._async_connection, "closed", 0) != 0:
             if self._async_connection_factory:
                 factory = resolve_callable(self._async_connection_factory)
                 if not factory:
@@ -109,7 +110,7 @@ async def aget_connection(self) -> psycopg.AsyncConnection:
             else:
                 raise RuntimeError('Could not construct async connection for lack of config or factory')
             self._async_connection = connection
-            return connection  # slightly weird due to MyPY
+        assert self._async_connection is not None
         return self._async_connection
 
     def get_listen_query(self, channel: str) -> psycopg.sql.Composed:
@@ -178,7 +179,8 @@ async def aclose(self) -> None:
     # --- synchronous connection methods ---
 
     def get_connection(self) -> psycopg.Connection:
-        if not self._sync_connection:
+        # Check if the cached connection is either None or closed.
+        if not self._sync_connection or getattr(self._sync_connection, "closed", 0) != 0:
             if self._sync_connection_factory:
                 factory = resolve_callable(self._sync_connection_factory)
                 if not factory:
@@ -189,7 +191,7 @@ def get_connection(self) -> psycopg.Connection:
             else:
                 raise RuntimeError('Could not construct connection for lack of config or factory')
             self._sync_connection = connection
-            return connection
+        assert self._sync_connection is not None
         return self._sync_connection
 
     def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
@@ -234,6 +236,7 @@ class ConnectionSaver:
     def __init__(self) -> None:
         self._connection: Optional[psycopg.Connection] = None
         self._async_connection: Optional[psycopg.AsyncConnection] = None
+        self._lock = threading.Lock()
 
 
 connection_save = ConnectionSaver()
@@ -245,10 +248,14 @@ def connection_saver(**config) -> psycopg.Connection:  # type: ignore[no-untyped
     Philosophically, this is used by an application that uses an ORM,
     or otherwise has its own connection management logic.
     Dispatcher does not manage connections, so this a simulation of that.
+
+    Uses a thread lock to ensure thread safety.
     """
-    if connection_save._connection is None:
-        connection_save._connection = create_connection(**config)
-    return connection_save._connection
+    with connection_save._lock:
+        # Check if we need to create a new connection because it's either None or closed.
+        if connection_save._connection is None or getattr(connection_save._connection, 'closed', False):
+            connection_save._connection = create_connection(**config)
+        return connection_save._connection
 
 
 async def async_connection_saver(**config) -> psycopg.AsyncConnection:  # type: ignore[no-untyped-def]
@@ -257,7 +264,10 @@ async def async_connection_saver(**config) -> psycopg.AsyncConnection:  # type:
     Philosophically, this is used by an application that uses an ORM,
     or otherwise has its own connection management logic.
     Dispatcher does not manage connections, so this a simulation of that.
+
+    Uses a thread lock to ensure thread safety.
     """
-    if connection_save._async_connection is None:
-        connection_save._async_connection = await acreate_connection(**config)
-    return connection_save._async_connection
+    with connection_save._lock:
+        if connection_save._async_connection is None or getattr(connection_save._async_connection, 'closed', False):
+            connection_save._async_connection = await acreate_connection(**config)
+        return connection_save._async_connection
diff --git a/tests/unit/test_connection_saver.py b/tests/unit/test_connection_saver.py
new file mode 100644
index 0000000..32d3f65
--- /dev/null
+++ b/tests/unit/test_connection_saver.py
@@ -0,0 +1,70 @@
+import threading
+import asyncio
+import pytest
+
+from dispatcher.brokers.pg_notify import connection_saver, async_connection_saver, connection_save
+
+# Define a dummy connection object that supports both sync and async close methods.
+class DummyConnection:
+    def __init__(self):
+        self.closed = False
+    def close(self):
+        self.closed = True
+    async def aclose(self):
+        self.close()
+
+connection_create_count = 0
+
+def dummy_create_connection(**config):
+    global connection_create_count
+    connection_create_count += 1
+    return DummyConnection()
+
+@pytest.fixture(autouse=True)
+def reset_sync(monkeypatch):
+    global connection_create_count
+    connection_create_count = 0
+    monkeypatch.setattr("dispatcher.brokers.pg_notify.create_connection", dummy_create_connection)
+    connection_save._connection = None
+
+def test_connection_saver_thread_safety():
+    results = []
+    def worker():
+        res = connection_saver(foo="bar")
+        results.append(res)
+
+    threads = [threading.Thread(target=worker) for _ in range(10)]
+    for t in threads:
+        t.start()
+    for t in threads:
+        t.join()
+    # Ensure all threads got the same connection object.
+    assert all(r is results[0] for r in results)
+    # Ensure only one connection was created.
+    assert connection_create_count == 1
+    # Check that the connection supports close() properly.
+    results[0].close()
+    assert results[0].closed is True
+
+@pytest.mark.asyncio
+async def test_async_connection_saver_thread_safety(monkeypatch):
+    global connection_create_count
+    connection_create_count = 0
+
+    async def dummy_acreate_connection(**config):
+        global connection_create_count
+        connection_create_count += 1
+        return DummyConnection()
+
+    monkeypatch.setattr("dispatcher.brokers.pg_notify.acreate_connection", dummy_acreate_connection)
+    connection_save._async_connection = None
+
+    async def worker():
+        return await async_connection_saver(foo="bar")
+    results = await asyncio.gather(*[worker() for _ in range(10)])
+    # Ensure all tasks returned the same connection object.
+    assert all(r is results[0] for r in results)
+    # Ensure only one async connection was created.
+    assert connection_create_count == 1
+    await results[0].aclose()
+    assert results[0].closed is True

From 790de964765109f3dcb8d1be3d2a6588ee2af3b9 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 21:19:00 +0000
Subject: [PATCH 08/10] Remove redundant lock in WorkerPool.dispatch_task

- Refactor dispatch_task to avoid holding workers.management_lock for the entire operation.
- Blocker and Queuer functions are expected to be used within the WorkerPool context, so extra locking is unnecessary.
---
 dispatcher/service/pool.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/dispatcher/service/pool.py b/dispatcher/service/pool.py
index 385bd5e..a1ac89c 100644
--- a/dispatcher/service/pool.py
+++ b/dispatcher/service/pool.py
@@ -460,14 +460,16 @@ async def post_task_start(self, message: dict) -> None:
 
     async def dispatch_task(self, message: dict) -> None:
         uuid = message.get("uuid", "<unknown>")
-        async with self.workers.management_lock:
-            if unblocked_task := self.blocker.process_task(message):
-                if worker := self.queuer.get_worker_or_process_task(unblocked_task):
-                    logger.debug(f"Dispatching task (uuid={uuid}) to worker (id={worker.worker_id})")
+        unblocked_task = self.blocker.process_task(message)
+        if unblocked_task:
+            worker = self.queuer.get_worker_or_process_task(unblocked_task)
+            if worker:
+                logger.debug(f"Dispatching task (uuid={uuid}) to worker (id={worker.worker_id})")
+                async with self.workers.management_lock:
                     await worker.start_task(unblocked_task)
                     await self.post_task_start(unblocked_task)
-                else:
-                    self.events.management_event.set()  # kick manager task to start auto-scale up
+            else:
+                self.events.management_event.set()  # kick manager task to start auto-scale up if needed
 
     async def drain_queue(self) -> None:
         async with self.workers.management_lock:

From 53c7e432dd045e21f92ca41f812eb10e92fb4e8d Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 21:24:11 +0000
Subject: [PATCH 09/10] Add type annotations to context manager methods in
 ProcessProxy

- Implement __enter__ and __exit__ with proper type annotations.
- __exit__ ensures that a running process is terminated (or killed) and joined. It returns Optional[bool] and ensures proper process cleanup.
---
 dispatcher/service/process.py | 19 ++++++++++++++++++-
 1 file changed, 18 insertions(+), 1 deletion(-)

diff --git a/dispatcher/service/process.py b/dispatcher/service/process.py
index 4064575..2ebbc84 100644
--- a/dispatcher/service/process.py
+++ b/dispatcher/service/process.py
@@ -2,7 +2,7 @@
 import multiprocessing
 from multiprocessing.context import BaseContext
 from types import ModuleType
-from typing import Callable, Iterable, Optional, Union
+from typing import Any, Callable, Iterable, Optional, Union
 
 from ..config import LazySettings
 from ..config import settings as global_settings
@@ -51,6 +51,23 @@ def kill(self) -> None:
     def terminate(self) -> None:
         self._process.terminate()
 
+    def __enter__(self) -> "ProcessProxy":
+        """Enter the runtime context and return this ProcessProxy."""
+        return self
+
+    def __exit__(self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[Any]) -> Optional[bool]:
+        """Ensure the process is terminated and joined when exiting the context.
+
+        If the process is still alive, it will be terminated (or killed if necessary) and then joined.
+        """
+        if self.is_alive():
+            try:
+                self.terminate()
+            except Exception:
+                self.kill()
+        self.join()
+        return None
+
 
 class ProcessManager:
     mp_context = 'fork'

From ea273bada5623498cc5e39c0fa568bb410d2d6c2 Mon Sep 17 00:00:00 2001
From: Artem Tiupin <artem.tiupin@gmail.com>
Date: Thu, 13 Mar 2025 21:54:15 +0000
Subject: [PATCH 10/10] Use f-string in control.py log message

Replace .format() with f-string for improved readability
in control-and-reply log message.
---
 dispatcher/control.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/dispatcher/control.py b/dispatcher/control.py
index d4d5842..15f15b4 100644
--- a/dispatcher/control.py
+++ b/dispatcher/control.py
@@ -96,7 +96,7 @@ async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
             await broker.aclose()
 
     def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
-        logger.info('control-and-reply {} to {}'.format(command, self.queuename))
+        logger.info(f'control-and-reply {command} to {self.queuename}')
         start = time.time()
         reply_queue = Control.generate_reply_queue_name()
         send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)