diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 3f391df9a..16a38f796 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -6,19 +6,21 @@ from functools import wraps from pathlib import Path from pprint import pprint +from typing import Any, TypeGuard import click from bluesky.callbacks.best_effort import BestEffortCallback from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker +from click.core import Context, Parameter from click.exceptions import ClickException +from click.types import ParamType from observability_utils.tracing import setup_tracing -from pydantic import ValidationError from requests.exceptions import ConnectionError from blueapi import __version__, config from blueapi.cli.format import OutputFormat -from blueapi.client.client import BlueapiClient +from blueapi.client.client import BlueapiClient, TaskParameters from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import ( BlueskyRemoteControlError, @@ -34,12 +36,39 @@ from blueapi.log import set_up_logging from blueapi.service.authentication import SessionCacheManager, SessionManager from blueapi.service.model import SourceInfo -from blueapi.worker import ProgressEvent, Task, WorkerEvent +from blueapi.worker import ProgressEvent, WorkerEvent from .scratch import setup_scratch from .updates import CliEventRenderer +class ParametersType(ParamType): + """CLI input parameter to accept a JSON object as an argument""" + + name = "TaskParameters" + + def convert( + self, + value: str | dict[str, Any] | None, + param: Parameter | None, + ctx: Context | None, + ) -> TaskParameters: + if isinstance(value, str): + try: + params = json.loads(value) + if is_str_dict(params): + return params + self.fail("Parameters must be a JSON object with string keys") + except json.JSONDecodeError as jde: + self.fail(f"Parameters are not valid JSON: {jde}") + else: + return super().convert(value, param, ctx) + + +def is_str_dict(val: Any) -> TypeGuard[TaskParameters]: + return isinstance(val, dict) and all(isinstance(k, str) for k in val) + + @click.group( invoke_without_command=True, context_settings={"auto_envvar_prefix": "BLUEAPI"} ) @@ -220,7 +249,7 @@ def on_event( @controller.command(name="run") @click.argument("name", type=str) -@click.argument("parameters", type=str, required=False) +@click.argument("parameters", type=ParametersType(), default={}, required=False) @click.option( "--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True ) @@ -236,25 +265,13 @@ def on_event( def run_plan( obj: dict, name: str, - parameters: str | None, timeout: float | None, foreground: bool, + parameters: TaskParameters, ) -> None: """Run a plan with parameters""" client: BlueapiClient = obj["client"] - parameters = parameters or "{}" - try: - parsed_params = json.loads(parameters) if isinstance(parameters, str) else {} - except json.JSONDecodeError as jde: - raise ClickException(f"Parameters are not valid JSON: {jde}") from jde - - try: - task = Task(name=name, params=parsed_params) - except ValidationError as ve: - ip = InvalidParameters.from_validation_error(ve) - raise ClickException(ip.message()) from ip - try: if foreground: progress_bar = CliEventRenderer() @@ -266,12 +283,12 @@ def on_event(event: AnyEvent) -> None: elif isinstance(event, DataEvent): callback(event.name, event.doc) - resp = client.run_task(task, on_event=on_event) + resp = client.run_task(name, parameters, on_event=on_event) if resp.task_status is not None and not resp.task_status.task_failed: print("Plan Succeeded") else: - server_task = client.create_and_start_task(task) + server_task = client.create_and_start_task(name, parameters) click.echo(server_task.task_id) except config.MissingStompConfiguration as mse: raise ClickException(*mse.args) from mse diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 113b14686..e275b0f6c 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -1,5 +1,6 @@ import time from concurrent.futures import Future +from typing import Any from bluesky_stomp.messaging import MessageContext, StompClient from bluesky_stomp.models import Broker @@ -32,6 +33,8 @@ TRACER = get_tracer("client") +TaskParameters = dict[str, Any] + class BlueapiClient: """Unified client for controlling blueapi""" @@ -194,10 +197,12 @@ def get_active_task(self) -> WorkerTask: return self._rest.get_active_task() - @start_as_current_span(TRACER, "task", "timeout") + @start_as_current_span(TRACER, "name", "parameters", "timeout") def run_task( self, - task: Task, + name: str, + parameters: TaskParameters | None = None, + *, on_event: OnAnyEvent | None = None, timeout: float | None = None, ) -> WorkerEvent: @@ -220,7 +225,7 @@ def run_task( "Stomp configuration required to run plans is missing or disabled" ) - task_response = self.create_task(task) + task_response = self.create_task(name, parameters or {}) task_id = task_response.task_id complete: Future[WorkerEvent] = Future() @@ -257,8 +262,10 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None: self.start_task(WorkerTask(task_id=task_id)) return complete.result(timeout=timeout) - @start_as_current_span(TRACER, "task") - def create_and_start_task(self, task: Task) -> TaskResponse: + @start_as_current_span(TRACER, "name", "parameters") + def create_and_start_task( + self, name: str, parameters: TaskParameters | None = None + ) -> TaskResponse: """ Create a new task and instruct the worker to start it immediately. @@ -270,7 +277,7 @@ def create_and_start_task(self, task: Task) -> TaskResponse: TaskResponse: Acknowledgement of request """ - response = self.create_task(task) + response = self.create_task(name, parameters or {}) worker_response = self.start_task(WorkerTask(task_id=response.task_id)) if worker_response.task_id == response.task_id: return response @@ -280,8 +287,10 @@ def create_and_start_task(self, task: Task) -> TaskResponse: f"but {worker_response.task_id} was started instead" ) - @start_as_current_span(TRACER, "task") - def create_task(self, task: Task) -> TaskResponse: + @start_as_current_span(TRACER, "name", "parameters") + def create_task( + self, name: str, parameters: TaskParameters | None = None + ) -> TaskResponse: """ Create a new task, does not start execution @@ -292,6 +301,7 @@ def create_task(self, task: Task) -> TaskResponse: TaskResponse: Acknowledgement of request """ + task = Task(name=name, params=parameters or {}) return self._rest.create_task(task) @start_as_current_span(TRACER) diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index f60fe07a5..f0762d6be 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -1,6 +1,7 @@ import inspect import time from pathlib import Path +from typing import Any import pytest from bluesky_stomp.models import BasicAuthentication @@ -26,11 +27,10 @@ WorkerTask, ) from blueapi.worker.event import TaskStatus, WorkerEvent, WorkerState -from blueapi.worker.task import Task from blueapi.worker.task_worker import TrackableTask -_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0}) -_LONG_TASK = Task(name="sleep", params={"time": 1.0}) +_SIMPLE_TASK = ("sleep", {"time": 0.0}) +_LONG_TASK = ("sleep", {"time": 1.0}) _DATA_PATH = Path(__file__).parent @@ -182,19 +182,19 @@ def test_get_non_existent_device(client: BlueapiClient): def test_create_task_and_delete_task_by_id(client: BlueapiClient): - create_task = client.create_task(_SIMPLE_TASK) + create_task = client.create_task(*_SIMPLE_TASK) client.clear_task(create_task.task_id) def test_create_task_validation_error(client: BlueapiClient): with pytest.raises(UnknownPlan): - client.create_task(Task(name="Not-exists", params={"Not-exists": 0.0})) + client.create_task("Not-exists", {"Not-exists": 0.0}) def test_get_all_tasks(client: BlueapiClient): created_tasks: list[TaskResponse] = [] for task in [_SIMPLE_TASK, _LONG_TASK]: - created_task = client.create_task(task) + created_task = client.create_task(*task) created_tasks.append(created_task) task_ids = [task.task_id for task in created_tasks] @@ -208,7 +208,7 @@ def test_get_all_tasks(client: BlueapiClient): def test_get_task_by_id(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) + created_task = client.create_task(*_SIMPLE_TASK) get_task = client.get_task(created_task.task_id) assert ( @@ -232,7 +232,7 @@ def test_delete_non_existent_task(client: BlueapiClient): def test_put_worker_task(client: BlueapiClient): - created_task = client.create_task(_SIMPLE_TASK) + created_task = client.create_task(*_SIMPLE_TASK) client.start_task(WorkerTask(task_id=created_task.task_id)) active_task = client.get_active_task() assert active_task.task_id == created_task.task_id @@ -240,8 +240,8 @@ def test_put_worker_task(client: BlueapiClient): def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): - small_task = client.create_task(_SIMPLE_TASK) - long_task = client.create_task(_LONG_TASK) + small_task = client.create_task(*_SIMPLE_TASK) + long_task = client.create_task(*_LONG_TASK) client.start_task(WorkerTask(task_id=long_task.task_id)) active_task = client.get_active_task() @@ -269,8 +269,8 @@ def test_set_state_transition_error(client: BlueapiClient): def test_get_task_by_status(client: BlueapiClient): - task_1 = client.create_task(_SIMPLE_TASK) - task_2 = client.create_task(_SIMPLE_TASK) + task_1 = client.create_task(*_SIMPLE_TASK) + task_2 = client.create_task(*_SIMPLE_TASK) task_by_pending = client.get_all_tasks() # https://github.com/DiamondLightSource/blueapi/issues/680 # task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING) @@ -305,7 +305,7 @@ def test_progress_with_stomp(client_with_stomp: BlueapiClient): def on_event(event: AnyEvent): all_events.append(event) - client_with_stomp.run_task(_SIMPLE_TASK, on_event=on_event) + client_with_stomp.run_task(*_SIMPLE_TASK, on_event=on_event) assert isinstance(all_events[0], WorkerEvent) and all_events[0].task_status task_id = all_events[0].task_status.task_id assert all_events == [ @@ -350,11 +350,11 @@ def test_delete_current_environment(client: BlueapiClient): @pytest.mark.parametrize( - "task", + "plan,params", [ - Task( - name="count", - params={ + ( + "count", + { "detectors": [ "image_det", "current_det", @@ -362,9 +362,9 @@ def test_delete_current_environment(client: BlueapiClient): "num": 5, }, ), - Task( - name="spec_scan", - params={ + ( + "spec_scan", + { "detectors": [ "image_det", "current_det", @@ -372,34 +372,34 @@ def test_delete_current_environment(client: BlueapiClient): "spec": Line("x", 0.0, 10.0, 2) * Line("y", 5.0, 15.0, 3), }, ), - Task( - name="set_absolute", - params={ + ( + "set_absolute", + { "movable": "dynamic_motor", "value": "bar", }, ), - Task( - name="motor_plan", - params={ + ( + "motor_plan", + { "motor": "movable_motor", }, ), - Task( - name="motor_plan", - params={ + ( + "motor_plan", + { "motor": "dynamic_motor", }, ), - Task( - name="dataclass_motor_plan", - params={ + ( + "dataclass_motor_plan", + { "motor": "data_class_motor", }, ), ], ) -def test_plan_runs(client_with_stomp: BlueapiClient, task: Task): - final_event = client_with_stomp.run_task(task) +def test_plan_runs(client_with_stomp: BlueapiClient, plan: str, params: dict[str, Any]): + final_event = client_with_stomp.run_task(plan, params) assert final_event.is_complete() and not final_event.is_error() assert final_event.state is WorkerState.IDLE diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index 28c764488..2d343cbc9 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -171,7 +171,7 @@ def test_create_task( client: BlueapiClient, mock_rest: Mock, ): - client.create_task(task=Task(name="foo")) + client.create_task(name="foo") mock_rest.create_task.assert_called_once_with(Task(name="foo")) @@ -179,7 +179,7 @@ def test_create_task_does_not_start_task( client: BlueapiClient, mock_rest: Mock, ): - client.create_task(task=Task(name="foo")) + client.create_task(name="foo") mock_rest.update_worker_task.assert_not_called() @@ -218,7 +218,7 @@ def test_create_and_start_task_calls_both_creating_and_starting_endpoints( ): mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") mock_rest.create_task.assert_called_once_with(Task(name="baz")) mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="baz")) @@ -229,7 +229,7 @@ def test_create_and_start_task_fails_if_task_creation_fails( ): mock_rest.create_task.side_effect = BlueskyRemoteControlError("No can do") with pytest.raises(BlueskyRemoteControlError): - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") def test_create_and_start_task_fails_if_task_id_is_wrong( @@ -239,7 +239,7 @@ def test_create_and_start_task_fails_if_task_id_is_wrong( mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.return_value = TaskResponse(task_id="bar") with pytest.raises(BlueskyRemoteControlError): - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") def test_create_and_start_task_fails_if_task_start_fails( @@ -249,7 +249,7 @@ def test_create_and_start_task_fails_if_task_start_fails( mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.side_effect = BlueskyRemoteControlError("No can do") with pytest.raises(BlueskyRemoteControlError): - client.create_and_start_task(Task(name="baz")) + client.create_and_start_task(name="baz") def test_get_environment(client: BlueapiClient): @@ -384,7 +384,7 @@ def test_cannot_run_task_without_message_bus(client: BlueapiClient): MissingStompConfiguration, match="Stomp configuration required to run plans is missing or disabled", ): - client.run_task(Task(name="foo")) + client.run_task(name="foo") def test_run_task_sets_up_control( @@ -398,7 +398,7 @@ def test_run_task_sets_up_control( ctx.correlation_id = "foo" mock_events.subscribe_to_all_events = lambda on_event: on_event(COMPLETE_EVENT, ctx) - client_with_events.run_task(Task(name="foo")) + client_with_events.run_task(name="foo") mock_rest.create_task.assert_called_once_with(Task(name="foo")) mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="foo")) @@ -417,7 +417,7 @@ def test_run_task_fails_on_failing_event( on_event = Mock() with pytest.raises(BlueskyStreamingError): - client_with_events.run_task(Task(name="foo"), on_event=on_event) + client_with_events.run_task(name="foo", on_event=on_event) on_event.assert_called_with(FAILED_EVENT) @@ -456,7 +456,7 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_events.subscribe_to_all_events = callback # type: ignore mock_on_event = Mock() - client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) + client_with_events.run_task(name="foo", on_event=mock_on_event) assert mock_on_event.mock_calls == [call(test_event), call(COMPLETE_EVENT)] @@ -495,7 +495,7 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): mock_events.subscribe_to_all_events = callback mock_on_event = Mock() - client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) + client_with_events.run_task(name="foo", on_event=mock_on_event) mock_on_event.assert_called_once_with(COMPLETE_EVENT) @@ -543,8 +543,8 @@ def test_create_task_span_ok( client: BlueapiClient, mock_rest: Mock, ): - with asserting_span_exporter(exporter, "create_task", "task"): - client.create_task(task=Task(name="foo")) + with asserting_span_exporter(exporter, "create_task", "name", "parameters"): + client.create_task(name="foo") def test_clear_task_span_ok( @@ -579,8 +579,10 @@ def test_create_and_start_task_span_ok( ): mock_rest.create_task.return_value = TaskResponse(task_id="baz") mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") - with asserting_span_exporter(exporter, "create_and_start_task", "task"): - client.create_and_start_task(Task(name="baz")) + with asserting_span_exporter( + exporter, "create_and_start_task", "name", "parameters" + ): + client.create_and_start_task(name="baz") def test_get_environment_span_ok( @@ -644,4 +646,4 @@ def test_cannot_run_task_span_ok( match="Stomp configuration required to run plans is missing or disabled", ): with asserting_span_exporter(exporter, "grun_task"): - client.run_task(Task(name="foo")) + client.run_task(name="foo") diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 5c26b159a..b6e794716 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -24,7 +24,7 @@ from stomp.connect import StompConnection11 as Connection from blueapi import __version__ -from blueapi.cli.cli import main +from blueapi.cli.cli import ParametersType, main from blueapi.cli.format import OutputFormat, fmt_dict from blueapi.client.event_bus import BlueskyStreamingError from blueapi.client.rest import ( @@ -586,13 +586,13 @@ def test_error_handling(exception, error_message, runner: CliRunner): @pytest.mark.parametrize( - "params, error", + "params", [ - ("{", "Parameters are not valid JSON"), - ("[]", ""), + "{", + "[]", ], ) -def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): +def test_run_task_parsing_errors(params: str, runner: CliRunner): result = runner.invoke( main, [ @@ -604,8 +604,8 @@ def test_run_task_parsing_errors(params: str, error: str, runner: CliRunner): params, ], ) - assert result.stderr.startswith("Error: " + error) - assert result.exit_code == 1 + assert "Error: Invalid value for '[PARAMETERS]'" in result.stderr + assert result.exit_code == 2 def test_device_output_formatting(): @@ -1164,3 +1164,9 @@ def test_python_env_output_formatting(): """) _assert_matching_formatting(OutputFormat.FULL, empty_python_env, full) + + +@pytest.mark.parametrize("value,result", [({}, {}), ("{}", {}), (None, None)]) +def test_task_parameter_type(value, result): + t = ParametersType() + assert t.convert(value, None, None) == result