Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented PaperQAEnvironment.from_task #907

Merged
merged 1 commit into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paperqa/agents/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ToolRequestMessage,
ToolResponseMessage,
)
from aviary.env import ENV_REGISTRY
from aviary.utils import MultipleChoiceQuestion
from lmi import EmbeddingModel, LiteLLMModel

Expand Down Expand Up @@ -222,6 +223,10 @@ def __init__(
self._embedding_model = embedding_model
self._session_id = session_id

@classmethod
def from_task(cls, task: str) -> Self:
return cls(query=task, settings=Settings(), docs=Docs())

def make_tools(self) -> list[Tool]:
return settings_to_tools(
settings=self._settings,
Expand Down Expand Up @@ -346,3 +351,6 @@ def __deepcopy__(self, memo) -> Self:
# tool functions within the tools
copy_self.tools = copy_self.make_tools()
return copy_self


ENV_REGISTRY["paperqa"] = "paperqa.agents.env", PaperQAEnvironment.__name__
28 changes: 27 additions & 1 deletion tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@

import ldp.agent
import pytest
from aviary.core import Tool, ToolRequestMessage, ToolsAdapter, ToolSelector
from aviary.core import (
Environment,
Tool,
ToolRequestMessage,
ToolsAdapter,
ToolSelector,
)
from ldp.agent import MemoryAgent, SimpleAgent
from ldp.graph.memory import Memory, UIndexMemoryModel
from ldp.graph.ops import OpResult
Expand All @@ -29,6 +35,7 @@
from paperqa.agents import SearchIndex, agent_query
from paperqa.agents.env import (
CLINICAL_STATUS_SEARCH_REGEX_PATTERN,
PaperQAEnvironment,
clinical_trial_status,
settings_to_tools,
)
Expand Down Expand Up @@ -1041,3 +1048,22 @@ async def test_index_build_concurrency(agent_test_settings: Settings) -> None:
"Expected fewer save_index with high batch size, but got"
f" {high_batch_save_count} vs {low_batch_save_count}"
)


def test_env_from_name(subtests: SubTests) -> None:
assert "paperqa" in Environment.available()

with subtests.test(msg="only-task"):
env = Environment.from_name( # type: ignore[var-annotated]
"paperqa", "How can you use XAI for chemical property prediction?"
)
assert isinstance(env, PaperQAEnvironment)

with subtests.test(msg="env-kwargs"):
env = Environment.from_name(
"paperqa",
query="How can you use XAI for chemical property prediction?",
settings=Settings(),
docs=Docs(),
)
assert isinstance(env, PaperQAEnvironment)