Skip to content

Commit 53d8313

Browse files
committed
[feat] try to use embedding from vllm
1 parent a736e79 commit 53d8313

File tree

5 files changed

+49
-428
lines changed

5 files changed

+49
-428
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ dependencies = [
1212
"langchain-neo4j>=0.4.0",
1313
"pandas>=2.2.3",
1414
"s3fs>=2024.12.0",
15-
"langchain-huggingface>=0.1.2",
1615
"uvicorn>=0.34.0",
1716
"fastapi>=0.115.12",
1817
"streamlit>=1.44.0",
18+
"langchain-openai>=0.3.11",
1919
]
2020
authors = [
2121
{name="Thomas Faria", email="[email protected]"}

src/constants/graph_db.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
NEO4J_URL = "neo4j://neo4j-585569.projet-ape:7687"
44
NEO4J_USERNAME = "neo4j"
55
NEO4J_PWD = os.environ["NEO4J_API_KEY"]
6-
EMBEDDING_MODEL = "OrdalieTech/Solon-embeddings-large-0.1"
6+
EMBEDDING_MODEL = "ordalieTech/Solon-embeddings-large-0.1"
7+
URL_EMBEDDING_API = "http://user-tfaria-vllm:8000/v1"

src/llm/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
@asynccontextmanager
1010
async def get_llm_client():
11-
client = AsyncOpenAI(api_key="EMPTY", base_url=URL_LLM_API, timeout=httpx.Timeout(30.0))
11+
client = AsyncOpenAI(api_key="EMPTY", base_url=URL_LLM_API, timeout=httpx.Timeout(1 * 60 * 60))
1212
try:
1313
yield client
1414
finally:

src/vector_db/loaders.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
22

3-
import torch
4-
from langchain_huggingface import HuggingFaceEmbeddings
53
from langchain_neo4j import Neo4jGraph, Neo4jVector
4+
from langchain_openai import OpenAIEmbeddings
65

7-
from constants.graph_db import EMBEDDING_MODEL, NEO4J_PWD, NEO4J_URL, NEO4J_USERNAME
6+
# from vector_db.openai_embeddings import CustomOpenAIEmbeddings
7+
from constants.graph_db import EMBEDDING_MODEL, NEO4J_PWD, NEO4J_URL, NEO4J_USERNAME, URL_EMBEDDING_API
88

99
logger = logging.getLogger(__name__)
1010

@@ -24,25 +24,16 @@ def setup_graph() -> Neo4jGraph:
2424
)
2525

2626

27-
def get_embedding_model(model_name: str) -> HuggingFaceEmbeddings:
28-
"""Initialize the HuggingFace embedding model."""
29-
30-
device = "cuda" if torch.cuda.is_available() else "cpu"
31-
32-
if device == "cpu":
33-
logger.info("No GPU found: running on CPU. The embedding step might be slow 🫠")
34-
elif device == "cuda":
35-
logger.info("Running on GPU 🚀")
36-
37-
return HuggingFaceEmbeddings(
38-
model_name=model_name,
39-
model_kwargs={"device": device},
40-
encode_kwargs={"normalize_embeddings": True},
41-
show_progress=False,
27+
def get_embedding_model(model_name: str) -> OpenAIEmbeddings:
28+
"""Initialize the embedding model."""
29+
return OpenAIEmbeddings(
30+
open=model_name,
31+
openai_base_url=URL_EMBEDDING_API,
32+
openai_api_key="EMPTY",
4233
)
4334

4435

45-
def get_vector_db() -> Neo4jVector:
36+
async def get_vector_db() -> Neo4jVector:
4637
"""Initialize the Neo4jVector Store from existing graph."""
4738
emb_model = get_embedding_model(EMBEDDING_MODEL)
4839
graph = setup_graph()

0 commit comments

Comments
 (0)