Skip to content

Commit 4eeed82

Browse files
committed
Made marginal search optional
1 parent 06de777 commit 4eeed82

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

paperqa/docs.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
from langchain.llms.base import LLM
1313
from langchain.chains import LLMChain
1414
from langchain.callbacks import get_openai_callback
15+
from langchain.cache import InMemoryCache
16+
import langchain
17+
18+
langchain.llm_cache = InMemoryCache()
1519

1620

1721
@dataclass
@@ -200,14 +204,20 @@ def get_evidence(
200204
answer: Answer,
201205
k: int = 3,
202206
max_sources: int = 5,
207+
marginal_relevance: bool = True,
203208
) -> str:
204209
if self._faiss_index is None:
205210
self._build_faiss_index()
206211

207212
# want to work through indices but less k
208-
docs = self._faiss_index.max_marginal_relevance_search(
209-
answer.question, k=k, fetch_k=5 * k
210-
)
213+
if marginal_relevance:
214+
docs = self._faiss_index.max_marginal_relevance_search(
215+
answer.question, k=k, fetch_k=5 * k
216+
)
217+
else:
218+
docs = self._faiss_index.similarity_search(
219+
answer.question, k=k, fetch_k=5 * k
220+
)
211221
for doc in docs:
212222
c = (
213223
doc.metadata["key"],
@@ -251,9 +261,14 @@ def query_gen(
251261
k: int = 10,
252262
max_sources: int = 5,
253263
length_prompt: str = "about 100 words",
264+
marginal_relevance: bool = True,
254265
):
255266
yield from self._query(
256-
query, k=k, max_sources=max_sources, length_prompt=length_prompt
267+
query,
268+
k=k,
269+
max_sources=max_sources,
270+
length_prompt=length_prompt,
271+
marginal_relevance=marginal_relevance,
257272
)
258273

259274
def query(
@@ -262,20 +277,37 @@ def query(
262277
k: int = 10,
263278
max_sources: int = 5,
264279
length_prompt: str = "about 100 words",
280+
marginal_relevance: bool = True,
265281
):
266282
for answer in self._query(
267-
query, k=k, max_sources=max_sources, length_prompt=length_prompt
283+
query,
284+
k=k,
285+
max_sources=max_sources,
286+
length_prompt=length_prompt,
287+
marginal_relevance=marginal_relevance,
268288
):
269289
pass
270290
return answer
271291

272-
def _query(self, query: str, k: int, max_sources: int, length_prompt: str):
292+
def _query(
293+
self,
294+
query: str,
295+
k: int,
296+
max_sources: int,
297+
length_prompt: str,
298+
marginal_relevance: bool,
299+
):
273300
if k < max_sources:
274301
raise ValueError("k should be greater than max_sources")
275302
tokens = 0
276303
answer = Answer(query)
277304
with get_openai_callback() as cb:
278-
for answer in self.get_evidence(answer, k=k, max_sources=max_sources):
305+
for answer in self.get_evidence(
306+
answer,
307+
k=k,
308+
max_sources=max_sources,
309+
marginal_relevance=marginal_relevance,
310+
):
279311
yield answer
280312
tokens += cb.total_tokens
281313
context_str, citations = answer.context, answer.contexts
@@ -290,11 +322,14 @@ def _query(self, query: str, k: int, max_sources: int, length_prompt: str):
290322
answer_text = self.qa_chain.run(
291323
question=query, context_str=context_str, length=length_prompt
292324
)[1:]
293-
if maybe_is_truncated(answer_text):
294-
answer_text = self.edit_chain.run(
295-
question=query, answer=answer_text
296-
)
325+
# if maybe_is_truncated(answer_text):
326+
# answer_text = self.edit_chain.run(
327+
# question=query, answer=answer_text
328+
# )
297329
tokens += cb.total_tokens
330+
# it still happens lol
331+
if "(Foo2012)" in answer_text:
332+
answer_text = answer_text.replace("(Foo2012)", "")
298333
for key, citation, summary, text in citations:
299334
# do check for whole key (so we don't catch Callahan2019a with Callahan2019)
300335
skey = key.split(" ")[0]

paperqa/qaprompts.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,8 @@
2424
"For each sentence in your answer, indicate which sources most support it "
2525
"via valid citation markers at the end of sentences, like (Foo2012). "
2626
"Answer in an unbiased, balanced, and scientific tone. "
27-
"Use Markdown for formatting code or text. "
28-
# "write a complete unbiased answer prefixed by \"Answer:\""
29-
"\n--------------------\n"
27+
"Use Markdown for formatting code or text.\n\n"
3028
"{context_str}\n"
31-
"----------------------\n"
3229
"Question: {question}\n"
3330
"Answer: ",
3431
)

0 commit comments

Comments
 (0)