|
| 1 | +import asyncio |
| 2 | +import logging |
| 3 | +import os |
| 4 | +import socket |
| 5 | +from typing import Any, AsyncGenerator, Callable, Coroutine, Iterator, Optional, Union |
| 6 | + |
| 7 | +from ..protocols import Broker as BrokerProtocol |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | + |
| 12 | +class Client: |
| 13 | + def __init__(self, client_id: int, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: |
| 14 | + self.client_id = client_id |
| 15 | + self.reader = reader |
| 16 | + self.writer = writer |
| 17 | + self.listen_loop_active = False |
| 18 | + # This is needed for task management betewen the client tasks and the main aprocess_notify |
| 19 | + # if the client task starts listening, then we can not send replies |
| 20 | + # so this waits for the caller method to add replies to stack before continuing |
| 21 | + self.yield_clear = asyncio.Event() |
| 22 | + self.replies_to_send: list = [] |
| 23 | + |
| 24 | + def write(self, message) -> None: |
| 25 | + self.writer.write((message + '\n').encode()) |
| 26 | + |
| 27 | + def queue_reply(self, reply: str) -> None: |
| 28 | + self.replies_to_send.append(reply) |
| 29 | + |
| 30 | + async def send_replies(self): |
| 31 | + for reply in self.replies_to_send.copy(): |
| 32 | + logger.info(f'Sending reply to client_id={self.client_id} len={len(reply)}') |
| 33 | + self.write(reply) |
| 34 | + else: |
| 35 | + logger.info(f'No replies to send to client_id={self.client_id}') |
| 36 | + await self.writer.drain() |
| 37 | + self.replies_to_send = [] |
| 38 | + |
| 39 | + |
| 40 | +class Broker(BrokerProtocol): |
| 41 | + """A Unix socket client for dispatcher as simple as possible |
| 42 | +
|
| 43 | + Because we want to be as simple as possible we do not maintain persistent connections. |
| 44 | + So every control-and-reply command will connect and disconnect. |
| 45 | +
|
| 46 | + Intended use is for dispatcherctl, so that we may bypass any flake related to pg_notify |
| 47 | + for debugging information. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__(self, socket_path: str) -> None: |
| 51 | + self.socket_path = socket_path |
| 52 | + self.client_ct = 0 |
| 53 | + self.clients: dict[int, Client] = {} |
| 54 | + self.sock: Optional[socket.socket] = None # for synchronous clients |
| 55 | + self.incoming_queue: asyncio.Queue = asyncio.Queue() |
| 56 | + |
| 57 | + def __str__(self): |
| 58 | + return f'socket-producer-{self.socket_path}' |
| 59 | + |
| 60 | + async def _add_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: |
| 61 | + client = Client(self.client_ct, reader, writer) |
| 62 | + self.clients[self.client_ct] = client |
| 63 | + self.client_ct += 1 |
| 64 | + logger.info(f'Socket client_id={client.client_id} is connected') |
| 65 | + |
| 66 | + try: |
| 67 | + client.listen_loop_active = True |
| 68 | + while True: |
| 69 | + line = await client.reader.readline() |
| 70 | + if not line: |
| 71 | + break # disconnect |
| 72 | + message = line.decode().strip() |
| 73 | + await self.incoming_queue.put((client.client_id, message)) |
| 74 | + # Wait for caller to potentially fill a reply queue |
| 75 | + # this should realistically never take more than a trivial amount of time |
| 76 | + await asyncio.wait_for(client.yield_clear.wait(), timeout=2) |
| 77 | + client.yield_clear.clear() |
| 78 | + await client.send_replies() |
| 79 | + except asyncio.TimeoutError: |
| 80 | + logger.error(f'Unexpected asyncio task management bug for client_id={client.client_id}, exiting') |
| 81 | + except asyncio.CancelledError: |
| 82 | + logger.debug(f'Ack that reader task for client_id={client.client_id} has been canceled') |
| 83 | + except Exception: |
| 84 | + logger.exception(f'Exception from reader task for client_id={client.client_id}') |
| 85 | + finally: |
| 86 | + del self.clients[client.client_id] |
| 87 | + client.writer.close() |
| 88 | + await client.writer.wait_closed() |
| 89 | + logger.info(f'Socket client_id={client.client_id} is disconnected') |
| 90 | + |
| 91 | + async def aprocess_notify( |
| 92 | + self, connected_callback: Optional[Callable[[], Coroutine[Any, Any, None]]] = None |
| 93 | + ) -> AsyncGenerator[tuple[Union[int, str], str], None]: |
| 94 | + if os.path.exists(self.socket_path): |
| 95 | + logger.debug(f'Deleted pre-existing {self.socket_path}') |
| 96 | + os.remove(self.socket_path) |
| 97 | + |
| 98 | + aserver = None |
| 99 | + try: |
| 100 | + aserver = await asyncio.start_unix_server(self._add_client, self.socket_path) |
| 101 | + logger.info(f'Set up socket server on {self.socket_path}') |
| 102 | + |
| 103 | + if connected_callback: |
| 104 | + await connected_callback() |
| 105 | + |
| 106 | + while True: |
| 107 | + client_id, message = await self.incoming_queue.get() |
| 108 | + if (client_id == -1) and (message == 'stop'): |
| 109 | + return # internal exit signaling from aclose |
| 110 | + |
| 111 | + yield client_id, message |
| 112 | + # trigger reply messages if applicable |
| 113 | + client = self.clients.get(client_id) |
| 114 | + if client: |
| 115 | + logger.info(f'Yield complete for client_id={client_id}') |
| 116 | + client.yield_clear.set() |
| 117 | + |
| 118 | + except asyncio.CancelledError: |
| 119 | + logger.debug('Ack that general socket server task has been canceled') |
| 120 | + finally: |
| 121 | + if aserver: |
| 122 | + aserver.close() |
| 123 | + await aserver.wait_closed() |
| 124 | + |
| 125 | + for client in self.clients.values(): |
| 126 | + client.writer.close() |
| 127 | + await client.writer.wait_closed() |
| 128 | + self.clients = {} |
| 129 | + |
| 130 | + if os.path.exists(self.socket_path): |
| 131 | + os.remove(self.socket_path) |
| 132 | + |
| 133 | + async def aclose(self) -> None: |
| 134 | + """Send an internal message to the async generator, which will cause it to close the server""" |
| 135 | + await self.incoming_queue.put((-1, 'stop')) |
| 136 | + |
| 137 | + async def apublish_message(self, channel: Optional[str] = '', origin: Union[int, str, None] = None, message: str = "") -> None: |
| 138 | + if isinstance(origin, int) and origin >= 0: |
| 139 | + client = self.clients.get(int(origin)) |
| 140 | + if client: |
| 141 | + if client.listen_loop_active: |
| 142 | + logger.info(f'Queued message len={len(message)} for client_id={origin}') |
| 143 | + client.queue_reply(message) |
| 144 | + else: |
| 145 | + logger.warning(f'Not currently listening to client_id={origin}, attempting reply len={len(message)}, but might be dropped') |
| 146 | + client.write(message) |
| 147 | + await client.writer.drain() |
| 148 | + else: |
| 149 | + logger.error(f'Client_id={origin} is not currently connected') |
| 150 | + else: |
| 151 | + # Acting as a client in this case, mostly for tests |
| 152 | + logger.info(f'Publishing async socket message len={len(message)} with new connection') |
| 153 | + writer = None |
| 154 | + try: |
| 155 | + _, writer = await asyncio.open_unix_connection(self.socket_path) |
| 156 | + writer.write((message + '\n').encode()) |
| 157 | + await writer.drain() |
| 158 | + finally: |
| 159 | + if writer: |
| 160 | + writer.close() |
| 161 | + await writer.wait_closed() |
| 162 | + |
| 163 | + def process_notify( |
| 164 | + self, connected_callback: Optional[Callable] = None, timeout: float = 5.0, max_messages: int = 1 |
| 165 | + ) -> Iterator[tuple[Union[int, str], str]]: |
| 166 | + try: |
| 167 | + with socket.socket(socket.AF_UNIX) as sock: |
| 168 | + self.sock = sock |
| 169 | + sock.settimeout(timeout) |
| 170 | + sock.connect(self.socket_path) |
| 171 | + |
| 172 | + if connected_callback: |
| 173 | + connected_callback() |
| 174 | + |
| 175 | + received_ct = 0 |
| 176 | + buffer = '' |
| 177 | + while True: |
| 178 | + response = sock.recv(1024).decode().strip() |
| 179 | + |
| 180 | + if response.endswith('}'): |
| 181 | + response = buffer + response |
| 182 | + buffer = '' |
| 183 | + received_ct += 1 |
| 184 | + yield (0, response) |
| 185 | + if received_ct >= max_messages: |
| 186 | + return |
| 187 | + else: |
| 188 | + logger.info(f'Received incomplete message len={len(response)}, adding to buffer') |
| 189 | + buffer += response |
| 190 | + finally: |
| 191 | + self.sock = None |
| 192 | + |
| 193 | + def _publish_from_sock(self, sock, message) -> None: |
| 194 | + sock.sendall((message + "\n").encode()) |
| 195 | + |
| 196 | + def publish_message(self, channel=None, message=None) -> None: |
| 197 | + if self.sock: |
| 198 | + logger.info(f'Publishing socket message len={len(message)} via existing connection') |
| 199 | + self._publish_from_sock(self.sock, message) |
| 200 | + else: |
| 201 | + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: |
| 202 | + sock.connect(self.socket_path) |
| 203 | + logger.info(f'Publishing socket message len={len(message)} over new connection') |
| 204 | + self._publish_from_sock(sock, message) |
0 commit comments