Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion docs/how-to/write-plans.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,24 @@ The type annotations in the example above (e.g. `: str`, `: int`, `-> MsgGenerat

## Injecting Devices

Some plans are created for specific sets of devices, or will almost always be used with the same devices, it is useful to be able to specify defaults. [Dodal makes this easy with its factory functions](https://diamondlightsource.github.io/dodal/main/how-to/include-devices-in-plans.html).
Some plans are created for specific sets of devices, or will almost always be used with the same devices, it is useful to be able to specify defaults. [Dodal makes this easy with its inject function](https://diamondlightsource.github.io/dodal/main/reference/generated/dodal.common.html#dodal.common.inject).

## Injecting multiple devices

If a plan requires multiple devices to be injected at once, rather than have a plan with several device parameters each of them with their own injection default, it is possible to define a device composite which can be accepted as a parameter.

For example you could define a composite as below:

```{literalinclude} ../../tests/unit_tests/code_examples/device_composite.py
:language: python
```

Then in your plan module:

```{literalinclude} ../../tests/unit_tests/code_examples/plan_with_composite.py
:language: python
```


## Injecting Metadata

Expand Down
50 changes: 32 additions & 18 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections.abc import Callable
from dataclasses import InitVar, dataclass, field
from dataclasses import InitVar, dataclass, field, fields, is_dataclass
from importlib import import_module
from inspect import Parameter, isclass, signature
from types import ModuleType, NoneType, UnionType
Expand All @@ -11,7 +11,12 @@
from dodal.common.beamlines.beamline_utils import get_path_provider, set_path_provider
from dodal.utils import AnyDevice, make_all_devices
from ophyd_async.core import NotConnected
from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, create_model
from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
create_model,
)
from pydantic.fields import FieldInfo
from pydantic.json_schema import JsonSchemaValue, SkipJsonSchema
from pydantic_core import CoreSchema, core_schema
Expand Down Expand Up @@ -82,7 +87,7 @@ def is_bluesky_type(typ: type) -> bool:
return typ in BLUESKY_PROTOCOLS or isinstance(typ, BLUESKY_PROTOCOLS)


C = TypeVar("C", bound=BaseModel, covariant=True)
C = TypeVar("C", covariant=True)


@dataclass
Expand Down Expand Up @@ -386,16 +391,19 @@ def _type_spec_for_function(
)

no_default = para.default is Parameter.empty
default_factory = (
self._composite_factory(arg_type)
if isclass(arg_type)
and issubclass(arg_type, BaseModel)
if (
isclass(arg_type)
and (issubclass(arg_type, BaseModel) or is_dataclass(arg_type))
and isinstance(para.default, str)
else DefaultFactory(para.default)
)
):
default_factory = self._composite_factory(arg_type)
_type = SkipJsonSchema[self._convert_type(arg_type, no_default)]
else:
default_factory = DefaultFactory(para.default)
_type = self._convert_type(arg_type, no_default)
factory = None if no_default else default_factory
new_args[name] = (
self._convert_type(arg_type, no_default),
_type,
FieldInfo(default_factory=factory),
)
return new_args
Expand Down Expand Up @@ -431,14 +439,20 @@ def _convert_type(self, typ: type | Any, no_default: bool = True) -> type:

def _composite_factory(self, composite_class: type[C]) -> Callable[[], C]:
def _inject_composite():
devices = {
field: self.find_device(info.default)
if info.annotation is not None
and is_bluesky_type(info.annotation)
and isinstance(info.default, str)
else info.default
for field, info in composite_class.model_fields.items()
}
if issubclass(composite_class, BaseModel):
devices = {
field_name: self.find_device(field_name)
for field_name in composite_class.model_fields.keys()
}
else:
assert is_dataclass(composite_class), (
f"Unsupported composite type: {composite_class}, composite must be"
" a pydantic BaseModel or a dataclass"
)
devices = {
field.name: self.find_device(field.name)
for field in fields(composite_class)
}
return composite_class(**devices)

return _inject_composite
Expand Down
8 changes: 8 additions & 0 deletions tests/unit_tests/code_examples/device_composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pydantic
from tests.unit_tests.code_examples.device_module import BimorphMirror


@pydantic.dataclasses.dataclass(config={"arbitrary_types_allowed": True})
class MyDeviceComposite:
oav: BimorphMirror
# More devices here....
12 changes: 12 additions & 0 deletions tests/unit_tests/code_examples/plan_with_composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from bluesky.utils import MsgGenerator
from dodal.common import inject
from tests.unit_tests.code_examples.device_composite import MyDeviceComposite


def my_plan(
parameter_one: int,
parameter_two: str,
my_necessary_devices: MyDeviceComposite = inject(""),
) -> MsgGenerator[None]:
# logic goes here
...
88 changes: 82 additions & 6 deletions tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import itertools
import threading
from collections.abc import Callable, Iterable
Expand All @@ -7,6 +8,7 @@
from typing import Any, TypeVar
from unittest.mock import ANY, MagicMock, Mock, patch

import pydantic
import pytest
from bluesky.protocols import Movable, Status
from bluesky.utils import MsgGenerator
Expand All @@ -21,6 +23,7 @@
from blueapi.config import EnvironmentConfig, Source, SourceKind
from blueapi.core import BlueskyContext, EventStream
from blueapi.core.bluesky_types import DataEvent
from blueapi.service.model import PlanModel
from blueapi.utils.base_model import BlueapiBaseModel
from blueapi.worker import (
Task,
Expand Down Expand Up @@ -661,6 +664,21 @@ def injected_device_plan(
assert params["dev"] == fake_device


def test_injected_devices_plan_model(
fake_device: FakeDevice,
context: BlueskyContext,
):
def injected_device_plan(
dev: FakeDevice = inject(fake_device.name),
) -> MsgGenerator:
yield from ()

context.register_plan(injected_device_plan)
plan = context.plans["injected_device_plan"]
model = PlanModel.from_plan(plan)
print(model)


def test_missing_injected_devices_fail_early(
context: BlueskyContext,
):
Expand Down Expand Up @@ -699,25 +717,83 @@ def test_cycle_without_otel_context(mock_logger: Mock, inert_worker: TaskWorker)


class MyComposite(BlueapiBaseModel):
dev_a: FakeDevice = inject(fake_device.name)
dev_b: FakeDevice = inject(second_fake_device.name)
fake_device: FakeDevice
second_fake_device: FakeDevice

model_config = {"arbitrary_types_allowed": True}


@pydantic.dataclasses.dataclass(config={"arbitrary_types_allowed": True})
class MyPydanticDataClassComposite:
fake_device: FakeDevice
second_fake_device: FakeDevice


@dataclasses.dataclass()
class MyStandardDataClassComposite:
fake_device: FakeDevice
second_fake_device: FakeDevice


def injected_device_plan(composite: MyComposite = inject("")) -> MsgGenerator:
yield from ()


def injected_dataclass_device_plan(
composite: MyPydanticDataClassComposite = inject(""),
) -> MsgGenerator:
yield from ()


def injected_standard_dataclass_device_plan(
composite: MyStandardDataClassComposite = inject(""),
) -> MsgGenerator:
yield from ()


def test_injected_composite_devices_are_found(
fake_device: FakeDevice,
second_fake_device: FakeDevice,
context: BlueskyContext,
):
context.register_plan(injected_device_plan)
params = Task(name="injected_device_plan").prepare_params(context)
assert params["composite"].dev_a == fake_device
assert params["composite"].dev_b == second_fake_device
assert params["composite"].fake_device == fake_device
assert params["composite"].second_fake_device == second_fake_device


def test_injected_composite_devices_plan_model(
fake_device: FakeDevice,
second_fake_device: FakeDevice,
context: BlueskyContext,
):
context.register_plan(injected_device_plan)
plan = context.plans["injected_device_plan"]
PlanModel.from_plan(plan)


def test_injected_composite_with_pydantic_dataclass(
context: BlueskyContext,
fake_device: FakeDevice,
second_fake_device: FakeDevice,
):
context.register_plan(injected_dataclass_device_plan)
params = Task(name="injected_dataclass_device_plan").prepare_params(context)
assert params["composite"].fake_device == fake_device
assert params["composite"].second_fake_device == second_fake_device


def test_injected_composite_with_standard_dataclass(
context: BlueskyContext,
fake_device: FakeDevice,
second_fake_device: FakeDevice,
):
context.register_plan(injected_standard_dataclass_device_plan)
params = Task(name="injected_standard_dataclass_device_plan").prepare_params(
context
)
assert params["composite"].fake_device == fake_device
assert params["composite"].second_fake_device == second_fake_device


def test_plan_module_with_composite_devices_can_be_loaded_before_device_module(
Expand All @@ -729,5 +805,5 @@ def test_plan_module_with_composite_devices_can_be_loaded_before_device_module(
context_without_devices.register_device(fake_device)
context_without_devices.register_device(second_fake_device)
params = Task(name="injected_device_plan").prepare_params(context_without_devices)
assert params["composite"].dev_a == fake_device
assert params["composite"].dev_b == second_fake_device
assert params["composite"].fake_device == fake_device
assert params["composite"].second_fake_device == second_fake_device