Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unix socket broker #125

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dispatcher.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ brokers:
- test_channel2
- test_channel3
default_publish_channel: test_channel
socket:
socket_path: demo_dispatcher.sock
producers:
ScheduledProducer:
task_schedule:
Expand All @@ -30,4 +32,5 @@ producers:
task_list:
'lambda: print("This task runs on startup")': {}
publish:
default_control_broker: socket
default_broker: pg_notify
9 changes: 7 additions & 2 deletions dispatcher/brokers/pg_notify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dispatcher.utils import resolve_callable

from ..protocols import Broker as BrokerProtocol

logger = logging.getLogger(__name__)


Expand All @@ -32,7 +34,7 @@ def create_connection(**config) -> psycopg.Connection: # type: ignore[no-untype
return connection


class Broker:
class Broker(BrokerProtocol):
NOTIFY_QUERY_TEMPLATE = 'SELECT pg_notify(%s, %s);'

def __init__(
Expand Down Expand Up @@ -98,6 +100,9 @@ def get_publish_channel(self, channel: Optional[str] = None) -> str:

raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config')

def __str__(self) -> str:
return 'pg_notify-broker'

# --- asyncio connection methods ---

async def aget_connection(self) -> psycopg.AsyncConnection:
Expand Down Expand Up @@ -157,7 +162,7 @@ async def apublish_message_from_cursor(self, cursor: psycopg.AsyncCursor, channe
"""The inner logic of async message publishing where we already have a cursor"""
await cursor.execute(self.NOTIFY_QUERY_TEMPLATE, (channel, message))

async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public
async def apublish_message(self, channel: Optional[str] = None, origin: Union[str, int, None] = '', message: str = '') -> None: # public
"""asyncio way to publish a message, used to send control in control-and-reply

Not strictly necessary for the service itself if it sends replies in the workers,
Expand Down
223 changes: 223 additions & 0 deletions dispatcher/brokers/socket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
import asyncio
import json
import logging
import os
import socket
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union

from ..protocols import Broker as BrokerProtocol

logger = logging.getLogger(__name__)


class Client:
def __init__(self, client_id: int, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
self.client_id = client_id
self.reader = reader
self.writer = writer
self.listen_loop_active = False
# This is needed for task management betewen the client tasks and the main aprocess_notify
# if the client task starts listening, then we can not send replies
# so this waits for the caller method to add replies to stack before continuing
self.yield_clear = asyncio.Event()
self.replies_to_send: list = []

def write(self, message: str, /) -> None:
self.writer.write((message + '\n').encode())

def queue_reply(self, reply: str, /) -> None:
self.replies_to_send.append(reply)

async def send_replies(self) -> None:
for reply in self.replies_to_send.copy():
logger.info(f'Sending reply to client_id={self.client_id} len={len(reply)}')
self.write(reply)
else:
logger.info(f'No replies to send to client_id={self.client_id}')
await self.writer.drain()
self.replies_to_send = []


def extract_json(message: str) -> Iterator[str]:
"""With message that may be an incomplete JSON string, yield JSON-complete strings and leftover"""
decoder = json.JSONDecoder()
pos = 0
length = len(message)
while pos < length:
try:
_, index = decoder.raw_decode(message, pos)
json_msg = message[pos:index]
yield json_msg
pos = index
except json.JSONDecodeError:
break


class Broker(BrokerProtocol):
"""A Unix socket client for dispatcher as simple as possible

Because we want to be as simple as possible we do not maintain persistent connections.
So every control-and-reply command will connect and disconnect.

Intended use is for dispatcherctl, so that we may bypass any flake related to pg_notify
for debugging information.
"""

def __init__(self, socket_path: str) -> None:
self.socket_path = socket_path
self.client_ct = 0
self.clients: dict[int, Client] = {}
self.sock: Optional[socket.socket] = None # for synchronous clients
self.incoming_queue: asyncio.Queue = asyncio.Queue()

def __str__(self) -> str:
return f'socket-broker-{self.socket_path}'

async def _add_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
client = Client(self.client_ct, reader, writer)
self.clients[self.client_ct] = client
self.client_ct += 1
logger.info(f'Socket client_id={client.client_id} is connected')

try:
client.listen_loop_active = True
while True:
line = await client.reader.readline()
if not line:
break # disconnect
message = line.decode().strip()
await self.incoming_queue.put((client.client_id, message))
# Wait for caller to potentially fill a reply queue
# this should realistically never take more than a trivial amount of time
await asyncio.wait_for(client.yield_clear.wait(), timeout=2)
client.yield_clear.clear()
await client.send_replies()
except asyncio.TimeoutError:
logger.error(f'Unexpected asyncio task management bug for client_id={client.client_id}, exiting')
except asyncio.CancelledError:
logger.debug(f'Ack that reader task for client_id={client.client_id} has been canceled')
except Exception:
logger.exception(f'Exception from reader task for client_id={client.client_id}')
finally:
del self.clients[client.client_id]
client.writer.close()
await client.writer.wait_closed()
logger.info(f'Socket client_id={client.client_id} is disconnected')

async def aprocess_notify(
self, connected_callback: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
) -> AsyncGenerator[tuple[Union[int, str], str], None]:
if os.path.exists(self.socket_path):
logger.debug(f'Deleted pre-existing {self.socket_path}')
os.remove(self.socket_path)

aserver = None
try:
aserver = await asyncio.start_unix_server(self._add_client, self.socket_path)
logger.info(f'Set up socket server on {self.socket_path}')

if connected_callback:
await connected_callback()

while True:
client_id, message = await self.incoming_queue.get()
if (client_id == -1) and (message == 'stop'):
return # internal exit signaling from aclose

yield client_id, message
# trigger reply messages if applicable
client = self.clients.get(client_id)
if client:
logger.info(f'Yield complete for client_id={client_id}')
client.yield_clear.set()

except asyncio.CancelledError:
logger.debug('Ack that general socket server task has been canceled')
finally:
if aserver:
aserver.close()
await aserver.wait_closed()

for client in self.clients.values():
client.writer.close()
await client.writer.wait_closed()
self.clients = {}

if os.path.exists(self.socket_path):
os.remove(self.socket_path)

async def aclose(self) -> None:
"""Send an internal message to the async generator, which will cause it to close the server"""
await self.incoming_queue.put((-1, 'stop'))

async def apublish_message(self, channel: Optional[str] = '', origin: Union[int, str, None] = None, message: str = "") -> None:
if isinstance(origin, int) and origin >= 0:
client = self.clients.get(int(origin))
if client:
if client.listen_loop_active:
logger.info(f'Queued message len={len(message)} for client_id={origin}')
client.queue_reply(message)
else:
logger.warning(f'Not currently listening to client_id={origin}, attempting reply len={len(message)}, but might be dropped')
client.write(message)
await client.writer.drain()
else:
logger.error(f'Client_id={origin} is not currently connected')
else:
# Acting as a client in this case, mostly for tests
logger.info(f'Publishing async socket message len={len(message)} with new connection')
writer = None
try:
_, writer = await asyncio.open_unix_connection(self.socket_path)
writer.write((message + '\n').encode())
await writer.drain()
finally:
if writer:
writer.close()
await writer.wait_closed()

def process_notify(
self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1
) -> Iterator[tuple[Union[int, str], str]]:
try:
with socket.socket(socket.AF_UNIX) as sock:
self.sock = sock
sock.settimeout(timeout)
sock.connect(self.socket_path)

if connected_callback:
connected_callback()

received_ct = 0
buffer = ''
while True:
response = sock.recv(1024).decode().strip()

current_message = buffer + response
yielded_chars = 0
for complete_msg in extract_json(current_message):
received_ct += 1
yield (0, complete_msg)
if received_ct >= max_messages:
return
yielded_chars += len(complete_msg)
else:
buffer = current_message[yielded_chars:]
logger.info(f'Received incomplete message len={len(buffer)}, adding to buffer')

finally:
self.sock = None

def _publish_from_sock(self, sock: socket.socket, message: str) -> None:
sock.sendall((message + "\n").encode())

def publish_message(self, channel: Optional[str] = None, message: Optional[str] = None) -> None:
assert isinstance(message, str)
if self.sock:
logger.info(f'Publishing socket message len={len(message)} via existing connection')
self._publish_from_sock(self.sock, message)
else:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.connect(self.socket_path)
logger.info(f'Publishing socket message len={len(message)} over new connection')
self._publish_from_sock(sock, message)
8 changes: 5 additions & 3 deletions dispatcher/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, queuename: Optional[str], broker: Broker, send_message: str,
self.expected_replies = expected_replies

async def connected_callback(self) -> None:
await self.broker.apublish_message(self.queuename, self.send_message)
await self.broker.apublish_message(channel=self.queuename, message=self.send_message)

async def listen_for_replies(self) -> None:
"""Listen to the reply channel until we get the expected number of messages.
Expand Down Expand Up @@ -94,12 +94,14 @@ async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
await broker.aclose()

def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
logger.info(f'control-and-reply {command} to {self.queuename}')
start = time.time()
reply_queue = Control.generate_reply_queue_name()
send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)

broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
try:
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
except TypeError:
broker = get_broker(self.broker_name, self.broker_config)

def connected_callback() -> None:
broker.publish_message(channel=self.queuename, message=send_message)
Expand Down
10 changes: 7 additions & 3 deletions dispatcher/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,14 @@ def get_publisher_from_settings(publish_broker: Optional[str] = None, settings:


def get_control_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides):
publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings)
broker_options = settings.brokers[publish_broker].copy()
"""Returns a Control instance based on the values in settings"""
if 'default_control_broker' in settings.publish:
result_publish_broker = settings.publish['default_control_broker']
else:
result_publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings)
broker_options = settings.brokers[result_publish_broker].copy()
broker_options.update(overrides)
return Control(publish_broker, broker_options)
return Control(result_publish_broker, broker_options)


# ---- Schema generation ----
Expand Down
13 changes: 8 additions & 5 deletions dispatcher/producers/brokered.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Iterable, Optional
from typing import Iterable, Optional, Union

from ..protocols import Broker, DispatcherMain
from .base import BaseProducer
Expand All @@ -15,6 +15,9 @@ def __init__(self, broker: Broker) -> None:
self.dispatcher: Optional[DispatcherMain] = None
super().__init__()

def __str__(self) -> str:
return f'brokered-producer-{self.broker}'

async def start_producing(self, dispatcher: DispatcherMain) -> None:
self.production_task = asyncio.create_task(self.produce_forever(dispatcher), name=f'{self.broker.__module__}_production')

Expand All @@ -33,12 +36,12 @@ async def produce_forever(self, dispatcher: DispatcherMain) -> None:
self.dispatcher = dispatcher
async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
self.produced_count += 1
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=channel)
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=str(channel))
if reply_to and reply_payload:
await self.notify(channel=reply_to, message=reply_payload)
await self.notify(channel=reply_to, origin=channel, message=reply_payload)

async def notify(self, channel: Optional[str] = None, message: str = '') -> None:
await self.broker.apublish_message(channel=channel, message=message)
async def notify(self, channel: Optional[str] = None, origin: Optional[Union[int, str]] = None, message: str = '') -> None:
await self.broker.apublish_message(channel=channel, origin=origin, message=message)

async def shutdown(self) -> None:
if self.production_task:
Expand Down
8 changes: 5 additions & 3 deletions dispatcher/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,25 @@ class Broker(Protocol):

async def aprocess_notify(
self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
) -> AsyncGenerator[tuple[str, str], None]:
) -> AsyncGenerator[tuple[Union[int, str], str], None]:
"""The generator of messages from the broker for the dispatcher service

The producer iterates this to produce tasks.
This uses the async connection of the broker.
"""
yield ('', '') # yield affects CPython type https://github.com/python/mypy/pull/18422

async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None:
async def apublish_message(self, channel: Optional[str] = None, origin: Union[int, str, None] = None, message: str = '') -> None:
"""Asynchronously send a message to the broker, used by dispatcher service for reply messages"""
...

async def aclose(self) -> None:
"""Close the asynchronous connection, used by service, and optionally by publishers"""
...

def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
def process_notify(
self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1
) -> Iterator[tuple[Union[int, str], str]]:
"""Synchronous method to generate messages from broker, used for synchronous control-and-reply"""
...

Expand Down
Loading