diff --git a/paperqa/agents/tools.py b/paperqa/agents/tools.py index 26f36f22..d1ffe7ca 100644 --- a/paperqa/agents/tools.py +++ b/paperqa/agents/tools.py @@ -251,8 +251,15 @@ async def gather_evidence(self, question: str, state: EnvironmentState) -> str: status = state.status logger.info(status) + # only show top n contexts for this particular question to the agent sorted_contexts = sorted( - state.session.contexts, key=lambda x: x.score, reverse=True + [ + c + for c in state.session.contexts + if (c.question is None or c.question == question) + ], + key=lambda x: x.score, + reverse=True, ) top_contexts = "\n".join( diff --git a/paperqa/core.py b/paperqa/core.py index 3c5a4377..0eeaf60b 100644 --- a/paperqa/core.py +++ b/paperqa/core.py @@ -207,6 +207,7 @@ async def map_fxn_summary( return ( Context( context=context, + question=question, text=Text( text=text.text, name=text.name, diff --git a/paperqa/types.py b/paperqa/types.py index 69564b64..c68b9d5c 100644 --- a/paperqa/types.py +++ b/paperqa/types.py @@ -110,6 +110,13 @@ class Context(BaseModel): model_config = ConfigDict(extra="allow") context: str = Field(description="Summary of the text with respect to a question.") + question: str | None = Field( + default=None, + description=( + "Question that the context is summarizing for. " + "Note this can differ from the user query." + ), + ) text: Text score: int = 5 @@ -236,6 +243,7 @@ def filter_content_for_user(self) -> None: self.contexts = [ Context( context=c.context, + question=c.question, score=c.score, text=Text( text="", diff --git a/tests/test_agents.py b/tests/test_agents.py index 8ea6b281..692a386c 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -670,10 +670,27 @@ def new_status(state: EnvironmentState) -> str: # now adjust to give the agent 2x pieces of evidence gather_evidence_tool.settings.agent.agent_evidence_n = 2 + # also reset the question to ensure that contexts are + # only returned to the agent for the new question + new_question = "How does XAI relate to a self-explanatory model?" response = await gather_evidence_tool.gather_evidence( - session.question, state=env_state + new_question, state=env_state ) - + assert len({c.question for c in session.contexts}) == 2, "Expected 2 questions" + # now we make sure this is only for the old question + for context in session.contexts: + if context.question != new_question: + assert ( + context.context[:20] not in response + ), "gather_evidence should not return any contexts for the old question" + assert ( + sum( + (1 if (context.context[:20] in response) else 0) + for context in session.contexts + if context.question == new_question + ) + == 2 + ), "gather_evidence should only return 2 contexts for the new question" split = re.split( r"(\d+) pieces of evidence, (\d+) of which were relevant", response, @@ -899,6 +916,7 @@ def test_answers_are_striped() -> None: contexts=[ Context( context="bla", + question="foo", text=Text( name="text", text="The meaning of life is 42.",