From 2edf0e437ea9aad5f65e28527f83d6c3e54312dd Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 3 Jan 2025 14:55:32 -0500 Subject: [PATCH 01/19] Use full class names in fixtures for simplicity First pass at global config setup Finish running and just starting on tests Add a config test Make half-way progress through demo script Cut some more stuff out of the config Fix failing unit test, handle queue can not be found Review comment to consolidate factory handling Factories refactor Adopt new patterns up to some tests passing Unfinished start on settings serialization --- README.md | 12 +- dispatcher.yml | 31 +-- dispatcher/__init__.py | 21 ++ dispatcher/brokers/__init__.py | 0 dispatcher/brokers/base.py | 25 ++ dispatcher/brokers/pg_notify.py | 273 ++++++++++++++------- dispatcher/cli.py | 20 +- dispatcher/config.py | 73 ++++++ dispatcher/control.py | 23 +- dispatcher/factories.py | 98 ++++++++ dispatcher/main.py | 41 ++-- dispatcher/pool.py | 12 +- dispatcher/producers/__init__.py | 5 + dispatcher/producers/base.py | 9 +- dispatcher/producers/brokered.py | 39 ++- dispatcher/producers/scheduled.py | 8 +- dispatcher/registry.py | 17 +- dispatcher/tasks.py | 8 + dispatcher/worker/task.py | 6 +- tests/conftest.py | 104 +++++--- tests/integration/publish/__init__.py | 0 tests/integration/publish/test_registry.py | 28 +++ tests/integration/test_main.py | 32 ++- tests/unit/conftest.py | 9 - tests/unit/test_config.py | 32 ++- tests/unit/test_publish.py | 19 -- tools/write_messages.py | 46 ++-- 27 files changed, 690 insertions(+), 301 deletions(-) create mode 100644 dispatcher/brokers/__init__.py create mode 100644 dispatcher/brokers/base.py create mode 100644 dispatcher/config.py create mode 100644 dispatcher/factories.py create mode 100644 dispatcher/producers/__init__.py create mode 100644 dispatcher/tasks.py create mode 100644 tests/integration/publish/__init__.py create mode 100644 tests/integration/publish/test_registry.py diff --git a/README.md b/README.md index 045177a..a6db44c 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,8 @@ There are 2 ways to run the dispatcher service: - A CLI entrypoint `dispatcher-standalone` for demo purposes ```python -from dispatcher.main import DispatcherMain -import asyncio +from dispatcher.config import setup +from dispatcher import run_service config = { "producers": { @@ -63,13 +63,9 @@ config = { }, "pool": {"max_workers": 4}, } -loop = asyncio.get_event_loop() -dispatcher = DispatcherMain(config) +setup(config) -try: - loop.run_until_complete(dispatcher.main()) -finally: - loop.close() +run_service() ``` Configuration tells how to connect to postgres, and what channel(s) to listen to. diff --git a/dispatcher.yml b/dispatcher.yml index a34eab1..f00b9ca 100644 --- a/dispatcher.yml +++ b/dispatcher.yml @@ -1,20 +1,23 @@ # Demo config --- -pool: - max_workers: 3 -producers: - brokers: - # List of channels to listen on +service: + max_workers: 4 +brokers: + pg_notify: + config: + conninfo: dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777 + sync_connection_factory: dispatcher.brokers.pg_notify.connection_saver channels: - test_channel - test_channel2 - test_channel3 - pg_notify: - # Database connection details - conninfo: dbname=dispatch_db user=dispatch password=dispatching host=localhost - port=55777 - scheduled: - 'lambda: __import__("time").sleep(1)': - schedule: 3 - 'lambda: __import__("time").sleep(2)': - schedule: 3 + default_publish_channel: test_channel +producers: + ScheduledProducer: + task_schedule: + 'lambda: __import__("time").sleep(1)': + schedule: 3 + 'lambda: __import__("time").sleep(2)': + schedule: 3 +publish: + default_broker: pg_notify diff --git a/dispatcher/__init__.py b/dispatcher/__init__.py index e69de29..ae79987 100644 --- a/dispatcher/__init__.py +++ b/dispatcher/__init__.py @@ -0,0 +1,21 @@ +import asyncio +import logging + +from dispatcher.factories import from_settings + +logger = logging.getLogger(__name__) + + +def run_service() -> None: + """ + Runs dispatcher task service (runs tasks due to messages from brokers and other local producers) + Before calling this you need to configure by calling dispatcher.config.setup + """ + loop = asyncio.get_event_loop() + dispatcher = from_settings() + try: + loop.run_until_complete(dispatcher.main()) + except KeyboardInterrupt: + logger.info('Dispatcher stopped by KeyboardInterrupt') + finally: + loop.close() diff --git a/dispatcher/brokers/__init__.py b/dispatcher/brokers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/dispatcher/brokers/base.py b/dispatcher/brokers/base.py new file mode 100644 index 0000000..241fb47 --- /dev/null +++ b/dispatcher/brokers/base.py @@ -0,0 +1,25 @@ +from abc import abstractmethod +from typing import Optional + + +class BaseBroker: + @abstractmethod + async def connect(self): ... + + @abstractmethod + async def aprocess_notify(self, connected_callback=None): ... + + @abstractmethod + async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: ... + + @abstractmethod + async def aclose(self) -> None: ... + + @abstractmethod + def get_connection(self): ... + + @abstractmethod + def publish_message(self, channel=None, message=None): ... + + @abstractmethod + def close(self): ... diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index bb36032..e59f5d7 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -1,7 +1,11 @@ import logging +from typing import Callable, Iterable, Optional import psycopg +from dispatcher.brokers.base import BaseBroker +from dispatcher.utils import resolve_callable + logger = logging.getLogger(__name__) @@ -13,87 +17,190 @@ """ -# TODO: get database data from settings -# # As Django settings, may not use -# DATABASES = { -# "default": { -# "ENGINE": "django.db.backends.postgresql", -# "HOST": os.getenv("DB_HOST", "127.0.0.1"), -# "PORT": os.getenv("DB_PORT", 55777), -# "USER": os.getenv("DB_USER", "dispatch"), -# "PASSWORD": os.getenv("DB_PASSWORD", "dispatching"), -# "NAME": os.getenv("DB_NAME", "dispatch_db"), -# } -# } - - -async def aget_connection(config): - return await psycopg.AsyncConnection.connect(**config, autocommit=True) - - -def get_connection(config): - return psycopg.Connection.connect(**config, autocommit=True) - - -async def aprocess_notify(connection, channels, connected_callback=None): - async with connection.cursor() as cur: - for channel in channels: - await cur.execute(f"LISTEN {channel};") - logger.info(f"Set up pg_notify listening on channel '{channel}'") - - if connected_callback: - await connected_callback() - - while True: - logger.debug('Starting listening for pg_notify notifications') - async for notify in connection.notifies(): - yield notify.channel, notify.payload - - -async def apublish_message(connection, channel, payload=None): - async with connection.cursor() as cur: - if not payload: - await cur.execute(f'NOTIFY {channel};') +class PGNotifyBase(BaseBroker): + + def __init__( + self, + config: Optional[dict] = None, + channels: Iterable[str] = ('dispatcher_default',), + default_publish_channel: Optional[str] = None, + ) -> None: + """ + channels - listening channels for the service and used for control-and-reply + default_publish_channel - if not specified on task level or in the submission + by default messages will be sent to this channel. + this should be one of the listening channels for messages to be received. + """ + if config: + self._config: dict = config.copy() + self._config['autocommit'] = True + else: + self._config = {} + + self.channels = channels + self.default_publish_channel = default_publish_channel + + def get_publish_channel(self, channel: Optional[str] = None): + "Handle default for the publishing channel for calls to publish_message, shared sync and async" + if channel is not None: + return channel + if self.default_publish_channel is None: + raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') + return self.default_publish_channel + + def get_connection_method(self, factory_path: Optional[str] = None) -> Callable: + "Handles settings, returns a method (async or sync) for getting a new connection" + if factory_path: + factory = resolve_callable(factory_path) + if not factory: + raise RuntimeError(f'Could not import connection factory {factory_path}') + return factory + elif self._config: + return self.create_connection else: - await cur.execute(f"NOTIFY {channel}, '{payload}';") - - -def get_django_connection(): - try: - from django.conf import ImproperlyConfigured - from django.db import connection as pg_connection - except ImportError: - return None - else: - try: - if pg_connection.connection is None: - pg_connection.connect() - if pg_connection.connection is None: - raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') - return pg_connection.connection - except ImproperlyConfigured: - return None - - -def publish_message(queue, message, config=None, connection=None, new_connection=False): - conn = None - if connection: - conn = connection - - if (not conn) and (not new_connection): - conn = get_django_connection() - - created_new_conn = False - if not conn: - if config is None: - raise RuntimeError('Could not use Django connection, and no postgres config supplied') - conn = get_connection(config) - created_new_conn = True - - with conn.cursor() as cur: - cur.execute('SELECT pg_notify(%s, %s);', (queue, message)) - - logger.debug(f'Sent pg_notify message to {queue}') - - if created_new_conn: - conn.close() + raise RuntimeError('Could not construct connection for lack of config or factory') + + def create_connection(self): ... + + +class AsyncBroker(PGNotifyBase): + def __init__( + self, + config: Optional[dict] = None, + async_connection_factory: Optional[str] = None, + sync_connection_factory: Optional[str] = None, # noqa + connection: Optional[psycopg.AsyncConnection] = None, + **kwargs, + ) -> None: + if not (config or async_connection_factory or connection): + raise RuntimeError('Must specify either config or async_connection_factory') + + self._async_connection_factory = async_connection_factory + self._connection = connection + + super().__init__(config=config, **kwargs) + + async def get_connection(self) -> psycopg.AsyncConnection: + if not self._connection: + factory = self.get_connection_method(factory_path=self._async_connection_factory) + connection = await factory(**self._config) + self._connection = connection + return connection # slightly weird due to MyPY + return self._connection + + @staticmethod + async def create_connection(**config) -> psycopg.AsyncConnection: + return await psycopg.AsyncConnection.connect(**config) + + async def aprocess_notify(self, connected_callback=None): + connection = await self.get_connection() + async with connection.cursor() as cur: + for channel in self.channels: + await cur.execute(f"LISTEN {channel};") + logger.info(f"Set up pg_notify listening on channel '{channel}'") + + if connected_callback: + await connected_callback() + + while True: + logger.debug('Starting listening for pg_notify notifications') + async for notify in connection.notifies(): + yield notify.channel, notify.payload + + async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: + connection = await self.get_connection() + channel = self.get_publish_channel(channel) + + async with connection.cursor() as cur: + if not message: + await cur.execute(f'NOTIFY {channel};') + else: + await cur.execute(f"NOTIFY {channel}, '{message}';") + + logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') + + async def aclose(self) -> None: + if self._connection: + await self._connection.close() + self._connection = None + + +class SyncBroker(PGNotifyBase): + def __init__( + self, + config: Optional[dict] = None, + async_connection_factory: Optional[str] = None, # noqa + sync_connection_factory: Optional[str] = None, + connection: Optional[psycopg.Connection] = None, + **kwargs, + ) -> None: + if not (config or sync_connection_factory or connection): + raise RuntimeError('Must specify either config or async_connection_factory') + + self._sync_connection_factory = sync_connection_factory + self._connection = connection + super().__init__(config=config, **kwargs) + + def get_connection(self) -> psycopg.Connection: + if not self._connection: + factory = self.get_connection_method(factory_path=self._sync_connection_factory) + connection = factory(**self._config) + self._connection = connection + return connection + return self._connection + + @staticmethod + def create_connection(**config) -> psycopg.Connection: + return psycopg.Connection.connect(**config) + + def publish_message(self, channel: Optional[str] = None, message: str = '') -> None: + connection = self.get_connection() + channel = self.get_publish_channel(channel) + + with connection.cursor() as cur: + if message: + cur.execute('SELECT pg_notify(%s, %s);', (channel, message)) + else: + cur.execute(f'NOTIFY {channel};') + + logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') + + def close(self) -> None: + if self._connection: + self._connection.close() + self._connection = None + + +class ConnectionSaver: + def __init__(self) -> None: + self._connection: Optional[psycopg.Connection] = None + self._async_connection: Optional[psycopg.AsyncConnection] = None + + +connection_save = ConnectionSaver() + + +def connection_saver(**config) -> psycopg.Connection: + """ + This mimics the behavior of Django for tests and demos + Philosophically, this is used by an application that uses an ORM, + or otherwise has its own connection management logic. + Dispatcher does not manage connections, so this a simulation of that. + """ + if connection_save._connection is None: + config['autocommit'] = True + connection_save._connection = SyncBroker.create_connection(**config) + return connection_save._connection + + +async def async_connection_saver(**config) -> psycopg.AsyncConnection: + """ + This mimics the behavior of Django for tests and demos + Philosophically, this is used by an application that uses an ORM, + or otherwise has its own connection management logic. + Dispatcher does not manage connections, so this a simulation of that. + """ + if connection_save._async_connection is None: + config['autocommit'] = True + connection_save._async_connection = await AsyncBroker.create_connection(**config) + return connection_save._async_connection diff --git a/dispatcher/cli.py b/dispatcher/cli.py index 3b77dca..0b9276a 100644 --- a/dispatcher/cli.py +++ b/dispatcher/cli.py @@ -1,12 +1,10 @@ import argparse -import asyncio import logging import os import sys -import yaml - -from dispatcher.main import DispatcherMain +from dispatcher import run_service +from dispatcher.config import setup logger = logging.getLogger(__name__) @@ -32,16 +30,6 @@ def standalone() -> None: logger.debug(f"Configured standard out logging at {args.log_level} level") - with open(args.config, 'r') as f: - config_content = f.read() - - config = yaml.safe_load(config_content) + setup(file_path=args.config) - loop = asyncio.get_event_loop() - dispatcher = DispatcherMain(config) - try: - loop.run_until_complete(dispatcher.main()) - except KeyboardInterrupt: - logger.info('CLI entry point leaving') - finally: - loop.close() + run_service() diff --git a/dispatcher/config.py b/dispatcher/config.py new file mode 100644 index 0000000..1c76a6b --- /dev/null +++ b/dispatcher/config.py @@ -0,0 +1,73 @@ +import os +from contextlib import contextmanager +from typing import Optional + +import yaml + + +class DispatcherSettings: + def __init__(self, config: dict) -> None: + self.brokers: dict = config.get('brokers', {}) + self.producers: dict = config.get('producers', {}) + self.service: dict = config.get('service', {'max_workers': 3}) + self.publish: dict = config.get('publish', {}) + # TODO: firmly planned sections of config for later + # self.callbacks: dict = config.get('callbacks', {}) + # self.options: dict = config.get('options', {}) + + def serialize(self): + return dict( + brokers=self.brokers, + producers=self.producers, + service=self.service, + publish=self.publish + ) + + +def settings_from_file(path: str) -> DispatcherSettings: + with open(path, 'r') as f: + config_content = f.read() + + config = yaml.safe_load(config_content) + return DispatcherSettings(config) + + +def settings_from_env() -> DispatcherSettings: + if file_path := os.getenv('DISPATCHER_CONFIG_FILE'): + return settings_from_file(file_path) + raise RuntimeError('Dispatcher not configured, set DISPATCHER_CONFIG_FILE or call dispatcher.config.setup') + + +class LazySettings: + def __init__(self) -> None: + self._wrapped: Optional[DispatcherSettings] = None + + def __getattr__(self, name): + if self._wrapped is None: + self._setup() + return getattr(self._wrapped, name) + + def _setup(self) -> None: + self._wrapped = settings_from_env() + + +settings = LazySettings() + + +def setup(config: Optional[dict] = None, file_path: Optional[str] = None): + if config: + settings._wrapped = DispatcherSettings(config) + elif file_path: + settings._wrapped = settings_from_file(file_path) + else: + settings._wrapped = settings_from_env() + + +@contextmanager +def temporary_settings(config): + prior_settings = settings._wrapped + try: + settings._wrapped = DispatcherSettings(config) + yield settings + finally: + settings._wrapped = prior_settings diff --git a/dispatcher/control.py b/dispatcher/control.py index 5e52260..424d85f 100644 --- a/dispatcher/control.py +++ b/dispatcher/control.py @@ -5,7 +5,8 @@ import uuid from types import SimpleNamespace -from dispatcher.producers.brokered import BrokeredProducer +from dispatcher.factories import get_async_publisher_from_settings, get_sync_publisher_from_settings +from dispatcher.producers import BrokeredProducer logger = logging.getLogger('awx.main.dispatch.control') @@ -28,7 +29,7 @@ def __init__(self, queuename, send_data, expected_replies): def _create_events(self): return SimpleNamespace(exit_event=asyncio.Event()) - async def process_message(self, payload, broker=None, channel=None): + async def process_message(self, payload, producer=None, channel=None): self.received_replies.append(payload) if self.expected_replies and (len(self.received_replies) >= self.expected_replies): self.events.exit_event.set() @@ -55,10 +56,9 @@ def fatal_error_callback(self, *args): class Control(object): - def __init__(self, queue, config=None, async_connection=None): + def __init__(self, queue, config=None): self.queuename = queue self.config = config - self.async_connection = async_connection def running(self, *args, **kwargs): return self.control_with_reply('running', *args, **kwargs) @@ -90,11 +90,8 @@ async def acontrol_with_reply_internal(self, producer, send_data, expected_repli return [json.loads(payload) for payload in control_callbacks.received_replies] def make_producer(self, reply_queue): - if self.async_connection: - conn_kwargs = {'connection': self.async_connection} - else: - conn_kwargs = {'config': self.config} - return BrokeredProducer(broker='pg_notify', channels=[reply_queue], **conn_kwargs) + broker = get_async_publisher_from_settings(channels=[reply_queue]) + return BrokeredProducer(broker, close_on_exit=True) async def acontrol_with_reply(self, command, expected_replies=1, timeout=1, data=None): reply_queue = Control.generate_reply_queue_name() @@ -118,9 +115,6 @@ def control_with_reply(self, command, expected_replies=1, timeout=1, data=None): start = time.time() reply_queue = Control.generate_reply_queue_name() - if (not self.config) and (not self.async_connection): - raise RuntimeError('Must use a new psycopg connection to do control-and-reply') - send_data = {'control': command, 'reply_to': reply_queue} if data: send_data['control_data'] = data @@ -139,11 +133,10 @@ def control_with_reply(self, command, expected_replies=1, timeout=1, data=None): # NOTE: this is the synchronous version, only to be used for no-reply def control(self, command, data=None): - from dispatcher.brokers.pg_notify import publish_message - send_data = {'control': command} if data: send_data['control_data'] = data payload = json.dumps(send_data) - publish_message(self.queuename, payload, config=self.config) + broker = get_sync_publisher_from_settings() + broker.publish_message(channel=self.queuename, message=payload) diff --git a/dispatcher/factories.py b/dispatcher/factories.py new file mode 100644 index 0000000..152567f --- /dev/null +++ b/dispatcher/factories.py @@ -0,0 +1,98 @@ +import importlib +from types import ModuleType +from typing import Iterable, Optional + +from dispatcher import producers +from dispatcher.brokers.base import BaseBroker +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings +from dispatcher.main import DispatcherMain + +""" +Creates objects from settings, +This is kept separate from the settings and the class definitions themselves, +which is to avoid import dependencies. +""" + +# ---- Service objects ---- + + +def get_broker_module(broker_name) -> ModuleType: + "Static method to alias import_module so we use a consistent import path" + return importlib.import_module(f'dispatcher.brokers.{broker_name}') + + +def get_async_broker(broker_name: str, broker_config: dict, **overrides) -> BaseBroker: + """ + Given the name of the broker in the settings, and the data under that entry in settings, + return the asyncio broker object. + """ + broker_module = get_broker_module(broker_name) + kwargs = broker_config.copy() + kwargs.update(overrides) + return broker_module.AsyncBroker(**kwargs) + + +def producers_from_settings(settings: LazySettings = global_settings) -> Iterable[producers.BaseProducer]: + producer_objects = [] + for broker_name, broker_kwargs in settings.brokers.items(): + broker = get_async_broker(broker_name, broker_kwargs) + producer = producers.BrokeredProducer(broker=broker) + producer_objects.append(producer) + + for producer_cls, producer_kwargs in settings.producers.items(): + producer_objects.append(getattr(producers, producer_cls)(**producer_kwargs)) + + return producer_objects + + +def from_settings(settings: LazySettings = global_settings) -> DispatcherMain: + """ + Returns the main dispatcher object, used for running the background task service. + You could initialize this yourself, but using the shared settings allows for consistency + between the service, publisher, and any other interacting processes. + """ + producers = producers_from_settings(settings=settings) + return DispatcherMain(settings.service, producers) + + +# ---- Publisher objects ---- + + +def get_sync_broker(broker_name, broker_config) -> BaseBroker: + """ + Given the name of the broker in the settings, and the data under that entry in settings, + return the synchronous broker object. + """ + broker_module = get_broker_module(broker_name) + return broker_module.SyncBroker(**broker_config) + + +def _get_publisher_broker_name(publish_broker: Optional[str] = None, settings: LazySettings = global_settings) -> str: + if publish_broker: + return publish_broker + elif len(settings.brokers) == 1: + return list(settings.brokers.keys())[0] + elif 'default_broker' in settings.publish: + return settings.publish['default_broker'] + else: + raise RuntimeError(f'Could not determine which broker to publish with between options {list(settings.brokers.keys())}') + + +def get_sync_publisher_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides) -> BaseBroker: + publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) + + return get_sync_broker(publish_broker, settings.brokers[publish_broker], **overrides) + + +def get_async_publisher_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides) -> BaseBroker: + """ + An asynchronous publisher is the ideal choice for submitting control-and-reply actions. + This returns an asyncio broker of the default publisher type. + + If channels are specified, these completely replace the channel list from settings. + For control-and-reply, this will contain only the reply_to channel, to not receive + unrelated traffic. + """ + publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) + return get_async_broker(publish_broker, settings.brokers[publish_broker], **overrides) diff --git a/dispatcher/main.py b/dispatcher/main.py index 419fd26..2327165 100644 --- a/dispatcher/main.py +++ b/dispatcher/main.py @@ -3,13 +3,10 @@ import logging import signal from types import SimpleNamespace -from typing import Optional, Union +from typing import Iterable, Optional from dispatcher.pool import WorkerPool -from dispatcher.producers.base import BaseProducer -from dispatcher.producers.brokered import BrokeredProducer -from dispatcher.producers.scheduled import ScheduledProducer -from dispatcher.utils import MODULE_METHOD_DELIMITER +from dispatcher.producers import BaseProducer logger = logging.getLogger(__name__) @@ -79,7 +76,7 @@ def __init__(self) -> None: class DispatcherMain: - def __init__(self, config: dict): + def __init__(self, service_config: dict, producers: Iterable[BaseProducer]): self.delayed_messages: list[SimpleNamespace] = [] self.received_count = 0 self.control_count = 0 @@ -88,24 +85,17 @@ def __init__(self, config: dict): # Lock for file descriptor mgmnt - hold lock when forking or connecting, to avoid DNS hangs # psycopg is well-behaved IFF you do not connect while forking, compare to AWX __clean_on_fork__ self.fd_lock = asyncio.Lock() - self.pool = WorkerPool(config.get('pool', {}).get('max_workers', 3), self.fd_lock) - - # Initialize all the producers, this should not start anything, just establishes objects - self.producers: list[Union[ScheduledProducer, BrokeredProducer]] = [] - if 'producers' in config: - producer_config = config['producers'] - if 'brokers' in producer_config: - for broker_name, broker_config in producer_config['brokers'].items(): - # TODO: import from the broker module here, some importlib stuff - # TODO: make channels specific to broker, probably - if broker_name != 'pg_notify': - continue - self.producers.append(BrokeredProducer(broker=broker_name, config=broker_config, channels=producer_config['brokers']['channels'])) - if 'scheduled' in producer_config: - self.producers.append(ScheduledProducer(producer_config['scheduled'])) + self.pool = WorkerPool(fd_lock=self.fd_lock, **service_config) + + # Set all the producers, this should still not start anything, just establishes objects + self.producers = producers self.events: DispatcherEvents = DispatcherEvents() + def _create_events(self): + "Benchmark tests have to re-create this because they use same object in different event loops" + return SimpleNamespace(exit_event=asyncio.Event()) + def fatal_error_callback(self, *args) -> None: """Method to connect to error callbacks of other tasks, will kick out of main loop""" if self.shutting_down: @@ -183,7 +173,7 @@ def create_delayed_task(self, message: dict) -> None: capsule.task = new_task self.delayed_messages.append(capsule) - async def process_message(self, payload: dict, broker: Optional[BrokeredProducer] = None, channel: Optional[str] = None) -> None: + async def process_message(self, payload: dict, producer: Optional[BaseProducer] = None, channel: Optional[str] = None) -> None: # Convert payload from client into python dict # TODO: more structured validation of the incoming payload from publishers if isinstance(payload, str): @@ -208,9 +198,9 @@ async def process_message(self, payload: dict, broker: Optional[BrokeredProducer # NOTE: control messages with reply should never be delayed, document this for users self.create_delayed_task(message) else: - await self.process_message_internal(message, broker=broker) + await self.process_message_internal(message, producer=producer) - async def process_message_internal(self, message: dict, broker=None) -> None: + async def process_message_internal(self, message: dict, producer=None) -> None: if 'control' in message: method = getattr(self.ctl_tasks, message['control']) control_data = message.get('control_data', {}) @@ -220,9 +210,8 @@ async def process_message_internal(self, message: dict, broker=None) -> None: self.control_count += 1 await self.pool.dispatch_task( { - 'task': f'dispatcher.brokers.{broker.broker}{MODULE_METHOD_DELIMITER}publish_message', + 'task': 'dispatcher.tasks.reply_to_control', 'args': [message['reply_to'], json.dumps(returned)], - 'kwargs': {'config': broker.config, 'new_connection': True}, 'uuid': f'control-{self.control_count}', 'control': 'reply', # for record keeping } diff --git a/dispatcher/pool.py b/dispatcher/pool.py index 0b25909..6506d50 100644 --- a/dispatcher/pool.py +++ b/dispatcher/pool.py @@ -6,17 +6,19 @@ from dispatcher.process import ProcessManager, ProcessProxy from dispatcher.utils import DuplicateBehavior, MessageAction +from dispatcher.config import settings as global_settings, LazySettings logger = logging.getLogger(__name__) class PoolWorker: - def __init__(self, worker_id: int, process: ProcessProxy) -> None: + def __init__(self, worker_id: int, process: ProcessProxy, settings: LazySettings = global_settings) -> None: self.worker_id = worker_id self.process = process self.current_task: Optional[dict] = None self.started_at: Optional[int] = None self.is_active_cancel: bool = False + self.settings_stash: dict = settings.serialize() # Tracking information for worker self.finished_count = 0 @@ -94,8 +96,8 @@ def __init__(self) -> None: class WorkerPool: - def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None): - self.num_workers = num_workers + def __init__(self, max_workers: int, fd_lock: Optional[asyncio.Lock] = None): + self.max_workers = max_workers self.workers: dict[int, PoolWorker] = {} self.next_worker_id = 0 self.process_manager = ProcessManager() @@ -132,7 +134,7 @@ async def start_working(self, dispatcher) -> None: async def manage_workers(self) -> None: """Enforces worker policy like min and max workers, and later, auto scale-down""" while not self.shutting_down: - while len(self.workers) < self.num_workers: + while len(self.workers) < self.max_workers: await self.up() # TODO: if all workers are busy, queue has unblocked work, below max_workers @@ -186,7 +188,7 @@ async def manage_timeout(self) -> None: self.events.timeout_event.clear() async def up(self) -> None: - process = self.process_manager.create_process((self.next_worker_id,)) + process = self.process_manager.create_process((self.settings_stash, self.next_worker_id,)) worker = PoolWorker(self.next_worker_id, process) self.workers[self.next_worker_id] = worker self.next_worker_id += 1 diff --git a/dispatcher/producers/__init__.py b/dispatcher/producers/__init__.py new file mode 100644 index 0000000..050e1a2 --- /dev/null +++ b/dispatcher/producers/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseProducer +from .brokered import BrokeredProducer +from .scheduled import ScheduledProducer + +__all__ = ['BaseProducer', 'BrokeredProducer', 'ScheduledProducer'] diff --git a/dispatcher/producers/base.py b/dispatcher/producers/base.py index c1da07d..6ebfd74 100644 --- a/dispatcher/producers/base.py +++ b/dispatcher/producers/base.py @@ -7,4 +7,11 @@ def __init__(self): class BaseProducer: - pass + + def __init__(self) -> None: + self.events = ProducerEvents() + self.produced_count = 0 + + async def start_producing(self, dispatcher) -> None: ... + + async def shutdown(self): ... diff --git a/dispatcher/producers/brokered.py b/dispatcher/producers/brokered.py index c7797e4..3b276f3 100644 --- a/dispatcher/producers/brokered.py +++ b/dispatcher/producers/brokered.py @@ -2,26 +2,21 @@ import logging from typing import Optional -from dispatcher.brokers.pg_notify import aget_connection, aprocess_notify, apublish_message -from dispatcher.producers.base import BaseProducer, ProducerEvents +from dispatcher.brokers.base import BaseBroker +from dispatcher.producers.base import BaseProducer logger = logging.getLogger(__name__) class BrokeredProducer(BaseProducer): - def __init__(self, broker: str = 'pg_notify', config: Optional[dict] = None, channels: tuple = (), connection=None) -> None: - self.events = ProducerEvents() + def __init__(self, broker: BaseBroker, close_on_exit: bool = True) -> None: self.production_task: Optional[asyncio.Task] = None self.broker = broker - self.config = config - self.channels = channels - self.connection = connection - self.old_connection = bool(connection) + self.close_on_exit = close_on_exit self.dispatcher = None + super().__init__() async def start_producing(self, dispatcher) -> None: - await self.connect() - self.production_task = asyncio.create_task(self.produce_forever(dispatcher), name=f'{self.broker}_production') # TODO: implement connection retry logic self.production_task.add_done_callback(dispatcher.fatal_error_callback) @@ -31,22 +26,20 @@ def all_tasks(self) -> list[asyncio.Task]: return [self.production_task] return [] - async def connect(self): - if self.connection is None: - self.connection = await aget_connection(self.config) - async def connected_callback(self) -> None: - self.events.ready_event.set() + if self.events: + self.events.ready_event.set() if self.dispatcher: await self.dispatcher.connected_callback(self) async def produce_forever(self, dispatcher) -> None: self.dispatcher = dispatcher - async for channel, payload in aprocess_notify(self.connection, self.channels, connected_callback=self.connected_callback): - await dispatcher.process_message(payload, broker=self, channel=channel) + async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback): + self.produced_count += 1 + await dispatcher.process_message(payload, producer=self, channel=channel) - async def notify(self, channel: str, payload: Optional[str] = None) -> None: - await apublish_message(self.connection, channel, payload=payload) + async def notify(self, channel: str, payload: str = '') -> None: + await self.broker.apublish_message(channel=channel, message=payload) async def shutdown(self) -> None: if self.production_task: @@ -60,8 +53,6 @@ async def shutdown(self) -> None: if not hasattr(self.production_task, '_dispatcher_tb_logged'): logger.exception(f'Broker {self.broker} shutdown saw an unexpected exception from production task') self.production_task = None - if not self.old_connection: - if self.connection: - logger.debug(f'Closing {self.broker} connection') - await self.connection.close() - self.connection = None + if self.close_on_exit: + logger.debug(f'Closing {self.broker} connection') + await self.broker.aclose() diff --git a/dispatcher/producers/scheduled.py b/dispatcher/producers/scheduled.py index 6e50512..298dc61 100644 --- a/dispatcher/producers/scheduled.py +++ b/dispatcher/producers/scheduled.py @@ -1,17 +1,16 @@ import asyncio import logging -from dispatcher.producers.base import BaseProducer, ProducerEvents +from dispatcher.producers.base import BaseProducer logger = logging.getLogger(__name__) class ScheduledProducer(BaseProducer): def __init__(self, task_schedule: dict): - self.events = ProducerEvents() self.task_schedule = task_schedule self.scheduled_tasks: list[asyncio.Task] = [] - self.produced_count = 0 + super().__init__() async def start_producing(self, dispatcher) -> None: for task_name, options in self.task_schedule.items(): @@ -19,7 +18,8 @@ async def start_producing(self, dispatcher) -> None: schedule_task = asyncio.create_task(self.run_schedule_forever(task_name, per_seconds, dispatcher)) self.scheduled_tasks.append(schedule_task) schedule_task.add_done_callback(dispatcher.fatal_error_callback) - self.events.ready_event.set() + if self.events: + self.events.ready_event.set() def all_tasks(self) -> list[asyncio.Task]: return self.scheduled_tasks diff --git a/dispatcher/registry.py b/dispatcher/registry.py index 23f602a..b92a2eb 100644 --- a/dispatcher/registry.py +++ b/dispatcher/registry.py @@ -7,6 +7,7 @@ from uuid import uuid4 from dispatcher.utils import MODULE_METHOD_DELIMITER, DispatcherCallable, resolve_callable +from dispatcher.config import settings as global_settings, DispatcherSettings logger = logging.getLogger(__name__) @@ -79,23 +80,21 @@ def get_async_body( return body - def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, connection=None, config=None, **kw) -> Tuple[dict, str]: + def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, settings: DispatcherSettings = global_settings, **kw) -> Tuple[dict, str]: queue = queue or self.submission_defaults.get('queue') - if not queue: - msg = f'{self.fn}: Queue value required and may not be None' - logger.error(msg) - raise ValueError(msg) if callable(queue): queue = queue() obj = self.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw) - # TODO: before sending, consult an app-specific callback if configured - from dispatcher.brokers.pg_notify import publish_message + from dispatcher.factories import get_sync_publisher_from_settings - # NOTE: the kw will communicate things in the database connection data - publish_message(queue, json.dumps(obj), connection=connection, config=config) + broker = get_sync_publisher_from_settings(settings=settings) + + # TODO: exit if a setting is applied to disable publishing + + broker.publish_message(channel=queue, message=json.dumps(obj)) return (obj, queue) diff --git a/dispatcher/tasks.py b/dispatcher/tasks.py new file mode 100644 index 0000000..a16d75d --- /dev/null +++ b/dispatcher/tasks.py @@ -0,0 +1,8 @@ +from dispatcher.factories import get_sync_publisher_from_settings +from dispatcher.publish import task + + +@task() +def reply_to_control(reply_channel: str, message: str): + broker = get_sync_publisher_from_settings() + broker.publish_message(channel=reply_channel, message=message) diff --git a/dispatcher/worker/task.py b/dispatcher/worker/task.py index 420caf0..ebafc49 100644 --- a/dispatcher/worker/task.py +++ b/dispatcher/worker/task.py @@ -9,6 +9,7 @@ from queue import Empty as QueueEmpty from dispatcher.registry import registry +from dispatcher.config import setup logger = logging.getLogger(__name__) @@ -197,11 +198,14 @@ def get_shutdown_message(self): return {"worker": self.worker_id, "event": "shutdown"} -def work_loop(worker_id: int, queue: multiprocessing.Queue, finished_queue): +def work_loop(settings: dict, worker_id: int, queue: multiprocessing.Queue, finished_queue): """ Worker function that processes messages from the queue and sends confirmation to the finished_queue once done. """ + # Load settings passed from parent + # this assures that workers are all configured the same + setup(config=settings) worker = TaskWorker(worker_id) # TODO: add an app callback here to set connection name and things like that diff --git a/tests/conftest.py b/tests/conftest.py index 4accbfa..8ad89b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,10 @@ from dispatcher.main import DispatcherMain from dispatcher.control import Control -from dispatcher.brokers.pg_notify import apublish_message, aget_connection, get_connection +from dispatcher.brokers.pg_notify import SyncBroker, AsyncBroker +from dispatcher.registry import DispatcherMethodRegistry +from dispatcher.config import temporary_settings, DispatcherSettings +from dispatcher.factories import from_settings # List of channels to listen on @@ -20,63 +23,102 @@ # Database connection details CONNECTION_STRING = "dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777" -BASIC_CONFIG = {"producers": {"brokers": {"pg_notify": {"conninfo": CONNECTION_STRING}, "channels": CHANNELS}}, "pool": {"max_workers": 3}} +BASIC_CONFIG = { + "brokers": { + "pg_notify": { + "channels": CHANNELS, + "config": {'conninfo': CONNECTION_STRING}, + "sync_connection_factory": "dispatcher.brokers.pg_notify.connection_saver", + # "async_connection_factory": "dispatcher.brokers.pg_notify.async_connection_saver", + } + }, + "pool": { + "max_workers": 3 + } +} + + +@contextlib.asynccontextmanager +async def aconnection_for_test(): + conn = None + try: + conn = await AsyncBroker.create_connection(conninfo=CONNECTION_STRING, autocommit=True) + + # Make sure database is running to avoid deadlocks which can come + # from using the loop provided by pytest asyncio + async with conn.cursor() as cursor: + await cursor.execute('SELECT 1') + await cursor.fetchall() + + yield conn + finally: + if conn: + await conn.close() + + +@pytest.fixture +def conn_config(): + return {'conninfo': CONNECTION_STRING} @pytest.fixture def pg_dispatcher() -> DispatcherMain: - return DispatcherMain(BASIC_CONFIG) + # We can not reuse the connection between tests + config = BASIC_CONFIG.copy() + config['brokers']['pg_notify'].pop('async_connection_factory') + return DispatcherMain(config) + + +@pytest.fixture +def test_settings(): + return DispatcherSettings(BASIC_CONFIG) + + +@pytest.fixture +def test_setup(): + with temporary_settings(BASIC_CONFIG): + yield @pytest_asyncio.fixture(loop_scope="function", scope="function") -async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]: +async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: + dispatcher = None try: - dispatcher = DispatcherMain(BASIC_CONFIG) + dispatcher = from_settings(settings=test_settings) await dispatcher.connect_signals() await dispatcher.start_working() await dispatcher.wait_for_producers_ready() + assert dispatcher.pool.finished_count == 0 # sanity + yield dispatcher finally: - await dispatcher.shutdown() - await dispatcher.cancel_tasks() + if dispatcher: + await dispatcher.shutdown() + await dispatcher.cancel_tasks() @pytest_asyncio.fixture(loop_scope="function", scope="function") async def pg_message(psycopg_conn) -> Callable: async def _rf(message, channel='test_channel'): - await apublish_message(psycopg_conn, channel, message) + broker = AsyncBroker(connection=psycopg_conn) + await broker.apublish_message(channel=channel, message=message) return _rf -@pytest.fixture -def conn_config(): - return {'conninfo': CONNECTION_STRING} - - -@contextlib.asynccontextmanager -async def aconnection_for_test(): - conn = None - try: - conn = await aget_connection({'conninfo': CONNECTION_STRING}) - yield conn - finally: - if conn: - await conn.close() - - @pytest_asyncio.fixture(loop_scope="function", scope="function") -async def pg_control() -> AsyncIterator[Control]: - """This has to use a different connection from dispatcher itself - - because psycopg will pool async connections, meaning that submission - for the control task would be blocked by the listening query of the dispatcher itself""" - async with aconnection_for_test() as conn: - yield Control('test_channel', async_connection=conn) +async def pg_control(test_setup) -> AsyncIterator[Control]: + yield Control(queue='test_channel') @pytest_asyncio.fixture(loop_scope="function", scope="function") async def psycopg_conn(): async with aconnection_for_test() as conn: yield conn + + +@pytest.fixture +def registry() -> DispatcherMethodRegistry: + "Return a fresh registry, separate from the global one, for testing" + return DispatcherMethodRegistry() diff --git a/tests/integration/publish/__init__.py b/tests/integration/publish/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/publish/test_registry.py b/tests/integration/publish/test_registry.py new file mode 100644 index 0000000..412b8e8 --- /dev/null +++ b/tests/integration/publish/test_registry.py @@ -0,0 +1,28 @@ +from unittest import mock + +import pytest + +from dispatcher.publish import task +from dispatcher.config import temporary_settings + + +def test_apply_async_with_no_queue(registry, conn_config): + @task(registry=registry) + def test_method(): + return + + dmethod = registry.get_from_callable(test_method) + + # These settings do not specify a default channel, that is the main point + with temporary_settings({'brokers': {'pg_notify': {'config': conn_config}}}): + + # Can not run a method if we do not have a queue + with pytest.raises(ValueError): + dmethod.apply_async() + + # But providing a queue at time of submission works + with mock.patch('dispatcher.brokers.pg_notify.SyncBroker.publish_message') as mock_publish_method: + dmethod.apply_async(queue='fooqueue') + mock_publish_method.assert_called_once_with(channel='fooqueue', message=mock.ANY) + + mock_publish_method.assert_called_once() diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 9b2a137..5095a8a 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -6,6 +6,8 @@ from tests.data import methods as test_methods +from dispatcher.config import temporary_settings + SLEEP_METHOD = 'lambda: __import__("time").sleep(0.1)' @@ -24,7 +26,7 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05): async def test_run_lambda_function(apg_dispatcher, pg_message): assert apg_dispatcher.pool.finished_count == 0 - clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) + clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait') await pg_message('lambda: "This worked!"') await asyncio.wait_for(clearing_task, timeout=3) @@ -32,11 +34,19 @@ async def test_run_lambda_function(apg_dispatcher, pg_message): @pytest.mark.asyncio -async def test_run_decorated_function(apg_dispatcher, conn_config): - assert apg_dispatcher.pool.finished_count == 0 +async def test_run_decorated_function(apg_dispatcher, test_settings): + clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) + test_methods.print_hello.apply_async(settings=test_settings) + await asyncio.wait_for(clearing_task, timeout=3) + assert apg_dispatcher.pool.finished_count == 1 + + +@pytest.mark.asyncio +async def test_submit_with_global_settings(apg_dispatcher, test_settings): clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) - test_methods.print_hello.apply_async(config=conn_config) + with temporary_settings(test_settings): + test_methods.print_hello.delay() # settings are inferred from global context await asyncio.wait_for(clearing_task, timeout=3) assert apg_dispatcher.pool.finished_count == 1 @@ -95,8 +105,6 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_message_with_delay(apg_dispatcher, pg_message, pg_control): - assert apg_dispatcher.pool.finished_count == 0 - # Send message to run task with a delay msg = json.dumps({'task': 'lambda: print("This task had a delay")', 'uuid': 'delay_task', 'delay': 0.3}) await pg_message(msg) @@ -186,11 +194,11 @@ async def test_task_discard(apg_dispatcher, pg_message): @pytest.mark.asyncio -async def test_task_discard_in_task_definition(apg_dispatcher, conn_config): +async def test_task_discard_in_task_definition(apg_dispatcher, test_settings): assert apg_dispatcher.pool.finished_count == 0 for i in range(10): - test_methods.sleep_discard.apply_async(args=[2], config=conn_config) + test_methods.sleep_discard.apply_async(args=[2], settings=test_settings) await wait_to_receive(apg_dispatcher, 10) @@ -199,11 +207,11 @@ async def test_task_discard_in_task_definition(apg_dispatcher, conn_config): @pytest.mark.asyncio -async def test_tasks_in_serial(apg_dispatcher, conn_config): +async def test_tasks_in_serial(apg_dispatcher, test_settings): assert apg_dispatcher.pool.finished_count == 0 for i in range(10): - test_methods.sleep_serial.apply_async(args=[2], config=conn_config) + test_methods.sleep_serial.apply_async(args=[2], settings=test_settings) await wait_to_receive(apg_dispatcher, 10) @@ -212,11 +220,11 @@ async def test_tasks_in_serial(apg_dispatcher, conn_config): @pytest.mark.asyncio -async def test_tasks_queue_one(apg_dispatcher, conn_config): +async def test_tasks_queue_one(apg_dispatcher, test_settings): assert apg_dispatcher.pool.finished_count == 0 for i in range(10): - test_methods.sleep_queue_one.apply_async(args=[2], config=conn_config) + test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings) await wait_to_receive(apg_dispatcher, 10) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7ea0277..e69de29 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,9 +0,0 @@ -import pytest - -from dispatcher.registry import DispatcherMethodRegistry - - -@pytest.fixture -def registry() -> DispatcherMethodRegistry: - "Return a fresh registry, separate from the global one, for testing" - return DispatcherMethodRegistry() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4f98be1..73a8178 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,3 +1,29 @@ -# this is a good place to create config files, load them, and test that we get the params we expected -# also a good place to take some configs and test initializing dispatcher objects with them -# None of this has been done, but you could do it +import pytest + +from dispatcher.config import DispatcherSettings, LazySettings + + +def test_settings_reference_unconfigured(): + settings = LazySettings() + with pytest.raises(Exception) as exc: + settings.brokers + assert 'Dispatcher not configured' in str(exc) + + +def test_configured_settings(): + settings = LazySettings() + settings._wrapped = DispatcherSettings({'brokers': {'pg_notify': {'config': {}}}}) + 'pg_notify' in settings.brokers + + +def test_serialize_settings(test_settings): + config = test_settings.serialize() + assert 'producers' in config + assert 'publish' in config + assert config['publish'] == {} + assert 'pg_notify' in config['brokers'] + assert config['service']['max_workers'] == 3 + + re_loaded = DispatcherSettings(config) + assert re_loaded.brokers == test_settings.brokers + assert re_loaded.service == test_settings.service diff --git a/tests/unit/test_publish.py b/tests/unit/test_publish.py index c14e825..be4583a 100644 --- a/tests/unit/test_publish.py +++ b/tests/unit/test_publish.py @@ -1,7 +1,6 @@ from unittest import mock from dispatcher.publish import task -from dispatcher.utils import serialize_task import pytest @@ -76,21 +75,3 @@ def run(self): TestMethod.delay() mock_apply_async.assert_called_once_with((), {}) - - -def test_apply_async_with_no_queue(registry): - @task(registry=registry) - def test_method(): - return - - dmethod = registry.get_from_callable(test_method) - - # Can not run a method if we do not have a queue - with pytest.raises(ValueError): - dmethod.apply_async() - - # But providing a queue at time of submission works - with mock.patch('dispatcher.brokers.pg_notify.publish_message') as mock_publish_method: - dmethod.apply_async(queue='fooqueue') - - mock_publish_method.assert_called_once() diff --git a/tools/write_messages.py b/tools/write_messages.py index 4ad981a..c92fc75 100644 --- a/tools/write_messages.py +++ b/tools/write_messages.py @@ -4,9 +4,10 @@ import os import sys -from dispatcher.brokers.pg_notify import publish_message +from dispatcher.factories import get_sync_publisher_from_settings from dispatcher.control import Control from dispatcher.utils import MODULE_METHOD_DELIMITER +from dispatcher.config import setup # Add the test methods to the path so we can use .delay type contracts tools_dir = os.path.abspath( @@ -17,8 +18,12 @@ from test_methods import sleep_function, sleep_discard, task_has_timeout, hello_world_binder -# Database connection details -CONNECTION_STRING = "dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777" + +# Setup the global config from the settings file shared with the service +setup(file_path='dispatcher.yml') + + +broker = get_sync_publisher_from_settings() TEST_MSGS = [ @@ -32,39 +37,38 @@ def main(): print('writing some basic test messages') for channel, message in TEST_MSGS: # Send the notification - publish_message(channel, message, config={'conninfo': CONNECTION_STRING}) + broker.publish_message(channel=channel, message=message) # await send_notification(channel, message) # send more than number of workers quickly print('') print('writing 15 messages fast') for i in range(15): - publish_message('test_channel', f'lambda: {i}', config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=f'lambda: {i}') print('') print('performing a task cancel') # submit a task we will "find" two different ways - publish_message(channel, json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar'}), config={'conninfo': CONNECTION_STRING}) - ctl = Control('test_channel', config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar'})) + ctl = Control('test_channel') canceled_jobs = ctl.control_with_reply('cancel', data={'uuid': 'foobar'}) print(json.dumps(canceled_jobs, indent=2)) print('') print('finding a running task by its task name') - publish_message(channel, json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar2'}), config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar2'})) running_data = ctl.control_with_reply('running', data={'task': 'lambda: __import__("time").sleep(3.1415)'}) print(json.dumps(running_data, indent=2)) print('writing a message with a delay') print(' 4 second delay task') - publish_message(channel, json.dumps({'task': 'lambda: 123421', 'uuid': 'foobar2', 'delay': 4}), config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: 123421', 'uuid': 'foobar2', 'delay': 4})) print(' 30 second delay task') - publish_message(channel, json.dumps({'task': 'lambda: 987987234', 'uuid': 'foobar2', 'delay': 30}), config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: 987987234', 'uuid': 'foobar2', 'delay': 30})) print(' 10 second delay task') # NOTE: this task will error unless you run the dispatcher itself with it in the PYTHONPATH, which is intended sleep_function.apply_async( args=[3], # sleep 3 seconds delay=10, - config={'conninfo': CONNECTION_STRING} ) print('') @@ -88,31 +92,31 @@ def main(): print('') print('demo of submitting discarding tasks') for i in range(10): - publish_message(channel, json.dumps( + broker.publish_message(message=json.dumps( {'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'} - ), config={'conninfo': CONNECTION_STRING}) + )) print('demo of discarding task marked as discarding') for i in range(10): - sleep_discard.apply_async(args=[2], config={'conninfo': CONNECTION_STRING}) + sleep_discard.apply_async(args=[2]) print('demo of discarding tasks with apply_async contract') for i in range(10): - sleep_function.apply_async(args=[3], on_duplicate='discard', config={'conninfo': CONNECTION_STRING}) + sleep_function.apply_async(args=[3], on_duplicate='discard') print('demo of submitting waiting tasks') for i in range(10): - publish_message(channel, json.dumps( + broker.publish_message(message=json.dumps( {'task': 'lambda: __import__("time").sleep(10)', 'on_duplicate': 'serial', 'uuid': f'wait-{i}'} - ), config={'conninfo': CONNECTION_STRING}) + )) print('demo of submitting queue-once tasks') for i in range(10): - publish_message(channel, json.dumps( + broker.publish_message(message=json.dumps( {'task': 'lambda: __import__("time").sleep(8)', 'on_duplicate': 'queue_one', 'uuid': f'queue_one-{i}'} - ), config={'conninfo': CONNECTION_STRING}) + )) print('demo of task_has_timeout that times out due to decorator use') - task_has_timeout.apply_async(config={'conninfo': CONNECTION_STRING}) + task_has_timeout.delay() print('demo of using bind=True, with hello_world_binder') - hello_world_binder.apply_async(config={'conninfo': CONNECTION_STRING}) + hello_world_binder.delay() if __name__ == "__main__": logging.basicConfig(level='ERROR', stream=sys.stdout) From 477439fd288f725ae84b06850742e96d85c0ce3e Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 14 Feb 2025 13:28:18 -0500 Subject: [PATCH 02/19] Complete worker settings initialization --- dispatcher/factories.py | 2 +- dispatcher/main.py | 5 +++-- dispatcher/pool.py | 6 +++--- tests/integration/test_main.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index 152567f..a17ed33 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -53,7 +53,7 @@ def from_settings(settings: LazySettings = global_settings) -> DispatcherMain: between the service, publisher, and any other interacting processes. """ producers = producers_from_settings(settings=settings) - return DispatcherMain(settings.service, producers) + return DispatcherMain(settings.service, producers, settings=settings) # ---- Publisher objects ---- diff --git a/dispatcher/main.py b/dispatcher/main.py index 2327165..239bc0e 100644 --- a/dispatcher/main.py +++ b/dispatcher/main.py @@ -7,6 +7,7 @@ from dispatcher.pool import WorkerPool from dispatcher.producers import BaseProducer +from dispatcher.config import settings as global_settings, LazySettings logger = logging.getLogger(__name__) @@ -76,7 +77,7 @@ def __init__(self) -> None: class DispatcherMain: - def __init__(self, service_config: dict, producers: Iterable[BaseProducer]): + def __init__(self, service_config: dict, producers: Iterable[BaseProducer], settings: LazySettings = global_settings): self.delayed_messages: list[SimpleNamespace] = [] self.received_count = 0 self.control_count = 0 @@ -85,7 +86,7 @@ def __init__(self, service_config: dict, producers: Iterable[BaseProducer]): # Lock for file descriptor mgmnt - hold lock when forking or connecting, to avoid DNS hangs # psycopg is well-behaved IFF you do not connect while forking, compare to AWX __clean_on_fork__ self.fd_lock = asyncio.Lock() - self.pool = WorkerPool(fd_lock=self.fd_lock, **service_config) + self.pool = WorkerPool(fd_lock=self.fd_lock, settings=settings, **service_config) # Set all the producers, this should still not start anything, just establishes objects self.producers = producers diff --git a/dispatcher/pool.py b/dispatcher/pool.py index 6506d50..0dfdf1e 100644 --- a/dispatcher/pool.py +++ b/dispatcher/pool.py @@ -12,13 +12,12 @@ class PoolWorker: - def __init__(self, worker_id: int, process: ProcessProxy, settings: LazySettings = global_settings) -> None: + def __init__(self, worker_id: int, process: ProcessProxy) -> None: self.worker_id = worker_id self.process = process self.current_task: Optional[dict] = None self.started_at: Optional[int] = None self.is_active_cancel: bool = False - self.settings_stash: dict = settings.serialize() # Tracking information for worker self.finished_count = 0 @@ -96,9 +95,10 @@ def __init__(self) -> None: class WorkerPool: - def __init__(self, max_workers: int, fd_lock: Optional[asyncio.Lock] = None): + def __init__(self, max_workers: int, fd_lock: Optional[asyncio.Lock] = None, settings: LazySettings = global_settings): self.max_workers = max_workers self.workers: dict[int, PoolWorker] = {} + self.settings_stash: dict = settings.serialize() # These are passed to the workers to initialize dispatcher settings self.next_worker_id = 0 self.process_manager = ProcessManager() self.queued_messages: list[dict] = [] # TODO: use deque, invent new kinds of logging anxiety diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 5095a8a..af35524 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -45,7 +45,7 @@ async def test_run_decorated_function(apg_dispatcher, test_settings): @pytest.mark.asyncio async def test_submit_with_global_settings(apg_dispatcher, test_settings): clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) - with temporary_settings(test_settings): + with temporary_settings(test_settings.serialize()): test_methods.print_hello.delay() # settings are inferred from global context await asyncio.wait_for(clearing_task, timeout=3) From 46b808da8b6846c12395041a2851355a126b821c Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 14 Feb 2025 13:30:41 -0500 Subject: [PATCH 03/19] linter fixups after worker settings fixing --- dispatcher/config.py | 7 +------ dispatcher/main.py | 3 ++- dispatcher/pool.py | 10 ++++++++-- dispatcher/process.py | 2 +- dispatcher/registry.py | 5 +++-- dispatcher/worker/task.py | 2 +- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/dispatcher/config.py b/dispatcher/config.py index 1c76a6b..ff4bf07 100644 --- a/dispatcher/config.py +++ b/dispatcher/config.py @@ -16,12 +16,7 @@ def __init__(self, config: dict) -> None: # self.options: dict = config.get('options', {}) def serialize(self): - return dict( - brokers=self.brokers, - producers=self.producers, - service=self.service, - publish=self.publish - ) + return dict(brokers=self.brokers, producers=self.producers, service=self.service, publish=self.publish) def settings_from_file(path: str) -> DispatcherSettings: diff --git a/dispatcher/main.py b/dispatcher/main.py index 239bc0e..bf2f0f6 100644 --- a/dispatcher/main.py +++ b/dispatcher/main.py @@ -5,9 +5,10 @@ from types import SimpleNamespace from typing import Iterable, Optional +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings from dispatcher.pool import WorkerPool from dispatcher.producers import BaseProducer -from dispatcher.config import settings as global_settings, LazySettings logger = logging.getLogger(__name__) diff --git a/dispatcher/pool.py b/dispatcher/pool.py index 0dfdf1e..c6d2531 100644 --- a/dispatcher/pool.py +++ b/dispatcher/pool.py @@ -4,9 +4,10 @@ from asyncio import Task from typing import Iterator, Optional +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings from dispatcher.process import ProcessManager, ProcessProxy from dispatcher.utils import DuplicateBehavior, MessageAction -from dispatcher.config import settings as global_settings, LazySettings logger = logging.getLogger(__name__) @@ -188,7 +189,12 @@ async def manage_timeout(self) -> None: self.events.timeout_event.clear() async def up(self) -> None: - process = self.process_manager.create_process((self.settings_stash, self.next_worker_id,)) + process = self.process_manager.create_process( + ( + self.settings_stash, + self.next_worker_id, + ) + ) worker = PoolWorker(self.next_worker_id, process) self.workers[self.next_worker_id] = worker self.next_worker_id += 1 diff --git a/dispatcher/process.py b/dispatcher/process.py index ecfce2c..5150b9c 100644 --- a/dispatcher/process.py +++ b/dispatcher/process.py @@ -46,7 +46,7 @@ def get_event_loop(self): self._loop = asyncio.get_event_loop() return self._loop - def create_process(self, args: Iterable[int | str], **kwargs) -> ProcessProxy: + def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy: return ProcessProxy(args, self.finished_queue, **kwargs) async def read_finished(self) -> dict[str, Union[str, int]]: diff --git a/dispatcher/registry.py b/dispatcher/registry.py index b92a2eb..3e872fc 100644 --- a/dispatcher/registry.py +++ b/dispatcher/registry.py @@ -6,8 +6,9 @@ from typing import Callable, Optional, Set, Tuple from uuid import uuid4 +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings from dispatcher.utils import MODULE_METHOD_DELIMITER, DispatcherCallable, resolve_callable -from dispatcher.config import settings as global_settings, DispatcherSettings logger = logging.getLogger(__name__) @@ -80,7 +81,7 @@ def get_async_body( return body - def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, settings: DispatcherSettings = global_settings, **kw) -> Tuple[dict, str]: + def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, settings: LazySettings = global_settings, **kw) -> Tuple[dict, str]: queue = queue or self.submission_defaults.get('queue') if callable(queue): diff --git a/dispatcher/worker/task.py b/dispatcher/worker/task.py index ebafc49..2708366 100644 --- a/dispatcher/worker/task.py +++ b/dispatcher/worker/task.py @@ -8,8 +8,8 @@ import traceback from queue import Empty as QueueEmpty -from dispatcher.registry import registry from dispatcher.config import setup +from dispatcher.registry import registry logger = logging.getLogger(__name__) From 622cd4af66cf3eacccf04289f26a40cce7837568 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 14 Feb 2025 14:28:02 -0500 Subject: [PATCH 04/19] Work factories into control module --- dispatcher/brokers/__init__.py | 29 +++++++++++++++++++++++++ dispatcher/brokers/pg_notify.py | 12 +++++++---- dispatcher/control.py | 17 ++++++++------- dispatcher/factories.py | 36 ++++++++------------------------ dispatcher/producers/brokered.py | 4 ++-- tests/conftest.py | 23 +++++++------------- 6 files changed, 65 insertions(+), 56 deletions(-) diff --git a/dispatcher/brokers/__init__.py b/dispatcher/brokers/__init__.py index e69de29..aec85d4 100644 --- a/dispatcher/brokers/__init__.py +++ b/dispatcher/brokers/__init__.py @@ -0,0 +1,29 @@ +import importlib +from types import ModuleType + +from .base import BaseBroker + + +def get_broker_module(broker_name) -> ModuleType: + "Static method to alias import_module so we use a consistent import path" + return importlib.import_module(f'dispatcher.brokers.{broker_name}') + + +def get_async_broker(broker_name: str, broker_config: dict, **overrides) -> BaseBroker: + """ + Given the name of the broker in the settings, and the data under that entry in settings, + return the asyncio broker object. + """ + broker_module = get_broker_module(broker_name) + kwargs = broker_config.copy() + kwargs.update(overrides) + return broker_module.AsyncBroker(**kwargs) + + +def get_sync_broker(broker_name, broker_config) -> BaseBroker: + """ + Given the name of the broker in the settings, and the data under that entry in settings, + return the synchronous broker object. + """ + broker_module = get_broker_module(broker_name) + return broker_module.SyncBroker(**broker_config) diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index e59f5d7..e4d6bd7 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -22,7 +22,7 @@ class PGNotifyBase(BaseBroker): def __init__( self, config: Optional[dict] = None, - channels: Iterable[str] = ('dispatcher_default',), + channels: Iterable[str] = (), default_publish_channel: Optional[str] = None, ) -> None: """ @@ -44,9 +44,13 @@ def get_publish_channel(self, channel: Optional[str] = None): "Handle default for the publishing channel for calls to publish_message, shared sync and async" if channel is not None: return channel - if self.default_publish_channel is None: - raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') - return self.default_publish_channel + elif self.default_publish_channel is not None: + return self.default_publish_channel + elif len(self.channels) == 1: + # de-facto default channel, because there is only 1 + return self.channels[0] + + raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') def get_connection_method(self, factory_path: Optional[str] = None) -> Callable: "Handles settings, returns a method (async or sync) for getting a new connection" diff --git a/dispatcher/control.py b/dispatcher/control.py index 424d85f..ac47869 100644 --- a/dispatcher/control.py +++ b/dispatcher/control.py @@ -4,8 +4,9 @@ import time import uuid from types import SimpleNamespace +from typing import Optional -from dispatcher.factories import get_async_publisher_from_settings, get_sync_publisher_from_settings +from dispatcher.factories import get_async_broker, get_sync_broker from dispatcher.producers import BrokeredProducer logger = logging.getLogger('awx.main.dispatch.control') @@ -36,7 +37,7 @@ async def process_message(self, payload, producer=None, channel=None): async def connected_callback(self, producer) -> None: payload = json.dumps(self.send_data) - await producer.notify(self.queuename, payload) + await producer.notify(channel=self.queuename, message=payload) logger.info('Sent control message, expecting replies soon') def fatal_error_callback(self, *args): @@ -56,9 +57,10 @@ def fatal_error_callback(self, *args): class Control(object): - def __init__(self, queue, config=None): + def __init__(self, broker_name: str, broker_config: dict, queue: Optional[str] = None) -> None: self.queuename = queue - self.config = config + self.broker_name = broker_name + self.broker_config = broker_config def running(self, *args, **kwargs): return self.control_with_reply('running', *args, **kwargs) @@ -90,7 +92,7 @@ async def acontrol_with_reply_internal(self, producer, send_data, expected_repli return [json.loads(payload) for payload in control_callbacks.received_replies] def make_producer(self, reply_queue): - broker = get_async_publisher_from_settings(channels=[reply_queue]) + broker = get_async_broker(self.broker_name, self.broker_config, channels=[reply_queue]) return BrokeredProducer(broker, close_on_exit=True) async def acontrol_with_reply(self, command, expected_replies=1, timeout=1, data=None): @@ -114,7 +116,6 @@ def control_with_reply(self, command, expected_replies=1, timeout=1, data=None): logger.info('control-and-reply {} to {}'.format(command, self.queuename)) start = time.time() reply_queue = Control.generate_reply_queue_name() - send_data = {'control': command, 'reply_to': reply_queue} if data: send_data['control_data'] = data @@ -131,12 +132,12 @@ def control_with_reply(self, command, expected_replies=1, timeout=1, data=None): logger.info(f'control-and-reply message returned in {time.time() - start} seconds') return replies - # NOTE: this is the synchronous version, only to be used for no-reply def control(self, command, data=None): + "Send message in fire-and-forget mode, as synchronous code. Only for no-reply control." send_data = {'control': command} if data: send_data['control_data'] = data payload = json.dumps(send_data) - broker = get_sync_publisher_from_settings() + broker = get_sync_broker(self.broker_name, self.broker_config) broker.publish_message(channel=self.queuename, message=payload) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index a17ed33..cbe061f 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -1,11 +1,11 @@ -import importlib -from types import ModuleType from typing import Iterable, Optional from dispatcher import producers +from dispatcher.brokers import get_async_broker, get_sync_broker from dispatcher.brokers.base import BaseBroker from dispatcher.config import LazySettings from dispatcher.config import settings as global_settings +from dispatcher.control import Control from dispatcher.main import DispatcherMain """ @@ -17,22 +17,6 @@ # ---- Service objects ---- -def get_broker_module(broker_name) -> ModuleType: - "Static method to alias import_module so we use a consistent import path" - return importlib.import_module(f'dispatcher.brokers.{broker_name}') - - -def get_async_broker(broker_name: str, broker_config: dict, **overrides) -> BaseBroker: - """ - Given the name of the broker in the settings, and the data under that entry in settings, - return the asyncio broker object. - """ - broker_module = get_broker_module(broker_name) - kwargs = broker_config.copy() - kwargs.update(overrides) - return broker_module.AsyncBroker(**kwargs) - - def producers_from_settings(settings: LazySettings = global_settings) -> Iterable[producers.BaseProducer]: producer_objects = [] for broker_name, broker_kwargs in settings.brokers.items(): @@ -59,15 +43,6 @@ def from_settings(settings: LazySettings = global_settings) -> DispatcherMain: # ---- Publisher objects ---- -def get_sync_broker(broker_name, broker_config) -> BaseBroker: - """ - Given the name of the broker in the settings, and the data under that entry in settings, - return the synchronous broker object. - """ - broker_module = get_broker_module(broker_name) - return broker_module.SyncBroker(**broker_config) - - def _get_publisher_broker_name(publish_broker: Optional[str] = None, settings: LazySettings = global_settings) -> str: if publish_broker: return publish_broker @@ -96,3 +71,10 @@ def get_async_publisher_from_settings(publish_broker: Optional[str] = None, sett """ publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) return get_async_broker(publish_broker, settings.brokers[publish_broker], **overrides) + + +def get_control_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides): + publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) + broker_options = settings.brokers[publish_broker].copy() + broker_options.update(overrides) + return Control(publish_broker, broker_options) diff --git a/dispatcher/producers/brokered.py b/dispatcher/producers/brokered.py index 3b276f3..b40987b 100644 --- a/dispatcher/producers/brokered.py +++ b/dispatcher/producers/brokered.py @@ -38,8 +38,8 @@ async def produce_forever(self, dispatcher) -> None: self.produced_count += 1 await dispatcher.process_message(payload, producer=self, channel=channel) - async def notify(self, channel: str, payload: str = '') -> None: - await self.broker.apublish_message(channel=channel, message=payload) + async def notify(self, channel: Optional[str] = None, message: str = '') -> None: + await self.broker.apublish_message(channel=channel, message=message) async def shutdown(self) -> None: if self.production_task: diff --git a/tests/conftest.py b/tests/conftest.py index 8ad89b9..cf79404 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import contextlib from typing import Callable, AsyncIterator @@ -11,10 +9,10 @@ from dispatcher.main import DispatcherMain from dispatcher.control import Control -from dispatcher.brokers.pg_notify import SyncBroker, AsyncBroker +from dispatcher.brokers.pg_notify import AsyncBroker from dispatcher.registry import DispatcherMethodRegistry -from dispatcher.config import temporary_settings, DispatcherSettings -from dispatcher.factories import from_settings +from dispatcher.config import DispatcherSettings +from dispatcher.factories import from_settings, get_control_from_settings # List of channels to listen on @@ -30,6 +28,7 @@ "config": {'conninfo': CONNECTION_STRING}, "sync_connection_factory": "dispatcher.brokers.pg_notify.connection_saver", # "async_connection_factory": "dispatcher.brokers.pg_notify.async_connection_saver", + "default_publish_channel": "test_channel" } }, "pool": { @@ -74,12 +73,6 @@ def test_settings(): return DispatcherSettings(BASIC_CONFIG) -@pytest.fixture -def test_setup(): - with temporary_settings(BASIC_CONFIG): - yield - - @pytest_asyncio.fixture(loop_scope="function", scope="function") async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: dispatcher = None @@ -101,15 +94,15 @@ async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: @pytest_asyncio.fixture(loop_scope="function", scope="function") async def pg_message(psycopg_conn) -> Callable: - async def _rf(message, channel='test_channel'): - broker = AsyncBroker(connection=psycopg_conn) + async def _rf(message, channel=None): + broker = AsyncBroker(connection=psycopg_conn, default_publish_channel='test_channel') await broker.apublish_message(channel=channel, message=message) return _rf @pytest_asyncio.fixture(loop_scope="function", scope="function") -async def pg_control(test_setup) -> AsyncIterator[Control]: - yield Control(queue='test_channel') +async def pg_control(test_settings) -> AsyncIterator[Control]: + return get_control_from_settings(settings=test_settings) @pytest_asyncio.fixture(loop_scope="function", scope="function") From bbf9a3663eb1d4964318f6596861feaaa8c7f3f0 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 14 Feb 2025 15:40:58 -0500 Subject: [PATCH 05/19] Add docs on the config --- README.md | 46 +++++++++++++++------- docs/config.md | 104 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 14 deletions(-) create mode 100644 docs/config.md diff --git a/README.md b/README.md index a6db44c..7f2527b 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,18 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/ansible/dispatcher/blob/main/LICENSE) -Working space for dispatcher prototyping - -This is firstly intended to be a code split of: +This is intended to be a working space for prototyping a code split of: As a part of doing the split, we also want to resolve a number of long-standing design and sustainability issues, thus, asyncio. +The philosophy of the dispatcher is to have a limited scope +as a "local" runner of background tasks, but to be composable +so that it can be "wrapped" easily to enable clustering and +distributed task management by apps using it. + Licensed under [Apache Software License 2.0](LICENSE) ### Usage @@ -37,20 +40,11 @@ def print_hello(): print('hello world!!') ``` -#### Dispatcher service - -The dispatcher service needs to be running before you submit tasks. -This does not make any attempts at message durability or confirmation. -If you submit a task in an outage of the service, it will be dropped. - -There are 2 ways to run the dispatcher service: - -- Importing and running (code snippet below) -- A CLI entrypoint `dispatcher-standalone` for demo purposes +Additionally, you need to configure dispatcher somewhere in your import path. +This tells dispatcher how to submit tasks to be ran. ```python from dispatcher.config import setup -from dispatcher import run_service config = { "producers": { @@ -64,6 +58,26 @@ config = { "pool": {"max_workers": 4}, } setup(config) +``` + +For more on how to set up and the allowed options in the config, +see the section [config](docs/config.md) docs. + +#### Dispatcher service + +The dispatcher service needs to be running before you submit tasks. +This does not make any attempts at message durability or confirmation. +If you submit a task in an outage of the service, it will be dropped. + +There are 2 ways to run the dispatcher service: + +- Importing and running (code snippet below) +- A CLI entrypoint `dispatcher-standalone` for demo purposes + +```python +from dispatcher import run_service + +# After the setup() method has been called run_service() ``` @@ -84,6 +98,8 @@ The following code will submit `print_hello` to run in the background dispatcher ```python from test_methods import print_hello +# After the setup() method has been called + print_hello.delay() ``` @@ -92,6 +108,8 @@ Also valid: ```python from test_methods import print_hello +# After the setup() method has been called + print_hello.apply_async(args=[], kwargs={}) ``` diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000..9c95223 --- /dev/null +++ b/docs/config.md @@ -0,0 +1,104 @@ +## Dispatcher Configuration + +Why is configuration needed? Consider doing this, which uses the demo content: + +``` +PYTHONPATH=$PYTHONPATH:tools/ python -c "from tools.test_methods import sleep_function; sleep_function.delay()" +``` + +This will result in an error: + +> Dispatcher not configured, set DISPATCHER_CONFIG_FILE or call dispatcher.config.setup + +This is an error because dispatcher does not have information to connect to a message broker. +In the case of postgres, that information is the connection information (host, user, password, etc) +as well as the pg_notify channel to send the message to. + +### Ways to configure + +#### From file + +The provided entrypoint `dispatcher-standalone` can only use a file, which is how the demo works. +The demo runs using the `dispatcher.yml` config file at the top-level of this repo. + +You can do the same thing in python by: + +```python +from dispatcher.config import setup + +setup(file_path='dispatcher.yml') +``` + +This approach is used by the demo's test script at `tools/write_messages.py`, +which acts as a "publisher", meaning that it submits tasks over the message +broker to be ran. +This setup ensures that both the service (`dispatcher-standalone`) and the publisher +are using the same configuration. + +#### From a dictionary + +Calling `setup(config)`, where `config` is a python dictionary is +equivelent to dumping the `config` to a yaml file and using that as +the file-based config. + +### Configuration Contents + +The config is broken down by either the process that uses that section, +or shared resources used by multiple processes. + +The general structure is: + +```yaml +--- +service: + # options +brokers: + pg_notify: + # options +producers: + ProducerClass: + # options +publish: + # options +``` + +#### Brokers + +Brokers relay messages which give instructions about code to run. +Right now the only broker available is pg_notify. + +The sub-options become python `kwargs` passed to the broker classes +`AsyncBroker` and `SyncBroker`, for the sychronous and asyncio versions +of the broker. +For now, you will just have to read the code to see what those options are +at [dispatcher.brokers.pg_notify](dispatcher/brokers/pg_notify.py). + +The broker classes have methods that allow for submitting messages +and reading messages. + +#### Service + +This configures the background task service. + +The options will correspond to the `DispatcherMain` class +in [dispatcher.main](dispatcher/main.py), or its related +[dispatcher.pool](dispatcher/pool.py). + +Service-specific options are mainly concerned with worker +management. For instance, auto-scaling options will be here, +like worker count, etc. + +#### Producers + +These are "producers of tasks" in the dispatcher service. + +For every listed broker, a `BrokeredProducer` is automatically +created. That means that tasks may be produced from the messaging +system that the dispatcher service is listening to. + +The other current use case is `ScheduledProducer`, +which submits tasks every certain number of seconds. + +#### Publish + +Additional options for publishers (task submitters). From 34f7098dfd99072776841fd3b61bc66fa3fdf2be Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 14 Feb 2025 15:54:21 -0500 Subject: [PATCH 06/19] Fix type hinting issue --- dispatcher/brokers/pg_notify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index e4d6bd7..437ead3 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Iterable, Optional +from typing import Callable, Optional, Union import psycopg @@ -22,7 +22,7 @@ class PGNotifyBase(BaseBroker): def __init__( self, config: Optional[dict] = None, - channels: Iterable[str] = (), + channels: Union[tuple, list] = (), default_publish_channel: Optional[str] = None, ) -> None: """ From 6b4653f1cc36737ad1e5061bc444a6f5d04ac889 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Fri, 14 Feb 2025 15:59:32 -0500 Subject: [PATCH 07/19] Fix events data structure pattern goof --- dispatcher/control.py | 11 ++++++----- dispatcher/main.py | 4 ---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/dispatcher/control.py b/dispatcher/control.py index ac47869..68cbe2e 100644 --- a/dispatcher/control.py +++ b/dispatcher/control.py @@ -3,7 +3,6 @@ import logging import time import uuid -from types import SimpleNamespace from typing import Optional from dispatcher.factories import get_async_broker, get_sync_broker @@ -12,6 +11,11 @@ logger = logging.getLogger('awx.main.dispatch.control') +class ControlEvents: + def __init__(self) -> None: + self.exit_event = asyncio.Event() + + class ControlCallbacks: """This calls follows the same structure as the DispatcherMain class @@ -24,12 +28,9 @@ def __init__(self, queuename, send_data, expected_replies): self.expected_replies = expected_replies self.received_replies = [] - self.events = self._create_events() + self.events = ControlEvents() self.shutting_down = False - def _create_events(self): - return SimpleNamespace(exit_event=asyncio.Event()) - async def process_message(self, payload, producer=None, channel=None): self.received_replies.append(payload) if self.expected_replies and (len(self.received_replies) >= self.expected_replies): diff --git a/dispatcher/main.py b/dispatcher/main.py index bf2f0f6..5c93b03 100644 --- a/dispatcher/main.py +++ b/dispatcher/main.py @@ -94,10 +94,6 @@ def __init__(self, service_config: dict, producers: Iterable[BaseProducer], sett self.events: DispatcherEvents = DispatcherEvents() - def _create_events(self): - "Benchmark tests have to re-create this because they use same object in different event loops" - return SimpleNamespace(exit_event=asyncio.Event()) - def fatal_error_callback(self, *args) -> None: """Method to connect to error callbacks of other tasks, will kick out of main loop""" if self.shutting_down: From ff1d6e4e9e8b7e039e5248470ec2e47b749ee217 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Mon, 17 Feb 2025 09:46:02 -0500 Subject: [PATCH 08/19] Implement change described in comment, cls and kwargs patterns --- dispatcher.yml | 3 ++- dispatcher/config.py | 9 ++++++++- dispatcher/factories.py | 12 +++++++++++- dispatcher/main.py | 9 ++++----- dispatcher/pool.py | 11 +++++------ docs/config.md | 20 ++++++++++++++------ docs/vision.md | 30 ++++++++++++++++++++++++++++++ 7 files changed, 74 insertions(+), 20 deletions(-) create mode 100644 docs/vision.md diff --git a/dispatcher.yml b/dispatcher.yml index f00b9ca..340cf6e 100644 --- a/dispatcher.yml +++ b/dispatcher.yml @@ -1,7 +1,8 @@ # Demo config --- service: - max_workers: 4 + pool_kwargs: + max_workers: 4 brokers: pg_notify: config: diff --git a/dispatcher/config.py b/dispatcher/config.py index ff4bf07..56bbb8b 100644 --- a/dispatcher/config.py +++ b/dispatcher/config.py @@ -9,8 +9,15 @@ class DispatcherSettings: def __init__(self, config: dict) -> None: self.brokers: dict = config.get('brokers', {}) self.producers: dict = config.get('producers', {}) - self.service: dict = config.get('service', {'max_workers': 3}) + self.service: dict = config.get('service', {}) self.publish: dict = config.get('publish', {}) + + # Automatic defaults + if 'pool_kwargs' not in self.service: + self.service['pool_kwargs'] = {} + if 'max_workers' not in self.service['pool_kwargs']: + self.service['pool_kwargs']['max_workers'] = 3 + # TODO: firmly planned sections of config for later # self.callbacks: dict = config.get('callbacks', {}) # self.options: dict = config.get('options', {}) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index cbe061f..60bc409 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -7,6 +7,8 @@ from dispatcher.config import settings as global_settings from dispatcher.control import Control from dispatcher.main import DispatcherMain +from dispatcher.pool import WorkerPool +from dispatcher.process import ProcessManager """ Creates objects from settings, @@ -17,6 +19,13 @@ # ---- Service objects ---- +def pool_from_settings(settings: LazySettings = global_settings): + kwargs = settings.service.get('pool_kwargs', {}).copy() + kwargs['settings'] = settings + kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings + return WorkerPool(**kwargs) + + def producers_from_settings(settings: LazySettings = global_settings) -> Iterable[producers.BaseProducer]: producer_objects = [] for broker_name, broker_kwargs in settings.brokers.items(): @@ -37,7 +46,8 @@ def from_settings(settings: LazySettings = global_settings) -> DispatcherMain: between the service, publisher, and any other interacting processes. """ producers = producers_from_settings(settings=settings) - return DispatcherMain(settings.service, producers, settings=settings) + pool = pool_from_settings(settings=settings) + return DispatcherMain(producers, pool) # ---- Publisher objects ---- diff --git a/dispatcher/main.py b/dispatcher/main.py index 5c93b03..ce0b398 100644 --- a/dispatcher/main.py +++ b/dispatcher/main.py @@ -5,8 +5,6 @@ from types import SimpleNamespace from typing import Iterable, Optional -from dispatcher.config import LazySettings -from dispatcher.config import settings as global_settings from dispatcher.pool import WorkerPool from dispatcher.producers import BaseProducer @@ -78,7 +76,7 @@ def __init__(self) -> None: class DispatcherMain: - def __init__(self, service_config: dict, producers: Iterable[BaseProducer], settings: LazySettings = global_settings): + def __init__(self, producers: Iterable[BaseProducer], pool: WorkerPool): self.delayed_messages: list[SimpleNamespace] = [] self.received_count = 0 self.control_count = 0 @@ -87,9 +85,10 @@ def __init__(self, service_config: dict, producers: Iterable[BaseProducer], sett # Lock for file descriptor mgmnt - hold lock when forking or connecting, to avoid DNS hangs # psycopg is well-behaved IFF you do not connect while forking, compare to AWX __clean_on_fork__ self.fd_lock = asyncio.Lock() - self.pool = WorkerPool(fd_lock=self.fd_lock, settings=settings, **service_config) - # Set all the producers, this should still not start anything, just establishes objects + # Save the associated dispatcher objects, usually created by factories + # expected that these are not yet running any tasks + self.pool = pool self.producers = producers self.events: DispatcherEvents = DispatcherEvents() diff --git a/dispatcher/pool.py b/dispatcher/pool.py index c6d2531..b55b758 100644 --- a/dispatcher/pool.py +++ b/dispatcher/pool.py @@ -96,12 +96,12 @@ def __init__(self) -> None: class WorkerPool: - def __init__(self, max_workers: int, fd_lock: Optional[asyncio.Lock] = None, settings: LazySettings = global_settings): + def __init__(self, max_workers: int, process_manager: ProcessManager, settings: LazySettings = global_settings): self.max_workers = max_workers self.workers: dict[int, PoolWorker] = {} self.settings_stash: dict = settings.serialize() # These are passed to the workers to initialize dispatcher settings self.next_worker_id = 0 - self.process_manager = ProcessManager() + self.process_manager = process_manager self.queued_messages: list[dict] = [] # TODO: use deque, invent new kinds of logging anxiety self.read_results_task: Optional[Task] = None self.start_worker_task: Optional[Task] = None @@ -112,7 +112,6 @@ def __init__(self, max_workers: int, fd_lock: Optional[asyncio.Lock] = None, set self.discard_count: int = 0 self.shutdown_timeout = 3 self.management_lock = asyncio.Lock() - self.fd_lock = fd_lock or asyncio.Lock() self.events: PoolEvents = PoolEvents() @@ -127,12 +126,12 @@ def received_count(self): async def start_working(self, dispatcher) -> None: self.read_results_task = asyncio.create_task(self.read_results_forever(), name='results_task') self.read_results_task.add_done_callback(dispatcher.fatal_error_callback) - self.management_task = asyncio.create_task(self.manage_workers(), name='management_task') + self.management_task = asyncio.create_task(self.manage_workers(forking_lock=dispatcher.fd_lock), name='management_task') self.management_task.add_done_callback(dispatcher.fatal_error_callback) self.timeout_task = asyncio.create_task(self.manage_timeout(), name='timeout_task') self.timeout_task.add_done_callback(dispatcher.fatal_error_callback) - async def manage_workers(self) -> None: + async def manage_workers(self, forking_lock: asyncio.Lock) -> None: """Enforces worker policy like min and max workers, and later, auto scale-down""" while not self.shutting_down: while len(self.workers) < self.max_workers: @@ -144,7 +143,7 @@ async def manage_workers(self) -> None: for worker in self.workers.values(): if worker.status == 'initialized': logger.debug(f'Starting subprocess for worker {worker.worker_id}') - async with self.fd_lock: # never fork while connecting + async with forking_lock: # never fork while connecting await worker.start() await self.events.management_event.wait() diff --git a/docs/config.md b/docs/config.md index 9c95223..5bab855 100644 --- a/docs/config.md +++ b/docs/config.md @@ -43,15 +43,21 @@ the file-based config. ### Configuration Contents -The config is broken down by either the process that uses that section, -or shared resources used by multiple processes. +At the top-level, config is broken down by either the process that uses that section, +or brokers, which is a shared resources used by multiple processes. + +At the level below that, the config gives instructions for creating python objects. +The module `dispatcher.factories` has the task of creating those objects from settings. +The design goal is to have a little possible divergence from the settings structure +and the class structure in the code. The general structure is: ```yaml --- service: - # options + pool_kwargs: + # options brokers: pg_notify: # options @@ -62,6 +68,8 @@ publish: # options ``` +When providing `pool_kwargs`, those are the kwargs passed to `WorkerPool`, for example. + #### Brokers Brokers relay messages which give instructions about code to run. @@ -80,11 +88,11 @@ and reading messages. This configures the background task service. -The options will correspond to the `DispatcherMain` class -in [dispatcher.main](dispatcher/main.py), or its related +The `pool_kwargs` options will correspond to the `WorkerPool` class [dispatcher.pool](dispatcher/pool.py). +Process management options will be added to this section later. -Service-specific options are mainly concerned with worker +These options are mainly concerned with worker management. For instance, auto-scaling options will be here, like worker count, etc. diff --git a/docs/vision.md b/docs/vision.md new file mode 100644 index 0000000..f600e64 --- /dev/null +++ b/docs/vision.md @@ -0,0 +1,30 @@ +## Vision for Dispatcher + +The dispatcher strives to be an extremely contained, simple library, +by assuming that your system already has a source-of-truth. + +This will: + - run python tasks + +This will not: + - provide a result backend + - treat the queue as a source of truth + +For problems that go beyond the scope of this library, +suggestions will go into a cookbook. + +```mermaid +flowchart TD + +A(web) +B(task) +C(postgres) + +A-->C +B-->C +``` + +https://taskiq-python.github.io/guide/architecture-overview.html#context + +https://python-rq.org/docs/workers/ + From f97962a2c35855acbc8d4c28037ffc36302f6400 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Mon, 17 Feb 2025 09:55:17 -0500 Subject: [PATCH 09/19] Fix unit tests --- tests/unit/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 73a8178..28da8f1 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -22,7 +22,7 @@ def test_serialize_settings(test_settings): assert 'publish' in config assert config['publish'] == {} assert 'pg_notify' in config['brokers'] - assert config['service']['max_workers'] == 3 + assert config['service']['pool_kwargs']['max_workers'] == 3 re_loaded = DispatcherSettings(config) assert re_loaded.brokers == test_settings.brokers From 9706c1409a3080849b22dc1632f5d03777ec8d22 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Mon, 17 Feb 2025 16:52:32 -0500 Subject: [PATCH 10/19] Convert broker base to protocol --- dispatcher/brokers/base.py | 16 ++-------------- dispatcher/brokers/pg_notify.py | 3 +-- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/dispatcher/brokers/base.py b/dispatcher/brokers/base.py index 241fb47..5f493a2 100644 --- a/dispatcher/brokers/base.py +++ b/dispatcher/brokers/base.py @@ -1,25 +1,13 @@ -from abc import abstractmethod -from typing import Optional +from typing import Optional, Protocol -class BaseBroker: - @abstractmethod - async def connect(self): ... - - @abstractmethod +class BaseBroker(Protocol): async def aprocess_notify(self, connected_callback=None): ... - @abstractmethod async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: ... - @abstractmethod async def aclose(self) -> None: ... - @abstractmethod - def get_connection(self): ... - - @abstractmethod def publish_message(self, channel=None, message=None): ... - @abstractmethod def close(self): ... diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index 437ead3..b888e9f 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -3,7 +3,6 @@ import psycopg -from dispatcher.brokers.base import BaseBroker from dispatcher.utils import resolve_callable logger = logging.getLogger(__name__) @@ -17,7 +16,7 @@ """ -class PGNotifyBase(BaseBroker): +class PGNotifyBase: def __init__( self, From a489c1cc1b35e2c08e54af39d61a09f75be22396 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Tue, 18 Feb 2025 00:27:45 -0500 Subject: [PATCH 11/19] Refactor into single broker class --- dispatcher/brokers/__init__.py | 15 +- dispatcher/brokers/pg_notify.py | 153 ++++++++++---------- dispatcher/control.py | 6 +- dispatcher/factories.py | 14 +- dispatcher/registry.py | 4 +- dispatcher/tasks.py | 4 +- docs/config.md | 4 +- tests/conftest.py | 21 +-- tests/integration/brokers/test_pg_notify.py | 30 ++++ tests/integration/publish/test_registry.py | 2 +- tools/write_messages.py | 4 +- 11 files changed, 132 insertions(+), 125 deletions(-) create mode 100644 tests/integration/brokers/test_pg_notify.py diff --git a/dispatcher/brokers/__init__.py b/dispatcher/brokers/__init__.py index aec85d4..cd811a5 100644 --- a/dispatcher/brokers/__init__.py +++ b/dispatcher/brokers/__init__.py @@ -9,21 +9,12 @@ def get_broker_module(broker_name) -> ModuleType: return importlib.import_module(f'dispatcher.brokers.{broker_name}') -def get_async_broker(broker_name: str, broker_config: dict, **overrides) -> BaseBroker: +def get_broker(broker_name: str, broker_config: dict, **overrides) -> BaseBroker: """ Given the name of the broker in the settings, and the data under that entry in settings, - return the asyncio broker object. + return the broker object. """ broker_module = get_broker_module(broker_name) kwargs = broker_config.copy() kwargs.update(overrides) - return broker_module.AsyncBroker(**kwargs) - - -def get_sync_broker(broker_name, broker_config) -> BaseBroker: - """ - Given the name of the broker in the settings, and the data under that entry in settings, - return the synchronous broker object. - """ - broker_module = get_broker_module(broker_name) - return broker_module.SyncBroker(**broker_config) + return broker_module.Broker(**kwargs) diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index b888e9f..71709d7 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -16,20 +16,51 @@ """ -class PGNotifyBase: +async def acreate_connection(**config) -> psycopg.AsyncConnection: + "Create a new asyncio connection" + return await psycopg.AsyncConnection.connect(**config) + + +def create_connection(**config) -> psycopg.Connection: + return psycopg.Connection.connect(**config) + + +class Broker: def __init__( self, config: Optional[dict] = None, + async_connection_factory: Optional[str] = None, + sync_connection_factory: Optional[str] = None, + sync_connection: Optional[psycopg.Connection] = None, + async_connection: Optional[psycopg.AsyncConnection] = None, channels: Union[tuple, list] = (), default_publish_channel: Optional[str] = None, ) -> None: """ + config - kwargs to psycopg connect classes, if creating connection this way + (a)sync_connection_factory - importable path to callback for creating + the psycopg connection object, the normal or synchronous version + this will have the config passed as kwargs, if that is also given + async_connection - directly pass the async connection object + sync_connection - directly pass the async connection object channels - listening channels for the service and used for control-and-reply default_publish_channel - if not specified on task level or in the submission by default messages will be sent to this channel. this should be one of the listening channels for messages to be received. """ + if not (config or async_connection_factory or async_connection): + raise RuntimeError('Must specify either config or async_connection_factory') + + if not (config or sync_connection_factory or sync_connection): + raise RuntimeError('Must specify either config or sync_connection_factory') + + self._async_connection_factory = async_connection_factory + self._async_connection = async_connection + + self._sync_connection_factory = sync_connection_factory + self._sync_connection = sync_connection + if config: self._config: dict = config.copy() self._config['autocommit'] = True @@ -51,52 +82,26 @@ def get_publish_channel(self, channel: Optional[str] = None): raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') - def get_connection_method(self, factory_path: Optional[str] = None) -> Callable: - "Handles settings, returns a method (async or sync) for getting a new connection" - if factory_path: - factory = resolve_callable(factory_path) - if not factory: - raise RuntimeError(f'Could not import connection factory {factory_path}') - return factory - elif self._config: - return self.create_connection - else: - raise RuntimeError('Could not construct connection for lack of config or factory') - - def create_connection(self): ... - - -class AsyncBroker(PGNotifyBase): - def __init__( - self, - config: Optional[dict] = None, - async_connection_factory: Optional[str] = None, - sync_connection_factory: Optional[str] = None, # noqa - connection: Optional[psycopg.AsyncConnection] = None, - **kwargs, - ) -> None: - if not (config or async_connection_factory or connection): - raise RuntimeError('Must specify either config or async_connection_factory') - - self._async_connection_factory = async_connection_factory - self._connection = connection - - super().__init__(config=config, **kwargs) - - async def get_connection(self) -> psycopg.AsyncConnection: - if not self._connection: - factory = self.get_connection_method(factory_path=self._async_connection_factory) - connection = await factory(**self._config) - self._connection = connection + # --- asyncio connection methods --- + + async def aget_connection(self) -> psycopg.AsyncConnection: + "Return existing connection or create a new one" + if not self._async_connection: + if self._async_connection_factory: + factory = resolve_callable(self._async_connection_factory) + if not factory: + raise RuntimeError(f'Could not import async connection factory {self._async_connection_factory}') + connection = await factory(**self._config) + elif self._config: + connection = await acreate_connection(**self._config) + else: + raise RuntimeError('Could not construct async connection for lack of config or factory') + self._async_connection = connection return connection # slightly weird due to MyPY - return self._connection + return self._async_connection - @staticmethod - async def create_connection(**config) -> psycopg.AsyncConnection: - return await psycopg.AsyncConnection.connect(**config) - - async def aprocess_notify(self, connected_callback=None): - connection = await self.get_connection() + async def aprocess_notify(self, connected_callback: Optional[Callable] = None): # public + connection = await self.aget_connection() async with connection.cursor() as cur: for channel in self.channels: await cur.execute(f"LISTEN {channel};") @@ -110,8 +115,8 @@ async def aprocess_notify(self, connected_callback=None): async for notify in connection.notifies(): yield notify.channel, notify.payload - async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: - connection = await self.get_connection() + async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public + connection = await self.aget_connection() channel = self.get_publish_channel(channel) async with connection.cursor() as cur: @@ -123,38 +128,26 @@ async def apublish_message(self, channel: Optional[str] = None, message: str = ' logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') async def aclose(self) -> None: - if self._connection: - await self._connection.close() - self._connection = None - + if self._async_connection: + await self._async_connection.close() + self._async_connection = None -class SyncBroker(PGNotifyBase): - def __init__( - self, - config: Optional[dict] = None, - async_connection_factory: Optional[str] = None, # noqa - sync_connection_factory: Optional[str] = None, - connection: Optional[psycopg.Connection] = None, - **kwargs, - ) -> None: - if not (config or sync_connection_factory or connection): - raise RuntimeError('Must specify either config or async_connection_factory') - - self._sync_connection_factory = sync_connection_factory - self._connection = connection - super().__init__(config=config, **kwargs) + # --- synchronous connection methods --- def get_connection(self) -> psycopg.Connection: - if not self._connection: - factory = self.get_connection_method(factory_path=self._sync_connection_factory) - connection = factory(**self._config) - self._connection = connection + if not self._sync_connection: + if self._sync_connection_factory: + factory = resolve_callable(self._sync_connection_factory) + if not factory: + raise RuntimeError(f'Could not import connection factory {self._sync_connection_factory}') + connection = factory(**self._config) + elif self._config: + connection = create_connection(**self._config) + else: + raise RuntimeError('Could not construct connection for lack of config or factory') + self._sync_connection = connection return connection - return self._connection - - @staticmethod - def create_connection(**config) -> psycopg.Connection: - return psycopg.Connection.connect(**config) + return self._sync_connection def publish_message(self, channel: Optional[str] = None, message: str = '') -> None: connection = self.get_connection() @@ -169,9 +162,9 @@ def publish_message(self, channel: Optional[str] = None, message: str = '') -> N logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') def close(self) -> None: - if self._connection: - self._connection.close() - self._connection = None + if self._sync_connection: + self._sync_connection.close() + self._sync_connection = None class ConnectionSaver: @@ -192,7 +185,7 @@ def connection_saver(**config) -> psycopg.Connection: """ if connection_save._connection is None: config['autocommit'] = True - connection_save._connection = SyncBroker.create_connection(**config) + connection_save._connection = create_connection(**config) return connection_save._connection @@ -205,5 +198,5 @@ async def async_connection_saver(**config) -> psycopg.AsyncConnection: """ if connection_save._async_connection is None: config['autocommit'] = True - connection_save._async_connection = await AsyncBroker.create_connection(**config) + connection_save._async_connection = await acreate_connection(**config) return connection_save._async_connection diff --git a/dispatcher/control.py b/dispatcher/control.py index 68cbe2e..35ae3d7 100644 --- a/dispatcher/control.py +++ b/dispatcher/control.py @@ -5,7 +5,7 @@ import uuid from typing import Optional -from dispatcher.factories import get_async_broker, get_sync_broker +from dispatcher.factories import get_broker from dispatcher.producers import BrokeredProducer logger = logging.getLogger('awx.main.dispatch.control') @@ -93,7 +93,7 @@ async def acontrol_with_reply_internal(self, producer, send_data, expected_repli return [json.loads(payload) for payload in control_callbacks.received_replies] def make_producer(self, reply_queue): - broker = get_async_broker(self.broker_name, self.broker_config, channels=[reply_queue]) + broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue]) return BrokeredProducer(broker, close_on_exit=True) async def acontrol_with_reply(self, command, expected_replies=1, timeout=1, data=None): @@ -140,5 +140,5 @@ def control(self, command, data=None): send_data['control_data'] = data payload = json.dumps(send_data) - broker = get_sync_broker(self.broker_name, self.broker_config) + broker = get_broker(self.broker_name, self.broker_config) broker.publish_message(channel=self.queuename, message=payload) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index 60bc409..b54745c 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -1,7 +1,7 @@ from typing import Iterable, Optional from dispatcher import producers -from dispatcher.brokers import get_async_broker, get_sync_broker +from dispatcher.brokers import get_broker from dispatcher.brokers.base import BaseBroker from dispatcher.config import LazySettings from dispatcher.config import settings as global_settings @@ -29,7 +29,7 @@ def pool_from_settings(settings: LazySettings = global_settings): def producers_from_settings(settings: LazySettings = global_settings) -> Iterable[producers.BaseProducer]: producer_objects = [] for broker_name, broker_kwargs in settings.brokers.items(): - broker = get_async_broker(broker_name, broker_kwargs) + broker = get_broker(broker_name, broker_kwargs) producer = producers.BrokeredProducer(broker=broker) producer_objects.append(producer) @@ -64,13 +64,7 @@ def _get_publisher_broker_name(publish_broker: Optional[str] = None, settings: L raise RuntimeError(f'Could not determine which broker to publish with between options {list(settings.brokers.keys())}') -def get_sync_publisher_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides) -> BaseBroker: - publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) - - return get_sync_broker(publish_broker, settings.brokers[publish_broker], **overrides) - - -def get_async_publisher_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides) -> BaseBroker: +def get_publisher_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides) -> BaseBroker: """ An asynchronous publisher is the ideal choice for submitting control-and-reply actions. This returns an asyncio broker of the default publisher type. @@ -80,7 +74,7 @@ def get_async_publisher_from_settings(publish_broker: Optional[str] = None, sett unrelated traffic. """ publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) - return get_async_broker(publish_broker, settings.brokers[publish_broker], **overrides) + return get_broker(publish_broker, settings.brokers[publish_broker], **overrides) def get_control_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides): diff --git a/dispatcher/registry.py b/dispatcher/registry.py index 3e872fc..f67be5d 100644 --- a/dispatcher/registry.py +++ b/dispatcher/registry.py @@ -89,9 +89,9 @@ def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, settings: L obj = self.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw) - from dispatcher.factories import get_sync_publisher_from_settings + from dispatcher.factories import get_publisher_from_settings - broker = get_sync_publisher_from_settings(settings=settings) + broker = get_publisher_from_settings(settings=settings) # TODO: exit if a setting is applied to disable publishing diff --git a/dispatcher/tasks.py b/dispatcher/tasks.py index a16d75d..710ecf6 100644 --- a/dispatcher/tasks.py +++ b/dispatcher/tasks.py @@ -1,8 +1,8 @@ -from dispatcher.factories import get_sync_publisher_from_settings +from dispatcher.factories import get_publisher_from_settings from dispatcher.publish import task @task() def reply_to_control(reply_channel: str, message: str): - broker = get_sync_publisher_from_settings() + broker = get_publisher_from_settings() broker.publish_message(channel=reply_channel, message=message) diff --git a/docs/config.md b/docs/config.md index 5bab855..7137512 100644 --- a/docs/config.md +++ b/docs/config.md @@ -75,9 +75,7 @@ When providing `pool_kwargs`, those are the kwargs passed to `WorkerPool`, for e Brokers relay messages which give instructions about code to run. Right now the only broker available is pg_notify. -The sub-options become python `kwargs` passed to the broker classes -`AsyncBroker` and `SyncBroker`, for the sychronous and asyncio versions -of the broker. +The sub-options become python `kwargs` passed to the broker class `Broker`. For now, you will just have to read the code to see what those options are at [dispatcher.brokers.pg_notify](dispatcher/brokers/pg_notify.py). diff --git a/tests/conftest.py b/tests/conftest.py index cf79404..de0db5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ from dispatcher.main import DispatcherMain from dispatcher.control import Control -from dispatcher.brokers.pg_notify import AsyncBroker +from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection from dispatcher.registry import DispatcherMethodRegistry from dispatcher.config import DispatcherSettings from dispatcher.factories import from_settings, get_control_from_settings @@ -41,7 +41,7 @@ async def aconnection_for_test(): conn = None try: - conn = await AsyncBroker.create_connection(conninfo=CONNECTION_STRING, autocommit=True) + conn = await acreate_connection(conninfo=CONNECTION_STRING, autocommit=True) # Make sure database is running to avoid deadlocks which can come # from using the loop provided by pytest asyncio @@ -92,14 +92,6 @@ async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: await dispatcher.cancel_tasks() -@pytest_asyncio.fixture(loop_scope="function", scope="function") -async def pg_message(psycopg_conn) -> Callable: - async def _rf(message, channel=None): - broker = AsyncBroker(connection=psycopg_conn, default_publish_channel='test_channel') - await broker.apublish_message(channel=channel, message=message) - return _rf - - @pytest_asyncio.fixture(loop_scope="function", scope="function") async def pg_control(test_settings) -> AsyncIterator[Control]: return get_control_from_settings(settings=test_settings) @@ -111,6 +103,15 @@ async def psycopg_conn(): yield conn +@pytest_asyncio.fixture(loop_scope="function", scope="function") +async def pg_message(psycopg_conn) -> Callable: + async def _rf(message, channel=None): + # Note on weirdness here, this broker will only be used for async publishing, so we give junk for synchronous connection + broker = Broker(async_connection=psycopg_conn, default_publish_channel='test_channel', sync_connection_factory='tests.data.methods.something') + await broker.apublish_message(channel=channel, message=message) + return _rf + + @pytest.fixture def registry() -> DispatcherMethodRegistry: "Return a fresh registry, separate from the global one, for testing" diff --git a/tests/integration/brokers/test_pg_notify.py b/tests/integration/brokers/test_pg_notify.py new file mode 100644 index 0000000..9498794 --- /dev/null +++ b/tests/integration/brokers/test_pg_notify.py @@ -0,0 +1,30 @@ +import pytest + +from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection + + +def test_sync_connection_from_config_reuse(conn_config): + broker = Broker(config=conn_config) + conn = broker.get_connection() + with conn.cursor() as cur: + cur.execute('SELECT 1') + assert cur.fetchall() == [(1,)] + + conn2 = broker.get_connection() + assert conn is conn2 + + assert conn is not create_connection(**conn_config) + + +@pytest.mark.asyncio +async def test_async_connection_from_config_reuse(conn_config): + broker = Broker(config=conn_config) + conn = await broker.aget_connection() + async with conn.cursor() as cur: + await cur.execute('SELECT 1') + assert await cur.fetchall() == [(1,)] + + conn2 = await broker.aget_connection() + assert conn is conn2 + + assert conn is not await acreate_connection(**conn_config) diff --git a/tests/integration/publish/test_registry.py b/tests/integration/publish/test_registry.py index 412b8e8..4a6adcb 100644 --- a/tests/integration/publish/test_registry.py +++ b/tests/integration/publish/test_registry.py @@ -21,7 +21,7 @@ def test_method(): dmethod.apply_async() # But providing a queue at time of submission works - with mock.patch('dispatcher.brokers.pg_notify.SyncBroker.publish_message') as mock_publish_method: + with mock.patch('dispatcher.brokers.pg_notify.Broker.publish_message') as mock_publish_method: dmethod.apply_async(queue='fooqueue') mock_publish_method.assert_called_once_with(channel='fooqueue', message=mock.ANY) diff --git a/tools/write_messages.py b/tools/write_messages.py index c92fc75..771784a 100644 --- a/tools/write_messages.py +++ b/tools/write_messages.py @@ -4,7 +4,7 @@ import os import sys -from dispatcher.factories import get_sync_publisher_from_settings +from dispatcher.factories import get_publisher_from_settings from dispatcher.control import Control from dispatcher.utils import MODULE_METHOD_DELIMITER from dispatcher.config import setup @@ -23,7 +23,7 @@ setup(file_path='dispatcher.yml') -broker = get_sync_publisher_from_settings() +broker = get_publisher_from_settings() TEST_MSGS = [ From 7cc2b0118174953b16f96e197afe077caa3757c1 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Tue, 18 Feb 2025 09:36:39 -0500 Subject: [PATCH 12/19] Produce a reference schema --- dispatcher.yml | 1 + dispatcher/config.py | 5 +- dispatcher/factories.py | 54 ++++++++++++++++++++-- dispatcher/producers/scheduled.py | 2 +- docs/config.md | 8 ++++ schema.json | 25 ++++++++++ tests/conftest.py | 1 + tests/integration/publish/test_registry.py | 2 +- tests/unit/test_config.py | 27 ++++++++++- 9 files changed, 117 insertions(+), 8 deletions(-) create mode 100644 schema.json diff --git a/dispatcher.yml b/dispatcher.yml index 340cf6e..be9bc3e 100644 --- a/dispatcher.yml +++ b/dispatcher.yml @@ -1,5 +1,6 @@ # Demo config --- +version: 2 service: pool_kwargs: max_workers: 4 diff --git a/dispatcher/config.py b/dispatcher/config.py index 56bbb8b..6c86d5d 100644 --- a/dispatcher/config.py +++ b/dispatcher/config.py @@ -7,6 +7,9 @@ class DispatcherSettings: def __init__(self, config: dict) -> None: + self.version = 2 + if config.get('version') != self.version: + raise RuntimeError(f'Current config version is {self.version}, config version must match this') self.brokers: dict = config.get('brokers', {}) self.producers: dict = config.get('producers', {}) self.service: dict = config.get('service', {}) @@ -23,7 +26,7 @@ def __init__(self, config: dict) -> None: # self.options: dict = config.get('options', {}) def serialize(self): - return dict(brokers=self.brokers, producers=self.producers, service=self.service, publish=self.publish) + return dict(version=self.version, brokers=self.brokers, producers=self.producers, service=self.service, publish=self.publish) def settings_from_file(path: str) -> DispatcherSettings: diff --git a/dispatcher/factories.py b/dispatcher/factories.py index b54745c..455b592 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -1,4 +1,6 @@ -from typing import Iterable, Optional +from typing import Iterable, Optional, Type, get_origin, get_args +import inspect +from copy import deepcopy from dispatcher import producers from dispatcher.brokers import get_broker @@ -26,10 +28,13 @@ def pool_from_settings(settings: LazySettings = global_settings): return WorkerPool(**kwargs) +def brokers_from_settings(settings: LazySettings = global_settings) -> BaseBroker: + return [get_broker(broker_name, broker_kwargs) for broker_name, broker_kwargs in settings.brokers.items()] + + def producers_from_settings(settings: LazySettings = global_settings) -> Iterable[producers.BaseProducer]: producer_objects = [] - for broker_name, broker_kwargs in settings.brokers.items(): - broker = get_broker(broker_name, broker_kwargs) + for broker in brokers_from_settings(settings=settings): producer = producers.BrokeredProducer(broker=broker) producer_objects.append(producer) @@ -82,3 +87,46 @@ def get_control_from_settings(publish_broker: Optional[str] = None, settings: La broker_options = settings.brokers[publish_broker].copy() broker_options.update(overrides) return Control(publish_broker, broker_options) + + +# ---- Schema generation ---- + +SERIALIZED_TYPES = (int, str, dict, type(None), tuple, list) + + +def is_valid_annotation(annotation): + if get_origin(annotation): + for arg in get_args(annotation): + if not is_valid_annotation(arg): + return False + else: + if annotation not in SERIALIZED_TYPES: + return False + return True + + +def schema_for_cls(cls: Type) -> dict[str,str]: + signature = inspect.signature(cls.__init__) + parameters = signature.parameters + spec = {} + for k, p in parameters.items(): + if is_valid_annotation(p.annotation): + spec[k] = str(p.annotation) + return spec + + +def generate_settings_schema(settings: LazySettings = global_settings) -> dict: + ret = deepcopy(settings.serialize()) + + ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool) + + for broker_name, broker_kwargs in settings.brokers.items(): + broker = get_broker(broker_name, broker_kwargs) + ret['brokers'][broker_name] = schema_for_cls(type(broker)) + + for producer_cls, producer_kwargs in settings.producers.items(): + ret['producers'][producer_cls] = schema_for_cls(getattr(producers, producer_cls)) + + ret['publish'] = {'default_broker': 'str'} + + return ret diff --git a/dispatcher/producers/scheduled.py b/dispatcher/producers/scheduled.py index 298dc61..07f4362 100644 --- a/dispatcher/producers/scheduled.py +++ b/dispatcher/producers/scheduled.py @@ -7,7 +7,7 @@ class ScheduledProducer(BaseProducer): - def __init__(self, task_schedule: dict): + def __init__(self, task_schedule: dict[str,dict[str,int]]): self.task_schedule = task_schedule self.scheduled_tasks: list[asyncio.Task] = [] super().__init__() diff --git a/docs/config.md b/docs/config.md index 7137512..09d4b36 100644 --- a/docs/config.md +++ b/docs/config.md @@ -55,6 +55,7 @@ The general structure is: ```yaml --- +version: # number service: pool_kwargs: # options @@ -70,6 +71,13 @@ publish: When providing `pool_kwargs`, those are the kwargs passed to `WorkerPool`, for example. +#### Version + +The version field is mandatory and must match the current config in the library. +This is validated against current code and saved in the [schema.json](schema.json) file. + +The version will be bumped when any breaking change happens. + #### Brokers Brokers relay messages which give instructions about code to run. diff --git a/schema.json b/schema.json new file mode 100644 index 0000000..be75890 --- /dev/null +++ b/schema.json @@ -0,0 +1,25 @@ +{ + "version": 2, + "brokers": { + "pg_notify": { + "config": "typing.Optional[dict]", + "async_connection_factory": "typing.Optional[str]", + "sync_connection_factory": "typing.Optional[str]", + "channels": "typing.Union[tuple, list]", + "default_publish_channel": "typing.Optional[str]" + } + }, + "producers": { + "ScheduledProducer": { + "task_schedule": "dict[str, dict[str, int]]" + } + }, + "service": { + "pool_kwargs": { + "max_workers": "" + } + }, + "publish": { + "default_broker": "str" + } +} diff --git a/tests/conftest.py b/tests/conftest.py index de0db5b..76836d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,6 +22,7 @@ CONNECTION_STRING = "dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777" BASIC_CONFIG = { + "version": 2, "brokers": { "pg_notify": { "channels": CHANNELS, diff --git a/tests/integration/publish/test_registry.py b/tests/integration/publish/test_registry.py index 4a6adcb..ed87a4e 100644 --- a/tests/integration/publish/test_registry.py +++ b/tests/integration/publish/test_registry.py @@ -14,7 +14,7 @@ def test_method(): dmethod = registry.get_from_callable(test_method) # These settings do not specify a default channel, that is the main point - with temporary_settings({'brokers': {'pg_notify': {'config': conn_config}}}): + with temporary_settings({'version': 2, 'brokers': {'pg_notify': {'config': conn_config}}}): # Can not run a method if we do not have a queue with pytest.raises(ValueError): diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 28da8f1..bbbd16e 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,6 +1,11 @@ +import json + +import yaml + import pytest -from dispatcher.config import DispatcherSettings, LazySettings +from dispatcher.config import DispatcherSettings, LazySettings, temporary_settings +from dispatcher.factories import generate_settings_schema def test_settings_reference_unconfigured(): @@ -12,7 +17,7 @@ def test_settings_reference_unconfigured(): def test_configured_settings(): settings = LazySettings() - settings._wrapped = DispatcherSettings({'brokers': {'pg_notify': {'config': {}}}}) + settings._wrapped = DispatcherSettings({'version': 2, 'brokers': {'pg_notify': {'config': {}}}}) 'pg_notify' in settings.brokers @@ -27,3 +32,21 @@ def test_serialize_settings(test_settings): re_loaded = DispatcherSettings(config) assert re_loaded.brokers == test_settings.brokers assert re_loaded.service == test_settings.service + + +def test_version_validated(): + with pytest.raises(RuntimeError) as exc: + DispatcherSettings({}) + assert 'config version must match this' in str(exc) + + +def test_schema_is_current(): + with open('dispatcher.yml', 'r') as f: + file_contents = f.read() + demo_data = yaml.safe_load(file_contents) + with temporary_settings(demo_data): + expect_schema = generate_settings_schema() + with open('schema.json', 'r') as sch_f: + schema_contents = sch_f.read() + schema_data = json.loads(schema_contents) + assert schema_data == expect_schema From 739861be9b827f6512719d18e04d14a4064f692d Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Tue, 18 Feb 2025 09:38:54 -0500 Subject: [PATCH 13/19] Fix linters --- dispatcher/factories.py | 6 +++--- dispatcher/producers/scheduled.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dispatcher/factories.py b/dispatcher/factories.py index 455b592..8f98c64 100644 --- a/dispatcher/factories.py +++ b/dispatcher/factories.py @@ -1,6 +1,6 @@ -from typing import Iterable, Optional, Type, get_origin, get_args import inspect from copy import deepcopy +from typing import Iterable, Optional, Type, get_args, get_origin from dispatcher import producers from dispatcher.brokers import get_broker @@ -28,7 +28,7 @@ def pool_from_settings(settings: LazySettings = global_settings): return WorkerPool(**kwargs) -def brokers_from_settings(settings: LazySettings = global_settings) -> BaseBroker: +def brokers_from_settings(settings: LazySettings = global_settings) -> Iterable[BaseBroker]: return [get_broker(broker_name, broker_kwargs) for broker_name, broker_kwargs in settings.brokers.items()] @@ -105,7 +105,7 @@ def is_valid_annotation(annotation): return True -def schema_for_cls(cls: Type) -> dict[str,str]: +def schema_for_cls(cls: Type) -> dict[str, str]: signature = inspect.signature(cls.__init__) parameters = signature.parameters spec = {} diff --git a/dispatcher/producers/scheduled.py b/dispatcher/producers/scheduled.py index 07f4362..edc4fd7 100644 --- a/dispatcher/producers/scheduled.py +++ b/dispatcher/producers/scheduled.py @@ -7,7 +7,7 @@ class ScheduledProducer(BaseProducer): - def __init__(self, task_schedule: dict[str,dict[str,int]]): + def __init__(self, task_schedule: dict[str, dict[str, int]]): self.task_schedule = task_schedule self.scheduled_tasks: list[asyncio.Task] = [] super().__init__() From a3517063c64da4af951e0d21a059be3e54c9f08d Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Tue, 18 Feb 2025 09:40:18 -0500 Subject: [PATCH 14/19] Fix link --- docs/config.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/config.md b/docs/config.md index 09d4b36..28036c2 100644 --- a/docs/config.md +++ b/docs/config.md @@ -74,7 +74,7 @@ When providing `pool_kwargs`, those are the kwargs passed to `WorkerPool`, for e #### Version The version field is mandatory and must match the current config in the library. -This is validated against current code and saved in the [schema.json](schema.json) file. +This is validated against current code and saved in the [schema.json](../schema.json) file. The version will be bumped when any breaking change happens. From 5a5125bad2f3698263239934a46907754d888446 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Wed, 19 Feb 2025 09:04:24 -0500 Subject: [PATCH 15/19] Add type hints to async generator --- dispatcher/brokers/base.py | 5 +++-- dispatcher/brokers/pg_notify.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dispatcher/brokers/base.py b/dispatcher/brokers/base.py index 5f493a2..48fe305 100644 --- a/dispatcher/brokers/base.py +++ b/dispatcher/brokers/base.py @@ -1,8 +1,9 @@ -from typing import Optional, Protocol +from typing import AsyncGenerator, Optional, Protocol class BaseBroker(Protocol): - async def aprocess_notify(self, connected_callback=None): ... + # NOTE: should be async def, but conflicts with lack of yield statement, which implementers would have + def aprocess_notify(self, connected_callback=None) -> AsyncGenerator[tuple[str, str], None]: ... async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: ... diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index 71709d7..28fb4f6 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from typing import AsyncGenerator, Callable, Optional, Union import psycopg @@ -100,7 +100,7 @@ async def aget_connection(self) -> psycopg.AsyncConnection: return connection # slightly weird due to MyPY return self._async_connection - async def aprocess_notify(self, connected_callback: Optional[Callable] = None): # public + async def aprocess_notify(self, connected_callback: Optional[Callable] = None) -> AsyncGenerator[tuple[str, str], None]: # public connection = await self.aget_connection() async with connection.cursor() as cur: for channel in self.channels: From ea9c940b27020ff89f36b268658f5c4ad8a3f1f3 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Wed, 19 Feb 2025 09:10:52 -0500 Subject: [PATCH 16/19] Add type hint to make clear default enforcing --- dispatcher/brokers/pg_notify.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index 28fb4f6..7b19b2e 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -70,7 +70,7 @@ def __init__( self.channels = channels self.default_publish_channel = default_publish_channel - def get_publish_channel(self, channel: Optional[str] = None): + def get_publish_channel(self, channel: Optional[str] = None) -> str: "Handle default for the publishing channel for calls to publish_message, shared sync and async" if channel is not None: return channel From 883a8fbea05132fe99db00d1e546261284e32ba5 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Wed, 19 Feb 2025 09:19:29 -0500 Subject: [PATCH 17/19] Test fix from review --- tests/unit/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index bbbd16e..70717f7 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -18,7 +18,7 @@ def test_settings_reference_unconfigured(): def test_configured_settings(): settings = LazySettings() settings._wrapped = DispatcherSettings({'version': 2, 'brokers': {'pg_notify': {'config': {}}}}) - 'pg_notify' in settings.brokers + assert 'pg_notify' in settings.brokers def test_serialize_settings(test_settings): From 39c7259f9ab2d606372910096d2df8cae111df71 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Wed, 19 Feb 2025 09:20:30 -0500 Subject: [PATCH 18/19] Schema re-gen docs --- docs/config.md | 6 ++++++ tools/gen_schema.py | 9 +++++++++ 2 files changed, 15 insertions(+) create mode 100644 tools/gen_schema.py diff --git a/docs/config.md b/docs/config.md index 28036c2..1116787 100644 --- a/docs/config.md +++ b/docs/config.md @@ -78,6 +78,12 @@ This is validated against current code and saved in the [schema.json](../schema. The version will be bumped when any breaking change happens. +You can re-generate the schema after making changes by running: + +``` +python tools/gen_schema.py > schema.json +``` + #### Brokers Brokers relay messages which give instructions about code to run. diff --git a/tools/gen_schema.py b/tools/gen_schema.py new file mode 100644 index 0000000..50e2e2d --- /dev/null +++ b/tools/gen_schema.py @@ -0,0 +1,9 @@ +import json +from dispatcher.config import setup +from dispatcher.factories import generate_settings_schema + +setup(file_path='dispatcher.yml') + +data = generate_settings_schema() + +print(json.dumps(data, indent=2)) From 1e20eae562ca2597d6ad7e3bbcbb65caf405acf4 Mon Sep 17 00:00:00 2001 From: Alan Rominger Date: Wed, 19 Feb 2025 12:10:14 -0500 Subject: [PATCH 19/19] Add yield to keep protcol method async --- dispatcher/brokers/base.py | 4 ++-- dispatcher/brokers/pg_notify.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/dispatcher/brokers/base.py b/dispatcher/brokers/base.py index 48fe305..8a6ffc1 100644 --- a/dispatcher/brokers/base.py +++ b/dispatcher/brokers/base.py @@ -2,8 +2,8 @@ class BaseBroker(Protocol): - # NOTE: should be async def, but conflicts with lack of yield statement, which implementers would have - def aprocess_notify(self, connected_callback=None) -> AsyncGenerator[tuple[str, str], None]: ... + async def aprocess_notify(self, connected_callback=None) -> AsyncGenerator[tuple[str, str], None]: + yield ('', '') # yield affects CPython type https://github.com/python/mypy/pull/18422 async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: ... diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index 7b19b2e..4ffddb8 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -116,6 +116,11 @@ async def aprocess_notify(self, connected_callback: Optional[Callable] = None) - yield notify.channel, notify.payload async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public + """asyncio way to publish a message, used to send control in control-and-reply + + Not strictly necessary for the service itself if it sends replies in the workers, + but this may change in the future. + """ connection = await self.aget_connection() channel = self.get_publish_channel(channel)