diff --git a/docs/how-to/write-plans.md b/docs/how-to/write-plans.md index 01ea2efaa..cba1ab495 100644 --- a/docs/how-to/write-plans.md +++ b/docs/how-to/write-plans.md @@ -24,7 +24,24 @@ The type annotations in the example above (e.g. `: str`, `: int`, `-> MsgGenerat ## Injecting Devices -Some plans are created for specific sets of devices, or will almost always be used with the same devices, it is useful to be able to specify defaults. [Dodal makes this easy with its factory functions](https://diamondlightsource.github.io/dodal/main/how-to/include-devices-in-plans.html). +Some plans are created for specific sets of devices, or will almost always be used with the same devices, it is useful to be able to specify defaults. [Dodal makes this easy with its inject function](https://diamondlightsource.github.io/dodal/main/reference/generated/dodal.common.html#dodal.common.inject). + +## Injecting multiple devices + +If a plan requires multiple devices to be injected at once, rather than have a plan with several device parameters each of them with their own injection default, it is possible to define a device composite which can be accepted as a parameter. + +For example you could define a composite as below: + +```{literalinclude} ../../tests/unit_tests/code_examples/device_composite.py +:language: python +``` + +Then in your plan module: + +```{literalinclude} ../../tests/unit_tests/code_examples/plan_with_composite.py +:language: python +``` + ## Injecting Metadata diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 31764b286..39bf9e317 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable -from dataclasses import InitVar, dataclass, field +from dataclasses import InitVar, dataclass, field, fields, is_dataclass from importlib import import_module from inspect import Parameter, isclass, signature from types import ModuleType, NoneType, UnionType @@ -11,7 +11,12 @@ from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider from dodal.utils import AnyDevice, make_all_devices from ophyd_async.core import NotConnected -from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, create_model +from pydantic import ( + BaseModel, + GetCoreSchemaHandler, + GetJsonSchemaHandler, + create_model, +) from pydantic.fields import FieldInfo from pydantic.json_schema import JsonSchemaValue, SkipJsonSchema from pydantic_core import CoreSchema, core_schema @@ -82,7 +87,7 @@ def is_bluesky_type(typ: type) -> bool: return typ in BLUESKY_PROTOCOLS or isinstance(typ, BLUESKY_PROTOCOLS) -C = TypeVar("C", bound=BaseModel, covariant=True) +C = TypeVar("C", covariant=True) @dataclass @@ -386,16 +391,19 @@ def _type_spec_for_function( ) no_default = para.default is Parameter.empty - default_factory = ( - self._composite_factory(arg_type) - if isclass(arg_type) - and issubclass(arg_type, BaseModel) + if ( + isclass(arg_type) + and (issubclass(arg_type, BaseModel) or is_dataclass(arg_type)) and isinstance(para.default, str) - else DefaultFactory(para.default) - ) + ): + default_factory = self._composite_factory(arg_type) + _type = SkipJsonSchema[self._convert_type(arg_type, no_default)] + else: + default_factory = DefaultFactory(para.default) + _type = self._convert_type(arg_type, no_default) factory = None if no_default else default_factory new_args[name] = ( - self._convert_type(arg_type, no_default), + _type, FieldInfo(default_factory=factory), ) return new_args @@ -431,14 +439,20 @@ def _convert_type(self, typ: type | Any, no_default: bool = True) -> type: def _composite_factory(self, composite_class: type[C]) -> Callable[[], C]: def _inject_composite(): - devices = { - field: self.find_device(info.default) - if info.annotation is not None - and is_bluesky_type(info.annotation) - and isinstance(info.default, str) - else info.default - for field, info in composite_class.model_fields.items() - } + if issubclass(composite_class, BaseModel): + devices = { + field_name: self.find_device(field_name) + for field_name in composite_class.model_fields.keys() + } + else: + assert is_dataclass(composite_class), ( + f"Unsupported composite type: {composite_class}, composite must be" + " a pydantic BaseModel or a dataclass" + ) + devices = { + field.name: self.find_device(field.name) + for field in fields(composite_class) + } return composite_class(**devices) return _inject_composite diff --git a/tests/unit_tests/code_examples/device_composite.py b/tests/unit_tests/code_examples/device_composite.py new file mode 100644 index 000000000..7f43811bb --- /dev/null +++ b/tests/unit_tests/code_examples/device_composite.py @@ -0,0 +1,8 @@ +import pydantic +from tests.unit_tests.code_examples.device_module import BimorphMirror + + +@pydantic.dataclasses.dataclass(config={"arbitrary_types_allowed": True}) +class MyDeviceComposite: + oav: BimorphMirror + # More devices here.... diff --git a/tests/unit_tests/code_examples/plan_with_composite.py b/tests/unit_tests/code_examples/plan_with_composite.py new file mode 100644 index 000000000..a03d31989 --- /dev/null +++ b/tests/unit_tests/code_examples/plan_with_composite.py @@ -0,0 +1,12 @@ +from bluesky.utils import MsgGenerator +from dodal.common import inject +from tests.unit_tests.code_examples.device_composite import MyDeviceComposite + + +def my_plan( + parameter_one: int, + parameter_two: str, + my_necessary_devices: MyDeviceComposite = inject(""), +) -> MsgGenerator[None]: + # logic goes here + ... diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index e66323dc5..a51b7b12d 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -1,3 +1,4 @@ +import dataclasses import itertools import threading from collections.abc import Callable, Iterable @@ -7,6 +8,7 @@ from typing import Any, TypeVar from unittest.mock import ANY, MagicMock, Mock, patch +import pydantic import pytest from bluesky.protocols import Movable, Status from bluesky.utils import MsgGenerator @@ -21,6 +23,7 @@ from blueapi.config import EnvironmentConfig, Source, SourceKind from blueapi.core import BlueskyContext, EventStream from blueapi.core.bluesky_types import DataEvent +from blueapi.service.model import PlanModel from blueapi.utils.base_model import BlueapiBaseModel from blueapi.worker import ( Task, @@ -661,6 +664,21 @@ def injected_device_plan( assert params["dev"] == fake_device +def test_injected_devices_plan_model( + fake_device: FakeDevice, + context: BlueskyContext, +): + def injected_device_plan( + dev: FakeDevice = inject(fake_device.name), + ) -> MsgGenerator: + yield from () + + context.register_plan(injected_device_plan) + plan = context.plans["injected_device_plan"] + model = PlanModel.from_plan(plan) + print(model) + + def test_missing_injected_devices_fail_early( context: BlueskyContext, ): @@ -699,16 +717,40 @@ def test_cycle_without_otel_context(mock_logger: Mock, inert_worker: TaskWorker) class MyComposite(BlueapiBaseModel): - dev_a: FakeDevice = inject(fake_device.name) - dev_b: FakeDevice = inject(second_fake_device.name) + fake_device: FakeDevice + second_fake_device: FakeDevice model_config = {"arbitrary_types_allowed": True} +@pydantic.dataclasses.dataclass(config={"arbitrary_types_allowed": True}) +class MyPydanticDataClassComposite: + fake_device: FakeDevice + second_fake_device: FakeDevice + + +@dataclasses.dataclass() +class MyStandardDataClassComposite: + fake_device: FakeDevice + second_fake_device: FakeDevice + + def injected_device_plan(composite: MyComposite = inject("")) -> MsgGenerator: yield from () +def injected_dataclass_device_plan( + composite: MyPydanticDataClassComposite = inject(""), +) -> MsgGenerator: + yield from () + + +def injected_standard_dataclass_device_plan( + composite: MyStandardDataClassComposite = inject(""), +) -> MsgGenerator: + yield from () + + def test_injected_composite_devices_are_found( fake_device: FakeDevice, second_fake_device: FakeDevice, @@ -716,8 +758,42 @@ def test_injected_composite_devices_are_found( ): context.register_plan(injected_device_plan) params = Task(name="injected_device_plan").prepare_params(context) - assert params["composite"].dev_a == fake_device - assert params["composite"].dev_b == second_fake_device + assert params["composite"].fake_device == fake_device + assert params["composite"].second_fake_device == second_fake_device + + +def test_injected_composite_devices_plan_model( + fake_device: FakeDevice, + second_fake_device: FakeDevice, + context: BlueskyContext, +): + context.register_plan(injected_device_plan) + plan = context.plans["injected_device_plan"] + PlanModel.from_plan(plan) + + +def test_injected_composite_with_pydantic_dataclass( + context: BlueskyContext, + fake_device: FakeDevice, + second_fake_device: FakeDevice, +): + context.register_plan(injected_dataclass_device_plan) + params = Task(name="injected_dataclass_device_plan").prepare_params(context) + assert params["composite"].fake_device == fake_device + assert params["composite"].second_fake_device == second_fake_device + + +def test_injected_composite_with_standard_dataclass( + context: BlueskyContext, + fake_device: FakeDevice, + second_fake_device: FakeDevice, +): + context.register_plan(injected_standard_dataclass_device_plan) + params = Task(name="injected_standard_dataclass_device_plan").prepare_params( + context + ) + assert params["composite"].fake_device == fake_device + assert params["composite"].second_fake_device == second_fake_device def test_plan_module_with_composite_devices_can_be_loaded_before_device_module( @@ -729,5 +805,5 @@ def test_plan_module_with_composite_devices_can_be_loaded_before_device_module( context_without_devices.register_device(fake_device) context_without_devices.register_device(second_fake_device) params = Task(name="injected_device_plan").prepare_params(context_without_devices) - assert params["composite"].dev_a == fake_device - assert params["composite"].dev_b == second_fake_device + assert params["composite"].fake_device == fake_device + assert params["composite"].second_fake_device == second_fake_device