From 521666f8964ed5fb8cf6dad50f451179b080de3c Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Wed, 19 Mar 2025 12:03:53 +0100 Subject: [PATCH 1/4] Initial Prototype to generate Infrahub Schema from Pydantic models --- .../python-sdk/examples/schema_pydantic.py | 39 ++ infrahub_sdk/schema/__init__.py | 6 + infrahub_sdk/schema/main.py | 2 +- infrahub_sdk/schema/pydantic_utils.py | 167 ++++++++ tests/unit/sdk/test_pydantic.py | 367 ++++++++++++++++++ 5 files changed, 580 insertions(+), 1 deletion(-) create mode 100644 docs/docs/python-sdk/examples/schema_pydantic.py create mode 100644 infrahub_sdk/schema/pydantic_utils.py create mode 100644 tests/unit/sdk/test_pydantic.py diff --git a/docs/docs/python-sdk/examples/schema_pydantic.py b/docs/docs/python-sdk/examples/schema_pydantic.py new file mode 100644 index 00000000..8767f7f9 --- /dev/null +++ b/docs/docs/python-sdk/examples/schema_pydantic.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from asyncio import run as aiorun + +from typing import Annotated + +from pydantic import BaseModel, Field +from infrahub_sdk import InfrahubClient +from rich import print as rprint +from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic + + +class Tag(BaseModel): + name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] + label: str | None = Field(description="The label of the tag") + description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None + + +class Car(BaseModel): + name: str = Field(description="The name of the car") + tags: list[Tag] + owner: Annotated[Person, RelParam(identifier="car__person")] + secondary_owner: Person | None = None + + +class Person(BaseModel): + name: str + cars: Annotated[list[Car] | None, RelParam(identifier="car__person")] = None + + +async def main(): + client = InfrahubClient() + schema = from_pydantic(models=[Person, Car, Tag]) + rprint(schema.to_schema_dict()) + response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) + rprint(response) + +if __name__ == "__main__": + aiorun(main()) diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 080d7237..35190836 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -19,6 +19,7 @@ from ..graphql import Mutation from ..queries import SCHEMA_HASH_SYNC_STATUS from .main import ( + AttributeKind, AttributeSchema, AttributeSchemaAPI, BranchSchema, @@ -36,6 +37,7 @@ SchemaRootAPI, TemplateSchemaAPI, ) +from .pydantic_utils import InfrahubAttributeParam, InfrahubRelationshipParam, from_pydantic if TYPE_CHECKING: from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync @@ -45,11 +47,14 @@ __all__ = [ + "AttributeKind", "AttributeSchema", "AttributeSchemaAPI", "BranchSupportType", "GenericSchema", "GenericSchemaAPI", + "InfrahubAttributeParam", + "InfrahubRelationshipParam", "NodeSchema", "NodeSchemaAPI", "ProfileSchemaAPI", @@ -60,6 +65,7 @@ "SchemaRoot", "SchemaRootAPI", "TemplateSchemaAPI", + "from_pydantic", ] diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py index af5556b3..2a06cb3f 100644 --- a/infrahub_sdk/schema/main.py +++ b/infrahub_sdk/schema/main.py @@ -338,7 +338,7 @@ class SchemaRoot(BaseModel): node_extensions: list[NodeExtensionSchema] = Field(default_factory=list) def to_schema_dict(self) -> dict[str, Any]: - return self.model_dump(exclude_unset=True, exclude_defaults=True) + return self.model_dump(exclude_defaults=True, mode="json") class SchemaRootAPI(BaseModel): diff --git a/infrahub_sdk/schema/pydantic_utils.py b/infrahub_sdk/schema/pydantic_utils.py new file mode 100644 index 00000000..60cc3b9a --- /dev/null +++ b/infrahub_sdk/schema/pydantic_utils.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass +from types import UnionType +from typing import Any + +from pydantic import BaseModel +from pydantic.fields import FieldInfo, PydanticUndefined + +from infrahub_sdk.schema.main import AttributeSchema, NodeSchema, RelationshipSchema, SchemaRoot + +from .main import AttributeKind, BranchSupportType, SchemaState + +KIND_MAPPING: dict[type, AttributeKind] = { + int: AttributeKind.NUMBER, + float: AttributeKind.NUMBER, + str: AttributeKind.TEXT, + bool: AttributeKind.BOOLEAN, +} + + +@dataclass +class InfrahubAttributeParam: + state: SchemaState = SchemaState.PRESENT + kind: AttributeKind | None = None + label: str | None = None + unique: bool = False + branch: BranchSupportType | None = None + + +@dataclass +class InfrahubRelationshipParam: + identifier: str | None = None + branch: BranchSupportType | None = None + + +@dataclass +class InfrahubFieldInfo: + name: str + types: list[type] + optional: bool + default: Any + + @property + def primary_type(self) -> type: + if len(self.types) == 0: + raise ValueError("No types found") + if self.is_list: + return typing.get_args(self.types[0])[0] + + return self.types[0] + + @property + def is_attribute(self) -> bool: + return self.primary_type in KIND_MAPPING + + @property + def is_relationship(self) -> bool: + return issubclass(self.primary_type, BaseModel) + + @property + def is_list(self) -> bool: + return typing.get_origin(self.types[0]) is list + + def to_dict(self) -> dict: + return { + "name": self.name, + "primary_type": self.primary_type, + "optional": self.optional, + "default": self.default, + "is_attribute": self.is_attribute, + "is_relationship": self.is_relationship, + "is_list": self.is_list, + } + + +def analyze_field(field_name: str, field: FieldInfo) -> InfrahubFieldInfo: + clean_types = [] + if isinstance(field.annotation, UnionType) or ( + hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] + ): + clean_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr] + else: + clean_types.append(field.annotation) + + return InfrahubFieldInfo( + name=field.alias or field_name, + types=clean_types, + optional=not field.is_required(), + default=field.default if field.default is not PydanticUndefined else None, + ) + + +def get_attribute_kind(field: FieldInfo) -> AttributeKind: + if field.annotation in KIND_MAPPING: + return KIND_MAPPING[field.annotation] + + if isinstance(field.annotation, UnionType) or ( + hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] + ): + valid_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr] + if len(valid_types) == 1 and valid_types[0] in KIND_MAPPING: + return KIND_MAPPING[valid_types[0]] + + raise ValueError(f"Unknown field type: {field.annotation}") + + +def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: # noqa: ARG001 + field_param = InfrahubAttributeParam() + field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubAttributeParam)] + if len(field_params) == 1: + field_param = field_params[0] + + return AttributeSchema( + name=field_name, + label=field_param.label, + description=field.description, + kind=field_param.kind or get_attribute_kind(field), + optional=not field.is_required(), + unique=field_param.unique, + branch=field_param.branch, + ) + + +def field_to_relationship( + field_name: str, + field_info: InfrahubFieldInfo, + field: FieldInfo, + namespace: str = "Testing", +) -> RelationshipSchema: + field_param = InfrahubRelationshipParam() + field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubRelationshipParam)] + if len(field_params) == 1: + field_param = field_params[0] + + return RelationshipSchema( + name=field_name, + description=field.description, + peer=f"{namespace}{field_info.primary_type.__name__}", + identifier=field_param.identifier, + cardinality="many" if field_info.is_list else "one", + optional=field_info.optional, + branch=field_param.branch, + ) + + +def from_pydantic(models: list[type[BaseModel]], namespace: str = "Testing") -> SchemaRoot: + schema = SchemaRoot(version="1.0") + + for model in models: + node = NodeSchema( + name=model.__name__, + namespace=namespace, + ) + + for field_name, field in model.model_fields.items(): + field_info = analyze_field(field_name, field) + + if field_info.is_attribute: + node.attributes.append(field_to_attribute(field_name, field_info, field)) + elif field_info.is_relationship: + node.relationships.append(field_to_relationship(field_name, field_info, field, namespace)) + + schema.nodes.append(node) + + return schema diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py new file mode 100644 index 00000000..e5d18359 --- /dev/null +++ b/tests/unit/sdk/test_pydantic.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +from typing import Annotated, Optional + +import pytest +from pydantic import BaseModel, Field + +from infrahub_sdk.schema.main import AttributeKind, AttributeSchema, RelationshipSchema +from infrahub_sdk.schema.pydantic_utils import ( + InfrahubAttributeParam as AttrParam, +) +from infrahub_sdk.schema.pydantic_utils import ( + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, +) + + +class MyModel(BaseModel): + name: str + age: int + is_active: bool + opt_age: int | None = None + default_name: str = "some_default" + old_opt_age: Optional[int] = None # noqa: UP007 + + +class Tag(BaseModel): + name: str = Field(default="test_tag", description="The name of the tag") + description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None + label: Annotated[str, AttrParam(unique=True), Field(description="The label of the tag")] + + +class Car(BaseModel): + name: str + tags: list[Tag] + owner: Person + secondary_owner: Person | None = None + + +class Person(BaseModel): + name: str + cars: list[Car] | None = None + + +@pytest.mark.parametrize( + "field_name, expected_kind", + [ + ("name", "Text"), + ("age", "Number"), + ("is_active", "Boolean"), + ("opt_age", "Number"), + ("default_name", "Text"), + ("old_opt_age", "Number"), + ], +) +def test_get_field_kind(field_name, expected_kind): + assert get_attribute_kind(MyModel.model_fields[field_name]) == expected_kind + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + ( + "name", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "name", + "optional": False, + "primary_type": str, + }, + ), + ( + "age", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "age", + "optional": False, + "primary_type": int, + }, + ), + ( + "is_active", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "is_active", + "optional": False, + "primary_type": bool, + }, + ), + ( + "opt_age", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "opt_age", + "optional": True, + "primary_type": int, + }, + ), + ( + "default_name", + MyModel, + { + "default": "some_default", + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "default_name", + "optional": True, + "primary_type": str, + }, + ), + ( + "old_opt_age", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "old_opt_age", + "optional": True, + "primary_type": int, + }, + ), + ( + "description", + Tag, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "description", + "optional": True, + "primary_type": str, + }, + ), + ( + "name", + Tag, + { + "default": "test_tag", + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "name", + "optional": True, + "primary_type": str, + }, + ), + ( + "label", + Tag, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "label", + "optional": False, + "primary_type": str, + }, + ), + ( + "owner", + Car, + { + "default": None, + "is_attribute": False, + "is_list": False, + "is_relationship": True, + "name": "owner", + "optional": False, + "primary_type": Person, + }, + ), + ( + "tags", + Car, + { + "default": None, + "is_attribute": False, + "is_list": True, + "is_relationship": True, + "name": "tags", + "optional": False, + "primary_type": Tag, + }, + ), + ( + "secondary_owner", + Car, + { + "default": None, + "is_attribute": False, + "is_list": False, + "is_relationship": True, + "name": "secondary_owner", + "optional": True, + "primary_type": Person, + }, + ), + ], +) +def test_analyze_field(field_name: str, model: BaseModel, expected: dict): + assert analyze_field(field_name, model.model_fields[field_name]).to_dict() == expected + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + ( + "name", + MyModel, + AttributeSchema( + name="name", + kind=AttributeKind.TEXT, + optional=False, + ), + ), + ( + "age", + MyModel, + AttributeSchema( + name="age", + kind=AttributeKind.NUMBER, + optional=False, + ), + ), + ( + "is_active", + MyModel, + AttributeSchema( + name="is_active", + kind=AttributeKind.BOOLEAN, + optional=False, + ), + ), + ( + "opt_age", + MyModel, + AttributeSchema( + name="opt_age", + kind=AttributeKind.NUMBER, + optional=True, + ), + ), + ( + "default_name", + MyModel, + AttributeSchema( + name="default_name", + kind=AttributeKind.TEXT, + optional=True, + default="some_default", + ), + ), + ( + "old_opt_age", + MyModel, + AttributeSchema( + name="old_opt_age", + kind=AttributeKind.NUMBER, + optional=True, + ), + ), + ( + "description", + Tag, + AttributeSchema( + name="description", + kind=AttributeKind.TEXTAREA, + optional=True, + ), + ), + ( + "name", + Tag, + AttributeSchema( + name="name", + description="The name of the tag", + kind=AttributeKind.TEXT, + optional=True, + ), + ), + ( + "label", + Tag, + AttributeSchema( + name="label", + description="The label of the tag", + kind=AttributeKind.TEXT, + optional=False, + unique=True, + ), + ), + ], +) +def test_field_to_attribute(field_name: str, model: BaseModel, expected: AttributeSchema): + field = model.model_fields[field_name] + field_info = analyze_field(field_name, field) + assert field_to_attribute(field_name, field_info, field) == expected + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + ( + "owner", + Car, + RelationshipSchema( + name="owner", + peer="TestingPerson", + cardinality="one", + optional=False, + ), + ), + ( + "tags", + Car, + RelationshipSchema( + name="tags", + peer="TestingTag", + cardinality="many", + optional=False, + ), + ), + ( + "secondary_owner", + Car, + RelationshipSchema( + name="secondary_owner", + peer="TestingPerson", + cardinality="one", + optional=True, + ), + ), + ], +) +def test_field_to_relationship(field_name: str, model: BaseModel, expected: RelationshipSchema): + field = model.model_fields[field_name] + field_info = analyze_field(field_name, field) + assert field_to_relationship(field_name, field_info, field) == expected + + +def test_related_models(): + schemas = from_pydantic(models=[Person, Car, Tag]) + assert len(schemas.nodes) == 3 From b8221e1cec58311cc5a2545e7658161c5c8f8393 Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Mon, 24 Mar 2025 12:00:19 +0100 Subject: [PATCH 2/4] Add typing support for get | filters | all methods when using Pydantic --- .../{schema_pydantic.py => pydantic_car.py} | 29 +- .../python-sdk/examples/pydantic_infra.py | 113 +++++++ infrahub_sdk/client.py | 234 +++++++++++++- infrahub_sdk/schema/__init__.py | 25 +- infrahub_sdk/schema/pydantic_utils.py | 247 +++++++++++++-- tests/unit/sdk/test_pydantic.py | 285 +++++++++++++++--- 6 files changed, 836 insertions(+), 97 deletions(-) rename docs/docs/python-sdk/examples/{schema_pydantic.py => pydantic_car.py} (53%) create mode 100644 docs/docs/python-sdk/examples/pydantic_infra.py diff --git a/docs/docs/python-sdk/examples/schema_pydantic.py b/docs/docs/python-sdk/examples/pydantic_car.py similarity index 53% rename from docs/docs/python-sdk/examples/schema_pydantic.py rename to docs/docs/python-sdk/examples/pydantic_car.py index 8767f7f9..321c7d31 100644 --- a/docs/docs/python-sdk/examples/schema_pydantic.py +++ b/docs/docs/python-sdk/examples/pydantic_car.py @@ -4,36 +4,47 @@ from typing import Annotated -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from infrahub_sdk import InfrahubClient from rich import print as rprint -from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic +from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericModel -class Tag(BaseModel): +class Tag(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Tag", namespace="Test", human_readable_fields=["name__value"]) + ) + name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] label: str | None = Field(description="The label of the tag") description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None -class Car(BaseModel): +class TestCar(NodeModel): name: str = Field(description="The name of the car") tags: list[Tag] - owner: Annotated[Person, RelParam(identifier="car__person")] - secondary_owner: Person | None = None + owner: Annotated[TestPerson, RelParam(identifier="car__person")] + secondary_owner: TestPerson | None = None -class Person(BaseModel): +class TestPerson(GenericModel): name: str - cars: Annotated[list[Car] | None, RelParam(identifier="car__person")] = None + +class TestCarOwner(NodeModel, TestPerson): + cars: Annotated[list[TestCar] | None, RelParam(identifier="car__person")] = None async def main(): client = InfrahubClient() - schema = from_pydantic(models=[Person, Car, Tag]) + schema = from_pydantic(models=[TestPerson, TestCar, Tag, TestPerson, TestCarOwner]) rprint(schema.to_schema_dict()) response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) rprint(response) + # Create a Tag + tag = await client.create("TestTag", name="Blue", label="Blue") + await tag.save(allow_upsert=True) + + if __name__ == "__main__": aiorun(main()) diff --git a/docs/docs/python-sdk/examples/pydantic_infra.py b/docs/docs/python-sdk/examples/pydantic_infra.py new file mode 100644 index 00000000..5ff137a4 --- /dev/null +++ b/docs/docs/python-sdk/examples/pydantic_infra.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from asyncio import run as aiorun + +from infrahub_sdk.async_typer import AsyncTyper + +from typing import Annotated + +from pydantic import BaseModel, Field, ConfigDict +from infrahub_sdk import InfrahubClient +from rich import print as rprint +from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericSchema, GenericModel, RelationshipKind + + +app = AsyncTyper() + + +class Site(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + ) + + name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the site") + + +class Vlan(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"]) + ) + + name: str + vlan_id: int + description: str | None = None + + +class Device(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + ) + + name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the car") + site: Annotated[Site, RelParam(kind=RelationshipKind.ATTRIBUTE, identifier="device__site")] + interfaces: Annotated[list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces")] = Field(default_factory=list) + + +class Interface(GenericModel): + model_config = ConfigDict( + generic_schema=GenericSchema(name="Interface", namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"]) + ) + + device: Annotated[Device, RelParam(kind=RelationshipKind.PARENT, identifier="device__interfaces")] + name: str + description: str | None = None + +class L2Interface(Interface): + model_config = ConfigDict( + node_schema=NodeSchema(name="L2Interface", namespace="Infra") + ) + + vlans: list[Vlan] = Field(default_factory=list) + +class LoopbackInterface(Interface): + model_config = ConfigDict( + node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra") + ) + + + +@app.command() +async def load_schema(): + client = InfrahubClient() + schema = from_pydantic(models=[Site, Device, Interface, L2Interface, LoopbackInterface, Vlan]) + rprint(schema.to_schema_dict()) + response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) + rprint(response) + + +@app.command() +async def load_data(): + client = InfrahubClient() + + atl = await client.create("InfraSite", name="ATL") + await atl.save(allow_upsert=True) + cdg = await client.create("InfraSite", name="CDG") + await cdg.save(allow_upsert=True) + + device1 = await client.create("InfraDevice", name="atl1-dev1", site=atl) + await device1.save(allow_upsert=True) + device2 = await client.create("InfraDevice", name="atl1-dev2", site=atl) + await device2.save(allow_upsert=True) + + lo0dev1 = await client.create("InfraLoopbackInterface", name="lo0", device=device1) + await lo0dev1.save(allow_upsert=True) + lo0dev2 = await client.create("InfraLoopbackInterface", name="lo0", device=device2) + await lo0dev2.save(allow_upsert=True) + + for idx in range(1, 3): + interface = await client.create("InfraL2Interface", name=f"Ethernet{idx}", device=device1) + await interface.save(allow_upsert=True) + + +@app.command() +async def query_data(): + client = InfrahubClient() + sites = await client.all(kind=Site) + + breakpoint() + devices = await client.all(kind=Device) + for device in devices: + rprint(device) + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index fffa8164..24727d2c 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -50,7 +50,7 @@ from .protocols_base import CoreNode, CoreNodeSync from .queries import QUERY_USER, get_commit_update_mutation from .query_groups import InfrahubGroupContext, InfrahubGroupContextSync -from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI +from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI, SchemaModel from .store import NodeStore, NodeStoreSync from .task.manager import InfrahubTaskManager, InfrahubTaskManagerSync from .timestamp import Timestamp @@ -63,6 +63,7 @@ from .context import RequestContext +SchemaModelType = TypeVar("SchemaModelType", bound=SchemaModel) SchemaType = TypeVar("SchemaType", bound=CoreNode) SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync) @@ -402,6 +403,63 @@ async def get( **kwargs: Any, ) -> SchemaType: ... + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[False], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType | None: ... + + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[True], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: bool = ..., + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + @overload async def get( self, @@ -461,7 +519,7 @@ async def get( async def get( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, raise_when_missing: bool = True, at: Timestamp | None = None, branch: str | None = None, @@ -475,7 +533,7 @@ async def get( prefetch_relationships: bool = False, property: bool = False, **kwargs: Any, - ) -> InfrahubNode | SchemaType | None: + ) -> InfrahubNode | SchemaType | SchemaModelType | None: branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -558,7 +616,7 @@ async def _process_nodes_and_relationships( async def count( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -599,6 +657,25 @@ async def all( order: Order | None = ..., ) -> list[SchemaType]: ... + @overload + async def all( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + ) -> list[SchemaModelType]: ... + @overload async def all( self, @@ -620,7 +697,7 @@ async def all( async def all( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -634,7 +711,7 @@ async def all( property: bool = False, parallel: bool = False, order: Order | None = None, - ) -> list[InfrahubNode] | list[SchemaType]: + ) -> list[InfrahubNode] | list[SchemaType] | list[SchemaModelType]: """Retrieve all nodes of a given kind Args: @@ -693,6 +770,27 @@ async def filters( **kwargs: Any, ) -> list[SchemaType]: ... + @overload + async def filters( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + **kwargs: Any, + ) -> list[SchemaModelType]: ... + @overload async def filters( self, @@ -716,7 +814,7 @@ async def filters( async def filters( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -732,7 +830,7 @@ async def filters( parallel: bool = False, order: Order | None = None, **kwargs: Any, - ) -> list[InfrahubNode] | list[SchemaType]: + ) -> list[InfrahubNode] | list[SchemaType] | list[SchemaModelType]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -756,6 +854,7 @@ async def filters( list[InfrahubNodeSync]: List of Nodes that match the given filters. """ branch = branch or self.default_branch + schema = await self.schema.get(kind=kind, branch=branch) if at: at = Timestamp(at) @@ -845,6 +944,10 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]: for node in related_nodes: if node.id: self.store.set(node=node) + + if isinstance(kind, type) and issubclass(kind, SchemaModel): + return [kind.from_node(node) for node in nodes] # type: ignore[return-value] + return nodes def clone(self) -> InfrahubClient: @@ -1679,7 +1782,7 @@ def execute_graphql( def count( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1720,6 +1823,25 @@ def all( order: Order | None = ..., ) -> list[SchemaTypeSync]: ... + @overload + def all( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + ) -> list[SchemaModelType]: ... + @overload def all( self, @@ -1741,7 +1863,7 @@ def all( def all( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1755,7 +1877,7 @@ def all( property: bool = False, parallel: bool = False, order: Order | None = None, - ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: + ) -> list[InfrahubNodeSync] | list[SchemaTypeSync] | list[SchemaModelType]: """Retrieve all nodes of a given kind Args: @@ -1849,6 +1971,27 @@ def filters( **kwargs: Any, ) -> list[SchemaTypeSync]: ... + @overload + def filters( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + **kwargs: Any, + ) -> list[SchemaModelType]: ... + @overload def filters( self, @@ -1872,7 +2015,7 @@ def filters( def filters( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1888,7 +2031,7 @@ def filters( parallel: bool = False, order: Order | None = None, **kwargs: Any, - ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: + ) -> list[InfrahubNodeSync] | list[SchemaTypeSync] | list[SchemaModelType]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -2002,6 +2145,10 @@ def process_non_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]] for node in related_nodes: if node.id: self.store.set(node=node) + + if isinstance(kind, type) and issubclass(kind, SchemaModel): + return [kind.from_node(node) for node in nodes] # type: ignore[return-value] + return nodes @overload @@ -2061,6 +2208,63 @@ def get( **kwargs: Any, ) -> SchemaTypeSync: ... + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[False], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType | None: ... + + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[True], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: bool = ..., + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + @overload def get( self, @@ -2120,7 +2324,7 @@ def get( def get( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, raise_when_missing: bool = True, at: Timestamp | None = None, branch: str | None = None, @@ -2134,7 +2338,7 @@ def get( prefetch_relationships: bool = False, property: bool = False, **kwargs: Any, - ) -> InfrahubNodeSync | SchemaTypeSync | None: + ) -> InfrahubNodeSync | SchemaTypeSync | SchemaModelType | None: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 35190836..41ffcaa9 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -37,10 +37,17 @@ SchemaRootAPI, TemplateSchemaAPI, ) -from .pydantic_utils import InfrahubAttributeParam, InfrahubRelationshipParam, from_pydantic +from .pydantic_utils import ( + GenericModel, + InfrahubAttributeParam, + InfrahubRelationshipParam, + NodeModel, + SchemaModel, + from_pydantic, +) if TYPE_CHECKING: - from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync + from ..client import InfrahubClient, InfrahubClientSync, SchemaModelType, SchemaType, SchemaTypeSync from ..node import InfrahubNode, InfrahubNodeSync InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] @@ -51,10 +58,12 @@ "AttributeSchema", "AttributeSchemaAPI", "BranchSupportType", + "GenericModel", "GenericSchema", "GenericSchemaAPI", "InfrahubAttributeParam", "InfrahubRelationshipParam", + "NodeModel", "NodeSchema", "NodeSchemaAPI", "ProfileSchemaAPI", @@ -62,6 +71,7 @@ "RelationshipKind", "RelationshipSchema", "RelationshipSchemaAPI", + "SchemaModel", "SchemaRoot", "SchemaRootAPI", "TemplateSchemaAPI", @@ -162,14 +172,17 @@ def _validate_load_schema_response(response: httpx.Response) -> SchemaLoadRespon raise InvalidResponseError(message=f"Invalid response received from server HTTP {response.status_code}") @staticmethod - def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str: + def _get_schema_name(schema: type[SchemaType | SchemaTypeSync | SchemaModelType] | str) -> str: if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr] return schema.__name__ # type: ignore[union-attr] + if isinstance(schema, type) and issubclass(schema, SchemaModel): + return schema.get_kind() + if isinstance(schema, str): return schema - raise ValueError("schema must be a protocol or a string") + raise ValueError("schema must be a protocol, a SchemaModel, or a string") class InfrahubSchema(InfrahubSchemaBase): @@ -179,7 +192,7 @@ def __init__(self, client: InfrahubClient): async def get( self, - kind: type[SchemaType | SchemaTypeSync] | str, + kind: type[SchemaType | SchemaTypeSync | SchemaModelType] | str, branch: str | None = None, refresh: bool = False, timeout: int | None = None, @@ -486,7 +499,7 @@ def all( def get( self, - kind: type[SchemaType | SchemaTypeSync] | str, + kind: type[SchemaType | SchemaTypeSync | SchemaModelType] | str, branch: str | None = None, refresh: bool = False, timeout: int | None = None, diff --git a/infrahub_sdk/schema/pydantic_utils.py b/infrahub_sdk/schema/pydantic_utils.py index 60cc3b9a..0f605c57 100644 --- a/infrahub_sdk/schema/pydantic_utils.py +++ b/infrahub_sdk/schema/pydantic_utils.py @@ -1,16 +1,29 @@ from __future__ import annotations +import re import typing from dataclasses import dataclass from types import UnionType -from typing import Any +from typing import TYPE_CHECKING, Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from pydantic.fields import FieldInfo, PydanticUndefined - -from infrahub_sdk.schema.main import AttributeSchema, NodeSchema, RelationshipSchema, SchemaRoot - -from .main import AttributeKind, BranchSupportType, SchemaState +from typing_extensions import Self + +from .main import ( + AttributeKind, + AttributeSchema, + BranchSupportType, + GenericSchema, + NodeSchema, + RelationshipKind, + RelationshipSchema, + SchemaRoot, + SchemaState, +) + +if TYPE_CHECKING: + from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync KIND_MAPPING: dict[type, AttributeKind] = { int: AttributeKind.NUMBER, @@ -19,6 +32,42 @@ bool: AttributeKind.BOOLEAN, } +NAMESPACE_REGEX = r"^[A-Z][a-z0-9]+$" +NODE_KIND_REGEX = r"^[A-Z][a-zA-Z0-9]+$" + + +class SchemaModel(BaseModel): + id: str | None = Field(default=None, description="The ID of the node") + + @classmethod + def get_kind(cls) -> str: + return get_kind(cls) + + @classmethod + def from_node(cls, node: InfrahubNode | InfrahubNodeSync) -> Self: + data = {} + for field_name, field in cls.model_fields.items(): + field_info = analyze_field(field_name, field) + if field_name == "id": + data[field_name] = node.id + elif field_info.is_attribute: + attr = getattr(node, field_name) + data[field_name] = attr.value + + # elif field_info.is_relationship: + # rel = getattr(node, field_name) + # data[field_name] = rel.value + + return cls(**data) + + +class NodeModel(SchemaModel): + pass + + +class GenericModel(SchemaModel): + pass + @dataclass class InfrahubAttributeParam: @@ -31,6 +80,7 @@ class InfrahubAttributeParam: @dataclass class InfrahubRelationshipParam: + kind: RelationshipKind | None = None identifier: str | None = None branch: BranchSupportType | None = None @@ -46,6 +96,10 @@ class InfrahubFieldInfo: def primary_type(self) -> type: if len(self.types) == 0: raise ValueError("No types found") + + # if isinstance(self.primary_type, ForwardRef): + # raise TypeError("Forward References are not supported yet, please ensure the models are defined in the right order") + if self.is_list: return typing.get_args(self.types[0])[0] @@ -61,6 +115,7 @@ def is_relationship(self) -> bool: @property def is_list(self) -> bool: + # breakpoint() return typing.get_origin(self.types[0]) is list def to_dict(self) -> dict: @@ -106,12 +161,16 @@ def get_attribute_kind(field: FieldInfo) -> AttributeKind: raise ValueError(f"Unknown field type: {field.annotation}") -def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: # noqa: ARG001 +def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: field_param = InfrahubAttributeParam() field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubAttributeParam)] if len(field_params) == 1: field_param = field_params[0] + pattern = field._attributes_set.get("pattern", None) + max_length = field._attributes_set.get("max_length", None) + min_length = field._attributes_set.get("min_length", None) + return AttributeSchema( name=field_name, label=field_param.label, @@ -120,6 +179,10 @@ def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: Fi optional=not field.is_required(), unique=field_param.unique, branch=field_param.branch, + default_value=field_info.default, + regex=str(pattern) if pattern else None, + max_length=int(str(max_length)) if max_length else None, + min_length=int(str(min_length)) if min_length else None, ) @@ -127,7 +190,6 @@ def field_to_relationship( field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo, - namespace: str = "Testing", ) -> RelationshipSchema: field_param = InfrahubRelationshipParam() field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubRelationshipParam)] @@ -137,7 +199,7 @@ def field_to_relationship( return RelationshipSchema( name=field_name, description=field.description, - peer=f"{namespace}{field_info.primary_type.__name__}", + peer=get_kind(field_info.primary_type), identifier=field_param.identifier, cardinality="many" if field_info.is_list else "one", optional=field_info.optional, @@ -145,23 +207,164 @@ def field_to_relationship( ) -def from_pydantic(models: list[type[BaseModel]], namespace: str = "Testing") -> SchemaRoot: - schema = SchemaRoot(version="1.0") +def extract_validate_generic(model: type[BaseModel]) -> list[str]: + return [get_kind(ancestor) for ancestor in model.__bases__ if issubclass(ancestor, GenericModel)] - for model in models: - node = NodeSchema( - name=model.__name__, - namespace=namespace, + +def validate_kind(kind: str) -> tuple[str, str]: + # First, handle transition from a lowercase to uppercase + name_with_spaces = re.sub(r"([a-z])([A-Z])", r"\1 \2", kind) + + # Then, handle consecutive uppercase letters followed by a lowercase + # (e.g., "HTTPRequest" -> "HTTP Request") + name_with_spaces = re.sub(r"([A-Z])([A-Z][a-z])", r"\1 \2", name_with_spaces) + + name_parts = name_with_spaces.split(" ") + + if len(name_parts) == 1: + raise ValueError(f"Invalid kind: {kind}, must contain a Namespace and a Name") + kind_namespace = name_parts[0] + kind_name = "".join(name_parts[1:]) + + if not kind_namespace[0].isupper(): + raise ValueError(f"Invalid namespace: {kind_namespace}, must start with an uppercase letter") + + return kind_namespace, kind_name + + +def is_generic(model: type[BaseModel]) -> bool: + return GenericModel in model.__bases__ + + +def get_kind(model: type[BaseModel]) -> str: + node_schema: NodeSchema | None = model.model_config.get("node_schema") or None # type: ignore[assignment] + generic_schema: GenericSchema | None = model.model_config.get("generic_schema") or None # type: ignore[assignment] + + if is_generic(model) and generic_schema: + return generic_schema.kind + if node_schema: + return node_schema.kind + namespace, name = validate_kind(model.__name__) + return f"{namespace}{name}" + + +def get_generics(model: type[BaseModel]) -> list[type[GenericModel]]: + return [ancestor for ancestor in model.__bases__ if issubclass(ancestor, GenericModel)] + + +def _add_fields( + node: NodeSchema | GenericSchema, model: type[BaseModel], inherited_fields: dict[str, dict[str, Any]] | None = None +) -> None: + for field_name, field in model.model_fields.items(): + if ( + inherited_fields + and field_name in inherited_fields + and field._attributes_set == inherited_fields[field_name] + ): + continue + + if field_name == "id": + continue + + field_info = analyze_field(field_name, field) + + if field_info.is_attribute: + node.attributes.append(field_to_attribute(field_name, field_info, field)) + elif field_info.is_relationship: + node.relationships.append(field_to_relationship(field_name, field_info, field)) + + +def model_to_node(model: type[BaseModel]) -> NodeSchema | GenericSchema: + # ------------------------------------------------------------ + # GenericSchema + # ------------------------------------------------------------ + if GenericModel in model.__bases__: + generic_schema: GenericSchema | None = model.model_config.get("generic_schema") or None # type: ignore[assignment] + + if not generic_schema: + namespace, name = validate_kind(model.__name__) + + generic = GenericSchema( + name=generic_schema.name if generic_schema else name, + namespace=generic_schema.namespace if generic_schema else namespace, + display_labels=generic_schema.display_labels if generic_schema else None, + description=generic_schema.description if generic_schema else None, + state=generic_schema.state if generic_schema else SchemaState.PRESENT, + label=generic_schema.label if generic_schema else None, + include_in_menu=generic_schema.include_in_menu if generic_schema else None, + menu_placement=generic_schema.menu_placement if generic_schema else None, + documentation=generic_schema.documentation if generic_schema else None, + order_by=generic_schema.order_by if generic_schema else None, + # parent=schema.parent if schema else None, + # children=schema.children if schema else None, + icon=generic_schema.icon if generic_schema else None, + # generate_profile=schema.generate_profile if schema else None, + # branch=schema.branch if schema else None, + # default_filter=schema.default_filter if schema else None, ) + _add_fields(node=generic, model=model) + return generic + + # ------------------------------------------------------------ + # NodeSchema + # ------------------------------------------------------------ + node_schema: NodeSchema | None = model.model_config.get("node_schema") or None # type: ignore[assignment] + + if not node_schema: + namespace, name = validate_kind(model.__name__) + + generics = get_generics(model) + + # list all inherited fields with a hash for each to track if they are identical on the node + inherited_fields = { + field_name: field._attributes_set for generic in generics for field_name, field in generic.model_fields.items() + } + + node = NodeSchema( + name=node_schema.name if node_schema else name, + namespace=node_schema.namespace if node_schema else namespace, + display_labels=node_schema.display_labels if node_schema else None, + description=node_schema.description if node_schema else None, + state=node_schema.state if node_schema else SchemaState.PRESENT, + label=node_schema.label if node_schema else None, + include_in_menu=node_schema.include_in_menu if node_schema else None, + menu_placement=node_schema.menu_placement if node_schema else None, + documentation=node_schema.documentation if node_schema else None, + order_by=node_schema.order_by if node_schema else None, + inherit_from=[get_kind(generic) for generic in generics], + parent=node_schema.parent if node_schema else None, + children=node_schema.children if node_schema else None, + icon=node_schema.icon if node_schema else None, + generate_profile=node_schema.generate_profile if node_schema else None, + branch=node_schema.branch if node_schema else None, + # default_filter=schema.default_filter if schema else None, + ) + + _add_fields(node=node, model=model, inherited_fields=inherited_fields) + return node - for field_name, field in model.model_fields.items(): - field_info = analyze_field(field_name, field) - if field_info.is_attribute: - node.attributes.append(field_to_attribute(field_name, field_info, field)) - elif field_info.is_relationship: - node.relationships.append(field_to_relationship(field_name, field_info, field, namespace)) +def from_pydantic(models: list[type[BaseModel]]) -> SchemaRoot: + schema = SchemaRoot(version="1.0") + + for model in models: + node = model_to_node(model=model) - schema.nodes.append(node) + if isinstance(node, NodeSchema): + schema.nodes.append(node) + elif isinstance(node, GenericSchema): + schema.generics.append(node) return schema + + +# class NodeSchema(BaseModel): +# name: str| None = None +# namespace: str| None = None +# display_labels: list[str] | None = None + +# class NodeMetaclass(ModelMetaclass): +# model_config: NodeConfig +# # model_schema: NodeSchema +# __config__: type[NodeConfig] +# # __schema__: NodeSchema diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py index e5d18359..df6fdd74 100644 --- a/tests/unit/sdk/test_pydantic.py +++ b/tests/unit/sdk/test_pydantic.py @@ -3,22 +3,33 @@ from typing import Annotated, Optional import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field -from infrahub_sdk.schema.main import AttributeKind, AttributeSchema, RelationshipSchema -from infrahub_sdk.schema.pydantic_utils import ( - InfrahubAttributeParam as AttrParam, +from infrahub_sdk.schema.main import ( + AttributeKind, + AttributeSchema, + GenericSchema, + NodeSchema, + RelationshipSchema, + SchemaState, ) from infrahub_sdk.schema.pydantic_utils import ( + GenericModel, + NodeModel, analyze_field, field_to_attribute, field_to_relationship, from_pydantic, get_attribute_kind, + get_kind, + model_to_node, +) +from infrahub_sdk.schema.pydantic_utils import ( + InfrahubAttributeParam as AttrParam, ) -class MyModel(BaseModel): +class MyAllInOneModel(BaseModel): name: str age: int is_active: bool @@ -27,22 +38,48 @@ class MyModel(BaseModel): old_opt_age: Optional[int] = None # noqa: UP007 -class Tag(BaseModel): +class AcmeTag(BaseModel): name: str = Field(default="test_tag", description="The name of the tag") description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None label: Annotated[str, AttrParam(unique=True), Field(description="The label of the tag")] -class Car(BaseModel): +class AcmeCar(BaseModel): name: str - tags: list[Tag] - owner: Person - secondary_owner: Person | None = None + tags: list[AcmeTag] + owner: AcmePerson + secondary_owner: AcmePerson | None = None -class Person(BaseModel): +class AcmePerson(BaseModel): name: str - cars: list[Car] | None = None + cars: list[AcmeCar] | None = None + + +# -------------------------------- + + +class Book(NodeModel): + model_config = ConfigDict(node_schema=NodeSchema(name="Book", namespace="Library", display_labels=["name__value"])) + title: str + isbn: Annotated[str, AttrParam(unique=True)] + created_at: str + author: LibraryAuthor + + +class AbstractPerson(GenericModel): + model_config = ConfigDict(generic_schema=GenericSchema(name="AbstractPerson", namespace="Library")) + firstname: str = Field(..., description="The first name of the person", pattern=r"^[a-zA-Z]+$") + lastname: str + + +class LibraryAuthor(AbstractPerson): + books: list[Book] + + +class LibraryReader(AbstractPerson): + favorite_books: list[Book] + favorite_authors: list[LibraryAuthor] @pytest.mark.parametrize( @@ -57,7 +94,7 @@ class Person(BaseModel): ], ) def test_get_field_kind(field_name, expected_kind): - assert get_attribute_kind(MyModel.model_fields[field_name]) == expected_kind + assert get_attribute_kind(MyAllInOneModel.model_fields[field_name]) == expected_kind @pytest.mark.parametrize( @@ -65,7 +102,7 @@ def test_get_field_kind(field_name, expected_kind): [ ( "name", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -78,7 +115,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "age", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -91,7 +128,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "is_active", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -104,7 +141,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "opt_age", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -117,7 +154,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "default_name", - MyModel, + MyAllInOneModel, { "default": "some_default", "is_attribute": True, @@ -130,7 +167,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "old_opt_age", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -143,7 +180,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "description", - Tag, + AcmeTag, { "default": None, "is_attribute": True, @@ -156,7 +193,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "name", - Tag, + AcmeTag, { "default": "test_tag", "is_attribute": True, @@ -169,7 +206,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "label", - Tag, + AcmeTag, { "default": None, "is_attribute": True, @@ -182,7 +219,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "owner", - Car, + AcmeCar, { "default": None, "is_attribute": False, @@ -190,12 +227,12 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "owner", "optional": False, - "primary_type": Person, + "primary_type": AcmePerson, }, ), ( "tags", - Car, + AcmeCar, { "default": None, "is_attribute": False, @@ -203,12 +240,12 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "tags", "optional": False, - "primary_type": Tag, + "primary_type": AcmeTag, }, ), ( "secondary_owner", - Car, + AcmeCar, { "default": None, "is_attribute": False, @@ -216,7 +253,7 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "secondary_owner", "optional": True, - "primary_type": Person, + "primary_type": AcmePerson, }, ), ], @@ -230,7 +267,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): [ ( "name", - MyModel, + MyAllInOneModel, AttributeSchema( name="name", kind=AttributeKind.TEXT, @@ -239,7 +276,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "age", - MyModel, + MyAllInOneModel, AttributeSchema( name="age", kind=AttributeKind.NUMBER, @@ -248,7 +285,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "is_active", - MyModel, + MyAllInOneModel, AttributeSchema( name="is_active", kind=AttributeKind.BOOLEAN, @@ -257,7 +294,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "opt_age", - MyModel, + MyAllInOneModel, AttributeSchema( name="opt_age", kind=AttributeKind.NUMBER, @@ -266,17 +303,17 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "default_name", - MyModel, + MyAllInOneModel, AttributeSchema( name="default_name", kind=AttributeKind.TEXT, optional=True, - default="some_default", + default_value="some_default", ), ), ( "old_opt_age", - MyModel, + MyAllInOneModel, AttributeSchema( name="old_opt_age", kind=AttributeKind.NUMBER, @@ -285,7 +322,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "description", - Tag, + AcmeTag, AttributeSchema( name="description", kind=AttributeKind.TEXTAREA, @@ -294,17 +331,18 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "name", - Tag, + AcmeTag, AttributeSchema( name="name", description="The name of the tag", kind=AttributeKind.TEXT, optional=True, + default_value="test_tag", ), ), ( "label", - Tag, + AcmeTag, AttributeSchema( name="label", description="The label of the tag", @@ -313,6 +351,17 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): unique=True, ), ), + ( + "firstname", + AbstractPerson, + AttributeSchema( + name="firstname", + description="The first name of the person", + kind=AttributeKind.TEXT, + optional=False, + regex=r"^[a-zA-Z]+$", + ), + ), ], ) def test_field_to_attribute(field_name: str, model: BaseModel, expected: AttributeSchema): @@ -326,30 +375,30 @@ def test_field_to_attribute(field_name: str, model: BaseModel, expected: Attribu [ ( "owner", - Car, + AcmeCar, RelationshipSchema( name="owner", - peer="TestingPerson", + peer="AcmePerson", cardinality="one", optional=False, ), ), ( "tags", - Car, + AcmeCar, RelationshipSchema( name="tags", - peer="TestingTag", + peer="AcmeTag", cardinality="many", optional=False, ), ), ( "secondary_owner", - Car, + AcmeCar, RelationshipSchema( name="secondary_owner", - peer="TestingPerson", + peer="AcmePerson", cardinality="one", optional=True, ), @@ -362,6 +411,152 @@ def test_field_to_relationship(field_name: str, model: BaseModel, expected: Rela assert field_to_relationship(field_name, field_info, field) == expected +@pytest.mark.parametrize( + "model, expected", + [ + (MyAllInOneModel, "MyAllInOneModel"), + (Book, "LibraryBook"), + (LibraryAuthor, "LibraryAuthor"), + (LibraryReader, "LibraryReader"), + (AbstractPerson, "LibraryAbstractPerson"), + (AcmeTag, "AcmeTag"), + (AcmeCar, "AcmeCar"), + (AcmePerson, "AcmePerson"), + ], +) +def test_get_kind(model: BaseModel, expected: str): + assert get_kind(model) == expected + + +@pytest.mark.parametrize( + "model, expected", + [ + ( + MyAllInOneModel, + NodeSchema( + name="AllInOneModel", + namespace="My", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema(name="name", kind=AttributeKind.TEXT, optional=False), + AttributeSchema(name="age", kind=AttributeKind.NUMBER, optional=False), + AttributeSchema(name="is_active", kind=AttributeKind.BOOLEAN, optional=False), + AttributeSchema(name="opt_age", kind=AttributeKind.NUMBER, optional=True), + AttributeSchema( + name="default_name", kind=AttributeKind.TEXT, optional=True, default_value="some_default" + ), + AttributeSchema(name="old_opt_age", kind=AttributeKind.NUMBER, optional=True), + ], + ), + ), + ( + Book, + NodeSchema( + name="Book", + namespace="Library", + display_labels=["name__value"], + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema(name="title", kind=AttributeKind.TEXT, optional=False), + AttributeSchema(name="isbn", kind=AttributeKind.TEXT, optional=False, unique=True), + AttributeSchema(name="created_at", kind=AttributeKind.TEXT, optional=False), + ], + relationships=[ + RelationshipSchema( + name="author", + peer="LibraryAuthor", + cardinality="one", + optional=False, + relationships=[ + RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), + ], + ), + ], + ), + ), + ( + LibraryAuthor, + NodeSchema( + name="Author", + namespace="Library", + inherit_from=["LibraryAbstractPerson"], + state=SchemaState.PRESENT, + relationships=[ + RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), + ], + ), + ), + ( + LibraryReader, + NodeSchema( + name="Reader", + namespace="Library", + inherit_from=["LibraryAbstractPerson"], + state=SchemaState.PRESENT, + relationships=[ + RelationshipSchema(name="favorite_books", peer="LibraryBook", cardinality="many", optional=False), + RelationshipSchema( + name="favorite_authors", peer="LibraryAuthor", cardinality="many", optional=False + ), + ], + ), + ), + ( + AbstractPerson, + GenericSchema( + name="AbstractPerson", + namespace="Library", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema( + name="firstname", + kind=AttributeKind.TEXT, + optional=False, + description="The first name of the person", + regex=r"^[a-zA-Z]+$", + ), + AttributeSchema(name="lastname", kind=AttributeKind.TEXT, optional=False), + ], + ), + ), + ( + AcmeTag, + NodeSchema( + name="Tag", + namespace="Acme", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema( + name="name", + kind=AttributeKind.TEXT, + default_value="test_tag", + optional=True, + description="The name of the tag", + ), + AttributeSchema(name="description", kind=AttributeKind.TEXTAREA, optional=True), + AttributeSchema( + name="label", + kind=AttributeKind.TEXT, + optional=False, + unique=True, + description="The label of the tag", + ), + ], + ), + ), + ], +) +def test_model_to_node(model: BaseModel, expected: NodeSchema): + node = model_to_node(model) + assert node == expected + + def test_related_models(): - schemas = from_pydantic(models=[Person, Car, Tag]) + schemas = from_pydantic(models=[AcmePerson, AcmeCar, AcmeTag]) + assert len(schemas.nodes) == 3 + + +def test_library_models(): + schemas = from_pydantic(models=[Book, AbstractPerson, LibraryAuthor, LibraryReader]) assert len(schemas.nodes) == 3 + assert len(schemas.generics) == 1 From 178adfb97f80d6ff9a15f6e43c65d4dd9550f02c Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Mon, 24 Mar 2025 12:04:30 +0100 Subject: [PATCH 3/4] Format examples for Pydantic --- docs/docs/python-sdk/examples/pydantic_car.py | 25 +++++-- .../python-sdk/examples/pydantic_infra.py | 71 ++++++++++++------- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/docs/docs/python-sdk/examples/pydantic_car.py b/docs/docs/python-sdk/examples/pydantic_car.py index 321c7d31..11102d19 100644 --- a/docs/docs/python-sdk/examples/pydantic_car.py +++ b/docs/docs/python-sdk/examples/pydantic_car.py @@ -1,20 +1,32 @@ from __future__ import annotations from asyncio import run as aiorun - from typing import Annotated -from pydantic import BaseModel, Field, ConfigDict -from infrahub_sdk import InfrahubClient +from pydantic import ConfigDict, Field from rich import print as rprint -from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericModel + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.schema import ( + AttributeKind, + GenericModel, + NodeModel, + NodeSchema, + from_pydantic, +) +from infrahub_sdk.schema import ( + InfrahubAttributeParam as AttrParam, +) +from infrahub_sdk.schema import ( + InfrahubRelationshipParam as RelParam, +) class Tag(NodeModel): model_config = ConfigDict( node_schema=NodeSchema(name="Tag", namespace="Test", human_readable_fields=["name__value"]) ) - + name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] label: str | None = Field(description="The label of the tag") description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None @@ -30,11 +42,12 @@ class TestCar(NodeModel): class TestPerson(GenericModel): name: str + class TestCarOwner(NodeModel, TestPerson): cars: Annotated[list[TestCar] | None, RelParam(identifier="car__person")] = None -async def main(): +async def main() -> None: client = InfrahubClient() schema = from_pydantic(models=[TestPerson, TestCar, Tag, TestPerson, TestCarOwner]) rprint(schema.to_schema_dict()) diff --git a/docs/docs/python-sdk/examples/pydantic_infra.py b/docs/docs/python-sdk/examples/pydantic_infra.py index 5ff137a4..a7182c1b 100644 --- a/docs/docs/python-sdk/examples/pydantic_infra.py +++ b/docs/docs/python-sdk/examples/pydantic_infra.py @@ -1,23 +1,35 @@ from __future__ import annotations -from asyncio import run as aiorun - -from infrahub_sdk.async_typer import AsyncTyper - from typing import Annotated -from pydantic import BaseModel, Field, ConfigDict -from infrahub_sdk import InfrahubClient +from pydantic import ConfigDict, Field from rich import print as rprint -from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericSchema, GenericModel, RelationshipKind +from infrahub_sdk import InfrahubClient +from infrahub_sdk.async_typer import AsyncTyper +from infrahub_sdk.schema import ( + GenericModel, + GenericSchema, + NodeModel, + NodeSchema, + RelationshipKind, + from_pydantic, +) +from infrahub_sdk.schema import ( + InfrahubAttributeParam as AttrParam, +) +from infrahub_sdk.schema import ( + InfrahubRelationshipParam as RelParam, +) app = AsyncTyper() class Site(NodeModel): model_config = ConfigDict( - node_schema=NodeSchema(name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + node_schema=NodeSchema( + name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] + ) ) name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the site") @@ -25,7 +37,9 @@ class Site(NodeModel): class Vlan(NodeModel): model_config = ConfigDict( - node_schema=NodeSchema(name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"]) + node_schema=NodeSchema( + name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"] + ) ) name: str @@ -35,39 +49,45 @@ class Vlan(NodeModel): class Device(NodeModel): model_config = ConfigDict( - node_schema=NodeSchema(name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + node_schema=NodeSchema( + name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] + ) ) name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the car") site: Annotated[Site, RelParam(kind=RelationshipKind.ATTRIBUTE, identifier="device__site")] - interfaces: Annotated[list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces")] = Field(default_factory=list) + interfaces: Annotated[ + list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces") + ] = Field(default_factory=list) class Interface(GenericModel): model_config = ConfigDict( - generic_schema=GenericSchema(name="Interface", namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"]) + generic_schema=GenericSchema( + name="Interface", + namespace="Infra", + human_friendly_id=["device__name__value", "name__value"], + display_labels=["name__value"], + ) ) device: Annotated[Device, RelParam(kind=RelationshipKind.PARENT, identifier="device__interfaces")] name: str description: str | None = None + class L2Interface(Interface): - model_config = ConfigDict( - node_schema=NodeSchema(name="L2Interface", namespace="Infra") - ) - + model_config = ConfigDict(node_schema=NodeSchema(name="L2Interface", namespace="Infra")) + vlans: list[Vlan] = Field(default_factory=list) + class LoopbackInterface(Interface): - model_config = ConfigDict( - node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra") - ) - + model_config = ConfigDict(node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra")) @app.command() -async def load_schema(): +async def load_schema() -> None: client = InfrahubClient() schema = from_pydantic(models=[Site, Device, Interface, L2Interface, LoopbackInterface, Vlan]) rprint(schema.to_schema_dict()) @@ -76,7 +96,7 @@ async def load_schema(): @app.command() -async def load_data(): +async def load_data() -> None: client = InfrahubClient() atl = await client.create("InfraSite", name="ATL") @@ -100,14 +120,15 @@ async def load_data(): @app.command() -async def query_data(): +async def query_data() -> None: client = InfrahubClient() sites = await client.all(kind=Site) + rprint(sites) - breakpoint() devices = await client.all(kind=Device) for device in devices: rprint(device) + if __name__ == "__main__": - app() \ No newline at end of file + app() From 8a8a76852be77c5a5d25a83dd805f1f4540065be Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Sun, 30 Mar 2025 18:11:27 +0200 Subject: [PATCH 4/4] Fix conflict --- tests/unit/sdk/test_pydantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py index df6fdd74..bd221a7a 100644 --- a/tests/unit/sdk/test_pydantic.py +++ b/tests/unit/sdk/test_pydantic.py @@ -35,7 +35,7 @@ class MyAllInOneModel(BaseModel): is_active: bool opt_age: int | None = None default_name: str = "some_default" - old_opt_age: Optional[int] = None # noqa: UP007 + old_opt_age: Optional[int] = None class AcmeTag(BaseModel):