|
1 | 1 | import logging
|
| 2 | +from typing import Any, Iterable, Optional |
2 | 3 |
|
3 | 4 | import psycopg
|
4 | 5 |
|
| 6 | +from dispatcher.brokers.base import BaseBroker |
| 7 | +from dispatcher.utils import resolve_callable |
| 8 | + |
5 | 9 | logger = logging.getLogger(__name__)
|
6 | 10 |
|
7 | 11 |
|
|
13 | 17 | """
|
14 | 18 |
|
15 | 19 |
|
16 |
| -# TODO: get database data from settings |
17 |
| -# # As Django settings, may not use |
18 |
| -# DATABASES = { |
19 |
| -# "default": { |
20 |
| -# "ENGINE": "django.db.backends.postgresql", |
21 |
| -# "HOST": os.getenv("DB_HOST", "127.0.0.1"), |
22 |
| -# "PORT": os.getenv("DB_PORT", 55777), |
23 |
| -# "USER": os.getenv("DB_USER", "dispatch"), |
24 |
| -# "PASSWORD": os.getenv("DB_PASSWORD", "dispatching"), |
25 |
| -# "NAME": os.getenv("DB_NAME", "dispatch_db"), |
26 |
| -# } |
27 |
| -# } |
28 |
| - |
29 |
| - |
30 |
| -async def aget_connection(config): |
31 |
| - return await psycopg.AsyncConnection.connect(**config, autocommit=True) |
32 |
| - |
33 |
| - |
34 |
| -def get_connection(config): |
35 |
| - return psycopg.Connection.connect(**config, autocommit=True) |
36 |
| - |
37 |
| - |
38 |
| -async def aprocess_notify(connection, channels, connected_callback=None): |
39 |
| - async with connection.cursor() as cur: |
40 |
| - for channel in channels: |
41 |
| - await cur.execute(f"LISTEN {channel};") |
42 |
| - logger.info(f"Set up pg_notify listening on channel '{channel}'") |
43 |
| - |
44 |
| - if connected_callback: |
45 |
| - await connected_callback() |
46 |
| - |
47 |
| - while True: |
48 |
| - logger.debug('Starting listening for pg_notify notifications') |
49 |
| - async for notify in connection.notifies(): |
50 |
| - yield notify.channel, notify.payload |
51 |
| - |
52 |
| - |
53 |
| -async def apublish_message(connection, channel, payload=None): |
54 |
| - async with connection.cursor() as cur: |
55 |
| - if not payload: |
56 |
| - await cur.execute(f'NOTIFY {channel};') |
| 20 | +class PGNotifyBase(BaseBroker): |
| 21 | + |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + channels: Iterable[str] = ('dispatcher_default',), |
| 25 | + default_publish_channel: str = 'dispatcher_default', |
| 26 | + ) -> None: |
| 27 | + self.channels = channels |
| 28 | + self.default_publish_channel = default_publish_channel |
| 29 | + |
| 30 | + |
| 31 | +class AsyncBroker(PGNotifyBase): |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + config: Optional[dict] = None, |
| 35 | + async_connection_factory: Optional[str] = None, |
| 36 | + sync_connection_factory: Optional[str] = None, # noqa |
| 37 | + connection: Optional[psycopg.AsyncConnection] = None, |
| 38 | + **kwargs, |
| 39 | + ) -> None: |
| 40 | + if not (config or async_connection_factory or connection): |
| 41 | + raise RuntimeError('Must specify either config or async_connection_factory') |
| 42 | + |
| 43 | + if config: |
| 44 | + self._config: Optional[dict] = config.copy() |
| 45 | + self._config['autocommit'] = True |
| 46 | + else: |
| 47 | + self._config = None |
| 48 | + |
| 49 | + self._async_connection_factory = async_connection_factory |
| 50 | + self._connection: Optional[Any] = connection |
| 51 | + |
| 52 | + super().__init__(**kwargs) |
| 53 | + |
| 54 | + async def get_connection(self) -> psycopg.AsyncConnection: |
| 55 | + if not self._connection: |
| 56 | + if self._async_connection_factory: |
| 57 | + factory = resolve_callable(self._async_connection_factory) |
| 58 | + if not factory: |
| 59 | + raise RuntimeError(f'Could not import connection factory {self._async_connection_factory}') |
| 60 | + if self._config: |
| 61 | + self._connection = await factory(**self._config) |
| 62 | + else: |
| 63 | + self._connection = await factory() |
| 64 | + elif self._config: |
| 65 | + self._connection = await AsyncBroker.create_connection(self._config) |
| 66 | + else: |
| 67 | + raise RuntimeError('Could not construct async connection for lack of config or factory') |
| 68 | + return self._connection |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + async def create_connection(config) -> psycopg.AsyncConnection: |
| 72 | + return await psycopg.AsyncConnection.connect(**config) |
| 73 | + |
| 74 | + async def aprocess_notify(self, connected_callback=None): |
| 75 | + connection = await self.get_connection() |
| 76 | + async with connection.cursor() as cur: |
| 77 | + for channel in self.channels: |
| 78 | + await cur.execute(f"LISTEN {channel};") |
| 79 | + logger.info(f"Set up pg_notify listening on channel '{channel}'") |
| 80 | + |
| 81 | + if connected_callback: |
| 82 | + await connected_callback() |
| 83 | + |
| 84 | + while True: |
| 85 | + logger.debug('Starting listening for pg_notify notifications') |
| 86 | + async for notify in connection.notifies(): |
| 87 | + yield notify.channel, notify.payload |
| 88 | + |
| 89 | + async def apublish_message(self, channel: Optional[str] = None, payload=None) -> None: |
| 90 | + connection = await self.get_connection() |
| 91 | + if not channel: |
| 92 | + channel = self.default_publish_channel |
| 93 | + async with connection.cursor() as cur: |
| 94 | + if not payload: |
| 95 | + await cur.execute(f'NOTIFY {channel};') |
| 96 | + else: |
| 97 | + await cur.execute(f"NOTIFY {channel}, '{payload}';") |
| 98 | + |
| 99 | + async def aclose(self) -> None: |
| 100 | + if self._connection: |
| 101 | + await self._connection.close() |
| 102 | + self._connection = None |
| 103 | + |
| 104 | + |
| 105 | +connection_save = object() |
| 106 | + |
| 107 | + |
| 108 | +def connection_saver(**config): |
| 109 | + """ |
| 110 | + This mimics the behavior of Django for tests and demos |
| 111 | + Philosophically, this is used by an application that uses an ORM, |
| 112 | + or otherwise has its own connection management logic. |
| 113 | + Dispatcher does not manage connections, so this a simulation of that. |
| 114 | + """ |
| 115 | + if not hasattr(connection_save, '_connection'): |
| 116 | + config['autocommit'] = True |
| 117 | + connection_save._connection = SyncBroker.connect(**config) |
| 118 | + return connection_save._connection |
| 119 | + |
| 120 | + |
| 121 | +class SyncBroker(PGNotifyBase): |
| 122 | + def __init__( |
| 123 | + self, |
| 124 | + config: Optional[dict] = None, |
| 125 | + async_connection_factory: Optional[str] = None, # noqa |
| 126 | + sync_connection_factory: Optional[str] = None, |
| 127 | + connection: Optional[psycopg.Connection] = None, |
| 128 | + **kwargs, |
| 129 | + ) -> None: |
| 130 | + if not (config or sync_connection_factory or connection): |
| 131 | + raise RuntimeError('Must specify either config or async_connection_factory') |
| 132 | + |
| 133 | + if config: |
| 134 | + self._config: Optional[dict] = config.copy() |
| 135 | + self._config['autocommit'] = True |
57 | 136 | else:
|
58 |
| - await cur.execute(f"NOTIFY {channel}, '{payload}';") |
59 |
| - |
60 |
| - |
61 |
| -def get_django_connection(): |
62 |
| - try: |
63 |
| - from django.conf import ImproperlyConfigured |
64 |
| - from django.db import connection as pg_connection |
65 |
| - except ImportError: |
66 |
| - return None |
67 |
| - else: |
68 |
| - try: |
69 |
| - if pg_connection.connection is None: |
70 |
| - pg_connection.connect() |
71 |
| - if pg_connection.connection is None: |
72 |
| - raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') |
73 |
| - return pg_connection.connection |
74 |
| - except ImproperlyConfigured: |
75 |
| - return None |
76 |
| - |
77 |
| - |
78 |
| -def publish_message(queue, message, config=None, connection=None, new_connection=False): |
79 |
| - conn = None |
80 |
| - if connection: |
81 |
| - conn = connection |
82 |
| - |
83 |
| - if (not conn) and (not new_connection): |
84 |
| - conn = get_django_connection() |
85 |
| - |
86 |
| - created_new_conn = False |
87 |
| - if not conn: |
88 |
| - if config is None: |
89 |
| - raise RuntimeError('Could not use Django connection, and no postgres config supplied') |
90 |
| - conn = get_connection(config) |
91 |
| - created_new_conn = True |
92 |
| - |
93 |
| - with conn.cursor() as cur: |
94 |
| - cur.execute('SELECT pg_notify(%s, %s);', (queue, message)) |
95 |
| - |
96 |
| - logger.debug(f'Sent pg_notify message to {queue}') |
97 |
| - |
98 |
| - if created_new_conn: |
99 |
| - conn.close() |
| 137 | + self._config = None |
| 138 | + |
| 139 | + self._sync_connection_factory = sync_connection_factory |
| 140 | + self._connection: Optional[Any] = connection |
| 141 | + super().__init__(**kwargs) |
| 142 | + |
| 143 | + def get_connection(self) -> psycopg.Connection: |
| 144 | + if not self._connection: |
| 145 | + if self._sync_connection_factory: |
| 146 | + factory = resolve_callable(self._sync_connection_factory) |
| 147 | + if not factory: |
| 148 | + raise RuntimeError(f'Could not import connection factory {self._sync_connection_factory}') |
| 149 | + if self._config: |
| 150 | + self._connection = factory(**self._config) |
| 151 | + else: |
| 152 | + self._connection = factory() |
| 153 | + elif self._config: |
| 154 | + self._connection = SyncBroker.create_connection(self._config) |
| 155 | + else: |
| 156 | + raise RuntimeError('Cound not construct synchronous connection for lack of config or factory') |
| 157 | + return self._connection |
| 158 | + |
| 159 | + @staticmethod |
| 160 | + def create_connection(config) -> psycopg.Connection: |
| 161 | + return psycopg.Connection.connect(**config) |
| 162 | + |
| 163 | + def publish_message(self, channel: Optional[str], message: dict) -> None: |
| 164 | + connection = self.get_connection() |
| 165 | + if not channel: |
| 166 | + channel = self.default_publish_channel |
| 167 | + |
| 168 | + with connection.cursor() as cur: |
| 169 | + cur.execute('SELECT pg_notify(%s, %s);', (channel, message)) |
| 170 | + |
| 171 | + logger.debug(f'Sent pg_notify message to {channel}') |
| 172 | + |
| 173 | + def close(self) -> None: |
| 174 | + if self._connection: |
| 175 | + self._connection.close() |
| 176 | + self._connection = None |
0 commit comments