Skip to content

Commit eeece9f

Browse files
committed
Added support for latest langchain chat api
1 parent 8b47b69 commit eeece9f

File tree

3 files changed

+51
-35
lines changed

3 files changed

+51
-35
lines changed

paperqa/docs.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Tuple, Dict, Callable, Any
1+
from typing import List, Optional, Tuple, Dict, Callable, Any, Union
22
from functools import reduce
33
import os
44
import os
@@ -10,13 +10,13 @@
1010
qa_prompt,
1111
search_prompt,
1212
citation_prompt,
13-
chat_pref,
13+
make_chain,
1414
)
1515
from dataclasses import dataclass
1616
from .readers import read_doc
1717
from langchain.vectorstores import FAISS
1818
from langchain.embeddings.openai import OpenAIEmbeddings
19-
from langchain.llms import OpenAI, OpenAIChat
19+
from langchain.chat_models import ChatOpenAI
2020
from langchain.llms.base import LLM
2121
from langchain.chains import LLMChain
2222
from langchain.callbacks import get_openai_callback
@@ -64,7 +64,7 @@ def __init__(
6464
summary_llm: Optional[LLM] = None,
6565
name: str = "default",
6666
index_path: Optional[Path] = None,
67-
model_name: str = 'gpt-3.5-turbo'
67+
model_name: str = "gpt-3.5-turbo",
6868
) -> None:
6969
"""Initialize the collection of documents.
7070
@@ -82,26 +82,32 @@ def __init__(
8282
self.chunk_size_limit = chunk_size_limit
8383
self.keys = set()
8484
self._faiss_index = None
85-
if llm is None:
86-
llm = OpenAIChat(temperature=0.1, max_tokens=512, prefix_messages=chat_pref, model_name=model_name)
87-
if summary_llm is None:
88-
summary_llm = llm
8985
self.update_llm(llm, summary_llm)
9086
if index_path is None:
9187
index_path = Path.home() / ".paperqa" / name
9288
self.index_path = index_path
9389
self.name = name
9490

95-
def update_llm(self, llm: LLM, summary_llm: Optional[LLM] = None) -> None:
91+
def update_llm(
92+
self,
93+
llm: Optional[Union[LLM, str]] = None,
94+
summary_llm: Optional[Union[LLM, str]] = None,
95+
) -> None:
9696
"""Update the LLM for answering questions."""
97+
if llm is None:
98+
llm = "gpt-3.5-turbo"
99+
if type(llm) is str:
100+
llm = ChatOpenAI(temperature=0.1, model=llm)
101+
if type(summary_llm) is str:
102+
summary_llm = ChatOpenAI(temperature=0.1, model=summary_llm)
97103
self.llm = llm
98104
if summary_llm is None:
99105
summary_llm = llm
100106
self.summary_llm = summary_llm
101-
self.summary_chain = LLMChain(prompt=summary_prompt, llm=summary_llm)
102-
self.qa_chain = LLMChain(prompt=qa_prompt, llm=llm)
103-
self.search_chain = LLMChain(prompt=search_prompt, llm=llm)
104-
self.cite_chain = LLMChain(prompt=citation_prompt, llm=llm)
107+
self.summary_chain = make_chain(prompt=summary_prompt, llm=summary_llm)
108+
self.qa_chain = make_chain(prompt=qa_prompt, llm=llm)
109+
self.search_chain = make_chain(prompt=search_prompt, llm=summary_llm)
110+
self.cite_chain = make_chain(prompt=citation_prompt, llm=summary_llm)
105111

106112
def add(
107113
self,
@@ -112,12 +118,12 @@ def add(
112118
chunk_chars: Optional[int] = 3000,
113119
) -> None:
114120
"""Add a document to the collection."""
115-
116-
# first check to see if we already have this document
121+
122+
# first check to see if we already have this document
117123
# this way we don't make api call to create citation on file we already have
118124
if path in self.docs:
119125
raise ValueError(f"Document {path} already in collection.")
120-
126+
121127
if citation is None:
122128
# peak first chunk
123129
texts, _ = read_doc(path, "", "", chunk_chars=chunk_chars)
@@ -126,7 +132,6 @@ def add(
126132
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation:
127133
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}"
128134

129-
130135
if key is None:
131136
# get first name and year from citation
132137
try:
@@ -212,9 +217,7 @@ def __setstate__(self, state):
212217
except:
213218
# they use some special exception type, but I don't want to import it
214219
self._faiss_index = None
215-
self.update_llm(
216-
OpenAIChat(temperature=0.1, max_tokens=512, prefix_messages=chat_pref)
217-
)
220+
self.update_llm("gpt-3.5-turbo")
218221

219222
def _build_faiss_index(self):
220223
if self._faiss_index is None:
@@ -252,7 +255,9 @@ def get_evidence(
252255
doc.metadata["key"],
253256
doc.metadata["citation"],
254257
self.summary_chain.run(
255-
question=answer.question, context_str=doc.page_content
258+
question=answer.question,
259+
context_str=doc.page_content,
260+
citation=doc.metadata["citation"],
256261
),
257262
doc.page_content,
258263
)

paperqa/qaprompts.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import langchain.prompts as prompts
22
from datetime import datetime
3+
from langchain.chains import LLMChain
4+
from langchain.chat_models import ChatOpenAI
5+
from langchain.schema import HumanMessage, SystemMessage
6+
from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
7+
38

49
summary_prompt = prompts.PromptTemplate(
5-
input_variables=["question", "context_str"],
10+
input_variables=["question", "context_str", "citation"],
611
template="Summarize and provide direct quotes from the text below to help answer a question. "
7-
"Do not directly answer the question, instead provide a summary and quotes with the context of the question. "
12+
"Do not directly answer the question, instead summarize and "
13+
"quote to give evidence to help answer the question. "
814
"Do not use outside sources. "
915
'Reply with "Not applicable" if the text is unrelated to the question. '
1016
"Use 75 or less words."
1117
"\n\n"
1218
"{context_str}\n"
13-
"\n"
19+
"Extracted from {citation}\n"
1420
"Question: {question}\n"
1521
"Relevant Information Summary:",
1622
)
@@ -20,7 +26,7 @@
2026
input_variables=["question", "context_str", "length"],
2127
template="Write an answer ({length}) "
2228
"for the question below solely based on the provided context. "
23-
"If the context is irrelevant, "
29+
"If the context provides insufficient information, "
2430
'reply "I cannot answer". '
2531
"For each sentence in your answer, indicate which sources most support it "
2632
"via valid citation markers at the end of sentences, like (Example2012). "
@@ -35,8 +41,8 @@
3541
search_prompt = prompts.PromptTemplate(
3642
input_variables=["question"],
3743
template="We want to answer the following question: {question} \n"
38-
"Provide three different targeted keyword searches (one search per line) "
39-
"that will find papers that help answer the question. Do not use boolean operators. "
44+
"Provide three keyword searches (one search per line) "
45+
"that will find papers to help answer the question. Do not use boolean operators. "
4046
"Recent years are 2021, 2022, 2023.\n\n"
4147
"1.",
4248
)
@@ -55,10 +61,15 @@ def _get_datetime():
5561
partial_variables={"date": _get_datetime},
5662
)
5763

58-
chat_pref = [
59-
{
60-
"role": "system",
61-
"content": "You are a scholarly researcher that answers in an unbiased, scholarly tone. "
62-
"You sometimes refuse to answer if there is insufficient information.",
63-
}
64-
]
64+
65+
def make_chain(prompt, llm):
66+
if type(llm) == ChatOpenAI:
67+
system_message_prompt = SystemMessage(
68+
content="You are a scholarly researcher that answers in an unbiased, scholarly tone. "
69+
"You sometimes refuse to answer if there is insufficient information.",
70+
)
71+
human_message_prompt = HumanMessagePromptTemplate(prompt=prompt)
72+
prompt = ChatPromptTemplate.from_messages(
73+
[system_message_prompt, human_message_prompt]
74+
)
75+
return LLMChain(prompt=prompt, llm=llm)

paperqa/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.30"
1+
__version__ = "0.1.0"

0 commit comments

Comments
 (0)