4848from langchain_core .embeddings import Embeddings
4949from langchain_core .vectorstores import VectorStore
5050
51+ from ..embeddings import OracleEmbeddings
52+
5153logger = logging .getLogger (__name__ )
5254log_level = os .getenv ("LOG_LEVEL" , "ERROR" ).upper ()
5355logging .basicConfig (
@@ -862,9 +864,9 @@ def _get_similarity_search_query(
862864 k : int ,
863865 db_filter : Optional [FilterGroup | FilterCondition ] = None ,
864866 return_embeddings : bool = False ,
865- ) -> str :
867+ ) -> Tuple [ str , list [ str ]] :
866868 where_clause = ""
867- bind_variables = []
869+ bind_variables : list [ str ] = []
868870 if db_filter :
869871 where_clause = _generate_where_clause (db_filter , bind_variables )
870872
@@ -1142,33 +1144,61 @@ def add_texts(
11421144 texts = list (texts )
11431145 processed_ids = get_processed_ids (texts , metadatas , ids )
11441146
1145- embeddings = self ._embed_documents (texts )
11461147 if not metadatas :
11471148 metadatas = [{} for _ in texts ]
11481149
1149- docs : List [Tuple [Any , Any , Any , Any ]] = [
1150- (
1151- id_ ,
1152- array .array ("f" , embedding ),
1153- metadata ,
1154- text ,
1155- )
1156- for id_ , embedding , metadata , text in zip (
1157- processed_ids , embeddings , metadatas , texts
1158- )
1159- ]
1150+ docs : Any
1151+ if not isinstance (self .embeddings , OracleEmbeddings ):
1152+ embeddings = self ._embed_documents (texts )
1153+
1154+ docs = [
1155+ (
1156+ id_ ,
1157+ array .array ("f" , embedding ),
1158+ metadata ,
1159+ text ,
1160+ )
1161+ for id_ , embedding , metadata , text in zip (
1162+ processed_ids , embeddings , metadatas , texts
1163+ )
1164+ ]
1165+ else :
1166+ docs = list (zip (processed_ids , metadatas , texts ))
11601167
11611168 connection = _get_connection (self .client )
11621169 if connection is None :
11631170 raise ValueError ("Failed to acquire a connection." )
11641171 with connection .cursor () as cursor :
1165- cursor .setinputsizes (None , None , oracledb .DB_TYPE_JSON , None )
1166- cursor .executemany (
1167- f"INSERT INTO { self .table_name } (id, embedding, metadata, "
1168- f"text) VALUES (:1, :2, :3, :4)" ,
1169- docs ,
1170- )
1171- connection .commit ()
1172+ if not isinstance (self .embeddings , OracleEmbeddings ):
1173+ cursor .setinputsizes (None , None , oracledb .DB_TYPE_JSON , None )
1174+ cursor .executemany (
1175+ f"INSERT INTO { self .table_name } (id, embedding, metadata, "
1176+ f"text) VALUES (:1, :2, :3, :4)" ,
1177+ docs ,
1178+ )
1179+ connection .commit ()
1180+ else :
1181+ if self .embeddings .proxy :
1182+ cursor .execute (
1183+ "begin utl_http.set_proxy(:proxy); end;" ,
1184+ proxy = self .embeddings .proxy ,
1185+ )
1186+
1187+ cursor .setinputsizes (None , oracledb .DB_TYPE_JSON , None )
1188+ cursor .executemany (
1189+ f"INSERT INTO { self .table_name } (id, metadata, "
1190+ f"text) VALUES (:1, :2, :3)" ,
1191+ docs ,
1192+ )
1193+
1194+ cursor .setinputsizes (oracledb .DB_TYPE_JSON )
1195+ update_sql = (
1196+ f"UPDATE { self .table_name } "
1197+ "SET embedding = dbms_vector_chain.utl_to_embedding(text, json(:1))"
1198+ )
1199+ cursor .execute (update_sql , [self .embeddings .params ])
1200+ connection .commit ()
1201+
11721202 return processed_ids
11731203
11741204 @_ahandle_exceptions
@@ -1192,33 +1222,60 @@ async def aadd_texts(
11921222 texts = list (texts )
11931223 processed_ids = get_processed_ids (texts , metadatas , ids )
11941224
1195- embeddings = await self ._aembed_documents (texts )
11961225 if not metadatas :
11971226 metadatas = [{} for _ in texts ]
11981227
1199- docs : List [Tuple [Any , Any , Any , Any ]] = [
1200- (
1201- id_ ,
1202- array .array ("f" , embedding ),
1203- metadata ,
1204- text ,
1205- )
1206- for id_ , embedding , metadata , text in zip (
1207- processed_ids , embeddings , metadatas , texts
1208- )
1209- ]
1228+ docs : Any
1229+ if not isinstance (self .embeddings , OracleEmbeddings ):
1230+ embeddings = await self ._aembed_documents (texts )
1231+
1232+ docs = [
1233+ (
1234+ id_ ,
1235+ array .array ("f" , embedding ),
1236+ metadata ,
1237+ text ,
1238+ )
1239+ for id_ , embedding , metadata , text in zip (
1240+ processed_ids , embeddings , metadatas , texts
1241+ )
1242+ ]
1243+ else :
1244+ docs = list (zip (processed_ids , metadatas , texts ))
12101245
12111246 async def context (connection : Any ) -> None :
12121247 if connection is None :
12131248 raise ValueError ("Failed to acquire a connection." )
12141249 with connection .cursor () as cursor :
1215- cursor .setinputsizes (None , None , oracledb .DB_TYPE_JSON , None )
1216- await cursor .executemany (
1217- f"INSERT INTO { self .table_name } (id, embedding, metadata, "
1218- f"text) VALUES (:1, :2, :3, :4)" ,
1219- docs ,
1220- )
1221- await connection .commit ()
1250+ if not isinstance (self .embeddings , OracleEmbeddings ):
1251+ cursor .setinputsizes (None , None , oracledb .DB_TYPE_JSON , None )
1252+ await cursor .executemany (
1253+ f"INSERT INTO { self .table_name } (id, embedding, metadata, "
1254+ f"text) VALUES (:1, :2, :3, :4)" ,
1255+ docs ,
1256+ )
1257+ await connection .commit ()
1258+ else :
1259+ if self .embeddings .proxy :
1260+ await cursor .execute (
1261+ "begin utl_http.set_proxy(:proxy); end;" ,
1262+ proxy = self .embeddings .proxy ,
1263+ )
1264+
1265+ cursor .setinputsizes (None , oracledb .DB_TYPE_JSON , None )
1266+ await cursor .executemany (
1267+ f"INSERT INTO { self .table_name } (id, metadata, "
1268+ f"text) VALUES (:1, :2, :3)" ,
1269+ docs ,
1270+ )
1271+
1272+ cursor .setinputsizes (oracledb .DB_TYPE_JSON )
1273+ update_sql = (
1274+ f"UPDATE { self .table_name } "
1275+ "SET embedding = dbms_vector_chain.utl_to_embedding(text, json(:1))" # noqa: E501
1276+ )
1277+ await cursor .execute (update_sql , [self .embeddings .params ])
1278+ await connection .commit ()
12221279
12231280 await _handle_context (self .client , context )
12241281
@@ -1628,7 +1685,6 @@ async def amax_marginal_relevance_search_with_score_by_vector(
16281685 def max_marginal_relevance_search_by_vector (
16291686 self ,
16301687 embedding : List [float ],
1631- * ,
16321688 k : int = 4 ,
16331689 fetch_k : int = 20 ,
16341690 lambda_mult : float = 0.5 ,
0 commit comments