Skip to content

Commit 1720ad3

Browse files
OSC typing, harden Capture (#462)
* Tighten OSC typing * Harden Capture
1 parent 096233c commit 1720ad3

File tree

7 files changed

+51
-52
lines changed

7 files changed

+51
-52
lines changed

supriya/contexts/core.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import itertools
1111
import re
1212
import threading
13+
from collections.abc import Sequence as SequenceABC
1314
from os import PathLike
1415
from typing import (
1516
Callable,
@@ -1179,7 +1180,7 @@ def read_buffer(
11791180
return self._add_request_with_completion(request, on_completion)
11801181

11811182
@abc.abstractmethod
1182-
def send(self, message: SupportsOsc):
1183+
def send(self, message: Union[SupportsOsc, SequenceABC, str]):
11831184
"""
11841185
Send a message to the execution context.
11851186

supriya/contexts/nonrealtime.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import shlex
1010
import shutil
1111
import struct
12+
from collections.abc import Sequence as SequenceABC
1213
from contextlib import ExitStack
1314
from os import PathLike
1415
from pathlib import Path
@@ -238,7 +239,7 @@ def iterate_request_bundles(
238239
if until and until > timestamp:
239240
yield RequestBundle(timestamp=until, contents=[DoNothing()])
240241

241-
def send(self, message: SupportsOsc) -> None:
242+
def send(self, message: Union[SupportsOsc, SequenceABC, str]):
242243
"""
243244
Send a message to the execution context.
244245

supriya/contexts/realtime.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from ..osc import (
4545
AsyncOscProtocol,
4646
HealthCheck,
47-
OscBundle,
4847
OscCallback,
4948
OscMessage,
5049
OscProtocol,
@@ -374,9 +373,7 @@ def _validate_moment_timestamp(self, seconds: Optional[float]) -> None:
374373

375374
### PUBLIC METHODS ###
376375

377-
def send(
378-
self, message: Union[OscMessage, OscBundle, SupportsOsc, SequenceABC, str]
379-
) -> None:
376+
def send(self, message: Union[SupportsOsc, SequenceABC, str]) -> None:
380377
"""
381378
Send a message to the execution context.
382379
@@ -385,9 +382,7 @@ def send(
385382
if self._boot_status == BootStatus.OFFLINE:
386383
raise ServerOffline
387384
osc_protocol: OscProtocol = getattr(self, "_osc_protocol")
388-
osc_protocol.send(
389-
message.to_osc() if isinstance(message, SupportsOsc) else message
390-
)
385+
osc_protocol.send(message)
391386

392387
def set_latency(self, latency: float) -> None:
393388
"""

supriya/osc/asynchronous.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from typing import Awaitable, Callable, Dict, Optional, Sequence, Set, Tuple, Union
44

55
from ..enums import BootStatus
6-
from ..typing import FutureLike
7-
from .messages import OscBundle, OscMessage
6+
from ..typing import FutureLike, SupportsOsc
7+
from .messages import OscMessage
88
from .protocols import (
99
HealthCheck,
1010
OscCallback,
@@ -196,7 +196,7 @@ def register(
196196
)
197197
return callback
198198

199-
def send(self, message: Union[OscBundle, OscMessage, SequenceABC, str]) -> None:
199+
def send(self, message: Union[SupportsOsc, SequenceABC, str]) -> None:
200200
self.transport.sendto(self._send(message))
201201

202202
def unregister(self, callback: OscCallback) -> None:

supriya/osc/messages.py

+6
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ def to_list(self):
297297
result.append(x)
298298
return result
299299

300+
def to_osc(self) -> "OscMessage":
301+
return self
302+
300303

301304
class OscBundle:
302305
"""
@@ -471,3 +474,6 @@ def to_list(self):
471474
result = [self.timestamp]
472475
result.append([x.to_list() for x in self.contents])
473476
return result
477+
478+
def to_osc(self) -> "OscBundle":
479+
return self

supriya/osc/protocols.py

+33-37
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
Awaitable,
1010
Callable,
1111
Dict,
12+
Iterator,
1213
List,
14+
Literal,
1315
NamedTuple,
1416
Optional,
1517
Sequence,
@@ -19,7 +21,7 @@
1921
)
2022

2123
from ..enums import BootStatus
22-
from ..typing import FutureLike
24+
from ..typing import FutureLike, SupportsOsc
2325
from .messages import OscBundle, OscMessage
2426

2527
osc_protocol_logger = logging.getLogger(__name__)
@@ -65,31 +67,35 @@ class HealthCheck:
6567

6668
class CaptureEntry(NamedTuple):
6769
timestamp: float
68-
label: str
70+
label: Literal["R", "S"]
6971
message: Union[OscMessage, OscBundle]
72+
raw_message: Union[SupportsOsc, SequenceABC, str] | None = None
7073

7174

7275
class Capture:
7376
### INITIALIZER ###
7477

75-
def __init__(self, osc_protocol):
78+
def __init__(self, osc_protocol: "OscProtocol") -> None:
7679
self.osc_protocol = osc_protocol
77-
self.messages = []
80+
self.messages: list[CaptureEntry] = []
7881

7982
### SPECIAL METHODS ###
8083

81-
def __enter__(self):
84+
def __enter__(self) -> "Capture":
8285
self.osc_protocol.captures.add(self)
8386
self.messages[:] = []
8487
return self
8588

86-
def __exit__(self, exc_type, exc_value, traceback):
89+
def __exit__(self, exc_type, exc_value, traceback) -> None:
8790
self.osc_protocol.captures.remove(self)
8891

89-
def __iter__(self):
92+
def __getitem__(self, i: int | slice) -> CaptureEntry | list[CaptureEntry]:
93+
return self.messages[i]
94+
95+
def __iter__(self) -> Iterator[CaptureEntry]:
9096
return iter(self.messages)
9197

92-
def __len__(self):
98+
def __len__(self) -> int:
9399
return len(self.messages)
94100

95101
### PUBLIC METHODS ###
@@ -98,7 +104,7 @@ def filtered(
98104
self, sent=True, received=True, status=True
99105
) -> List[Union[OscBundle, OscMessage]]:
100106
messages = []
101-
for _, label, message in self.messages:
107+
for _, label, message, _ in self.messages:
102108
if label == "R" and not received:
103109
continue
104110
if label == "S" and not sent:
@@ -112,24 +118,6 @@ def filtered(
112118
messages.append(message)
113119
return messages
114120

115-
### PUBLIC PROPERTIES ###
116-
117-
@property
118-
def received_messages(self):
119-
return [
120-
(timestamp, osc_message)
121-
for timestamp, label, osc_message in self.messages
122-
if label == "R"
123-
]
124-
125-
@property
126-
def sent_messages(self):
127-
return [
128-
(timestamp, osc_message)
129-
for timestamp, label, osc_message in self.messages
130-
if label == "S"
131-
]
132-
133121

134122
class OscProtocol:
135123
### INITIALIZER ###
@@ -293,27 +281,35 @@ def _register(
293281
kwargs=kwargs,
294282
)
295283

296-
def _send(self, message):
284+
def _send(self, raw_message: Union[SupportsOsc, SequenceABC, str]) -> bytes:
297285
if self.status not in (BootStatus.BOOTING, BootStatus.ONLINE):
298286
raise OscProtocolOffline
299-
if not isinstance(message, (str, SequenceABC, OscBundle, OscMessage)):
300-
raise ValueError(message)
301-
if isinstance(message, str):
302-
message = OscMessage(message)
303-
elif isinstance(message, SequenceABC):
304-
message = OscMessage(*message)
287+
if not isinstance(raw_message, (str, SequenceABC, SupportsOsc)):
288+
raise ValueError(raw_message)
289+
message: OscBundle | OscMessage
290+
if isinstance(raw_message, str):
291+
message = OscMessage(raw_message)
292+
elif isinstance(raw_message, SequenceABC):
293+
message = OscMessage(*raw_message)
294+
else:
295+
message = raw_message.to_osc()
305296
osc_out_logger.debug(
306297
f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] "
307298
f"{message!r}"
308299
)
309300
for capture in self.captures:
310301
capture.messages.append(
311-
CaptureEntry(timestamp=time.time(), label="S", message=message)
302+
CaptureEntry(
303+
timestamp=time.time(),
304+
label="S",
305+
message=message,
306+
raw_message=raw_message,
307+
)
312308
)
313309
datagram = message.to_datagram()
314310
udp_out_logger.debug(
315311
f"[{self.ip_address}:{self.port}/{self.name or hex(id(self))}] "
316-
f"{datagram}"
312+
f"{datagram!r}"
317313
)
318314
return datagram
319315

@@ -373,7 +369,7 @@ def register(
373369
) -> OscCallback:
374370
raise NotImplementedError
375371

376-
def send(self, message: Union[OscBundle, OscMessage, SequenceABC, str]) -> None:
372+
def send(self, message: Union[SupportsOsc, SequenceABC, str]) -> None:
377373
raise NotImplementedError
378374

379375
def unregister(self, callback: OscCallback) -> None:

supriya/osc/threaded.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
)
1818

1919
from ..enums import BootStatus
20-
from ..typing import FutureLike
21-
from .messages import OscBundle, OscMessage
20+
from ..typing import FutureLike, SupportsOsc
21+
from .messages import OscMessage
2222
from .protocols import (
2323
HealthCheck,
2424
OscCallback,
@@ -229,7 +229,7 @@ def register(
229229
)
230230
return callback
231231

232-
def send(self, message: Union[OscBundle, OscMessage, SequenceABC, str]) -> None:
232+
def send(self, message: Union[SupportsOsc, SequenceABC, str]) -> None:
233233
try:
234234
self.osc_server.socket.sendto(
235235
self._send(message),

0 commit comments

Comments
 (0)