Skip to content

Commit e8ce156

Browse files
evvaaaacoretl
authored andcommitted
swapped to generic TypedDict
1 parent ffea866 commit e8ce156

File tree

2 files changed

+31
-32
lines changed

2 files changed

+31
-32
lines changed

src/ophyd_async/core/_derived_signal.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,36 @@
11
import asyncio
2-
import dataclasses
32
from abc import abstractmethod
4-
from typing import Generic, Self, TypeVar, get_args
3+
from typing import Generic, TypedDict, TypeVar, get_args
54

65
from ._device import Device
76
from ._protocol import AsyncMovable
87
from ._signal import SignalR, SignalRW
98
from ._signal_backend import SignalBackend, SignalDatatypeT
109

1110

12-
@dataclasses.dataclass
13-
class TransformArgument(Generic[SignalDatatypeT]):
14-
@classmethod
15-
async def get_dataclass_from_signals(cls, device: Device) -> Self:
16-
coros = {}
17-
for field in dataclasses.fields(cls):
18-
sig = getattr(device, field.name)
19-
assert isinstance(
20-
sig, SignalR
21-
), f"{device.name}.{field.name} is {sig}, not a Signal"
22-
coros[field.name] = sig.get_value()
23-
results = await asyncio.gather(*coros.values())
24-
kwargs = dict(zip(coros, results, strict=True))
25-
return cls(**kwargs)
11+
class TransformArgument(TypedDict, Generic[SignalDatatypeT]):
12+
pass
13+
14+
15+
T = TypeVar("T", bound=TransformArgument)
16+
17+
18+
async def _get_dataclass_from_signals(cls: type[T], device: Device) -> T:
19+
coros = {}
20+
for name in cls.__annotations__:
21+
signal = getattr(device, name)
22+
assert isinstance(
23+
signal, SignalR
24+
), f"{device.name}.{name} is {signal}, not a Signal"
25+
coros[name] = signal.get_value()
26+
results = await asyncio.gather(*coros.values())
27+
kwargs = dict(zip(coros, results, strict=True))
28+
return cls(**kwargs)
2629

2730

2831
RawT = TypeVar("RawT", bound=TransformArgument)
2932
DerivedT = TypeVar("DerivedT", bound=TransformArgument)
30-
ParametersT = TypeVar("ParametersT", bound=TransformArgument)
33+
ParametersT = TypeVar("ParametersT")
3134

3235

3336
class TransformMeta(type):
@@ -67,12 +70,12 @@ def __init__(
6770
self._transform = transform
6871

6972
async def get_parameters(self) -> ParametersT:
70-
return await self._transform.parameters_cls.get_dataclass_from_signals(
71-
self._device
73+
return await _get_dataclass_from_signals(
74+
self._transform.parameters_cls, self._device
7275
)
7376

7477
async def get_raw_values(self) -> RawT:
75-
return await self._transform.raw_cls.get_dataclass_from_signals(self._device)
78+
return await _get_dataclass_from_signals(self._transform.raw_cls, self._device)
7679

7780
async def get_derived_values(self) -> DerivedT:
7881
raw, parameters = await asyncio.gather(

tests/core/test_derived_signal.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
from dataclasses import dataclass
32
from typing import TypeVar
43

54
import numpy as np
@@ -36,7 +35,7 @@ class SomeTransform1(Transform[Raw, Derived]): ... # type: ignore
3635
TypeError,
3736
match=(
3837
"Transform classes must be defined with Raw, Derived, "
39-
"and Parameter args."
38+
"and Parameter `TransformArgument`s."
4039
),
4140
):
4241

@@ -54,39 +53,36 @@ class SomeTransform(Transform[Raw, Derived, Parameters]): ...
5453
F = TypeVar("F", float, Array1D[np.float64])
5554

5655

57-
@dataclass
5856
class SlitsRaw(TransformArgument[F]):
5957
top: F
6058
bottom: F
6159

6260

63-
@dataclass
6461
class SlitsDerived(TransformArgument[F]):
6562
gap: F
6663
centre: F
6764

6865

69-
@dataclass
70-
class SlitsParameters(TransformArgument[float]):
66+
class SlitsParameters(TransformArgument):
7167
gap_offset: float
7268

7369

7470
class SlitsTransform(Transform[SlitsRaw[F], SlitsDerived[F], SlitsParameters]):
7571
@classmethod
7672
def forward(cls, raw: SlitsRaw[F], parameters: SlitsParameters) -> SlitsDerived[F]:
7773
return SlitsDerived(
78-
gap=raw.top - raw.bottom + parameters.gap_offset,
79-
centre=(raw.top + raw.bottom) / 2,
74+
gap=raw["top"] - raw["bottom"] + parameters["gap_offset"],
75+
centre=(raw["top"] + raw["bottom"]) / 2,
8076
)
8177

8278
@classmethod
8379
def inverse(
8480
cls, derived: SlitsDerived[F], parameters: SlitsParameters
8581
) -> SlitsRaw[F]:
86-
half_gap = (derived.gap - parameters.gap_offset) / 2
82+
half_gap = (derived["gap"] - parameters["gap_offset"]) / 2
8783
return SlitsRaw(
88-
top=derived.centre + half_gap,
89-
bottom=derived.centre - half_gap,
84+
top=derived["centre"] + half_gap,
85+
bottom=derived["centre"] - half_gap,
9086
)
9187

9288

@@ -106,7 +102,7 @@ def __init__(self, name=""):
106102
@AsyncStatus.wrap
107103
async def set(self, derived: SlitsDerived[float]) -> None:
108104
raw: SlitsRaw[float] = await self._backend.calculate_raw_values(derived)
109-
await asyncio.gather(self.top.set(raw.top), self.bottom.set(raw.bottom))
105+
await asyncio.gather(self.top.set(raw["top"]), self.bottom.set(raw["bottom"]))
110106

111107

112108
async def test_derived_signals():

0 commit comments

Comments
 (0)