Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pydastic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,18 @@ def get_version() -> str:
InvalidModelError,
NotFoundError,
)
from pydastic.model import ESModel
from pydastic.pydastic import PydasticClient, connect
from pydastic.model import ESAsyncModel, ESModel
from pydastic.pydastic import PydasticClient, connect, connect_async
from pydastic.session import Session

__all__ = [
"ESModel",
"ESAsyncModel",
"Session",
"NotFoundError",
"InvalidModelError",
"InvalidElasticsearchResponse",
"PydasticClient",
"connect",
"connect_async",
]
205 changes: 168 additions & 37 deletions pydastic/model.py
Copy link
Copy Markdown
Author

@bshakur8 bshakur8 Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's a duplicate code. My suggestion is to separate it to pre & post functions for each save, get and delete functions. Those pre&post function will be shared between ESModel and ESAsyncModel.
The only difference between them is that the later will use await in AsyncES client function.

For example:

def save(<args>)
    kwargs = pre_save(<args>)
    result = client.save(**kwargs)
    return post_save(result)
    
    
async def save(<args>)
    kwargs = pre_save(<args>)
    result = await client.save(**kwargs)
    return post_save(result)

@RamiAwar Do you agree with this approach or do you have a better one?

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorryy totally missed this! Will check

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, here are my thoughts @bshakur8:

  • The operation classes are a bit of an overkill - they make understanding the code much more complicated. Having a little bit of duplication is better if it makes maintenance and extendability easier. Maybe functions are a better alternative here. Maybe we can put the pre- and post- logic in functions that we'd then simply call in both clients.

  • We should have 2 clients - an async one and a sync one. I don't like the idea of supporting sync and async operations on the same client as it's complicating something that should be simple. Users will either use the sync client, or they will use the async client, never both. Also every other ORM I've seen uses this same idea, so we shouldn't stray from the standard.

  • The tests look amazing, the async logic looks solid. That's for that! We'll definitely merge that in as soon as we address the first two points.

What do you think? I'm willing to draft a more detailed suggestion if you want.

Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from copy import copy
from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, Optional, Tuple, Type, TypeVar, Union

from elasticsearch import NotFoundError as ElasticNotFoundError
Expand Down Expand Up @@ -124,20 +126,11 @@ def save(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] =
Args:
index (str, optional): Index name
wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False.
"""
doc = self.dict(exclude={"id"})

# Allow waiting for shards - useful when testing
refresh = "false"
if wait_for:
refresh = "wait_for"

# Use user-provided index if provided (dynamic index support)
if not index:
index = self.Meta.index

res = _client.client.index(index=index, body=doc, id=self.id, refresh=refresh)
self.id = res.get("_id")
Returns:
New document ID
"""
return Save(index=index, wait_for=wait_for, model=self).sync()

@classmethod
def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Optional[str] = None) -> M:
Expand All @@ -154,26 +147,11 @@ def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Opti
Raises:
NotFoundError: Returned if document not found
"""
source_includes = None
if not extra_fields:
fields: dict = copy(vars(cls).get("__fields__"))
fields.pop("id", None)
source_includes = list(fields.keys())

# Use user-provided index if provided (dynamic index support)
if not index:
index = cls.Meta.index

try:
res = _client.client.get(index=index, id=id, _source_includes=source_includes)
return Get(model=cls, id_=id, index=index, extra_fields=extra_fields).sync()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")

model = cls.from_es(res)
model.id = id

return model

def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
"""Deletes document from elasticsearch.

Expand All @@ -187,17 +165,170 @@ def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool]
"""
if not self.id:
raise ValueError("id missing from object")
try:
Delete(index=index, model=self).sync()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")


class ESAsyncModel(ESModel):
class Meta:
@property
def index(self) -> str:
"""Elasticsearch index name associated with this model class"""
raise NotImplementedError

async def save(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
"""Indexes document into elasticsearch.
If document already exists, existing document will be updated as per native elasticsearch index operation.
If model instance includes an 'id' property, this will be used as the elasticsearch _id.
If no 'id' is provided, then document will be indexed and elasticsearch will generate a suitable id that will be populated on the returned model.

Args:
index (str, optional): Index name
wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False.

Returns:
New document ID
"""
return await Save(index=index, wait_for=wait_for, model=self).asyncio()

@classmethod
async def get(cls: Type[M], id: str, extra_fields: Optional[bool] = False, index: Optional[str] = None) -> M:
"""Fetches document and returns ESModel instance populated with properties.

Args:
id (str): Document id
extra_fields (bool, Optional): Include fields found in elasticsearch but not part of the model definition
index (str, optional): Index name

Returns:
ESAsyncModel

Raises:
NotFoundError: Returned if document not found
"""
get = Get(model=cls, id_=id, extra_fields=extra_fields, index=index)
try:
return await get.asyncio()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")

# Allow waiting for shards - useful when testing
refresh = "false"
if wait_for:
refresh = "wait_for"
async def delete(self: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
"""Deletes document from elasticsearch.

# Use user-provided index if provided (dynamic index support)
if not index:
index = self.Meta.index
Args:
index (str, optional): Index name
wait_for (bool, optional): Waits for all shards to sync before returning response - useful when writing tests. Defaults to False.

Raises:
NotFoundError: Returned if document not found
ValueError: Returned when id attribute missing from instance
"""
if not self.id:
raise ValueError("id missing from object")
try:
_client.client.delete(index=index, id=self.id, refresh=refresh)
await Delete(index=index, model=self, wait_for=wait_for).asyncio()
except ElasticNotFoundError:
raise NotFoundError(f"document with id {id} not found")


class BaseESOperation(ABC):
"""Abstract class of ES operation: save, get, delete

Each subclass should set INDEX_METHOD_NAME which is the ES client function to run
initiator (__init__) is used as pre-operation activity.
"""

INDEX_METHOD_NAME: Optional[str] = None

def __init__(self, model: Type[M], index: str, wait_for: Optional[bool] = None):
if self.INDEX_METHOD_NAME is None:
raise AttributeError(f"Must set INDEX_METHOD_NAME variable for class {self.__class__.__name__}")
self._es_client_func = getattr(_client.client, self.INDEX_METHOD_NAME)

self.model = model
self.index = index
if wait_for is not None:
# Allow waiting for shards - useful when testing
self.refresh = "wait_for" if wait_for else "false"

@abstractmethod
@property
def kwargs(self):
"""ES operation kwargs"""
raise NotImplementedError()

@abstractmethod
def post(self, *args, **kwargs):
"""Post activites after ES operation is done"""
# do nothing by default
raise NotImplementedError()

def sync(self):
"""Run in blocking mode"""
es_result = self._es_callable()
return self.post(es_result=es_result)

async def asyncio(self):
"""Run in async mode"""
es_result = await self._es_callable()
return self.post(es_result=es_result)

@property
def _es_callable(self) -> callable:
return partial(self._es_client_func, **self.kwargs)


class Get(BaseESOperation):
INDEX_METHOD_NAME = "get"

def __init__(self, model: Type[M], id_: str, index: Optional[str] = None, extra_fields: Optional[bool] = False):
super().__init__(index=index, model=model)
self.id = id_

self.source_includes = None
if not extra_fields:
fields: dict = copy(vars(model).get("__fields__"))
fields.pop("id", None)
self.source_includes = list(fields.keys())

@property
def kwargs(self):
return dict(index=self.index, id=self.id, _source_includes=self.source_includes)

def post(self, es_result: dict) -> M:
model = self.model.from_es(es_result)
model.id = self.id
return model


class Save(BaseESOperation):
INDEX_METHOD_NAME = "index"

def __init__(self, model: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
super().__init__(model=model, index=index, wait_for=wait_for)
self.doc = self.model.dict(exclude={"id"})

def post(self, es_result: dict) -> str:
self.model.id = es_result.get("_id")
return self.model.id

@property
def kwargs(self):
return dict(index=self.index, doc=self.doc, id=self.model.id, refresh=self.refresh)


class Delete(BaseESOperation):

INDEX_METHOD_NAME = "delete"

def __init__(self, model: Type[M], index: Optional[str] = None, wait_for: Optional[bool] = False):
super().__init__(model=model, index=index, wait_for=wait_for)

@property
def kwargs(self):
return dict(index=self.index, id=self.model.id, refresh=self.refresh)

def post(self, es_result: dict) -> None:
return None
15 changes: 11 additions & 4 deletions pydastic/pydastic.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import typing as t

from elasticsearch import Elasticsearch
from elasticsearch import AsyncElasticsearch, Elasticsearch

ElasticClasses = t.Union[Elasticsearch, AsyncElasticsearch]


class PydasticClient:
client: Elasticsearch = None
client: t.Optional[ElasticClasses] = None

def __getattribute__(self, __name: str) -> t.Any:
if __name == "client" and object.__getattribute__(self, __name) is None:
raise AttributeError("client not initialized - make sure to call Pydastic.connect()")
raise AttributeError("client not initialized - make sure to call Pydastic.connect() or Pydastic.connect_async()")
return object.__getattribute__(self, __name)

def _set_client(self, client: Elasticsearch):
def _set_client(self, client: ElasticClasses):
object.__setattr__(self, "client", client)
return client

Expand All @@ -35,3 +37,8 @@ def wrapped(*args, **kwargs):
@copy_signature(Elasticsearch)
def connect(*args, **kwargs):
...


@copy_signature(AsyncElasticsearch)
def connect_async(*args, **kwargs):
...
15 changes: 15 additions & 0 deletions tests/car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from datetime import datetime
from typing import Optional

from pydantic import Field

from pydastic import ESAsyncModel


class Car(ESAsyncModel):
model: str
year: Optional[int]
last_test: datetime = Field(default_factory=datetime.now)

class Meta:
index = "car"
19 changes: 17 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import pytest
from elasticsearch import Elasticsearch
from car import Car
from elasticsearch import AsyncElasticsearch, Elasticsearch
from user import User

from pydastic.pydastic import _client, connect
from pydastic.pydastic import _client, connect, connect_async


@pytest.fixture()
async def async_es() -> AsyncElasticsearch:
connect_async(hosts="http://localhost:9200")
await _client.client.delete_by_query(index="_all", body={"query": {"match_all": {}}}, wait_for_completion=True, refresh=True)
return _client.client


@pytest.fixture()
Expand All @@ -17,3 +25,10 @@ def user(es: Elasticsearch) -> User:
user = User(name="John", phone="123456")
user.save(wait_for=True)
return user


@pytest.fixture()
def car(aes: AsyncElasticsearch) -> Car:
car = Car(name="Seat", year="2023")
car.save(wait_for=True)
return car
Loading