Skip to content

Commit 32c75cc

Browse files
committed
Add OracleEmbeddings optimization
1 parent 1cd111f commit 32c75cc

File tree

2 files changed

+172
-41
lines changed

2 files changed

+172
-41
lines changed

libs/oracledb/langchain_oracledb/vectorstores/oraclevs.py

Lines changed: 97 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from langchain_core.embeddings import Embeddings
4949
from langchain_core.vectorstores import VectorStore
5050

51+
from ..embeddings import OracleEmbeddings
52+
5153
logger = logging.getLogger(__name__)
5254
log_level = os.getenv("LOG_LEVEL", "ERROR").upper()
5355
logging.basicConfig(
@@ -862,9 +864,9 @@ def _get_similarity_search_query(
862864
k: int,
863865
db_filter: Optional[FilterGroup | FilterCondition] = None,
864866
return_embeddings: bool = False,
865-
) -> str:
867+
) -> Tuple[str, list[str]]:
866868
where_clause = ""
867-
bind_variables = []
869+
bind_variables: list[str] = []
868870
if db_filter:
869871
where_clause = _generate_where_clause(db_filter, bind_variables)
870872

@@ -1142,33 +1144,61 @@ def add_texts(
11421144
texts = list(texts)
11431145
processed_ids = get_processed_ids(texts, metadatas, ids)
11441146

1145-
embeddings = self._embed_documents(texts)
11461147
if not metadatas:
11471148
metadatas = [{} for _ in texts]
11481149

1149-
docs: List[Tuple[Any, Any, Any, Any]] = [
1150-
(
1151-
id_,
1152-
array.array("f", embedding),
1153-
metadata,
1154-
text,
1155-
)
1156-
for id_, embedding, metadata, text in zip(
1157-
processed_ids, embeddings, metadatas, texts
1158-
)
1159-
]
1150+
docs: Any
1151+
if not isinstance(self.embeddings, OracleEmbeddings):
1152+
embeddings = self._embed_documents(texts)
1153+
1154+
docs = [
1155+
(
1156+
id_,
1157+
array.array("f", embedding),
1158+
metadata,
1159+
text,
1160+
)
1161+
for id_, embedding, metadata, text in zip(
1162+
processed_ids, embeddings, metadatas, texts
1163+
)
1164+
]
1165+
else:
1166+
docs = list(zip(processed_ids, metadatas, texts))
11601167

11611168
connection = _get_connection(self.client)
11621169
if connection is None:
11631170
raise ValueError("Failed to acquire a connection.")
11641171
with connection.cursor() as cursor:
1165-
cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None)
1166-
cursor.executemany(
1167-
f"INSERT INTO {self.table_name} (id, embedding, metadata, "
1168-
f"text) VALUES (:1, :2, :3, :4)",
1169-
docs,
1170-
)
1171-
connection.commit()
1172+
if not isinstance(self.embeddings, OracleEmbeddings):
1173+
cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None)
1174+
cursor.executemany(
1175+
f"INSERT INTO {self.table_name} (id, embedding, metadata, "
1176+
f"text) VALUES (:1, :2, :3, :4)",
1177+
docs,
1178+
)
1179+
connection.commit()
1180+
else:
1181+
if self.embeddings.proxy:
1182+
cursor.execute(
1183+
"begin utl_http.set_proxy(:proxy); end;",
1184+
proxy=self.embeddings.proxy,
1185+
)
1186+
1187+
cursor.setinputsizes(None, oracledb.DB_TYPE_JSON, None)
1188+
cursor.executemany(
1189+
f"INSERT INTO {self.table_name} (id, metadata, "
1190+
f"text) VALUES (:1, :2, :3)",
1191+
docs,
1192+
)
1193+
1194+
cursor.setinputsizes(oracledb.DB_TYPE_JSON)
1195+
update_sql = (
1196+
f"UPDATE {self.table_name} "
1197+
"SET embedding = dbms_vector_chain.utl_to_embedding(text, json(:1))"
1198+
)
1199+
cursor.execute(update_sql, [self.embeddings.params])
1200+
connection.commit()
1201+
11721202
return processed_ids
11731203

11741204
@_ahandle_exceptions
@@ -1192,33 +1222,60 @@ async def aadd_texts(
11921222
texts = list(texts)
11931223
processed_ids = get_processed_ids(texts, metadatas, ids)
11941224

1195-
embeddings = await self._aembed_documents(texts)
11961225
if not metadatas:
11971226
metadatas = [{} for _ in texts]
11981227

1199-
docs: List[Tuple[Any, Any, Any, Any]] = [
1200-
(
1201-
id_,
1202-
array.array("f", embedding),
1203-
metadata,
1204-
text,
1205-
)
1206-
for id_, embedding, metadata, text in zip(
1207-
processed_ids, embeddings, metadatas, texts
1208-
)
1209-
]
1228+
docs: Any
1229+
if not isinstance(self.embeddings, OracleEmbeddings):
1230+
embeddings = await self._aembed_documents(texts)
1231+
1232+
docs = [
1233+
(
1234+
id_,
1235+
array.array("f", embedding),
1236+
metadata,
1237+
text,
1238+
)
1239+
for id_, embedding, metadata, text in zip(
1240+
processed_ids, embeddings, metadatas, texts
1241+
)
1242+
]
1243+
else:
1244+
docs = list(zip(processed_ids, metadatas, texts))
12101245

12111246
async def context(connection: Any) -> None:
12121247
if connection is None:
12131248
raise ValueError("Failed to acquire a connection.")
12141249
with connection.cursor() as cursor:
1215-
cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None)
1216-
await cursor.executemany(
1217-
f"INSERT INTO {self.table_name} (id, embedding, metadata, "
1218-
f"text) VALUES (:1, :2, :3, :4)",
1219-
docs,
1220-
)
1221-
await connection.commit()
1250+
if not isinstance(self.embeddings, OracleEmbeddings):
1251+
cursor.setinputsizes(None, None, oracledb.DB_TYPE_JSON, None)
1252+
await cursor.executemany(
1253+
f"INSERT INTO {self.table_name} (id, embedding, metadata, "
1254+
f"text) VALUES (:1, :2, :3, :4)",
1255+
docs,
1256+
)
1257+
await connection.commit()
1258+
else:
1259+
if self.embeddings.proxy:
1260+
await cursor.execute(
1261+
"begin utl_http.set_proxy(:proxy); end;",
1262+
proxy=self.embeddings.proxy,
1263+
)
1264+
1265+
cursor.setinputsizes(None, oracledb.DB_TYPE_JSON, None)
1266+
await cursor.executemany(
1267+
f"INSERT INTO {self.table_name} (id, metadata, "
1268+
f"text) VALUES (:1, :2, :3)",
1269+
docs,
1270+
)
1271+
1272+
cursor.setinputsizes(oracledb.DB_TYPE_JSON)
1273+
update_sql = (
1274+
f"UPDATE {self.table_name} "
1275+
"SET embedding = dbms_vector_chain.utl_to_embedding(text, json(:1))" # noqa: E501
1276+
)
1277+
await cursor.execute(update_sql, [self.embeddings.params])
1278+
await connection.commit()
12221279

12231280
await _handle_context(self.client, context)
12241281

@@ -1628,7 +1685,6 @@ async def amax_marginal_relevance_search_with_score_by_vector(
16281685
def max_marginal_relevance_search_by_vector(
16291686
self,
16301687
embedding: List[float],
1631-
*,
16321688
k: int = 4,
16331689
fetch_k: int = 20,
16341690
lambda_mult: float = 0.5,

libs/oracledb/tests/integration_tests/vectorstores/test_oraclevs.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langchain_community.embeddings import HuggingFaceEmbeddings
1919
from langchain_community.vectorstores.utils import DistanceStrategy
2020

21+
from langchain_oracledb.embeddings import OracleEmbeddings
2122
from langchain_oracledb.vectorstores.oraclevs import (
2223
FilterCondition,
2324
FilterGroup,
@@ -2479,6 +2480,80 @@ async def test_index_table_case_async(caplog: pytest.LogCaptureFixture) -> None:
24792480
await adrop_table_purge(connection, "TB2")
24802481

24812482

2483+
##################################
2484+
##### test_oracle_embeddings ####
2485+
##################################
2486+
2487+
2488+
def test_oracle_embeddings() -> None:
2489+
try:
2490+
connection = oracledb.connect(user=username, password=password, dsn=dsn)
2491+
except Exception:
2492+
sys.exit(1)
2493+
2494+
drop_table_purge(connection, "TB1")
2495+
2496+
texts = ["Database Document", "Code Document"]
2497+
metadata = [
2498+
{"id": "100", "link": "Document Example Test 1"},
2499+
{"id": "101", "link": "Document Example Test 2"},
2500+
]
2501+
embedder_params = {"provider": "database", "model": "allminilm"}
2502+
proxy = ""
2503+
2504+
# instance
2505+
model = OracleEmbeddings(conn=connection, params=embedder_params, proxy=proxy)
2506+
2507+
vs_obj = OracleVS(connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE)
2508+
2509+
vs_obj.add_texts(texts, metadata)
2510+
res = vs_obj.similarity_search("database", 1)
2511+
2512+
assert "Database" in res[0].page_content
2513+
2514+
drop_table_purge(connection, "TB1")
2515+
2516+
connection.close()
2517+
2518+
2519+
@pytest.mark.asyncio
2520+
async def test_oracle_embeddings_async(caplog: pytest.LogCaptureFixture) -> None:
2521+
try:
2522+
connection = await oracledb.connect_async(
2523+
user=username, password=password, dsn=dsn
2524+
)
2525+
2526+
connection_sync = oracledb.connect(user=username, password=password, dsn=dsn)
2527+
except Exception:
2528+
sys.exit(1)
2529+
2530+
await adrop_table_purge(connection, "TB1")
2531+
2532+
texts = ["Database Document", "Code Document"]
2533+
metadata = [
2534+
{"id": "100", "link": "Document Example Test 1"},
2535+
{"id": "101", "link": "Document Example Test 2"},
2536+
]
2537+
embedder_params = {"provider": "database", "model": "allminilm"}
2538+
proxy = ""
2539+
2540+
# instance
2541+
model = OracleEmbeddings(conn=connection_sync, params=embedder_params, proxy=proxy)
2542+
2543+
vs_obj = await OracleVS.acreate(
2544+
connection, model, "TB1", DistanceStrategy.EUCLIDEAN_DISTANCE
2545+
)
2546+
2547+
await vs_obj.aadd_texts(texts, metadata)
2548+
res = await vs_obj.asimilarity_search("database", 1)
2549+
2550+
assert "Database" in res[0].page_content
2551+
2552+
await adrop_table_purge(connection, "TB1")
2553+
2554+
await connection.close()
2555+
2556+
24822557
##################################
24832558
##### test_quote_identifier #####
24842559
##################################

0 commit comments

Comments
 (0)