Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,27 +252,30 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
status = state.status
logger.info(status)
# only show top n contexts for this particular question to the agent
# only show context above score 0, because 0 is a sentinel for irrelevance
sorted_relevant_contexts = sorted(
[
sorted_contexts = sorted(
(
c
for c in state.session.contexts
if ((c.question is None or c.question == question) and c.score > 0)
],
if c.question is None or c.question == question
),
key=lambda x: x.score,
reverse=True,
)

top_contexts = "\n".join(
top_contexts = "\n\n".join(
[
f"{n + 1}. {sc.context}\n"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note there was trailing whitespace here previously, now there's not

f"- {sc.context}"
for n, sc in enumerate(
sorted_relevant_contexts[: self.settings.agent.agent_evidence_n]
sorted_contexts[: self.settings.agent.agent_evidence_n]
)
]
)

best_evidence = f" Best evidence(s):\n\n{top_contexts}" if top_contexts else ""
best_evidence = (
f" Best evidence(s) for the current question:\n\n{top_contexts}"
if top_contexts
else ""
)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks:
await asyncio.gather(
Expand Down
3 changes: 2 additions & 1 deletion src/paperqa/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,8 @@ async def aget_evidence(
for r in llm_results:
session.add_tokens(r)

session.contexts += [c for c, _ in results if c is not None]
# Filter out failed context creations or irrelevant contexts
session.contexts += [c for c, _ in results if c is not None and c.score > 0]
return session

def query(
Expand Down
12 changes: 5 additions & 7 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,11 +711,9 @@ def new_status(state: EnvironmentState) -> str:
total_added_1 = int(split[1])
assert total_added_1 > 0, "Expected non-negative added evidence count"
assert len(env_state.get_relevant_contexts()) == total_added_1
# ensure 1 piece of top evidence is returned
assert "\n1." in response, "gather_evidence did not return any results"
assert (
"\n2." not in response
), "gather_evidence should return only 1 context, not 2"
response.count("\n- ") == 1
), "Expected exactly one best evidence to be shown"

# now adjust to give the agent 2x pieces of evidence
gather_evidence_tool.settings.agent.agent_evidence_n = 2
Expand Down Expand Up @@ -745,9 +743,9 @@ def new_status(state: EnvironmentState) -> str:
total_added_2 = int(split[1])
assert total_added_2 > 0, "Expected non-negative added evidence count"
assert len(env_state.get_relevant_contexts()) == total_added_1 + total_added_2
# ensure both evidences are returned
assert "\n1." in response, "gather_evidence did not return any results"
assert "\n2." in response, "gather_evidence should return 2 contexts"
assert (
response.count("\n- ") == 2
), "Expected both evidences to be shown as best evidences"

assert session.contexts, "Evidence did not return any results"
assert not session.answer, "Expected no answer yet"
Expand Down
39 changes: 15 additions & 24 deletions tests/test_paperqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import os
import pathlib
import pickle
import random
import re
import sys
from collections.abc import AsyncIterable, Sequence
Expand Down Expand Up @@ -1077,17 +1076,21 @@ async def test_unrelated_context(
assert await docs.aadd(
stub_data_dir / "bates.txt", "WikiMedia Foundation, 2023, Accessed now"
)
assert docs.texts, "Test requires at least one text"
session = await docs.aget_evidence(
"What do scientist estimate as the planetary composition of Jupyter?",
settings=agent_test_settings,
)
assert session.contexts, "Test relies on some contexts being added"
session.contexts.append( # Give a context so the rest of the test can run
Context(
context="George Washington is a founding father",
question="What do scientist estimate as the planetary composition of Jupyter?",
text=docs.texts[0],
score=1,
)
)
for c in session.contexts:
assert c.score <= 2, "Expected contexts to be considered irrelevant"
if c.score <= 0:
# Now, let's trick the system into thinking the context
# was at least somewhat relevant
c.score = random.randint(1, 2)
session = await docs.aquery(session, settings=agent_test_settings)
assert unsure_sentinel in session.answer

Expand Down Expand Up @@ -1652,6 +1655,8 @@ async def test_images_corrupt(stub_data_dir: Path) -> None:
)
assert districts_docname, "Expected successful image addition"
(districts_doc,) = (d for d in docs.docs.values() if d.docname == districts_docname)
(districts_text,) = docs.texts
assert not districts_text.text, "Test expects no text content from image addition"
for media in (t.media for t in docs.texts if t.doc == districts_doc and t.media):
for m in media:
# Validate the image, then chop the image in half (breaking it), and
Expand All @@ -1669,27 +1674,13 @@ async def test_images_corrupt(stub_data_dir: Path) -> None:

# By suppressing the use of images, we can actually gather evidence now
settings.answer.evidence_text_only_fallback = True
# The answer will be garbage, but let's make sure we didn't claim to use images
session = await docs.aget_evidence(
"What districts neighbor the Western Addition?", settings=settings
)
assert session.contexts, "Test relies on some contexts being added"
for c in session.contexts:
assert c.score <= 2, "Expected contexts to be considered irrelevant"
if c.score <= 0:
# Now, let's trick the system into thinking the context
# was at least somewhat relevant
c.score = random.randint(1, 2)
await docs.aquery(session, settings=settings)
assert session.used_contexts
assert session.cost > 0
contexts_used = [
c
for c in session.contexts
if c.id in session.used_contexts and c.text.doc == districts_doc
]
assert contexts_used
assert all(not bool(c.used_images) for c in contexts_used) # type: ignore[attr-defined]
assert (
not session.contexts
), "Expected no contexts to be made from a bad image that has no text"
assert session.cost > 0, "Expected some costs to have been incurred in our attempt"


def test_zotero() -> None:
Expand Down