Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and Enhance Dispatcher: Worker Management, pg_notify Locking, Error Handling, and More #1

Merged
32 changes: 21 additions & 11 deletions dispatcher/brokers/pg_notify.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import threading
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union

import psycopg
@@ -97,8 +98,8 @@ def get_publish_channel(self, channel: Optional[str] = None) -> str:
# --- asyncio connection methods ---

async def aget_connection(self) -> psycopg.AsyncConnection:
"Return existing connection or create a new one"
if not self._async_connection:
# Check if the cached async connection is either None or closed.
if not self._async_connection or getattr(self._async_connection, "closed", 0) != 0:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's got to be a story behind getattr(self._async_connection, "closed", 0) here, and I'd like to hear it!

if self._async_connection_factory:
factory = resolve_callable(self._async_connection_factory)
if not factory:
@@ -109,7 +110,7 @@ async def aget_connection(self) -> psycopg.AsyncConnection:
else:
raise RuntimeError('Could not construct async connection for lack of config or factory')
self._async_connection = connection
return connection # slightly weird due to MyPY
assert self._async_connection is not None
return self._async_connection

def get_listen_query(self, channel: str) -> psycopg.sql.Composed:
@@ -178,7 +179,8 @@ async def aclose(self) -> None:
# --- synchronous connection methods ---

def get_connection(self) -> psycopg.Connection:
if not self._sync_connection:
# Check if the cached connection is either None or closed.
if not self._sync_connection or getattr(self._sync_connection, "closed", 0) != 0:
if self._sync_connection_factory:
factory = resolve_callable(self._sync_connection_factory)
if not factory:
@@ -189,7 +191,7 @@ def get_connection(self) -> psycopg.Connection:
else:
raise RuntimeError('Could not construct connection for lack of config or factory')
self._sync_connection = connection
return connection
assert self._sync_connection is not None
return self._sync_connection

def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
@@ -234,6 +236,7 @@ class ConnectionSaver:
def __init__(self) -> None:
self._connection: Optional[psycopg.Connection] = None
self._async_connection: Optional[psycopg.AsyncConnection] = None
self._lock = threading.Lock()


connection_save = ConnectionSaver()
@@ -245,10 +248,14 @@ def connection_saver(**config) -> psycopg.Connection: # type: ignore[no-untyped
Philosophically, this is used by an application that uses an ORM,
or otherwise has its own connection management logic.
Dispatcher does not manage connections, so this a simulation of that.

Uses a thread lock to ensure thread safety.
"""
if connection_save._connection is None:
connection_save._connection = create_connection(**config)
return connection_save._connection
with connection_save._lock:
# Check if we need to create a new connection because it's either None or closed.
if connection_save._connection is None or getattr(connection_save._connection, 'closed', False):
connection_save._connection = create_connection(**config)
return connection_save._connection


async def async_connection_saver(**config) -> psycopg.AsyncConnection: # type: ignore[no-untyped-def]
@@ -257,7 +264,10 @@ async def async_connection_saver(**config) -> psycopg.AsyncConnection: # type:
Philosophically, this is used by an application that uses an ORM,
or otherwise has its own connection management logic.
Dispatcher does not manage connections, so this a simulation of that.

Uses a thread lock to ensure thread safety.
"""
if connection_save._async_connection is None:
connection_save._async_connection = await acreate_connection(**config)
return connection_save._async_connection
with connection_save._lock:
if connection_save._async_connection is None or getattr(connection_save._async_connection, 'closed', False):
connection_save._async_connection = await acreate_connection(**config)
return connection_save._async_connection
51 changes: 33 additions & 18 deletions dispatcher/control.py
Original file line number Diff line number Diff line change
@@ -24,12 +24,17 @@ async def connected_callback(self) -> None:
await self.broker.apublish_message(self.queuename, self.send_message)

async def listen_for_replies(self) -> None:
"""Listen to the reply channel until we get the expected number of messages
"""Listen to the reply channel until we get the expected number of messages.

This gets ran in a task, and timing out will be accomplished by the main code
This gets ran in an async task, and timing out will be accomplished by the main code
"""
async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
self.received_replies.append(payload)
try:
# If payload is a string, parse it to a dict; otherwise assume it's valid.
message = json.loads(payload) if isinstance(payload, str) else payload
self.received_replies.append(message)
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON on channel '{channel}': {payload[:100]}... (Error: {e})")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be more localized to do this same thing inside of parse_replies. If you wanted to include the channel information, that wouldn't be completely unreasonable to pass through to that method, but I think channel is low-quality information. There are not going to be multiple channels involved, so we shouldn't care to log it.

if len(self.received_replies) >= self.expected_replies:
return

@@ -54,7 +59,7 @@ def parse_replies(received_replies: list[Union[str, dict]]) -> list[dict]:
ret.append(json.loads(payload))
return ret

def get_send_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str:
def create_message(self, command: str, reply_to: Optional[str] = None, send_data: Optional[dict] = None) -> str:
to_send: dict[str, Union[dict, str]] = {'control': command}
if reply_to:
to_send['reply_to'] = reply_to
@@ -65,7 +70,7 @@ def get_send_message(self, command: str, reply_to: Optional[str] = None, send_da
async def acontrol_with_reply(self, command: str, expected_replies: int = 1, timeout: int = 1, data: Optional[dict] = None) -> list[dict]:
reply_queue = Control.generate_reply_queue_name()
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data)
send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)

control_callbacks = BrokerCallbacks(broker=broker, queuename=self.queuename, send_message=send_message, expected_replies=expected_replies)

@@ -77,35 +82,45 @@ async def acontrol_with_reply(self, command: str, expected_replies: int = 1, tim
except asyncio.TimeoutError:
logger.warning(f'Did not receive {expected_replies} reply in {timeout} seconds, only {len(control_callbacks.received_replies)}')
listen_task.cancel()
finally:
await broker.aclose()

return self.parse_replies(control_callbacks.received_replies)

async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
broker = get_broker(self.broker_name, self.broker_config, channels=[])
send_message = self.get_send_message(command=command, send_data=data)
await broker.apublish_message(message=send_message)
send_message = self.create_message(command=command, send_data=data)
try:
await broker.apublish_message(message=send_message)
finally:
await broker.aclose()

def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
logger.info('control-and-reply {} to {}'.format(command, self.queuename))
logger.info(f'control-and-reply {command} to {self.queuename}')
start = time.time()
reply_queue = Control.generate_reply_queue_name()
send_message = self.get_send_message(command=command, reply_to=reply_queue, send_data=data)
send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)

broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])

def connected_callback() -> None:
broker.publish_message(channel=self.queuename, message=send_message)

replies = []
for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
reply_data = json.loads(payload)
replies.append(reply_data)

logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
return replies
try:
for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
reply_data = json.loads(payload)
replies.append(reply_data)
logger.info(f'control-and-reply message returned in {time.time() - start} seconds')
return replies
finally:
broker.close()

def control(self, command: str, data: Optional[dict] = None) -> None:
"Send message in fire-and-forget mode, as synchronous code. Only for no-reply control."
"""Send a fire-and-forget control message synchronously."""
broker = get_broker(self.broker_name, self.broker_config)
send_message = self.get_send_message(command=command, send_data=data)
broker.publish_message(channel=self.queuename, message=send_message)
send_message = self.create_message(command=command, send_data=data)
try:
broker.publish_message(channel=self.queuename, message=send_message)
finally:
broker.close()
61 changes: 61 additions & 0 deletions dispatcher/protocols.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,13 @@


class Broker(Protocol):
"""
Describes a messaging broker interface.

This interface abstracts functionality for sending and receiving messages,
both asynchronously and synchronously, and for managing connection lifecycles.
"""

async def aprocess_notify(
self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
) -> AsyncGenerator[tuple[str, str], None]:
@@ -35,10 +42,23 @@ def close(self):


class ProducerEvents(Protocol):
"""
Describes an events container for producers.

Typically provides a signal (like a ready event) to indicate producer readiness.
"""

ready_event: asyncio.Event


class Producer(Protocol):
"""
Describes a task producer interface.

This interface encapsulates behavior for starting task production,
managing its lifecycle, and tracking asynchronous operations.
"""

events: ProducerEvents

async def start_producing(self, dispatcher: 'DispatcherMain') -> None:
@@ -55,6 +75,13 @@ def all_tasks(self) -> Iterable[asyncio.Task]:


class PoolWorker(Protocol):
"""
Describes an individual worker in a task pool.

It covers the properties and behaviors needed to track a worker’s execution state
and control its task processing lifecycle.
"""

current_task: Optional[dict]
worker_id: int

@@ -70,18 +97,37 @@ def cancel(self) -> None: ...


class Queuer(Protocol):
"""
Describes an interface for managing pending tasks.

It provides a way to iterate over and modify tasks awaiting assignment.
"""

def __iter__(self) -> Iterator[dict]: ...

def remove_task(self, message: dict) -> None: ...


class Blocker(Protocol):
"""
Describes an interface for handling tasks that are temporarily deferred.

It offers a mechanism to view and manage tasks that cannot run immediately.
"""

def __iter__(self) -> Iterator[dict]: ...

def remove_task(self, message: dict) -> None: ...


class WorkerData(Protocol):
"""
Describes an interface for managing a collection of workers.

It abstracts how worker instances are iterated over and retrieved,
and it provides a lock for safe concurrent updates.
"""

management_lock: asyncio.Lock

def __iter__(self) -> Iterator[PoolWorker]: ...
@@ -90,6 +136,13 @@ def get_by_id(self, worker_id: int) -> PoolWorker: ...


class WorkerPool(Protocol):
"""
Describes an interface for a pool managing task workers.

It includes core functionality for starting the pool, dispatching tasks,
and shutting down the pool in a controlled manner.
"""

workers: WorkerData
queuer: Queuer
blocker: Blocker
@@ -106,6 +159,14 @@ async def shutdown(self) -> None: ...


class DispatcherMain(Protocol):
"""
Describes the primary dispatcher interface.

This interface defines the contract for the overall task dispatching service,
including coordinating task processing, managing the worker pool, and
handling delayed or control messages.
"""

pool: WorkerPool
delayed_messages: set

13 changes: 10 additions & 3 deletions dispatcher/service/next_wakeup_runner.py
Original file line number Diff line number Diff line change
@@ -50,9 +50,12 @@ def __init__(self, wakeup_objects: Iterable[HasWakeup], process_object: Callable
self.name = name

async def process_wakeups(self, current_time: float, do_processing: bool = True) -> Optional[float]:
"""Runs process_object for objects past for which we have passed the wakeup time
"""Runs process_object for objects whose wakeup time has passed.

Returns the time of the soonest wakeup that has not been processed here
Returns the soonest upcoming wakeup time among the objects that have not been processed.

If do_processing is True, process_object is called for objects with wakeup times below current_time.
Errors from process_object are logged and propagated.

Arguments:
- current_time - output of time.monotonic() passed from caller to keep this deterministic
@@ -63,7 +66,11 @@ async def process_wakeups(self, current_time: float, do_processing: bool = True)
for obj in list(self.wakeup_objects):
if obj_wakeup := obj.next_wakeup():
if do_processing and (obj_wakeup < current_time):
await self.process_object(obj)
try:
await self.process_object(obj)
except Exception as e:
logger.error(f"Error processing wakeup for object {obj}: {e}", exc_info=True)
raise
# refresh wakeup, which should be nullified or pushed back by process_object
obj_wakeup = obj.next_wakeup()
if obj_wakeup is None:
Loading