-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathWorker_completed.py
128 lines (104 loc) · 4.81 KB
/
Worker_completed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from langchain_core.prompts import PromptTemplate # Updated import per deprecation notice
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceInstructEmbeddings # New import path
from langchain_community.document_loaders import PyPDFLoader # New import path
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma # New import path
from langchain_ibm import WatsonxLLM
# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None
# Function to initialize the language model and its embeddings
def init_llm():
global llm_hub, embeddings
logger.info("Initializing WatsonxLLM and embeddings...")
# Llama Model Configuration
MODEL_ID = "meta-llama/llama-3-3-70b-instruct"
WATSONX_URL = "https://us-south.ml.cloud.ibm.com"
PROJECT_ID = "skills-network"
# Use the same parameters as before:
# MAX_NEW_TOKENS: 256, TEMPERATURE: 0.1
model_parameters = {
# "decoding_method": "greedy",
"max_new_tokens": 256,
"temperature": 0.1,
}
# Initialize Llama LLM using the updated WatsonxLLM API
llm_hub = WatsonxLLM(
model_id=MODEL_ID,
url=WATSONX_URL,
project_id=PROJECT_ID,
params=model_parameters
)
logger.debug("WatsonxLLM initialized: %s", llm_hub)
# Initialize embeddings using a pre-trained model to represent the text data.
### --> if you are using huggingFace API:
# Set up the environment variable for HuggingFace and initialize the desired model, and load the model into the HuggingFaceHub
# dont forget to remove llm_hub for watsonX
# os.environ["HUGGINGFACEHUB_API_TOKEN"] = "YOUR API KEY"
# model_id = "tiiuae/falcon-7b-instruct"
#llm_hub = HuggingFaceHub(repo_id=model_id, model_kwargs={"temperature": 0.1, "max_new_tokens": 600, "max_length": 600})
embeddings = HuggingFaceInstructEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": DEVICE}
)
logger.debug("Embeddings initialized with model device: %s", DEVICE)
# Function to process a PDF document
def process_document(document_path):
global conversation_retrieval_chain
logger.info("Loading document from path: %s", document_path)
# Load the document
loader = PyPDFLoader(document_path)
documents = loader.load()
logger.debug("Loaded %d document(s)", len(documents))
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
texts = text_splitter.split_documents(documents)
logger.debug("Document split into %d text chunks", len(texts))
# Create an embeddings database using Chroma from the split text chunks.
logger.info("Initializing Chroma vector store from documents...")
db = Chroma.from_documents(texts, embedding=embeddings)
logger.debug("Chroma vector store initialized.")
# Optional: Log available collections if accessible (this may be internal API)
try:
collections = db._client.list_collections() # _client is internal; adjust if needed
logger.debug("Available collections in Chroma: %s", collections)
except Exception as e:
logger.warning("Could not retrieve collections from Chroma: %s", e)
# Build the QA chain, which utilizes the LLM and retriever for answering questions.
conversation_retrieval_chain = RetrievalQA.from_chain_type(
llm=llm_hub,
chain_type="stuff",
retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
return_source_documents=False,
input_key="question"
# chain_type_kwargs={"prompt": prompt} # if you are using a prompt template, uncomment this part
)
logger.info("RetrievalQA chain created successfully.")
# Function to process a user prompt
def process_prompt(prompt):
global conversation_retrieval_chain
global chat_history
logger.info("Processing prompt: %s", prompt)
# Query the model using the new .invoke() method
output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
answer = output["result"]
logger.debug("Model response: %s", answer)
# Update the chat history
chat_history.append((prompt, answer))
logger.debug("Chat history updated. Total exchanges: %d", len(chat_history))
# Return the model's response
return answer
# Initialize the language model
init_llm()
logger.info("LLM and embeddings initialization complete.")