Skip to content

Add more control commands #132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dispatcher.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ producers:
OnStartProducer:
task_list:
'lambda: print("This task runs on startup")': {}
ControlProducer:
publish:
default_control_broker: socket
default_broker: pg_notify
3 changes: 2 additions & 1 deletion dispatcherd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from . import run_service
from .config import setup
from .factories import get_control_from_settings
from .service import control_tasks

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,7 +56,7 @@ def standalone() -> None:

def control() -> None:
parser = get_parser()
parser.add_argument('command', help='The control action to run.')
parser.add_argument('command', choices=[cmd for cmd in control_tasks.__all__], help='The control action to run.')
parser.add_argument(
'--task',
type=str,
Expand Down
2 changes: 2 additions & 0 deletions dispatcherd/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def producers_from_settings(settings: LazySettings = global_settings) -> Iterabl
producer_objects.append(producer)

for producer_cls, producer_kwargs in settings.producers.items():
if producer_kwargs is None:
producer_kwargs = {}
producer_objects.append(getattr(producers, producer_cls)(**producer_kwargs))

return producer_objects
Expand Down
3 changes: 2 additions & 1 deletion dispatcherd/producers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base import BaseProducer
from .brokered import BrokeredProducer
from .control import ControlProducer
from .on_start import OnStartProducer
from .scheduled import ScheduledProducer

__all__ = ['BaseProducer', 'BrokeredProducer', 'ScheduledProducer', 'OnStartProducer']
__all__ = ['BaseProducer', 'BrokeredProducer', 'ScheduledProducer', 'OnStartProducer', 'ControlProducer']
7 changes: 5 additions & 2 deletions dispatcherd/producers/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio

from ..protocols import Producer
from ..protocols import Producer as ProducerProtocol


class ProducerEvents:
Expand All @@ -9,8 +9,11 @@ def __init__(self) -> None:
self.recycle_event = asyncio.Event()


class BaseProducer(Producer):
class BaseProducer(ProducerProtocol):

def __init__(self) -> None:
self.events = ProducerEvents()
self.produced_count = 0

def get_status_data(self) -> dict:
return {'produced_count': self.produced_count}
3 changes: 2 additions & 1 deletion dispatcherd/producers/brokered.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ async def recycle(self) -> None:
await self.start_producing(self.dispatcher)

def __str__(self) -> str:
return f'brokered-producer-{self.broker}'
broker_module = self.broker.__module__.rsplit('.', 1)[-1]
return f'{broker_module}-producer'

async def start_producing(self, dispatcher: DispatcherMain) -> None:
self.production_task = asyncio.create_task(self.produce_forever(dispatcher), name=f'{self.broker.__module__}_production')
Expand Down
36 changes: 36 additions & 0 deletions dispatcherd/producers/control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import asyncio
import json
import logging
from typing import Optional

from ..protocols import DispatcherMain
from .base import BaseProducer

logger = logging.getLogger(__name__)


class ControlProducer(BaseProducer):
"""Placeholder producer to allow control actions to start tasks

This must be enabled to start tasks via control actions.
Indirectly, this also allows tasks to start other tasks.
"""

def __init__(self) -> None:
self.dispatcher: Optional[DispatcherMain] = None
super().__init__()

async def start_producing(self, dispatcher: DispatcherMain) -> None:
self.dispatcher = dispatcher
self.events.ready_event.set()

async def submit_task(self, data: dict) -> None:
assert self.dispatcher is not None
await self.dispatcher.process_message(json.dumps(data))
self.produced_count += 1

def all_tasks(self) -> list[asyncio.Task]:
return []

async def shutdown(self) -> None:
pass
15 changes: 14 additions & 1 deletion dispatcherd/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ async def start_producing(self, dispatcher: 'DispatcherMain') -> None:
"""Starts tasks which will eventually call DispatcherMain.process_message - how tasks originate in the service"""
...

def get_status_data(self) -> dict:
"""Data for debugging commands"""
...

async def shutdown(self):
"""Stop producing tasks and clean house, a producer may be shut down independently from the main program"""
...
Expand Down Expand Up @@ -108,7 +112,7 @@ async def start_task(self, message: dict) -> None: ...

def is_ready(self) -> bool: ...

def get_data(self) -> dict[str, Any]:
def get_status_data(self) -> dict[str, Any]:
"""Used for worker status control-and-reply command"""
...

Expand Down Expand Up @@ -174,6 +178,10 @@ async def dispatch_task(self, message: dict) -> None:
"""Called by DispatcherMain after in the normal task lifecycle, pool will try to hand the task to a worker"""
...

def get_status_data(self) -> dict:
"""Data for debugging commands"""
...

async def shutdown(self) -> None: ...


Expand All @@ -189,6 +197,7 @@ class DispatcherMain(Protocol):
pool: WorkerPool
delayed_messages: set
fd_lock: asyncio.Lock # Forking and locking may need to be serialized, which this does
producers: Iterable[Producer]

async def main(self) -> None:
"""This is the method that runs the service, bring your own event loop"""
Expand All @@ -207,3 +216,7 @@ async def process_message(
) -> tuple[Optional[str], Optional[str]]:
"""This is called by producers when a new request to run a task comes in"""
...

def get_status_data(self) -> dict:
"""Data for debugging commands"""
...
62 changes: 60 additions & 2 deletions dispatcherd/service/control_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..protocols import DispatcherMain

__all__ = ['running', 'cancel', 'alive', 'aio_tasks', 'workers']
__all__ = ['running', 'cancel', 'alive', 'aio_tasks', 'workers', 'producers', 'main', 'status']


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,11 +55,21 @@ async def _find_tasks(dispatcher: DispatcherMain, data: dict, cancel: bool = Fal


async def running(dispatcher: DispatcherMain, data: dict) -> dict[str, dict]:
"""Information on running tasks managed by this dispatcherd service

Data may be used to filter the tasks of interest.
Keys and values in data correspond to expected key-values in the message,
but are limited to task, kwargs, args, and uuid.
"""
async with dispatcher.pool.workers.management_lock:
return await _find_tasks(dispatcher=dispatcher, data=data)


async def cancel(dispatcher: DispatcherMain, data: dict) -> dict[str, dict]:
"""Cancel all tasks that match the filter given by data

The protocol for the data filtering is the same as the running command.
"""
async with dispatcher.pool.workers.management_lock:
return await _find_tasks(dispatcher=dispatcher, cancel=True, data=data)

Expand All @@ -71,6 +81,7 @@ def _stack_from_task(task: asyncio.Task, limit: int = 6) -> str:


async def aio_tasks(dispatcher: DispatcherMain, data: dict) -> dict[str, dict]:
"""Information on the asyncio tasks running in the dispatcher main process"""
ret = {}
extra = {}
if 'limit' in data:
Expand All @@ -83,11 +94,58 @@ async def aio_tasks(dispatcher: DispatcherMain, data: dict) -> dict[str, dict]:


async def alive(dispatcher: DispatcherMain, data: dict) -> dict:
"""Returns no information, used to get fast roll-call of instances"""
return {}


async def workers(dispatcher: DispatcherMain, data: dict) -> dict:
"""Information about subprocess workers"""
ret = {}
for worker in dispatcher.pool.workers:
ret[f'worker-{worker.worker_id}'] = worker.get_data()
ret[f'worker-{worker.worker_id}'] = worker.get_status_data()
return ret


async def producers(dispatcher: DispatcherMain, data: dict) -> dict:
"""Information about the enabled task producers"""
ret = {}
for producer in dispatcher.producers:
ret[str(producer)] = producer.get_status_data()
return ret


async def run(dispatcher: DispatcherMain, data: dict) -> dict:
"""Run a task. The control data should follow the standard message protocol.

You could just submit task data, as opposed to submitting a control task
with task data nested in control_data, which is what this is.
This might be useful if you:
- need to get a confirmation that your task has been received
- you need to start a task from another task
"""
for producer in dispatcher.producers:
if hasattr(producer, 'submit_task'):
try:
await producer.submit_task(data)
except Exception as exc:
return {'error': str(exc)}
return {'ack': data}
return {'error': 'A ControlProducer producer is not enabled. Add it to the list of producers in the service config to use this.'}


async def main(dispatcher: DispatcherMain, data: dict) -> dict:
"""Information about scalar quantities on the main or pool objects"""
ret = dispatcher.get_status_data()
ret["pool"] = dispatcher.pool.get_status_data()
return ret


async def status(dispatcher: DispatcherMain, data: dict) -> dict:
"""Information from all other non-destructive commands nested in a sub-dictionary"""
ret = {}
for command in __all__:
if command in ('cancel', 'alive', 'status', 'run'):
continue
control_method = globals()[command]
ret[command] = await control_method(dispatcher=dispatcher, data={})
return ret
6 changes: 5 additions & 1 deletion dispatcherd/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import logging
import signal
import time
from typing import Iterable, Optional, Union
from os import getpid
from typing import Any, Iterable, Optional, Union
from uuid import uuid4

from ..producers import BrokeredProducer
Expand Down Expand Up @@ -68,6 +69,9 @@ def receive_signal(self, *args, **kwargs) -> None: # type: ignore[no-untyped-de
logger.warning(f"Received exit signal args={args} kwargs={kwargs}")
self.events.exit_event.set()

def get_status_data(self) -> dict[str, Any]:
return {"received_count": self.received_count, "control_count": self.control_count, "pid": getpid()}

async def wait_for_producers_ready(self) -> None:
"Returns when all the producers have hit their ready event"
for producer in self.producers:
Expand Down
9 changes: 8 additions & 1 deletion dispatcherd/service/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def cancel(self) -> None:
return # it's effectively already canceled/not running
os.kill(self.process.pid, signal.SIGUSR1) # Use SIGUSR1 instead of SIGTERM

def get_data(self) -> dict[str, Any]:
def get_status_data(self) -> dict[str, Any]:
return {
'worker_id': self.worker_id,
'pid': self.process.pid,
Expand Down Expand Up @@ -245,6 +245,13 @@ def processed_count(self) -> int:
def received_count(self) -> int:
return self.processed_count + self.queuer.count() + self.blocker.count() + sum(1 for w in self.workers if w.current_task)

def get_status_data(self) -> dict[str, Any]:
return {
"next_worker_id": self.next_worker_id,
"finished_count": self.finished_count,
"canceled_count": self.canceled_count,
}

async def start_working(self, dispatcher: DispatcherMain, exit_event: Optional[asyncio.Event] = None) -> None:
self.dispatcher = dispatcher
self.read_results_task = ensure_fatal(asyncio.create_task(self.read_results_forever(), name='results_task'), exit_event=exit_event)
Expand Down
9 changes: 9 additions & 0 deletions docs/task_options.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ right now it offers:

- `uuid` - the internal id of this task call in dispatcher
- `worker_id` - the id of the worker running this task
- `control` - runs a control-and-reply command against its own parent process

Using the `dispatcher.control` interface on the bound object is
an more efficient alternative to communication over the broker.
It also allows tasks to dispatch follow-up tasks in the local service.

More complex examples can be found in `tests.data.methods`.
The `schedules_another_task` example shows how this can be used
to have a task start another task.

#### Queue

Expand Down
1 change: 1 addition & 0 deletions schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
}
},
"producers": {
"ControlProducer": {},
"ScheduledProducer": {
"task_schedule": "dict[str, dict[str, typing.Union[int, str]]]"
},
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"default_publish_channel": "test_channel",
}
},
"producers": {"ControlProducer": {}},
"pool": {"pool_kwargs": {"min_workers": 1, "max_workers": 6}},
}

Expand Down
8 changes: 7 additions & 1 deletion tests/data/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,10 @@ def run(self):
@task(bind=True)
def prints_running_tasks(binder):
r = binder.control('running')
print(r)
print(f'Obtained data on running tasks, result:\n{r}')


@task(bind=True)
def schedules_another_task(binder):
r = binder.control('run', data={'task': 'tests.data.methods.print_hello'})
print(f'Scheduled another task, result: {r}')
9 changes: 6 additions & 3 deletions tests/integration/test_disruptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

from tests.conftest import CONNECTION_STRING

from dispatcherd.producers.brokered import BrokeredProducer


# Change the application_name so that when we run this test we will not kill the connection for the test itself
THIS_TEST_STR = CONNECTION_STRING.replace('application_name=apg_test_server', 'application_name=do_not_delete_me')


@pytest.mark.asyncio
async def test_sever_pg_connection(apg_dispatcher, pg_message):
assert len(apg_dispatcher.producers) == 1
apg_dispatcher.producers[0].events.ready_event.clear()
brokered_producers = [producer for producer in apg_dispatcher.producers if isinstance(producer, BrokeredProducer)]
assert len(brokered_producers) == 1
brokered_producers[0].events.ready_event.clear()

query = """
SELECT pid, usename, application_name, backend_start, state
Expand Down Expand Up @@ -48,7 +51,7 @@ async def test_sever_pg_connection(apg_dispatcher, pg_message):
await apg_dispatcher.recycle_broker_producers()

# Continue method after the producers are ready, an effect of the recycle
ready_event_task = asyncio.create_task(apg_dispatcher.producers[0].events.ready_event.wait(), name='test_ready_event')
ready_event_task = asyncio.create_task(brokered_producers[0].events.ready_event.wait(), name='test_ready_event')
await asyncio.wait_for(ready_event_task, timeout=5)

# Submitting a new task should now work
Expand Down
Loading