Skip to content

Commit 88d50e8

Browse files
authored
Alow failure on evidence gathering without crashing (#129)
* Alow failure on evidence gathering without crashing * Fixed errors on key filters
1 parent bda7bef commit 88d50e8

File tree

5 files changed

+73
-39
lines changed

5 files changed

+73
-39
lines changed

paperqa/docs.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727
search_prompt,
2828
select_paper_prompt,
2929
summary_prompt,
30+
get_score,
3031
)
3132
from .readers import read_doc
3233
from .types import Answer, Context
33-
from .utils import maybe_is_text, md5sum, gather_with_concurrency
34+
from .utils import maybe_is_text, md5sum, gather_with_concurrency, guess_is_4xx
3435

3536
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True)
3637
langchain.llm_cache = SQLiteCache(CACHE_PATH)
@@ -373,31 +374,46 @@ async def aget_evidence(
373374
docs = self._faiss_index.similarity_search(
374375
answer.question, k=_k, fetch_k=5 * _k
375376
)
377+
# ok now filter
378+
if key_filter is not None:
379+
docs = [doc for doc in docs if doc.metadata["dockey"] in key_filter][:k]
376380

377381
async def process(doc):
378382
if doc.metadata["dockey"] in self._deleted_keys:
379383
return None, None
380-
if key_filter is not None and doc.metadata["dockey"] not in key_filter:
381-
return None, None
382384
# check if it is already in answer (possible in agent setting)
383385
if doc.metadata["key"] in [c.key for c in answer.contexts]:
384386
return None, None
385387
callbacks = [OpenAICallbackHandler()] + get_callbacks(
386388
"evidence:" + doc.metadata["key"]
387389
)
388390
summary_chain = make_chain(summary_prompt, self.summary_llm)
389-
c = Context(
390-
key=doc.metadata["key"],
391-
citation=doc.metadata["citation"],
392-
context=await summary_chain.arun(
391+
# This is dangerous because it
392+
# could mask errors that are important
393+
# I also cannot know what the exception
394+
# type is because any model could be used
395+
# my best idea is see if there is a 4XX
396+
# http code in the exception
397+
try:
398+
context = await summary_chain.arun(
393399
question=answer.question,
394400
context_str=doc.page_content,
395401
citation=doc.metadata["citation"],
396402
callbacks=callbacks,
397-
),
403+
)
404+
except Exception as e:
405+
if guess_is_4xx(e):
406+
return None, None
407+
raise e
408+
c = Context(
409+
key=doc.metadata["key"],
410+
citation=doc.metadata["citation"],
411+
context=context,
398412
text=doc.page_content,
413+
score=get_score(context),
399414
)
400415
if "not applicable" not in c.context.casefold():
416+
print(c.score)
401417
return c, callbacks[0]
402418
return None, None
403419

@@ -411,7 +427,7 @@ async def process(doc):
411427
contexts = [c for c, _ in results if c is not None]
412428
if len(contexts) == 0:
413429
return answer
414-
contexts = sorted(contexts, key=lambda x: len(x.context), reverse=True)
430+
contexts = sorted(contexts, key=lambda x: x.score, reverse=True)
415431
contexts = contexts[:max_sources]
416432
# add to answer (if not already there)
417433
keys = [c.key for c in answer.contexts]
@@ -499,11 +515,12 @@ async def aquery(
499515
if answer is None:
500516
answer = Answer(query)
501517
if len(answer.contexts) == 0:
502-
if key_filter or (key_filter is None and len(self.docs) > 5):
518+
if key_filter or (key_filter is None and len(self.docs) > max_sources):
503519
callbacks = [OpenAICallbackHandler()] + get_callbacks("filter")
504520
keys = await self.adoc_match(answer.question, callbacks=callbacks)
505521
answer.tokens += callbacks[0].total_tokens
506522
answer.cost += callbacks[0].total_cost
523+
key_filter = True if len(keys) > 0 else False
507524
answer = await self.aget_evidence(
508525
answer,
509526
k=k,
@@ -532,8 +549,8 @@ async def aquery(
532549
answer.tokens += cb.total_tokens
533550
answer.cost += cb.total_cost
534551
# it still happens lol
535-
if "(Foo2012)" in answer_text:
536-
answer_text = answer_text.replace("(Foo2012)", "")
552+
if "(Example2012)" in answer_text:
553+
answer_text = answer_text.replace("(Example2012)", "")
537554
for c in contexts:
538555
key = c.key
539556
text = c.context

paperqa/qaprompts.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
from datetime import datetime
33
from typing import Any, Dict, List, Optional
4+
import re
45

56
import langchain.prompts as prompts
67
from langchain.callbacks.manager import AsyncCallbackManagerForChainRun
@@ -11,27 +12,26 @@
1112

1213
summary_prompt = prompts.PromptTemplate(
1314
input_variables=["question", "context_str", "citation"],
14-
template="Summarize and provide direct quotes from the text below to help answer a question. "
15-
"Do not directly answer the question, instead summarize and "
16-
"quote to give evidence to help answer the question. "
17-
"Do not use outside sources. "
18-
'Reply with only "Not applicable" if the text is unrelated to the question. '
19-
"Use 100 or less words."
15+
template="Summarize the text below to help answer a question. "
16+
"Do not directly answer the question, instead summarize "
17+
"to give evidence to help answer the question. Include direct quotes. "
18+
'Reply "Not applicable" if text is irrelevant. '
19+
"Use around 100 words. At the end of your response, provide a score from 1-10 on a newline "
20+
"indicating relevance to question. Do not explain your score. "
2021
"\n\n"
2122
"{context_str}\n"
2223
"Extracted from {citation}\n"
2324
"Question: {question}\n"
2425
"Relevant Information Summary:",
2526
)
2627

27-
2828
qa_prompt = prompts.PromptTemplate(
2929
input_variables=["question", "context_str", "length"],
3030
template="Write an answer ({length}) "
3131
"for the question below based on the provided context. "
3232
"If the context provides insufficient information, "
3333
'reply "I cannot answer". '
34-
"For each sentence in your answer, indicate which sources most support it "
34+
"For each part of your answer, indicate which sources most support it "
3535
"via valid citation markers at the end of sentences, like (Example2012). "
3636
"Answer in an unbiased, comprehensive, and scholarly tone. "
3737
"If the question is subjective, provide an opinionated answer in the concluding 1-2 sentences. "
@@ -98,12 +98,21 @@ async def agenerate(
9898
def make_chain(prompt, llm):
9999
if type(llm) == ChatOpenAI:
100100
system_message_prompt = SystemMessage(
101-
content="You are a scholarly researcher that answers in an unbiased, concise, scholarly tone. "
102-
"You sometimes refuse to answer if there is insufficient information. "
103-
"If there are potentially ambiguous terms or acronyms, first define them. ",
101+
content="Answer in an unbiased, concise, scholarly tone. "
102+
"You may refuse to answer if there is insufficient information. "
103+
"If there are ambiguous terms or acronyms, first define them. ",
104104
)
105105
human_message_prompt = HumanMessagePromptTemplate(prompt=prompt)
106106
prompt = ChatPromptTemplate.from_messages(
107107
[system_message_prompt, human_message_prompt]
108108
)
109109
return FallbackLLMChain(prompt=prompt, llm=llm)
110+
111+
112+
def get_score(text):
113+
score = re.search(r"[sS]core[:is\s]+([0-9]+)", text)
114+
if score:
115+
return int(score.group(1))
116+
if len(text) < 100:
117+
return 1
118+
return 5

paperqa/types.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,29 @@
55
StrPath = Union[str, Path]
66

77

8+
@dataclass
9+
class Context:
10+
"""A class to hold the context of a question."""
11+
12+
key: str
13+
citation: str
14+
context: str
15+
text: str
16+
score: int = 5
17+
18+
def __str__(self) -> str:
19+
"""Return the context as a string."""
20+
return self.context
21+
22+
823
@dataclass
924
class Answer:
1025
"""A class to hold the answer to a question."""
1126

1227
question: str
1328
answer: str = ""
1429
context: str = ""
15-
contexts: List[Any] = None
30+
contexts: List[Context] = None
1631
references: str = ""
1732
formatted_answer: str = ""
1833
passages: Dict[str, str] = None
@@ -29,17 +44,3 @@ def __post_init__(self):
2944
def __str__(self) -> str:
3045
"""Return the answer as a string."""
3146
return self.formatted_answer
32-
33-
34-
@dataclass
35-
class Context:
36-
"""A class to hold the context of a question."""
37-
38-
key: str
39-
citation: str
40-
context: str
41-
text: str
42-
43-
def __str__(self) -> str:
44-
"""Return the context as a string."""
45-
return self.context

paperqa/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import math
22
import string
3+
import re
34
import asyncio
45

56
import pypdf
@@ -80,3 +81,9 @@ async def sem_coro(coro):
8081
return await coro
8182

8283
return await asyncio.gather(*(sem_coro(c) for c in coros))
84+
85+
86+
def guess_is_4xx(msg: str) -> bool:
87+
if re.search(r"4\d\d", msg):
88+
return True
89+
return False

paperqa/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.12.0"
1+
__version__ = "1.13.0"

0 commit comments

Comments
 (0)