-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathworker.py
119 lines (96 loc) · 4.61 KB
/
worker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import torch
import logging
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from langchain_core.prompts import PromptTemplate # Updated import per deprecation notice
from langchain.chains import RetrievalQA
from langchain_community.embeddings import HuggingFaceInstructEmbeddings # New import path
from langchain_community.document_loaders import PyPDFLoader # New import path
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma # New import path
from langchain_ibm import WatsonxLLM
# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None
# Function to initialize the language model and its embeddings
def init_llm():
global llm_hub, embeddings
logger.info("Initializing WatsonxLLM and embeddings...")
# Llama Model Configuration
MODEL_ID = "meta-llama/llama-3-3-70b-instruct"
WATSONX_URL = "https://us-south.ml.cloud.ibm.com"
PROJECT_ID = "skills-network"
# Use the same parameters as before:
# MAX_NEW_TOKENS: 256, TEMPERATURE: 0.1
model_parameters = {
# "decoding_method": "greedy",
"max_new_tokens": 256,
"temperature": 0.1,
}
# Initialize Llama LLM using the updated WatsonxLLM API
llm_hub = WatsonxLLM(
model_id=MODEL_ID,
url=WATSONX_URL,
project_id=PROJECT_ID,
params=model_parameters
)
logger.debug("WatsonxLLM initialized: %s", llm_hub)
#Initialize embeddings using a pre-trained model to represent the text data.
embeddings = # create object of Hugging Face Instruct Embeddings with (model_name, model_kwargs={"device": DEVICE} )
logger.debug("Embeddings initialized with model device: %s", DEVICE)
# Function to process a PDF document
def process_document(document_path):
global conversation_retrieval_chain
logger.info("Loading document from path: %s", document_path)
# Load the document
loader = # ---> use PyPDFLoader and document_path from the function input parameter <---
documents = loader.load()
logger.debug("Loaded %d document(s)", len(documents))
# Split the document into chunks, set chunk_size=1024, and chunk_overlap=64. assign it to variable text_splitter
text_splitter = # ---> use Recursive Character TextSplitter and specify the input parameters <---
texts = text_splitter.split_documents(documents)
logger.debug("Document split into %d text chunks", len(texts))
# Create an embeddings database using Chroma from the split text chunks.
logger.info("Initializing Chroma vector store from documents...")
db = Chroma.from_documents(texts, embedding=embeddings)
logger.debug("Chroma vector store initialized.")
# Optional: Log available collections if accessible (this may be internal API)
try:
collections = db._client.list_collections() # _client is internal; adjust if needed
logger.debug("Available collections in Chroma: %s", collections)
except Exception as e:
logger.warning("Could not retrieve collections from Chroma: %s", e)
# Build the QA chain, which utilizes the LLM and retriever for answering questions.
conversation_retrieval_chain = RetrievalQA.from_chain_type(
llm=llm_hub,
chain_type="stuff",
retriever=db.as_retriever(search_type="mmr", search_kwargs={'k': 6, 'lambda_mult': 0.25}),
return_source_documents=False,
input_key="question"
# chain_type_kwargs={"prompt": prompt} # if you are using a prompt template, uncomment this part
)
logger.info("RetrievalQA chain created successfully.")
# Function to process a user prompt
def process_prompt(prompt):
global conversation_retrieval_chain
global chat_history
logger.info("Processing prompt: %s", prompt)
# Query the model using the new .invoke() method
output = conversation_retrieval_chain.invoke({"question": prompt, "chat_history": chat_history})
answer = output["result"]
logger.debug("Model response: %s", answer)
# Update the chat history
# TODO: Append the prompt and the bot's response to the chat history using chat_history.append and pass `prompt` `answer` as arguments
# --> write your code here <--
logger.debug("Chat history updated. Total exchanges: %d", len(chat_history))
# Return the model's response
return answer
# Initialize the language model
init_llm()
logger.info("LLM and embeddings initialization complete.")