Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from functools import wraps
from pathlib import Path
from pprint import pprint
from typing import Any, TypeGuard

import click
from bluesky.callbacks.best_effort import BestEffortCallback
from bluesky_stomp.messaging import MessageContext, StompClient
from bluesky_stomp.models import Broker
from click.core import Context, Parameter
from click.exceptions import ClickException
from click.types import ParamType
from observability_utils.tracing import setup_tracing
from pydantic import ValidationError
from requests.exceptions import ConnectionError

from blueapi import __version__, config
from blueapi.cli.format import OutputFormat
from blueapi.client.client import BlueapiClient
from blueapi.client.client import BlueapiClient, TaskParameters
from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient
from blueapi.client.rest import (
BlueskyRemoteControlError,
Expand All @@ -34,12 +36,39 @@
from blueapi.log import set_up_logging
from blueapi.service.authentication import SessionCacheManager, SessionManager
from blueapi.service.model import SourceInfo
from blueapi.worker import ProgressEvent, Task, WorkerEvent
from blueapi.worker import ProgressEvent, WorkerEvent

from .scratch import setup_scratch
from .updates import CliEventRenderer


class ParametersType(ParamType):
"""CLI input parameter to accept a JSON object as an argument"""

name = "TaskParameters"

def convert(
self,
value: str | dict[str, Any] | None,
param: Parameter | None,
ctx: Context | None,
) -> TaskParameters:
if isinstance(value, str):
try:
params = json.loads(value)
if is_str_dict(params):
return params
self.fail("Parameters must be a JSON object with string keys")
except json.JSONDecodeError as jde:
self.fail(f"Parameters are not valid JSON: {jde}")
else:
return super().convert(value, param, ctx)


def is_str_dict(val: Any) -> TypeGuard[TaskParameters]:
return isinstance(val, dict) and all(isinstance(k, str) for k in val)


@click.group(
invoke_without_command=True, context_settings={"auto_envvar_prefix": "BLUEAPI"}
)
Expand Down Expand Up @@ -220,7 +249,7 @@ def on_event(

@controller.command(name="run")
@click.argument("name", type=str)
@click.argument("parameters", type=str, required=False)
@click.argument("parameters", type=ParametersType(), default={}, required=False)
@click.option(
"--foreground/--background", "--fg/--bg", type=bool, is_flag=True, default=True
)
Expand All @@ -236,25 +265,13 @@ def on_event(
def run_plan(
obj: dict,
name: str,
parameters: str | None,
timeout: float | None,
foreground: bool,
parameters: TaskParameters,
) -> None:
"""Run a plan with parameters"""
client: BlueapiClient = obj["client"]

parameters = parameters or "{}"
try:
parsed_params = json.loads(parameters) if isinstance(parameters, str) else {}
except json.JSONDecodeError as jde:
raise ClickException(f"Parameters are not valid JSON: {jde}") from jde

try:
task = Task(name=name, params=parsed_params)
except ValidationError as ve:
ip = InvalidParameters.from_validation_error(ve)
raise ClickException(ip.message()) from ip

try:
if foreground:
progress_bar = CliEventRenderer()
Expand All @@ -266,12 +283,12 @@ def on_event(event: AnyEvent) -> None:
elif isinstance(event, DataEvent):
callback(event.name, event.doc)

resp = client.run_task(task, on_event=on_event)
resp = client.run_task(name, parameters, on_event=on_event)

if resp.task_status is not None and not resp.task_status.task_failed:
print("Plan Succeeded")
else:
server_task = client.create_and_start_task(task)
server_task = client.create_and_start_task(name, parameters)
click.echo(server_task.task_id)
except config.MissingStompConfiguration as mse:
raise ClickException(*mse.args) from mse
Expand Down
26 changes: 18 additions & 8 deletions src/blueapi/client/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from concurrent.futures import Future
from typing import Any

from bluesky_stomp.messaging import MessageContext, StompClient
from bluesky_stomp.models import Broker
Expand Down Expand Up @@ -32,6 +33,8 @@

TRACER = get_tracer("client")

TaskParameters = dict[str, Any]


class BlueapiClient:
"""Unified client for controlling blueapi"""
Expand Down Expand Up @@ -194,10 +197,12 @@ def get_active_task(self) -> WorkerTask:

return self._rest.get_active_task()

@start_as_current_span(TRACER, "task", "timeout")
@start_as_current_span(TRACER, "name", "parameters", "timeout")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the problems we've had with the logs being spammed with junk I'm hesitant to add a dict here, but I suppose better to add it and change the underlying behaviour that's causing problems.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the same content would have ended up in the logs when a task was being passed around. It still contained the same dict.

def run_task(
self,
task: Task,
name: str,
parameters: TaskParameters | None = None,
*,
on_event: OnAnyEvent | None = None,
timeout: float | None = None,
) -> WorkerEvent:
Expand All @@ -220,7 +225,7 @@ def run_task(
"Stomp configuration required to run plans is missing or disabled"
)

task_response = self.create_task(task)
task_response = self.create_task(name, parameters or {})
task_id = task_response.task_id

complete: Future[WorkerEvent] = Future()
Expand Down Expand Up @@ -257,8 +262,10 @@ def inner_on_event(event: AnyEvent, ctx: MessageContext) -> None:
self.start_task(WorkerTask(task_id=task_id))
return complete.result(timeout=timeout)

@start_as_current_span(TRACER, "task")
def create_and_start_task(self, task: Task) -> TaskResponse:
@start_as_current_span(TRACER, "name", "parameters")
def create_and_start_task(
self, name: str, parameters: TaskParameters | None = None
) -> TaskResponse:
"""
Create a new task and instruct the worker to start it
immediately.
Expand All @@ -270,7 +277,7 @@ def create_and_start_task(self, task: Task) -> TaskResponse:
TaskResponse: Acknowledgement of request
"""

response = self.create_task(task)
response = self.create_task(name, parameters or {})
worker_response = self.start_task(WorkerTask(task_id=response.task_id))
if worker_response.task_id == response.task_id:
return response
Expand All @@ -280,8 +287,10 @@ def create_and_start_task(self, task: Task) -> TaskResponse:
f"but {worker_response.task_id} was started instead"
)

@start_as_current_span(TRACER, "task")
def create_task(self, task: Task) -> TaskResponse:
@start_as_current_span(TRACER, "name", "parameters")
def create_task(
self, name: str, parameters: TaskParameters | None = None
) -> TaskResponse:
"""
Create a new task, does not start execution

Expand All @@ -292,6 +301,7 @@ def create_task(self, task: Task) -> TaskResponse:
TaskResponse: Acknowledgement of request
"""

task = Task(name=name, params=parameters or {})
return self._rest.create_task(task)

@start_as_current_span(TRACER)
Expand Down
68 changes: 34 additions & 34 deletions tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import time
from pathlib import Path
from typing import Any

import pytest
from bluesky_stomp.models import BasicAuthentication
Expand All @@ -26,11 +27,10 @@
WorkerTask,
)
from blueapi.worker.event import TaskStatus, WorkerEvent, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TrackableTask

_SIMPLE_TASK = Task(name="sleep", params={"time": 0.0})
_LONG_TASK = Task(name="sleep", params={"time": 1.0})
_SIMPLE_TASK = ("sleep", {"time": 0.0})
_LONG_TASK = ("sleep", {"time": 1.0})

_DATA_PATH = Path(__file__).parent

Expand Down Expand Up @@ -182,19 +182,19 @@ def test_get_non_existent_device(client: BlueapiClient):


def test_create_task_and_delete_task_by_id(client: BlueapiClient):
create_task = client.create_task(_SIMPLE_TASK)
create_task = client.create_task(*_SIMPLE_TASK)
client.clear_task(create_task.task_id)


def test_create_task_validation_error(client: BlueapiClient):
with pytest.raises(UnknownPlan):
client.create_task(Task(name="Not-exists", params={"Not-exists": 0.0}))
client.create_task("Not-exists", {"Not-exists": 0.0})


def test_get_all_tasks(client: BlueapiClient):
created_tasks: list[TaskResponse] = []
for task in [_SIMPLE_TASK, _LONG_TASK]:
created_task = client.create_task(task)
created_task = client.create_task(*task)
created_tasks.append(created_task)
task_ids = [task.task_id for task in created_tasks]

Expand All @@ -208,7 +208,7 @@ def test_get_all_tasks(client: BlueapiClient):


def test_get_task_by_id(client: BlueapiClient):
created_task = client.create_task(_SIMPLE_TASK)
created_task = client.create_task(*_SIMPLE_TASK)

get_task = client.get_task(created_task.task_id)
assert (
Expand All @@ -232,16 +232,16 @@ def test_delete_non_existent_task(client: BlueapiClient):


def test_put_worker_task(client: BlueapiClient):
created_task = client.create_task(_SIMPLE_TASK)
created_task = client.create_task(*_SIMPLE_TASK)
client.start_task(WorkerTask(task_id=created_task.task_id))
active_task = client.get_active_task()
assert active_task.task_id == created_task.task_id
client.clear_task(created_task.task_id)


def test_put_worker_task_fails_if_not_idle(client: BlueapiClient):
small_task = client.create_task(_SIMPLE_TASK)
long_task = client.create_task(_LONG_TASK)
small_task = client.create_task(*_SIMPLE_TASK)
long_task = client.create_task(*_LONG_TASK)

client.start_task(WorkerTask(task_id=long_task.task_id))
active_task = client.get_active_task()
Expand Down Expand Up @@ -269,8 +269,8 @@ def test_set_state_transition_error(client: BlueapiClient):


def test_get_task_by_status(client: BlueapiClient):
task_1 = client.create_task(_SIMPLE_TASK)
task_2 = client.create_task(_SIMPLE_TASK)
task_1 = client.create_task(*_SIMPLE_TASK)
task_2 = client.create_task(*_SIMPLE_TASK)
task_by_pending = client.get_all_tasks()
# https://github.com/DiamondLightSource/blueapi/issues/680
# task_by_pending = client.get_tasks_by_status(TaskStatusEnum.PENDING)
Expand Down Expand Up @@ -305,7 +305,7 @@ def test_progress_with_stomp(client_with_stomp: BlueapiClient):
def on_event(event: AnyEvent):
all_events.append(event)

client_with_stomp.run_task(_SIMPLE_TASK, on_event=on_event)
client_with_stomp.run_task(*_SIMPLE_TASK, on_event=on_event)
assert isinstance(all_events[0], WorkerEvent) and all_events[0].task_status
task_id = all_events[0].task_status.task_id
assert all_events == [
Expand Down Expand Up @@ -350,56 +350,56 @@ def test_delete_current_environment(client: BlueapiClient):


@pytest.mark.parametrize(
"task",
"plan,params",
[
Task(
name="count",
params={
(
"count",
{
"detectors": [
"image_det",
"current_det",
],
"num": 5,
},
),
Task(
name="spec_scan",
params={
(
"spec_scan",
{
"detectors": [
"image_det",
"current_det",
],
"spec": Line("x", 0.0, 10.0, 2) * Line("y", 5.0, 15.0, 3),
},
),
Task(
name="set_absolute",
params={
(
"set_absolute",
{
"movable": "dynamic_motor",
"value": "bar",
},
),
Task(
name="motor_plan",
params={
(
"motor_plan",
{
"motor": "movable_motor",
},
),
Task(
name="motor_plan",
params={
(
"motor_plan",
{
"motor": "dynamic_motor",
},
),
Task(
name="dataclass_motor_plan",
params={
(
"dataclass_motor_plan",
{
"motor": "data_class_motor",
},
),
],
)
def test_plan_runs(client_with_stomp: BlueapiClient, task: Task):
final_event = client_with_stomp.run_task(task)
def test_plan_runs(client_with_stomp: BlueapiClient, plan: str, params: dict[str, Any]):
final_event = client_with_stomp.run_task(plan, params)
assert final_event.is_complete() and not final_event.is_error()
assert final_event.state is WorkerState.IDLE
Loading
Loading