diff --git a/pyproject.toml b/pyproject.toml index 26a7b59..9497e3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ license = {text = "GPL-3.0"} requires-python = ">=3.8" dependencies = [ "voluptuous", - "zigpy>=0.60.2", + "zigpy>=0.68.1", 'async-timeout; python_version<"3.11"', ] diff --git a/zigpy_espzb/api.py b/zigpy_espzb/api.py index 4e09a62..c629ae0 100644 --- a/zigpy_espzb/api.py +++ b/zigpy_espzb/api.py @@ -4,7 +4,6 @@ import asyncio import collections -import itertools import logging import sys from typing import Any, Callable @@ -16,546 +15,37 @@ from zigpy.config import CONF_DEVICE_PATH import zigpy.types as t -from zigpy.zdo.types import SimpleDescriptor -from zigpy_espzb.exception import APIException, CommandError, MismatchedResponseError -from zigpy_espzb.types import Bytes, DeviceAddrMode, ZnspTransmitOptions, list_replace +from zigpy_espzb import commands +from zigpy_espzb.commands import ( + COMMAND_SCHEMA_TO_COMMAND_ID, + COMMAND_SCHEMAS, + CommandFrame, + FrameType, +) +from zigpy_espzb.exception import APIException, CommandError +from zigpy_espzb.types import ( + Bytes, + DeviceType, + ExtendedAddrMode, + FirmwareVersion, + NetworkState, + SecurityMode, + ShiftedChannels, + Status, + TransmitOptions, + TXStatus, + addr_mode_with_eui64_to_addr_mode_address, +) import zigpy_espzb.uart LOGGER = logging.getLogger(__name__) +POLL_UNTIL_RUNNING_TIMEOUT = 10 COMMAND_TIMEOUT = 1.8 PROBE_TIMEOUT = 2 REQUEST_RETRY_DELAYS = (0.5, 1.0, 1.5, None) -FRAME_LENGTH = object() -PAYLOAD_LENGTH = object() - - -class DeviceType(t.enum8): - COORDINATOR = 0 - ROUTER = 1 - ED = 2 - - -class Status(t.enum8): - SUCCESS = 0 - FAILURE = 1 - INVALID_VALUE = 2 - TIMEOUT = 3 - UNSUPPORTED = 4 - ERROR = 5 - NO_NETWORK = 6 - BUSY = 7 - - -class FirmwareVersion(t.Struct, t.uint32_t): - reserved: t.uint8_t - patch: t.uint8_t - minor: t.uint8_t - major: t.uint8_t - - -class NetworkState(t.enum8): - OFFLINE = 0 - JOINING = 1 - CONNECTED = 2 - LEAVING = 3 - CONFIRM = (4,) - INDICATION = (5,) - - -class DeviceState(t.Struct): - network_state: NetworkState - - -class SecurityMode(t.enum8): - NO_SECURITY = 0x00 - PRECONFIGURED_NETWORK_KEY = 0x01 - NETWORK_KEY_FROM_TC = 0x02 - ONLY_TCLK = 0x03 - - -class ZDPResponseHandling(t.bitmap16): - NONE = 0x0000 - NodeDescRsp = 0x0001 - - -class FormNetwork(t.Struct): - role: DeviceType - policy: t.Bool - nwk_cfg0: t.uint8_t - nwk_cfg1: t.uint32_t - - -class CommandId(t.uint16_t): - networkinit = 0x0000 - start = 0x0001 - device_state = 0x0002 - change_network_state = 0x0003 - form_network = 0x0004 - permit_joining = 0x0005 - panid_get = 0x000B - panid_set = 0x000C - extpanid_get = 0x000D - extpanid_set = 0x000E - channel_mask_get = 0x000F - channel_mask_set = 0x0010 - current_channel_get = 0x0013 - current_channel_set = 0x0014 - network_key_get = 0x0017 - network_key_set = 0x0018 - nwk_frame_counter_get = 0x0019 - nwk_frame_counter_set = 0x001A - aps_designed_coordinator_get = 0x001B - aps_designed_coordinator_set = 0x001C - short_addr_get = 0x001D - short_addr_set = 0x001E - long_addr_get = 0x001F - long_addr_set = 0x0020 - nwk_update_id_get = 0x0023 - nwk_update_id_set = 0x0024 - trust_center_address_get = 0x0025 - trust_center_address_set = 0x0026 - link_key_get = 0x0027 - link_key_set = 0x0028 - security_mode_get = 0x0029 - security_mode_set = 0x002A - use_predefined_nwk_panid_set = 0x002B - addendpoint = 0x0100 - aps_data_request = 0x0300 - aps_data_indication = 0x0301 - aps_data_confirm = 0x0302 - - -class TXStatus(t.enum8): - SUCCESS = 0x00 - - @classmethod - def _missing_(cls, value): - chained = t.APSStatus(value) - status = t.uint8_t.__new__(cls, chained.value) - status._name_ = chained.name - status._value_ = value - return status - - -class IndexedKey(t.Struct): - index: t.uint8_t - key: t.KeyData - - -class LinkKey(t.Struct): - ieee: t.EUI64 - key: t.KeyData - - -class IndexedEndpoint(t.Struct): - index: t.uint8_t - descriptor: SimpleDescriptor - - -class UpdateNeighborAction(t.enum8): - ADD = 0x01 - - -class Command(t.Struct): - flags: t.uint16_t - command_id: CommandId - seq: t.uint8_t - payload: Bytes - - -COMMAND_SCHEMAS = { - CommandId.networkinit: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.start: ( - { - "payload_length": PAYLOAD_LENGTH, - "autostart": t.Bool, - }, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.form_network: ( - { - "payload_length": PAYLOAD_LENGTH, - "form_mwk": FormNetwork, - }, - { - "payload_length": t.uint16_t, - "status": Status, - }, - { - "payload_length": t.uint16_t, - "extended_panid": t.EUI64, - "panid": t.uint16_t, - "channel": t.uint8_t, - }, - ), - CommandId.permit_joining: ( - { - "payload_length": PAYLOAD_LENGTH, - "form_mwk": FormNetwork, - }, - { - "payload_length": t.uint16_t, - "permit": t.uint8_t, - }, - { - "payload_length": t.uint16_t, - "permit": t.uint8_t, - }, - ), - CommandId.extpanid_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "ieee": t.EUI64}, - {}, - ), - CommandId.extpanid_set: ( - {"payload_length": PAYLOAD_LENGTH, "ieee": t.EUI64}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.panid_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "panid": t.uint16_t}, - {}, - ), - CommandId.panid_set: ( - {"payload_length": PAYLOAD_LENGTH, "panid": t.PanId}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.short_addr_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "short_addr": t.uint16_t}, - {}, - ), - CommandId.short_addr_set: ( - {"payload_length": PAYLOAD_LENGTH, "short_addr": t.uint16_t}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.long_addr_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "ieee": t.EUI64}, - {}, - ), - CommandId.long_addr_set: ( - {"payload_length": PAYLOAD_LENGTH, "ieee": t.EUI64}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.current_channel_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "channel": t.uint8_t}, - {}, - ), - CommandId.current_channel_set: ( - {"payload_length": PAYLOAD_LENGTH, "channel": t.uint8_t}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.channel_mask_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "channel_mask": t.uint32_t}, - {}, - ), - CommandId.channel_mask_set: ( - {"payload_length": PAYLOAD_LENGTH, "channel_mask": t.Channels}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.addendpoint: ( - { - "payload_length": PAYLOAD_LENGTH, - "endpoint": t.uint8_t, - "profileId": t.uint16_t, - "deviceId": t.uint16_t, - "appFlags": t.uint8_t, - "inputClusterCount": t.uint8_t, - "outputClusterCount": t.uint8_t, - "inputClusterList": t.List[t.uint8_t], - "outputClusterList": t.List[t.uint8_t], - }, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.device_state: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - { - "payload_length": t.uint16_t, - "device_state": DeviceState, - }, - {}, - ), - CommandId.change_network_state: ( - { - "payload_length": PAYLOAD_LENGTH, - "network_state": t.uint8_t, - }, - { - "payload_length": t.uint16_t, - "network_state": t.uint8_t, - }, - {}, - ), - CommandId.aps_data_request: ( - { - "payload_length": PAYLOAD_LENGTH, - "dst_addr": t.EUI64, - "dst_endpoint": t.uint8_t, - "src_endpoint": t.uint8_t, - "address_mode": t.uint8_t, - "profile_id": t.uint16_t, - "cluster_id": t.uint16_t, - "tx_options": t.uint8_t, - "use_alias": t.Bool, - "src_addr": t.EUI64, - "sequence": t.uint8_t, - "radius": t.uint8_t, - "asdu_length": t.uint32_t, - "asdu": t.List[t.uint8_t], - }, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.aps_data_indication: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - { - "payload_length": t.uint16_t, - "device_state": DeviceState, - "dst_addr_mode": t.uint8_t, - "dst_addr": t.EUI64, - "dst_ep": t.uint8_t, - "src_addr_mode": t.uint8_t, - "src_addr": t.EUI64, - "src_ep": t.uint8_t, - "profile_id": t.uint16_t, - "cluster_id": t.uint16_t, - "indication_status": TXStatus, - "security_status": t.uint8_t, - "lqi": t.uint8_t, - "rx_time": t.uint32_t, - "asdu_length": t.uint32_t, - "asdu": t.List[t.uint8_t], - }, - { - "payload_length": t.uint16_t, - "device_state": DeviceState, - "dst_addr_mode": t.uint8_t, - "dst_addr": t.EUI64, - "dst_ep": t.uint8_t, - "src_addr_mode": t.uint8_t, - "src_addr": t.EUI64, - "src_ep": t.uint8_t, - "profile_id": t.uint16_t, - "cluster_id": t.uint16_t, - "indication_status": TXStatus, - "security_status": t.uint8_t, - "lqi": t.uint8_t, - "rx_time": t.uint32_t, - "asdu_length": t.uint32_t, - "asdu": t.List[t.uint8_t], - }, - ), - CommandId.aps_data_confirm: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - { - "payload_length": t.uint16_t, - "device_state": DeviceState, - "dst_addr_mode": t.uint8_t, - "dst_addr": t.EUI64, - "dst_ep": t.uint8_t, - "src_ep": t.uint8_t, - "tx_time": t.uint32_t, - "request_id": t.uint8_t, - "confirm_status": TXStatus, - "asdu_length": t.uint32_t, - "asdu": t.List[t.uint8_t], - }, - { - "payload_length": t.uint16_t, - "device_state": DeviceState, - "dst_addr_mode": t.uint8_t, - "dst_addr": t.EUI64, - "dst_ep": t.uint8_t, - "src_ep": t.uint8_t, - "tx_time": t.uint32_t, - "confirm_status": TXStatus, - "asdu_length": t.uint32_t, - "asdu": t.List[t.uint8_t], - }, - ), - CommandId.network_key_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "nwk_key": t.KeyData}, - {}, - ), - CommandId.network_key_set: ( - {"payload_length": PAYLOAD_LENGTH, "nwk_key": t.KeyData}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.nwk_frame_counter_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "nwk_frame_counter": t.uint32_t}, - {}, - ), - CommandId.nwk_frame_counter_set: ( - {"payload_length": PAYLOAD_LENGTH, "nwk_frame_counter": t.uint32_t}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.aps_designed_coordinator_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "role": t.uint8_t}, - {}, - ), - CommandId.aps_designed_coordinator_set: ( - {"payload_length": PAYLOAD_LENGTH, "role": t.uint8_t}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.use_predefined_nwk_panid_set: ( - {"payload_length": PAYLOAD_LENGTH, "predefined": t.Bool}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.nwk_update_id_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "nwk_update_id": t.uint8_t}, - {}, - ), - CommandId.nwk_update_id_set: ( - {"payload_length": PAYLOAD_LENGTH, "nwk_update_id": t.uint8_t}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.trust_center_address_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "trust_center_addr": t.EUI64}, - {}, - ), - CommandId.trust_center_address_set: ( - {"payload_length": PAYLOAD_LENGTH, "trust_center_addr": t.EUI64}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.link_key_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "link_key": LinkKey}, - {}, - ), - CommandId.link_key_set: ( - {"payload_length": PAYLOAD_LENGTH, "link_key": t.KeyData}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), - CommandId.security_mode_get: ( - { - "payload_length": PAYLOAD_LENGTH, - }, - {"payload_length": t.uint16_t, "security_mode": SecurityMode}, - {}, - ), - CommandId.security_mode_set: ( - {"payload_length": PAYLOAD_LENGTH, "security_mode": SecurityMode}, - { - "payload_length": t.uint16_t, - "status": Status, - }, - {}, - ), -} - class Znsp: """Espressif ZNSP API class.""" @@ -568,9 +58,7 @@ def __init__(self, app: Callable, device_config: dict[str, Any]): self._awaiting = collections.defaultdict(lambda: collections.defaultdict(list)) self._command_lock = asyncio.Lock() self._config = device_config - self._device_state = DeviceState( - network_state=NetworkState.OFFLINE, - ) + self._network_state = NetworkState.OFFLINE self._data_poller_event = asyncio.Event() self._data_poller_event.set() @@ -589,18 +77,15 @@ def firmware_version(self) -> FirmwareVersion: @property def network_state(self) -> NetworkState: """Return current network state.""" - return self._device_state.network_state + return self._network_state async def connect(self) -> None: assert self._uart is None self._uart = await zigpy_espzb.uart.connect(self._config, self) - await self.network_init() - - device_state_rsp = await self.send_command(CommandId.device_state) - self._device_state = device_state_rsp["device_state"] - - self._data_poller_task = asyncio.create_task(self._data_poller()) + # TODO: implement a firmware version command + self._firmware_version = await self.system_firmware() + self._network_state = await self.get_network_state() def connection_lost(self, exc: Exception) -> None: """Lost serial connection.""" @@ -624,60 +109,18 @@ def close(self): self._uart.close() self._uart = None - async def send_command(self, cmd, **kwargs) -> Any: - while True: - try: - return await self._command(cmd, **kwargs) - except MismatchedResponseError as exc: - LOGGER.debug("Firmware responded incorrectly (%s), retrying", exc) - - async def _command(self, cmd, **kwargs): - payload = [] - tx_schema, _, _ = COMMAND_SCHEMAS[cmd] - trailing_optional = False - - for name, param_type in tx_schema.items(): - if isinstance(param_type, int): - if name not in kwargs: - # Default value - value = param_type.serialize() - else: - value = type(param_type)(kwargs[name]).serialize() - elif name in ("frame_length", "payload_length"): - value = param_type - elif kwargs.get(name) is None: - trailing_optional = True - value = None - elif not isinstance(kwargs[name], param_type): - value = param_type(kwargs[name]).serialize() - else: - value = kwargs[name].serialize() - - if value is None: - continue - - if trailing_optional: - raise ValueError( - f"Command {cmd} with kwargs {kwargs}" - f" has non-trailing optional argument" - ) - - payload.append(value) + async def send_command(self, command: t.Struct, *, wait_for_response: bool = True): + command_id = COMMAND_SCHEMA_TO_COMMAND_ID[type(command)] + serialized_payload = command.serialize() - if PAYLOAD_LENGTH in payload: - payload = list_replace( - lst=payload, - old=PAYLOAD_LENGTH, - new=t.uint16_t( - sum(len(p) for p in payload[payload.index(PAYLOAD_LENGTH) + 1 :]) - ).serialize(), - ) - - command = Command( - flags=0x0000, - command_id=cmd, + command_frame = CommandFrame( + version=0b0000, + frame_type=FrameType.Request, + reserved=0x00, + command_id=command_id, seq=None, - payload=b"".join(payload), + length=len(serialized_payload), + payload=serialized_payload, ) if self._uart is None: @@ -687,51 +130,63 @@ async def _command(self, cmd, **kwargs): async with self._command_lock: seq = self._seq - LOGGER.debug("Sending %s%s (seq=%s)", cmd, kwargs, seq) - self._uart.send(command.replace(seq=seq).serialize()) + LOGGER.debug("Sending %s (seq=%s)", command, seq) + self._uart.send(command_frame.replace(seq=seq).serialize()) self._seq = (self._seq % 255) + 1 + if not wait_for_response: + LOGGER.debug("Not waiting for a response") + return + fut = asyncio.Future() - self._awaiting[seq][cmd].append(fut) + self._awaiting[seq][command_id].append(fut) try: async with asyncio_timeout(COMMAND_TIMEOUT): return await fut except asyncio.TimeoutError: - LOGGER.debug("No response to '%s' command with seq %d", cmd, seq) + LOGGER.debug("No response to '%s' command with seq %d", command, seq) raise finally: - self._awaiting[seq][cmd].remove(fut) + self._awaiting[seq][command_id].remove(fut) def data_received(self, data: bytes) -> None: - command, _ = Command.deserialize(data) + command, _ = CommandFrame.deserialize(data) if command.command_id not in COMMAND_SCHEMAS: LOGGER.warning("Unknown command received: %s", command) return - if command.flags == 0x0010: - _, rx_schema, _ = COMMAND_SCHEMAS[command.command_id] - elif command.flags == 0x0020: - _, _, rx_schema = COMMAND_SCHEMAS[command.command_id] + tx_schema, rx_schema, ind_schema = COMMAND_SCHEMAS[command.command_id] + + if command.frame_type == FrameType.Request: + schema = tx_schema + elif command.frame_type == FrameType.Response: + schema = rx_schema + elif command.frame_type == FrameType.Indicate: + schema = ind_schema + else: + raise ValueError(f"Unknown frame type: {command}") + + # We won't implement requests for now + assert command.frame_type != FrameType.Request + + if schema is None: + return fut = None - wrong_fut_cmd_id = None - try: - fut = self._awaiting[command.seq][command.command_id][0] - except IndexError: - # XXX: The firmware can sometimes respond with the wrong response. Find the - # future associated with it so we can throw an appropriate error. - for cmd_id, futs in self._awaiting[command.seq].items(): - if futs: - fut = futs[0] - wrong_fut_cmd_id = cmd_id - break + if command.frame_type == FrameType.Response: + try: + fut = self._awaiting[command.seq][command.command_id][0] + except IndexError: + LOGGER.warning( + "Received unexpected response %s%s", command.command_id, command + ) try: - params, rest = t.deserialize_dict(command.payload, rx_schema) + params, rest = schema.deserialize(command.payload) except Exception: LOGGER.warning("Failed to parse command %s", command, exc_info=True) @@ -745,37 +200,17 @@ def data_received(self, data: bytes) -> None: if rest: LOGGER.debug("Unparsed data remains after frame: %s, %s", command, rest) - if "payload_length" in params: - running_length = itertools.accumulate( - len(v.serialize()) if v is not None else 0 for v in params.values() - ) - length_at_param = dict(zip(params.keys(), running_length)) - - assert ( - len(data) - length_at_param["payload_length"] - 5 - == params["payload_length"] - ) - LOGGER.debug( - "Received command %s%s (seq %d)", command.command_id, params, command.seq + "Received %s %s (seq %d)", + ("indication" if command.frame_type == FrameType.Indicate else "response"), + params, + command.seq, ) - status = Status.SUCCESS - if "status" in params: - status = params["status"] - exc = None + status = getattr(params, "status", None) - if wrong_fut_cmd_id is not None: - exc = MismatchedResponseError( - command.command_id, - params, - ( - f"Response is mismatched! Sent {wrong_fut_cmd_id}," - f" received {command.command_id}" - ), - ) - elif status != Status.SUCCESS: + if status is not None and status != Status.SUCCESS: exc = CommandError(status, f"{command.command_id}, status: {status}") if fut is not None: @@ -786,246 +221,203 @@ def data_received(self, data: bytes) -> None: fut.set_exception(exc) except asyncio.InvalidStateError: LOGGER.warning( - "Duplicate or delayed response for 0x:%02x sequence", + "Duplicate or delayed response for 0x%02x sequence", command.seq, ) if exc is not None: return - if handler := getattr(self, f"_handle_{command.command_id}", None): - handler_params = { - k: v - for k, v in params.items() - if k not in ("frame_length", "payload_length") - } - + if handler := getattr(self, f"_handle_{command.command_id.name}", None): # Queue up the callback within the event loop - asyncio.get_running_loop().call_soon(lambda: handler(**handler_params)) - - async def _data_poller(self): - while True: - await self._data_poller_event.wait() - self._data_poller_event.clear() - - if self._device_state.network_state == NetworkState.OFFLINE: - continue - - # Poll data indication - rsp = await self.send_command(CommandId.aps_data_indication) - self._handle_device_state_changed( - Status.SUCCESS, device_state=rsp["device_state"] - ) + asyncio.get_running_loop().call_soon(lambda: handler(**params.as_dict())) - if rsp["device_state"] == NetworkState.INDICATION: - self._app.packet_received( - t.ZigbeePacket( - src=t.AddrModeAddress( - addr_mode=rsp["src_addr_mode"], - address=rsp["src_addr"], - ), - src_ep=rsp["src_ep"], - dst=t.AddrModeAddress( - addr_mode=rsp["dst_addr_mode"], - address=rsp["dst_addr"], - ), - dst_ep=rsp["dst_ep"], - tsn=None, - profile_id=rsp["profile_id"], - cluster_id=rsp["cluster_id"], - data=t.SerializableBytes(rsp["asdu"]), - lqi=rsp["lqi"], - rssi=rsp["rssi"], - ) + def _handle_aps_data_indication( + self, + network_state: NetworkState, + dst_addr_mode: ExtendedAddrMode, + dst_addr: t.EUI64, + dst_ep: t.uint8_t, + src_addr_mode: ExtendedAddrMode, + src_addr: t.EUI64, + src_ep: t.uint8_t, + profile_id: t.uint16_t, + cluster_id: t.uint16_t, + indication_status: TXStatus, + security_status: t.uint8_t, + lqi: t.uint8_t, + rx_time: t.uint32_t, + asdu_length: t.uint32_t, + asdu: Bytes, + ): + if network_state == NetworkState.INDICATION: + self._app.packet_received( + t.ZigbeePacket( + src=addr_mode_with_eui64_to_addr_mode_address( + src_addr_mode, src_addr + ), + src_ep=src_ep, + dst=addr_mode_with_eui64_to_addr_mode_address( + dst_addr_mode, dst_addr + ), + dst_ep=dst_ep, + tsn=None, + profile_id=profile_id, + cluster_id=cluster_id, + data=t.SerializableBytes(asdu), + lqi=lqi, + rssi=None, ) - - # Poll data confirm - rsp = await self.send_command(CommandId.aps_data_confirm) - self._handle_device_state_changed( - Status.SUCCESS, device_state=rsp["device_state"] ) - def _handle_device_state_changed( - self, - status: t.Status, - device_state: DeviceState, - reserved: t.uint8_t = 0, - ) -> None: - if device_state.network_state != self.network_state: + def _handle_network_state_changed(self, network_state: NetworkState) -> None: + if network_state != self.network_state: LOGGER.debug( - "Network device_state transition: %s -> %s", + "Network network_state transition: %s -> %s", self.network_state.name, - device_state.network_state.name, + network_state.name, ) - self._device_state = device_state + self._network_state = network_state self._data_poller_event.set() - def _handle_device_state( - self, - status: t.Status, - device_state: DeviceState, - reserved1: t.uint8_t, - reserved2: t.uint8_t, - ) -> None: - self._handle_device_state_changed(status=status, device_state=device_state) + def _handle_network_state(self, network_state: NetworkState) -> None: + self._handle_network_state_changed(network_state=network_state) - async def network_init(self): - await self.send_command(CommandId.networkinit) - await self.form_network( - FormNetwork( - role=DeviceType.COORDINATOR, policy=False, nwk_cfg0=0x14, nwk_cfg1=0 - ) - ) - await self.start(False) - - return Status.SUCCESS - - async def channel_mask(self): - rssult = [] - rsp = await self.send_command(CommandId.channel_mask_get) + async def network_init(self) -> None: + await self.send_command(commands.NetworkInitReq()) - for index in range(32): - if (rsp["channel_mask"] & (1 << index)) != 0: - rssult.append(index) + async def get_channel_mask(self) -> t.Channels: + rsp = await self.send_command(commands.PrimaryChannelMaskGetReq()) + return t.Channels.from_channel_list(tuple(rsp.channel_mask)) - return rssult - - async def set_channel_mask(self, parameter: t.Channels): - rsp = await self.send_command( - CommandId.channel_mask_set, channel_mask=parameter + async def set_channel_mask(self, channels: t.Channels) -> None: + await self.send_command( + commands.PrimaryChannelMaskSetReq( + channel_mask=ShiftedChannels.from_channel_list(channels) + ) ) - return rsp["status"] + async def set_channel(self, channel: int) -> None: + await self.set_channel_mask(channels=t.Channels.from_channel_list([channel])) - async def form_network(self, parameter: FormNetwork): - rsp = await self.send_command( - CommandId.form_network, - form_mwk=parameter, + async def form_network( + self, + role: DeviceType = DeviceType.COORDINATOR, + install_code_policy: bool = False, + # For coordinators/routers + max_children: t.uint8_t = 20, + # For end devices + ed_timeout: t.uint8_t = 0, + keep_alive: t.uint32_t = 0, + ) -> None: + await self.send_command( + commands.FormNetworkReq( + role=role, + install_code_policy=install_code_policy, + max_children=max_children, + ed_timeout=ed_timeout, + keep_alive=keep_alive, + ) ) - return rsp["status"] + # TODO: wait for the `form_network` indication as well? + await asyncio.sleep(2) - async def start(self, parameter: t.uint8_t): - rsp = await self.send_command(CommandId.start, autostart=parameter) + async def start(self, autostart: bool) -> Status: + await self.send_command(commands.StartReq(autostart=autostart)) - return rsp["status"] + # TODO: wait for the `form_network` indication as well? + await asyncio.sleep(2) - async def mac_address(self): - rsp = await self.send_command(CommandId.long_addr_get) + async def get_mac_address(self): + rsp = await self.send_command(commands.LongAddrGetReq()) - return rsp["ieee"] + return rsp.ieee async def set_mac_address(self, parameter: t.EUI64): - rsp = await self.send_command(CommandId.long_addr_set, ieee=parameter) + await self.send_command(commands.LongAddrSetReq(ieee=parameter)) - return rsp["status"] + async def get_nwk_address(self): + rsp = await self.send_command(commands.ShortAddrGetReq()) - async def nwk_address(self): - rsp = await self.send_command(CommandId.short_addr_get) - - return rsp["short_addr"] + return rsp.short_addr async def set_nwk_address(self, parameter: t.uint16_t): - rsp = await self.send_command(CommandId.short_addr_set, short_addr=parameter) - - return rsp["status"] + await self.send_command(commands.ShortAddrSetReq(short_addr=parameter)) - async def nwk_panid(self): - rsp = await self.send_command(CommandId.panid_get) + async def get_nwk_panid(self): + rsp = await self.send_command(commands.PanidGetReq()) - return rsp["panid"] + return rsp.panid async def set_nwk_panid(self, parameter: t.PanId): - rsp = await self.send_command(CommandId.panid_set, panid=parameter) - - return rsp["status"] + await self.send_command(commands.PanidSetReq(panid=parameter)) - async def nwk_extended_panid(self): - rsp = await self.send_command(CommandId.extpanid_get) + async def get_nwk_extended_panid(self): + rsp = await self.send_command(commands.ExtpanidGetReq()) - return rsp["ieee"] + return rsp.ieee async def set_nwk_extended_panid(self, parameter: t.ExtendedPanId): - rsp = await self.send_command(CommandId.extpanid_set, panid=parameter) + await self.send_command(commands.ExtpanidSetReq(ieee=parameter)) - return rsp["status"] + async def get_current_channel(self) -> int: + rsp = await self.send_command(commands.CurrentChannelGetReq()) - async def current_channel(self): - rsp = await self.send_command(CommandId.current_channel_get) + return rsp.channel - return rsp["channel"] + async def get_nwk_update_id(self): + rsp = await self.send_command(commands.NwkUpdateIdGetReq()) - async def nwk_update_id(self): - rsp = await self.send_command(CommandId.nwk_update_id_get) - - return rsp["nwk_update_id"] + return rsp.nwk_update_id async def set_nwk_update_id(self, parameter: t.uint8_t): - rsp = await self.send_command( - CommandId.nwk_update_id_set, nwk_update_id=parameter - ) - - return rsp["status"] - - async def network_key(self): - rsp = await self.send_command(CommandId.network_key_get) + await self.send_command(commands.NwkUpdateIdSetReq(nwk_update_id=parameter)) - indexed_key = IndexedKey(index=0, key=rsp["nwk_key"]) + async def get_network_key(self): + rsp = await self.send_command(commands.NetworkKeyGetReq()) - return indexed_key + return rsp.nwk_key - async def set_network_key(self, parameter: IndexedKey): - rsp = await self.send_command(CommandId.network_key_set, nwk_key=parameter.key) + async def set_network_key(self, key: t.KeyData): + await self.send_command(commands.NetworkKeySetReq(nwk_key=key)) - return rsp["status"] + async def get_nwk_frame_counter(self): + rsp = await self.send_command(commands.NwkFrameCounterGetReq()) - async def nwk_frame_counter(self): - rsp = await self.send_command(CommandId.nwk_frame_counter_get) + return rsp.nwk_frame_counter - return rsp["nwk_frame_counter"] - - async def set_nwk_frame_counter(self, parameter: t.uint32_t): - rsp = await self.send_command( - CommandId.nwk_frame_counter_set, - nwk_frame_counter=parameter, + async def set_nwk_frame_counter(self, counter: t.uint32_t): + await self.send_command( + commands.NwkFrameCounterSetReq(nwk_frame_counter=counter) ) - return rsp["status"] - - async def trust_center_address(self): - rsp = await self.send_command(CommandId.trust_center_address_get) + async def get_trust_center_address(self): + rsp = await self.send_command(commands.TrustCenterAddressGetReq()) - return rsp["trust_center_addr"] + return rsp.trust_center_addr - async def set_trust_center_address(self, parameter: t.EUI64): - rsp = await self.send_command( - CommandId.trust_center_address_set, trust_center_addr=parameter + async def set_trust_center_address(self, addr: t.EUI64) -> None: + await self.send_command( + commands.TrustCenterAddressSetReq(trust_center_addr=addr) ) - return rsp["status"] - - async def link_key(self, parameter: Any = None) -> Any: - rsp = await self.send_command(CommandId.link_key_get) + async def get_link_key(self) -> Any: + rsp = await self.send_command(commands.LinkKeyGetReq()) - return rsp["link_key"] + return rsp.key - async def set_link_key(self, parameter: LinkKey): - rsp = await self.send_command(CommandId.link_key_set, link_key=parameter.key) + async def set_link_key(self, key: t.KeyData): + await self.send_command(commands.LinkKeySetReq(key=key)) - return rsp["status"] + async def get_security_mode(self): + rsp = await self.send_command(commands.SecurityModeGetReq()) - async def security_mode(self): - rsp = await self.send_command(CommandId.security_mode_get) + return rsp.security_mode - return rsp["security_mode"] - - async def set_security_mode(self, parameter: SecurityMode): - rsp = await self.send_command( - CommandId.security_mode_set, security_mode=parameter - ) - - return rsp["status"] + async def set_security_mode(self, mode: SecurityMode): + await self.send_command(commands.SecurityModeSetReq(security_mode=mode)) async def add_endpoint( self, @@ -1033,78 +425,45 @@ async def add_endpoint( profile: t.uint16_t, device_type: t.uint16_t, device_version: t.uint8_t, - input_clusters: t.LVList[t.uint16_t], - output_clusters: t.LVList[t.uint16_t], + input_clusters: list[t.ClusterId], + output_clusters: list[t.ClusterId], ): - inputClusterList = t.LVList[t.uint16_t].serialize(input_clusters) - outputClusterList = t.LVList[t.uint16_t].serialize(output_clusters) - if profile == 0xC05E: - return Status.SUCCESS - - rsp = await self.send_command( - CommandId.addendpoint, - endpoint=endpoint, - profileId=profile, - deviceId=device_type, - appFlags=device_version, - inputClusterCount=len(input_clusters), - outputClusterCount=len(output_clusters), - inputClusterList=t.List(inputClusterList[1:]), - outputClusterList=t.List(outputClusterList[1:]), - ) - - return rsp["status"] - - async def set_use_predefined_nwk_panid(self, parameter: t.Bool): - rsp = await self.send_command( - CommandId.use_predefined_nwk_panid_set, - predefined=parameter, - ) - - return rsp["status"] - - async def set_permit_join(self, parameter: t.uint8_t): - rsp = await self.send_command( - CommandId.permit_join_set, - role=parameter, - ) - - return rsp["status"] + return - async def set_watchdog_ttl(self, parameter: t.uint16_t): - rsp = await self.send_command( - CommandId.watchdog_ttl_set, - role=parameter, + await self.send_command( + commands.AddEndpointReq( + endpoint=endpoint, + profile_id=profile, + device_id=device_type, + app_flags=device_version, + input_cluster_count=len(input_clusters), + output_cluster_count=len(output_clusters), + input_cluster_list=input_clusters, + output_cluster_list=output_clusters, + ) ) - return rsp["status"] - - async def aps_designed_coordinator(self): - rsp = await self.send_command( - CommandId.aps_designed_coordinator_get, - reserved=0, + async def set_use_predefined_nwk_panid(self, use_predefined: t.Bool): + await self.send_command( + commands.UsePredefinedNwkPanidSetReq( + predefined=use_predefined, + ) ) - return rsp["role"] - - async def set_aps_designed_coordinator(self, parameter: t.uint8_t): - rsp = await self.send_command( - CommandId.aps_designed_coordinator_set, - role=parameter, + async def set_permit_join(self, duration: t.uint8_t): + await self.send_command( + commands.PermitJoiningReq( + duration=duration, + ) ) - return rsp["status"] - - async def aps_extended_panid(self): - rsp = await self.send_command(CommandId.extpanid_get) + async def get_network_role(self) -> DeviceType: + rsp = await self.send_command(commands.NetworkRoleGetReq()) + return rsp.role - return rsp["ieee"] - - async def set_aps_extended_panid(self, parameter: t.ExtendedPanId): - rsp = await self.send_command(CommandId.extpanid_set, ieee=parameter) - - return rsp["status"] + async def set_network_role(self, role: DeviceType) -> None: + await self.send_command(commands.NetworkRoleSetReq(role=role)) async def aps_data_request( self, @@ -1113,32 +472,31 @@ async def aps_data_request( src_addr: t.EUI64, src_ep: t.uint8_t, profile: t.uint16_t, - addr_mode: DeviceAddrMode, + addr_mode: t.AddrMode, cluster: t.uint16_t, sequence: t.uint16_t, - options: ZnspTransmitOptions, + options: TransmitOptions, radius: t.uint16_t, data: bytes, - relays: list[int] | None = None, - extended_timeout: bool = False, ): for delay in REQUEST_RETRY_DELAYS: try: - rsp = await self.send_command( - CommandId.aps_data_request, - dst_addr=dst_addr, - dst_endpoint=dst_ep, - src_endpoint=src_ep, - address_mode=addr_mode, - profile_id=profile, - cluster_id=cluster, - tx_options=options, - use_alias=False, - src_addr=src_addr, - sequence=sequence, - radius=radius, - asdu_length=len(data), - asdu=t.List(data), + await self.send_command( + commands.ApsDataRequestReq( + dst_addr=dst_addr, + dst_endpoint=dst_ep, + src_endpoint=src_ep, + address_mode=addr_mode, + profile_id=profile, + cluster_id=cluster, + tx_options=options, + use_alias=False, + alias_src_addr=src_addr, + alias_seq_num=sequence, + radius=radius, + asdu_length=len(data), + asdu=data, + ) ) except CommandError as ex: LOGGER.debug("'aps_data_request' failure: %s", ex) @@ -1148,27 +506,44 @@ async def aps_data_request( LOGGER.debug("retrying 'aps_data_request' in %ss", delay) await asyncio.sleep(delay) else: - self._handle_device_state_changed( - status=rsp["status"], - device_state=DeviceState(network_state=NetworkState.CONNECTED), - ) return - async def get_device_state(self) -> DeviceState: - rsp = await self.send_command(CommandId.device_state) + async def get_network_state(self) -> NetworkState: + rsp = await self.send_command(commands.NetworkStateReq()) - return rsp["device_state"] + return rsp.network_state - async def change_network_state(self, new_state: NetworkState) -> None: - await self.send_command(CommandId.change_network_state, network_state=new_state) + async def _poll_until_running(self): + async with asyncio_timeout(POLL_UNTIL_RUNNING_TIMEOUT): + while True: + await asyncio.sleep(0.5) - async def add_neighbour( - self, nwk: t.NWK, ieee: t.EUI64, mac_capability_flags: t.uint8_t - ) -> None: - await self.send_command( - CommandId.update_neighbor, - action=UpdateNeighborAction.ADD, - nwk=nwk, - ieee=ieee, - mac_capability_flags=mac_capability_flags, - ) + try: + LOGGER.debug("Polling firmware to see if it is running") + await self.system_firmware() + break + except asyncio.TimeoutError: + pass + + async def reset(self) -> None: + await self.send_command(commands.SystemResetReq(), wait_for_response=False) + await self._poll_until_running() + + async def factory_reset(self): + await self.send_command(commands.SystemFactoryReq(), wait_for_response=False) + await self._poll_until_running() + + async def system_firmware(self): + rsp = await self.send_command(commands.SystemFirmwareReq()) + + return rsp.firmware_version + + async def system_model(self): + rsp = await self.send_command(commands.SystemModelReq()) + + return rsp.payload + + async def system_manufacturer(self): + rsp = await self.send_command(commands.SystemManufacturerReq()) + + return rsp.payload diff --git a/zigpy_espzb/commands.py b/zigpy_espzb/commands.py new file mode 100644 index 0000000..b6cc609 --- /dev/null +++ b/zigpy_espzb/commands.py @@ -0,0 +1,780 @@ +"""Serial command schemas.""" + +import zigpy.types as t + +from zigpy_espzb.types import ( + Bytes, + DeviceType, + ExtendedAddrMode, + FirmwareVersion, + NetworkState, + SecurityMode, + ShiftedChannels, + Status, + TransmitOptions, + TXStatus, +) + + +class CommandId(t.enum16): + network_init = 0x0000 + start = 0x0001 + network_state = 0x0002 + stack_status_handler = 0x0003 + form_network = 0x0004 + permit_joining = 0x0005 + join_network = 0x0006 + leave_network = 0x0007 + start_scan = 0x0008 + scan_complete_handler = 0x0009 + stop_scan = 0x000A + panid_get = 0x000B + panid_set = 0x000C + extpanid_get = 0x000D + extpanid_set = 0x000E + primary_channel_mask_get = 0x000F + primary_channel_mask_set = 0x0010 + secondary_channel_mask_get = 0x0011 + secondary_channel_mask_set = 0x0012 + current_channel_get = 0x0013 + current_channel_set = 0x0014 + tx_power_get = 0x0015 + tx_power_set = 0x0016 + network_key_get = 0x0017 + network_key_set = 0x0018 + nwk_frame_counter_get = 0x0019 + nwk_frame_counter_set = 0x001A + network_role_get = 0x001B + network_role_set = 0x001C + short_addr_get = 0x001D + short_addr_set = 0x001E + long_addr_get = 0x001F + long_addr_set = 0x0020 + channel_masks_get = 0x0021 + channel_masks_set = 0x0022 + nwk_update_id_get = 0x0023 + nwk_update_id_set = 0x0024 + trust_center_address_get = 0x0025 + trust_center_address_set = 0x0026 + link_key_get = 0x0027 + link_key_set = 0x0028 + security_mode_get = 0x0029 + security_mode_set = 0x002A + use_predefined_nwk_panid_set = 0x002B + short_to_ieee = 0x002C + ieee_to_short = 0x002D + add_endpoint = 0x0100 + remove_endpoint = 0x0101 + attribute_read = 0x0102 + attribute_write = 0x0103 + attribute_report = 0x0104 + attribute_discover = 0x0105 + aps_read = 0x0106 + aps_write = 0x0107 + report_config = 0x0108 + bind_set = 0x0200 + unbind_set = 0x0201 + find_match = 0x0202 + aps_data_request = 0x0300 + aps_data_indication = 0x0301 + aps_data_confirm = 0x0302 + system_reset = 0x0400 + system_factory = 0x0401 + system_firmware = 0x0402 + system_model = 0x0403 + system_manufacturer = 0x0404 + + +class FrameType(t.enum4): + Request = 0 + Response = 1 + Indicate = 2 + + +class CommandFrame(t.Struct): + version: t.uint4_t + frame_type: FrameType + reserved: t.uint8_t + + command_id: CommandId + seq: t.uint8_t + length: t.uint16_t + payload: Bytes + + +class BaseCommand(t.Struct): + pass + + +class NetworkInitReq(BaseCommand): + pass + + +class NetworkInitRsp(BaseCommand): + status: Status + + +class StartReq(BaseCommand): + autostart: t.Bool + + +class StartRsp(BaseCommand): + status: Status + + +class FormNetworkReq(BaseCommand): + role: DeviceType + install_code_policy: t.Bool + + # For coordinators/routers + max_children: t.uint8_t = t.StructField( + requires=lambda f: f.role in (DeviceType.ROUTER, DeviceType.COORDINATOR) + ) + + # For end devices + ed_timeout: t.uint8_t = t.StructField( + requires=lambda f: f.role == DeviceType.END_DEVICE + ) + keep_alive: t.uint32_t = t.StructField( + requires=lambda f: f.role == DeviceType.END_DEVICE + ) + + +class FormNetworkRsp(BaseCommand): + status: Status + + +class FormNetworkInd(BaseCommand): + extended_panid: t.EUI64 + panid: t.PanId + channel: t.uint8_t + + +class PermitJoiningReq(BaseCommand): + duration: t.uint8_t + + +class PermitJoiningRsp(BaseCommand): + status: Status + + +class PermitJoiningInd(BaseCommand): + duration: t.uint8_t + + +class LeaveNetworkReq(BaseCommand): + pass + + +class LeaveNetworkRsp(BaseCommand): + status: Status + + +class LeaveNetworkInd(BaseCommand): + short_addr: t.NWK + device_addr: t.EUI64 + rejoin: t.Bool + + +class ExtpanidGetReq(BaseCommand): + pass + + +class ExtpanidGetRsp(BaseCommand): + ieee: t.EUI64 + + +class ExtpanidSetReq(BaseCommand): + ieee: t.EUI64 + + +class ExtpanidSetRsp(BaseCommand): + status: Status + + +class PanidGetReq(BaseCommand): + pass + + +class PanidGetRsp(BaseCommand): + panid: t.PanId + + +class PanidSetReq(BaseCommand): + panid: t.PanId + + +class PanidSetRsp(BaseCommand): + status: Status + + +class ShortAddrGetReq(BaseCommand): + pass + + +class ShortAddrGetRsp(BaseCommand): + short_addr: t.NWK + + +class ShortAddrSetReq(BaseCommand): + short_addr: t.NWK + + +class ShortAddrSetRsp(BaseCommand): + status: Status + + +class LongAddrGetReq(BaseCommand): + pass + + +class LongAddrGetRsp(BaseCommand): + ieee: t.EUI64 + + +class LongAddrSetReq(BaseCommand): + ieee: t.EUI64 + + +class LongAddrSetRsp(BaseCommand): + status: Status + + +class CurrentChannelGetReq(BaseCommand): + pass + + +class CurrentChannelGetRsp(BaseCommand): + channel: t.uint8_t + + +class CurrentChannelSetReq(BaseCommand): + channel: t.uint8_t + + +class CurrentChannelSetRsp(BaseCommand): + status: Status + + +class PrimaryChannelMaskGetReq(BaseCommand): + pass + + +class PrimaryChannelMaskGetRsp(BaseCommand): + channel_mask: ShiftedChannels + + +class PrimaryChannelMaskSetReq(BaseCommand): + channel_mask: ShiftedChannels + + +class PrimaryChannelMaskSetRsp(BaseCommand): + status: Status + + +class AddEndpointReq(BaseCommand): + endpoint: t.uint8_t + profile_id: t.uint16_t + device_id: t.uint16_t + app_flags: t.uint8_t + input_cluster_count: t.uint8_t + output_cluster_count: t.uint8_t + input_cluster_list: t.List[t.uint16_t] + output_cluster_list: t.List[t.uint16_t] + + +class AddEndpointRsp(BaseCommand): + status: Status + + +class NetworkStateReq(BaseCommand): + pass + + +class NetworkStateRsp(BaseCommand): + network_state: NetworkState + + +class StackStatusHandlerReq(BaseCommand): + pass + + +class StackStatusHandlerRsp(BaseCommand): + network_state: t.uint8_t + + +class StackStatusHandlerInd(BaseCommand): + network_state: t.uint8_t + + +class ApsDataRequestReq(BaseCommand): + dst_addr: t.EUI64 + dst_endpoint: t.uint8_t + src_endpoint: t.uint8_t + address_mode: ExtendedAddrMode + profile_id: t.uint16_t + cluster_id: t.uint16_t + tx_options: TransmitOptions + use_alias: t.Bool + alias_src_addr: t.EUI64 + alias_seq_num: t.uint8_t + radius: t.uint8_t + asdu_length: t.uint32_t + asdu: Bytes + + +class ApsDataRequestRsp(BaseCommand): + status: Status + + +class ApsDataIndicationRsp(BaseCommand): + network_state: NetworkState + dst_addr_mode: ExtendedAddrMode + dst_addr: t.EUI64 + dst_ep: t.uint8_t + src_addr_mode: ExtendedAddrMode + src_addr: t.EUI64 + src_ep: t.uint8_t + profile_id: t.uint16_t + cluster_id: t.uint16_t + indication_status: TXStatus + security_status: t.uint8_t + lqi: t.uint8_t + rx_time: t.uint32_t + asdu_length: t.uint32_t + asdu: Bytes + + +class ApsDataIndicationInd(BaseCommand): + network_state: NetworkState + dst_addr_mode: ExtendedAddrMode + dst_addr: t.EUI64 + dst_ep: t.uint8_t + src_addr_mode: ExtendedAddrMode + src_addr: t.EUI64 + src_ep: t.uint8_t + profile_id: t.uint16_t + cluster_id: t.uint16_t + indication_status: TXStatus + security_status: t.uint8_t + lqi: t.uint8_t + rx_time: t.uint32_t + asdu_length: t.uint32_t + asdu: Bytes + + +class ApsDataConfirmReq(BaseCommand): + pass + + +class ApsDataConfirmRsp(BaseCommand): + network_state: NetworkState + dst_addr_mode: ExtendedAddrMode + dst_addr: t.EUI64 + dst_ep: t.uint8_t + src_ep: t.uint8_t + tx_time: t.uint32_t + request_id: t.uint8_t + confirm_status: TXStatus + asdu_length: t.uint32_t + asdu: Bytes + + +class ApsDataConfirmInd(BaseCommand): + network_state: NetworkState + dst_addr_mode: ExtendedAddrMode + dst_addr: t.EUI64 + dst_ep: t.uint8_t + src_ep: t.uint8_t + tx_time: t.uint32_t + confirm_status: TXStatus + asdu_length: t.uint32_t + asdu: Bytes + + +class NetworkKeyGetReq(BaseCommand): + pass + + +class NetworkKeyGetRsp(BaseCommand): + nwk_key: t.KeyData + + +class NetworkKeySetReq(BaseCommand): + nwk_key: t.KeyData + + +class NetworkKeySetRsp(BaseCommand): + status: Status + + +class NwkFrameCounterGetReq(BaseCommand): + pass + + +class NwkFrameCounterGetRsp(BaseCommand): + nwk_frame_counter: t.uint32_t + + +class NwkFrameCounterSetReq(BaseCommand): + nwk_frame_counter: t.uint32_t + + +class NwkFrameCounterSetRsp(BaseCommand): + status: Status + + +class NetworkRoleGetReq(BaseCommand): + pass + + +class NetworkRoleGetRsp(BaseCommand): + role: DeviceType + + +class NetworkRoleSetReq(BaseCommand): + role: DeviceType + + +class NetworkRoleSetRsp(BaseCommand): + status: Status + + +class UsePredefinedNwkPanidSetReq(BaseCommand): + predefined: t.Bool + + +class UsePredefinedNwkPanidSetRsp(BaseCommand): + status: Status + + +class NwkUpdateIdGetReq(BaseCommand): + pass + + +class NwkUpdateIdGetRsp(BaseCommand): + nwk_update_id: t.uint8_t + + +class NwkUpdateIdSetReq(BaseCommand): + nwk_update_id: t.uint8_t + + +class NwkUpdateIdSetRsp(BaseCommand): + status: Status + + +class TrustCenterAddressGetReq(BaseCommand): + pass + + +class TrustCenterAddressGetRsp(BaseCommand): + trust_center_addr: t.EUI64 + + +class TrustCenterAddressSetReq(BaseCommand): + trust_center_addr: t.EUI64 + + +class TrustCenterAddressSetRsp(BaseCommand): + status: Status + + +class LinkKeyGetReq(BaseCommand): + pass + + +class LinkKeyGetRsp(BaseCommand): + ieee: t.EUI64 + key: t.KeyData + + +class LinkKeySetReq(BaseCommand): + key: t.KeyData + + +class LinkKeySetRsp(BaseCommand): + status: Status + + +class SecurityModeGetReq(BaseCommand): + pass + + +class SecurityModeGetRsp(BaseCommand): + security_mode: SecurityMode + + +class SecurityModeSetReq(BaseCommand): + security_mode: SecurityMode + + +class SecurityModeSetRsp(BaseCommand): + status: Status + + +class SystemResetReq(BaseCommand): + pass + + +class SystemResetRsp(BaseCommand): + status: Status + +class SystemResetInd(BaseCommand): + error: t.uint32_t + +class SystemFactoryReq(BaseCommand): + pass + + +class SystemFactoryRsp(BaseCommand): + status: Status + +class SystemFactoryInd(BaseCommand): + error: t.uint32_t + +class SystemFirmwareReq(BaseCommand): + pass + + +class SystemFirmwareRsp(BaseCommand): + firmware_version: FirmwareVersion + + +class SystemModelReq(BaseCommand): + pass + + +class SystemModelRsp(BaseCommand): + payload: t.CharacterString + + +class SystemManufacturerReq(BaseCommand): + pass + + +class SystemManufacturerRsp(BaseCommand): + payload: t.CharacterString + + +COMMAND_SCHEMAS = { + CommandId.network_init: ( + NetworkInitReq, + NetworkInitRsp, + None, + ), + CommandId.start: ( + StartReq, + StartRsp, + None, + ), + CommandId.form_network: ( + FormNetworkReq, + FormNetworkRsp, + FormNetworkInd, + ), + CommandId.permit_joining: ( + PermitJoiningReq, + PermitJoiningRsp, + PermitJoiningInd, + ), + CommandId.leave_network: ( + LeaveNetworkReq, + LeaveNetworkRsp, + LeaveNetworkInd, + ), + CommandId.extpanid_get: ( + ExtpanidGetReq, + ExtpanidGetRsp, + None, + ), + CommandId.extpanid_set: ( + ExtpanidSetReq, + ExtpanidSetRsp, + None, + ), + CommandId.panid_get: ( + PanidGetReq, + PanidGetRsp, + None, + ), + CommandId.panid_set: ( + PanidSetReq, + PanidSetRsp, + None, + ), + CommandId.short_addr_get: ( + ShortAddrGetReq, + ShortAddrGetRsp, + None, + ), + CommandId.short_addr_set: ( + ShortAddrSetReq, + ShortAddrSetRsp, + None, + ), + CommandId.long_addr_get: ( + LongAddrGetReq, + LongAddrGetRsp, + None, + ), + CommandId.long_addr_set: ( + LongAddrSetReq, + LongAddrSetRsp, + None, + ), + CommandId.current_channel_get: ( + CurrentChannelGetReq, + CurrentChannelGetRsp, + None, + ), + CommandId.current_channel_set: ( + CurrentChannelSetReq, + CurrentChannelSetRsp, + None, + ), + CommandId.primary_channel_mask_get: ( + PrimaryChannelMaskGetReq, + PrimaryChannelMaskGetRsp, + None, + ), + CommandId.primary_channel_mask_set: ( + PrimaryChannelMaskSetReq, + PrimaryChannelMaskSetRsp, + None, + ), + CommandId.add_endpoint: ( + AddEndpointReq, + AddEndpointRsp, + None, + ), + CommandId.network_state: ( + NetworkStateReq, + NetworkStateRsp, + None, + ), + CommandId.stack_status_handler: ( + StackStatusHandlerReq, + StackStatusHandlerRsp, + StackStatusHandlerInd, + ), + CommandId.aps_data_request: ( + ApsDataRequestReq, + ApsDataRequestRsp, + None, + ), + CommandId.aps_data_indication: ( + None, + ApsDataIndicationRsp, + ApsDataIndicationInd, + ), + CommandId.aps_data_confirm: ( + ApsDataConfirmReq, + ApsDataConfirmRsp, + ApsDataConfirmInd, + ), + CommandId.network_key_get: ( + NetworkKeyGetReq, + NetworkKeyGetRsp, + None, + ), + CommandId.network_key_set: ( + NetworkKeySetReq, + NetworkKeySetRsp, + None, + ), + CommandId.nwk_frame_counter_get: ( + NwkFrameCounterGetReq, + NwkFrameCounterGetRsp, + None, + ), + CommandId.nwk_frame_counter_set: ( + NwkFrameCounterSetReq, + NwkFrameCounterSetRsp, + None, + ), + CommandId.network_role_get: ( + NetworkRoleGetReq, + NetworkRoleGetRsp, + None, + ), + CommandId.network_role_set: ( + NetworkRoleSetReq, + NetworkRoleSetRsp, + None, + ), + CommandId.use_predefined_nwk_panid_set: ( + UsePredefinedNwkPanidSetReq, + UsePredefinedNwkPanidSetRsp, + None, + ), + CommandId.nwk_update_id_get: ( + NwkUpdateIdGetReq, + NwkUpdateIdGetRsp, + None, + ), + CommandId.nwk_update_id_set: ( + NwkUpdateIdSetReq, + NwkUpdateIdSetRsp, + None, + ), + CommandId.trust_center_address_get: ( + TrustCenterAddressGetReq, + TrustCenterAddressGetRsp, + None, + ), + CommandId.trust_center_address_set: ( + TrustCenterAddressSetReq, + TrustCenterAddressSetRsp, + None, + ), + CommandId.link_key_get: ( + LinkKeyGetReq, + LinkKeyGetRsp, + None, + ), + CommandId.link_key_set: ( + LinkKeySetReq, + LinkKeySetRsp, + None, + ), + CommandId.security_mode_get: ( + SecurityModeGetReq, + SecurityModeGetRsp, + None, + ), + CommandId.security_mode_set: ( + SecurityModeSetReq, + SecurityModeSetRsp, + None, + ), + CommandId.system_reset: ( + SystemResetReq, + SystemResetRsp, + SystemResetInd, + ), + CommandId.system_factory: ( + SystemFactoryReq, + SystemFactoryRsp, + SystemFactoryInd, + ), + CommandId.system_firmware: ( + SystemFirmwareReq, + SystemFirmwareRsp, + None, + ), + CommandId.system_model: ( + SystemModelReq, + SystemModelRsp, + None, + ), + CommandId.system_manufacturer: ( + SystemManufacturerReq, + SystemManufacturerRsp, + None, + ), +} + +COMMAND_SCHEMA_TO_COMMAND_ID = { + req: command_id for command_id, (req, _, _) in COMMAND_SCHEMAS.items() +} diff --git a/zigpy_espzb/types.py b/zigpy_espzb/types.py index 5a3f601..dc6b229 100644 --- a/zigpy_espzb/types.py +++ b/zigpy_espzb/types.py @@ -1,57 +1,174 @@ """Data types module.""" -from zigpy.types import bitmap8 +from __future__ import annotations +import zigpy.types as t -def serialize_dict(data, schema): - chunks = [] - for key in schema: - value = data[key] - if value is None: - break +class Bytes(bytes): + def serialize(self): + return self - if not isinstance(value, schema[key]): - value = schema[key](value) + @classmethod + def deserialize(cls, data): + return cls(data), b"" - chunks.append(value.serialize()) - return b"".join(chunks) +class TransmitOptions(t.bitmap8): + NONE = 0x00 + # Security enabled transmission + SECURITY_ENABLED = 0x01 + # Use NWK key (obsolete) + USE_NWK_KEY_R21OBSOLETE = 0x02 + # Extension: do not include long src/dst addresses into NWK hdr + NO_LONG_ADDR = 0x02 + # Acknowledged transmission + ACK_TX = 0x04 + # Fragmentation permitted + FRAG_PERMITTED = 0x08 + # Include extended nonce in APS security frame + INC_EXT_NONCE = 0x10 + + +class ExtendedAddrMode(t.enum8): + # DstAddress and DstEndpoint not present + MODE_DST_ADDR_ENDP_NOT_PRESENT = 0x00 + # 16-bit group address for DstAddress; DstEndpoint not present + MODE_16_GROUP_ENDP_NOT_PRESENT = 0x01 + # 16-bit address for DstAddress and DstEndpoint present + MODE_16_ENDP_PRESENT = 0x02 + # 64-bit extended address for DstAddress and DstEndpoint present + MODE_64_ENDP_PRESENT = 0x03 -def deserialize_dict(data, schema): - result = {} - for name, type_ in schema.items(): - try: - result[name], data = type_.deserialize(data) - except ValueError: - if data: - raise + @classmethod + def from_zigpy_addr_mode(cls, addr_mode: t.AddrMode) -> ExtendedAddrMode: + """Convert a Zigpy AddrMode to an ExtendedAddrMode.""" + return { + t.AddrMode.IEEE: cls.MODE_64_ENDP_PRESENT, + t.AddrMode.NWK: cls.MODE_16_ENDP_PRESENT, + t.AddrMode.Group: cls.MODE_16_GROUP_ENDP_NOT_PRESENT, + t.AddrMode.Broadcast: cls.MODE_16_GROUP_ENDP_NOT_PRESENT, + }[addr_mode] + + def to_zigpy_addr_mode(self) -> t.AddrMode: + """Convert a Zigpy AddrMode to an ExtendedAddrMode.""" + return { + self.MODE_64_ENDP_PRESENT: t.AddrMode.IEEE, + self.MODE_16_ENDP_PRESENT: t.AddrMode.NWK, + self.MODE_DST_ADDR_ENDP_NOT_PRESENT: t.AddrMode.NWK, + self.MODE_16_GROUP_ENDP_NOT_PRESENT: t.AddrMode.Group, + self.MODE_16_GROUP_ENDP_NOT_PRESENT: t.AddrMode.Broadcast, + # TODO: why is this necessary? + 0xFF: t.AddrMode.NWK, + }[self] + + +def addr_mode_with_eui64_to_addr_mode_address( + addr_mode: ExtendedAddrMode, address: t.EUI64 +) -> t.AddrModeAddress: + """Convert an address mode and an EUI64 address to an AddrModeAddress.""" + address_short, _ = t.uint16_t.deserialize(address.serialize()[:2]) + zigpy_addr_mode = addr_mode.to_zigpy_addr_mode() + + if zigpy_addr_mode == t.AddrMode.IEEE: + address = address + elif zigpy_addr_mode == t.AddrMode.NWK: + address = t.NWK(address_short) + elif zigpy_addr_mode == t.AddrMode.Group: + address = t.Group(address_short) + elif zigpy_addr_mode == t.AddrMode.Broadcast: + address = t.BroadcastAddress(address_short) + else: + raise ValueError(f"Unknown address mode: {zigpy_addr_mode}") + + return t.AddrModeAddress(addr_mode=zigpy_addr_mode, address=address) + + +class ShiftedChannels(t.bitmap32): + """Zigbee Channels.""" + + # fmt: off + CHANNEL_11 = 0b00000000000000000000100000000000 + CHANNEL_12 = 0b00000000000000000001000000000000 + CHANNEL_13 = 0b00000000000000000010000000000000 + CHANNEL_14 = 0b00000000000000000100000000000000 + CHANNEL_15 = 0b00000000000000001000000000000000 + CHANNEL_16 = 0b00000000000000010000000000000000 + CHANNEL_17 = 0b00000000000000100000000000000000 + CHANNEL_18 = 0b00000000000001000000000000000000 + CHANNEL_19 = 0b00000000000010000000000000000000 + CHANNEL_20 = 0b00000000000100000000000000000000 + CHANNEL_21 = 0b00000000001000000000000000000000 + CHANNEL_22 = 0b00000000010000000000000000000000 + CHANNEL_23 = 0b00000000100000000000000000000000 + CHANNEL_24 = 0b00000001000000000000000000000000 + CHANNEL_25 = 0b00000010000000000000000000000000 + CHANNEL_26 = 0b00000100000000000000000000000000 + ALL_CHANNELS = 0b00000111111111111111100000000000 + NO_CHANNELS = 0b00000000000000000000000000000000 + # fmt: on + + __iter__ = t.Channels.__iter__ + from_channel_list = classmethod(t.Channels.from_channel_list.__func__) - result[name] = None - return result, data + @classmethod + def from_zigpy_channels(cls, channels: t.Channels) -> ShiftedChannels: + """Convert a Zigpy Channels to a ShiftedChannels.""" + return cls.from_channel_list(tuple(channels)) -def list_replace(lst: list, old: object, new: object) -> list: - """Replace all occurrences of `old` with `new` in `lst`.""" - return [new if x == old else x for x in lst] +class DeviceType(t.enum8): + COORDINATOR = 0x00 + ROUTER = 0x01 + END_DEVICE = 0x02 + NONE = 0x03 -class Bytes(bytes): - def serialize(self): - return self +class Status(t.enum8): + SUCCESS = 0 + FAILURE = 1 + INVALID_VALUE = 2 + TIMEOUT = 3 + UNSUPPORTED = 4 + ERROR = 5 + NO_NETWORK = 6 + BUSY = 7 - @classmethod - def deserialize(cls, data): - return cls(data), b"" +class FirmwareVersion(t.Struct, t.uint32_t): + reserved: t.uint8_t + patch: t.uint8_t + minor: t.uint8_t + major: t.uint8_t + + +class NetworkState(t.enum8): + OFFLINE = 0 + JOINING = 1 + CONNECTED = 2 + LEAVING = 3 + CONFIRM = 4 + INDICATION = 5 -class ZnspTransmitOptions(bitmap8): - NONE = 0x00 - ACK_ENABLED = 0x01 - SECURITY_ENABLED = 0x02 +class SecurityMode(t.enum8): + NO_SECURITY = 0x00 + PRECONFIGURED_NETWORK_KEY = 0x01 -class DeviceAddrMode: - # TODO: implement this class - pass + +class ZDPResponseHandling(t.bitmap16): + NONE = 0x0000 + NodeDescRsp = 0x0001 + + +class TXStatus(t.enum8): + SUCCESS = 0x00 + + @classmethod + def _missing_(cls, value): + chained = t.APSStatus(value) + status = t.uint8_t.__new__(cls, chained.value) + status._name_ = chained.name + status._value_ = value + return status diff --git a/zigpy_espzb/zigbee/application.py b/zigpy_espzb/zigbee/application.py index 6fe04aa..352d896 100644 --- a/zigpy_espzb/zigbee/application.py +++ b/zigpy_espzb/zigbee/application.py @@ -20,33 +20,25 @@ import zigpy.exceptions from zigpy.exceptions import FormationFailure, NetworkNotFormed import zigpy.state -import zigpy.types +import zigpy.types as t import zigpy.util import zigpy.zdo.types as zdo_t -import zigpy_espzb -from zigpy_espzb import types as t -from zigpy_espzb.api import ( - IndexedKey, - LinkKey, +from zigpy_espzb.api import Znsp +from zigpy_espzb.types import ( + DeviceType, + ExtendedAddrMode, NetworkState, SecurityMode, - Status, - Znsp, + TransmitOptions, ) -import zigpy_espzb.exception LOGGER = logging.getLogger(__name__) CHANGE_NETWORK_POLL_TIME = 1 CHANGE_NETWORK_STATE_DELAY = 2 -DELAY_NEIGHBOUR_SCAN_S = 1500 SEND_CONFIRM_TIMEOUT = 60 -PROTO_VER_MANUAL_SOURCE_ROUTE = 0x010C -PROTO_VER_WATCHDOG = 0x0108 -PROTO_VER_NEIGBOURS = 0x0107 - ENERGY_SCAN_ATTEMPTS = 5 @@ -60,18 +52,15 @@ class ControllerApplication(zigpy.application.ControllerApplication): def __init__(self, config: dict[str, Any]): """Initialize instance.""" - super().__init__(config=zigpy.config.ZIGPY_SCHEMA(config)) + super().__init__(config=config) self._api = None self._pending = zigpy.util.Requests() - - self._delayed_neighbor_scan_task = None self._reconnect_task = None - self._written_endpoints = set() - async def _watchdog_feed(self): - await self._api.set_watchdog_ttl(int(self._watchdog_period / 0.75)) + # TODO: implement a proper software-driven watchdog + await self._api.get_network_state() async def connect(self): api = Znsp(self, self._config[zigpy.config.CONF_DEVICE]) @@ -82,41 +71,46 @@ async def connect(self): api.close() raise + await api.reset() + + # TODO: Most commands fail if the network is not formed. Why? + await api.network_init() + await api.start(autostart=False) + self._api = api - self._written_endpoints.clear() async def disconnect(self): - if self._delayed_neighbor_scan_task is not None: - self._delayed_neighbor_scan_task.cancel() - self._delayed_neighbor_scan_task = None - if self._api is not None: self._api.close() self._api = None async def permit_with_link_key(self, node: t.EUI64, link_key: t.KeyData, time_s=60): - await self._api.set_link_key( - LinkKey(ieee=node, key=link_key), - ) - await self.permit(time_s) + raise NotImplementedError() async def start_network(self): - await self.register_endpoints() await self.load_network_info(load_devices=False) - await self._change_network_state(NetworkState.CONNECTED) + await self.register_endpoints() - coordinator = await ZnspDevice.new( - self, - self.state.node_info.ieee, - self.state.node_info.nwk, - self.state.node_info.model, + # Create the coordinator device + coordinator = zigpy.device.Device( + application=self, + ieee=self.state.node_info.ieee, + nwk=self.state.node_info.nwk, ) - self.devices[self.state.node_info.ieee] = coordinator - self._delayed_neighbor_scan_task = asyncio.create_task( - self._delayed_neighbour_scan() - ) + # TODO: why does the coordinator respond to the loopback ZDO Active_EP_req with + # [242, 242]? It should include endpoints 1 and 2, we registered them. + await coordinator.schedule_initialize() + + # TODO: add our registered endpoints manually so things don't crash. These + # should be discovered automatically. + ep1 = coordinator.add_endpoint(1) + ep1.status = zigpy.endpoint.Status.ZDO_INIT + ep2 = coordinator.add_endpoint(2) + ep2.status = zigpy.endpoint.Status.ZDO_INIT + + await self._api.form_network(role=DeviceType.COORDINATOR) async def _change_network_state( self, @@ -127,11 +121,11 @@ async def _change_network_state( async def change_loop(): while True: try: - device_state = await self._api.get_device_state() + network_state = await self._api.get_network_state() except asyncio.TimeoutError: LOGGER.debug("Failed to poll device state") else: - if NetworkState(device_state.network_state) == target_state: + if network_state == target_state: break await asyncio.sleep(CHANGE_NETWORK_POLL_TIME) @@ -148,54 +142,33 @@ async def change_loop(): raise FormationFailure("Network formation refused.") async def reset_network_info(self): - await self.form_network() + await self._api.factory_reset() async def write_network_info(self, *, network_info, node_info): - try: - await self._api.set_nwk_frame_counter(network_info.network_key.tx_counter) - except zigpy_espzb.exception.CommandError as ex: - assert ex.status == Status.UNSUPPORTED - LOGGER.warning( - "Doesn't support writing the network frame counter with this firmware" - ) + await self._api.factory_reset() + await self._api.network_init() + await self._api.start(autostart=False) - if node_info.logical_type == zdo_t.LogicalType.Coordinator: - await self._api.set_aps_designed_coordinator(1) - else: - await self._api.set_aps_designed_coordinator(0) + role = { + zdo_t.LogicalType.Coordinator: DeviceType.COORDINATOR, + zdo_t.LogicalType.Router: DeviceType.ROUTER, + }[node_info.logical_type] + await self._api.set_network_role(role) await self._api.set_nwk_address(node_info.nwk) - if node_info.ieee != zigpy.types.EUI64.UNKNOWN: + if node_info.ieee != t.EUI64.UNKNOWN: await self._api.set_mac_address(node_info.ieee) node_ieee = node_info.ieee else: - ieee = await self._api.mac_address() - node_ieee = zigpy.types.EUI64(ieee) - - if network_info.channel is not None: - channel_mask = zigpy.types.Channels.from_channel_list( - [network_info.channel] - ) - - if network_info.channel_mask and channel_mask != network_info.channel_mask: - LOGGER.warning( - "Channel mask %s will be replaced with current logical channel %s", - network_info.channel_mask, - channel_mask, - ) - else: - channel_mask = network_info.channel_mask + node_ieee = await self._api.get_mac_address() - await self._api.set_channel_mask(channel_mask) await self._api.set_use_predefined_nwk_panid(True) await self._api.set_nwk_panid(network_info.pan_id) - await self._api.set_aps_extended_panid(network_info.extended_pan_id) + await self._api.set_nwk_extended_panid(network_info.extended_pan_id) await self._api.set_nwk_update_id(network_info.nwk_update_id) - - await self._api.set_network_key( - IndexedKey(index=0, key=network_info.network_key.key), - ) + await self._api.set_network_key(network_info.network_key.key) + await self._api.set_nwk_frame_counter(network_info.network_key.tx_counter) if network_info.network_key.seq != 0: LOGGER.warning( @@ -205,94 +178,71 @@ async def write_network_info(self, *, network_info, node_info): tc_link_key_partner_ieee = network_info.tc_link_key.partner_ieee - if tc_link_key_partner_ieee == zigpy.types.EUI64.UNKNOWN: + if tc_link_key_partner_ieee == t.EUI64.UNKNOWN: tc_link_key_partner_ieee = node_ieee - await self._api.set_trust_center_address( - tc_link_key_partner_ieee, - ) - await self._api.set_link_key( - LinkKey( - ieee=tc_link_key_partner_ieee, - key=network_info.tc_link_key.key, - ), - ) + await self._api.set_trust_center_address(tc_link_key_partner_ieee) + await self._api.set_link_key(network_info.tc_link_key.key) if network_info.security_level == 0x00: await self._api.set_security_mode(SecurityMode.NO_SECURITY) else: - await self._api.set_security_mode(SecurityMode.ONLY_TCLK) + await self._api.set_security_mode(SecurityMode.PRECONFIGURED_NETWORK_KEY) - await self._change_network_state(NetworkState.OFFLINE) - await asyncio.sleep(CHANGE_NETWORK_STATE_DELAY) - await self._change_network_state(NetworkState.CONNECTED) + await self._api.set_channel(network_info.channel) async def load_network_info(self, *, load_devices=False): + channel = await self._api.get_current_channel() + + if not 11 <= channel <= 26: + raise NetworkNotFormed(f"Channel is invalid: {channel}") + network_info = self.state.network_info node_info = self.state.node_info - ieee = await self._api.mac_address() - node_info.ieee = zigpy.types.EUI64(ieee) - designed_coord = await self._api.aps_designed_coordinator() + role = await self._api.get_network_role() - if designed_coord == 0x01: + if role == DeviceType.COORDINATOR: node_info.logical_type = zdo_t.LogicalType.Coordinator else: node_info.logical_type = zdo_t.LogicalType.Router - node_info.nwk = await self._api.nwk_address() + node_info.nwk = await self._api.get_nwk_address() + node_info.ieee = await self._api.get_mac_address() - node_info.manufacturer = "Espressif Systems" - - node_info.model = "ESP32H2" + # TODO: implement firmware commands to read the board name, manufacturer + node_info.manufacturer = await self._api.system_manufacturer() + node_info.model = await self._api.system_model() + # TODO: implement firmware command to read out the firmware version and build ID node_info.version = f"{int(self._api.firmware_version):#010x}" network_info.source = f"zigpy-espzb@{importlib.metadata.version('zigpy-espzb')}" - network_info.metadata = { - "espzb": { - "version": node_info.version, - } - } - - network_info.pan_id = await self._api.nwk_panid() - network_info.extended_pan_id = await self._api.aps_extended_panid() - - if network_info.extended_pan_id == zigpy.types.EUI64.convert( - "00:00:00:00:00:00:00:00" - ): - network_info.extended_pan_id = await self._api.nwk_extended_panid() - - network_info.channel = await self._api.current_channel() - network_info.channel_mask = await self._api.channel_mask() - network_info.nwk_update_id = await self._api.nwk_update_id() + network_info.metadata = {} - if network_info.channel == 0: - raise NetworkNotFormed("Network channel is zero") + network_info.pan_id = await self._api.get_nwk_panid() + network_info.extended_pan_id = await self._api.get_nwk_extended_panid() + network_info.channel = await self._api.get_current_channel() + network_info.channel_mask = await self._api.get_channel_mask() + network_info.nwk_update_id = await self._api.get_nwk_update_id() - indexed_key = await self._api.network_key() + if network_info.channel in (0, 255): + raise NetworkNotFormed(f"Channel is invalid: {network_info.channel}") - network_info.network_key = zigpy.state.Key() - network_info.network_key.key = indexed_key.key - - try: - network_info.network_key.tx_counter = await self._api.nwk_frame_counter() - except zigpy_espzb.exception.CommandError as ex: - assert ex.status == Status.UNSUPPORTED + network_info.network_key.key = await self._api.get_network_key() + network_info.network_key.tx_counter = await self._api.get_nwk_frame_counter() network_info.tc_link_key = zigpy.state.Key() - network_info.tc_link_key.partner_ieee = await self._api.trust_center_address() - - link_key = await self._api.link_key( - network_info.tc_link_key.partner_ieee, + network_info.tc_link_key.key = await self._api.get_link_key() + network_info.tc_link_key.partner_ieee = ( + await self._api.get_trust_center_address() ) - network_info.tc_link_key.key = link_key.key - security_mode = await self._api.security_mode() + security_mode = await self._api.get_security_mode() if security_mode == SecurityMode.NO_SECURITY: network_info.security_level = 0x00 - elif security_mode == SecurityMode.ONLY_TCLK: + elif security_mode == SecurityMode.PRECONFIGURED_NETWORK_KEY: network_info.security_level = 0x05 else: LOGGER.warning("Unsupported security mode %r", security_mode) @@ -301,45 +251,11 @@ async def load_network_info(self, *, load_devices=False): async def force_remove(self, dev): """Forcibly remove device from NCP.""" - async def energy_scan( - self, channels: t.Channels.ALL_CHANNELS, duration_exp: int, count: int - ) -> dict[int, float]: - results = await super().energy_scan( - channels=channels, duration_exp=duration_exp, count=count - ) - - return {c: v * 3 for c, v in results.items()} - - for i in range(ENERGY_SCAN_ATTEMPTS): - try: - rsp = await self._device.zdo.Mgmt_NWK_Update_req( - zigpy.zdo.types.NwkUpdate( - ScanChannels=channels, - ScanDuration=duration_exp, - ScanCount=count, - ) - ) - break - except (asyncio.TimeoutError, zigpy.exceptions.DeliveryError): - if i == ENERGY_SCAN_ATTEMPTS - 1: - raise - - continue - - _, scanned_channels, _, _, energy_values = rsp - return dict(zip(scanned_channels, energy_values)) - async def _move_network_to_channel( self, new_channel: int, new_nwk_update_id: int ) -> None: """Move device to a new channel.""" - channel_mask = zigpy.types.Channels.from_channel_list([new_channel]) - await self._api.set_channel_mask(channel_mask) - await self._api.set_nwk_update_id(new_nwk_update_id) - - await self._change_network_state(NetworkState.OFFLINE) - await asyncio.sleep(CHANGE_NETWORK_STATE_DELAY) - await self._change_network_state(NetworkState.CONNECTED) + raise NotImplementedError() async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None: """Register a new endpoint on the device.""" @@ -356,170 +272,92 @@ async def add_endpoint(self, descriptor: zdo_t.SimpleDescriptor) -> None: async def send_packet(self, packet): LOGGER.debug("Sending packet: %r", packet) - force_relays = None - - dst_addr = packet.dst.address - addr_mode = packet.dst.addr_mode - if packet.dst.addr_mode != zigpy.types.AddrMode.IEEE: - dst_addr = t.EUI64( - [ - packet.dst.address % 0x100, - packet.dst.address >> 8, - 0, - 0, - 0, - 0, - 0, - 0, - ] - ) - if packet.dst.addr_mode == zigpy.types.AddrMode.Broadcast: - addr_mode = zigpy.types.AddrMode.Group - - if packet.dst.addr_mode != zigpy.types.AddrMode.IEEE: - src_addr = t.EUI64( - [ - packet.dst.address % 0x100, - packet.dst.address >> 8, - 0, - 0, - 0, - 0, - 0, - 0, - ] + try: + device = self.get_device_with_address(packet.dst) + except (KeyError, ValueError): + device = None + + if packet.dst.addr_mode == t.AddrMode.IEEE: + LOGGER.warning("IEEE addressing is not supported, falling back to NWK") + + if device is None: + raise ValueError(f"Cannot find device with IEEE {packet.dst.address}") + + packet = packet.replace( + dst=t.AddrModeAddress(addr_mode=t.AddrMode.NWK, address=device.nwk) ) - if packet.source_route is not None: - force_relays = packet.source_route + assert packet.src.addr_mode == t.AddrMode.NWK + src_addr = t.EUI64( + [ + packet.src.address % 0x100, + packet.src.address >> 8, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ) - tx_options = t.ZnspTransmitOptions.NONE + dst_addr_mode = { + t.AddrMode.NWK: ExtendedAddrMode.MODE_16_ENDP_PRESENT, + t.AddrMode.IEEE: ExtendedAddrMode.MODE_64_ENDP_PRESENT, + t.AddrMode.Group: ExtendedAddrMode.MODE_16_GROUP_ENDP_NOT_PRESENT, + t.AddrMode.Broadcast: ExtendedAddrMode.MODE_16_GROUP_ENDP_NOT_PRESENT, + }[packet.dst.addr_mode] + + dst_addr = t.EUI64( + [ + packet.dst.address % 0x100, + packet.dst.address >> 8, + 0, + 0, + 0, + 0, + 0, + 0, + ] + ) + + tx_options = TransmitOptions.NONE - if zigpy.types.TransmitOptions.ACK in packet.tx_options: - tx_options |= t.ZnspTransmitOptions.ACK_ENABLED + if t.TransmitOptions.ACK in packet.tx_options: + tx_options |= TransmitOptions.ACK_TX - if zigpy.types.TransmitOptions.APS_Encryption in packet.tx_options: - tx_options |= t.ZnspTransmitOptions.SECURITY_ENABLED + if t.TransmitOptions.APS_Encryption in packet.tx_options: + tx_options |= TransmitOptions.SECURITY_ENABLED - async with self._limit_concurrency(): + async with self._limit_concurrency(priority=packet.priority): await self._api.aps_data_request( dst_addr=dst_addr, dst_ep=packet.dst_ep, src_addr=src_addr, src_ep=packet.src_ep, - profile=packet.profile_id, - addr_mode=addr_mode, + profile=packet.profile_id or 0, + addr_mode=dst_addr_mode, cluster=packet.cluster_id, sequence=packet.tsn, options=tx_options, radius=packet.radius or 0, data=packet.data.serialize(), - relays=force_relays, - extended_timeout=packet.extended_timeout, ) async def permit_ncp(self, time_s=60): assert 0 <= time_s <= 254 - await self._api.set_permit_join(time_s) - - async def restore_neighbours(self) -> None: - """Restore children.""" - coord = self.get_device(ieee=self.state.node_info.ieee) - - for neighbor in self.topology.neighbors[coord.ieee]: - try: - device = self.get_device(ieee=neighbor.ieee) - except KeyError: - continue - - descr = device.node_desc - LOGGER.debug( - "device: 0x%04x - %s %s, FFD=%s, Rx_on_when_idle=%s", - device.nwk, - device.manufacturer, - device.model, - descr.is_full_function_device if descr is not None else None, - descr.is_receiver_on_when_idle if descr is not None else None, - ) - if ( - descr is None - or descr.is_full_function_device - or descr.is_receiver_on_when_idle - ): - continue - - LOGGER.debug("Restoring %s as direct child", device) - - try: - await self._api.add_neighbour( - nwk=device.nwk, - ieee=device.ieee, - mac_capability_flags=descr.mac_capability_flags, - ) - except zigpy_espzb.exception.CommandError as ex: - assert ex.status == Status.FAILURE - LOGGER.debug("Failed to add device to neighbor table: %s", ex) - - async def _delayed_neighbour_scan(self) -> None: - """Scan coordinator's neighbours.""" - await asyncio.sleep(DELAY_NEIGHBOUR_SCAN_S) - coord = self.get_device(ieee=self.state.node_info.ieee) - await self.topology.scan(devices=[coord]) - - -class ZnspDevice(zigpy.device.Device): - """Zigpy Device representing Coordinator.""" - - def __init__(self, model: str, *args): - """Initialize instance.""" - super().__init__(*args) - self._model = model - - async def add_to_group(self, grp_id: int, name: str = None) -> None: - group = self.application.groups.add_group(grp_id, name) - - for epid in self.endpoints: - if not epid: - continue # skip ZDO - group.add_member(self.endpoints[epid]) - return [0] - - async def remove_from_group(self, grp_id: int) -> None: - for epid in self.endpoints: - if not epid: - continue # skip ZDO - self.application.groups[grp_id].remove_member(self.endpoints[epid]) - return [0] - - @property - def manufacturer(self): - return "Espressif Systems" - - @property - def model(self): - return self._model - - @classmethod - async def new(cls, application, ieee, nwk, model: str): - """Create or replace zigpy device.""" - dev = cls(model, application, ieee, nwk) - - if ieee in application.devices: - from_dev = application.get_device(ieee=ieee) - dev.status = from_dev.status - dev.node_desc = from_dev.node_desc - for ep_id, from_ep in from_dev.endpoints.items(): - if not ep_id: - continue # Skip ZDO - ep = dev.add_endpoint(ep_id) - ep.profile_id = from_ep.profile_id - ep.device_type = from_ep.device_type - ep.status = from_ep.status - ep.in_clusters = from_ep.in_clusters - ep.out_clusters = from_ep.out_clusters - else: - application.devices[ieee] = dev - await dev.initialize() + await self._device.zdo.permit(time_s) + + # TODO: this does not work, the NCP responds again with: + # Unknown command received: Command( + # version=0, + # frame_type=, + # reserved=0, + # command_id=, + # seq=144, + # length=1, + # payload=b'\x02' + # ) - return dev + # await self._api.set_permit_join(time_s)