diff --git a/pydantic_ai_harness/__init__.py b/pydantic_ai_harness/__init__.py index 0a60fd7..a8a7dc5 100644 --- a/pydantic_ai_harness/__init__.py +++ b/pydantic_ai_harness/__init__.py @@ -4,8 +4,9 @@ if TYPE_CHECKING: from .code_mode import CodeMode + from .home_automation import HomeAutomation -__all__ = ['CodeMode'] +__all__ = ['CodeMode', 'HomeAutomation'] def __getattr__(name: str) -> object: @@ -13,4 +14,8 @@ def __getattr__(name: str) -> object: from .code_mode import CodeMode return CodeMode + if name == 'HomeAutomation': + from .home_automation import HomeAutomation + + return HomeAutomation raise AttributeError(f'module {__name__!r} has no attribute {name!r}') diff --git a/pydantic_ai_harness/home_automation/README.md b/pydantic_ai_harness/home_automation/README.md new file mode 100644 index 0000000..3612498 --- /dev/null +++ b/pydantic_ai_harness/home_automation/README.md @@ -0,0 +1,63 @@ +# Home Automation + +Home Automation exposes smart home entities and service calls to a Pydantic AI +agent. The current backend targets the Home Assistant REST API. + +```python +from pydantic_ai import Agent +from pydantic_ai_harness import HomeAutomation +from pydantic_ai_harness.home_automation import HomeAssistantBackend + +backend = HomeAssistantBackend( + url='http://localhost:8123', + token='...', +) + +agent = Agent( + 'openai:gpt-5', + capabilities=[HomeAutomation(backend=backend)], +) +``` + +## Tools + +The capability exposes the backend methods directly as tools: + +- `list_services(domain=None)`: list callable Home Assistant services and their normalized arguments. +- `list_entities(domain=None)`: list entity summaries. +- `get_state(entity_id)`: get the current state for one entity. +- `list_states(domain=None)`: list current entity states. +- `call_service(domain, entity_id, service_name, **data)`: call a service for one entity. + +Agents should usually discover entities with `list_entities`, inspect available +services with `list_services`, then call `call_service` with the exact +`domain`, `service_name`, and `entity_id`. + +## Home Assistant Backend + +`HomeAssistantBackend` validates `/api/services` and `/api/states` responses +with Pydantic models, then converts them into backend-neutral dataclasses for +the agent-facing tools. + +Service calls return a `ServiceCallResult` with: + +- `changed_states`: states Home Assistant reports as changed during execution. +- `service_response`: response data for services that support or require it. +- `verified_state`: a best-effort post-call state for the target entity. + +If Home Assistant returns no changed states or service response data, +`HomeAssistantBackend` may poll the target entity state to populate +`verified_state`. Configure this with `verification_poll_attempts` and +`verification_poll_interval`. + +```python +backend = HomeAssistantBackend( + url='http://localhost:8123', + token='...', + verification_poll_attempts=3, + verification_poll_interval=0.5, +) +``` + +Injected `httpx.AsyncClient` instances are not closed by `aclose`; clients +created by the backend are closed by `aclose`. diff --git a/pydantic_ai_harness/home_automation/__init__.py b/pydantic_ai_harness/home_automation/__init__.py new file mode 100644 index 0000000..549282b --- /dev/null +++ b/pydantic_ai_harness/home_automation/__init__.py @@ -0,0 +1,7 @@ +"""Home automation capability and backend exports.""" + +from pydantic_ai_harness.home_automation._capability import HomeAutomation +from pydantic_ai_harness.home_automation._toolset import HomeAutomationToolset +from pydantic_ai_harness.home_automation.backends.home_assistant._backend import HomeAssistantBackend + +__all__ = ['HomeAutomation', 'HomeAutomationToolset', 'HomeAssistantBackend'] diff --git a/pydantic_ai_harness/home_automation/_capability.py b/pydantic_ai_harness/home_automation/_capability.py new file mode 100644 index 0000000..87aeb09 --- /dev/null +++ b/pydantic_ai_harness/home_automation/_capability.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +from pydantic_ai._run_context import AgentDepsT +from pydantic_ai.capabilities import AbstractCapability + +from pydantic_ai_harness.home_automation._toolset import HomeAutomationToolset +from pydantic_ai_harness.home_automation.backends import HomeBackend + + +@dataclass +class HomeAutomation(AbstractCapability[AgentDepsT]): + """Capability exposing Home Assistant-style service discovery.""" + + backend: HomeBackend + + def get_instructions(self) -> str: + """Return guidance for agents using the home automation capability.""" + return ( + 'You have access to tools that let you inspect and control smart home ' + 'entities such as lights, switches, and climate devices through Home ' + 'Assistant. Use `list_entities` or `list_states` to discover available ' + 'entities, and `get_state` when you need the current state of one entity. ' + 'Use `list_services` to discover valid services and their arguments before ' + 'calling them. When using `call_service`, pass the exact `domain`, ' + '`service_name`, and `entity_id` returned by Home Assistant. After a ' + 'service call, prefer `verified_state` as the strongest confirmation of ' + 'what happened, then `changed_states`, and finally `service_response`. ' + '`call_service` may perform follow-up state reads when Home Assistant does ' + 'not return changed states or response data.' + ) + + def get_toolset(self) -> HomeAutomationToolset[AgentDepsT]: + """Expose the home automation tools to the agent runtime.""" + return HomeAutomationToolset(self.backend) diff --git a/pydantic_ai_harness/home_automation/_toolset.py b/pydantic_ai_harness/home_automation/_toolset.py new file mode 100644 index 0000000..72f69c7 --- /dev/null +++ b/pydantic_ai_harness/home_automation/_toolset.py @@ -0,0 +1,24 @@ +"""Toolset for exposing home automation backends to agents.""" + +from pydantic_ai import FunctionToolset +from pydantic_ai._run_context import AgentDepsT + +from pydantic_ai_harness.home_automation.backends import HomeBackend + + +class HomeAutomationToolset(FunctionToolset[AgentDepsT]): + """Function toolset backed by a `HomeBackend` implementation.""" + + backend: HomeBackend + + def __init__(self, backend: HomeBackend) -> None: + self.backend = backend + super().__init__( + tools=[ + backend.list_services, + backend.list_entities, + backend.get_state, + backend.list_states, + backend.call_service, + ] + ) diff --git a/pydantic_ai_harness/home_automation/backends/__init__.py b/pydantic_ai_harness/home_automation/backends/__init__.py new file mode 100644 index 0000000..b08f0a3 --- /dev/null +++ b/pydantic_ai_harness/home_automation/backends/__init__.py @@ -0,0 +1,12 @@ +"""Backend contracts and shared home automation models.""" + +from pydantic_ai_harness.home_automation.backends._base_backend import ( + Entity, + EntityState, + HomeBackend, + Service, + ServiceArgument, + ServiceCallResult, +) + +__all__ = ['Entity', 'EntityState', 'HomeBackend', 'Service', 'ServiceArgument', 'ServiceCallResult'] diff --git a/pydantic_ai_harness/home_automation/backends/_base_backend.py b/pydantic_ai_harness/home_automation/backends/_base_backend.py new file mode 100644 index 0000000..842565a --- /dev/null +++ b/pydantic_ai_harness/home_automation/backends/_base_backend.py @@ -0,0 +1,78 @@ +from dataclasses import dataclass, field +from typing import Any, Protocol + + +@dataclass +class ServiceArgument: + name: str + value_type: str + required: bool = False + description: str | None = None + enum: tuple[str, ...] = () + minimum: float | int | None = None + maximum: float | int | None = None + + +@dataclass +class Service: + domain: str + name: str + args: tuple[ServiceArgument, ...] = () + + +@dataclass +class Entity: + entity_id: str + domain: str + name: str | None + current_state: str + + +@dataclass +class EntityState: + entity_id: str + last_updated: str + state: str + + +def _empty_entity_states() -> list[EntityState]: + return [] + + +@dataclass +class ServiceCallResult: + changed_states: list[EntityState] = field(default_factory=_empty_entity_states) + service_response: dict[str, Any] | None = None + verified_state: EntityState | None = None + + +class HomeBackend(Protocol): + """Backend contract for home automation service discovery.""" + + async def list_services(self, domain: str | None = None) -> list[Service]: # pragma: no cover + """Return the services that can be called, optionally filtered by domain.""" + ... + + async def list_entities(self, domain: str | None = None) -> list[Entity]: # pragma: no cover + """Return known entities, optionally filtered by domain.""" + ... + + async def get_state(self, entity_id: str) -> EntityState: # pragma: no cover + """Return the current state for a single entity.""" + ... + + async def list_states(self, domain: str | None = None) -> list[EntityState]: # pragma: no cover + """Return current states, optionally filtered by domain.""" + ... + + async def call_service( + self, + domain: str, + entity_id: str, + service_name: str, + *, + want_response: bool = False, + **data: Any, + ) -> ServiceCallResult: # pragma: no cover + """Call a service for one entity and return the observed outcome.""" + ... diff --git a/pydantic_ai_harness/home_automation/backends/home_assistant/__init__.py b/pydantic_ai_harness/home_automation/backends/home_assistant/__init__.py new file mode 100644 index 0000000..d50f279 --- /dev/null +++ b/pydantic_ai_harness/home_automation/backends/home_assistant/__init__.py @@ -0,0 +1,5 @@ +"""Home Assistant backend exports.""" + +from ._backend import HomeAssistantBackend as HomeAssistantBackend + +__all__ = ['HomeAssistantBackend'] diff --git a/pydantic_ai_harness/home_automation/backends/home_assistant/_backend.py b/pydantic_ai_harness/home_automation/backends/home_assistant/_backend.py new file mode 100644 index 0000000..a566335 --- /dev/null +++ b/pydantic_ai_harness/home_automation/backends/home_assistant/_backend.py @@ -0,0 +1,215 @@ +import asyncio +from typing import Any + +import httpx + +from pydantic_ai_harness.home_automation.backends import Entity, EntityState, HomeBackend, Service, ServiceCallResult +from pydantic_ai_harness.home_automation.backends.home_assistant.models import ( + ENTITY_STATE_CATALOG_ADAPTER, + SERVICE_CATALOG_ADAPTER, + HAEntityState, + HAServiceCallResult, + ServiceCatalog, + ServiceDescription, +) + + +class HomeAssistantBackend(HomeBackend): + """Home Assistant adapter focused on `/api/services`.""" + + def __init__( + self, + url: str, + token: str, + client: httpx.AsyncClient | None = None, + *, + verification_poll_attempts: int = 3, + verification_poll_interval: float = 0.5, + ) -> None: + if verification_poll_attempts < 0: + raise ValueError('verification_poll_attempts must be greater than or equal to 0.') + if verification_poll_interval < 0: + raise ValueError('verification_poll_interval must be greater than or equal to 0.') + + self.url = url.rstrip('/') + self.token = token + self._service_catalog_cache: ServiceCatalog | None = None + self._verification_poll_attempts = verification_poll_attempts + self._verification_poll_interval = verification_poll_interval + self._owns_client = client is None + self.client = client or httpx.AsyncClient( + base_url=self.url, + headers={ + 'Authorization': f'Bearer {self.token}', + 'Content-Type': 'application/json', + }, + timeout=15.0, + ) + + async def aclose(self) -> None: + """Close the underlying HTTP client.""" + if self._owns_client: + await self.client.aclose() + + async def list_services(self, domain: str | None = None) -> list[Service]: + """Fetch and validate the Home Assistant REST service catalog.""" + catalog = await self._get_service_catalog() + if domain is not None: + catalog = [item for item in catalog if item.domain == domain] + return [service for domain_services in catalog for service in domain_services.to_services()] + + async def list_entities(self, domain: str | None = None) -> list[Entity]: + """Return Home Assistant entities, optionally filtered by domain.""" + ha_entities = await self._fetch_state_catalog() + if domain is not None: + return [ha_entity.to_entity() for ha_entity in ha_entities if ha_entity.domain == domain] + return [ha_entity.to_entity() for ha_entity in ha_entities] + + async def get_state(self, entity_id: str) -> EntityState: + """Fetch the current state for a single Home Assistant entity.""" + response = await self.client.get(f'/api/states/{entity_id}') + if response.status_code == 404: + raise ValueError(f'Entity {entity_id!r} was not found.') + response.raise_for_status() + ha_state = HAEntityState.model_validate(response.json()) + return ha_state.to_state() + + async def call_service( + self, + domain: str, + entity_id: str, + service_name: str, + *, + want_response: bool = False, + **data: Any, + ) -> ServiceCallResult: + """Call one Home Assistant service and normalize the response. + + If Home Assistant returns no changed states or response data, this may perform + follow-up state reads to populate `verified_state`. + """ + service_description = await self._get_service_description(domain, service_name) + if service_description is None: + raise ValueError(f'Service {domain}.{service_name} was not found.') + previous_state = await self._get_state_for_verification(entity_id) + + payload = {'entity_id': entity_id, **data} + response = await self.client.post( + f'/api/services/{domain}/{service_name}', + json=payload, + params=self._build_service_call_params(service_description, want_response), + ) + if response.status_code == 404: + raise ValueError(f'Service {domain}.{service_name} was not found.') + response.raise_for_status() + response_json = response.json() + if isinstance(response_json, list): + changed_states = ENTITY_STATE_CATALOG_ADAPTER.validate_python(response_json) + result = HAServiceCallResult(changed_states=changed_states).to_result() + else: + result = HAServiceCallResult.model_validate(response_json).to_result() + + result.verified_state = await self._get_verified_state( + entity_id=entity_id, + domain=domain, + previous_state=previous_state, + service_name=service_name, + result=result, + ) + return result + + async def list_states(self, domain: str | None = None) -> list[EntityState]: + """Return current entity states, optionally filtered by domain.""" + ha_entities = await self._fetch_state_catalog() + if domain is not None: + ha_entities = [ha_entity for ha_entity in ha_entities if ha_entity.domain == domain] + return [ha_entity.to_state() for ha_entity in ha_entities] + + async def _fetch_state_catalog(self) -> list[HAEntityState]: + """Fetch and validate the full Home Assistant state catalog.""" + response = await self.client.get('/api/states') + response.raise_for_status() + return ENTITY_STATE_CATALOG_ADAPTER.validate_python(response.json()) + + async def _get_service_catalog(self, *, refresh: bool = False) -> ServiceCatalog: + """Fetch the service catalog once and reuse it until refreshed.""" + catalog = self._service_catalog_cache + if catalog is None or refresh: + response = await self.client.get('/api/services') + response.raise_for_status() + catalog = SERVICE_CATALOG_ADAPTER.validate_python(response.json()) + self._service_catalog_cache = catalog + return catalog + + async def _get_service_description( + self, + domain: str, + service_name: str, + ) -> ServiceDescription | None: + """Look up one service description from the cached service catalog.""" + catalog = await self._get_service_catalog() + for domain_services in catalog: + if domain_services.domain == domain: + return domain_services.get_service_description(service_name) + return None + + @staticmethod + def _build_service_call_params( + service_description: ServiceDescription, + want_response: bool, + ) -> dict[str, str] | None: + """Build query parameters for service calls that need response payloads.""" + if service_description.response is None: + return None + + if not service_description.response.optional or want_response: + return {'return_response': ''} + + return None + + async def _get_verified_state( + self, + *, + entity_id: str, + domain: str, + service_name: str, + result: ServiceCallResult, + previous_state: EntityState | None = None, + ) -> EntityState | None: + """Resolve a concrete post-call state for the target entity when possible.""" + for changed_state in result.changed_states: + if changed_state.entity_id == entity_id: + return changed_state + + if result.changed_states or result.service_response is not None: + return None + + if self._verification_poll_attempts == 0: + return None + + latest_state: EntityState | None = None + for attempt in range(self._verification_poll_attempts): + latest_state = await self._get_state_for_verification(entity_id) + if latest_state is not None and self._state_changed(previous_state, latest_state): + return latest_state + if attempt < self._verification_poll_attempts - 1: + await asyncio.sleep(self._verification_poll_interval) + + if latest_state is not None: + return latest_state + + return None + + async def _get_state_for_verification(self, entity_id: str) -> EntityState | None: + """Best-effort state fetch used to verify a service call outcome.""" + try: + return await self.get_state(entity_id) + except (ValueError, httpx.HTTPError): + return None + + @staticmethod + def _state_changed(previous_state: EntityState | None, current_state: EntityState) -> bool: + """Return whether a polled state differs from the pre-call snapshot.""" + if previous_state is None: + return True + return current_state.state != previous_state.state or current_state.last_updated != previous_state.last_updated diff --git a/pydantic_ai_harness/home_automation/backends/home_assistant/models.py b/pydantic_ai_harness/home_automation/backends/home_assistant/models.py new file mode 100644 index 0000000..cf9640c --- /dev/null +++ b/pydantic_ai_harness/home_automation/backends/home_assistant/models.py @@ -0,0 +1,394 @@ +"""Home Assistant REST response models.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter + +from pydantic_ai_harness.home_automation.backends._base_backend import ( + Entity, + EntityState, + Service, + ServiceArgument, + ServiceCallResult, +) + + +class ServiceResponseInfo(BaseModel): + """Response metadata for a service.""" + + optional: bool + + +class FieldFilter(BaseModel): + """Visibility/applicability filters for one service field.""" + + supported_features: list[Any] = Field(default_factory=list) + attribute: dict[str, list[Any]] = Field(default_factory=dict) + + model_config = ConfigDict(extra='allow') + + +class NumberSelectorConfig(BaseModel): + """Selector config for numeric inputs.""" + + min: float | int | None = None + max: float | int | None = None + step: float | int | str | None = None + unit_of_measurement: str | None = None + + model_config = ConfigDict(extra='allow') + + +class SelectSelectorOption(BaseModel): + """Dict-style option for select inputs.""" + + value: str | int | float | bool + label: str | None = None + + model_config = ConfigDict(extra='allow') + + +def _empty_select_options() -> list[str | int | float | bool | SelectSelectorOption]: + return [] + + +class SelectSelectorConfig(BaseModel): + """Selector config for select inputs.""" + + options: list[str | int | float | bool | SelectSelectorOption] = Field(default_factory=_empty_select_options) + multiple: bool | None = None + translation_key: str | None = None + + model_config = ConfigDict(extra='allow') + + +class ColorTempSelectorConfig(BaseModel): + """Selector config for color temperature inputs.""" + + unit: str | None = None + min: int | None = None + max: int | None = None + + model_config = ConfigDict(extra='allow') + + +class HAEntitySelectorConfig(BaseModel): + """Selector config for entity targets/fields.""" + + domain: str | list[str] | None = None + integration: str | list[str] | None = None + device_class: str | list[str] | None = None + supported_features: list[Any] = Field(default_factory=list) + + model_config = ConfigDict(extra='allow') + + +class DeviceSelectorConfig(BaseModel): + """Selector config for device targets/fields.""" + + integration: str | list[str] | None = None + manufacturer: str | list[str] | None = None + model: str | list[str] | None = None + + model_config = ConfigDict(extra='allow') + + +class AreaSelectorConfig(BaseModel): + """Selector config for area targets/fields.""" + + model_config = ConfigDict(extra='allow') + + +class Selector(BaseModel): + """A Home Assistant selector declaration.""" + + number: NumberSelectorConfig | None = None + select: SelectSelectorConfig | None = None + color_temp: ColorTempSelectorConfig | None = None + color_rgb: dict[str, Any] | None = None + boolean: dict[str, Any] | None = None + text: dict[str, Any] | None = None + object: dict[str, Any] | None = None + entity: HAEntitySelectorConfig | None = None + device: DeviceSelectorConfig | None = None + area: AreaSelectorConfig | None = None + constant: dict[str, Any] | None = None + + model_config = ConfigDict(extra='allow') + + +class ServiceFieldDescription(BaseModel): + """Description for one service field, including nested field groups.""" + + name: str | None = None + description: str | None = None + required: bool | None = None + advanced: bool | None = None + example: Any = None + default: Any = None + selector: Selector | None = None + filter: FieldFilter | None = None + fields: dict[str, ServiceFieldDescription] = Field(default_factory=dict) + collapsed: bool | None = None + + model_config = ConfigDict(extra='allow') + + def to_service_argument(self, field_name: str) -> ServiceArgument: + """Convert this Home Assistant field description to a service argument.""" + selector = self.selector + + if selector is not None and selector.number is not None: + return ServiceArgument( + name=field_name, + value_type='number', + required=bool(self.required), + description=self.description, + minimum=selector.number.min, + maximum=selector.number.max, + ) + + if selector is not None and selector.select is not None: + return ServiceArgument( + name=field_name, + value_type='string', + required=bool(self.required), + description=self.description, + enum=tuple(_select_option_value(option) for option in selector.select.options), + ) + + if selector is not None and selector.color_temp is not None: + return ServiceArgument( + name=field_name, + value_type='number', + required=bool(self.required), + description=self.description, + minimum=selector.color_temp.min, + maximum=selector.color_temp.max, + ) + + if selector is not None and selector.color_rgb is not None: + return ServiceArgument( + name=field_name, + value_type='object', + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.boolean is not None: + return ServiceArgument( + name=field_name, + value_type='boolean', + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.constant is not None: + value_type = 'boolean' if isinstance(selector.constant.get('value'), bool) else 'object' + return ServiceArgument( + name=field_name, + value_type=value_type, + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.text is not None: + return ServiceArgument( + name=field_name, + value_type='string', + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.object is not None: + return ServiceArgument( + name=field_name, + value_type='object', + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.entity is not None: + return ServiceArgument( + name=field_name, + value_type='entity_id', + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.device is not None: + return ServiceArgument( + name=field_name, + value_type='device_id', + required=bool(self.required), + description=self.description, + ) + + if selector is not None and selector.area is not None: + return ServiceArgument( + name=field_name, + value_type='area_id', + required=bool(self.required), + description=self.description, + ) + + return ServiceArgument( + name=field_name, + value_type='string', + required=bool(self.required), + description=self.description, + ) + + +class ServiceTarget(BaseModel): + """Target metadata for a service.""" + + entity: HAEntitySelectorConfig | list[HAEntitySelectorConfig] | None = None + device: DeviceSelectorConfig | list[DeviceSelectorConfig] | None = None + area: AreaSelectorConfig | list[AreaSelectorConfig] | None = None + + model_config = ConfigDict(extra='allow') + + +class ServiceDescription(BaseModel): + """One Home Assistant service description.""" + + name: str | None = None + description: str | None = None + target: ServiceTarget | None = None + fields: dict[str, ServiceFieldDescription] = Field(default_factory=dict) + description_placeholders: dict[str, str] = Field(default_factory=dict) + response: ServiceResponseInfo | None = None + + model_config = ConfigDict(extra='allow') + + def to_service(self, domain: str, service_name: str) -> Service: + """Convert this Home Assistant service description to a service.""" + args = tuple(self._iter_service_arguments()) + return Service(domain=domain, name=service_name, args=args) + + def _iter_service_arguments(self) -> list[ServiceArgument]: + """Return top-level arguments plus one level of grouped advanced fields.""" + args: list[ServiceArgument] = [] + for field_name, field in self.fields.items(): + if field_name == 'advanced_fields': + args.extend( + nested_field.to_service_argument(nested_field_name) + for nested_field_name, nested_field in field.fields.items() + ) + continue + args.append(field.to_service_argument(field_name)) + return args + + +class DomainServices(BaseModel): + """Services grouped by Home Assistant domain.""" + + domain: str + services: dict[str, ServiceDescription] + + def to_services(self) -> list[Service]: + """Convert all services in this domain to service models.""" + return [ + service_description.to_service(self.domain, service_name) + for service_name, service_description in self.services.items() + ] + + def get_service_description(self, service_name: str) -> ServiceDescription | None: + """Return the Home Assistant description for one service.""" + return self.services.get(service_name) + + +ServiceFieldDescription.model_rebuild() +ServiceCatalog = list[DomainServices] +SERVICE_CATALOG_ADAPTER = TypeAdapter(ServiceCatalog) + + +def _select_option_value(option: str | int | float | bool | SelectSelectorOption) -> str: + if isinstance(option, SelectSelectorOption): + return str(option.value) + return str(option) + + +class HAEntityAttributes(BaseModel): + """Attributes returned for a Home Assistant entity state.""" + + friendly_name: str | None = None + supported_features: int | None = None + device_class: str | None = None + unit_of_measurement: str | None = None + state_class: str | None = None + icon: str | None = None + entity_picture: str | None = None + + model_config = ConfigDict(extra='allow') + + +class HAEntityContext(BaseModel): + """Context metadata returned with a Home Assistant entity state.""" + + id: str + parent_id: str | None = None + user_id: str | None = None + + model_config = ConfigDict(extra='allow') + + +class HAEntityState(BaseModel): + """One state object returned by Home Assistant.""" + + entity_id: str + state: str + attributes: HAEntityAttributes = Field(default_factory=HAEntityAttributes) + last_changed: str + last_reported: str | None = None + last_updated: str | None = None + context: HAEntityContext | None = None + + def to_entity(self) -> Entity: + """Convert this Home Assistant state to an entity summary.""" + domain, _, _object_id = self.entity_id.partition('.') + return Entity( + entity_id=self.entity_id, + domain=domain, + name=self.attributes.friendly_name, + current_state=self.state, + ) + + def to_state(self) -> EntityState: + """Convert this Home Assistant state to a lightweight state model.""" + return EntityState( + entity_id=self.entity_id, + state=self.state, + last_updated=self.last_updated or self.last_changed, + ) + + @property + def domain(self) -> str: + """Return the entity domain prefix.""" + return self.entity_id.split('.', 1)[0] + + model_config = ConfigDict(extra='allow') + + +EntityCatalog = list[HAEntityState] +ENTITY_STATE_CATALOG_ADAPTER = TypeAdapter(EntityCatalog) + + +def _empty_ha_entity_states() -> list[HAEntityState]: + return [] + + +class HAServiceCallResult(BaseModel): + """Response object returned by a Home Assistant service call.""" + + changed_states: list[HAEntityState] = Field(default_factory=_empty_ha_entity_states) + service_response: dict[str, Any] = Field(default_factory=dict) + + def to_result(self) -> ServiceCallResult: + """Convert this Home Assistant service response to a call result.""" + return ServiceCallResult( + changed_states=[state.to_state() for state in self.changed_states], + service_response=self.service_response or None, + ) diff --git a/tests/_home_automation/__init__.py b/tests/_home_automation/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/tests/_home_automation/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/_home_automation/test_home_automation.py b/tests/_home_automation/test_home_automation.py new file mode 100644 index 0000000..a6bd259 --- /dev/null +++ b/tests/_home_automation/test_home_automation.py @@ -0,0 +1,724 @@ +from collections.abc import Callable + +import httpx +import pytest +from pydantic_ai import FunctionToolset + +import pydantic_ai_harness +from pydantic_ai_harness.home_automation import HomeAssistantBackend, HomeAutomation, HomeAutomationToolset +from pydantic_ai_harness.home_automation.backends import Entity, EntityState, Service, ServiceCallResult +from pydantic_ai_harness.home_automation.backends.home_assistant._backend import HomeAssistantBackend as ExportedBackend +from pydantic_ai_harness.home_automation.backends.home_assistant.models import ( + HAEntityState, + HAServiceCallResult, + ServiceDescription, + ServiceFieldDescription, +) + +pytestmark = pytest.mark.anyio + +RouteHandler = Callable[[httpx.Request], httpx.Response] + + +@pytest.fixture +def anyio_backend() -> str: + return 'asyncio' + + +class MockHomeAssistant: + def __init__(self) -> None: + self.requests: list[httpx.Request] = [] + self._routes: dict[tuple[str, str], RouteHandler] = {} + + def route(self, method: str, path: str, handler: RouteHandler) -> None: + self._routes[(method.upper(), path)] = handler + + def json(self, method: str, path: str, payload: object, *, status_code: int = 200) -> None: + self.route(method, path, lambda request: httpx.Response(status_code, json=payload)) + + def count(self, method: str, path: str) -> int: + return sum(1 for request in self.requests if request.method == method.upper() and request.url.path == path) + + async def handle(self, request: httpx.Request) -> httpx.Response: + self.requests.append(request) + handler = self._routes.get((request.method, request.url.path)) + if handler is None: + raise AssertionError(f'Unexpected request: {request.method} {request.url}') # pragma: no cover + return handler(request) + + def backend( + self, + *, + verification_poll_attempts: int = 3, + verification_poll_interval: float = 0.5, + ) -> HomeAssistantBackend: + client = httpx.AsyncClient(base_url='http://example.test', transport=httpx.MockTransport(self.handle)) + return HomeAssistantBackend( + url='http://example.test', + token='token', + client=client, + verification_poll_attempts=verification_poll_attempts, + verification_poll_interval=verification_poll_interval, + ) + + +class StubBackend: # pragma: no cover - protocol test double + async def list_services(self, domain: str | None = None) -> list[Service]: + return [] + + async def list_entities(self, domain: str | None = None) -> list[Entity]: + return [] + + async def get_state(self, entity_id: str) -> EntityState: + return EntityState(entity_id=entity_id, last_updated='2026-04-19T12:00:00+00:00', state='on') + + async def list_states(self, domain: str | None = None) -> list[EntityState]: + return [] + + async def call_service( + self, + domain: str, + entity_id: str, + service_name: str, + *, + want_response: bool = False, + **data: object, + ) -> ServiceCallResult: + return ServiceCallResult() + + +def _services_payload() -> list[dict[str, object]]: + return [ + { + 'domain': 'light', + 'services': { + 'turn_on': { + 'name': 'Turn on', + 'description': 'Turn the light on.', + 'target': { + 'entity': { + 'domain': 'light', + } + }, + 'fields': { + 'brightness_pct': { + 'name': 'Brightness', + 'description': 'Brightness in percent.', + 'required': False, + 'selector': { + 'number': { + 'min': 0, + 'max': 100, + 'unit_of_measurement': '%', + } + }, + 'filter': { + 'supported_features': ['light.LightEntityFeature.EFFECT'], + 'attribute': { + 'supported_color_modes': [ + 'light.ColorMode.BRIGHTNESS', + 'light.ColorMode.HS', + ] + }, + }, + }, + 'advanced_fields': { + 'collapsed': True, + 'fields': { + 'transition': { + 'name': 'Transition', + 'selector': { + 'number': { + 'min': 0, + 'max': 300, + 'unit_of_measurement': 'seconds', + } + }, + } + }, + }, + }, + 'description_placeholders': {'docs_url': 'https://example.test/docs'}, + }, + 'turn_off': { + 'name': 'Turn off', + 'description': 'Turn the light off.', + 'target': { + 'entity': { + 'domain': 'light', + } + }, + 'fields': {}, + 'response': {'optional': True}, + }, + }, + }, + { + 'domain': 'climate', + 'services': { + 'set_temperature': { + 'name': 'Set temperature', + 'description': 'Set target temperature.', + 'target': { + 'entity': { + 'domain': 'climate', + } + }, + 'fields': { + 'temperature': { + 'required': True, + 'selector': { + 'number': { + 'min': 5, + 'max': 35, + } + }, + } + }, + 'response': {'optional': False}, + } + }, + }, + ] + + +def _states_payload() -> list[dict[str, object]]: + return [ + { + 'entity_id': 'light.bedroom_light', + 'state': 'on', + 'attributes': { + 'friendly_name': 'Bedroom Light', + }, + 'last_changed': '2026-04-19T12:00:00+00:00', + 'last_reported': '2026-04-19T12:00:00+00:00', + 'last_updated': '2026-04-19T12:00:00+00:00', + 'context': { + 'id': 'abc', + 'parent_id': None, + 'user_id': None, + }, + }, + { + 'entity_id': 'switch.coffee_machine', + 'state': 'off', + 'attributes': { + 'friendly_name': 'Coffee Machine', + }, + 'last_changed': '2026-04-19T12:05:00+00:00', + 'last_reported': '2026-04-19T12:05:00+00:00', + 'last_updated': '2026-04-19T12:05:00+00:00', + 'context': { + 'id': 'def', + 'parent_id': None, + 'user_id': None, + }, + }, + ] + + +class TestHomeAutomation: + def test_exposes_home_automation_toolset(self) -> None: + capability = HomeAutomation(backend=StubBackend()) + + toolset = capability.get_toolset() + + assert isinstance(toolset, FunctionToolset) + assert isinstance(toolset, HomeAutomationToolset) + assert toolset.backend is capability.backend + assert set(toolset.tools) == {'list_services', 'list_entities', 'get_state', 'list_states', 'call_service'} + + def test_instructions_describe_entities_services_and_verification(self) -> None: + instructions = HomeAutomation(backend=StubBackend()).get_instructions() + + assert 'smart home entities' in instructions + assert 'list_services' in instructions + assert 'verified_state' in instructions + assert 'follow-up state reads' in instructions + + def test_package_exports_public_capability_only_at_root(self) -> None: + assert pydantic_ai_harness.HomeAutomation is HomeAutomation + assert HomeAssistantBackend is ExportedBackend + with pytest.raises(AttributeError, match='HomeAssistantBackend'): + getattr(pydantic_ai_harness, 'HomeAssistantBackend') + with pytest.raises(AttributeError, match='missing'): + getattr(pydantic_ai_harness, 'missing') + + +class TestHomeAssistantServices: + async def test_list_services_parses_catalog_and_flattens_advanced_fields(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + + services = await server.backend().list_services() + + assert [(service.domain, service.name) for service in services] == [ + ('light', 'turn_on'), + ('light', 'turn_off'), + ('climate', 'set_temperature'), + ] + light_turn_on = services[0] + assert [arg.name for arg in light_turn_on.args] == ['brightness_pct', 'transition'] + assert light_turn_on.args[0].value_type == 'number' + assert light_turn_on.args[0].minimum == 0 + assert light_turn_on.args[0].maximum == 100 + assert light_turn_on.args[1].value_type == 'number' + assert light_turn_on.args[1].maximum == 300 + + async def test_list_services_can_filter_by_domain(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + + light_services = await server.backend().list_services(domain='light') + + assert [(service.domain, service.name) for service in light_services] == [ + ('light', 'turn_on'), + ('light', 'turn_off'), + ] + + async def test_list_services_maps_required_numeric_arguments(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + + services = await server.backend().list_services(domain='climate') + + assert len(services) == 1 + service = services[0] + assert service.domain == 'climate' + assert service.name == 'set_temperature' + assert len(service.args) == 1 + assert service.args[0].name == 'temperature' + assert service.args[0].required is True + assert service.args[0].minimum == 5 + assert service.args[0].maximum == 35 + + async def test_list_services_returns_empty_for_missing_domain(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + + services = await server.backend().list_services(domain='fan') + + assert services == [] + + async def test_service_catalog_can_refresh_cache(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + backend = server.backend() + + await backend._get_service_catalog() + await backend._get_service_catalog(refresh=True) + + assert server.count('GET', '/api/services') == 2 + + async def test_get_service_description_returns_none_for_missing_domain(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + + assert await server.backend()._get_service_description('fan', 'turn_on') is None + + +class TestHomeAssistantStates: + async def test_list_states_can_filter_by_domain(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states', _states_payload()) + + states = await server.backend().list_states(domain='light') + + assert len(states) == 1 + assert states[0].entity_id == 'light.bedroom_light' + assert states[0].state == 'on' + assert states[0].last_updated == '2026-04-19T12:00:00+00:00' + + async def test_list_states_returns_all_states(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states', _states_payload()) + + states = await server.backend().list_states() + + assert [state.entity_id for state in states] == ['light.bedroom_light', 'switch.coffee_machine'] + + async def test_list_entities_can_filter_by_domain(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states', _states_payload()) + + entities = await server.backend().list_entities(domain='switch') + + assert entities == [ + Entity( + entity_id='switch.coffee_machine', + domain='switch', + name='Coffee Machine', + current_state='off', + ) + ] + + async def test_list_entities_returns_all_entities(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states', _states_payload()) + + entities = await server.backend().list_entities() + + assert [entity.entity_id for entity in entities] == ['light.bedroom_light', 'switch.coffee_machine'] + + async def test_get_state_returns_one_entity_state(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states/light.bedroom_light', _states_payload()[0]) + + state = await server.backend().get_state('light.bedroom_light') + + assert state.entity_id == 'light.bedroom_light' + assert state.state == 'on' + assert state.last_updated == '2026-04-19T12:00:00+00:00' + + async def test_get_state_prefers_home_assistant_last_updated(self) -> None: + payload = { + **_states_payload()[0], + 'last_changed': '2026-04-19T12:00:00+00:00', + 'last_updated': '2026-04-19T12:10:00+00:00', + } + server = MockHomeAssistant() + server.json('GET', '/api/states/light.bedroom_light', payload) + + state = await server.backend().get_state('light.bedroom_light') + + assert state.last_updated == '2026-04-19T12:10:00+00:00' + + async def test_get_state_raises_for_missing_entity(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states/light.missing', {'message': 'Entity not found'}, status_code=404) + + with pytest.raises(ValueError, match=r"Entity 'light\.missing' was not found\."): + await server.backend().get_state('light.missing') + + +class TestHomeAssistantClient: + async def test_backend_can_close_owned_client(self) -> None: + backend = HomeAssistantBackend(url='http://example.test', token='token') + + await backend.aclose() + + assert backend.client.is_closed + + async def test_backend_does_not_close_injected_client(self) -> None: + client = httpx.AsyncClient(base_url='http://example.test') + backend = HomeAssistantBackend(url='http://example.test', token='token', client=client) + + await backend.aclose() + + assert not client.is_closed + await client.aclose() + + def test_backend_rejects_negative_verification_poll_attempts(self) -> None: + with pytest.raises(ValueError, match='verification_poll_attempts'): + HomeAssistantBackend( + url='http://example.test', + token='token', + verification_poll_attempts=-1, + ) + + def test_backend_rejects_negative_verification_poll_interval(self) -> None: + with pytest.raises(ValueError, match='verification_poll_interval'): + HomeAssistantBackend( + url='http://example.test', + token='token', + verification_poll_interval=-0.1, + ) + + +class TestHomeAssistantCallService: + async def test_call_service_uses_cached_catalog_and_normalizes_changed_states(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + server.json('GET', '/api/states/light.bedroom_light', _states_payload()[0]) + + def turn_off(request: httpx.Request) -> httpx.Response: + assert 'return_response' not in request.url.params + return httpx.Response(200, json=_states_payload()[:1]) + + server.route('POST', '/api/services/light/turn_off', turn_off) + backend = server.backend() + + await backend.list_services(domain='light') + result = await backend.call_service( + domain='light', + entity_id='light.bedroom_light', + service_name='turn_off', + ) + + assert server.count('GET', '/api/services') == 1 + assert len(result.changed_states) == 1 + assert result.changed_states[0].entity_id == 'light.bedroom_light' + assert result.service_response is None + assert result.verified_state == result.changed_states[0] + + async def test_call_service_forces_return_response_for_required_services(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + server.json('GET', '/api/states/climate.home', {'message': 'Entity not found'}, status_code=404) + + def set_temperature(request: httpx.Request) -> httpx.Response: + assert 'return_response' in request.url.params + return httpx.Response( + 200, + json={ + 'changed_states': _states_payload()[:1], + 'service_response': { + 'weather.home': { + 'forecast': [{'condition': 'sunny'}], + } + }, + }, + ) + + server.route('POST', '/api/services/climate/set_temperature', set_temperature) + + result = await server.backend().call_service( + domain='climate', + entity_id='climate.home', + service_name='set_temperature', + temperature=21, + ) + + assert len(result.changed_states) == 1 + assert result.service_response == { + 'weather.home': { + 'forecast': [{'condition': 'sunny'}], + } + } + assert result.verified_state is None + + async def test_call_service_can_request_optional_service_response(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + server.json('GET', '/api/states/light.bedroom_light', _states_payload()[0]) + + def turn_off(request: httpx.Request) -> httpx.Response: + assert 'return_response' in request.url.params + return httpx.Response( + 200, + json={ + 'changed_states': _states_payload()[:1], + 'service_response': {'light.bedroom_light': {'acknowledged': True}}, + }, + ) + + server.route('POST', '/api/services/light/turn_off', turn_off) + + result = await server.backend().call_service( + domain='light', + entity_id='light.bedroom_light', + service_name='turn_off', + want_response=True, + ) + + assert result.service_response == {'light.bedroom_light': {'acknowledged': True}} + assert result.verified_state == result.changed_states[0] + + async def test_call_service_polls_for_state_when_service_returns_no_verification_payload(self) -> None: + state_requests = 0 + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + server.json('POST', '/api/services/light/turn_off', []) + + def get_state(request: httpx.Request) -> httpx.Response: + nonlocal state_requests + state_requests += 1 + if state_requests < 3: + return httpx.Response( + 200, + json={ + **_states_payload()[0], + 'state': 'off', + 'last_changed': '2026-04-19T12:00:00+00:00', + 'last_updated': '2026-04-19T12:00:00+00:00', + }, + ) + return httpx.Response(200, json=_states_payload()[0]) + + server.route('GET', '/api/states/light.bedroom_light', get_state) + + result = await server.backend(verification_poll_interval=0.0).call_service( + domain='light', + entity_id='light.bedroom_light', + service_name='turn_off', + ) + + assert result.changed_states == [] + assert result.service_response is None + assert result.verified_state is not None + assert result.verified_state.entity_id == 'light.bedroom_light' + assert result.verified_state.state == 'on' + assert state_requests == 3 + + async def test_call_service_ignores_verification_http_errors(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + server.json('POST', '/api/services/light/turn_off', []) + server.json('GET', '/api/states/light.bedroom_light', {'message': 'server error'}, status_code=500) + + result = await server.backend(verification_poll_attempts=1).call_service( + domain='light', + entity_id='light.bedroom_light', + service_name='turn_off', + ) + + assert result.changed_states == [] + assert result.service_response is None + assert result.verified_state is None + + async def test_call_service_raises_for_unknown_service(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + + with pytest.raises(ValueError, match='Service light.missing was not found.'): + await server.backend().call_service( + domain='light', + entity_id='light.bedroom_light', + service_name='missing', + ) + + async def test_call_service_raises_when_home_assistant_reports_missing_service(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/services', _services_payload()) + server.json('GET', '/api/states/light.bedroom_light', _states_payload()[0]) + server.json('POST', '/api/services/light/turn_off', {'message': 'not found'}, status_code=404) + + with pytest.raises(ValueError, match='Service light.turn_off was not found.'): + await server.backend().call_service( + domain='light', + entity_id='light.bedroom_light', + service_name='turn_off', + ) + + def test_build_service_call_params_for_services_without_response_data(self) -> None: + service_description = ServiceDescription.model_validate({}) + + assert HomeAssistantBackend._build_service_call_params(service_description, want_response=False) is None + + async def test_get_verified_state_returns_latest_state_when_state_does_not_change(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states/light.bedroom_light', _states_payload()[0]) + previous_state = EntityState( + entity_id='light.bedroom_light', + last_updated='2026-04-19T12:00:00+00:00', + state='on', + ) + + verified_state = await server.backend(verification_poll_attempts=1)._get_verified_state( + entity_id='light.bedroom_light', + domain='light', + service_name='turn_on', + result=ServiceCallResult(), + previous_state=previous_state, + ) + + assert verified_state == previous_state + + async def test_get_verified_state_returns_none_when_state_cannot_be_verified(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states/light.missing', {}, status_code=404) + + verified_state = await server.backend(verification_poll_attempts=1)._get_verified_state( + entity_id='light.missing', + domain='light', + service_name='turn_on', + result=ServiceCallResult(), + ) + + assert verified_state is None + + async def test_get_verified_state_returns_none_when_polling_is_disabled(self) -> None: + server = MockHomeAssistant() + server.json('GET', '/api/states/light.bedroom_light', _states_payload()[0]) + + verified_state = await server.backend(verification_poll_attempts=0)._get_verified_state( + entity_id='light.bedroom_light', + domain='light', + service_name='turn_on', + result=ServiceCallResult(), + ) + + assert verified_state is None + + def test_state_changed_returns_true_without_previous_state(self) -> None: + assert HomeAssistantBackend._state_changed( + None, + EntityState( + entity_id='light.bedroom_light', + last_updated='2026-04-19T12:00:00+00:00', + state='on', + ), + ) + + +class TestHomeAssistantModels: + @pytest.mark.parametrize( + 'field_name, field, expected_type', + [ + ('select', {'selector': {'select': {'options': ['on', 'off']}}}, 'string'), + ('color_temp', {'selector': {'color_temp': {'min': 2000, 'max': 6500}}}, 'number'), + ('color_rgb', {'selector': {'color_rgb': {}}}, 'object'), + ('boolean', {'selector': {'boolean': {}}}, 'boolean'), + ('constant', {'selector': {'constant': {'value': True, 'label': 'Enabled'}}}, 'boolean'), + ('text', {'selector': {'text': {'multiline': False}}}, 'string'), + ('object', {'selector': {'object': {}}}, 'object'), + ('entity', {'selector': {'entity': {'domain': 'light'}}}, 'entity_id'), + ('device', {'selector': {'device': {'integration': 'test'}}}, 'device_id'), + ('area', {'selector': {'area': {}}}, 'area_id'), + ('fallback', {}, 'string'), + ], + ) + def test_service_field_description_maps_selector_variants( + self, + field_name: str, + field: dict[str, object], + expected_type: str, + ) -> None: + argument = ServiceFieldDescription.model_validate(field).to_service_argument(field_name) + + assert argument.name == field_name + assert argument.value_type == expected_type + + def test_service_field_description_preserves_select_options(self) -> None: + field = ServiceFieldDescription.model_validate({'selector': {'select': {'options': ['on', 'off']}}}) + + assert field.to_service_argument('mode').enum == ('on', 'off') + + def test_service_field_description_defaults_select_options_to_empty(self) -> None: + field = ServiceFieldDescription.model_validate({'selector': {'select': {}}}) + + assert field.to_service_argument('mode').enum == () + + def test_service_field_description_extracts_dict_select_option_values(self) -> None: + field = ServiceFieldDescription.model_validate( + { + 'selector': { + 'select': { + 'options': [ + {'value': 'heat', 'label': 'Heat'}, + {'value': 'cool', 'label': 'Cool'}, + ] + } + } + } + ) + + assert field.to_service_argument('hvac_mode').enum == ('heat', 'cool') + + def test_service_field_description_preserves_color_temp_range(self) -> None: + field = ServiceFieldDescription.model_validate({'selector': {'color_temp': {'min': 2000, 'max': 6500}}}) + + argument = field.to_service_argument('color_temp_kelvin') + + assert argument.minimum == 2000 + assert argument.maximum == 6500 + + def test_service_call_result_defaults_to_empty_changed_states(self) -> None: + result = HAServiceCallResult.model_validate({'service_response': {'light.test': {'acknowledged': True}}}) + + assert result.changed_states == [] + + def test_entity_domain_is_derived_from_entity_id(self) -> None: + ha_state = HAEntityState.model_validate(_states_payload()[0]) + + assert ha_state.domain == 'light'