Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/langchain_google_alloydb_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json
import re
import uuid
from typing import Any, Callable, Iterable, List, Optional, Sequence
from typing import Any, Callable, Iterable, Optional, Sequence

import numpy as np
import requests
Expand Down Expand Up @@ -720,7 +720,7 @@ async def asimilarity_search_with_score_by_vector(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
id=str(row[self.id_column]),
),
row["distance"],
)
Expand Down Expand Up @@ -811,7 +811,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
id=str(row[self.id_column]),
),
row["distance"],
)
Expand Down Expand Up @@ -903,7 +903,7 @@ async def is_valid_index(
results = result_map.fetchall()
return bool(len(results) == 1)

async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]:
async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""

quoted_ids = [f"'{id_val}'" for id_val in ids]
Expand All @@ -918,7 +918,7 @@ async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]:

column_names = ", ".join(f'"{col}"' for col in columns)

query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} IN ({id_list_str});'
query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({id_list_str});'

async with self.engine.connect() as conn:
result = await conn.execute(text(query))
Expand All @@ -939,14 +939,14 @@ async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]:
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
id=str(row[self.id_column]),
)
)
)

return documents

def get_by_ids(self, ids: Sequence[str]) -> List[Document]:
def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
)
Expand Down
6 changes: 3 additions & 3 deletions src/langchain_google_alloydb_pg/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import Any, Callable, Iterable, List, Optional, Sequence
from typing import Any, Callable, Iterable, Optional, Sequence

from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
Expand Down Expand Up @@ -906,11 +906,11 @@ def is_valid_index(
"""Check if index exists in the table."""
return self._engine._run_as_sync(self.__vs.is_valid_index(index_name))

async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]:
async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids))

def get_by_ids(self, ids: Sequence[str]) -> List[Document]:
def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids))

Expand Down
6 changes: 6 additions & 0 deletions tests/test_async_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
IMAGE_TABLE = "test_image_table" + str(uuid.uuid4()).replace("-", "_")
VECTOR_SIZE = 768
sync_method_exception_str = "Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)

Expand Down Expand Up @@ -374,3 +375,8 @@ async def test_aget_by_ids_custom_vs(self, vs_custom):
results = await vs_custom.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])

def test_get_by_ids(self, vs):
test_ids = [ids[0]]
with pytest.raises(Exception, match=sync_method_exception_str):
vs.get_by_ids(ids=test_ids)