Skip to content

Commit f84989f

Browse files
art-tapinAlanCoding
authored andcommittedMar 14, 2025··
Refactor and Enhance Dispatcher: Worker Management, pg_notify Locking, Error Handling, and More (#1)
* Improve protocol documentation by adding docstrings * Improve error handling in BrokerCallbacks.listen_for_replies - Add JSON parsing with exception handling to ignore malformed messages. - Log warnings when invalid JSON is received. - Add a unit test * Rename get_send_message to create_message for clarity - Change method name in Control class to better reflect its role in constructing messages. * Add resource cleanup in Control methods and tests - Ensure broker connections are closed in acontrol, control_with_reply, and control. - Update tests to verify that aclose() or close() is called appropriately. * Fix race condition in manage_old_workers and add tests - Refactor manage_old_workers to use a two-phase locking approach. - Take a snapshot under lock, process removals, then re-acquire the lock to remove workers atomically. Potentially closes ansible#124 * Improve error propagation in NextWakeupRunner.process_wakeups - Wrap process_object callback in try/except to log and re-raise errors. - Add unit tests to verify normal operation and error propagation. * pg_notify: Improve ConnectionSaver caching, thread safety, and type correctness -- Squashed -- Fix ConnectionSaver caching and type issues for closed connections - Update get_connection and aget_connection to check if the cached connection is closed (i.e. .closed != 0) and reinitialize it if so, ensuring that run_demo.py and other users always receive a live connection. - Add type assertions to guarantee that a valid (non-None) connection is returned, resolving mypy errors. Add thread safety to ConnectionSaver in pg_notify.py and add tests - Introduce a threading.Lock in ConnectionSaver to protect _connection and _async_connection. - Wrap the initialization in connection_saver and async_connection_saver with the lock to avoid race conditions. - Update tests to verify that concurrent access creates only one connection. Note: We use a standard threading.Lock because this is protecting shared state across threads. * Remove redundant lock in WorkerPool.dispatch_task - Refactor dispatch_task to avoid holding workers.management_lock for the entire operation. - Blocker and Queuer functions are expected to be used within the WorkerPool context, so extra locking is unnecessary. * Add type annotations to context manager methods in ProcessProxy - Implement __enter__ and __exit__ with proper type annotations. - __exit__ ensures that a running process is terminated (or killed) and joined. It returns Optional[bool] and ensures proper process cleanup. * Use f-string in control.py log message Replace .format() with f-string for improved readability in control-and-reply log message.
1 parent a2ef856 commit f84989f

10 files changed

+429
-46
lines changed
 

‎dispatcher/brokers/pg_notify.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import threading
23
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union
34

45
import psycopg
@@ -97,8 +98,8 @@ def get_publish_channel(self, channel: Optional[str] = None) -> str:
9798
# --- asyncio connection methods ---
9899

99100
async def aget_connection(self) -> psycopg.AsyncConnection:
100-
"Return existing connection or create a new one"
101-
if not self._async_connection:
101+
# Check if the cached async connection is either None or closed.
102+
if not self._async_connection or getattr(self._async_connection, "closed", 0) != 0:
102103
if self._async_connection_factory:
103104
factory = resolve_callable(self._async_connection_factory)
104105
if not factory:
@@ -109,7 +110,7 @@ async def aget_connection(self) -> psycopg.AsyncConnection:
109110
else:
110111
raise RuntimeError('Could not construct async connection for lack of config or factory')
111112
self._async_connection = connection
112-
return connection # slightly weird due to MyPY
113+
assert self._async_connection is not None
113114
return self._async_connection
114115

115116
def get_listen_query(self, channel: str) -> psycopg.sql.Composed:
@@ -178,7 +179,8 @@ async def aclose(self) -> None:
178179
# --- synchronous connection methods ---
179180

180181
def get_connection(self) -> psycopg.Connection:
181-
if not self._sync_connection:
182+
# Check if the cached connection is either None or closed.
183+
if not self._sync_connection or getattr(self._sync_connection, "closed", 0) != 0:
182184
if self._sync_connection_factory:
183185
factory = resolve_callable(self._sync_connection_factory)
184186
if not factory:
@@ -189,7 +191,7 @@ def get_connection(self) -> psycopg.Connection:
189191
else:
190192
raise RuntimeError('Could not construct connection for lack of config or factory')
191193
self._sync_connection = connection
192-
return connection
194+
assert self._sync_connection is not None
193195
return self._sync_connection
194196

195197
def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
@@ -234,6 +236,7 @@ class ConnectionSaver:
234236
def __init__(self) -> None:
235237
self._connection: Optional[psycopg.Connection] = None
236238
self._async_connection: Optional[psycopg.AsyncConnection] = None
239+
self._lock = threading.Lock()
237240

238241

239242
connection_save = ConnectionSaver()
@@ -245,10 +248,14 @@ def connection_saver(**config) -> psycopg.Connection: # type: ignore[no-untyped
245248
Philosophically, this is used by an application that uses an ORM,
246249
or otherwise has its own connection management logic.
247250
Dispatcher does not manage connections, so this a simulation of that.
251+
252+
Uses a thread lock to ensure thread safety.
248253
"""
249-
if connection_save._connection is None:
250-
connection_save._connection = create_connection(**config)
251-
return connection_save._connection
254+
with connection_save._lock:
255+
# Check if we need to create a new connection because it's either None or closed.
256+
if connection_save._connection is None or getattr(connection_save._connection, 'closed', False):
257+
connection_save._connection = create_connection(**config)
258+
return connection_save._connection
252259

253260

254261
async def async_connection_saver(**config) -> psycopg.AsyncConnection: # type: ignore[no-untyped-def]
@@ -257,7 +264,10 @@ async def async_connection_saver(**config) -> psycopg.AsyncConnection: # type:
257264
Philosophically, this is used by an application that uses an ORM,
258265
or otherwise has its own connection management logic.
259266
Dispatcher does not manage connections, so this a simulation of that.
267+
268+
Uses a thread lock to ensure thread safety.
260269
"""
261-
if connection_save._async_connection is None:
262-
connection_save._async_connection = await acreate_connection(**config)
263-
return connection_save._async_connection
270+
with connection_save._lock:
271+
if connection_save._async_connection is None or getattr(connection_save._async_connection, 'closed', False):
272+
connection_save._async_connection = await acreate_connection(**config)
273+
return connection_save._async_connection

‎dispatcher/control.py

+33-18
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@ async def connected_callback(self) -> None:
2424
await self.broker.apublish_message(self.queuename, self.send_message)
2525

2626
async def listen_for_replies(self) -> None:
27-
"""Listen to the reply channel until we get the expected number of messages
27+
"""Listen to the reply channel until we get the expected number of messages.
2828
29-
This gets ran in a task, and timing out will be accomplished by the main code
29+
This gets ran in an async task, and timing out will be accomplished by the main code
3030
"""
3131
async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
32-
self.received_replies.append(payload)
32+
try:
33+
# If payload is a string, parse it to a dict; otherwise assume it's valid.
34+
message = json.loads(payload) if isinstance(payload, str) else payload
35+
self.received_replies.append(message)
36+
except json.JSONDecodeError as e:
37+
logger.warning(f"Invalid JSON on channel '{channel}': {payload[:100]}... (Error: {e})")
3338
if len(self.received_replies) >= self.expected_replies:
3439
return
3540

@@ -54,7 +59,7 @@ def parse_replies(received_replies: list[Union[str, dict]]) -> list[dict]:
5459
ret.append(json.loads(payload))
5560
return ret
5661

57-
def get_send_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str:
62+
def create_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str:
5863
to_send: dict[str, Union[dict, str]] = {'control': command}
5964
if reply_to:
6065
to_send['reply_to'] = reply_to
@@ -65,7 +70,7 @@ def get_send_message(self, command: str, reply_to: Optional[str] = None, send_da
6570
async def acontrol_with_reply(self, command: str, expected_replies: int = 1, timeout: int = 1, data: Optional[dict] = None) -> list[dict]:
6671
reply_queue = Control.generate_reply_queue_name()
6772
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
68-
send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data)
73+
send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)
6974

7075
control_callbacks = BrokerCallbacks(broker=broker, queuename=self.queuename, send_message=send_message, expected_replies=expected_replies)
7176

@@ -77,35 +82,45 @@ async def acontrol_with_reply(self, command: str, expected_replies: int = 1, tim
7782
except asyncio.TimeoutError:
7883
logger.warning(f'Did not receive {expected_replies} reply in {timeout} seconds, only {len(control_callbacks.received_replies)}')
7984
listen_task.cancel()
85+
finally:
86+
await broker.aclose()
8087

8188
return self.parse_replies(control_callbacks.received_replies)
8289

8390
async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
8491
broker = get_broker(self.broker_name, self.broker_config, channels=[])
85-
send_message = self.get_send_message(command=command, send_data=data)
86-
await broker.apublish_message(message=send_message)
92+
send_message = self.create_message(command=command, send_data=data)
93+
try:
94+
await broker.apublish_message(message=send_message)
95+
finally:
96+
await broker.aclose()
8797

8898
def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
89-
logger.info('control-and-reply {} to {}'.format(command, self.queuename))
99+
logger.info(f'control-and-reply {command} to {self.queuename}')
90100
start = time.time()
91101
reply_queue = Control.generate_reply_queue_name()
92-
send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data)
102+
send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)
93103

94104
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
95105

96106
def connected_callback() -> None:
97107
broker.publish_message(channel=self.queuename, message=send_message)
98108

99109
replies = []
100-
for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
101-
reply_data = json.loads(payload)
102-
replies.append(reply_data)
103-
104-
logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
105-
return replies
110+
try:
111+
for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
112+
reply_data = json.loads(payload)
113+
replies.append(reply_data)
114+
logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
115+
return replies
116+
finally:
117+
broker.close()
106118

107119
def control(self, command: str, data: Optional[dict] = None) -> None:
108-
"Send message in fire-and-forget mode, as synchronous code. Only for no-reply control."
120+
"""Send a fire-and-forget control message synchronously."""
109121
broker = get_broker(self.broker_name, self.broker_config)
110-
send_message = self.get_send_message(command=command, send_data=data)
111-
broker.publish_message(channel=self.queuename, message=send_message)
122+
send_message = self.create_message(command=command, send_data=data)
123+
try:
124+
broker.publish_message(channel=self.queuename, message=send_message)
125+
finally:
126+
broker.close()

‎dispatcher/protocols.py

+61
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33

44

55
class Broker(Protocol):
6+
"""
7+
Describes a messaging broker interface.
8+
9+
This interface abstracts functionality for sending and receiving messages,
10+
both asynchronously and synchronously, and for managing connection lifecycles.
11+
"""
12+
613
async def aprocess_notify(
714
self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
815
) -> AsyncGenerator[tuple[str, str], None]:
@@ -35,10 +42,23 @@ def close(self):
3542

3643

3744
class ProducerEvents(Protocol):
45+
"""
46+
Describes an events container for producers.
47+
48+
Typically provides a signal (like a ready event) to indicate producer readiness.
49+
"""
50+
3851
ready_event: asyncio.Event
3952

4053

4154
class Producer(Protocol):
55+
"""
56+
Describes a task producer interface.
57+
58+
This interface encapsulates behavior for starting task production,
59+
managing its lifecycle, and tracking asynchronous operations.
60+
"""
61+
4262
events: ProducerEvents
4363

4464
async def start_producing(self, dispatcher: 'DispatcherMain') -> None:
@@ -55,6 +75,13 @@ def all_tasks(self) -> Iterable[asyncio.Task]:
5575

5676

5777
class PoolWorker(Protocol):
78+
"""
79+
Describes an individual worker in a task pool.
80+
81+
It covers the properties and behaviors needed to track a worker’s execution state
82+
and control its task processing lifecycle.
83+
"""
84+
5885
current_task: Optional[dict]
5986
worker_id: int
6087

@@ -70,18 +97,37 @@ def cancel(self) -> None: ...
7097

7198

7299
class Queuer(Protocol):
100+
"""
101+
Describes an interface for managing pending tasks.
102+
103+
It provides a way to iterate over and modify tasks awaiting assignment.
104+
"""
105+
73106
def __iter__(self) -> Iterator[dict]: ...
74107

75108
def remove_task(self, message: dict) -> None: ...
76109

77110

78111
class Blocker(Protocol):
112+
"""
113+
Describes an interface for handling tasks that are temporarily deferred.
114+
115+
It offers a mechanism to view and manage tasks that cannot run immediately.
116+
"""
117+
79118
def __iter__(self) -> Iterator[dict]: ...
80119

81120
def remove_task(self, message: dict) -> None: ...
82121

83122

84123
class WorkerData(Protocol):
124+
"""
125+
Describes an interface for managing a collection of workers.
126+
127+
It abstracts how worker instances are iterated over and retrieved,
128+
and it provides a lock for safe concurrent updates.
129+
"""
130+
85131
management_lock: asyncio.Lock
86132

87133
def __iter__(self) -> Iterator[PoolWorker]: ...
@@ -90,6 +136,13 @@ def get_by_id(self, worker_id: int) -> PoolWorker: ...
90136

91137

92138
class WorkerPool(Protocol):
139+
"""
140+
Describes an interface for a pool managing task workers.
141+
142+
It includes core functionality for starting the pool, dispatching tasks,
143+
and shutting down the pool in a controlled manner.
144+
"""
145+
93146
workers: WorkerData
94147
queuer: Queuer
95148
blocker: Blocker
@@ -106,6 +159,14 @@ async def shutdown(self) -> None: ...
106159

107160

108161
class DispatcherMain(Protocol):
162+
"""
163+
Describes the primary dispatcher interface.
164+
165+
This interface defines the contract for the overall task dispatching service,
166+
including coordinating task processing, managing the worker pool, and
167+
handling delayed or control messages.
168+
"""
169+
109170
pool: WorkerPool
110171
delayed_messages: set
111172

‎dispatcher/service/next_wakeup_runner.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ def __init__(self, wakeup_objects: Iterable[HasWakeup], process_object: Callable
5050
self.name = name
5151

5252
async def process_wakeups(self, current_time: float, do_processing: bool = True) -> Optional[float]:
53-
"""Runs process_object for objects past for which we have passed the wakeup time
53+
"""Runs process_object for objects whose wakeup time has passed.
5454
55-
Returns the time of the soonest wakeup that has not been processed here
55+
Returns the soonest upcoming wakeup time among the objects that have not been processed.
56+
57+
If do_processing is True, process_object is called for objects with wakeup times below current_time.
58+
Errors from process_object are logged and propagated.
5659
5760
Arguments:
5861
- current_time - output of time.monotonic() passed from caller to keep this deterministic
@@ -63,7 +66,11 @@ async def process_wakeups(self, current_time: float, do_processing: bool = True)
6366
for obj in list(self.wakeup_objects):
6467
if obj_wakeup := obj.next_wakeup():
6568
if do_processing and (obj_wakeup < current_time):
66-
await self.process_object(obj)
69+
try:
70+
await self.process_object(obj)
71+
except Exception as e:
72+
logger.error(f"Error processing wakeup for object {obj}: {e}", exc_info=True)
73+
raise
6774
# refresh wakeup, which should be nullified or pushed back by process_object
6875
obj_wakeup = obj.next_wakeup()
6976
if obj_wakeup is None:

0 commit comments

Comments
 (0)
Please sign in to comment.