diff --git a/README.md b/README.md index 045177a..7f2527b 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,18 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/ansible/dispatcher/blob/main/LICENSE) -Working space for dispatcher prototyping - -This is firstly intended to be a code split of: +This is intended to be a working space for prototyping a code split of: As a part of doing the split, we also want to resolve a number of long-standing design and sustainability issues, thus, asyncio. +The philosophy of the dispatcher is to have a limited scope +as a "local" runner of background tasks, but to be composable +so that it can be "wrapped" easily to enable clustering and +distributed task management by apps using it. + Licensed under [Apache Software License 2.0](LICENSE) ### Usage @@ -37,20 +40,11 @@ def print_hello(): print('hello world!!') ``` -#### Dispatcher service - -The dispatcher service needs to be running before you submit tasks. -This does not make any attempts at message durability or confirmation. -If you submit a task in an outage of the service, it will be dropped. - -There are 2 ways to run the dispatcher service: - -- Importing and running (code snippet below) -- A CLI entrypoint `dispatcher-standalone` for demo purposes +Additionally, you need to configure dispatcher somewhere in your import path. +This tells dispatcher how to submit tasks to be ran. ```python -from dispatcher.main import DispatcherMain -import asyncio +from dispatcher.config import setup config = { "producers": { @@ -63,13 +57,29 @@ config = { }, "pool": {"max_workers": 4}, } -loop = asyncio.get_event_loop() -dispatcher = DispatcherMain(config) +setup(config) +``` + +For more on how to set up and the allowed options in the config, +see the section [config](docs/config.md) docs. + +#### Dispatcher service + +The dispatcher service needs to be running before you submit tasks. +This does not make any attempts at message durability or confirmation. +If you submit a task in an outage of the service, it will be dropped. + +There are 2 ways to run the dispatcher service: -try: - loop.run_until_complete(dispatcher.main()) -finally: - loop.close() +- Importing and running (code snippet below) +- A CLI entrypoint `dispatcher-standalone` for demo purposes + +```python +from dispatcher import run_service + +# After the setup() method has been called + +run_service() ``` Configuration tells how to connect to postgres, and what channel(s) to listen to. @@ -88,6 +98,8 @@ The following code will submit `print_hello` to run in the background dispatcher ```python from test_methods import print_hello +# After the setup() method has been called + print_hello.delay() ``` @@ -96,6 +108,8 @@ Also valid: ```python from test_methods import print_hello +# After the setup() method has been called + print_hello.apply_async(args=[], kwargs={}) ``` diff --git a/dispatcher.yml b/dispatcher.yml index a34eab1..be9bc3e 100644 --- a/dispatcher.yml +++ b/dispatcher.yml @@ -1,20 +1,25 @@ # Demo config --- -pool: - max_workers: 3 -producers: - brokers: - # List of channels to listen on +version: 2 +service: + pool_kwargs: + max_workers: 4 +brokers: + pg_notify: + config: + conninfo: dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777 + sync_connection_factory: dispatcher.brokers.pg_notify.connection_saver channels: - test_channel - test_channel2 - test_channel3 - pg_notify: - # Database connection details - conninfo: dbname=dispatch_db user=dispatch password=dispatching host=localhost - port=55777 - scheduled: - 'lambda: __import__("time").sleep(1)': - schedule: 3 - 'lambda: __import__("time").sleep(2)': - schedule: 3 + default_publish_channel: test_channel +producers: + ScheduledProducer: + task_schedule: + 'lambda: __import__("time").sleep(1)': + schedule: 3 + 'lambda: __import__("time").sleep(2)': + schedule: 3 +publish: + default_broker: pg_notify diff --git a/dispatcher/__init__.py b/dispatcher/__init__.py index e69de29..ae79987 100644 --- a/dispatcher/__init__.py +++ b/dispatcher/__init__.py @@ -0,0 +1,21 @@ +import asyncio +import logging + +from dispatcher.factories import from_settings + +logger = logging.getLogger(__name__) + + +def run_service() -> None: + """ + Runs dispatcher task service (runs tasks due to messages from brokers and other local producers) + Before calling this you need to configure by calling dispatcher.config.setup + """ + loop = asyncio.get_event_loop() + dispatcher = from_settings() + try: + loop.run_until_complete(dispatcher.main()) + except KeyboardInterrupt: + logger.info('Dispatcher stopped by KeyboardInterrupt') + finally: + loop.close() diff --git a/dispatcher/brokers/__init__.py b/dispatcher/brokers/__init__.py new file mode 100644 index 0000000..cd811a5 --- /dev/null +++ b/dispatcher/brokers/__init__.py @@ -0,0 +1,20 @@ +import importlib +from types import ModuleType + +from .base import BaseBroker + + +def get_broker_module(broker_name) -> ModuleType: + "Static method to alias import_module so we use a consistent import path" + return importlib.import_module(f'dispatcher.brokers.{broker_name}') + + +def get_broker(broker_name: str, broker_config: dict, **overrides) -> BaseBroker: + """ + Given the name of the broker in the settings, and the data under that entry in settings, + return the broker object. + """ + broker_module = get_broker_module(broker_name) + kwargs = broker_config.copy() + kwargs.update(overrides) + return broker_module.Broker(**kwargs) diff --git a/dispatcher/brokers/base.py b/dispatcher/brokers/base.py new file mode 100644 index 0000000..8a6ffc1 --- /dev/null +++ b/dispatcher/brokers/base.py @@ -0,0 +1,14 @@ +from typing import AsyncGenerator, Optional, Protocol + + +class BaseBroker(Protocol): + async def aprocess_notify(self, connected_callback=None) -> AsyncGenerator[tuple[str, str], None]: + yield ('', '') # yield affects CPython type https://github.com/python/mypy/pull/18422 + + async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: ... + + async def aclose(self) -> None: ... + + def publish_message(self, channel=None, message=None): ... + + def close(self): ... diff --git a/dispatcher/brokers/pg_notify.py b/dispatcher/brokers/pg_notify.py index bb36032..4ffddb8 100644 --- a/dispatcher/brokers/pg_notify.py +++ b/dispatcher/brokers/pg_notify.py @@ -1,7 +1,10 @@ import logging +from typing import AsyncGenerator, Callable, Optional, Union import psycopg +from dispatcher.utils import resolve_callable + logger = logging.getLogger(__name__) @@ -13,87 +16,192 @@ """ -# TODO: get database data from settings -# # As Django settings, may not use -# DATABASES = { -# "default": { -# "ENGINE": "django.db.backends.postgresql", -# "HOST": os.getenv("DB_HOST", "127.0.0.1"), -# "PORT": os.getenv("DB_PORT", 55777), -# "USER": os.getenv("DB_USER", "dispatch"), -# "PASSWORD": os.getenv("DB_PASSWORD", "dispatching"), -# "NAME": os.getenv("DB_NAME", "dispatch_db"), -# } -# } - - -async def aget_connection(config): - return await psycopg.AsyncConnection.connect(**config, autocommit=True) - - -def get_connection(config): - return psycopg.Connection.connect(**config, autocommit=True) - - -async def aprocess_notify(connection, channels, connected_callback=None): - async with connection.cursor() as cur: - for channel in channels: - await cur.execute(f"LISTEN {channel};") - logger.info(f"Set up pg_notify listening on channel '{channel}'") - - if connected_callback: - await connected_callback() - - while True: - logger.debug('Starting listening for pg_notify notifications') - async for notify in connection.notifies(): - yield notify.channel, notify.payload - - -async def apublish_message(connection, channel, payload=None): - async with connection.cursor() as cur: - if not payload: - await cur.execute(f'NOTIFY {channel};') +async def acreate_connection(**config) -> psycopg.AsyncConnection: + "Create a new asyncio connection" + return await psycopg.AsyncConnection.connect(**config) + + +def create_connection(**config) -> psycopg.Connection: + return psycopg.Connection.connect(**config) + + +class Broker: + + def __init__( + self, + config: Optional[dict] = None, + async_connection_factory: Optional[str] = None, + sync_connection_factory: Optional[str] = None, + sync_connection: Optional[psycopg.Connection] = None, + async_connection: Optional[psycopg.AsyncConnection] = None, + channels: Union[tuple, list] = (), + default_publish_channel: Optional[str] = None, + ) -> None: + """ + config - kwargs to psycopg connect classes, if creating connection this way + (a)sync_connection_factory - importable path to callback for creating + the psycopg connection object, the normal or synchronous version + this will have the config passed as kwargs, if that is also given + async_connection - directly pass the async connection object + sync_connection - directly pass the async connection object + channels - listening channels for the service and used for control-and-reply + default_publish_channel - if not specified on task level or in the submission + by default messages will be sent to this channel. + this should be one of the listening channels for messages to be received. + """ + if not (config or async_connection_factory or async_connection): + raise RuntimeError('Must specify either config or async_connection_factory') + + if not (config or sync_connection_factory or sync_connection): + raise RuntimeError('Must specify either config or sync_connection_factory') + + self._async_connection_factory = async_connection_factory + self._async_connection = async_connection + + self._sync_connection_factory = sync_connection_factory + self._sync_connection = sync_connection + + if config: + self._config: dict = config.copy() + self._config['autocommit'] = True else: - await cur.execute(f"NOTIFY {channel}, '{payload}';") - - -def get_django_connection(): - try: - from django.conf import ImproperlyConfigured - from django.db import connection as pg_connection - except ImportError: - return None - else: - try: - if pg_connection.connection is None: - pg_connection.connect() - if pg_connection.connection is None: - raise RuntimeError('Unexpectedly could not connect to postgres for pg_notify actions') - return pg_connection.connection - except ImproperlyConfigured: - return None - - -def publish_message(queue, message, config=None, connection=None, new_connection=False): - conn = None - if connection: - conn = connection - - if (not conn) and (not new_connection): - conn = get_django_connection() - - created_new_conn = False - if not conn: - if config is None: - raise RuntimeError('Could not use Django connection, and no postgres config supplied') - conn = get_connection(config) - created_new_conn = True - - with conn.cursor() as cur: - cur.execute('SELECT pg_notify(%s, %s);', (queue, message)) - - logger.debug(f'Sent pg_notify message to {queue}') - - if created_new_conn: - conn.close() + self._config = {} + + self.channels = channels + self.default_publish_channel = default_publish_channel + + def get_publish_channel(self, channel: Optional[str] = None) -> str: + "Handle default for the publishing channel for calls to publish_message, shared sync and async" + if channel is not None: + return channel + elif self.default_publish_channel is not None: + return self.default_publish_channel + elif len(self.channels) == 1: + # de-facto default channel, because there is only 1 + return self.channels[0] + + raise ValueError('Could not determine a channel to use publish to from settings or PGNotify config') + + # --- asyncio connection methods --- + + async def aget_connection(self) -> psycopg.AsyncConnection: + "Return existing connection or create a new one" + if not self._async_connection: + if self._async_connection_factory: + factory = resolve_callable(self._async_connection_factory) + if not factory: + raise RuntimeError(f'Could not import async connection factory {self._async_connection_factory}') + connection = await factory(**self._config) + elif self._config: + connection = await acreate_connection(**self._config) + else: + raise RuntimeError('Could not construct async connection for lack of config or factory') + self._async_connection = connection + return connection # slightly weird due to MyPY + return self._async_connection + + async def aprocess_notify(self, connected_callback: Optional[Callable] = None) -> AsyncGenerator[tuple[str, str], None]: # public + connection = await self.aget_connection() + async with connection.cursor() as cur: + for channel in self.channels: + await cur.execute(f"LISTEN {channel};") + logger.info(f"Set up pg_notify listening on channel '{channel}'") + + if connected_callback: + await connected_callback() + + while True: + logger.debug('Starting listening for pg_notify notifications') + async for notify in connection.notifies(): + yield notify.channel, notify.payload + + async def apublish_message(self, channel: Optional[str] = None, message: str = '') -> None: # public + """asyncio way to publish a message, used to send control in control-and-reply + + Not strictly necessary for the service itself if it sends replies in the workers, + but this may change in the future. + """ + connection = await self.aget_connection() + channel = self.get_publish_channel(channel) + + async with connection.cursor() as cur: + if not message: + await cur.execute(f'NOTIFY {channel};') + else: + await cur.execute(f"NOTIFY {channel}, '{message}';") + + logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') + + async def aclose(self) -> None: + if self._async_connection: + await self._async_connection.close() + self._async_connection = None + + # --- synchronous connection methods --- + + def get_connection(self) -> psycopg.Connection: + if not self._sync_connection: + if self._sync_connection_factory: + factory = resolve_callable(self._sync_connection_factory) + if not factory: + raise RuntimeError(f'Could not import connection factory {self._sync_connection_factory}') + connection = factory(**self._config) + elif self._config: + connection = create_connection(**self._config) + else: + raise RuntimeError('Could not construct connection for lack of config or factory') + self._sync_connection = connection + return connection + return self._sync_connection + + def publish_message(self, channel: Optional[str] = None, message: str = '') -> None: + connection = self.get_connection() + channel = self.get_publish_channel(channel) + + with connection.cursor() as cur: + if message: + cur.execute('SELECT pg_notify(%s, %s);', (channel, message)) + else: + cur.execute(f'NOTIFY {channel};') + + logger.debug(f'Sent pg_notify message of {len(message)} chars to {channel}') + + def close(self) -> None: + if self._sync_connection: + self._sync_connection.close() + self._sync_connection = None + + +class ConnectionSaver: + def __init__(self) -> None: + self._connection: Optional[psycopg.Connection] = None + self._async_connection: Optional[psycopg.AsyncConnection] = None + + +connection_save = ConnectionSaver() + + +def connection_saver(**config) -> psycopg.Connection: + """ + This mimics the behavior of Django for tests and demos + Philosophically, this is used by an application that uses an ORM, + or otherwise has its own connection management logic. + Dispatcher does not manage connections, so this a simulation of that. + """ + if connection_save._connection is None: + config['autocommit'] = True + connection_save._connection = create_connection(**config) + return connection_save._connection + + +async def async_connection_saver(**config) -> psycopg.AsyncConnection: + """ + This mimics the behavior of Django for tests and demos + Philosophically, this is used by an application that uses an ORM, + or otherwise has its own connection management logic. + Dispatcher does not manage connections, so this a simulation of that. + """ + if connection_save._async_connection is None: + config['autocommit'] = True + connection_save._async_connection = await acreate_connection(**config) + return connection_save._async_connection diff --git a/dispatcher/cli.py b/dispatcher/cli.py index 3b77dca..0b9276a 100644 --- a/dispatcher/cli.py +++ b/dispatcher/cli.py @@ -1,12 +1,10 @@ import argparse -import asyncio import logging import os import sys -import yaml - -from dispatcher.main import DispatcherMain +from dispatcher import run_service +from dispatcher.config import setup logger = logging.getLogger(__name__) @@ -32,16 +30,6 @@ def standalone() -> None: logger.debug(f"Configured standard out logging at {args.log_level} level") - with open(args.config, 'r') as f: - config_content = f.read() - - config = yaml.safe_load(config_content) + setup(file_path=args.config) - loop = asyncio.get_event_loop() - dispatcher = DispatcherMain(config) - try: - loop.run_until_complete(dispatcher.main()) - except KeyboardInterrupt: - logger.info('CLI entry point leaving') - finally: - loop.close() + run_service() diff --git a/dispatcher/config.py b/dispatcher/config.py new file mode 100644 index 0000000..6c86d5d --- /dev/null +++ b/dispatcher/config.py @@ -0,0 +1,78 @@ +import os +from contextlib import contextmanager +from typing import Optional + +import yaml + + +class DispatcherSettings: + def __init__(self, config: dict) -> None: + self.version = 2 + if config.get('version') != self.version: + raise RuntimeError(f'Current config version is {self.version}, config version must match this') + self.brokers: dict = config.get('brokers', {}) + self.producers: dict = config.get('producers', {}) + self.service: dict = config.get('service', {}) + self.publish: dict = config.get('publish', {}) + + # Automatic defaults + if 'pool_kwargs' not in self.service: + self.service['pool_kwargs'] = {} + if 'max_workers' not in self.service['pool_kwargs']: + self.service['pool_kwargs']['max_workers'] = 3 + + # TODO: firmly planned sections of config for later + # self.callbacks: dict = config.get('callbacks', {}) + # self.options: dict = config.get('options', {}) + + def serialize(self): + return dict(version=self.version, brokers=self.brokers, producers=self.producers, service=self.service, publish=self.publish) + + +def settings_from_file(path: str) -> DispatcherSettings: + with open(path, 'r') as f: + config_content = f.read() + + config = yaml.safe_load(config_content) + return DispatcherSettings(config) + + +def settings_from_env() -> DispatcherSettings: + if file_path := os.getenv('DISPATCHER_CONFIG_FILE'): + return settings_from_file(file_path) + raise RuntimeError('Dispatcher not configured, set DISPATCHER_CONFIG_FILE or call dispatcher.config.setup') + + +class LazySettings: + def __init__(self) -> None: + self._wrapped: Optional[DispatcherSettings] = None + + def __getattr__(self, name): + if self._wrapped is None: + self._setup() + return getattr(self._wrapped, name) + + def _setup(self) -> None: + self._wrapped = settings_from_env() + + +settings = LazySettings() + + +def setup(config: Optional[dict] = None, file_path: Optional[str] = None): + if config: + settings._wrapped = DispatcherSettings(config) + elif file_path: + settings._wrapped = settings_from_file(file_path) + else: + settings._wrapped = settings_from_env() + + +@contextmanager +def temporary_settings(config): + prior_settings = settings._wrapped + try: + settings._wrapped = DispatcherSettings(config) + yield settings + finally: + settings._wrapped = prior_settings diff --git a/dispatcher/control.py b/dispatcher/control.py index 5e52260..35ae3d7 100644 --- a/dispatcher/control.py +++ b/dispatcher/control.py @@ -3,13 +3,19 @@ import logging import time import uuid -from types import SimpleNamespace +from typing import Optional -from dispatcher.producers.brokered import BrokeredProducer +from dispatcher.factories import get_broker +from dispatcher.producers import BrokeredProducer logger = logging.getLogger('awx.main.dispatch.control') +class ControlEvents: + def __init__(self) -> None: + self.exit_event = asyncio.Event() + + class ControlCallbacks: """This calls follows the same structure as the DispatcherMain class @@ -22,20 +28,17 @@ def __init__(self, queuename, send_data, expected_replies): self.expected_replies = expected_replies self.received_replies = [] - self.events = self._create_events() + self.events = ControlEvents() self.shutting_down = False - def _create_events(self): - return SimpleNamespace(exit_event=asyncio.Event()) - - async def process_message(self, payload, broker=None, channel=None): + async def process_message(self, payload, producer=None, channel=None): self.received_replies.append(payload) if self.expected_replies and (len(self.received_replies) >= self.expected_replies): self.events.exit_event.set() async def connected_callback(self, producer) -> None: payload = json.dumps(self.send_data) - await producer.notify(self.queuename, payload) + await producer.notify(channel=self.queuename, message=payload) logger.info('Sent control message, expecting replies soon') def fatal_error_callback(self, *args): @@ -55,10 +58,10 @@ def fatal_error_callback(self, *args): class Control(object): - def __init__(self, queue, config=None, async_connection=None): + def __init__(self, broker_name: str, broker_config: dict, queue: Optional[str] = None) -> None: self.queuename = queue - self.config = config - self.async_connection = async_connection + self.broker_name = broker_name + self.broker_config = broker_config def running(self, *args, **kwargs): return self.control_with_reply('running', *args, **kwargs) @@ -90,11 +93,8 @@ async def acontrol_with_reply_internal(self, producer, send_data, expected_repli return [json.loads(payload) for payload in control_callbacks.received_replies] def make_producer(self, reply_queue): - if self.async_connection: - conn_kwargs = {'connection': self.async_connection} - else: - conn_kwargs = {'config': self.config} - return BrokeredProducer(broker='pg_notify', channels=[reply_queue], **conn_kwargs) + broker = get_broker(self.broker_name, self.broker_config, channels=[reply_queue]) + return BrokeredProducer(broker, close_on_exit=True) async def acontrol_with_reply(self, command, expected_replies=1, timeout=1, data=None): reply_queue = Control.generate_reply_queue_name() @@ -117,10 +117,6 @@ def control_with_reply(self, command, expected_replies=1, timeout=1, data=None): logger.info('control-and-reply {} to {}'.format(command, self.queuename)) start = time.time() reply_queue = Control.generate_reply_queue_name() - - if (not self.config) and (not self.async_connection): - raise RuntimeError('Must use a new psycopg connection to do control-and-reply') - send_data = {'control': command, 'reply_to': reply_queue} if data: send_data['control_data'] = data @@ -137,13 +133,12 @@ def control_with_reply(self, command, expected_replies=1, timeout=1, data=None): logger.info(f'control-and-reply message returned in {time.time() - start} seconds') return replies - # NOTE: this is the synchronous version, only to be used for no-reply def control(self, command, data=None): - from dispatcher.brokers.pg_notify import publish_message - + "Send message in fire-and-forget mode, as synchronous code. Only for no-reply control." send_data = {'control': command} if data: send_data['control_data'] = data payload = json.dumps(send_data) - publish_message(self.queuename, payload, config=self.config) + broker = get_broker(self.broker_name, self.broker_config) + broker.publish_message(channel=self.queuename, message=payload) diff --git a/dispatcher/factories.py b/dispatcher/factories.py new file mode 100644 index 0000000..8f98c64 --- /dev/null +++ b/dispatcher/factories.py @@ -0,0 +1,132 @@ +import inspect +from copy import deepcopy +from typing import Iterable, Optional, Type, get_args, get_origin + +from dispatcher import producers +from dispatcher.brokers import get_broker +from dispatcher.brokers.base import BaseBroker +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings +from dispatcher.control import Control +from dispatcher.main import DispatcherMain +from dispatcher.pool import WorkerPool +from dispatcher.process import ProcessManager + +""" +Creates objects from settings, +This is kept separate from the settings and the class definitions themselves, +which is to avoid import dependencies. +""" + +# ---- Service objects ---- + + +def pool_from_settings(settings: LazySettings = global_settings): + kwargs = settings.service.get('pool_kwargs', {}).copy() + kwargs['settings'] = settings + kwargs['process_manager'] = ProcessManager() # TODO: use process_manager_cls from settings + return WorkerPool(**kwargs) + + +def brokers_from_settings(settings: LazySettings = global_settings) -> Iterable[BaseBroker]: + return [get_broker(broker_name, broker_kwargs) for broker_name, broker_kwargs in settings.brokers.items()] + + +def producers_from_settings(settings: LazySettings = global_settings) -> Iterable[producers.BaseProducer]: + producer_objects = [] + for broker in brokers_from_settings(settings=settings): + producer = producers.BrokeredProducer(broker=broker) + producer_objects.append(producer) + + for producer_cls, producer_kwargs in settings.producers.items(): + producer_objects.append(getattr(producers, producer_cls)(**producer_kwargs)) + + return producer_objects + + +def from_settings(settings: LazySettings = global_settings) -> DispatcherMain: + """ + Returns the main dispatcher object, used for running the background task service. + You could initialize this yourself, but using the shared settings allows for consistency + between the service, publisher, and any other interacting processes. + """ + producers = producers_from_settings(settings=settings) + pool = pool_from_settings(settings=settings) + return DispatcherMain(producers, pool) + + +# ---- Publisher objects ---- + + +def _get_publisher_broker_name(publish_broker: Optional[str] = None, settings: LazySettings = global_settings) -> str: + if publish_broker: + return publish_broker + elif len(settings.brokers) == 1: + return list(settings.brokers.keys())[0] + elif 'default_broker' in settings.publish: + return settings.publish['default_broker'] + else: + raise RuntimeError(f'Could not determine which broker to publish with between options {list(settings.brokers.keys())}') + + +def get_publisher_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides) -> BaseBroker: + """ + An asynchronous publisher is the ideal choice for submitting control-and-reply actions. + This returns an asyncio broker of the default publisher type. + + If channels are specified, these completely replace the channel list from settings. + For control-and-reply, this will contain only the reply_to channel, to not receive + unrelated traffic. + """ + publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) + return get_broker(publish_broker, settings.brokers[publish_broker], **overrides) + + +def get_control_from_settings(publish_broker: Optional[str] = None, settings: LazySettings = global_settings, **overrides): + publish_broker = _get_publisher_broker_name(publish_broker=publish_broker, settings=settings) + broker_options = settings.brokers[publish_broker].copy() + broker_options.update(overrides) + return Control(publish_broker, broker_options) + + +# ---- Schema generation ---- + +SERIALIZED_TYPES = (int, str, dict, type(None), tuple, list) + + +def is_valid_annotation(annotation): + if get_origin(annotation): + for arg in get_args(annotation): + if not is_valid_annotation(arg): + return False + else: + if annotation not in SERIALIZED_TYPES: + return False + return True + + +def schema_for_cls(cls: Type) -> dict[str, str]: + signature = inspect.signature(cls.__init__) + parameters = signature.parameters + spec = {} + for k, p in parameters.items(): + if is_valid_annotation(p.annotation): + spec[k] = str(p.annotation) + return spec + + +def generate_settings_schema(settings: LazySettings = global_settings) -> dict: + ret = deepcopy(settings.serialize()) + + ret['service']['pool_kwargs'] = schema_for_cls(WorkerPool) + + for broker_name, broker_kwargs in settings.brokers.items(): + broker = get_broker(broker_name, broker_kwargs) + ret['brokers'][broker_name] = schema_for_cls(type(broker)) + + for producer_cls, producer_kwargs in settings.producers.items(): + ret['producers'][producer_cls] = schema_for_cls(getattr(producers, producer_cls)) + + ret['publish'] = {'default_broker': 'str'} + + return ret diff --git a/dispatcher/main.py b/dispatcher/main.py index 419fd26..ce0b398 100644 --- a/dispatcher/main.py +++ b/dispatcher/main.py @@ -3,13 +3,10 @@ import logging import signal from types import SimpleNamespace -from typing import Optional, Union +from typing import Iterable, Optional from dispatcher.pool import WorkerPool -from dispatcher.producers.base import BaseProducer -from dispatcher.producers.brokered import BrokeredProducer -from dispatcher.producers.scheduled import ScheduledProducer -from dispatcher.utils import MODULE_METHOD_DELIMITER +from dispatcher.producers import BaseProducer logger = logging.getLogger(__name__) @@ -79,7 +76,7 @@ def __init__(self) -> None: class DispatcherMain: - def __init__(self, config: dict): + def __init__(self, producers: Iterable[BaseProducer], pool: WorkerPool): self.delayed_messages: list[SimpleNamespace] = [] self.received_count = 0 self.control_count = 0 @@ -88,21 +85,11 @@ def __init__(self, config: dict): # Lock for file descriptor mgmnt - hold lock when forking or connecting, to avoid DNS hangs # psycopg is well-behaved IFF you do not connect while forking, compare to AWX __clean_on_fork__ self.fd_lock = asyncio.Lock() - self.pool = WorkerPool(config.get('pool', {}).get('max_workers', 3), self.fd_lock) - - # Initialize all the producers, this should not start anything, just establishes objects - self.producers: list[Union[ScheduledProducer, BrokeredProducer]] = [] - if 'producers' in config: - producer_config = config['producers'] - if 'brokers' in producer_config: - for broker_name, broker_config in producer_config['brokers'].items(): - # TODO: import from the broker module here, some importlib stuff - # TODO: make channels specific to broker, probably - if broker_name != 'pg_notify': - continue - self.producers.append(BrokeredProducer(broker=broker_name, config=broker_config, channels=producer_config['brokers']['channels'])) - if 'scheduled' in producer_config: - self.producers.append(ScheduledProducer(producer_config['scheduled'])) + + # Save the associated dispatcher objects, usually created by factories + # expected that these are not yet running any tasks + self.pool = pool + self.producers = producers self.events: DispatcherEvents = DispatcherEvents() @@ -183,7 +170,7 @@ def create_delayed_task(self, message: dict) -> None: capsule.task = new_task self.delayed_messages.append(capsule) - async def process_message(self, payload: dict, broker: Optional[BrokeredProducer] = None, channel: Optional[str] = None) -> None: + async def process_message(self, payload: dict, producer: Optional[BaseProducer] = None, channel: Optional[str] = None) -> None: # Convert payload from client into python dict # TODO: more structured validation of the incoming payload from publishers if isinstance(payload, str): @@ -208,9 +195,9 @@ async def process_message(self, payload: dict, broker: Optional[BrokeredProducer # NOTE: control messages with reply should never be delayed, document this for users self.create_delayed_task(message) else: - await self.process_message_internal(message, broker=broker) + await self.process_message_internal(message, producer=producer) - async def process_message_internal(self, message: dict, broker=None) -> None: + async def process_message_internal(self, message: dict, producer=None) -> None: if 'control' in message: method = getattr(self.ctl_tasks, message['control']) control_data = message.get('control_data', {}) @@ -220,9 +207,8 @@ async def process_message_internal(self, message: dict, broker=None) -> None: self.control_count += 1 await self.pool.dispatch_task( { - 'task': f'dispatcher.brokers.{broker.broker}{MODULE_METHOD_DELIMITER}publish_message', + 'task': 'dispatcher.tasks.reply_to_control', 'args': [message['reply_to'], json.dumps(returned)], - 'kwargs': {'config': broker.config, 'new_connection': True}, 'uuid': f'control-{self.control_count}', 'control': 'reply', # for record keeping } diff --git a/dispatcher/pool.py b/dispatcher/pool.py index 0b25909..b55b758 100644 --- a/dispatcher/pool.py +++ b/dispatcher/pool.py @@ -4,6 +4,8 @@ from asyncio import Task from typing import Iterator, Optional +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings from dispatcher.process import ProcessManager, ProcessProxy from dispatcher.utils import DuplicateBehavior, MessageAction @@ -94,11 +96,12 @@ def __init__(self) -> None: class WorkerPool: - def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None): - self.num_workers = num_workers + def __init__(self, max_workers: int, process_manager: ProcessManager, settings: LazySettings = global_settings): + self.max_workers = max_workers self.workers: dict[int, PoolWorker] = {} + self.settings_stash: dict = settings.serialize() # These are passed to the workers to initialize dispatcher settings self.next_worker_id = 0 - self.process_manager = ProcessManager() + self.process_manager = process_manager self.queued_messages: list[dict] = [] # TODO: use deque, invent new kinds of logging anxiety self.read_results_task: Optional[Task] = None self.start_worker_task: Optional[Task] = None @@ -109,7 +112,6 @@ def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None): self.discard_count: int = 0 self.shutdown_timeout = 3 self.management_lock = asyncio.Lock() - self.fd_lock = fd_lock or asyncio.Lock() self.events: PoolEvents = PoolEvents() @@ -124,15 +126,15 @@ def received_count(self): async def start_working(self, dispatcher) -> None: self.read_results_task = asyncio.create_task(self.read_results_forever(), name='results_task') self.read_results_task.add_done_callback(dispatcher.fatal_error_callback) - self.management_task = asyncio.create_task(self.manage_workers(), name='management_task') + self.management_task = asyncio.create_task(self.manage_workers(forking_lock=dispatcher.fd_lock), name='management_task') self.management_task.add_done_callback(dispatcher.fatal_error_callback) self.timeout_task = asyncio.create_task(self.manage_timeout(), name='timeout_task') self.timeout_task.add_done_callback(dispatcher.fatal_error_callback) - async def manage_workers(self) -> None: + async def manage_workers(self, forking_lock: asyncio.Lock) -> None: """Enforces worker policy like min and max workers, and later, auto scale-down""" while not self.shutting_down: - while len(self.workers) < self.num_workers: + while len(self.workers) < self.max_workers: await self.up() # TODO: if all workers are busy, queue has unblocked work, below max_workers @@ -141,7 +143,7 @@ async def manage_workers(self) -> None: for worker in self.workers.values(): if worker.status == 'initialized': logger.debug(f'Starting subprocess for worker {worker.worker_id}') - async with self.fd_lock: # never fork while connecting + async with forking_lock: # never fork while connecting await worker.start() await self.events.management_event.wait() @@ -186,7 +188,12 @@ async def manage_timeout(self) -> None: self.events.timeout_event.clear() async def up(self) -> None: - process = self.process_manager.create_process((self.next_worker_id,)) + process = self.process_manager.create_process( + ( + self.settings_stash, + self.next_worker_id, + ) + ) worker = PoolWorker(self.next_worker_id, process) self.workers[self.next_worker_id] = worker self.next_worker_id += 1 diff --git a/dispatcher/process.py b/dispatcher/process.py index ecfce2c..5150b9c 100644 --- a/dispatcher/process.py +++ b/dispatcher/process.py @@ -46,7 +46,7 @@ def get_event_loop(self): self._loop = asyncio.get_event_loop() return self._loop - def create_process(self, args: Iterable[int | str], **kwargs) -> ProcessProxy: + def create_process(self, args: Iterable[int | str | dict], **kwargs) -> ProcessProxy: return ProcessProxy(args, self.finished_queue, **kwargs) async def read_finished(self) -> dict[str, Union[str, int]]: diff --git a/dispatcher/producers/__init__.py b/dispatcher/producers/__init__.py new file mode 100644 index 0000000..050e1a2 --- /dev/null +++ b/dispatcher/producers/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseProducer +from .brokered import BrokeredProducer +from .scheduled import ScheduledProducer + +__all__ = ['BaseProducer', 'BrokeredProducer', 'ScheduledProducer'] diff --git a/dispatcher/producers/base.py b/dispatcher/producers/base.py index c1da07d..6ebfd74 100644 --- a/dispatcher/producers/base.py +++ b/dispatcher/producers/base.py @@ -7,4 +7,11 @@ def __init__(self): class BaseProducer: - pass + + def __init__(self) -> None: + self.events = ProducerEvents() + self.produced_count = 0 + + async def start_producing(self, dispatcher) -> None: ... + + async def shutdown(self): ... diff --git a/dispatcher/producers/brokered.py b/dispatcher/producers/brokered.py index c7797e4..b40987b 100644 --- a/dispatcher/producers/brokered.py +++ b/dispatcher/producers/brokered.py @@ -2,26 +2,21 @@ import logging from typing import Optional -from dispatcher.brokers.pg_notify import aget_connection, aprocess_notify, apublish_message -from dispatcher.producers.base import BaseProducer, ProducerEvents +from dispatcher.brokers.base import BaseBroker +from dispatcher.producers.base import BaseProducer logger = logging.getLogger(__name__) class BrokeredProducer(BaseProducer): - def __init__(self, broker: str = 'pg_notify', config: Optional[dict] = None, channels: tuple = (), connection=None) -> None: - self.events = ProducerEvents() + def __init__(self, broker: BaseBroker, close_on_exit: bool = True) -> None: self.production_task: Optional[asyncio.Task] = None self.broker = broker - self.config = config - self.channels = channels - self.connection = connection - self.old_connection = bool(connection) + self.close_on_exit = close_on_exit self.dispatcher = None + super().__init__() async def start_producing(self, dispatcher) -> None: - await self.connect() - self.production_task = asyncio.create_task(self.produce_forever(dispatcher), name=f'{self.broker}_production') # TODO: implement connection retry logic self.production_task.add_done_callback(dispatcher.fatal_error_callback) @@ -31,22 +26,20 @@ def all_tasks(self) -> list[asyncio.Task]: return [self.production_task] return [] - async def connect(self): - if self.connection is None: - self.connection = await aget_connection(self.config) - async def connected_callback(self) -> None: - self.events.ready_event.set() + if self.events: + self.events.ready_event.set() if self.dispatcher: await self.dispatcher.connected_callback(self) async def produce_forever(self, dispatcher) -> None: self.dispatcher = dispatcher - async for channel, payload in aprocess_notify(self.connection, self.channels, connected_callback=self.connected_callback): - await dispatcher.process_message(payload, broker=self, channel=channel) + async for channel, payload in self.broker.aprocess_notify(connected_callback=self.connected_callback): + self.produced_count += 1 + await dispatcher.process_message(payload, producer=self, channel=channel) - async def notify(self, channel: str, payload: Optional[str] = None) -> None: - await apublish_message(self.connection, channel, payload=payload) + async def notify(self, channel: Optional[str] = None, message: str = '') -> None: + await self.broker.apublish_message(channel=channel, message=message) async def shutdown(self) -> None: if self.production_task: @@ -60,8 +53,6 @@ async def shutdown(self) -> None: if not hasattr(self.production_task, '_dispatcher_tb_logged'): logger.exception(f'Broker {self.broker} shutdown saw an unexpected exception from production task') self.production_task = None - if not self.old_connection: - if self.connection: - logger.debug(f'Closing {self.broker} connection') - await self.connection.close() - self.connection = None + if self.close_on_exit: + logger.debug(f'Closing {self.broker} connection') + await self.broker.aclose() diff --git a/dispatcher/producers/scheduled.py b/dispatcher/producers/scheduled.py index 6e50512..edc4fd7 100644 --- a/dispatcher/producers/scheduled.py +++ b/dispatcher/producers/scheduled.py @@ -1,17 +1,16 @@ import asyncio import logging -from dispatcher.producers.base import BaseProducer, ProducerEvents +from dispatcher.producers.base import BaseProducer logger = logging.getLogger(__name__) class ScheduledProducer(BaseProducer): - def __init__(self, task_schedule: dict): - self.events = ProducerEvents() + def __init__(self, task_schedule: dict[str, dict[str, int]]): self.task_schedule = task_schedule self.scheduled_tasks: list[asyncio.Task] = [] - self.produced_count = 0 + super().__init__() async def start_producing(self, dispatcher) -> None: for task_name, options in self.task_schedule.items(): @@ -19,7 +18,8 @@ async def start_producing(self, dispatcher) -> None: schedule_task = asyncio.create_task(self.run_schedule_forever(task_name, per_seconds, dispatcher)) self.scheduled_tasks.append(schedule_task) schedule_task.add_done_callback(dispatcher.fatal_error_callback) - self.events.ready_event.set() + if self.events: + self.events.ready_event.set() def all_tasks(self) -> list[asyncio.Task]: return self.scheduled_tasks diff --git a/dispatcher/registry.py b/dispatcher/registry.py index 23f602a..f67be5d 100644 --- a/dispatcher/registry.py +++ b/dispatcher/registry.py @@ -6,6 +6,8 @@ from typing import Callable, Optional, Set, Tuple from uuid import uuid4 +from dispatcher.config import LazySettings +from dispatcher.config import settings as global_settings from dispatcher.utils import MODULE_METHOD_DELIMITER, DispatcherCallable, resolve_callable logger = logging.getLogger(__name__) @@ -79,23 +81,21 @@ def get_async_body( return body - def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, connection=None, config=None, **kw) -> Tuple[dict, str]: + def apply_async(self, args=None, kwargs=None, queue=None, uuid=None, settings: LazySettings = global_settings, **kw) -> Tuple[dict, str]: queue = queue or self.submission_defaults.get('queue') - if not queue: - msg = f'{self.fn}: Queue value required and may not be None' - logger.error(msg) - raise ValueError(msg) if callable(queue): queue = queue() obj = self.get_async_body(args=args, kwargs=kwargs, uuid=uuid, **kw) - # TODO: before sending, consult an app-specific callback if configured - from dispatcher.brokers.pg_notify import publish_message + from dispatcher.factories import get_publisher_from_settings - # NOTE: the kw will communicate things in the database connection data - publish_message(queue, json.dumps(obj), connection=connection, config=config) + broker = get_publisher_from_settings(settings=settings) + + # TODO: exit if a setting is applied to disable publishing + + broker.publish_message(channel=queue, message=json.dumps(obj)) return (obj, queue) diff --git a/dispatcher/tasks.py b/dispatcher/tasks.py new file mode 100644 index 0000000..710ecf6 --- /dev/null +++ b/dispatcher/tasks.py @@ -0,0 +1,8 @@ +from dispatcher.factories import get_publisher_from_settings +from dispatcher.publish import task + + +@task() +def reply_to_control(reply_channel: str, message: str): + broker = get_publisher_from_settings() + broker.publish_message(channel=reply_channel, message=message) diff --git a/dispatcher/worker/task.py b/dispatcher/worker/task.py index 420caf0..2708366 100644 --- a/dispatcher/worker/task.py +++ b/dispatcher/worker/task.py @@ -8,6 +8,7 @@ import traceback from queue import Empty as QueueEmpty +from dispatcher.config import setup from dispatcher.registry import registry logger = logging.getLogger(__name__) @@ -197,11 +198,14 @@ def get_shutdown_message(self): return {"worker": self.worker_id, "event": "shutdown"} -def work_loop(worker_id: int, queue: multiprocessing.Queue, finished_queue): +def work_loop(settings: dict, worker_id: int, queue: multiprocessing.Queue, finished_queue): """ Worker function that processes messages from the queue and sends confirmation to the finished_queue once done. """ + # Load settings passed from parent + # this assures that workers are all configured the same + setup(config=settings) worker = TaskWorker(worker_id) # TODO: add an app callback here to set connection name and things like that diff --git a/docs/config.md b/docs/config.md new file mode 100644 index 0000000..1116787 --- /dev/null +++ b/docs/config.md @@ -0,0 +1,124 @@ +## Dispatcher Configuration + +Why is configuration needed? Consider doing this, which uses the demo content: + +``` +PYTHONPATH=$PYTHONPATH:tools/ python -c "from tools.test_methods import sleep_function; sleep_function.delay()" +``` + +This will result in an error: + +> Dispatcher not configured, set DISPATCHER_CONFIG_FILE or call dispatcher.config.setup + +This is an error because dispatcher does not have information to connect to a message broker. +In the case of postgres, that information is the connection information (host, user, password, etc) +as well as the pg_notify channel to send the message to. + +### Ways to configure + +#### From file + +The provided entrypoint `dispatcher-standalone` can only use a file, which is how the demo works. +The demo runs using the `dispatcher.yml` config file at the top-level of this repo. + +You can do the same thing in python by: + +```python +from dispatcher.config import setup + +setup(file_path='dispatcher.yml') +``` + +This approach is used by the demo's test script at `tools/write_messages.py`, +which acts as a "publisher", meaning that it submits tasks over the message +broker to be ran. +This setup ensures that both the service (`dispatcher-standalone`) and the publisher +are using the same configuration. + +#### From a dictionary + +Calling `setup(config)`, where `config` is a python dictionary is +equivelent to dumping the `config` to a yaml file and using that as +the file-based config. + +### Configuration Contents + +At the top-level, config is broken down by either the process that uses that section, +or brokers, which is a shared resources used by multiple processes. + +At the level below that, the config gives instructions for creating python objects. +The module `dispatcher.factories` has the task of creating those objects from settings. +The design goal is to have a little possible divergence from the settings structure +and the class structure in the code. + +The general structure is: + +```yaml +--- +version: # number +service: + pool_kwargs: + # options +brokers: + pg_notify: + # options +producers: + ProducerClass: + # options +publish: + # options +``` + +When providing `pool_kwargs`, those are the kwargs passed to `WorkerPool`, for example. + +#### Version + +The version field is mandatory and must match the current config in the library. +This is validated against current code and saved in the [schema.json](../schema.json) file. + +The version will be bumped when any breaking change happens. + +You can re-generate the schema after making changes by running: + +``` +python tools/gen_schema.py > schema.json +``` + +#### Brokers + +Brokers relay messages which give instructions about code to run. +Right now the only broker available is pg_notify. + +The sub-options become python `kwargs` passed to the broker class `Broker`. +For now, you will just have to read the code to see what those options are +at [dispatcher.brokers.pg_notify](dispatcher/brokers/pg_notify.py). + +The broker classes have methods that allow for submitting messages +and reading messages. + +#### Service + +This configures the background task service. + +The `pool_kwargs` options will correspond to the `WorkerPool` class +[dispatcher.pool](dispatcher/pool.py). +Process management options will be added to this section later. + +These options are mainly concerned with worker +management. For instance, auto-scaling options will be here, +like worker count, etc. + +#### Producers + +These are "producers of tasks" in the dispatcher service. + +For every listed broker, a `BrokeredProducer` is automatically +created. That means that tasks may be produced from the messaging +system that the dispatcher service is listening to. + +The other current use case is `ScheduledProducer`, +which submits tasks every certain number of seconds. + +#### Publish + +Additional options for publishers (task submitters). diff --git a/docs/vision.md b/docs/vision.md new file mode 100644 index 0000000..f600e64 --- /dev/null +++ b/docs/vision.md @@ -0,0 +1,30 @@ +## Vision for Dispatcher + +The dispatcher strives to be an extremely contained, simple library, +by assuming that your system already has a source-of-truth. + +This will: + - run python tasks + +This will not: + - provide a result backend + - treat the queue as a source of truth + +For problems that go beyond the scope of this library, +suggestions will go into a cookbook. + +```mermaid +flowchart TD + +A(web) +B(task) +C(postgres) + +A-->C +B-->C +``` + +https://taskiq-python.github.io/guide/architecture-overview.html#context + +https://python-rq.org/docs/workers/ + diff --git a/schema.json b/schema.json new file mode 100644 index 0000000..be75890 --- /dev/null +++ b/schema.json @@ -0,0 +1,25 @@ +{ + "version": 2, + "brokers": { + "pg_notify": { + "config": "typing.Optional[dict]", + "async_connection_factory": "typing.Optional[str]", + "sync_connection_factory": "typing.Optional[str]", + "channels": "typing.Union[tuple, list]", + "default_publish_channel": "typing.Optional[str]" + } + }, + "producers": { + "ScheduledProducer": { + "task_schedule": "dict[str, dict[str, int]]" + } + }, + "service": { + "pool_kwargs": { + "max_workers": "" + } + }, + "publish": { + "default_broker": "str" + } +} diff --git a/tests/conftest.py b/tests/conftest.py index 4accbfa..76836d2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,3 @@ -import asyncio - import contextlib from typing import Callable, AsyncIterator @@ -11,7 +9,10 @@ from dispatcher.main import DispatcherMain from dispatcher.control import Control -from dispatcher.brokers.pg_notify import apublish_message, aget_connection, get_connection +from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection +from dispatcher.registry import DispatcherMethodRegistry +from dispatcher.config import DispatcherSettings +from dispatcher.factories import from_settings, get_control_from_settings # List of channels to listen on @@ -20,63 +21,99 @@ # Database connection details CONNECTION_STRING = "dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777" -BASIC_CONFIG = {"producers": {"brokers": {"pg_notify": {"conninfo": CONNECTION_STRING}, "channels": CHANNELS}}, "pool": {"max_workers": 3}} +BASIC_CONFIG = { + "version": 2, + "brokers": { + "pg_notify": { + "channels": CHANNELS, + "config": {'conninfo': CONNECTION_STRING}, + "sync_connection_factory": "dispatcher.brokers.pg_notify.connection_saver", + # "async_connection_factory": "dispatcher.brokers.pg_notify.async_connection_saver", + "default_publish_channel": "test_channel" + } + }, + "pool": { + "max_workers": 3 + } +} + + +@contextlib.asynccontextmanager +async def aconnection_for_test(): + conn = None + try: + conn = await acreate_connection(conninfo=CONNECTION_STRING, autocommit=True) + + # Make sure database is running to avoid deadlocks which can come + # from using the loop provided by pytest asyncio + async with conn.cursor() as cursor: + await cursor.execute('SELECT 1') + await cursor.fetchall() + + yield conn + finally: + if conn: + await conn.close() + + +@pytest.fixture +def conn_config(): + return {'conninfo': CONNECTION_STRING} @pytest.fixture def pg_dispatcher() -> DispatcherMain: - return DispatcherMain(BASIC_CONFIG) + # We can not reuse the connection between tests + config = BASIC_CONFIG.copy() + config['brokers']['pg_notify'].pop('async_connection_factory') + return DispatcherMain(config) + + +@pytest.fixture +def test_settings(): + return DispatcherSettings(BASIC_CONFIG) @pytest_asyncio.fixture(loop_scope="function", scope="function") -async def apg_dispatcher(request) -> AsyncIterator[DispatcherMain]: +async def apg_dispatcher(test_settings) -> AsyncIterator[DispatcherMain]: + dispatcher = None try: - dispatcher = DispatcherMain(BASIC_CONFIG) + dispatcher = from_settings(settings=test_settings) await dispatcher.connect_signals() await dispatcher.start_working() await dispatcher.wait_for_producers_ready() + assert dispatcher.pool.finished_count == 0 # sanity + yield dispatcher finally: - await dispatcher.shutdown() - await dispatcher.cancel_tasks() + if dispatcher: + await dispatcher.shutdown() + await dispatcher.cancel_tasks() @pytest_asyncio.fixture(loop_scope="function", scope="function") -async def pg_message(psycopg_conn) -> Callable: - async def _rf(message, channel='test_channel'): - await apublish_message(psycopg_conn, channel, message) - return _rf +async def pg_control(test_settings) -> AsyncIterator[Control]: + return get_control_from_settings(settings=test_settings) -@pytest.fixture -def conn_config(): - return {'conninfo': CONNECTION_STRING} - - -@contextlib.asynccontextmanager -async def aconnection_for_test(): - conn = None - try: - conn = await aget_connection({'conninfo': CONNECTION_STRING}) +@pytest_asyncio.fixture(loop_scope="function", scope="function") +async def psycopg_conn(): + async with aconnection_for_test() as conn: yield conn - finally: - if conn: - await conn.close() @pytest_asyncio.fixture(loop_scope="function", scope="function") -async def pg_control() -> AsyncIterator[Control]: - """This has to use a different connection from dispatcher itself - - because psycopg will pool async connections, meaning that submission - for the control task would be blocked by the listening query of the dispatcher itself""" - async with aconnection_for_test() as conn: - yield Control('test_channel', async_connection=conn) +async def pg_message(psycopg_conn) -> Callable: + async def _rf(message, channel=None): + # Note on weirdness here, this broker will only be used for async publishing, so we give junk for synchronous connection + broker = Broker(async_connection=psycopg_conn, default_publish_channel='test_channel', sync_connection_factory='tests.data.methods.something') + await broker.apublish_message(channel=channel, message=message) + return _rf -@pytest_asyncio.fixture(loop_scope="function", scope="function") -async def psycopg_conn(): - async with aconnection_for_test() as conn: - yield conn +@pytest.fixture +def registry() -> DispatcherMethodRegistry: + "Return a fresh registry, separate from the global one, for testing" + return DispatcherMethodRegistry() diff --git a/tests/integration/brokers/test_pg_notify.py b/tests/integration/brokers/test_pg_notify.py new file mode 100644 index 0000000..9498794 --- /dev/null +++ b/tests/integration/brokers/test_pg_notify.py @@ -0,0 +1,30 @@ +import pytest + +from dispatcher.brokers.pg_notify import Broker, create_connection, acreate_connection + + +def test_sync_connection_from_config_reuse(conn_config): + broker = Broker(config=conn_config) + conn = broker.get_connection() + with conn.cursor() as cur: + cur.execute('SELECT 1') + assert cur.fetchall() == [(1,)] + + conn2 = broker.get_connection() + assert conn is conn2 + + assert conn is not create_connection(**conn_config) + + +@pytest.mark.asyncio +async def test_async_connection_from_config_reuse(conn_config): + broker = Broker(config=conn_config) + conn = await broker.aget_connection() + async with conn.cursor() as cur: + await cur.execute('SELECT 1') + assert await cur.fetchall() == [(1,)] + + conn2 = await broker.aget_connection() + assert conn is conn2 + + assert conn is not await acreate_connection(**conn_config) diff --git a/tests/integration/publish/__init__.py b/tests/integration/publish/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/publish/test_registry.py b/tests/integration/publish/test_registry.py new file mode 100644 index 0000000..ed87a4e --- /dev/null +++ b/tests/integration/publish/test_registry.py @@ -0,0 +1,28 @@ +from unittest import mock + +import pytest + +from dispatcher.publish import task +from dispatcher.config import temporary_settings + + +def test_apply_async_with_no_queue(registry, conn_config): + @task(registry=registry) + def test_method(): + return + + dmethod = registry.get_from_callable(test_method) + + # These settings do not specify a default channel, that is the main point + with temporary_settings({'version': 2, 'brokers': {'pg_notify': {'config': conn_config}}}): + + # Can not run a method if we do not have a queue + with pytest.raises(ValueError): + dmethod.apply_async() + + # But providing a queue at time of submission works + with mock.patch('dispatcher.brokers.pg_notify.Broker.publish_message') as mock_publish_method: + dmethod.apply_async(queue='fooqueue') + mock_publish_method.assert_called_once_with(channel='fooqueue', message=mock.ANY) + + mock_publish_method.assert_called_once() diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 9b2a137..af35524 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -6,6 +6,8 @@ from tests.data import methods as test_methods +from dispatcher.config import temporary_settings + SLEEP_METHOD = 'lambda: __import__("time").sleep(0.1)' @@ -24,7 +26,7 @@ async def wait_to_receive(dispatcher, ct, timeout=5.0, interval=0.05): async def test_run_lambda_function(apg_dispatcher, pg_message): assert apg_dispatcher.pool.finished_count == 0 - clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) + clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait(), name='test_lambda_clear_wait') await pg_message('lambda: "This worked!"') await asyncio.wait_for(clearing_task, timeout=3) @@ -32,11 +34,19 @@ async def test_run_lambda_function(apg_dispatcher, pg_message): @pytest.mark.asyncio -async def test_run_decorated_function(apg_dispatcher, conn_config): - assert apg_dispatcher.pool.finished_count == 0 +async def test_run_decorated_function(apg_dispatcher, test_settings): + clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) + test_methods.print_hello.apply_async(settings=test_settings) + await asyncio.wait_for(clearing_task, timeout=3) + assert apg_dispatcher.pool.finished_count == 1 + + +@pytest.mark.asyncio +async def test_submit_with_global_settings(apg_dispatcher, test_settings): clearing_task = asyncio.create_task(apg_dispatcher.pool.events.work_cleared.wait()) - test_methods.print_hello.apply_async(config=conn_config) + with temporary_settings(test_settings.serialize()): + test_methods.print_hello.delay() # settings are inferred from global context await asyncio.wait_for(clearing_task, timeout=3) assert apg_dispatcher.pool.finished_count == 1 @@ -95,8 +105,6 @@ async def test_cancel_task(apg_dispatcher, pg_message, pg_control): @pytest.mark.asyncio async def test_message_with_delay(apg_dispatcher, pg_message, pg_control): - assert apg_dispatcher.pool.finished_count == 0 - # Send message to run task with a delay msg = json.dumps({'task': 'lambda: print("This task had a delay")', 'uuid': 'delay_task', 'delay': 0.3}) await pg_message(msg) @@ -186,11 +194,11 @@ async def test_task_discard(apg_dispatcher, pg_message): @pytest.mark.asyncio -async def test_task_discard_in_task_definition(apg_dispatcher, conn_config): +async def test_task_discard_in_task_definition(apg_dispatcher, test_settings): assert apg_dispatcher.pool.finished_count == 0 for i in range(10): - test_methods.sleep_discard.apply_async(args=[2], config=conn_config) + test_methods.sleep_discard.apply_async(args=[2], settings=test_settings) await wait_to_receive(apg_dispatcher, 10) @@ -199,11 +207,11 @@ async def test_task_discard_in_task_definition(apg_dispatcher, conn_config): @pytest.mark.asyncio -async def test_tasks_in_serial(apg_dispatcher, conn_config): +async def test_tasks_in_serial(apg_dispatcher, test_settings): assert apg_dispatcher.pool.finished_count == 0 for i in range(10): - test_methods.sleep_serial.apply_async(args=[2], config=conn_config) + test_methods.sleep_serial.apply_async(args=[2], settings=test_settings) await wait_to_receive(apg_dispatcher, 10) @@ -212,11 +220,11 @@ async def test_tasks_in_serial(apg_dispatcher, conn_config): @pytest.mark.asyncio -async def test_tasks_queue_one(apg_dispatcher, conn_config): +async def test_tasks_queue_one(apg_dispatcher, test_settings): assert apg_dispatcher.pool.finished_count == 0 for i in range(10): - test_methods.sleep_queue_one.apply_async(args=[2], config=conn_config) + test_methods.sleep_queue_one.apply_async(args=[2], settings=test_settings) await wait_to_receive(apg_dispatcher, 10) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7ea0277..e69de29 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,9 +0,0 @@ -import pytest - -from dispatcher.registry import DispatcherMethodRegistry - - -@pytest.fixture -def registry() -> DispatcherMethodRegistry: - "Return a fresh registry, separate from the global one, for testing" - return DispatcherMethodRegistry() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4f98be1..70717f7 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1,3 +1,52 @@ -# this is a good place to create config files, load them, and test that we get the params we expected -# also a good place to take some configs and test initializing dispatcher objects with them -# None of this has been done, but you could do it +import json + +import yaml + +import pytest + +from dispatcher.config import DispatcherSettings, LazySettings, temporary_settings +from dispatcher.factories import generate_settings_schema + + +def test_settings_reference_unconfigured(): + settings = LazySettings() + with pytest.raises(Exception) as exc: + settings.brokers + assert 'Dispatcher not configured' in str(exc) + + +def test_configured_settings(): + settings = LazySettings() + settings._wrapped = DispatcherSettings({'version': 2, 'brokers': {'pg_notify': {'config': {}}}}) + assert 'pg_notify' in settings.brokers + + +def test_serialize_settings(test_settings): + config = test_settings.serialize() + assert 'producers' in config + assert 'publish' in config + assert config['publish'] == {} + assert 'pg_notify' in config['brokers'] + assert config['service']['pool_kwargs']['max_workers'] == 3 + + re_loaded = DispatcherSettings(config) + assert re_loaded.brokers == test_settings.brokers + assert re_loaded.service == test_settings.service + + +def test_version_validated(): + with pytest.raises(RuntimeError) as exc: + DispatcherSettings({}) + assert 'config version must match this' in str(exc) + + +def test_schema_is_current(): + with open('dispatcher.yml', 'r') as f: + file_contents = f.read() + demo_data = yaml.safe_load(file_contents) + with temporary_settings(demo_data): + expect_schema = generate_settings_schema() + with open('schema.json', 'r') as sch_f: + schema_contents = sch_f.read() + schema_data = json.loads(schema_contents) + assert schema_data == expect_schema diff --git a/tests/unit/test_publish.py b/tests/unit/test_publish.py index c14e825..be4583a 100644 --- a/tests/unit/test_publish.py +++ b/tests/unit/test_publish.py @@ -1,7 +1,6 @@ from unittest import mock from dispatcher.publish import task -from dispatcher.utils import serialize_task import pytest @@ -76,21 +75,3 @@ def run(self): TestMethod.delay() mock_apply_async.assert_called_once_with((), {}) - - -def test_apply_async_with_no_queue(registry): - @task(registry=registry) - def test_method(): - return - - dmethod = registry.get_from_callable(test_method) - - # Can not run a method if we do not have a queue - with pytest.raises(ValueError): - dmethod.apply_async() - - # But providing a queue at time of submission works - with mock.patch('dispatcher.brokers.pg_notify.publish_message') as mock_publish_method: - dmethod.apply_async(queue='fooqueue') - - mock_publish_method.assert_called_once() diff --git a/tools/gen_schema.py b/tools/gen_schema.py new file mode 100644 index 0000000..50e2e2d --- /dev/null +++ b/tools/gen_schema.py @@ -0,0 +1,9 @@ +import json +from dispatcher.config import setup +from dispatcher.factories import generate_settings_schema + +setup(file_path='dispatcher.yml') + +data = generate_settings_schema() + +print(json.dumps(data, indent=2)) diff --git a/tools/write_messages.py b/tools/write_messages.py index 4ad981a..771784a 100644 --- a/tools/write_messages.py +++ b/tools/write_messages.py @@ -4,9 +4,10 @@ import os import sys -from dispatcher.brokers.pg_notify import publish_message +from dispatcher.factories import get_publisher_from_settings from dispatcher.control import Control from dispatcher.utils import MODULE_METHOD_DELIMITER +from dispatcher.config import setup # Add the test methods to the path so we can use .delay type contracts tools_dir = os.path.abspath( @@ -17,8 +18,12 @@ from test_methods import sleep_function, sleep_discard, task_has_timeout, hello_world_binder -# Database connection details -CONNECTION_STRING = "dbname=dispatch_db user=dispatch password=dispatching host=localhost port=55777" + +# Setup the global config from the settings file shared with the service +setup(file_path='dispatcher.yml') + + +broker = get_publisher_from_settings() TEST_MSGS = [ @@ -32,39 +37,38 @@ def main(): print('writing some basic test messages') for channel, message in TEST_MSGS: # Send the notification - publish_message(channel, message, config={'conninfo': CONNECTION_STRING}) + broker.publish_message(channel=channel, message=message) # await send_notification(channel, message) # send more than number of workers quickly print('') print('writing 15 messages fast') for i in range(15): - publish_message('test_channel', f'lambda: {i}', config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=f'lambda: {i}') print('') print('performing a task cancel') # submit a task we will "find" two different ways - publish_message(channel, json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar'}), config={'conninfo': CONNECTION_STRING}) - ctl = Control('test_channel', config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar'})) + ctl = Control('test_channel') canceled_jobs = ctl.control_with_reply('cancel', data={'uuid': 'foobar'}) print(json.dumps(canceled_jobs, indent=2)) print('') print('finding a running task by its task name') - publish_message(channel, json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar2'}), config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: __import__("time").sleep(3.1415)', 'uuid': 'foobar2'})) running_data = ctl.control_with_reply('running', data={'task': 'lambda: __import__("time").sleep(3.1415)'}) print(json.dumps(running_data, indent=2)) print('writing a message with a delay') print(' 4 second delay task') - publish_message(channel, json.dumps({'task': 'lambda: 123421', 'uuid': 'foobar2', 'delay': 4}), config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: 123421', 'uuid': 'foobar2', 'delay': 4})) print(' 30 second delay task') - publish_message(channel, json.dumps({'task': 'lambda: 987987234', 'uuid': 'foobar2', 'delay': 30}), config={'conninfo': CONNECTION_STRING}) + broker.publish_message(message=json.dumps({'task': 'lambda: 987987234', 'uuid': 'foobar2', 'delay': 30})) print(' 10 second delay task') # NOTE: this task will error unless you run the dispatcher itself with it in the PYTHONPATH, which is intended sleep_function.apply_async( args=[3], # sleep 3 seconds delay=10, - config={'conninfo': CONNECTION_STRING} ) print('') @@ -88,31 +92,31 @@ def main(): print('') print('demo of submitting discarding tasks') for i in range(10): - publish_message(channel, json.dumps( + broker.publish_message(message=json.dumps( {'task': 'lambda: __import__("time").sleep(9)', 'on_duplicate': 'discard', 'uuid': f'dscd-{i}'} - ), config={'conninfo': CONNECTION_STRING}) + )) print('demo of discarding task marked as discarding') for i in range(10): - sleep_discard.apply_async(args=[2], config={'conninfo': CONNECTION_STRING}) + sleep_discard.apply_async(args=[2]) print('demo of discarding tasks with apply_async contract') for i in range(10): - sleep_function.apply_async(args=[3], on_duplicate='discard', config={'conninfo': CONNECTION_STRING}) + sleep_function.apply_async(args=[3], on_duplicate='discard') print('demo of submitting waiting tasks') for i in range(10): - publish_message(channel, json.dumps( + broker.publish_message(message=json.dumps( {'task': 'lambda: __import__("time").sleep(10)', 'on_duplicate': 'serial', 'uuid': f'wait-{i}'} - ), config={'conninfo': CONNECTION_STRING}) + )) print('demo of submitting queue-once tasks') for i in range(10): - publish_message(channel, json.dumps( + broker.publish_message(message=json.dumps( {'task': 'lambda: __import__("time").sleep(8)', 'on_duplicate': 'queue_one', 'uuid': f'queue_one-{i}'} - ), config={'conninfo': CONNECTION_STRING}) + )) print('demo of task_has_timeout that times out due to decorator use') - task_has_timeout.apply_async(config={'conninfo': CONNECTION_STRING}) + task_has_timeout.delay() print('demo of using bind=True, with hello_world_binder') - hello_world_binder.apply_async(config={'conninfo': CONNECTION_STRING}) + hello_world_binder.delay() if __name__ == "__main__": logging.basicConfig(level='ERROR', stream=sys.stdout)