Skip to content

Commit 5926831

Browse files
authored
Added memory to query (#140)
* Completed memory implementation * Fixed some missing types
1 parent 899f145 commit 5926831

File tree

8 files changed

+146
-29
lines changed

8 files changed

+146
-29
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,12 @@ Version 3 includes many changes to type the code, make it more focused/modular,
187187

188188
The following new features are in v3:
189189

190-
1. `add_url` and `add_file` are now supported for adding from URLs and file objects
191-
2. Prompts can be customized, and now can be executed pre and post query
192-
3. Consistent use of `dockey` and `docname` for unique and natural language names enable better tracking with external databases
193-
4. Texts and embeddings are no longer required to be part of `Docs` object, so you can use external databases or other strategies to manage them
194-
5. Various simplifications, bug fixes, and performance improvements
190+
1. Memory is now possible in `query` by setting `Docs(memory=True)` - this means follow-up questions will have a record of the previous question and answer.
191+
2. `add_url` and `add_file` are now supported for adding from URLs and file objects
192+
3. Prompts can be customized, and now can be executed pre and post query
193+
4. Consistent use of `dockey` and `docname` for unique and natural language names enable better tracking with external databases
194+
5. Texts and embeddings are no longer required to be part of `Docs` object, so you can use external databases or other strategies to manage them
195+
6. Various simplifications, bug fixes, and performance improvements
195196

196197
### Naming
197198

paperqa/chains.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@
88
)
99
from langchain.chains import LLMChain
1010
from langchain.chat_models import ChatOpenAI
11-
from langchain.prompts import StringPromptTemplate
11+
from langchain.memory.chat_memory import BaseChatMemory
12+
from langchain.prompts import BasePromptTemplate, PromptTemplate, StringPromptTemplate
1213
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
1314
from langchain.schema import LLMResult, SystemMessage
1415

1516
from .types import CBManager
1617

18+
memory_prompt = PromptTemplate(
19+
input_variables=["memory", "start"],
20+
template="Previous answers that may be helpful:\n\n{memory}\n\n"
21+
"----------------------------------------\n\n"
22+
"{start}",
23+
)
24+
1725

1826
class FallbackLLMChain(LLMChain):
1927
"""Chain that falls back to synchronous generation if the async generation fails."""
@@ -32,16 +40,44 @@ async def agenerate(
3240
return self.generate(input_list)
3341

3442

43+
# TODO: If upstream is fixed remove this
44+
45+
46+
class ExtendedHumanMessagePromptTemplate(HumanMessagePromptTemplate):
47+
prompt: BasePromptTemplate
48+
49+
3550
def make_chain(
36-
prompt: StringPromptTemplate, llm: BaseLanguageModel, skip_system: bool = False
51+
prompt: StringPromptTemplate,
52+
llm: BaseLanguageModel,
53+
skip_system: bool = False,
54+
memory: Optional[BaseChatMemory] = None,
3755
) -> FallbackLLMChain:
56+
if memory and len(memory.load_memory_variables({})["memory"]) > 0:
57+
# we copy the prompt so we don't modify the original
58+
# TODO: Figure out pipeline prompts to avoid this
59+
# the problem with pipeline prompts is that
60+
# the memory is a constant (or partial), not a prompt
61+
# and I cannot seem to make an empty prompt (or str)
62+
# work as an input to pipeline prompt
63+
assert isinstance(
64+
prompt, PromptTemplate
65+
), "Memory only works with prompt templates - see comment above"
66+
assert "memory" in memory.load_memory_variables({})
67+
new_prompt = PromptTemplate(
68+
input_variables=prompt.input_variables,
69+
template=memory_prompt.format(
70+
start=prompt.template, **memory.load_memory_variables({})
71+
),
72+
)
73+
prompt = new_prompt
3874
if type(llm) == ChatOpenAI:
3975
system_message_prompt = SystemMessage(
4076
content="Answer in an unbiased, concise, scholarly tone. "
4177
"You may refuse to answer if there is insufficient information. "
4278
"If there are ambiguous terms or acronyms, first define them. ",
4379
)
44-
human_message_prompt = HumanMessagePromptTemplate(prompt=prompt)
80+
human_message_prompt = ExtendedHumanMessagePromptTemplate(prompt=prompt)
4581
if skip_system:
4682
chat_prompt = ChatPromptTemplate.from_messages([human_message_prompt])
4783
else:

paperqa/docs.py

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from langchain.chat_models import ChatOpenAI
1313
from langchain.embeddings.base import Embeddings
1414
from langchain.embeddings.openai import OpenAIEmbeddings
15+
from langchain.memory import ConversationTokenBufferMemory
16+
from langchain.memory.chat_memory import BaseChatMemory
1517
from langchain.vectorstores import FAISS, VectorStore
1618
from pydantic import BaseModel, validator
1719

@@ -48,19 +50,39 @@ class Docs(BaseModel, arbitrary_types_allowed=True, smart_union=True):
4850
max_concurrent: int = 5
4951
deleted_dockeys: Set[DocKey] = set()
5052
prompts: PromptCollection = PromptCollection()
53+
memory: bool = False
54+
memory_model: Optional[BaseChatMemory] = None
5155

5256
# TODO: Not sure how to get this to work
5357
# while also passing mypy checks
5458
@validator("llm", "summary_llm")
5559
def check_llm(cls, v: Union[BaseLanguageModel, str]) -> BaseLanguageModel:
5660
if type(v) is str:
5761
return ChatOpenAI(temperature=0.1, model=v, client=None)
58-
return v
62+
return cast(BaseLanguageModel, v)
5963

6064
@validator("summary_llm", always=True)
6165
def copy_llm_if_not_set(cls, v, values):
6266
return v or values["llm"]
6367

68+
@validator("memory_model", always=True)
69+
def check_memory_model(cls, v, values):
70+
if values["memory"]:
71+
if v is None:
72+
return ConversationTokenBufferMemory(
73+
llm=values["summary_llm"],
74+
max_token_limit=512,
75+
memory_key="memory",
76+
human_prefix="Question",
77+
ai_prefix="Answer",
78+
input_key="Question",
79+
output_key="Answer",
80+
)
81+
if v.memory_variables()[0] != "memory":
82+
raise ValueError("Memory model must have memory_variables=['memory']")
83+
return values["memory_model"]
84+
return None
85+
6486
def update_llm(
6587
self,
6688
llm: Union[BaseLanguageModel, str],
@@ -76,7 +98,7 @@ def update_llm(
7698
summary_llm = llm
7799
self.summary_llm = cast(BaseLanguageModel, summary_llm)
78100

79-
def get_unique_name(self, docname: str) -> str:
101+
def _get_unique_name(self, docname: str) -> str:
80102
"""Create a unique name given proposed name"""
81103
suffix = ""
82104
while docname + suffix in self.docnames:
@@ -182,12 +204,14 @@ def add(
182204
if match is not None:
183205
year = match.group(1) # type: ignore
184206
docname = f"{author}{year}"
185-
docname = self.get_unique_name(docname)
207+
docname = self._get_unique_name(docname)
186208
doc = Doc(docname=docname, citation=citation, dockey=dockey)
187209
texts = read_doc(path, doc, chunk_chars=chunk_chars, overlap=100)
188210
# loose check to see if document was loaded
189-
if len(texts[0].text) < 10 or (
190-
not disable_check and not maybe_is_text(texts[0].text)
211+
if (
212+
len(texts) == 0
213+
or len(texts[0].text) < 10
214+
or (not disable_check and not maybe_is_text(texts[0].text))
191215
):
192216
raise ValueError(
193217
f"This does not look like a text document: {path}. Path disable_check to ignore this error."
@@ -206,7 +230,7 @@ def add_texts(
206230
if len(texts) == 0:
207231
raise ValueError("No texts to add.")
208232
if doc.docname in self.docnames:
209-
new_docname = self.get_unique_name(doc.docname)
233+
new_docname = self._get_unique_name(doc.docname)
210234
for t in texts:
211235
t.name = t.name.replace(doc.docname, new_docname)
212236
doc.docname = new_docname
@@ -261,7 +285,9 @@ async def adoc_match(
261285
query, k=k + len(self.deleted_dockeys)
262286
)
263287
matched_docs = [self.docs[m.metadata["dockey"]] for m in matches]
264-
chain = make_chain(self.prompts.select, cast(BaseLanguageModel, self.llm))
288+
chain = make_chain(
289+
self.prompts.select, cast(BaseLanguageModel, self.llm), skip_system=True
290+
)
265291
papers = [f"{d.docname}: {d.citation}" for d in matched_docs]
266292
result = await chain.arun( # type: ignore
267293
question=query, papers="\n".join(papers), callbacks=get_callbacks("filter")
@@ -298,6 +324,11 @@ def _build_texts_index(self):
298324
metadatas=metadatas,
299325
)
300326

327+
def clear_memory(self):
328+
"""Clear the memory of the model."""
329+
if self.memory_model is not None:
330+
self.memory_model.clear()
331+
301332
def get_evidence(
302333
self,
303334
answer: Answer,
@@ -375,7 +406,9 @@ async def aget_evidence(
375406

376407
async def process(match):
377408
callbacks = get_callbacks("evidence:" + match.metadata["name"])
378-
summary_chain = make_chain(self.prompts.summary, self.summary_llm)
409+
summary_chain = make_chain(
410+
self.prompts.summary, self.summary_llm, memory=self.memory_model
411+
)
379412
# This is dangerous because it
380413
# could mask errors that are important- like auth errors
381414
# I also cannot know what the exception
@@ -391,7 +424,7 @@ async def process(match):
391424
callbacks=callbacks,
392425
)
393426
except Exception as e:
394-
if guess_is_4xx(e):
427+
if guess_is_4xx(str(e)):
395428
return None
396429
raise e
397430
if "not applicable" in context.lower():
@@ -476,9 +509,9 @@ async def aquery(
476509
if answer is None:
477510
answer = Answer(question=query, answer_length=length_prompt)
478511
if len(answer.contexts) == 0:
479-
# this is heuristic - max_sources and len(docs) are not
512+
# this is heuristic - k and len(docs) are not
480513
# comparable - one is chunks and one is docs
481-
if key_filter or (key_filter is None and len(self.docs) > max_sources):
514+
if key_filter or (key_filter is None and len(self.docs) > k):
482515
keys = await self.adoc_match(
483516
answer.question, get_callbacks=get_callbacks
484517
)
@@ -492,19 +525,27 @@ async def aquery(
492525
get_callbacks=get_callbacks,
493526
)
494527
if self.prompts.pre is not None:
495-
chain = make_chain(self.prompts.pre, self.llm)
528+
chain = make_chain(
529+
self.prompts.pre,
530+
cast(BaseLanguageModel, self.llm),
531+
memory=self.memory_model,
532+
)
496533
pre = await chain.arun(
497534
question=answer.question, callbacks=get_callbacks("pre")
498535
)
499536
answer.context = pre + "\n\n" + answer.context
500537
bib = dict()
501-
if len(answer.context) < 10:
538+
if len(answer.context) < 10 and not self.memory:
502539
answer_text = (
503540
"I cannot answer this question due to insufficient information."
504541
)
505542
else:
506543
callbacks = get_callbacks("answer")
507-
qa_chain = make_chain(self.prompts.qa, self.llm)
544+
qa_chain = make_chain(
545+
self.prompts.qa,
546+
cast(BaseLanguageModel, self.llm),
547+
memory=self.memory_model,
548+
)
508549
answer_text = await qa_chain.arun(
509550
context=answer.context,
510551
answer_length=answer.answer_length,
@@ -531,11 +572,20 @@ async def aquery(
531572
answer.references = bib_str
532573

533574
if self.prompts.post is not None:
534-
chain = make_chain(self.prompts.post, self.llm)
575+
chain = make_chain(
576+
self.prompts.post,
577+
cast(BaseLanguageModel, self.llm),
578+
memory=self.memory_model,
579+
)
535580
post = await chain.arun(**answer.dict(), callbacks=get_callbacks("post"))
536581
answer.answer = post
537582
answer.formatted_answer = f"Question: {query}\n\n{post}\n"
538583
if len(bib) > 0:
539584
answer.formatted_answer += f"\nReferences\n\n{bib_str}\n"
585+
if self.memory_model is not None:
586+
answer.memory = self.memory_model.load_memory_variables(inputs={})["memory"]
587+
self.memory_model.save_context(
588+
{"Question": answer.question}, {"Answer": answer.answer}
589+
)
540590

541591
return answer

paperqa/prompts.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
input_variables=["text", "citation", "question", "summary_length"],
77
template="Summarize the text below to help answer a question. "
88
"Do not directly answer the question, instead summarize "
9-
"to give evidence to help answer the question. Include direct quotes. "
9+
"to give evidence to help answer the question. "
1010
'Reply "Not applicable" if text is irrelevant. '
1111
"Use {summary_length}. At the end of your response, provide a score from 1-10 on a newline "
1212
"indicating relevance to question. Do not explain your score. "
1313
"\n\n"
14-
"{text}\n"
15-
"Extracted from {citation}\n"
14+
"{text}\n\n"
15+
"Excerpt from {citation}\n"
1616
"Question: {question}\n"
1717
"Relevant Information Summary:",
1818
)

paperqa/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class Answer(BaseModel):
106106
dockey_filter: Optional[Set[DocKey]] = None
107107
summary_length: str = "about 100 words"
108108
answer_length: str = "about 100 words"
109+
memory: Optional[str] = None
109110
# these two below are for convenience
110111
# and are not set. But you can set them
111112
# if you want to use them.

paperqa/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.0.0.dev2"
1+
__version__ = "3.0.0.dev3"

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
packages=["paperqa", "paperqa.contrib"],
1919
install_requires=[
2020
"pypdf",
21-
"langchain>=0.0.195",
21+
"langchain>=0.0.198",
2222
"openai >= 0.27.8",
2323
"faiss-cpu",
2424
"PyCryptodome",
2525
"html2text",
26-
"tiktoken",
26+
"tiktoken>=0.4.0",
2727
],
2828
test_suite="tests",
2929
long_description=long_description,

tests/test_paperqa.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,32 @@ def test_post_prompt():
529529
f.write(r.text)
530530
docs.add(doc_path, "WikiMedia Foundation, 2023, Accessed now")
531531
docs.query("What country is Bates from?")
532+
533+
534+
def test_memory():
535+
docs = Docs(memory=True, k=3, max_sources=1, llm="gpt-3.5-turbo", key_filter=False)
536+
docs.add_url(
537+
"https://en.wikipedia.org/wiki/Red_Army",
538+
citation="WikiMedia Foundation, 2023, Accessed now",
539+
dockey="test",
540+
)
541+
answer1 = docs.query("When did the Soviet Union and Japan agree to a cease-fire?")
542+
print(answer1.answer)
543+
assert answer1.memory is not None
544+
assert "1939" in answer1.answer
545+
assert "Answer" in docs.memory_model.load_memory_variables({})["memory"]
546+
answer2 = docs.query("When was the conflict resolved?")
547+
assert "1941" in answer2.answer or "1945" in answer2.answer
548+
assert answer2.memory is not None
549+
assert "Answer" in docs.memory_model.load_memory_variables({})["memory"]
550+
print(answer2.answer)
551+
552+
docs.clear_memory()
553+
554+
answer3 = docs.query("When was the conflict resolved?")
555+
assert answer3.memory is not None
556+
assert (
557+
"I cannot answer" in answer3.answer
558+
or "insufficient" in answer3.answer
559+
or "does not provide" in answer3.answer
560+
)

0 commit comments

Comments
 (0)