|
1 | | -from langchain.agents import AgentType, initialize_agent |
2 | | -from langchain.chains import LLMChain |
3 | | -from langchain.chat_models import ChatOpenAI |
4 | | -from langchain.tools import BaseTool |
5 | | -from rmrkl import ChatZeroShotAgent, RetryAgentExecutor |
6 | | - |
7 | | -from .docs import Answer, Docs |
8 | | -from .qaprompts import make_chain, select_paper_prompt |
9 | | - |
10 | | - |
11 | | -def status(answer: Answer, docs: Docs): |
12 | | - return f" Status: Current Papers: {len(docs.doc_previews())} Current Evidence: {len(answer.contexts)} Current Cost: ${answer.cost:.2f}" |
13 | | - |
14 | | - |
15 | | -class PaperSelection(BaseTool): |
16 | | - name = "Select Papers" |
17 | | - description = "Select from current papers. Provide instructions as a string to use for choosing papers." |
18 | | - docs: Docs = None |
19 | | - answer: Answer = None |
20 | | - chain: LLMChain = None |
21 | | - |
22 | | - def __init__(self, docs, answer): |
23 | | - # call the parent class constructor |
24 | | - super(PaperSelection, self).__init__() |
25 | | - |
26 | | - self.docs = docs |
27 | | - self.answer = answer |
28 | | - self.chain = make_chain(select_paper_prompt, self.docs.summary_llm) |
29 | | - |
30 | | - def _run(self, query: str) -> str: |
31 | | - result = self.docs.doc_match(query) |
32 | | - if result is None or result.strip().startswith("None"): |
33 | | - return "No relevant papers found." |
34 | | - return result + status(self.answer, self.docs) |
35 | | - |
36 | | - async def _arun(self, query: str) -> str: |
37 | | - """Use the tool asynchronously.""" |
38 | | - raise NotImplementedError() |
39 | | - |
40 | | - |
41 | | -class ReadPapers(BaseTool): |
42 | | - name = "Gather Evidence" |
43 | | - description = ( |
44 | | - "Give a specific question to a researcher that will return evidence for it. " |
45 | | - "Optionally, you may specify papers using their key provided by the Select Papers tool. " |
46 | | - "Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..." |
47 | | - ) |
48 | | - docs: Docs = None |
49 | | - answer: Answer = None |
50 | | - |
51 | | - def __init__(self, docs, answer): |
52 | | - # call the parent class constructor |
53 | | - super(ReadPapers, self).__init__() |
54 | | - |
55 | | - self.docs = docs |
56 | | - self.answer = answer |
57 | | - |
58 | | - def _run(self, query: str) -> str: |
59 | | - if "|" in query: |
60 | | - question, keys = query.split("|") |
61 | | - keys = [k.strip() for k in keys.split(",")] |
62 | | - else: |
63 | | - question = query |
64 | | - keys = None |
65 | | - # swap out the question |
66 | | - old = self.answer.question |
67 | | - self.answer.question = question |
68 | | - # generator, so run it |
69 | | - l0 = len(self.answer.contexts) |
70 | | - self.docs.get_evidence(self.answer, key_filter=keys) |
71 | | - l1 = len(self.answer.contexts) |
72 | | - self.answer.question = old |
73 | | - return f"Added {l1 - l0} pieces of evidence." + status(self.answer, self.docs) |
74 | | - |
75 | | - async def _arun(self, query: str) -> str: |
76 | | - """Use the tool asynchronously.""" |
77 | | - raise NotImplementedError() |
78 | | - |
79 | | - |
80 | | -class AnswerTool(BaseTool): |
81 | | - name = "Propose Answer" |
82 | | - description = "Ask a researcher to propose an answer using evidence from papers. The input is the question to be answered." |
83 | | - docs: Docs = None |
84 | | - answer: Answer = None |
85 | | - |
86 | | - def __init__(self, docs, answer): |
87 | | - # call the parent class constructor |
88 | | - super(AnswerTool, self).__init__() |
89 | | - |
90 | | - self.docs = docs |
91 | | - self.answer = answer |
92 | | - |
93 | | - def _run(self, query: str) -> str: |
94 | | - self.answer = self.docs.query(query, answer=self.answer) |
95 | | - if "cannot answer" in self.answer.answer: |
96 | | - self.answer = Answer(self.answer.question) |
97 | | - return ( |
98 | | - "Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement." |
99 | | - + status(self.answer, self.docs) |
100 | | - ) |
101 | | - return self.answer.answer + status(self.answer, self.docs) |
102 | | - |
103 | | - def _arun(self, query: str) -> str: |
104 | | - """Use the tool asynchronously.""" |
105 | | - raise NotImplementedError() |
106 | | - |
107 | | - |
108 | | -class Search(BaseTool): |
109 | | - name = "Paper Search" |
110 | | - description = ( |
111 | | - "Search for papers to add to cur. Input should be a string of keywords." |
112 | | - ) |
113 | | - docs: Docs = None |
114 | | - answer: Answer = None |
115 | | - |
116 | | - def __init__(self, docs, answer): |
117 | | - # call the parent class constructor |
118 | | - super(Search, self).__init__() |
119 | | - |
120 | | - self.docs = docs |
121 | | - self.answer = answer |
122 | | - |
123 | | - def _run(self, query: str) -> str: |
124 | | - try: |
125 | | - import paperscraper |
126 | | - except ImportError: |
127 | | - raise ImportError( |
128 | | - "Please install paperscraper (github.com/blackadad/paper-scraper) to use agent" |
129 | | - ) |
130 | | - |
131 | | - papers = paperscraper.search_papers( |
132 | | - query, limit=20, verbose=False, pdir=self.docs.index_path |
133 | | - ) |
134 | | - for path, data in papers.items(): |
135 | | - try: |
136 | | - self.docs.add(path, citation=data["citation"]) |
137 | | - except: |
138 | | - pass |
139 | | - return status(self.answer, self.docs) |
140 | | - |
141 | | - def _arun(self, query: str) -> str: |
142 | | - """Use the tool asynchronously.""" |
143 | | - raise NotImplementedError() |
144 | | - |
145 | | - |
146 | | -def make_tools(docs, answer): |
147 | | - tools = [] |
148 | | - |
149 | | - tools.append(Search(docs, answer)) |
150 | | - tools.append(PaperSelection(docs, answer)) |
151 | | - tools.append(ReadPapers(docs, answer)) |
152 | | - tools.append(AnswerTool(docs, answer)) |
153 | | - return tools |
154 | | - |
155 | | - |
156 | | -def run_agent(docs, question, llm=None): |
157 | | - if llm is None: |
158 | | - llm = ChatOpenAI(temperature=0.0, model="gpt-4") |
159 | | - answer = Answer(question) |
160 | | - tools = make_tools(docs, answer) |
161 | | - mrkl = RetryAgentExecutor.from_agent_and_tools( |
162 | | - tools=tools, |
163 | | - agent=ChatZeroShotAgent.from_llm_and_tools(llm, tools), |
164 | | - verbose=True, |
165 | | - ) |
166 | | - mrkl.run( |
167 | | - f"Answer question: {question}. Search for papers, gather evidence, and answer. If you do not have enough evidence, you can search for more papers (preferred) or gather more evidence. You may rephrase or breaking-up the question in those steps. " |
168 | | - "Once you have five pieces of evidence, or you have tried for a while, call the Propose Answer tool. " |
169 | | - ) |
170 | | - |
171 | | - return answer |
| 1 | +from langchain.agents import AgentType, initialize_agent |
| 2 | +from langchain.chains import LLMChain |
| 3 | +from langchain.chat_models import ChatOpenAI |
| 4 | +from langchain.tools import BaseTool |
| 5 | +from rmrkl import ChatZeroShotAgent, RetryAgentExecutor |
| 6 | + |
| 7 | +from .docs import Answer, Docs |
| 8 | +from .qaprompts import make_chain, select_paper_prompt |
| 9 | + |
| 10 | + |
| 11 | +def status(answer: Answer, docs: Docs): |
| 12 | + return f" Status: Current Papers: {len(docs.doc_previews())} Current Evidence: {len(answer.contexts)} Current Cost: ${answer.cost:.2f}" |
| 13 | + |
| 14 | + |
| 15 | +class PaperSelection(BaseTool): |
| 16 | + name = "Select Papers" |
| 17 | + description = "Select from current papers. Provide instructions as a string to use for choosing papers." |
| 18 | + docs: Docs = None |
| 19 | + answer: Answer = None |
| 20 | + chain: LLMChain = None |
| 21 | + |
| 22 | + def __init__(self, docs, answer): |
| 23 | + # call the parent class constructor |
| 24 | + super(PaperSelection, self).__init__() |
| 25 | + |
| 26 | + self.docs = docs |
| 27 | + self.answer = answer |
| 28 | + self.chain = make_chain(select_paper_prompt, self.docs.summary_llm) |
| 29 | + |
| 30 | + def _run(self, query: str) -> str: |
| 31 | + result = self.docs.doc_match(query) |
| 32 | + if result is None or result.strip().startswith("None"): |
| 33 | + return "No relevant papers found." |
| 34 | + return result + status(self.answer, self.docs) |
| 35 | + |
| 36 | + async def _arun(self, query: str) -> str: |
| 37 | + """Use the tool asynchronously.""" |
| 38 | + raise NotImplementedError() |
| 39 | + |
| 40 | + |
| 41 | +class ReadPapers(BaseTool): |
| 42 | + name = "Gather Evidence" |
| 43 | + description = ( |
| 44 | + "Give a specific question to a researcher that will return evidence for it. " |
| 45 | + "Optionally, you may specify papers using their key provided by the Select Papers tool. " |
| 46 | + "Use the format: $QUESTION or use format $QUESTION|$KEY1,$KEY2,..." |
| 47 | + ) |
| 48 | + docs: Docs = None |
| 49 | + answer: Answer = None |
| 50 | + |
| 51 | + def __init__(self, docs, answer): |
| 52 | + # call the parent class constructor |
| 53 | + super(ReadPapers, self).__init__() |
| 54 | + |
| 55 | + self.docs = docs |
| 56 | + self.answer = answer |
| 57 | + |
| 58 | + def _run(self, query: str) -> str: |
| 59 | + if "|" in query: |
| 60 | + question, keys = query.split("|") |
| 61 | + keys = [k.strip() for k in keys.split(",")] |
| 62 | + else: |
| 63 | + question = query |
| 64 | + keys = None |
| 65 | + # swap out the question |
| 66 | + old = self.answer.question |
| 67 | + self.answer.question = question |
| 68 | + # generator, so run it |
| 69 | + l0 = len(self.answer.contexts) |
| 70 | + self.docs.get_evidence(self.answer, key_filter=keys) |
| 71 | + l1 = len(self.answer.contexts) |
| 72 | + self.answer.question = old |
| 73 | + return f"Added {l1 - l0} pieces of evidence." + status(self.answer, self.docs) |
| 74 | + |
| 75 | + async def _arun(self, query: str) -> str: |
| 76 | + """Use the tool asynchronously.""" |
| 77 | + raise NotImplementedError() |
| 78 | + |
| 79 | + |
| 80 | +class AnswerTool(BaseTool): |
| 81 | + name = "Propose Answer" |
| 82 | + description = "Ask a researcher to propose an answer using evidence from papers. The input is the question to be answered." |
| 83 | + docs: Docs = None |
| 84 | + answer: Answer = None |
| 85 | + |
| 86 | + def __init__(self, docs, answer): |
| 87 | + # call the parent class constructor |
| 88 | + super(AnswerTool, self).__init__() |
| 89 | + |
| 90 | + self.docs = docs |
| 91 | + self.answer = answer |
| 92 | + |
| 93 | + def _run(self, query: str) -> str: |
| 94 | + self.answer = self.docs.query(query, answer=self.answer) |
| 95 | + if "cannot answer" in self.answer.answer: |
| 96 | + self.answer = Answer(self.answer.question) |
| 97 | + return ( |
| 98 | + "Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement." |
| 99 | + + status(self.answer, self.docs) |
| 100 | + ) |
| 101 | + return self.answer.answer + status(self.answer, self.docs) |
| 102 | + |
| 103 | + def _arun(self, query: str) -> str: |
| 104 | + """Use the tool asynchronously.""" |
| 105 | + raise NotImplementedError() |
| 106 | + |
| 107 | + |
| 108 | +class Search(BaseTool): |
| 109 | + name = "Paper Search" |
| 110 | + description = ( |
| 111 | + "Search for papers to add to cur. Input should be a string of keywords." |
| 112 | + ) |
| 113 | + docs: Docs = None |
| 114 | + answer: Answer = None |
| 115 | + |
| 116 | + def __init__(self, docs, answer): |
| 117 | + # call the parent class constructor |
| 118 | + super(Search, self).__init__() |
| 119 | + |
| 120 | + self.docs = docs |
| 121 | + self.answer = answer |
| 122 | + |
| 123 | + def _run(self, query: str) -> str: |
| 124 | + try: |
| 125 | + import paperscraper |
| 126 | + except ImportError: |
| 127 | + raise ImportError( |
| 128 | + "Please install paperscraper (github.com/blackadad/paper-scraper) to use agent" |
| 129 | + ) |
| 130 | + |
| 131 | + papers = paperscraper.search_papers( |
| 132 | + query, limit=20, verbose=False, pdir=self.docs.index_path |
| 133 | + ) |
| 134 | + for path, data in papers.items(): |
| 135 | + try: |
| 136 | + self.docs.add(path, citation=data["citation"]) |
| 137 | + except: |
| 138 | + pass |
| 139 | + return status(self.answer, self.docs) |
| 140 | + |
| 141 | + def _arun(self, query: str) -> str: |
| 142 | + """Use the tool asynchronously.""" |
| 143 | + raise NotImplementedError() |
| 144 | + |
| 145 | + |
| 146 | +def make_tools(docs, answer): |
| 147 | + tools = [] |
| 148 | + |
| 149 | + tools.append(Search(docs, answer)) |
| 150 | + tools.append(PaperSelection(docs, answer)) |
| 151 | + tools.append(ReadPapers(docs, answer)) |
| 152 | + tools.append(AnswerTool(docs, answer)) |
| 153 | + return tools |
| 154 | + |
| 155 | + |
| 156 | +def run_agent(docs, question, llm=None): |
| 157 | + if llm is None: |
| 158 | + llm = ChatOpenAI(temperature=0.0, model_name="gpt-4") |
| 159 | + answer = Answer(question) |
| 160 | + tools = make_tools(docs, answer) |
| 161 | + mrkl = RetryAgentExecutor.from_agent_and_tools( |
| 162 | + tools=tools, |
| 163 | + agent=ChatZeroShotAgent.from_llm_and_tools(llm, tools), |
| 164 | + verbose=True, |
| 165 | + ) |
| 166 | + mrkl.run( |
| 167 | + f"Answer question: {question}. Search for papers, gather evidence, and answer. If you do not have enough evidence, you can search for more papers (preferred) or gather more evidence. You may rephrase or breaking-up the question in those steps. " |
| 168 | + "Once you have five pieces of evidence, or you have tried for a while, call the Propose Answer tool. " |
| 169 | + ) |
| 170 | + |
| 171 | + return answer |
0 commit comments