diff --git a/paperqa/agents/env.py b/paperqa/agents/env.py index 20118779..3fa327ef 100644 --- a/paperqa/agents/env.py +++ b/paperqa/agents/env.py @@ -12,6 +12,7 @@ ToolRequestMessage, ToolResponseMessage, ) +from aviary.env import ENV_REGISTRY from aviary.utils import MultipleChoiceQuestion from lmi import EmbeddingModel, LiteLLMModel @@ -222,6 +223,10 @@ def __init__( self._embedding_model = embedding_model self._session_id = session_id + @classmethod + def from_task(cls, task: str) -> Self: + return cls(query=task, settings=Settings(), docs=Docs()) + def make_tools(self) -> list[Tool]: return settings_to_tools( settings=self._settings, @@ -346,3 +351,6 @@ def __deepcopy__(self, memo) -> Self: # tool functions within the tools copy_self.tools = copy_self.make_tools() return copy_self + + +ENV_REGISTRY["paperqa"] = "paperqa.agents.env", PaperQAEnvironment.__name__ diff --git a/tests/test_agents.py b/tests/test_agents.py index 5f60aab8..8ea6b281 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -17,7 +17,13 @@ import ldp.agent import pytest -from aviary.core import Tool, ToolRequestMessage, ToolsAdapter, ToolSelector +from aviary.core import ( + Environment, + Tool, + ToolRequestMessage, + ToolsAdapter, + ToolSelector, +) from ldp.agent import MemoryAgent, SimpleAgent from ldp.graph.memory import Memory, UIndexMemoryModel from ldp.graph.ops import OpResult @@ -29,6 +35,7 @@ from paperqa.agents import SearchIndex, agent_query from paperqa.agents.env import ( CLINICAL_STATUS_SEARCH_REGEX_PATTERN, + PaperQAEnvironment, clinical_trial_status, settings_to_tools, ) @@ -1041,3 +1048,22 @@ async def test_index_build_concurrency(agent_test_settings: Settings) -> None: "Expected fewer save_index with high batch size, but got" f" {high_batch_save_count} vs {low_batch_save_count}" ) + + +def test_env_from_name(subtests: SubTests) -> None: + assert "paperqa" in Environment.available() + + with subtests.test(msg="only-task"): + env = Environment.from_name( # type: ignore[var-annotated] + "paperqa", "How can you use XAI for chemical property prediction?" + ) + assert isinstance(env, PaperQAEnvironment) + + with subtests.test(msg="env-kwargs"): + env = Environment.from_name( + "paperqa", + query="How can you use XAI for chemical property prediction?", + settings=Settings(), + docs=Docs(), + ) + assert isinstance(env, PaperQAEnvironment)