Skip to content

Commit 93c47cc

Browse files
authored
Added back paper selection to agents (#58)
* added back paper selection * fixed evidence * Bumped version * Added a small test for zotera that is not very good (#59) * Added status report on evidence gathering * Improved agent governing prompt
1 parent 575f938 commit 93c47cc

File tree

4 files changed

+58
-22
lines changed

4 files changed

+58
-22
lines changed

dev-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ pytest
22
pre-commit
33
requests
44
paper-scraper@git+https://github.com/blackadad/paper-scraper.git
5+
pyzotero

paperqa/agent.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,46 @@
22
from .docs import Answer, Docs
33
from langchain.agents import initialize_agent
44
from langchain.chat_models import ChatOpenAI
5+
from langchain.chains import LLMChain
6+
from .qaprompts import select_paper_prompt, make_chain
57

68

79
def status(answer: Answer, docs: Docs):
810
return f" Status: Current Papers: {len(docs.doc_previews())} Current Evidence: {len(answer.contexts)} Current Cost: {answer.cost}"
911

1012

13+
class PaperSelection(BaseTool):
14+
name = "Select Papers"
15+
description = "Select from current papers. Provide instructions as a string to use for choosing papers."
16+
docs: Docs = None
17+
answer: Answer = None
18+
chain: LLMChain = None
19+
20+
def __init__(self, docs, answer):
21+
# call the parent class constructor
22+
super(PaperSelection, self).__init__()
23+
24+
self.docs = docs
25+
self.answer = answer
26+
self.chain = make_chain(select_paper_prompt, self.docs.summary_llm)
27+
28+
def _run(self, query: str) -> str:
29+
result = self.docs.doc_match(query)
30+
if result is None or result.strip().startswith("None"):
31+
return "No relevant papers found."
32+
return result + status(self.answer, self.docs)
33+
34+
async def _arun(self, query: str) -> str:
35+
"""Use the tool asynchronously."""
36+
raise NotImplementedError()
37+
38+
1139
class ReadPapers(BaseTool):
1240
name = "Gather Evidence"
1341
description = (
1442
"Give a specific question to a researcher that will return evidence for it. "
15-
"Optionally, you may specify papers using their key provided by the Select Papers tool. "
16-
"Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..."
43+
# "Optionally, you may specify papers using their key provided by the Select Papers tool. "
44+
# "Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..."
1745
)
1846
docs: Docs = None
1947
answer: Answer = None
@@ -26,19 +54,21 @@ def __init__(self, docs, answer):
2654
self.answer = answer
2755

2856
def _run(self, query: str) -> str:
29-
if "|" in query:
30-
question, keys = query.split("|")
31-
keys = [k.strip() for k in keys.split(",")]
32-
else:
33-
question = query
34-
keys = None
57+
# if "|" in query:
58+
# question, keys = query.split("|")
59+
# keys = [k.strip() for k in keys.split(",")]
60+
# else:
61+
question = query
62+
keys = None
3563
# swap out the question
3664
old = self.answer.question
3765
self.answer.question = question
3866
# generator, so run it
67+
l0 = len(self.answer.contexts)
3968
self.docs.get_evidence(self.answer, key_filter=keys)
69+
l1 = len(self.answer.contexts)
4070
self.answer.question = old
41-
return status(self.answer, self.docs)
71+
return f"Added {l1 - l0} pieces of evidence." + status(self.answer, self.docs)
4272

4373
async def _arun(self, query: str) -> str:
4474
"""Use the tool asynchronously."""
@@ -76,7 +106,7 @@ def _arun(self, query: str) -> str:
76106

77107
class Search(BaseTool):
78108
name = "Paper Search"
79-
description = "Search for papers to add to current papers. Input should be a string of keywords."
109+
description = "Search for papers to add to cur. Input should be a string of keywords."
80110
docs: Docs = None
81111
answer: Answer = None
82112

@@ -95,10 +125,12 @@ def _run(self, query: str) -> str:
95125
"Please install paperscraper (github.com/blackadad/paper-scraper) to use agent"
96126
)
97127

98-
papers = paperscraper.search_papers(query, limit=20, verbose=False)
128+
papers = paperscraper.search_papers(
129+
query, limit=20, verbose=False, pdir=self.docs.index_path
130+
)
99131
for path, data in papers.items():
100132
try:
101-
self.docs.add(path)
133+
self.docs.add(path, citation=data["citation"])
102134
except:
103135
pass
104136
return status(self.answer, self.docs)
@@ -115,6 +147,7 @@ def make_tools(docs, answer):
115147
tools = []
116148

117149
tools.append(Search(docs, answer))
150+
# tools.append(PaperSelection(docs, answer))
118151
tools.append(ReadPapers(docs, answer))
119152
tools.append(AnswerTool(docs, answer))
120153
tools.append(ExceptionTool())
@@ -127,12 +160,14 @@ def run_agent(docs, question, llm=None):
127160
answer = Answer(question)
128161
tools = make_tools(docs, answer)
129162
mrkl = initialize_agent(
130-
tools, llm, agent="chat-zero-shot-react-description", verbose=True
163+
tools,
164+
llm,
165+
agent="chat-zero-shot-react-description",
166+
verbose=True,
131167
)
132168
mrkl.run(
133-
f"Answer question: {question}. Search for papers, gather evidence, and answer. "
134-
"Once you have at least five pieces of evidence, call the Propose Answer tool. "
135-
"If you do not have enough evidence, search with different keywords. "
169+
f"Answer question: {question}. Search for papers, gather evidence, and answer. If you do not have enough evidence, you can search for more papers (preferred) or gather more evidence. You may rephrase or breaking-up the question in those steps. "
170+
"Once you have five pieces of evidence, or you have tried for a while, call the Propose Answer tool. "
136171
)
137172

138173
return answer

paperqa/docs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,12 +397,12 @@ async def aquery(
397397
raise ValueError("k should be greater than max_sources")
398398
if answer is None:
399399
answer = Answer(query)
400-
if key_filter or (key_filter is None and len(self.docs) > 5):
401-
with get_openai_callback() as cb:
402-
keys = self.doc_match(answer.question)
403-
answer.tokens += cb.total_tokens
404-
answer.cost += cb.total_cost
405400
if len(answer.contexts) == 0:
401+
if key_filter or (key_filter is None and len(self.docs) > 5):
402+
with get_openai_callback() as cb:
403+
keys = self.doc_match(answer.question)
404+
answer.tokens += cb.total_tokens
405+
answer.cost += cb.total_cost
406406
answer = await self.aget_evidence(
407407
answer,
408408
k=k,

paperqa/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.1.0"
1+
__version__ = "1.1.1"

0 commit comments

Comments
 (0)