Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,25 @@ def make_clinical_trial_status(
total_clinical_trials: int,
relevant_clinical_trials: int,
evidence_count: int,
relevant_evidence_count: int,
cost: float,
) -> str:
return (
f"Status: Paper Count={total_paper_count}"
f" | Relevant Papers={relevant_paper_count}"
f" | Clinical Trial Count={total_clinical_trials}"
f" | Relevant Clinical Trials={relevant_clinical_trials}"
f" | Current Evidence={evidence_count}"
f" | Evidence Count={evidence_count}"
f" | Relevant Evidence={relevant_evidence_count}"
f" | Current Cost=${cost:.4f}"
)


# SEE: https://regex101.com/r/L0L5MH/1
# SEE: https://regex101.com/r/L0L5MH/4
CLINICAL_STATUS_SEARCH_REGEX_PATTERN: str = (
r"Status: Paper Count=(\d+) \| Relevant Papers=(\d+)(?:\s\|\sClinical Trial"
r" Count=(\d+)\s\|\sRelevant Clinical Trials=(\d+))?\s\|\sCurrent Evidence=(\d+)"
r"Status: Paper Count=(\d+)\s\|\sRelevant Papers=(\d+)"
r"(?:\s\|\sClinical Trial Count=(\d+)\s\|\sRelevant Clinical Trials=(\d+))?"
r"\s\|\sEvidence Count=(\d+)\s\|\sRelevant Evidence=(\d+)"
)


Expand Down Expand Up @@ -195,7 +198,8 @@ def clinical_trial_status(state: "EnvironmentState") -> str:
in getattr(c.text.doc, "other", {}).get("client_source", [])
}
),
evidence_count=len(relevant_contexts),
evidence_count=len(state.session.contexts),
relevant_evidence_count=len(relevant_contexts),
cost=state.session.cost,
)

Expand Down
34 changes: 24 additions & 10 deletions src/paperqa/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,17 @@


def make_status(
total_paper_count: int, relevant_paper_count: int, evidence_count: int, cost: float
total_paper_count: int,
relevant_paper_count: int,
evidence_count: int,
relevant_evidence_count: int,
cost: float,
) -> str:
return (
f"Status: Paper Count={total_paper_count}"
f" | Relevant Papers={relevant_paper_count} | Current Evidence={evidence_count}"
f" | Relevant Papers={relevant_paper_count}"
f" | Evidence Count={evidence_count}"
f" | Relevant Evidence={relevant_evidence_count}"
f" | Current Cost=${cost:.4f}"
)

Expand All @@ -39,7 +45,8 @@ def default_status(state: "EnvironmentState") -> str:
return make_status(
total_paper_count=len(state.docs.docs),
relevant_paper_count=len({c.text.doc.dockey for c in relevant_contexts}),
evidence_count=len(relevant_contexts),
evidence_count=len(state.session.contexts),
relevant_evidence_count=len(relevant_contexts),
cost=state.session.cost,
)

Expand All @@ -60,9 +67,10 @@ class EnvironmentState(BaseModel):
),
)

# SEE: https://regex101.com/r/RmuVdC/1
# SEE: https://regex101.com/r/RmuVdC/3
STATUS_SEARCH_REGEX_PATTERN: ClassVar[str] = (
r"Status: Paper Count=(\d+) \| Relevant Papers=(\d+) \| Current Evidence=(\d+)"
r"Status: Paper Count=(\d+)\s\|\sRelevant Papers=(\d+)"
r"\s\|\sEvidence Count=(\d+)\s\|\sRelevant Evidence=(\d+)"
)

@computed_field # type: ignore[prop-decorator]
Expand Down Expand Up @@ -272,7 +280,13 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str:
]
)

best_evidence = f" Best evidence(s):\n\n{top_contexts}" if top_contexts else ""
# Include 'current question' because different questions will lead to
# different best evidences being shown
best_evidence = (
f" Best evidence(s) for the current question:\n\n{top_contexts}"
if top_contexts
else ""
)

if f"{self.TOOL_FN_NAME}_completed" in self.settings.agent.callbacks:
await asyncio.gather(
Expand All @@ -298,7 +312,7 @@ class GenerateAnswer(NamedTool):

async def gen_answer(self, state: EnvironmentState) -> str:
"""
Generate an answer using current evidence.
Generate an answer using relevant evidence.

The tool may fail, indicating that better or different evidence should be found.
Aim for at least five pieces of evidence from multiple sources before invoking this tool.
Expand Down Expand Up @@ -350,7 +364,7 @@ async def gen_answer(self, state: EnvironmentState) -> str:
# Use to separate answer from status
# NOTE: can match failure to answer or an actual answer
ANSWER_SPLIT_REGEX_PATTERN: ClassVar[str] = (
r" \| " + EnvironmentState.STATUS_SEARCH_REGEX_PATTERN
r"\s\|\s" + EnvironmentState.STATUS_SEARCH_REGEX_PATTERN
)

@classmethod
Expand All @@ -359,7 +373,7 @@ def extract_answer_from_message(cls, content: str) -> str:
answer, *rest = re.split(
pattern=cls.ANSWER_SPLIT_REGEX_PATTERN, string=content, maxsplit=1
)
return answer if len(rest) == 4 else "" # noqa: PLR2004
return answer if len(rest) == 5 else "" # noqa: PLR2004


class Reset(NamedTool):
Expand All @@ -383,7 +397,7 @@ class Complete(NamedTool):

# Use to separate certainty from status
CERTAINTY_SPLIT_REGEX_PATTERN: ClassVar[str] = (
r" \| " + EnvironmentState.STATUS_SEARCH_REGEX_PATTERN
r"\s\|\s" + EnvironmentState.STATUS_SEARCH_REGEX_PATTERN
)

NO_ANSWER_PHRASE: ClassVar[str] = "No answer generated."
Expand Down
8 changes: 4 additions & 4 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,11 @@ async def test_successful_memory_agent(agent_test_settings: Settings) -> None:
" and you have already tried to answer several times,"
" you can terminate by calling the {complete_tool_name} tool."
" The current status of evidence/papers/cost is "
f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, cost=0.0)}" # Started 0 # noqa: E501
f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, relevant_evidence_count=0, cost=0.0)}" # Started 0 # noqa: E501
"\n\nTool request message '' for tool calls: paper_search(query='XAI for"
" chemical property prediction', min_year='2018', max_year='2024')"
f" [id={memory_id}]\n\nTool response message '"
f"{make_status(total_paper_count=2, relevant_paper_count=0, evidence_count=0, cost=0.0)}" # Found 2 # noqa: E501
f"{make_status(total_paper_count=2, relevant_paper_count=0, evidence_count=0, relevant_evidence_count=0, cost=0.0)}" # Found 2 # noqa: E501
f"' for tool call ID {memory_id} of tool 'paper_search'"
),
input=(
Expand All @@ -412,7 +412,7 @@ async def test_successful_memory_agent(agent_test_settings: Settings) -> None:
" and you have already tried to answer several times,"
" you can terminate by calling the {complete_tool_name} tool."
" The current status of evidence/papers/cost is "
f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, cost=0.0)}"
f"{make_status(total_paper_count=0, relevant_paper_count=0, evidence_count=0, relevant_evidence_count=0, cost=0.0)}" # noqa: E501
),
output=(
"Tool request message '' for tool calls: paper_search(query='XAI for"
Expand Down Expand Up @@ -837,7 +837,7 @@ def test_tool_schema(agent_test_settings: Settings) -> None:
"info": {
"name": "gen_answer",
"description": (
"Generate an answer using current evidence.\n\nThe tool may fail,"
"Generate an answer using relevant evidence.\n\nThe tool may fail,"
" indicating that better or different evidence should be"
" found.\nAim for at least five pieces of evidence from multiple"
" sources before invoking this tool.\nFeel free to invoke this tool"
Expand Down
Loading