Skip to content

Commit 50de570

Browse files
authored
Updated to new callback syntax
* Fixed callback bugs * Bumbped version * Switched to new rmrkl version
1 parent d7c08f9 commit 50de570

File tree

11 files changed

+116
-89
lines changed

11 files changed

+116
-89
lines changed

paperqa/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .docs import Docs, maybe_is_text, Answer
2-
from .version import __version__
31
from .agent import run_agent
2+
from .docs import Answer, Docs, maybe_is_text
3+
from .version import __version__

paperqa/agent.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from langchain.tools import BaseTool
2-
from .docs import Answer, Docs
3-
from langchain.agents import initialize_agent
4-
from langchain.chat_models import ChatOpenAI
1+
from langchain.agents import AgentType, initialize_agent
52
from langchain.chains import LLMChain
6-
from langchain.agents import AgentType
7-
from .qaprompts import select_paper_prompt, make_chain
3+
from langchain.chat_models import ChatOpenAI
4+
from langchain.tools import BaseTool
85
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor
96

7+
from .docs import Answer, Docs
8+
from .qaprompts import make_chain, select_paper_prompt
9+
1010

1111
def status(answer: Answer, docs: Docs):
1212
return f" Status: Current Papers: {len(docs.doc_previews())} Current Evidence: {len(answer.contexts)} Current Cost: ${answer.cost:.2f}"
@@ -91,13 +91,12 @@ def __init__(self, docs, answer):
9191
self.answer = answer
9292

9393
def _run(self, query: str) -> str:
94-
self.answer = self.docs.query(
95-
query, answer=self.answer
96-
)
94+
self.answer = self.docs.query(query, answer=self.answer)
9795
if "cannot answer" in self.answer.answer:
9896
self.answer = Answer(self.answer.question)
99-
return "Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement." + status(
100-
self.answer, self.docs
97+
return (
98+
"Failed to answer question. Deleting evidence. Consider rephrasing question or evidence statement."
99+
+ status(self.answer, self.docs)
101100
)
102101
return self.answer.answer + status(self.answer, self.docs)
103102

@@ -108,7 +107,9 @@ def _arun(self, query: str) -> str:
108107

109108
class Search(BaseTool):
110109
name = "Paper Search"
111-
description = "Search for papers to add to cur. Input should be a string of keywords."
110+
description = (
111+
"Search for papers to add to cur. Input should be a string of keywords."
112+
)
112113
docs: Docs = None
113114
answer: Answer = None
114115

@@ -143,7 +144,6 @@ def _arun(self, query: str) -> str:
143144

144145

145146
def make_tools(docs, answer):
146-
147147
tools = []
148148

149149
tools.append(Search(docs, answer))

paperqa/contrib/zotero.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
# This file gets PDF files from the user's Zotero library
2-
import os
3-
from typing import Union, Optional
4-
from pathlib import Path
52
import logging
3+
import os
64
from collections import namedtuple
5+
from pathlib import Path
6+
from typing import Optional, Union
77

88
try:
99
from pyzotero import zotero
1010
except ImportError:
1111
raise ImportError("Please install pyzotero: `pip install pyzotero`")
1212
from ..paths import CACHE_PATH
13-
from ..utils import count_pdf_pages
1413
from ..types import StrPath
14+
from ..utils import count_pdf_pages
1515

1616
ZoteroPaper = namedtuple(
1717
"ZoteroPaper", ["key", "title", "pdf", "num_pages", "zotero_key", "details"]

paperqa/docs.py

Lines changed: 49 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,30 @@
1-
from typing import List, Optional, Tuple, Union, Callable
2-
from functools import reduce
1+
import asyncio
32
import os
3+
import re
44
import sys
5-
import asyncio
5+
from datetime import datetime
6+
from functools import reduce
67
from pathlib import Path
7-
import re
8-
from .paths import CACHE_PATH
9-
from .utils import maybe_is_text, md5sum
10-
from .qaprompts import (
11-
summary_prompt,
12-
qa_prompt,
13-
search_prompt,
14-
citation_prompt,
15-
select_paper_prompt,
16-
make_chain,
17-
)
18-
from .types import Answer, Context
19-
from .readers import read_doc
20-
from langchain.vectorstores import FAISS
8+
from typing import Callable, List, Optional, Tuple, Union
9+
10+
import langchain
11+
from langchain.cache import SQLiteCache
12+
from langchain.callbacks import OpenAICallbackHandler, get_openai_callback
13+
from langchain.callbacks.base import AsyncCallbackHandler
14+
from langchain.callbacks.manager import AsyncCallbackManager
15+
from langchain.chat_models import ChatOpenAI
2116
from langchain.docstore.document import Document
22-
from langchain.embeddings.openai import OpenAIEmbeddings
2317
from langchain.embeddings.base import Embeddings
24-
from langchain.chat_models import ChatOpenAI
18+
from langchain.embeddings.openai import OpenAIEmbeddings
2519
from langchain.llms.base import LLM
26-
from langchain.callbacks import get_openai_callback, OpenAICallbackHandler
27-
from langchain.callbacks.base import AsyncCallbackHandler, AsyncCallbackManager
28-
from langchain.cache import SQLiteCache
29-
import langchain
30-
from datetime import datetime
20+
from langchain.vectorstores import FAISS
21+
22+
from .paths import CACHE_PATH
23+
from .qaprompts import (citation_prompt, make_chain, qa_prompt, search_prompt,
24+
select_paper_prompt, summary_prompt)
25+
from .readers import read_doc
26+
from .types import Answer, Context
27+
from .utils import maybe_is_text, md5sum
3128

3229
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
3330
langchain.llm_cache = SQLiteCache(CACHE_PATH)
@@ -44,7 +41,6 @@ def __init__(
4441
name: str = "default",
4542
index_path: Optional[Path] = None,
4643
embeddings: Optional[Embeddings] = None,
47-
get_callbacks: Callable[[str], AsyncCallbackHandler] = lambda x : []
4844
) -> None:
4945
"""Initialize the collection of documents.
5046
@@ -55,7 +51,6 @@ def __init__(
5551
name: The name of the collection.
5652
index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name
5753
embeddings: The embeddings to use for indexing documents. Default - OpenAI embeddings
58-
get_callbacks: A function that allows callbacks to built per stage of the pipeline.
5954
"""
6055
self.docs = dict()
6156
self.chunk_size_limit = chunk_size_limit
@@ -70,7 +65,6 @@ def __init__(
7065
if embeddings is None:
7166
embeddings = OpenAIEmbeddings()
7267
self.embeddings = embeddings
73-
self.get_callbacks = get_callbacks
7468

7569
def update_llm(
7670
self,
@@ -96,7 +90,6 @@ def add(
9690
key: Optional[str] = None,
9791
disable_check: bool = False,
9892
chunk_chars: Optional[int] = 3000,
99-
overwrite: bool = False,
10093
) -> None:
10194
"""Add a document to the collection."""
10295

@@ -110,8 +103,9 @@ def add(
110103
cite_chain = make_chain(prompt=citation_prompt, llm=self.summary_llm)
111104
# peak first chunk
112105
texts, _ = read_doc(path, "", "", chunk_chars=chunk_chars)
113-
with get_openai_callback():
114-
citation = cite_chain.run(texts[0])
106+
if len(texts) == 0:
107+
raise ValueError(f"Could not read document {path}. Is it empty?")
108+
citation = cite_chain.run(texts[0])
115109
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
116110
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
117111

@@ -207,12 +201,10 @@ def __getstate__(self):
207201
state["_faiss_index"].save_local(self.index_path)
208202
del state["_faiss_index"]
209203
del state["_doc_index"]
210-
del state["get_callbacks"]
211204
return state
212205

213206
def __setstate__(self, state):
214207
self.__dict__.update(state)
215-
self.get_callbacks = lambda x: []
216208
try:
217209
self._faiss_index = FAISS.load_local(self.index_path, self.embeddings)
218210
except:
@@ -240,8 +232,9 @@ def get_evidence(
240232
k: int = 3,
241233
max_sources: int = 5,
242234
marginal_relevance: bool = True,
243-
key_filter: Optional[List[str]] = None
244-
) -> Answer:
235+
key_filter: Optional[List[str]] = None,
236+
get_callbacks: Callable[[str], AsyncCallbackHandler] = lambda x: [],
237+
) -> Answer:
245238
# special case for jupyter notebooks
246239
if "get_ipython" in globals() or "google.colab" in sys.modules:
247240
import nest_asyncio
@@ -258,7 +251,8 @@ def get_evidence(
258251
k=k,
259252
max_sources=max_sources,
260253
marginal_relevance=marginal_relevance,
261-
key_filter=key_filter
254+
key_filter=key_filter,
255+
get_callbacks=get_callbacks,
262256
)
263257
)
264258

@@ -269,6 +263,7 @@ async def aget_evidence(
269263
max_sources: int = 5,
270264
marginal_relevance: bool = True,
271265
key_filter: Optional[List[str]] = None,
266+
get_callbacks: Callable[[str], AsyncCallbackHandler] = lambda x: [],
272267
) -> Answer:
273268
if len(self.docs) == 0:
274269
return answer
@@ -293,29 +288,32 @@ async def process(doc):
293288
# check if it is already in answer (possible in agent setting)
294289
if doc.metadata["key"] in [c.key for c in answer.contexts]:
295290
return None, None
296-
cb = OpenAICallbackHandler()
297-
manager = AsyncCallbackManager([cb] + self.get_callbacks('evidence:' + doc.metadata['key']))
298-
summary_chain = make_chain(summary_prompt, self.summary_llm, manager)
291+
callbacks = [OpenAICallbackHandler()] + get_callbacks(
292+
"evidence:" + doc.metadata["key"]
293+
)
294+
summary_chain = make_chain(summary_prompt, self.summary_llm)
299295
c = Context(
300296
key=doc.metadata["key"],
301297
citation=doc.metadata["citation"],
302298
context=await summary_chain.arun(
303299
question=answer.question,
304300
context_str=doc.page_content,
305301
citation=doc.metadata["citation"],
302+
callbacks=callbacks,
306303
),
307304
text=doc.page_content,
308305
)
309306
if "Not applicable" not in c.context:
310307
return c, cb
311308
return None, None
312309

313-
results = await asyncio.gather(*[process(doc) for doc in docs])
310+
with get_openai_callback() as cb:
311+
results = await asyncio.gather(*[process(doc) for doc in docs])
314312
# filter out failures
315313
results = [r for r in results if r[0] is not None]
316314
answer.tokens += sum([cb.total_tokens for _, cb in results])
317315
answer.cost += sum([cb.total_cost for _, cb in results])
318-
contexts = [c for c,_ in results if c is not None]
316+
contexts = [c for c, _ in results if c is not None]
319317
if len(contexts) == 0:
320318
return answer
321319
contexts = sorted(contexts, key=lambda x: len(x.context), reverse=True)
@@ -365,6 +363,7 @@ def query(
365363
marginal_relevance: bool = True,
366364
answer: Optional[Answer] = None,
367365
key_filter: Optional[bool] = None,
366+
get_callbacks: Callable[[str], AsyncCallbackHandler] = lambda x: [],
368367
) -> Answer:
369368
# special case for jupyter notebooks
370369
if "get_ipython" in globals() or "google.colab" in sys.modules:
@@ -385,6 +384,7 @@ def query(
385384
marginal_relevance=marginal_relevance,
386385
answer=answer,
387386
key_filter=key_filter,
387+
get_callbacks=get_callbacks,
388388
)
389389
)
390390

@@ -397,6 +397,7 @@ async def aquery(
397397
marginal_relevance: bool = True,
398398
answer: Optional[Answer] = None,
399399
key_filter: Optional[bool] = None,
400+
get_callbacks: Callable[[str], AsyncCallbackHandler] = lambda x: [],
400401
) -> Answer:
401402
if k < max_sources:
402403
raise ValueError("k should be greater than max_sources")
@@ -414,6 +415,7 @@ async def aquery(
414415
max_sources=max_sources,
415416
marginal_relevance=marginal_relevance,
416417
key_filter=keys if key_filter else None,
418+
get_callbacks=get_callbacks,
417419
)
418420
context_str, contexts = answer.context, answer.contexts
419421
bib = dict()
@@ -424,11 +426,14 @@ async def aquery(
424426
)
425427
else:
426428
cb = OpenAICallbackHandler()
427-
manager = AsyncCallbackManager([cb] + self.get_callbacks('answer'))
428-
qa_chain = make_chain(qa_prompt, self.llm, manager)
429+
callbacks = [OpenAICallbackHandler()] + get_callbacks("answer")
430+
qa_chain = make_chain(qa_prompt, self.llm)
429431
answer_text = await qa_chain.arun(
430-
question=query, context_str=context_str, length=length_prompt
431-
)
432+
question=query,
433+
context_str=context_str,
434+
length=length_prompt,
435+
callbacks=callbacks,
436+
)
432437
answer.tokens += cb.total_tokens
433438
answer.cost += cb.total_cost
434439
# it still happens lol

paperqa/qaprompts.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
import langchain.prompts as prompts
21
from datetime import datetime
2+
from typing import Any, Dict, List, Optional
3+
4+
import langchain.prompts as prompts
5+
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
36
from langchain.chains import LLMChain
47
from langchain.chat_models import ChatOpenAI
5-
from langchain.schema import SystemMessage
6-
from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
8+
from langchain.prompts.chat import (ChatPromptTemplate,
9+
HumanMessagePromptTemplate)
10+
from langchain.schema import LLMResult, SystemMessage
711

812
summary_prompt = prompts.PromptTemplate(
913
input_variables=["question", "context_str", "citation"],
@@ -73,11 +77,23 @@ def _get_datetime():
7377
partial_variables={"date": _get_datetime},
7478
)
7579

76-
def make_chain(prompt, llm, callback_manager=None):
77-
if callback_manager is not None:
78-
# need to clone to attach
79-
llm = llm.copy()
80-
llm.callback_manager = callback_manager
80+
81+
class FallbackLLMChain(LLMChain):
82+
"""Chain that falls back to synchronous generation if the async generation fails."""
83+
84+
async def agenerate(
85+
self,
86+
input_list: List[Dict[str, Any]],
87+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
88+
) -> LLMResult:
89+
"""Generate LLM result from inputs."""
90+
try:
91+
return await super().agenerate(input_list, run_manager=run_manager)
92+
except NotImplementedError as e:
93+
return self.generate(input_list, run_manager=run_manager)
94+
95+
96+
def make_chain(prompt, llm):
8197
if type(llm) == ChatOpenAI:
8298
system_message_prompt = SystemMessage(
8399
content="You are a scholarly researcher that answers in an unbiased, scholarly tone. "
@@ -87,4 +103,4 @@ def make_chain(prompt, llm, callback_manager=None):
87103
prompt = ChatPromptTemplate.from_messages(
88104
[system_message_prompt, human_message_prompt]
89105
)
90-
return LLMChain(prompt=prompt, llm=llm)
106+
return FallbackLLMChain(prompt=prompt, llm=llm)

paperqa/readers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
import os
2-
from .paths import OCR_CACHE_PATH
3-
from .version import __version__
4-
from html2text import html2text
5-
from pathlib import Path
61
import json
72
import logging
3+
import os
84
from hashlib import md5
5+
from pathlib import Path
96

10-
from langchain.text_splitter import TokenTextSplitter
7+
from html2text import html2text
118
from langchain.cache import SQLiteCache
129
from langchain.schema import Generation
10+
from langchain.text_splitter import TokenTextSplitter
11+
12+
from .paths import OCR_CACHE_PATH
13+
from .version import __version__
1314

1415
OCR_CACHE = None
1516

@@ -69,7 +70,6 @@ def parse_pdf(path, citation, key, chunk_chars=2000, overlap=50):
6970

7071

7172
def parse_txt(path, citation, key, chunk_chars=2000, overlap=50, html=False):
72-
7373
try:
7474
with open(path) as f:
7575
doc = f.read()

0 commit comments

Comments
 (0)