Skip to content

Commit 23230ef

Browse files
committed
Add socket broker
Protocol the brokers Run demo dispatcherctl over socket Add socket broker unit tests Add socket broker usage integration tests Work out issues with test scope and server not opening client connections
1 parent 0a93b24 commit 23230ef

File tree

12 files changed

+442
-47
lines changed

12 files changed

+442
-47
lines changed

dispatcher.yml

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ brokers:
1919
- test_channel2
2020
- test_channel3
2121
default_publish_channel: test_channel
22+
socket:
23+
socket_path: demo_dispatcher.sock
2224
producers:
2325
ScheduledProducer:
2426
task_schedule:
@@ -30,4 +32,5 @@ producers:
3032
task_list:
3133
'lambda: print("This task runs on startup")': {}
3234
publish:
35+
default_control_broker: socket
3336
default_broker: pg_notify

dispatcher/brokers/pg_notify.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from dispatcher.utils import resolve_callable
88

9+
from ..protocols import Broker as BrokerProtocol
10+
911
logger = logging.getLogger(__name__)
1012

1113

@@ -32,7 +34,7 @@ def create_connection(**config) -> psycopg.Connection: # type: ignore[no-untype
3234
return connection
3335

3436

35-
class Broker:
37+
class Broker(BrokerProtocol):
3638
NOTIFY_QUERY_TEMPLATE = 'SELECT pg_notify(%s, %s);'
3739

3840
def __init__(
@@ -157,7 +159,7 @@ async def apublish_message_from_cursor(self, cursor: psycopg.AsyncCursor, channe
157159
"""The inner logic of async message publishing where we already have a cursor"""
158160
await cursor.execute(self.NOTIFY_QUERY_TEMPLATE, (channel, message))
159161

160-
async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public
162+
async def apublish_message(self, channel: Optional[str] = None, origin: Union[str, int, None] = '', message: str = '') -> None: # public
161163
"""asyncio way to publish a message, used to send control in control-and-reply
162164
163165
Not strictly necessary for the service itself if it sends replies in the workers,

dispatcher/brokers/socket.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
import asyncio
2+
import logging
3+
import os
4+
import socket
5+
from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union
6+
7+
from ..protocols import Broker as BrokerProtocol
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class Client:
13+
def __init__(self, client_id: int, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
14+
self.client_id = client_id
15+
self.reader = reader
16+
self.writer = writer
17+
self.listen_loop_active = False
18+
# This is needed for task management betewen the client tasks and the main aprocess_notify
19+
# if the client task starts listening, then we can not send replies
20+
# so this waits for the caller method to add replies to stack before continuing
21+
self.yield_clear = asyncio.Event()
22+
self.replies_to_send: list = []
23+
24+
def write(self, message) -> None:
25+
self.writer.write((message + '\n').encode())
26+
27+
def queue_reply(self, reply: str) -> None:
28+
self.replies_to_send.append(reply)
29+
30+
async def send_replies(self):
31+
for reply in self.replies_to_send.copy():
32+
logger.info(f'Sending reply to client_id={self.client_id} len={len(reply)}')
33+
self.write(reply)
34+
else:
35+
logger.info(f'No replies to send to client_id={self.client_id}')
36+
await self.writer.drain()
37+
self.replies_to_send = []
38+
39+
40+
class Broker(BrokerProtocol):
41+
"""A Unix socket client for dispatcher as simple as possible
42+
43+
Because we want to be as simple as possible we do not maintain persistent connections.
44+
So every control-and-reply command will connect and disconnect.
45+
46+
Intended use is for dispatcherctl, so that we may bypass any flake related to pg_notify
47+
for debugging information.
48+
"""
49+
50+
def __init__(self, socket_path: str) -> None:
51+
self.socket_path = socket_path
52+
self.client_ct = 0
53+
self.clients: dict[int, Client] = {}
54+
self.sock: Optional[socket.socket] = None # for synchronous clients
55+
self.incoming_queue: asyncio.Queue = asyncio.Queue()
56+
57+
def __str__(self):
58+
return f'socket-producer-{self.socket_path}'
59+
60+
async def _add_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
61+
client = Client(self.client_ct, reader, writer)
62+
self.clients[self.client_ct] = client
63+
self.client_ct += 1
64+
logger.info(f'Socket client_id={client.client_id} is connected')
65+
66+
try:
67+
client.listen_loop_active = True
68+
while True:
69+
line = await client.reader.readline()
70+
if not line:
71+
break # disconnect
72+
message = line.decode().strip()
73+
await self.incoming_queue.put((client.client_id, message))
74+
# Wait for caller to potentially fill a reply queue
75+
# this should realistically never take more than a trivial amount of time
76+
await asyncio.wait_for(client.yield_clear.wait(), timeout=2)
77+
client.yield_clear.clear()
78+
await client.send_replies()
79+
except asyncio.TimeoutError:
80+
logger.error(f'Unexpected asyncio task management bug for client_id={client.client_id}, exiting')
81+
except asyncio.CancelledError:
82+
logger.debug(f'Ack that reader task for client_id={client.client_id} has been canceled')
83+
except Exception:
84+
logger.exception(f'Exception from reader task for client_id={client.client_id}')
85+
finally:
86+
del self.clients[client.client_id]
87+
client.writer.close()
88+
await client.writer.wait_closed()
89+
logger.info(f'Socket client_id={client.client_id} is disconnected')
90+
91+
async def aprocess_notify(
92+
self, connected_callback: Optional[Callable[[], Coroutine[Any, Any, None]]] = None
93+
) -> AsyncGenerator[tuple[Union[int, str], str], None]:
94+
if os.path.exists(self.socket_path):
95+
logger.debug(f'Deleted pre-existing {self.socket_path}')
96+
os.remove(self.socket_path)
97+
98+
aserver = None
99+
try:
100+
aserver = await asyncio.start_unix_server(self._add_client, self.socket_path)
101+
logger.info(f'Set up socket server on {self.socket_path}')
102+
103+
if connected_callback:
104+
await connected_callback()
105+
106+
while True:
107+
client_id, message = await self.incoming_queue.get()
108+
if (client_id == -1) and (message == 'stop'):
109+
return # internal exit signaling from aclose
110+
111+
yield client_id, message
112+
# trigger reply messages if applicable
113+
client = self.clients.get(client_id)
114+
if client:
115+
logger.info(f'Yield complete for client_id={client_id}')
116+
client.yield_clear.set()
117+
118+
except asyncio.CancelledError:
119+
logger.debug('Ack that general socket server task has been canceled')
120+
finally:
121+
if aserver:
122+
aserver.close()
123+
await aserver.wait_closed()
124+
125+
for client in self.clients.values():
126+
client.writer.close()
127+
await client.writer.wait_closed()
128+
self.clients = {}
129+
130+
if os.path.exists(self.socket_path):
131+
os.remove(self.socket_path)
132+
133+
async def aclose(self) -> None:
134+
"""Send an internal message to the async generator, which will cause it to close the server"""
135+
await self.incoming_queue.put((-1, 'stop'))
136+
137+
async def apublish_message(self, channel: Optional[str] = '', origin: Union[int, str, None] = None, message: str = "") -> None:
138+
if isinstance(origin, int) and origin >= 0:
139+
client = self.clients.get(int(origin))
140+
if client:
141+
if client.listen_loop_active:
142+
logger.info(f'Queued message len={len(message)} for client_id={origin}')
143+
client.queue_reply(message)
144+
else:
145+
logger.warning(f'Not currently listening to client_id={origin}, attempting reply len={len(message)}, but might be dropped')
146+
client.write(message)
147+
await client.writer.drain()
148+
else:
149+
logger.error(f'Client_id={origin} is not currently connected')
150+
else:
151+
# Acting as a client in this case, mostly for tests
152+
logger.info(f'Publishing async socket message len={len(message)} with new connection')
153+
writer = None
154+
try:
155+
_, writer = await asyncio.open_unix_connection(self.socket_path)
156+
writer.write((message + '\n').encode())
157+
await writer.drain()
158+
finally:
159+
if writer:
160+
writer.close()
161+
await writer.wait_closed()
162+
163+
def process_notify(
164+
self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1
165+
) -> Iterator[tuple[Union[int, str], str]]:
166+
try:
167+
with socket.socket(socket.AF_UNIX) as sock:
168+
self.sock = sock
169+
sock.settimeout(timeout)
170+
sock.connect(self.socket_path)
171+
172+
if connected_callback:
173+
connected_callback()
174+
175+
received_ct = 0
176+
buffer = ''
177+
while True:
178+
response = sock.recv(1024).decode().strip()
179+
180+
if response.endswith('}'):
181+
response = buffer + response
182+
buffer = ''
183+
received_ct += 1
184+
yield (0, response)
185+
if received_ct >= max_messages:
186+
return
187+
else:
188+
logger.info(f'Received incomplete message len={len(response)}, adding to buffer')
189+
buffer += response
190+
finally:
191+
self.sock = None
192+
193+
def _publish_from_sock(self, sock, message) -> None:
194+
sock.sendall((message + "\n").encode())
195+
196+
def publish_message(self, channel=None, message=None) -> None:
197+
if self.sock:
198+
logger.info(f'Publishing socket message len={len(message)} via existing connection')
199+
self._publish_from_sock(self.sock, message)
200+
else:
201+
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
202+
sock.connect(self.socket_path)
203+
logger.info(f'Publishing socket message len={len(message)} over new connection')
204+
self._publish_from_sock(sock, message)

dispatcher/control.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,14 @@ async def acontrol(self, command: str, data: Optional[dict] = None) -> None:
9494
await broker.aclose()
9595

9696
def control_with_reply(self, command: str, expected_replies: int = 1, timeout: float = 1.0, data: Optional[dict] = None) -> list[dict]:
97-
logger.info(f'control-and-reply {command} to {self.queuename}')
9897
start = time.time()
9998
reply_queue = Control.generate_reply_queue_name()
10099
send_message = self.create_message(command=command, reply_to=reply_queue, send_data=data)
101100

102-
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
101+
try:
102+
broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue])
103+
except TypeError:
104+
broker = get_broker(self.broker_name, self.broker_config)
103105

104106
def connected_callback() -> None:
105107
broker.publish_message(channel=self.queuename, message=send_message)

dispatcher/factories.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,14 @@ def get_publisher_from_settings(publish_broker: Optional[str] = None, settings:
9191

9292

9393
def get_control_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides):
94-
publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings)
95-
broker_options = settings.brokers[publish_broker].copy()
94+
"""Returns a Control instance based on the values in settings"""
95+
if 'default_control_broker' in settings.publish:
96+
result_publish_broker = settings.publish['default_control_broker']
97+
else:
98+
result_publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings)
99+
broker_options = settings.brokers[result_publish_broker].copy()
96100
broker_options.update(overrides)
97-
return Control(publish_broker, broker_options)
101+
return Control(result_publish_broker, broker_options)
98102

99103

100104
# ---- Schema generation ----

dispatcher/producers/brokered.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Iterable, Optional
3+
from typing import Iterable, Optional, Union
44

55
from ..protocols import Broker, DispatcherMain
66
from .base import BaseProducer
@@ -15,6 +15,9 @@ def __init__(self, broker: Broker) -> None:
1515
self.dispatcher: Optional[DispatcherMain] = None
1616
super().__init__()
1717

18+
def __str__(self):
19+
return f'brokered-producer-{self.broker}'
20+
1821
async def start_producing(self, dispatcher: DispatcherMain) -> None:
1922
self.production_task = asyncio.create_task(self.produce_forever(dispatcher), name=f'{self.broker.__module__}_production')
2023

@@ -33,12 +36,12 @@ async def produce_forever(self, dispatcher: DispatcherMain) -> None:
3336
self.dispatcher = dispatcher
3437
async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback):
3538
self.produced_count += 1
36-
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=channel)
39+
reply_to, reply_payload = await dispatcher.process_message(payload, producer=self, channel=str(channel))
3740
if reply_to and reply_payload:
38-
await self.notify(channel=reply_to, message=reply_payload)
41+
await self.notify(channel=reply_to, origin=channel, message=reply_payload)
3942

40-
async def notify(self, channel: Optional[str] = None, message: str = '') -> None:
41-
await self.broker.apublish_message(channel=channel, message=message)
43+
async def notify(self, channel: Optional[str] = None, origin: Optional[Union[int, str]] = None, message: str = '') -> None:
44+
await self.broker.apublish_message(channel=channel, origin=origin, message=message)
4245

4346
async def shutdown(self) -> None:
4447
if self.production_task:

dispatcher/protocols.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,25 @@ class Broker(Protocol):
1212

1313
async def aprocess_notify(
1414
self, connected_callback: Optional[Optional[Callable[[], Coroutine[Any, Any, None]]]] = None
15-
) -> AsyncGenerator[tuple[str, str], None]:
15+
) -> AsyncGenerator[tuple[Union[int, str], str], None]:
1616
"""The generator of messages from the broker for the dispatcher service
1717
1818
The producer iterates this to produce tasks.
1919
This uses the async connection of the broker.
2020
"""
2121
yield ('', '') # yield affects CPython type https://github.com/python/mypy/pull/18422
2222

23-
async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None:
23+
async def apublish_message(self, channel: Optional[str] = None, origin: Union[int, str, None] = None, message: str = '') -> None:
2424
"""Asynchronously send a message to the broker, used by dispatcher service for reply messages"""
2525
...
2626

2727
async def aclose(self) -> None:
2828
"""Close the asynchronous connection, used by service, and optionally by publishers"""
2929
...
3030

31-
def process_notify(self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1) -> Iterator[tuple[str, str]]:
31+
def process_notify(
32+
self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1
33+
) -> Iterator[tuple[Union[int, str], str]]:
3234
"""Synchronous method to generate messages from broker, used for synchronous control-and-reply"""
3335
...
3436

dispatcher/service/main.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,11 @@ async def run_control_action(self, action: str, control_data: Optional[dict] = N
179179

180180
# Give Nones for no reply, or the reply
181181
if reply_to:
182-
logger.info(f"Control action {action} returned {return_data}, sending back reply")
183-
return (reply_to, json.dumps(return_data))
182+
reply_msg = json.dumps(return_data)
183+
logger.info(f"Control action {action} returned message len={len(reply_msg)}, sending back reply")
184+
return (reply_to, reply_msg)
184185
else:
185-
logger.info(f"Control action {action} returned {return_data}, done")
186+
logger.info(f"Control action {action} returned {type(return_data)}, done")
186187
return (None, None)
187188

188189
async def process_message_internal(self, message: dict, producer: Optional[Producer] = None) -> tuple[Optional[str], Optional[str]]:
@@ -201,9 +202,9 @@ async def start_working(self) -> None:
201202
logger.exception(f'Pool {self.pool} failed to start working')
202203
self.events.exit_event.set()
203204

204-
logger.debug('Starting task production')
205205
async with self.fd_lock: # lots of connecting going on here
206206
for producer in self.producers:
207+
logger.debug(f'Starting task production from {producer}')
207208
try:
208209
await producer.start_producing(self)
209210
except Exception:

schema.json

+7-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
"sync_connection_factory": "typing.Optional[str]",
88
"channels": "typing.Union[tuple, list]",
99
"default_publish_channel": "typing.Optional[str]"
10+
},
11+
"socket": {
12+
"socket_path": "<class 'str'>"
1013
}
1114
},
1215
"producers": {
@@ -26,13 +29,13 @@
2629
"worker_stop_wait": "<class 'float'>",
2730
"worker_removal_wait": "<class 'float'>"
2831
},
32+
"main_kwargs": {
33+
"node_id": "typing.Optional[str]"
34+
},
2935
"process_manager_kwargs": {
3036
"preload_modules": "typing.Optional[list[str]]"
3137
},
32-
"process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']",
33-
"main_kwargs": {
34-
"node_id": "typing.Optional[str]"
35-
}
38+
"process_manager_cls": "typing.Literal['ProcessManager', 'ForkServerManager']"
3639
},
3740
"publish": {
3841
"default_broker": "str"

0 commit comments

Comments
 (0)