Skip to content

Commit 0f6b41b

Browse files
committed
Add socket broker
Protocol the brokers Run demo dispatcherctl over socket Add socket broker unit tests Add socket broker usage integration tests Work out issues with test scope and server not opening client connections
1 parent 44082d7 commit 0f6b41b

File tree

12 files changed

+448
-49
lines changed

12 files changed

+448
-49
lines changed

dispatcher.yml

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ brokers:
1919
- test_channel2
2020
- test_channel3
2121
default_publish_channel: test_channel
22+
socket:
23+
socket_path: demo_dispatcher.sock
2224
producers:
2325
ScheduledProducer:
2426
task_schedule:
@@ -30,4 +32,5 @@ producers:
3032
task_list:
3133
'lambda: print("This task runs on startup")': {}
3234
publish:
35+
default_control_broker: socket
3336
default_broker: pg_notify

dispatcher/brokers/pg_notify.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from dispatcher.utils import resolve_callable
77

8+
from ..protocols import Broker as BrokerProtocol
9+
810
logger = logging.getLogger(__name__)
911

1012

@@ -31,7 +33,7 @@ def create_connection(**config) -> psycopg.Connection:
3133
return connection
3234

3335

34-
class Broker:
36+
class Broker(BrokerProtocol):
3537
NOTIFY_QUERY_TEMPLATE = 'SELECT pg_notify(%s, %s);'
3638

3739
def __init__(
@@ -152,7 +154,7 @@ async def apublish_message_from_cursor(self, cursor: psycopg.AsyncCursor, channe
152154
"""The inner logic of async message publishing where we already have a cursor"""
153155
await cursor.execute(self.NOTIFY_QUERY_TEMPLATE, (channel, message))
154156

155-
async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public
157+
async def apublish_message(self, channel: Optional[str] = None, origin: Union[str, int, None] = '', message: str = '') -> None: # public
156158
"""asyncio way to publish a message, used to send control in control-and-reply
157159
158160
Not strictly necessary for the service itself if it sends replies in the workers,

dispatcher/brokers/socket.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import asyncio
2+
import logging
3+
import os
4+
import socket
5+
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union
6+
7+
from ..protocols import Broker as BrokerProtocol
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class Client:
13+
def __init__(self, client_id: int, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
14+
self.client_id = client_id
15+
self.reader = reader
16+
self.writer = writer
17+
self.listen_loop_active = False
18+
# This is needed for task management betewen the client tasks and the main aprocess_notify
19+
# if the client task starts listening, then we can not send replies
20+
# so this waits for the caller method to add replies to stack before continuing
21+
self.yield_clear = asyncio.Event()
22+
self.replies_to_send: list = []
23+
24+
def write(self, message) -> None:
25+
self.writer.write((message + '\n').encode())
26+
27+
def queue_reply(self, reply: str) -> None:
28+
self.replies_to_send.append(reply)
29+
30+
async def send_replies(self):
31+
for reply in self.replies_to_send.copy():
32+
logger.info(f'Sending reply to client_id={self.client_id} len={len(reply)}')
33+
self.write(reply)
34+
else:
35+
logger.info(f'No replies to send to client_id={self.client_id}')
36+
await self.writer.drain()
37+
self.replies_to_send = []
38+
39+
40+
class Broker(BrokerProtocol):
41+
"""A Unix socket client for dispatcher as simple as possible
42+
43+
Because we want to be as simple as possible we do not maintain persistent connections.
44+
So every control-and-reply command will connect and disconnect.
45+
46+
Intended use is for dispatcherctl, so that we may bypass any flake related to pg_notify
47+
for debugging information.
48+
"""
49+
50+
def __init__(self, socket_path: str) -> None:
51+
self.socket_path = socket_path
52+
self.client_ct = 0
53+
self.clients: dict[int, Client] = {}
54+
self.sock: Optional[socket.socket] = None # for synchronous clients
55+
self.incoming_queue: asyncio.Queue = asyncio.Queue()
56+
57+
def __str__(self):
58+
return f'socket-producer-{self.socket_path}'
59+
60+
async def _add_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
61+
client = Client(self.client_ct, reader, writer)
62+
self.clients[self.client_ct] = client
63+
self.client_ct += 1
64+
logger.info(f'Socket client_id={client.client_id} is connected')
65+
66+
try:
67+
client.listen_loop_active = True
68+
while True:
69+
line = await client.reader.readline()
70+
if not line:
71+
break # disconnect
72+
message = line.decode().strip()
73+
await self.incoming_queue.put((client.client_id, message))
74+
# Wait for caller to potentially fill a reply queue
75+
# this should realistically never take more than a trivial amount of time
76+
await asyncio.wait_for(client.yield_clear.wait(), timeout=2)
77+
client.yield_clear.clear()
78+
await client.send_replies()
79+
except asyncio.TimeoutError:
80+
logger.error(f'Unexpected asyncio task management bug for client_id={client.client_id}, exiting')
81+
except asyncio.CancelledError:
82+
logger.debug(f'Ack that reader task for client_id={client.client_id} has been canceled')
83+
except Exception:
84+
logger.exception(f'Exception from reader task for client_id={client.client_id}')
85+
finally:
86+
del self.clients[client.client_id]
87+
client.writer.close()
88+
await client.writer.wait_closed()
89+
logger.info(f'Socket client_id={client.client_id} is disconnected')
90+
91+
async def aprocess_notify(
92+
self, connected_callback: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
93+
) -> AsyncGenerator[tuple[Union[int, str], str], None]:
94+
if os.path.exists(self.socket_path):
95+
logger.debug(f'Deleted pre-existing {self.socket_path}')
96+
os.remove(self.socket_path)
97+
98+
aserver = None
99+
try:
100+
aserver = await asyncio.start_unix_server(self._add_client, self.socket_path)
101+
logger.info(f'Set up socket server on {self.socket_path}')
102+
103+
if connected_callback:
104+
await connected_callback()
105+
106+
while True:
107+
client_id, message = await self.incoming_queue.get()
108+
if (client_id == -1) and (message == 'stop'):
109+
return # internal exit signaling from aclose
110+
111+
yield client_id, message
112+
# trigger reply messages if applicable
113+
client = self.clients.get(client_id)
114+
if client:
115+
logger.info(f'Yield complete for client_id={client_id}')
116+
client.yield_clear.set()
117+
118+
except asyncio.CancelledError:
119+
logger.debug('Ack that general socket server task has been canceled')
120+
finally:
121+
if aserver:
122+
aserver.close()
123+
await aserver.wait_closed()
124+
125+
for client in self.clients.values():
126+
client.writer.close()
127+
await client.writer.wait_closed()
128+
self.clients = {}
129+
130+
if os.path.exists(self.socket_path):
131+
os.remove(self.socket_path)
132+
133+
async def aclose(self) -> None:
134+
"""Send an internal message to the async generator, which will cause it to close the server"""
135+
await self.incoming_queue.put((-1, 'stop'))
136+
137+
async def apublish_message(self, channel: Optional[str] = '', origin: Union[int, str, None] = None, message: str = "") -> None:
138+
if isinstance(origin, int) and origin >= 0:
139+
client = self.clients.get(int(origin))
140+
if client:
141+
if client.listen_loop_active:
142+
logger.info(f'Queued message len={len(message)} for client_id={origin}')
143+
client.queue_reply(message)
144+
else:
145+
logger.warning(f'Not currently listening to client_id={origin}, attempting reply len={len(message)}, but might be dropped')
146+
client.write(message)
147+
await client.writer.drain()
148+
else:
149+
logger.error(f'Client_id={origin} is not currently connected')
150+
else:
151+
# Acting as a client in this case, mostly for tests
152+
logger.info(f'Publishing async socket message len={len(message)} with new connection')
153+
writer = None
154+
try:
155+
_, writer = await asyncio.open_unix_connection(self.socket_path)
156+
writer.write((message + '\n').encode())
157+
await writer.drain()
158+
finally:
159+
if writer:
160+
writer.close()
161+
await writer.wait_closed()
162+
163+
def process_notify(
164+
self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1
165+
) -> Iterator[tuple[Union[int, str], str]]:
166+
try:
167+
with socket.socket(socket.AF_UNIX) as sock:
168+
self.sock = sock
169+
sock.settimeout(timeout)
170+
sock.connect(self.socket_path)
171+
172+
if connected_callback:
173+
connected_callback()
174+
175+
received_ct = 0
176+
buffer = ''
177+
while True:
178+
response = sock.recv(1024).decode().strip()
179+
180+
if response.endswith('}'):
181+
response = buffer + response
182+
buffer = ''
183+
received_ct += 1
184+
yield (0, response)
185+
if received_ct >= max_messages:
186+
return
187+
else:
188+
logger.info(f'Received incomplete message len={len(response)}, adding to buffer')
189+
buffer += response
190+
finally:
191+
self.sock = None
192+
193+
def _publish_from_sock(self, sock, message) -> None:
194+
sock.sendall((message + "\n").encode())
195+
196+
def publish_message(self, channel=None, message=None) -> None:
197+
if self.sock:
198+
logger.info(f'Publishing socket message len={len(message)} via existing connection')
199+
self._publish_from_sock(self.sock, message)
200+
else:
201+
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
202+
sock.connect(self.socket_path)
203+
logger.info(f'Publishing socket message len={len(message)} over new connection')
204+
self._publish_from_sock(sock, message)

dispatcher/control.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -116,25 +116,31 @@ async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
116116
await control_callbacks.connected_callback(producer)
117117

118118
def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
119-
logger.info('control-and-reply {} to {}'.format(command, self.queuename))
120119
start = time.time()
121120
reply_queue = Control.generate_reply_queue_name()
122121
send_data: dict[str, Union[dict, str]] = {'control': command, 'reply_to': reply_queue}
123122
if data:
124123
send_data['control_data'] = data
124+
payload = json.dumps(send_data)
125125

126-
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
126+
try:
127+
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
128+
except TypeError:
129+
broker = get_broker(self.broker_name, self.broker_config)
127130

128131
def connected_callback() -> None:
129-
payload = json.dumps(send_data)
130132
if self.queuename:
131133
broker.publish_message(channel=self.queuename, message=payload)
132134
else:
133135
broker.publish_message(message=payload)
134136

135137
replies = []
136138
for channel, payload in broker.process_notify(connected_callback=connected_callback, max_messages=expected_replies, timeout=timeout):
137-
reply_data = json.loads(payload)
139+
try:
140+
reply_data = json.loads(payload)
141+
except Exception:
142+
logger.error(f'Failed to parse response:\n{payload}')
143+
raise
138144
replies.append(reply_data)
139145

140146
logger.info(f'control-and-reply message returned in {time.time() - start} seconds')

dispatcher/factories.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ def get_publisher_from_settings(publish_broker: Optional[str] = None, settings:
9191

9292

9393
def get_control_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides):
94-
publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings)
95-
broker_options = settings.brokers[publish_broker].copy()
94+
"""Returns a Control instance based on the values in settings"""
95+
if 'default_control_broker' in settings.publish:
96+
result_publish_broker = settings.publish['default_control_broker']
97+
else:
98+
result_publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings)
99+
broker_options = settings.brokers[result_publish_broker].copy()
96100
broker_options.update(overrides)
97-
return Control(publish_broker, broker_options)
101+
return Control(result_publish_broker, broker_options)
98102

99103

100104
# ---- Schema generation ----

dispatcher/producers/brokered.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Iterable, Optional
3+
from typing import Iterable, Optional, Union
44

55
from ..protocols import Broker, DispatcherMain
66
from .base import BaseProducer
@@ -16,6 +16,9 @@ def __init__(self, broker: Broker, close_on_exit: bool = True) -> None:
1616
self.dispatcher: Optional[DispatcherMain] = None
1717
super().__init__()
1818

19+
def __str__(self):
20+
return f'brokered-producer-{self.broker}'
21+
1922
async def start_producing(self, dispatcher: DispatcherMain) -> None:
2023
self.production_task = asyncio.create_task(self.produce_forever(dispatcher), name=f'{self.broker.__module__}_production')
2124

@@ -34,12 +37,12 @@ async def produce_forever(self, dispatcher: DispatcherMain) -> None:
3437
self.dispatcher = dispatcher
3538
async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
3639
self.produced_count += 1
37-
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=channel)
40+
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=str(channel))
3841
if reply_to and reply_payload:
39-
await self.notify(channel=reply_to, message=reply_payload)
42+
await self.notify(channel=reply_to, origin=channel, message=reply_payload)
4043

41-
async def notify(self, channel: Optional[str] = None, message: str = '') -> None:
42-
await self.broker.apublish_message(channel=channel, message=message)
44+
async def notify(self, channel: Optional[str] = None, origin: Optional[Union[int, str]] = None, message: str = '') -> None:
45+
await self.broker.apublish_message(channel=channel, origin=origin, message=message)
4346

4447
async def shutdown(self) -> None:
4548
if self.production_task:

dispatcher/protocols.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,25 @@
55
class Broker(Protocol):
66
async def aprocess_notify(
77
self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
8-
) -> AsyncGenerator[tuple[str, str], None]:
8+
) -> AsyncGenerator[tuple[Union[int, str], str], None]:
99
"""The generator of messages from the broker for the dispatcher service
1010
1111
The producer iterates this to produce tasks.
1212
This uses the async connection of the broker.
1313
"""
1414
yield ('', '') # yield affects CPython type https://github.com/python/mypy/pull/18422
1515

16-
async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None:
16+
async def apublish_message(self, channel: Optional[str] = None, origin: Union[int, str, None] = None, message: str = '') -> None:
1717
"""Asynchronously send a message to the broker, used by dispatcher service for reply messages"""
1818
...
1919

2020
async def aclose(self) -> None:
2121
"""Close the asynchronous connection, used by service, and optionally by publishers"""
2222
...
2323

24-
def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
24+
def process_notify(
25+
self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1
26+
) -> Iterator[tuple[Union[int, str], str]]:
2527
"""Synchronous method to generate messages from broker, used for synchronous control-and-reply"""
2628
...
2729

dispatcher/service/main.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,11 @@ async def run_control_action(self, action: str, control_data: Optional[dict] = N
179179

180180
# Give Nones for no reply, or the reply
181181
if reply_to:
182-
logger.info(f"Control action {action} returned {return_data}, sending back reply")
183-
return (reply_to, json.dumps(return_data))
182+
reply_msg = json.dumps(return_data)
183+
logger.info(f"Control action {action} returned message len={len(reply_msg)}, sending back reply")
184+
return (reply_to, reply_msg)
184185
else:
185-
logger.info(f"Control action {action} returned {return_data}, done")
186+
logger.info(f"Control action {action} returned {type(return_data)}, done")
186187
return (None, None)
187188

188189
async def process_message_internal(self, message: dict, producer=None) -> tuple[Optional[str], Optional[str]]:
@@ -201,9 +202,9 @@ async def start_working(self) -> None:
201202
logger.exception(f'Pool {self.pool} failed to start working')
202203
self.events.exit_event.set()
203204

204-
logger.debug('Starting task production')
205205
async with self.fd_lock: # lots of connecting going on here
206206
for producer in self.producers:
207+
logger.debug(f'Starting task production from {producer}')
207208
try:
208209
await producer.start_producing(self)
209210
except Exception:

0 commit comments

Comments
 (0)