diff --git a/meshtastic/connection.py b/meshtastic/connection.py new file mode 100644 index 00000000..004a18b2 --- /dev/null +++ b/meshtastic/connection.py @@ -0,0 +1,285 @@ +""" +low-level radio connection API +""" +import asyncio +import io +import logging +from abc import ABC, abstractmethod +from typing import * + +import serial +import serial_asyncio +from bleak import BleakClient, BLEDevice, BleakScanner + +from meshtastic.protobuf.mesh_pb2 import FromRadio, ToRadio + + +STREAM_HEADER_MAGIC: bytes = b"\x94\xc3" # magic number used in streaming client headers +DEFAULT_BAUDRATE: int = 115200 +DEFAULT_TCP_PORT: int = 4403 +BLE_SERVICE_UUID: str = "6ba1b218-15a8-461f-9fa8-5dcae273eafd" +BLE_TORADIO_UUID: str = "f75c76d2-129e-4dad-a1dd-7866124401e7" +BLE_FROMRADIO_UUID: str = "2c55e69e-4993-11ed-b878-0242ac120002" +BLE_FROMNUM_UUID: str = "ed9da18c-a800-4f66-a670-aa7547e34453" +BLE_LEGACY_LOGRADIO_UUID: str = "6c6fd238-78fa-436b-aacf-15c5be1ef2e2" +BLE_LOGRADIO_UUID: str = "5a3d6e49-06e6-4423-9944-e9de8cdf9547" + + +class RadioConnectionError(Exception): + """Base class for RadioConnection-related errors.""" + + +class BadPayloadError(RadioConnectionError): + """Error indicating invalid payload over connection""" + def __init__(self, payload, reason: str): + self.payload = payload + super().__init__(reason) + + +class ConnectionTerminatedError(RadioConnectionError): + """Error indicating the connection was terminated.""" + + +class RadioConnection(ABC): + """A client API connection to a meshtastic radio.""" + + def __init__(self, name: str): + self.name: str = name + self.on_ready: asyncio.Event = asyncio.Event() + self.on_disconnect: asyncio.Event = asyncio.Event() + self._send_lock: asyncio.Lock = asyncio.Lock() + self._recv_lock: asyncio.Lock = asyncio.Lock() + self._listen_lock: asyncio.Lock = asyncio.Lock() + + @abstractmethod + async def _initialize(self): + """Perform any connection initialization that must be performed async + (and therefore not from the constructor).""" + + @abstractmethod + async def _send_bytes(self, msg: bytes): + """Send bytes to the mesh device.""" + + @abstractmethod + async def _recv_bytes(self) -> bytes: + """Recieve bytes from the mesh device.""" + + @staticmethod + @abstractmethod + async def get_available() -> AsyncGenerator[Any]: + """Enumerate any mesh devices that can be connected to. + + Generates values that can be passed to the concrete connection class's + constructor.""" + + def ready(self): + """Returns if the connection is ready for tx/rx""" + return self.on_ready.is_set() + + async def open(self): + """Start the connection""" + await self._initialize() + self.on_ready.set() + logging.info(f"Connected to mesh radio {self.name}") + + def _ensure_ready(self): + """Raise an exception if the connection is not ready for tx/rx""" + if not self.ready(): + raise RadioConnectionError("Connection used before it was ready") + + async def send(self, message: ToRadio): + """Send something to the connected device.""" + self._ensure_ready() + async with self._send_lock: + msg_str: str = message.SerializeToString() + await self._send_bytes(bytes(msg_str)) + + async def recv(self) -> FromRadio: + """Recieve something from the connected device.""" + self._ensure_ready() + async with self._recv_lock: + msg_bytes: bytes = await self._recv_bytes() + return FromRadio.FromString(str(msg_bytes, errors="ignore")) + + async def listen(self) -> AsyncGenerator[FromRadio]: + """Yields new messages from the radio so long as the connection is active.""" + self._ensure_ready() + async with self._listen_lock: + while not self.on_disconnect.is_set(): + yield await self.recv() + + async def close(self): + """Close the connection. + Overloaders should remember to call supermethod""" + self.on_ready.unset() + self.on_disconnect.set() + + async def __aenter__(self): + await self.open() + return self + + async def __aexit__(self, exc_type, exc_value, trace): + await self.close() + + #def __enter__(self): + # self.open() + # asyncio.run(self._init_task) + # return self + + #def __exit__(self, exc_type, exc_value, trace): + # self.close() + + +class StreamConnection(RadioConnection): + """Base class for connections using the aio stream API""" + def __init__(self, name: str): + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self.stream_debug_out: io.StringIO = io.StringIO() + super().__init__(name) + + def _handle_debug(self, debug_out: bytes): + self.stream_debug_out.write(str(debug_out)) + self.stream_debug_out.flush() + + async def _send_bytes(self, msg: bytes): + length: int = len(msg) + if length > 512: + raise BadPayloadError(msg, "Cannot send client API messages over 512 bytes") + + self._writer.write(STREAM_HEADER_MAGIC) + self._writer.write(length.to_bytes(2, "big")) + self._writer.write(msg) + await self._writer.drain() + + async def _find_stream_header(self): + """Consumes and logs debug out bytes until a valid header is detected""" + try: + while True: + from_stream: bytes = await self._reader.readuntil((b'\n', STREAM_HEADER_MAGIC)) + if from_stream.endswith(STREAM_HEADER_MAGIC): + self._handle_debug(from_stream[:-2]) + return + else: + self._handle_debug(from_stream) + + except asyncio.IncompleteReadError as err: + if len(err.partial) > 0: + self._handle_debug(err.partial) + raise + + async def _recv_bytes(self) -> bytes: + try: + while True: + await self._find_stream_header() + size_bytes: bytes = await self._reader.readexactly(2) + size: int = int.from_bytes(size_bytes, "big") + if 0 < size <= 512: + return await self._reader.readexactly(size) + + self._handle_debug(size_bytes) + + except asyncio.LimitOverrunError as err: + raise RadioConnectionError("Read buffer overrun while reading stream") from err + + except asyncio.IncompleteReadError: + self._reader.feed_eof() + logging.error(f"Connection to {self.name} terminated: stream EOF reached") + raise ConnectionTerminatedError from None + + async def close(self): + await super().close() + if self._writer.can_write_eof(): + self._writer.write_eof() + + self._writer.close() + self.stream_debug_out.close() + await self._writer.wait_closed() + + +class SerialConnection(StreamConnection): + """Connection to a mesh radio over serial port""" + def __init__(self, portaddr: str, baudrate: int=DEFAULT_BAUDRATE): + self.port: str = portaddr + self.baudrate: int = baudrate + super().__init__(portaddr) + + async def _initialize(self): + self._reader, self._writer = await serial_asyncio.open_serial_connection( + url=self.port, baudrate=self.baudrate) + + @staticmethod + async def get_available() -> AsyncGenerator[str]: + for port in serial.tools.list_ports.comports(): + # filtering for hwid gets rid of linux VT serials (e.g, /dev/ttyS0 and friends) + # FIXME: this may not be cross-platform or non-USB serial friendly + if port.hwid != "n/a": + yield port.device + + +class TCPConnection(StreamConnection): + """Connection to a mesh radio over TCP""" + + def __init__(self, host: str, port: int=DEFAULT_TCP_PORT): + self.host: str = host + self.port: int = port + super().__init__(f"{host}:{port}") + + async def _initialize(self): + self._reader, self._writer = await asyncio.open_connection(self.host, self.port) + + @staticmethod + async def get_available() -> AsyncGenerator[None]: + yield None # FIXME + + +class BLEConnection(RadioConnection): + """Connection to a mesh radio over BLE""" + + def __init__(self, device: Union[str, BLEDevice]): + self._recieved_messages: asyncio.Queue = asyncio.Queue() + self._ble_client = BleakClient(device, disconnected_callback=lambda _: self.close()) + self._ble_client.mtu_size = 512 + + name: str = device + if isinstance(device, BLEDevice): + name = device.name + super().__init__(name) + + async def _initialize(self): + await self._ble_client.connect() + await self._ble_client.start_notify(BLE_FROMNUM_UUID, self._on_recv) + + async def _on_recv(self, _sender: Any, _data: bytearray): + """Callback for handling fromnum endpoint notifs""" + data: bytearray = await self._ble_client.read_gatt_char(BLE_FROMRADIO_UUID) + if len(data) > 512: + raise BadPayloadError(data, "Cannot recieve client API messages over 512 bytes") + + await self._recieved_messages.put(data) + + async def _read_bytes(self) -> bytes: + return bytes(await self._recieved_messages.get()) + + async def _send_bytes(self, msg: bytes): + if len(msg) > 512: + raise BadPayloadError(msg, "Cannot send client API messages over 512 bytes") + + await self._ble_client.write_gatt_char(BLE_TORADIO_UUID, msg, response=True) + + async def close(self): + await super().close() + self._recieved_messages.shutdown(True) + await self._ble_client.stop_notify(BLE_FROMNUM_UUID) + await self._ble_client.disconnect() + + @staticmethod + async def get_available() -> AsyncGenerator[BLEDevice]: + async with BleakScanner(service_uuids=(BLE_SERVICE_UUID,)) as scanner: + try: + async with asyncio.timeout(10): + async for dev, _ad in scanner.advertisement_data(): + yield dev + + except TimeoutError: + pass