diff --git a/pydastic/__init__.py b/pydastic/__init__.py index 05c656c..ed483e6 100644 --- a/pydastic/__init__.py +++ b/pydastic/__init__.py @@ -23,16 +23,18 @@ def get_version() -> str: InvalidModelError, NotFoundError, ) -from pydastic.model import ESModel -from pydastic.pydastic import PydasticClient, connect +from pydastic.model import ESAsyncModel, ESModel +from pydastic.pydastic import PydasticClient, connect, connect_async from pydastic.session import Session __all__ = [ "ESModel", + "ESAsyncModel", "Session", "NotFoundError", "InvalidModelError", "InvalidElasticsearchResponse", "PydasticClient", "connect", + "connect_async", ] diff --git a/pydastic/model.py b/pydastic/model.py index ada8b72..6e707e0 100644 --- a/pydastic/model.py +++ b/pydastic/model.py @@ -1,5 +1,7 @@ +from abc import ABC, abstractmethod from copy import copy from datetime import datetime +from functools import partial from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union from elasticsearch import NotFoundError as ElasticNotFoundError @@ -124,20 +126,11 @@ def save(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = Args: index (str, optional): Index name wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False. - """ - doc = self.dict(exclude={"id"}) - - # Allow waiting for shards - useful when testing - refresh = "false" - if wait_for: - refresh = "wait_for" - - # Use user-provided index if provided (dynamic index support) - if not index: - index = self.Meta.index - res = _client.client.index(index=index, body=doc, id=self.id, refresh=refresh) - self.id = res.get("_id") + Returns: + New document ID + """ + return Save(index=index, wait_for=wait_for, model=self).sync() @classmethod def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Optional[str] = None) -> M: @@ -154,26 +147,11 @@ def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Opti Raises: NotFoundError: Returned if document not found """ - source_includes = None - if not extra_fields: - fields: dict = copy(vars(cls).get("__fields__")) - fields.pop("id", None) - source_includes = list(fields.keys()) - - # Use user-provided index if provided (dynamic index support) - if not index: - index = cls.Meta.index - try: - res = _client.client.get(index=index, id=id, _source_includes=source_includes) + return Get(model=cls, id_=id, index=index, extra_fields=extra_fields).sync() except ElasticNotFoundError: raise NotFoundError(f"document with id {id} not found") - model = cls.from_es(res) - model.id = id - - return model - def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False): """Deletes document from elasticsearch. @@ -187,17 +165,170 @@ def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] """ if not self.id: raise ValueError("id missing from object") + try: + Delete(index=index, model=self).sync() + except ElasticNotFoundError: + raise NotFoundError(f"document with id {id} not found") + + +class ESAsyncModel(ESModel): + class Meta: + @property + def index(self) -> str: + """Elasticsearch index name associated with this model class""" + raise NotImplementedError + + async def save(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False): + """Indexes document into elasticsearch. + If document already exists, existing document will be updated as per native elasticsearch index operation. + If model instance includes an 'id' property, this will be used as the elasticsearch _id. + If no 'id' is provided, then document will be indexed and elasticsearch will generate a suitable id that will be populated on the returned model. + + Args: + index (str, optional): Index name + wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False. + + Returns: + New document ID + """ + return await Save(index=index, wait_for=wait_for, model=self).asyncio() + + @classmethod + async def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Optional[str] = None) -> M: + """Fetches document and returns ESModel instance populated with properties. + + Args: + id (str): Document id + extra_fields (bool, Optional): Include fields found in elasticsearch but not part of the model definition + index (str, optional): Index name + + Returns: + ESAsyncModel + + Raises: + NotFoundError: Returned if document not found + """ + get = Get(model=cls, id_=id, extra_fields=extra_fields, index=index) + try: + return await get.asyncio() + except ElasticNotFoundError: + raise NotFoundError(f"document with id {id} not found") - # Allow waiting for shards - useful when testing - refresh = "false" - if wait_for: - refresh = "wait_for" + async def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False): + """Deletes document from elasticsearch. - # Use user-provided index if provided (dynamic index support) - if not index: - index = self.Meta.index + Args: + index (str, optional): Index name + wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False. + Raises: + NotFoundError: Returned if document not found + ValueError: Returned when id attribute missing from instance + """ + if not self.id: + raise ValueError("id missing from object") try: - _client.client.delete(index=index, id=self.id, refresh=refresh) + await Delete(index=index, model=self, wait_for=wait_for).asyncio() except ElasticNotFoundError: raise NotFoundError(f"document with id {id} not found") + + +class BaseESOperation(ABC): + """Abstract class of ES operation: save, get, delete + + Each subclass should set INDEX_METHOD_NAME which is the ES client function to run + initiator (__init__) is used as pre-operation activity. + """ + + INDEX_METHOD_NAME: Optional[str] = None + + def __init__(self, model: Type[M], index: str, wait_for: Optional[bool] = None): + if self.INDEX_METHOD_NAME is None: + raise AttributeError(f"Must set INDEX_METHOD_NAME variable for class {self.__class__.__name__}") + self._es_client_func = getattr(_client.client, self.INDEX_METHOD_NAME) + + self.model = model + self.index = index + if wait_for is not None: + # Allow waiting for shards - useful when testing + self.refresh = "wait_for" if wait_for else "false" + + @abstractmethod + @property + def kwargs(self): + """ES operation kwargs""" + raise NotImplementedError() + + @abstractmethod + def post(self, *args, **kwargs): + """Post activites after ES operation is done""" + # do nothing by default + raise NotImplementedError() + + def sync(self): + """Run in blocking mode""" + es_result = self._es_callable() + return self.post(es_result=es_result) + + async def asyncio(self): + """Run in async mode""" + es_result = await self._es_callable() + return self.post(es_result=es_result) + + @property + def _es_callable(self) -> callable: + return partial(self._es_client_func, **self.kwargs) + + +class Get(BaseESOperation): + INDEX_METHOD_NAME = "get" + + def __init__(self, model: Type[M], id_: str, index: Optional[str] = None, extra_fields: Optional[bool] = False): + super().__init__(index=index, model=model) + self.id = id_ + + self.source_includes = None + if not extra_fields: + fields: dict = copy(vars(model).get("__fields__")) + fields.pop("id", None) + self.source_includes = list(fields.keys()) + + @property + def kwargs(self): + return dict(index=self.index, id=self.id, _source_includes=self.source_includes) + + def post(self, es_result: dict) -> M: + model = self.model.from_es(es_result) + model.id = self.id + return model + + +class Save(BaseESOperation): + INDEX_METHOD_NAME = "index" + + def __init__(self, model: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False): + super().__init__(model=model, index=index, wait_for=wait_for) + self.doc = self.model.dict(exclude={"id"}) + + def post(self, es_result: dict) -> str: + self.model.id = es_result.get("_id") + return self.model.id + + @property + def kwargs(self): + return dict(index=self.index, doc=self.doc, id=self.model.id, refresh=self.refresh) + + +class Delete(BaseESOperation): + + INDEX_METHOD_NAME = "delete" + + def __init__(self, model: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False): + super().__init__(model=model, index=index, wait_for=wait_for) + + @property + def kwargs(self): + return dict(index=self.index, id=self.model.id, refresh=self.refresh) + + def post(self, es_result: dict) -> None: + return None diff --git a/pydastic/pydastic.py b/pydastic/pydastic.py index c4356b4..01e74df 100644 --- a/pydastic/pydastic.py +++ b/pydastic/pydastic.py @@ -1,17 +1,19 @@ import typing as t -from elasticsearch import Elasticsearch +from elasticsearch import AsyncElasticsearch, Elasticsearch + +ElasticClasses = t.Union[Elasticsearch, AsyncElasticsearch] class PydasticClient: - client: Elasticsearch = None + client: t.Optional[ElasticClasses] = None def __getattribute__(self, __name: str) -> t.Any: if __name == "client" and object.__getattribute__(self, __name) is None: - raise AttributeError("client not initialized - make sure to call Pydastic.connect()") + raise AttributeError("client not initialized - make sure to call Pydastic.connect() or Pydastic.connect_async()") return object.__getattribute__(self, __name) - def _set_client(self, client: Elasticsearch): + def _set_client(self, client: ElasticClasses): object.__setattr__(self, "client", client) return client @@ -35,3 +37,8 @@ def wrapped(*args, **kwargs): @copy_signature(Elasticsearch) def connect(*args, **kwargs): ... + + +@copy_signature(AsyncElasticsearch) +def connect_async(*args, **kwargs): + ... diff --git a/tests/car.py b/tests/car.py new file mode 100644 index 0000000..315dd5e --- /dev/null +++ b/tests/car.py @@ -0,0 +1,15 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from pydastic import ESAsyncModel + + +class Car(ESAsyncModel): + model: str + year: Optional[int] + last_test: datetime = Field(default_factory=datetime.now) + + class Meta: + index = "car" diff --git a/tests/conftest.py b/tests/conftest.py index 6eef357..252aaa8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,16 @@ import pytest -from elasticsearch import Elasticsearch +from car import Car +from elasticsearch import AsyncElasticsearch, Elasticsearch from user import User -from pydastic.pydastic import _client, connect +from pydastic.pydastic import _client, connect, connect_async + + +@pytest.fixture() +async def async_es() -> AsyncElasticsearch: + connect_async(hosts="http://localhost:9200") + await _client.client.delete_by_query(index="_all", body={"query": {"match_all": {}}}, wait_for_completion=True, refresh=True) + return _client.client @pytest.fixture() @@ -17,3 +25,10 @@ def user(es: Elasticsearch) -> User: user = User(name="John", phone="123456") user.save(wait_for=True) return user + + +@pytest.fixture() +def car(aes: AsyncElasticsearch) -> Car: + car = Car(name="Seat", year="2023") + car.save(wait_for=True) + return car diff --git a/tests/test_model_async.py b/tests/test_model_async.py new file mode 100644 index 0000000..d1a9274 --- /dev/null +++ b/tests/test_model_async.py @@ -0,0 +1,253 @@ +from copy import deepcopy +from datetime import datetime +from uuid import uuid4 + +import pytest +from car import Car +from elasticsearch import AsyncElasticsearch + +from pydastic import ESAsyncModel, NotFoundError +from pydastic.error import InvalidElasticsearchResponse + +pytestmark = pytest.mark.asyncio + + +async def test_model_definition_yields_error_without_meta_class(): + with pytest.raises(NotImplementedError): + + class AsyncUser(ESAsyncModel): + pass + + +async def test_model_definition_yields_error_without_index(): + with pytest.raises(NotImplementedError): + + class AsyncUser(ESAsyncModel): + class Meta: + pass + + +@pytest.mark.asyncio +async def test_model_async_save_without_connection_raises_attribute_error(): + with pytest.raises(AttributeError): + car = Car(model="Fiat") + await car.save(wait_for=True) + + +@pytest.mark.asyncio +async def test_model_async_save(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + await car.save(wait_for=True) + assert car.id is not None + + res = await async_es.get(index=car.Meta.index, id=car.id) + assert res["found"] + + # Check that fields match exactly + model = car.to_es() + assert res["_source"] == model + + +@pytest.mark.asyncio +async def test_model_save_with_index(async_es: AsyncElasticsearch): + preset_id = "sam@mail.com" + car = Car(id=preset_id, model="Fiat") + await car.save(wait_for=True) + + res = await async_es.get(index=car.Meta.index, id=preset_id) + assert res["found"] + + model = car.to_es() + assert res["_source"] == model + + +@pytest.mark.asyncio +async def test_model_save_with_dynamic_index(async_es: AsyncElasticsearch): + preset_id = "123456" + car = Car(id=preset_id, model="Fiat") + await car.save(index="custom-car", wait_for=True) + + res = await async_es.get(index="custom-car", id=preset_id) + assert res["found"] + + model = car.to_es() + assert res["_source"] == model + + +@pytest.mark.asyncio +async def test_model_save_datetime_saved_as_isoformat(async_es: AsyncElasticsearch): + date = datetime.now() + iso = date.isoformat() + + car = Car(model="Fiat", year=1994) + await car.save(wait_for=True) + + res = async_es.get(index=car.Meta.index, id=car.id) + assert res["found"] + assert res["_source"]["last_test"] == iso + + +@pytest.mark.asyncio +async def test_model_save_to_update(async_es: AsyncElasticsearch, car: Car): + # Update user details + user_copy = deepcopy(car) + + dummy_name = "xxxxx" + car.name = dummy_name + + await car.save(wait_for=True) + saved_car = await Car.get(id=car.id) + + assert saved_car.model == car.model + + # Change name back to compare with old object + saved_car.name = user_copy.name + assert saved_car == user_copy + + +@pytest.mark.asyncio +async def test_model_save_additional_fields(async_es: AsyncElasticsearch): + extra_fields = {"horse_power": "250", "color": "red"} + res = async_es.index(index=Car.Meta.index, body=extra_fields) + + car = await Car.get(res["_id"], extra_fields=True) + + # Confirm that user has these extra fields + assert car.horse_power == extra_fields["horse_power"] + assert car.color == extra_fields["color"] + + # Check that extra fields dict is exact subset + user_dict = car.dict() + assert dict(user_dict, **extra_fields) == user_dict + + +@pytest.mark.asyncio +async def test_model_ignores_additional_fields(async_es: AsyncElasticsearch): + extra_fields = {"horse_power": "250", "color": "red"} + res = async_es.index(index=Car.Meta.index, body=extra_fields) + + car = await Car.get(res["_id"]) + with pytest.raises(AttributeError): + car.horse_power + + with pytest.raises(AttributeError): + car.color + + +@pytest.mark.asyncio +async def test_model_get_fields_unaffected(async_es: AsyncElasticsearch, car: Car): + """Bug where fields get overwritten when model is fetched and ID is popped out""" + await Car.get(id=car.id) + assert "id" in Car.__fields__ + + +@pytest.mark.asyncio +async def test_model_from_es(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + await car.save(wait_for=True) + + res = async_es.get(index=car.Meta.index, id=car.id) + assert res["found"] + + car_from_es = Car.from_es(res) + assert car == car_from_es + + +@pytest.mark.asyncio +async def test_model_from_es_empty_data(): + car = Car.from_es({}) + assert car is None + + +@pytest.mark.asyncio +async def test_model_from_es_invalid_format(): + res = {"does not": "include _source", "or": "_id"} + + with pytest.raises(InvalidElasticsearchResponse): + Car.from_es(res) + + +@pytest.mark.asyncio +async def test_model_to_es(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + await car.save(wait_for=True) + es_from_car = car.to_es() + + res = async_es.get(index=car.Meta.index, id=car.id) + assert res["_source"] == es_from_car + + +@pytest.mark.asyncio +async def test_model_to_es_with_exclude(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + await car.save(wait_for=True) + es_from_car = car.to_es(exclude={"last_test", "year"}) + + # Check that id excluded and fields excluded + assert es_from_car == {"model": "Fiat"} + + +@pytest.mark.asyncio +async def test_model_get(async_es: AsyncElasticsearch): + car = Car(model="Fiat", year="1976") + await car.save(wait_for=True) + + get = Car.get(id=car.id) + assert get == car + + +@pytest.mark.asyncio +async def test_model_get_with_dynamic_index(async_es: AsyncElasticsearch): + car = Car(model="Fiat", year="1964") + await car.save(index="custom", wait_for=True) + + get = await Car.get(index="custom", id=car.id) + assert get == car + + +@pytest.mark.asyncio +async def test_model_get_nonexistent_raises_error(async_es: AsyncElasticsearch): + with pytest.raises(NotFoundError): + await Car.get(id=str(uuid4())) + + +@pytest.mark.asyncio +async def test_model_delete_raises_error(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + + with pytest.raises(ValueError): + await car.delete(wait_for=True) + + with pytest.raises(NotFoundError): + car.id = "123456" + await car.delete(wait_for=True) + + +@pytest.mark.asyncio +async def test_model_delete(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + await car.save(wait_for=True) + await car.delete(wait_for=True) + + with pytest.raises(NotFoundError): + await Car.get(id=car.id) + + +@pytest.mark.asyncio +async def test_model_delete_with_dynamic_index(async_es: AsyncElasticsearch): + car = Car(model="Fiat") + await car.save(index="abc", wait_for=True) + await car.delete(index="abc", wait_for=True) + + with pytest.raises(NotFoundError): + await Car.get(id=car.id, index="abc") + + +@pytest.mark.asyncio +async def test_internal_meta_class_changes_limited_to_instance(): + # Cannot modify Meta index to have a dynamic index name + car = Car(model="Fiat") + car.Meta.index = "dev-car" + + assert Car.Meta.index == "dev-car" + assert car.Meta.index == "dev-car"