diff --git a/src/langchain_google_alloydb_pg/async_vectorstore.py b/src/langchain_google_alloydb_pg/async_vectorstore.py index 28e62ba3..e7b98cbc 100644 --- a/src/langchain_google_alloydb_pg/async_vectorstore.py +++ b/src/langchain_google_alloydb_pg/async_vectorstore.py @@ -19,7 +19,7 @@ import json import re import uuid -from typing import Any, Callable, Iterable, List, Optional, Sequence +from typing import Any, Callable, Iterable, Optional, Sequence import numpy as np import requests @@ -720,7 +720,7 @@ async def asimilarity_search_with_score_by_vector( Document( page_content=row[self.content_column], metadata=metadata, - id=row[self.id_column], + id=str(row[self.id_column]), ), row["distance"], ) @@ -811,7 +811,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( Document( page_content=row[self.content_column], metadata=metadata, - id=row[self.id_column], + id=str(row[self.id_column]), ), row["distance"], ) @@ -903,7 +903,7 @@ async def is_valid_index( results = result_map.fetchall() return bool(len(results) == 1) - async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]: + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" quoted_ids = [f"'{id_val}'" for id_val in ids] @@ -918,7 +918,7 @@ async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]: column_names = ", ".join(f'"{col}"' for col in columns) - query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} IN ({id_list_str});' + query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({id_list_str});' async with self.engine.connect() as conn: result = await conn.execute(text(query)) @@ -939,14 +939,14 @@ async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]: Document( page_content=row[self.content_column], metadata=metadata, - id=row[self.id_column], + id=str(row[self.id_column]), ) ) ) return documents - def get_by_ids(self, ids: Sequence[str]) -> List[Document]: + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: raise NotImplementedError( "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." ) diff --git a/src/langchain_google_alloydb_pg/vectorstore.py b/src/langchain_google_alloydb_pg/vectorstore.py index 51e34dc3..f9422fab 100644 --- a/src/langchain_google_alloydb_pg/vectorstore.py +++ b/src/langchain_google_alloydb_pg/vectorstore.py @@ -15,7 +15,7 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -from typing import Any, Callable, Iterable, List, Optional, Sequence +from typing import Any, Callable, Iterable, Optional, Sequence from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -906,11 +906,11 @@ def is_valid_index( """Check if index exists in the table.""" return self._engine._run_as_sync(self.__vs.is_valid_index(index_name)) - async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]: + async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids)) - def get_by_ids(self, ids: Sequence[str]) -> List[Document]: + def get_by_ids(self, ids: Sequence[str]) -> list[Document]: """Get documents by ids.""" return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids)) diff --git a/tests/test_async_vectorstore_search.py b/tests/test_async_vectorstore_search.py index ce0e0561..f43f7fc2 100644 --- a/tests/test_async_vectorstore_search.py +++ b/tests/test_async_vectorstore_search.py @@ -34,6 +34,7 @@ CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_") IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()).replace("-", "_") VECTOR_SIZE = 768 +sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead." embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) @@ -374,3 +375,8 @@ async def test_aget_by_ids_custom_vs(self, vs_custom): results = await vs_custom.aget_by_ids(ids=test_ids) assert results[0] == Document(page_content="foo", id=ids[0]) + + def test_get_by_ids(self, vs): + test_ids = [ids[0]] + with pytest.raises(Exception, match=sync_method_exception_str): + vs.get_by_ids(ids=test_ids)