33from langchain .agents import initialize_agent
44from langchain .chat_models import ChatOpenAI
55from langchain .chains import LLMChain
6+ from langchain .agents import AgentType
67from .qaprompts import select_paper_prompt , make_chain
8+ from rmrkl import ChatZeroShotAgent , RetryAgentExecutor
79
810
911def status (answer : Answer , docs : Docs ):
10- return f" Status: Current Papers: { len (docs .doc_previews ())} Current Evidence: { len (answer .contexts )} Current Cost: { answer .cost } "
12+ return f" Status: Current Papers: { len (docs .doc_previews ())} Current Evidence: { len (answer .contexts )} Current Cost: $ { answer .cost :.2f } "
1113
1214
1315class PaperSelection (BaseTool ):
@@ -40,8 +42,8 @@ class ReadPapers(BaseTool):
4042 name = "Gather Evidence"
4143 description = (
4244 "Give a specific question to a researcher that will return evidence for it. "
43- # "Optionally, you may specify papers using their key provided by the Select Papers tool. "
44- # "Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..."
45+ "Optionally, you may specify papers using their key provided by the Select Papers tool. "
46+ "Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..."
4547 )
4648 docs : Docs = None
4749 answer : Answer = None
@@ -54,12 +56,12 @@ def __init__(self, docs, answer):
5456 self .answer = answer
5557
5658 def _run (self , query : str ) -> str :
57- # if "|" in query:
58- # question, keys = query.split("|")
59- # keys = [k.strip() for k in keys.split(",")]
60- # else:
61- question = query
62- keys = None
59+ if "|" in query :
60+ question , keys = query .split ("|" )
61+ keys = [k .strip () for k in keys .split ("," )]
62+ else :
63+ question = query
64+ keys = None
6365 # swap out the question
6466 old = self .answer .question
6567 self .answer .question = question
@@ -90,11 +92,11 @@ def __init__(self, docs, answer):
9092
9193 def _run (self , query : str ) -> str :
9294 self .answer = self .docs .query (
93- query , answer = self .answer , length_prompt = "length as long as needed"
95+ query , answer = self .answer
9496 )
9597 if "cannot answer" in self .answer .answer :
9698 self .answer = Answer (self .answer .question )
97- return "Failed to answer question. Deleting evidence." + status (
99+ return "Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement. " + status (
98100 self .answer , self .docs
99101 )
100102 return self .answer .answer + status (self .answer , self .docs )
@@ -141,28 +143,24 @@ def _arun(self, query: str) -> str:
141143
142144
143145def make_tools (docs , answer ):
144- # putting here until langchain PR is merged
145- from langchain .tools .exception .tool import ExceptionTool
146146
147147 tools = []
148148
149149 tools .append (Search (docs , answer ))
150- # tools.append(PaperSelection(docs, answer))
150+ tools .append (PaperSelection (docs , answer ))
151151 tools .append (ReadPapers (docs , answer ))
152152 tools .append (AnswerTool (docs , answer ))
153- tools .append (ExceptionTool ())
154153 return tools
155154
156155
157156def run_agent (docs , question , llm = None ):
158157 if llm is None :
159- llm = ChatOpenAI (temperature = 0.1 , model = "gpt-4" )
158+ llm = ChatOpenAI (temperature = 0.0 , model = "gpt-4" )
160159 answer = Answer (question )
161160 tools = make_tools (docs , answer )
162- mrkl = initialize_agent (
163- tools ,
164- llm ,
165- agent = "chat-zero-shot-react-description" ,
161+ mrkl = RetryAgentExecutor .from_agent_and_tools (
162+ tools = tools ,
163+ agent = ChatZeroShotAgent .from_llm_and_tools (llm , tools ),
166164 verbose = True ,
167165 )
168166 mrkl .run (
0 commit comments