Skip to content

Commit ab28ac2

Browse files
committed
Updated lineendings & bumped langchain version
1 parent 89c8ea8 commit ab28ac2

File tree

8 files changed

+1357
-1357
lines changed

8 files changed

+1357
-1357
lines changed

paperqa/agent.py

Lines changed: 171 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,171 +1,171 @@
1-
from langchain.agents import AgentType, initialize_agent
2-
from langchain.chains import LLMChain
3-
from langchain.chat_models import ChatOpenAI
4-
from langchain.tools import BaseTool
5-
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor
6-
7-
from .docs import Answer, Docs
8-
from .qaprompts import make_chain, select_paper_prompt
9-
10-
11-
def status(answer: Answer, docs: Docs):
12-
return f" Status: Current Papers: {len(docs.doc_previews())} Current Evidence: {len(answer.contexts)} Current Cost: ${answer.cost:.2f}"
13-
14-
15-
class PaperSelection(BaseTool):
16-
name = "Select Papers"
17-
description = "Select from current papers. Provide instructions as a string to use for choosing papers."
18-
docs: Docs = None
19-
answer: Answer = None
20-
chain: LLMChain = None
21-
22-
def __init__(self, docs, answer):
23-
# call the parent class constructor
24-
super(PaperSelection, self).__init__()
25-
26-
self.docs = docs
27-
self.answer = answer
28-
self.chain = make_chain(select_paper_prompt, self.docs.summary_llm)
29-
30-
def _run(self, query: str) -> str:
31-
result = self.docs.doc_match(query)
32-
if result is None or result.strip().startswith("None"):
33-
return "No relevant papers found."
34-
return result + status(self.answer, self.docs)
35-
36-
async def _arun(self, query: str) -> str:
37-
"""Use the tool asynchronously."""
38-
raise NotImplementedError()
39-
40-
41-
class ReadPapers(BaseTool):
42-
name = "Gather Evidence"
43-
description = (
44-
"Give a specific question to a researcher that will return evidence for it. "
45-
"Optionally, you may specify papers using their key provided by the Select Papers tool. "
46-
"Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..."
47-
)
48-
docs: Docs = None
49-
answer: Answer = None
50-
51-
def __init__(self, docs, answer):
52-
# call the parent class constructor
53-
super(ReadPapers, self).__init__()
54-
55-
self.docs = docs
56-
self.answer = answer
57-
58-
def _run(self, query: str) -> str:
59-
if "|" in query:
60-
question, keys = query.split("|")
61-
keys = [k.strip() for k in keys.split(",")]
62-
else:
63-
question = query
64-
keys = None
65-
# swap out the question
66-
old = self.answer.question
67-
self.answer.question = question
68-
# generator, so run it
69-
l0 = len(self.answer.contexts)
70-
self.docs.get_evidence(self.answer, key_filter=keys)
71-
l1 = len(self.answer.contexts)
72-
self.answer.question = old
73-
return f"Added {l1 - l0} pieces of evidence." + status(self.answer, self.docs)
74-
75-
async def _arun(self, query: str) -> str:
76-
"""Use the tool asynchronously."""
77-
raise NotImplementedError()
78-
79-
80-
class AnswerTool(BaseTool):
81-
name = "Propose Answer"
82-
description = "Ask a researcher to propose an answer using evidence from papers. The input is the question to be answered."
83-
docs: Docs = None
84-
answer: Answer = None
85-
86-
def __init__(self, docs, answer):
87-
# call the parent class constructor
88-
super(AnswerTool, self).__init__()
89-
90-
self.docs = docs
91-
self.answer = answer
92-
93-
def _run(self, query: str) -> str:
94-
self.answer = self.docs.query(query, answer=self.answer)
95-
if "cannot answer" in self.answer.answer:
96-
self.answer = Answer(self.answer.question)
97-
return (
98-
"Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement."
99-
+ status(self.answer, self.docs)
100-
)
101-
return self.answer.answer + status(self.answer, self.docs)
102-
103-
def _arun(self, query: str) -> str:
104-
"""Use the tool asynchronously."""
105-
raise NotImplementedError()
106-
107-
108-
class Search(BaseTool):
109-
name = "Paper Search"
110-
description = (
111-
"Search for papers to add to cur. Input should be a string of keywords."
112-
)
113-
docs: Docs = None
114-
answer: Answer = None
115-
116-
def __init__(self, docs, answer):
117-
# call the parent class constructor
118-
super(Search, self).__init__()
119-
120-
self.docs = docs
121-
self.answer = answer
122-
123-
def _run(self, query: str) -> str:
124-
try:
125-
import paperscraper
126-
except ImportError:
127-
raise ImportError(
128-
"Please install paperscraper (github.com/blackadad/paper-scraper) to use agent"
129-
)
130-
131-
papers = paperscraper.search_papers(
132-
query, limit=20, verbose=False, pdir=self.docs.index_path
133-
)
134-
for path, data in papers.items():
135-
try:
136-
self.docs.add(path, citation=data["citation"])
137-
except:
138-
pass
139-
return status(self.answer, self.docs)
140-
141-
def _arun(self, query: str) -> str:
142-
"""Use the tool asynchronously."""
143-
raise NotImplementedError()
144-
145-
146-
def make_tools(docs, answer):
147-
tools = []
148-
149-
tools.append(Search(docs, answer))
150-
tools.append(PaperSelection(docs, answer))
151-
tools.append(ReadPapers(docs, answer))
152-
tools.append(AnswerTool(docs, answer))
153-
return tools
154-
155-
156-
def run_agent(docs, question, llm=None):
157-
if llm is None:
158-
llm = ChatOpenAI(temperature=0.0, model="gpt-4")
159-
answer = Answer(question)
160-
tools = make_tools(docs, answer)
161-
mrkl = RetryAgentExecutor.from_agent_and_tools(
162-
tools=tools,
163-
agent=ChatZeroShotAgent.from_llm_and_tools(llm, tools),
164-
verbose=True,
165-
)
166-
mrkl.run(
167-
f"Answer question: {question}. Search for papers, gather evidence, and answer. If you do not have enough evidence, you can search for more papers (preferred) or gather more evidence. You may rephrase or breaking-up the question in those steps. "
168-
"Once you have five pieces of evidence, or you have tried for a while, call the Propose Answer tool. "
169-
)
170-
171-
return answer
1+
from langchain.agents import AgentType, initialize_agent
2+
from langchain.chains import LLMChain
3+
from langchain.chat_models import ChatOpenAI
4+
from langchain.tools import BaseTool
5+
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor
6+
7+
from .docs import Answer, Docs
8+
from .qaprompts import make_chain, select_paper_prompt
9+
10+
11+
def status(answer: Answer, docs: Docs):
12+
return f" Status: Current Papers: {len(docs.doc_previews())} Current Evidence: {len(answer.contexts)} Current Cost: ${answer.cost:.2f}"
13+
14+
15+
class PaperSelection(BaseTool):
16+
name = "Select Papers"
17+
description = "Select from current papers. Provide instructions as a string to use for choosing papers."
18+
docs: Docs = None
19+
answer: Answer = None
20+
chain: LLMChain = None
21+
22+
def __init__(self, docs, answer):
23+
# call the parent class constructor
24+
super(PaperSelection, self).__init__()
25+
26+
self.docs = docs
27+
self.answer = answer
28+
self.chain = make_chain(select_paper_prompt, self.docs.summary_llm)
29+
30+
def _run(self, query: str) -> str:
31+
result = self.docs.doc_match(query)
32+
if result is None or result.strip().startswith("None"):
33+
return "No relevant papers found."
34+
return result + status(self.answer, self.docs)
35+
36+
async def _arun(self, query: str) -> str:
37+
"""Use the tool asynchronously."""
38+
raise NotImplementedError()
39+
40+
41+
class ReadPapers(BaseTool):
42+
name = "Gather Evidence"
43+
description = (
44+
"Give a specific question to a researcher that will return evidence for it. "
45+
"Optionally, you may specify papers using their key provided by the Select Papers tool. "
46+
"Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..."
47+
)
48+
docs: Docs = None
49+
answer: Answer = None
50+
51+
def __init__(self, docs, answer):
52+
# call the parent class constructor
53+
super(ReadPapers, self).__init__()
54+
55+
self.docs = docs
56+
self.answer = answer
57+
58+
def _run(self, query: str) -> str:
59+
if "|" in query:
60+
question, keys = query.split("|")
61+
keys = [k.strip() for k in keys.split(",")]
62+
else:
63+
question = query
64+
keys = None
65+
# swap out the question
66+
old = self.answer.question
67+
self.answer.question = question
68+
# generator, so run it
69+
l0 = len(self.answer.contexts)
70+
self.docs.get_evidence(self.answer, key_filter=keys)
71+
l1 = len(self.answer.contexts)
72+
self.answer.question = old
73+
return f"Added {l1 - l0} pieces of evidence." + status(self.answer, self.docs)
74+
75+
async def _arun(self, query: str) -> str:
76+
"""Use the tool asynchronously."""
77+
raise NotImplementedError()
78+
79+
80+
class AnswerTool(BaseTool):
81+
name = "Propose Answer"
82+
description = "Ask a researcher to propose an answer using evidence from papers. The input is the question to be answered."
83+
docs: Docs = None
84+
answer: Answer = None
85+
86+
def __init__(self, docs, answer):
87+
# call the parent class constructor
88+
super(AnswerTool, self).__init__()
89+
90+
self.docs = docs
91+
self.answer = answer
92+
93+
def _run(self, query: str) -> str:
94+
self.answer = self.docs.query(query, answer=self.answer)
95+
if "cannot answer" in self.answer.answer:
96+
self.answer = Answer(self.answer.question)
97+
return (
98+
"Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement."
99+
+ status(self.answer, self.docs)
100+
)
101+
return self.answer.answer + status(self.answer, self.docs)
102+
103+
def _arun(self, query: str) -> str:
104+
"""Use the tool asynchronously."""
105+
raise NotImplementedError()
106+
107+
108+
class Search(BaseTool):
109+
name = "Paper Search"
110+
description = (
111+
"Search for papers to add to cur. Input should be a string of keywords."
112+
)
113+
docs: Docs = None
114+
answer: Answer = None
115+
116+
def __init__(self, docs, answer):
117+
# call the parent class constructor
118+
super(Search, self).__init__()
119+
120+
self.docs = docs
121+
self.answer = answer
122+
123+
def _run(self, query: str) -> str:
124+
try:
125+
import paperscraper
126+
except ImportError:
127+
raise ImportError(
128+
"Please install paperscraper (github.com/blackadad/paper-scraper) to use agent"
129+
)
130+
131+
papers = paperscraper.search_papers(
132+
query, limit=20, verbose=False, pdir=self.docs.index_path
133+
)
134+
for path, data in papers.items():
135+
try:
136+
self.docs.add(path, citation=data["citation"])
137+
except:
138+
pass
139+
return status(self.answer, self.docs)
140+
141+
def _arun(self, query: str) -> str:
142+
"""Use the tool asynchronously."""
143+
raise NotImplementedError()
144+
145+
146+
def make_tools(docs, answer):
147+
tools = []
148+
149+
tools.append(Search(docs, answer))
150+
tools.append(PaperSelection(docs, answer))
151+
tools.append(ReadPapers(docs, answer))
152+
tools.append(AnswerTool(docs, answer))
153+
return tools
154+
155+
156+
def run_agent(docs, question, llm=None):
157+
if llm is None:
158+
llm = ChatOpenAI(temperature=0.0, model_name="gpt-4")
159+
answer = Answer(question)
160+
tools = make_tools(docs, answer)
161+
mrkl = RetryAgentExecutor.from_agent_and_tools(
162+
tools=tools,
163+
agent=ChatZeroShotAgent.from_llm_and_tools(llm, tools),
164+
verbose=True,
165+
)
166+
mrkl.run(
167+
f"Answer question: {question}. Search for papers, gather evidence, and answer. If you do not have enough evidence, you can search for more papers (preferred) or gather more evidence. You may rephrase or breaking-up the question in those steps. "
168+
"Once you have five pieces of evidence, or you have tried for a while, call the Propose Answer tool. "
169+
)
170+
171+
return answer

0 commit comments

Comments
 (0)