Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/langchain_google_cloud_sql_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,49 @@ async def __aadd_embeddings(

return ids

async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""

quoted_ids = [f"'{id_val}'" for id_val in ids]
id_list_str = ", ".join(quoted_ids)

columns = self.metadata_columns + [
self.id_column,
self.content_column,
]
if self.metadata_json_column:
columns.append(self.metadata_json_column)

column_names = ", ".join(f'"{col}"' for col in columns)

query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE "{self.id_column}" IN ({id_list_str});'

async with self.pool.connect() as conn:
result = await conn.execute(text(query))
result_map = result.mappings()
results = result_map.fetchall()

documents = []
for row in results:
metadata = (
row[self.metadata_json_column]
if self.metadata_json_column and row[self.metadata_json_column]
else {}
)
for col in self.metadata_columns:
metadata[col] = row[col]
documents.append(
(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should cast this to a string just in case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have that option, but wouldn't that negate the value of having customizable ID columns?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Document interface is enforcing the str Id in the document object https://github.com/langchain-ai/langchain/blob/33354f984fba660e71ca0039cfbd3cf37643cfab/libs/core/langchain_core/documents/base.py#L34. Users will still be able to insert and use the id data type of their choice

)
)
)

return documents

async def aadd_texts(
self,
texts: Iterable[str],
Expand Down Expand Up @@ -772,6 +815,11 @@ async def is_valid_index(

return bool(len(results) == 1)

def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead."
)

def similarity_search(
self,
query: str,
Expand Down
10 changes: 9 additions & 1 deletion src/langchain_google_cloud_sql_pg/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Iterable, Optional, Sequence

import numpy as np
from langchain_core.documents import Document
Expand Down Expand Up @@ -813,3 +813,11 @@ def is_valid_index(
) -> bool:
"""Check if index exists in the table."""
return self._engine._run_as_sync(self.__vs.is_valid_index(index_name))

async def aget_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids))

def get_by_ids(self, ids: Sequence[str]) -> list[Document]:
"""Get documents by ids."""
return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids))
18 changes: 18 additions & 0 deletions tests/test_async_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
DEFAULT_TABLE = "test_table" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE = "test_table_custom" + str(uuid.uuid4()).replace("-", "_")
VECTOR_SIZE = 768
sync_method_exception_str = "Sync methods are not implemented for AsyncPostgresVectorStore. Use PostgresVectorStore interface instead."

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)

Expand Down Expand Up @@ -269,3 +270,20 @@ async def test_max_marginal_relevance_search_vector_score(self, vs_custom):
embedding, lambda_mult=0.75, fetch_k=10
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

async def test_aget_by_ids(self, vs):
test_ids = [ids[0]]
results = await vs.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])

async def test_aget_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = await vs_custom.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])

def test_get_by_ids(self, vs):
test_ids = [ids[0]]
with pytest.raises(Exception, match=sync_method_exception_str):
vs.get_by_ids(ids=test_ids)
18 changes: 18 additions & 0 deletions tests/test_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ async def test_amax_marginal_relevance_search_vector_score(self, vs):
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

async def test_aget_by_ids(self, vs):
test_ids = [ids[0]]
results = await vs.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])

async def test_aget_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = await vs_custom.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])


class TestVectorStoreSearchSync:
@pytest.fixture(scope="module")
Expand Down Expand Up @@ -331,3 +343,9 @@ def test_max_marginal_relevance_search_vector_score(self, vs_custom):
embedding, lambda_mult=0.75, fetch_k=10
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

def test_get_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = vs_custom.get_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])