Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ test = [
"mypy==1.13.0",
"pytest-asyncio==0.24.0",
"pytest==8.3.4",
"pytest-cov==6.0.0"
"pytest-cov==6.0.0",
"langchain-tests==0.3.12"
]

[build-system]
Expand Down
18 changes: 17 additions & 1 deletion src/langchain_google_cloud_sql_pg/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ async def __aadd_embeddings(
"""
if not ids:
ids = [str(uuid.uuid4()) for _ in texts]
else:
ids = [id if id is not None else str(uuid.uuid4()) for id in ids]
if not metadatas:
metadatas = [{} for _ in texts]
# Insert embeddings
Expand Down Expand Up @@ -271,7 +273,17 @@ async def __aadd_embeddings(
else:
values_stmt += ")"

query = insert_stmt + values_stmt
upsert_stmt = f' ON CONFLICT ("{self.id_column}") DO UPDATE SET "{self.content_column}" = EXCLUDED."{self.content_column}", "{self.embedding_column}" = EXCLUDED."{self.embedding_column}"'

if self.metadata_json_column:
upsert_stmt += f', "{self.metadata_json_column}" = EXCLUDED."{self.metadata_json_column}"'

for column in self.metadata_columns:
upsert_stmt += f', "{column}" = EXCLUDED."{column}"'

upsert_stmt += ";"

query = insert_stmt + values_stmt + upsert_stmt
async with self.pool.connect() as conn:
await conn.execute(text(query), values)
await conn.commit()
Expand Down Expand Up @@ -309,6 +321,8 @@ async def aadd_documents(
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
if not ids:
ids = [doc.id for doc in documents]
ids = await self.aadd_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return ids

Expand Down Expand Up @@ -593,6 +607,7 @@ async def asimilarity_search_with_score_by_vector(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
),
row["distance"],
)
Expand Down Expand Up @@ -683,6 +698,7 @@ async def amax_marginal_relevance_search_with_score_by_vector(
Document(
page_content=row[self.content_column],
metadata=metadata,
id=row[self.id_column],
),
row["distance"],
)
Expand Down
45 changes: 22 additions & 23 deletions tests/test_async_vectorstore_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ async def vs(self, engine):
embedding_service=embeddings_service,
table_name=DEFAULT_TABLE,
)
ids = [str(uuid.uuid4()) for i in range(len(texts))]
await vs.aadd_documents(docs, ids=ids)
yield vs

Expand Down Expand Up @@ -132,23 +131,23 @@ async def vs_custom(self, engine):
async def test_asimilarity_search(self, vs):
results = await vs.asimilarity_search("foo", k=1)
assert len(results) == 1
assert results == [Document(page_content="foo")]
assert results == [Document(page_content="foo", id=ids[0])]
results = await vs.asimilarity_search("foo", k=1, filter="content = 'bar'")
assert results == [Document(page_content="bar")]
assert results == [Document(page_content="bar", id=ids[1])]

async def test_asimilarity_search_score(self, vs):
results = await vs.asimilarity_search_with_score("foo")
assert len(results) == 4
assert results[0][0] == Document(page_content="foo")
assert results[0][0] == Document(page_content="foo", id=ids[0])
assert results[0][1] == 0

async def test_asimilarity_search_by_vector(self, vs):
embedding = embeddings_service.embed_query("foo")
results = await vs.asimilarity_search_by_vector(embedding)
assert len(results) == 4
assert results[0] == Document(page_content="foo")
assert results[0] == Document(page_content="foo", id=ids[0])
results = await vs.asimilarity_search_with_score_by_vector(embedding)
assert results[0][0] == Document(page_content="foo")
assert results[0][0] == Document(page_content="foo", id=ids[0])
assert results[0][1] == 0

async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs):
Expand All @@ -171,7 +170,7 @@ async def test_similarity_search_with_relevance_scores_threshold_cosine(self, vs
"foo", **score_threshold
)
assert len(results) == 1
assert results[0][0] == Document(page_content="foo")
assert results[0][0] == Document(page_content="foo", id=ids[0])

score_threshold = {"score_threshold": 0.02}
vs.distance_strategy = DistanceStrategy.EUCLIDEAN
Expand All @@ -195,78 +194,78 @@ async def test_similarity_search_with_relevance_scores_threshold_euclidean(
"foo", **score_threshold
)
assert len(results) == 1
assert results[0][0] == Document(page_content="foo")
assert results[0][0] == Document(page_content="foo", id=ids[0])

async def test_amax_marginal_relevance_search(self, vs):
results = await vs.amax_marginal_relevance_search("bar")
assert results[0] == Document(page_content="bar")
assert results[0] == Document(page_content="bar", id=ids[1])
results = await vs.amax_marginal_relevance_search(
"bar", filter="content = 'boo'"
)
assert results[0] == Document(page_content="boo")
assert results[0] == Document(page_content="boo", id=ids[3])

async def test_amax_marginal_relevance_search_vector(self, vs):
embedding = embeddings_service.embed_query("bar")
results = await vs.amax_marginal_relevance_search_by_vector(embedding)
assert results[0] == Document(page_content="bar")
assert results[0] == Document(page_content="bar", id=ids[1])

async def test_amax_marginal_relevance_search_vector_score(self, vs):
embedding = embeddings_service.embed_query("bar")
results = await vs.amax_marginal_relevance_search_with_score_by_vector(
embedding
)
assert results[0][0] == Document(page_content="bar")
assert results[0][0] == Document(page_content="bar", id=ids[1])

results = await vs.amax_marginal_relevance_search_with_score_by_vector(
embedding, lambda_mult=0.75, fetch_k=10
)
assert results[0][0] == Document(page_content="bar")
assert results[0][0] == Document(page_content="bar", id=ids[1])

async def test_similarity_search(self, vs_custom):
results = await vs_custom.asimilarity_search("foo", k=1)
assert len(results) == 1
assert results == [Document(page_content="foo")]
assert results == [Document(page_content="foo", id=ids[0])]
results = await vs_custom.asimilarity_search(
"foo", k=1, filter="mycontent = 'bar'"
)
assert results == [Document(page_content="bar")]
assert results == [Document(page_content="bar", id=ids[1])]

async def test_similarity_search_score(self, vs_custom):
results = await vs_custom.asimilarity_search_with_score("foo")
assert len(results) == 4
assert results[0][0] == Document(page_content="foo")
assert results[0][0] == Document(page_content="foo", id=ids[0])
assert results[0][1] == 0

async def test_similarity_search_by_vector(self, vs_custom):
embedding = embeddings_service.embed_query("foo")
results = await vs_custom.asimilarity_search_by_vector(embedding)
assert len(results) == 4
assert results[0] == Document(page_content="foo")
assert results[0] == Document(page_content="foo", id=ids[0])
results = await vs_custom.asimilarity_search_with_score_by_vector(embedding)
assert results[0][0] == Document(page_content="foo")
assert results[0][0] == Document(page_content="foo", id=ids[0])
assert results[0][1] == 0

async def test_max_marginal_relevance_search(self, vs_custom):
results = await vs_custom.amax_marginal_relevance_search("bar")
assert results[0] == Document(page_content="bar")
assert results[0] == Document(page_content="bar", id=ids[1])
results = await vs_custom.amax_marginal_relevance_search(
"bar", filter="mycontent = 'boo'"
)
assert results[0] == Document(page_content="boo")
assert results[0] == Document(page_content="boo", id=ids[3])

async def test_max_marginal_relevance_search_vector(self, vs_custom):
embedding = embeddings_service.embed_query("bar")
results = await vs_custom.amax_marginal_relevance_search_by_vector(embedding)
assert results[0] == Document(page_content="bar")
assert results[0] == Document(page_content="bar", id=ids[1])

async def test_max_marginal_relevance_search_vector_score(self, vs_custom):
embedding = embeddings_service.embed_query("bar")
results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector(
embedding
)
assert results[0][0] == Document(page_content="bar")
assert results[0][0] == Document(page_content="bar", id=ids[1])

results = await vs_custom.amax_marginal_relevance_search_with_score_by_vector(
embedding, lambda_mult=0.75, fetch_k=10
)
assert results[0][0] == Document(page_content="bar")
assert results[0][0] == Document(page_content="bar", id=ids[1])
159 changes: 159 additions & 0 deletions tests/test_standard_test_suite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import uuid

import pytest
import pytest_asyncio
from langchain_tests.integration_tests import VectorStoreIntegrationTests
from langchain_tests.integration_tests.vectorstores import EMBEDDING_SIZE
from sqlalchemy import text

from langchain_google_cloud_sql_pg import Column, PostgresEngine, PostgresVectorStore

DEFAULT_TABLE = "test_table_standard_test_suite" + str(uuid.uuid4())
DEFAULT_TABLE_SYNC = "test_table_sync_standard_test_suite" + str(uuid.uuid4())


def get_env_var(key: str, desc: str) -> str:
v = os.environ.get(key)
if v is None:
raise ValueError(f"Must set env var {key} to: {desc}")
return v


async def aexecute(
engine: PostgresEngine,
query: str,
) -> None:
async def run(engine, query):
async with engine._pool.connect() as conn:
await conn.execute(text(query))
await conn.commit()

await engine._run_as_async(run(engine, query))


@pytest.mark.filterwarnings("ignore")
@pytest.mark.asyncio
class TestStandardSuiteSync(VectorStoreIntegrationTests):
@pytest.fixture(scope="module")
def db_project(self) -> str:
return get_env_var("PROJECT_ID", "project id for google cloud")

@pytest.fixture(scope="module")
def db_region(self) -> str:
return get_env_var("REGION", "region for Cloud SQL instance")

@pytest.fixture(scope="module")
def db_instance(self) -> str:
return get_env_var("INSTANCE_ID", "instance for Cloud SQL")

@pytest.fixture(scope="module")
def db_name(self) -> str:
return get_env_var("DATABASE_ID", "database name on Cloud SQL instance")

@pytest.fixture(scope="module")
def user(self) -> str:
return get_env_var("DB_USER", "database user for Cloud SQL")

@pytest.fixture(scope="module")
def password(self) -> str:
return get_env_var("DB_PASSWORD", "database password for Cloud SQL")

@pytest_asyncio.fixture(loop_scope="function")
async def sync_engine(self, db_project, db_region, db_instance, db_name):
sync_engine = PostgresEngine.from_instance(
project_id=db_project,
instance=db_instance,
region=db_region,
database=db_name,
)
yield sync_engine
await aexecute(sync_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"')
await sync_engine.close()

@pytest.fixture(scope="function")
def vectorstore(self, sync_engine):
"""Get an empty vectorstore for unit tests."""
sync_engine.init_vectorstore_table(
DEFAULT_TABLE_SYNC,
EMBEDDING_SIZE,
id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False),
)

vs = PostgresVectorStore.create_sync(
sync_engine,
embedding_service=self.get_embeddings(),
table_name=DEFAULT_TABLE_SYNC,
)
yield vs


@pytest.mark.filterwarnings("ignore")
@pytest.mark.asyncio
class TestStandardSuiteAsync(VectorStoreIntegrationTests):
@pytest.fixture(scope="module")
def db_project(self) -> str:
return get_env_var("PROJECT_ID", "project id for google cloud")

@pytest.fixture(scope="module")
def db_region(self) -> str:
return get_env_var("REGION", "region for Cloud SQL instance")

@pytest.fixture(scope="module")
def db_instance(self) -> str:
return get_env_var("INSTANCE_ID", "instance for Cloud SQL")

@pytest.fixture(scope="module")
def db_name(self) -> str:
return get_env_var("DATABASE_ID", "database name on Cloud SQL instance")

@pytest.fixture(scope="module")
def user(self) -> str:
return get_env_var("DB_USER", "database user for Cloud SQL")

@pytest.fixture(scope="module")
def password(self) -> str:
return get_env_var("DB_PASSWORD", "database password for Cloud SQL")

@pytest_asyncio.fixture(loop_scope="function")
async def async_engine(self, db_project, db_region, db_instance, db_name):
async_engine = await PostgresEngine.afrom_instance(
project_id=db_project,
instance=db_instance,
region=db_region,
database=db_name,
)
yield async_engine
await aexecute(async_engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"')
await async_engine.close()

@pytest_asyncio.fixture(loop_scope="function")
async def vectorstore(self, async_engine):
"""Get an empty vectorstore for unit tests."""
await async_engine.ainit_vectorstore_table(
DEFAULT_TABLE,
EMBEDDING_SIZE,
id_column=Column(name="langchain_id", data_type="VARCHAR", nullable=False),
)

vs = await PostgresVectorStore.create(
async_engine,
embedding_service=self.get_embeddings(),
table_name=DEFAULT_TABLE,
)

yield vs
Loading