Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/langchain_google_cloud_sql_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import json
import uuid
from typing import Any, Callable, Iterable, Optional, Sequence
from typing import Any, Callable, Iterable, List, Optional, Sequence

import numpy as np
from langchain_core.documents import Document
Expand Down Expand Up @@ -291,6 +291,54 @@ async def __aadd_embeddings(

return ids

async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use lower case "list" to follow guidance of https://peps.python.org/pep-0585/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolved

"""Get documents by ids."""

quoted_ids = [f"'{id_val}'" for id_val in ids]
id_list_str = ", ".join(quoted_ids)

columns = self.metadata_columns + [
self.id_column,
self.content_column,
]
if self.metadata_json_column:
columns.append(self.metadata_json_column)

column_names = ", ".join(f'"{col}"' for col in columns)

query = f'SELECT {column_names} FROM "{self.schema_name}"."{self.table_name}" WHERE {self.id_column} IN ({id_list_str});'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id_column should be in quotes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved


async with self.pool.connect() as conn:
result = await conn.execute(text(query))
result_map = result.mappings()
results = result_map.fetchall()

documents = []
for row in results:
metadata = (
row[self.metadata_json_column]
if self.metadata_json_column and row[self.metadata_json_column]
else {}
)
for col in self.metadata_columns:
metadata[col] = row[col]
documents.append(
(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering if we should cast this to a string just in case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we have that option, but wouldn't that negate the value of having customizable ID columns?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Document interface is enforcing the str Id in the document object https://github.com/langchain-ai/langchain/blob/33354f984fba660e71ca0039cfbd3cf37643cfab/libs/core/langchain_core/documents/base.py#L34. Users will still be able to insert and use the id data type of their choice

)
)
)

return documents

def get_by_ids(self, ids: Sequence[str]) -> List[Document]:
raise NotImplementedError(
"Sync methods are not implemented for AsyncAlloyDBVectorStore. Use AlloyDBVectorStore interface instead."
)

async def aadd_texts(
self,
texts: Iterable[str],
Expand Down
10 changes: 9 additions & 1 deletion src/langchain_google_cloud_sql_pg/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Iterable, List, Optional, Sequence

import numpy as np
from langchain_core.documents import Document
Expand Down Expand Up @@ -813,3 +813,11 @@ def is_valid_index(
) -> bool:
"""Check if index exists in the table."""
return self._engine._run_as_sync(self.__vs.is_valid_index(index_name))

async def aget_by_ids(self, ids: Sequence[str]) -> List[Document]:
"""Get documents by ids."""
return await self._engine._run_as_async(self.__vs.aget_by_ids(ids=ids))

def get_by_ids(self, ids: Sequence[str]) -> List[Document]:
"""Get documents by ids."""
return self._engine._run_as_sync(self.__vs.aget_by_ids(ids=ids))
12 changes: 12 additions & 0 deletions tests/test_async_vectorstore_search.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for get_by_ids ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added a test for the sync function in the async vectorstore

Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,15 @@ async def test_max_marginal_relevance_search_vector_score(self, vs_custom):
embedding, lambda_mult=0.75, fetch_k=10
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

async def test_aget_by_ids(self, vs):
test_ids = [ids[0]]
results = await vs.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])

async def test_aget_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = await vs_custom.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])
18 changes: 18 additions & 0 deletions tests/test_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ async def test_amax_marginal_relevance_search_vector_score(self, vs):
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

async def test_aget_by_ids(self, vs):
test_ids = [ids[0]]
results = await vs.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])

async def test_aget_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = await vs_custom.aget_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])


class TestVectorStoreSearchSync:
@pytest.fixture(scope="module")
Expand Down Expand Up @@ -331,3 +343,9 @@ def test_max_marginal_relevance_search_vector_score(self, vs_custom):
embedding, lambda_mult=0.75, fetch_k=10
)
assert results[0][0] == Document(page_content="bar", id=ids[1])

def test_get_by_ids_custom_vs(self, vs_custom):
test_ids = [ids[0]]
results = vs_custom.get_by_ids(ids=test_ids)

assert results[0] == Document(page_content="foo", id=ids[0])