diff --git a/areal/experimental/agent_service/controller/config.py b/areal/experimental/agent_service/controller/config.py index c316d58227..ff66d3c70c 100644 --- a/areal/experimental/agent_service/controller/config.py +++ b/areal/experimental/agent_service/controller/config.py @@ -26,6 +26,21 @@ class AgentServiceControllerConfig: admin_api_key: str = DEFAULT_ADMIN_API_KEY """Shared admin API key for inter-service Bearer auth.""" + # -- Inference service integration ------------------------------------- + inference_addr: str = "" + """Address of the inference service gateway (e.g. ``http://host:port``). + Required for ``new_session`` / ``set_reward`` APIs that interact with + the inference service for RL data collection.""" + + inference_model: str = "" + """Model name served by the inference service. Passed to agents so + they can issue ``/chat/completions`` requests against the inference + gateway.""" + + inference_api_key: str = "" + """Admin API key for the inference service gateway. Used to call + ``/rl/start_session`` and other admin-only inference endpoints.""" + # -- Scaling ----------------------------------------------------------- num_pairs: int = 1 """Number of Worker+DataProxy pairs to launch on initialize.""" @@ -34,6 +49,10 @@ class AgentServiceControllerConfig: setup_timeout: float = 120.0 """Timeout (seconds) waiting for each service to become healthy.""" + request_timeout: float = 600.0 + """Timeout (seconds) for runtime HTTP requests (``step()``, + ``set_reward()``, ``new_session()``).""" + health_poll_interval: float = 5.0 """Seconds between health polls for crash detection (0 = disabled).""" @@ -49,14 +68,20 @@ class AgentServiceControllerConfig: """Extra environment variables to pass to all forked child processes.""" def __post_init__(self) -> None: - if not self.agent_cls_path: - raise ValueError("agent_cls_path must be a non-empty import path") + if not self.agent_cls_path and self.num_pairs > 0: + raise ValueError( + "agent_cls_path must be a non-empty import path when num_pairs > 0" + ) if self.num_pairs < 0: raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}") if self.setup_timeout <= 0: raise ValueError( f"setup_timeout must be positive, got {self.setup_timeout}" ) + if self.request_timeout <= 0: + raise ValueError( + f"request_timeout must be positive, got {self.request_timeout}" + ) if self.drain_timeout < 0: raise ValueError( f"drain_timeout must be non-negative, got {self.drain_timeout}" diff --git a/areal/experimental/agent_service/controller/controller.py b/areal/experimental/agent_service/controller/controller.py index 21b12851bb..57a784daae 100644 --- a/areal/experimental/agent_service/controller/controller.py +++ b/areal/experimental/agent_service/controller/controller.py @@ -26,6 +26,7 @@ import threading import time import traceback +import uuid from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -74,7 +75,7 @@ class AgentServiceController: def __init__( self, config: AgentServiceControllerConfig, - scheduler: Scheduler, + scheduler: Scheduler | None = None, ) -> None: self.config = config self.scheduler = scheduler @@ -92,6 +93,9 @@ def __init__( self._forked_services: list[tuple[str, str, int]] = [] + self._sessions: dict[str, dict[str, Any]] = {} + self._sessions_lock = threading.Lock() + self._health_stop = threading.Event() self._health_thread: threading.Thread | None = None @@ -105,7 +109,19 @@ def initialize(self) -> None: Order: Guards (via scheduler) → Router → Worker+DataProxy pairs → register → Gateway → health monitor. On failure, already-forked services are cleaned up via destroy(). + + When ``num_pairs`` is 0 and no scheduler is provided, the stack + is skipped entirely — only the data-collection APIs + (``new_session``, ``set_reward``) are available. """ + if self.config.num_pairs == 0 and self.scheduler is None: + logger.info( + "num_pairs=0 with no scheduler; " + "skipping micro-service stack (data-collection-only mode)" + ) + return + if self.scheduler is None: + raise ValueError("A scheduler is required when num_pairs > 0") try: self._do_initialize() except Exception: @@ -204,16 +220,17 @@ def destroy(self) -> None: ) self._forked_services.clear() - for role in reversed(self._service_roles): - try: - self.scheduler.delete_workers(role=role) - logger.info("Workers deleted for role: %s", role) - except Exception: - logger.error( - "Error deleting workers for role %s: %s", - role, - traceback.format_exc(), - ) + if self.scheduler is not None: + for role in reversed(self._service_roles): + try: + self.scheduler.delete_workers(role=role) + logger.info("Workers deleted for role: %s", role) + except Exception: + logger.error( + "Error deleting workers for role %s: %s", + role, + traceback.format_exc(), + ) self._service_roles.clear() self._workers.clear() self._guard_addrs.clear() @@ -371,6 +388,175 @@ def pairs(self) -> dict[int, _WorkerPair]: with self._pairs_lock: return dict(self._pairs) + # ------------------------------------------------------------------ + # Data-collection APIs (inference service integration) + # ------------------------------------------------------------------ + + def new_session(self, task_id: str = "") -> dict[str, str]: + """Create a new session for data collection. + + Generates a session ID for the agent service and starts a + corresponding session on the inference service via + ``/rl/start_session``. + + Parameters + ---------- + task_id: + Task identifier forwarded to the inference service. Defaults + to the generated session ID when empty. + + Returns + ------- + dict with keys: + + * ``session_id`` — agent-service session ID (use as ``user`` + field in ``/v1/responses`` requests). + * ``inference_session_id`` — inference-service session ID + (for trajectory export). + * ``inference_api_key`` — session-scoped API key for the + inference gateway. + """ + cfg = self.config + if not cfg.inference_addr: + raise RuntimeError( + "inference_addr must be set in AgentServiceControllerConfig " + "to use data-collection APIs" + ) + + session_id = f"agent-sess-{uuid.uuid4().hex[:12]}" + if not task_id: + task_id = session_id + + inf_addr = cfg.inference_addr.rstrip("/") + resp = requests.post( + f"{inf_addr}/rl/start_session", + json={"task_id": task_id}, + headers={"Authorization": f"Bearer {cfg.inference_api_key}"}, + timeout=cfg.request_timeout, + ) + resp.raise_for_status() + inf_data = resp.json() + + session_info: dict[str, str] = { + "session_id": session_id, + "inference_session_id": inf_data["session_id"], + "inference_api_key": inf_data["api_key"], + } + + with self._sessions_lock: + self._sessions[session_id] = session_info + + logger.info( + "New session: %s (inference session: %s)", + session_id, + inf_data["session_id"], + ) + return session_info + + def step( + self, + input: str | list[dict[str, Any]], + session_id: str, + ) -> dict[str, Any]: + """Send a message to the agent service and return the response. + + Parameters + ---------- + input: + A plain string or an OpenResponses-style input list + (e.g. ``[{"type": "message", "content": "hello"}]``). + session_id: + Agent-service session ID returned by :meth:`new_session`. + + Returns + ------- + dict + The JSON response from the agent service gateway + ``POST /v1/responses``. + """ + session_info = self._resolve_session(session_id) + sid = session_info["session_id"] + + if not self._gateway_addr: + raise RuntimeError( + "step() requires the agent-service gateway to be running. " + "It is not available in data-collection-only mode " + "(num_pairs=0 with no scheduler)." + ) + + if isinstance(input, str): + input_items: list[dict[str, Any]] = [{"type": "message", "content": input}] + else: + input_items = input + + cfg = self.config + metadata: dict[str, Any] = {} + if cfg.inference_addr: + metadata["inference_base_url"] = cfg.inference_addr.rstrip("/") + if cfg.inference_model: + metadata["inference_model"] = cfg.inference_model + inf_api_key = session_info.get("inference_api_key", "") + if inf_api_key: + metadata["inference_api_key"] = inf_api_key + + body: dict[str, Any] = { + "input": input_items, + "model": (cfg.inference_model or "default").replace("/", "--"), + "user": sid, + } + if metadata: + body["metadata"] = metadata + + resp = requests.post( + f"{self._gateway_addr}/v1/responses", + json=body, + headers={"Authorization": f"Bearer {cfg.admin_api_key}"}, + timeout=cfg.request_timeout, + ) + resp.raise_for_status() + return resp.json() + + def set_reward( + self, + reward: float, + session_id: str, + ) -> dict[str, Any]: + """Set a reward on the inference service for the current session. + + Parameters + ---------- + reward: + Scalar reward value. + session_id: + Agent-service session ID returned by :meth:`new_session`. + + Returns + ------- + dict + The JSON response from the inference gateway + ``POST /rl/set_reward``. + """ + session_info = self._resolve_session(session_id) + inf_api_key = session_info["inference_api_key"] + + cfg = self.config + inf_addr = cfg.inference_addr.rstrip("/") + resp = requests.post( + f"{inf_addr}/rl/set_reward", + json={"interaction_id": None, "reward": reward}, + headers={"Authorization": f"Bearer {inf_api_key}"}, + timeout=cfg.request_timeout, + ) + resp.raise_for_status() + return resp.json() + + def _resolve_session(self, session_id: str) -> dict[str, Any]: + with self._sessions_lock: + session_info = self._sessions.get(session_id) + if session_info is None: + raise KeyError(f"Unknown session_id: {session_id!r}") + return session_info + # ------------------------------------------------------------------ # Guard interaction helpers # ------------------------------------------------------------------ diff --git a/areal/experimental/inference_service/controller/controller.py b/areal/experimental/inference_service/controller/controller.py index 7cc546a554..09efe00b25 100644 --- a/areal/experimental/inference_service/controller/controller.py +++ b/areal/experimental/inference_service/controller/controller.py @@ -1304,7 +1304,7 @@ def start_proxy_gateway(self) -> None: """No-op — gateway already acts as the proxy gateway.""" @property - def proxy_gateway_addr(self) -> str: + def gateway_addr(self) -> str: return self._gateway_addr # -- Properties -------------------------------------------------------- diff --git a/areal/experimental/inference_service/controller/workflow.py b/areal/experimental/inference_service/controller/workflow.py index 95f0571770..b946e4e946 100644 --- a/areal/experimental/inference_service/controller/workflow.py +++ b/areal/experimental/inference_service/controller/workflow.py @@ -21,8 +21,6 @@ logger = logging.getLogger("InferenceServiceWorkflow") _GRANT_CAPACITY_PATHNAME = "grant_capacity" -_RL_START_SESSION_PATHNAME = "rl/start_session" -_RL_SET_REWARD_PATHNAME = "rl/set_reward" _EXPORT_TRAJECTORIES_PATHNAME = "export_trajectories" @@ -51,32 +49,6 @@ async def _grant_capacity(self, session: aiohttp.ClientSession) -> None: async with session.post(url, headers=headers) as resp: resp.raise_for_status() - async def _start_session( - self, session: aiohttp.ClientSession, task_id: str - ) -> tuple[str, str]: - url = f"{self.gateway_addr}/{_RL_START_SESSION_PATHNAME}" - headers = {"Authorization": f"Bearer {self._admin_api_key}"} - payload = {"task_id": task_id} - async with session.post(url, json=payload, headers=headers) as resp: - resp.raise_for_status() - data = await resp.json() - return data["session_id"], data["api_key"] - - async def _set_last_reward( - self, - session: aiohttp.ClientSession, - reward: float, - session_api_key: str, - ) -> int | None: - url = f"{self.gateway_addr}/{_RL_SET_REWARD_PATHNAME}" - headers = {"Authorization": f"Bearer {session_api_key}"} - payload: dict[str, Any] = {"interaction_id": None, "reward": reward} - async with session.post(url, json=payload, headers=headers) as resp: - resp.raise_for_status() - data = await resp.json() - trajectory_id = data.get("trajectory_id") - return int(trajectory_id) if trajectory_id is not None else None - async def _export_interactions( self, session: aiohttp.ClientSession, @@ -115,45 +87,31 @@ async def _run_offline( http_session: aiohttp.ClientSession, data: dict[str, Any], ) -> dict[str, InteractionWithTokenLogpReward] | None: + assert self.agent is not None task_id = workflow_context.get().task_id - session_id, session_api_key = await self._start_session( - http_session, str(task_id) + + http_client = await workflow_context.get_httpx_client() + result = await self.agent.run( + data, + base_url=self.gateway_addr, + http_client=http_client, + api_key=self._admin_api_key, + task_id=str(task_id), ) - assert self.agent is not None - finished = False - trajectory_id: int | None = None - try: - http_client = await workflow_context.get_httpx_client() - rewards = await self.agent.run( - data, - base_url=self.gateway_addr, - http_client=http_client, - api_key=session_api_key, + if not isinstance(result, dict): + raise TypeError( + f"Agent.run() must return a dict with 'session_id', " + f"'trajectory_id', and 'reward' keys, got {type(result)}" ) + _REQUIRED_KEYS = {"session_id", "trajectory_id", "reward"} + missing = _REQUIRED_KEYS - result.keys() + if missing: + raise KeyError(f"Agent.run() result is missing required keys: {missing}") - if isinstance(rewards, dict): - final_reward = float(list(rewards.values())[-1] if rewards else 0.0) - elif isinstance(rewards, (int, float)): - final_reward = float(rewards) - else: - raise ValueError(f"Invalid reward type: {type(rewards)}") - - trajectory_id = await self._set_last_reward( - http_session, final_reward, session_api_key - ) - finished = True - except Exception: - logger.warning("Agent task failed. This trajectory will be rejected.") - if not finished: - try: - await self._set_last_reward(http_session, 0.0, session_api_key) - except Exception: - logger.warning( - "Failed to finish session %s after agent failure", - session_id, - ) - raise + session_id = result["session_id"] + trajectory_id = result["trajectory_id"] + agent_reward = float(result["reward"]) interactions = await self._export_interactions( http_session, @@ -169,6 +127,14 @@ async def _run_offline( last_id = list(interactions.keys())[-1] last_reward = interactions[last_id].reward + if abs(last_reward - agent_reward) > 1e-6: + logger.warning( + "Session %s: agent reported reward %.6f but exported " + "interaction has %.6f; using exported value.", + session_id, + agent_reward, + last_reward, + ) stats_tracker.get(workflow_context.stat_scope()).scalar(reward=last_reward) return interactions diff --git a/examples/experimental/inference_service/README.md b/examples/experimental/inference_service/README.md index 1b7c0af21f..c359945a15 100644 --- a/examples/experimental/inference_service/README.md +++ b/examples/experimental/inference_service/README.md @@ -1,6 +1,6 @@ # AReaL Inference Service Examples -This directory contains two examples that use the AReaL Inference Service +This directory contains three examples that use the AReaL Inference Service (`GatewayInferenceController`) — an experimental rollout backend that exposes an OpenAI-compatible proxy gateway so any external agent runtime can submit chat requests and receive RL training data. @@ -13,7 +13,9 @@ This example runs rollout-only data generation on the [$\\tau^2$-Bench](https://github.com/sierra-research/tau2-bench) using the AReaL Inference Service. Unlike the full training pipeline in `examples/tau2/`, this script performs rollouts without a training step — useful for evaluation, data collection, or -debugging agent behaviour. +debugging agent behaviour. The workflow talks directly to the inference-service gateway +for `POST /rl/start_session` and `POST /rl/set_reward`; no agent-service controller is +started in this example. ### Installation @@ -54,6 +56,15 @@ python3 examples/experimental/inference_service/tau2_rollout.py \ | `` | Directory for experiment artifacts (logs, trajectories) | `/tmp/areal/experiments` | | `` | Shared path for name-resolve records | `/tmp/areal/name_resolve` | +What the script launches: + +1. A `GatewayInferenceController` for the rollout model. +1. A local `Tau2InferenceWorkflow` that creates RL sessions with the gateway, runs the + full Tau2 simulation locally, and submits rewards back to the gateway. + +This is the lightest Tau2 example in this directory because it does **not** depend on +the experimental agent service. + ### Result A successful rollout prints per-batch statistics after every batch: @@ -67,7 +78,26 @@ batch-level average reward. ______________________________________________________________________ -## Example 2: Human-in-the-Loop Online RL Demo +## Example 2: Tau2 with Both Inference Service and Agent Service + +`tau2_agent_service_rollout.py` demonstrates the complementary setup where the inference +service collects RL data and the agent service hosts the agent runtime. This variant +uses scripted user messages inside `Tau2AgentServiceWorkflow`, so it does not require +`econfig.user_llm_base_url`. + +```bash +python3 examples/experimental/inference_service/tau2_agent_service_rollout.py \ + --config examples/experimental/inference_service/tau2_rollout.yaml \ + cluster.fileroot= \ + cluster.name_resolve.nfs_record_root= +``` + +Use this variant when you want the Tau2 agent loop to be executed by the experimental +agent service instead of inside the rollout workflow. + +______________________________________________________________________ + +## Example 3: Human-in-the-Loop Online RL Demo This example demonstrates **human-in-the-loop (HITL) online RL**: a human (or an automated script acting as one) chats with the model through any OpenAI-compatible diff --git a/examples/experimental/inference_service/online_rollout.py b/examples/experimental/inference_service/online_rollout.py index 106f8e86d0..074c15fae3 100644 --- a/examples/experimental/inference_service/online_rollout.py +++ b/examples/experimental/inference_service/online_rollout.py @@ -98,7 +98,7 @@ def main(args: list[str]) -> None: server_args=server_args, ) - logger.info("Proxy gateway available at %s", ctrl.proxy_gateway_addr) + logger.info("Proxy gateway available at %s", ctrl.gateway_addr) # Online mode: pass None for both data and workflow so the # controller creates empty-dict placeholders and uses the diff --git a/examples/experimental/inference_service/tau2_agent.py b/examples/experimental/inference_service/tau2_agent.py new file mode 100644 index 0000000000..a5ba5099f2 --- /dev/null +++ b/examples/experimental/inference_service/tau2_agent.py @@ -0,0 +1,242 @@ +"""Tau2 Agent for AReaL Agent Service (PydanticAI). + +Implements :class:`AgentRunnable` using PydanticAI. Each call to ``run()`` +handles a **single turn** of a tau2 customer-service dialogue. The agent +uses tau2 environment tools (registered as PydanticAI function tools) and +maintains conversation context via ``request.history``. + +Requires: ``pip install pydantic-ai tau2-bench`` +""" + +from __future__ import annotations + +import inspect +import json +import os +from typing import Any + +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.openai import OpenAIProvider +from tau2.environment.environment import Environment +from tau2.environment.tool import Tool as Tau2Tool +from tau2.registry import registry + +from areal.experimental.agent_service.types import ( + AgentRequest, + AgentResponse, + EventEmitter, +) +from areal.utils import logging + +logger = logging.getLogger("Tau2Agent") + + +def _make_pydantic_tool(tau2_tool: Tau2Tool): + """Create a plain async function from a tau2 Tool for PydanticAI.""" + fn = tau2_tool._func # noqa: SLF001 + name = tau2_tool.name + doc = tau2_tool.openai_schema["function"].get("description", name) + + async def _wrapper(**kwargs: Any) -> str: + try: + result = fn(**kwargs) + except Exception as exc: + result = f"Tool error: {exc}" + if not isinstance(result, str): + result = json.dumps(result, default=str) + return result + + _wrapper.__name__ = name + _wrapper.__qualname__ = name + _wrapper.__doc__ = doc + + sig = inspect.signature(fn) + params = [ + inspect.Parameter( + pname, + kind=inspect.Parameter.KEYWORD_ONLY, + default=param.default, + annotation=param.annotation, + ) + for pname, param in sig.parameters.items() + ] + _wrapper.__signature__ = inspect.Signature(params) # type: ignore[attr-defined] + if hasattr(fn, "__annotations__"): + _wrapper.__annotations__ = { + k: v for k, v in fn.__annotations__.items() if k != "return" + } + return _wrapper + + +def _think_tool_fn(thoughts: str) -> str: + """Use this tool to think. Only use when necessary.""" + return "Your thoughts are recorded. Please continue your work." + + +class Tau2Agent: + """AgentRunnable that wraps a PydanticAI Agent with tau2 tools. + + Accepts a ``config`` dict (loaded from config.yaml by run_demo.py). + Falls back to environment variables if config is not provided. + """ + + def __init__(self, config: dict | None = None, **kwargs: Any) -> None: + config = config or {} + tau2_cfg = config.get("tau2", {}) + agent_llm_cfg = config.get("agent_llm", {}) + + self._domain = tau2_cfg.get("domain") or os.environ.get( + "TAU2_DOMAIN", "airline" + ) + add_thinking = tau2_cfg.get("add_thinking_tool", False) + + data_dir = tau2_cfg.get("data_dir") or os.environ.get("TAU2_DATA_DIR") + if data_dir: + os.environ["TAU2_DATA_DIR"] = data_dir + + env = self._build_environment() + tau2_tools: list[Tau2Tool] = env.get_tools() + if add_thinking: + tau2_tools.append(Tau2Tool(_think_tool_fn)) + + tools = [_make_pydantic_tool(t) for t in tau2_tools] + system_prompt = env.get_policy() + + model_name = agent_llm_cfg.get("model", "openai:default") + base_url = agent_llm_cfg.get("base_url") + api_key = agent_llm_cfg.get("api_key", "unused") + + if base_url: + model: Any = OpenAIChatModel( + model_name.replace("openai:", ""), + provider=OpenAIProvider(base_url=base_url, api_key=api_key), + ) + else: + model = model_name + + self._agent = Agent(model, system_prompt=system_prompt, tools=tools) + + logger.info( + "Tau2Agent initialized (domain=%s, tools=%d, model=%s)", + self._domain, + len(tools), + model_name, + ) + + def _build_environment(self) -> Environment: + constructor = registry.get_env_constructor(self._domain) + return constructor(solo_mode=False) + + def _resolve_model(self, metadata: dict[str, Any]) -> Any: + base_url = metadata.get("inference_base_url") + if not base_url: + return self._agent.model + model_name = metadata.get("inference_model", "default") + api_key = metadata.get("inference_api_key", "unused") + return OpenAIChatModel( + model_name, + provider=OpenAIProvider(base_url=base_url, api_key=api_key), + ) + + async def run( + self, + request: AgentRequest, + *, + emitter: EventEmitter, + ) -> AgentResponse: + from pydantic_ai.messages import ( + ModelRequest, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) + from pydantic_ai.messages import ( + ModelResponse as PAModelResponse, + ) + + message_history: list[ModelRequest | PAModelResponse] = [] + for msg in request.history: + role = msg.get("role", "user") + content = msg.get("content", "") + + if role == "user": + message_history.append( + ModelRequest(parts=[UserPromptPart(content=content or "")]) + ) + elif role == "assistant": + tool_calls = msg.get("tool_calls") + if tool_calls: + parts = [] + for tc in tool_calls: + fn = tc.get("function", tc) + parts.append( + ToolCallPart( + tool_name=fn.get("name", ""), + args=fn.get("arguments", ""), + tool_call_id=tc.get("id", ""), + ) + ) + message_history.append(PAModelResponse(parts=parts)) + elif content: + message_history.append( + PAModelResponse(parts=[TextPart(content=content)]) + ) + elif role == "tool": + tool_call_id = msg.get("tool_call_id", "") + message_history.append( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name=tool_call_id, + content=content or "", + tool_call_id=tool_call_id, + ) + ] + ) + ) + + model_override = self._resolve_model(request.metadata) + + try: + result = await self._agent.run( + request.message, + message_history=message_history, + model=model_override, + ) + except Exception as exc: + logger.error("Tau2Agent turn failed: %s", exc) + await emitter.emit_delta(f"Agent error: {exc}") + return AgentResponse( + summary=f"Agent error: {exc}", + metadata={"tool_calls": []}, + ) + + final_text = str(result.output) if result.output else "" + + tool_calls: list[dict[str, Any]] = [] + for msg in result.new_messages(): + if not hasattr(msg, "parts"): + continue + for part in msg.parts: + kind = getattr(part, "part_kind", "") + if kind == "tool-call": + name = getattr(part, "tool_name", "") + args = getattr(part, "args", "") + if isinstance(args, dict): + args = json.dumps(args) + await emitter.emit_tool_call(name=name, args=str(args)) + tool_calls.append({"name": name, "arguments": args}) + elif kind == "tool-return": + name = getattr(part, "tool_name", "") + content = str(getattr(part, "content", "")) + await emitter.emit_tool_result(name=name, result=content) + + if final_text: + await emitter.emit_delta(final_text) + + return AgentResponse( + summary=final_text[:200], + metadata={"tool_calls": tool_calls}, + ) diff --git a/examples/experimental/inference_service/tau2_agent_service_rollout.py b/examples/experimental/inference_service/tau2_agent_service_rollout.py new file mode 100644 index 0000000000..77a26f6301 --- /dev/null +++ b/examples/experimental/inference_service/tau2_agent_service_rollout.py @@ -0,0 +1,288 @@ +"""Rollout script for Tau2 using both agent service and inference service. + +Uses the agent service to run the Tau2Agent (PydanticAI) while the +inference service provides model inference with RL data collection. + +Usage: + python3 examples/experimental/inference_service/tau2_agent_service_rollout.py \ + --config examples/experimental/inference_service/tau2_rollout.yaml +""" + +from __future__ import annotations + +import os +import sys +import warnings +from dataclasses import asdict, dataclass, field +from typing import Any + +from datasets import Dataset + +from areal.api.alloc_mode import ModelAllocation +from areal.api.cli_args import ( + BaseExperimentConfig, + GenerationHyperparameters, + InferenceEngineConfig, + SGLangConfig, + TrainDatasetConfig, + load_expr_config, +) +from areal.experimental.agent_service.controller.config import ( + AgentServiceControllerConfig, +) +from areal.experimental.agent_service.controller.controller import ( + AgentServiceController, +) +from areal.experimental.inference_service.controller.config import ( + GatewayControllerConfig, +) +from areal.experimental.inference_service.controller.controller import ( + GatewayInferenceController, +) +from areal.utils import logging + +logger = logging.getLogger("Tau2AgentServiceRollout") + + +@dataclass +class Tau2EnvConfig: + domain: str = field( + default="telecom", + metadata={"help": "The tau2 domain name."}, + ) + max_steps: int = field( + default=100, metadata={"help": "Maximum number of steps per episode."} + ) + add_thinking_tool: bool = field( + default=False, metadata={"help": "Whether to add a thinking tool."} + ) + solo_mode: bool = field( + default=False, metadata={"help": "Whether to use solo mode."} + ) + user_llm_base_url: str | None = field( + default=None, metadata={"help": "The base URL of the user LLM."} + ) + user_llm: str | None = field( + default=None, metadata={"help": "The user LLM model name."} + ) + user_llm_args: dict | None = field( + default=None, metadata={"help": "The arguments for the user LLM."} + ) + turn_discount: float = field( + default=1.0, metadata={"help": "Discount factor for turn-based learning."} + ) + invalid_format_penalty: float = field( + default=0.1, metadata={"help": "Penalty for invalid format."} + ) + + +@dataclass +class AgentServiceConfig: + agent_cls_path: str = field( + default="examples.experimental.inference_service.tau2_agent.Tau2Agent", + metadata={"help": "Import path for the agent-service agent class."}, + ) + num_pairs: int = field( + default=1, metadata={"help": "Number of agent Worker+DataProxy pairs."} + ) + admin_api_key: str = field( + default="areal-agent-admin", + metadata={"help": "Admin API key for agent service."}, + ) + + +def get_tau2_dataset(domain: str, type: str = "rl", split: str = "train") -> Dataset: + from tau2.registry import registry + + from examples.tau2.agent import _get_task + + assert type == "rl", "Only RL dataset is supported" + splits_loader_fn = registry.get_task_splits_loader(domain) + if splits_loader_fn is None: + raise ValueError(f"No task splits loader found for domain {domain}") + splits = splits_loader_fn() + if split not in splits: + raise ValueError( + f"Split {split} not found for domain {domain}, " + f"available: {list(splits.keys())}" + ) + task_ids = splits[split] + dataset_items = [] + for tid in task_ids: + task = _get_task(domain=domain, task_id=tid, split=split) + dataset_items.append( + { + "task_id": tid, + "split": split, + "prompt": str(task.user_scenario), + } + ) + if len(dataset_items) < 128: + original = dataset_items.copy() + while len(dataset_items) < 128: + dataset_items.extend(original) + dataset = Dataset.from_list(dataset_items) + logger.info("Created dataset with %d items for %s/%s", len(dataset), domain, split) + return dataset + + +@dataclass +class Tau2AgentServiceRolloutConfig(BaseExperimentConfig): + gconfig: GenerationHyperparameters = field( + default_factory=GenerationHyperparameters + ) + rollout: InferenceEngineConfig = field(default_factory=InferenceEngineConfig) + model_path: str = "" + econfig: Tau2EnvConfig = field(default_factory=Tau2EnvConfig) + agent_service: AgentServiceConfig = field(default_factory=AgentServiceConfig) + sglang: SGLangConfig = field(default_factory=SGLangConfig) + train_dataset: TrainDatasetConfig = field(default_factory=TrainDatasetConfig) + + +def main(argv: list[str]) -> None: + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") + + config, _ = load_expr_config(argv, Tau2AgentServiceRolloutConfig) + econfig = config.econfig + rollout_cfg = config.rollout + agent_svc_cfg = config.agent_service + + os.environ["TAU2_DOMAIN"] = econfig.domain + + train_dataset = get_tau2_dataset( + domain=econfig.domain, + type=config.train_dataset.type, + split=config.train_dataset.path.split("/")[-1], + ) + + from torch.utils.data import DataLoader + + dataloader = DataLoader( + train_dataset, + batch_size=config.train_dataset.batch_size, + shuffle=config.train_dataset.shuffle, + num_workers=0, + ) + + openai_cfg = rollout_cfg.openai + ctrl_config = GatewayControllerConfig( + tokenizer_path=config.tokenizer_path, + model_path=config.model_path, + consumer_batch_size=rollout_cfg.consumer_batch_size, + max_concurrent_rollouts=rollout_cfg.max_concurrent_rollouts, + max_head_offpolicyness=rollout_cfg.max_head_offpolicyness, + queue_size=rollout_cfg.queue_size, + enable_rollout_tracing=rollout_cfg.enable_rollout_tracing, + fileroot=rollout_cfg.fileroot, + experiment_name=rollout_cfg.experiment_name, + trial_name=rollout_cfg.trial_name, + dump_to_file=rollout_cfg.dump_to_file, + backend=rollout_cfg.backend, + scheduling_spec=rollout_cfg.scheduling_spec, + setup_timeout=rollout_cfg.setup_timeout, + request_timeout=rollout_cfg.request_timeout, + **( + { + "admin_api_key": openai_cfg.admin_api_key, + "turn_discount": openai_cfg.turn_discount, + "export_style": openai_cfg.export_style, + "tool_call_parser": openai_cfg.tool_call_parser, + "reasoning_parser": openai_cfg.reasoning_parser, + "engine_max_tokens": openai_cfg.engine_max_tokens, + "chat_template_type": openai_cfg.chat_template_type, + } + if openai_cfg + else {} + ), + ) + + from areal.infra.scheduler.local import LocalScheduler + from areal.infra.scheduler.slurm import SlurmScheduler + + sched_type = config.scheduler.type + if sched_type == "local": + scheduler = LocalScheduler(exp_config=config) + elif sched_type == "slurm": + scheduler = SlurmScheduler(exp_config=config) + else: + raise NotImplementedError(f"Unknown scheduler type: {sched_type}") + + rollout_alloc = ModelAllocation.from_str(config.rollout.backend, name="rollout") + if rollout_alloc.backend == "sglang": + server_args = asdict(config.sglang) + elif rollout_alloc.backend == "vllm": + server_args = asdict(config.vllm) + else: + raise ValueError(f"Unsupported rollout backend: {rollout_alloc.backend}") + + inf_ctrl = GatewayInferenceController(config=ctrl_config, scheduler=scheduler) + inf_ctrl.initialize(role="rollout", server_args=server_args) + + logger.info("Inference service ready at %s", inf_ctrl.gateway_addr) + + agent_ctrl_config = AgentServiceControllerConfig( + agent_cls_path=agent_svc_cfg.agent_cls_path, + admin_api_key=agent_svc_cfg.admin_api_key, + num_pairs=agent_svc_cfg.num_pairs, + inference_addr=inf_ctrl.gateway_addr, + inference_model=config.model_path, + inference_api_key=ctrl_config.admin_api_key or "areal-admin-key", + ) + + agent_ctrl = AgentServiceController(config=agent_ctrl_config, scheduler=scheduler) + agent_ctrl.initialize() + + logger.info("Agent service ready at %s", agent_ctrl.gateway_addr) + + econfig_dict = asdict(econfig) + workflow_kwargs: dict[str, Any] = dict( + agent_controller=agent_ctrl, + econfig=econfig_dict, + gen_args=dict( + temperature=config.gconfig.temperature, + max_completion_tokens=config.gconfig.max_new_tokens, + ), + timeout=600.0, + ) + + try: + logger.info("Starting rollout loop") + batch_count = 0 + for batch_idx, batch in enumerate(dataloader): + keys = list(batch.keys()) + batch_size = len(batch[keys[0]]) + data = [{k: batch[k][i] for k in keys} for i in range(batch_size)] + + result = inf_ctrl.rollout_batch( + data=data, + workflow="examples.experimental.inference_service.tau2_workflow.Tau2AgentServiceWorkflow", + workflow_kwargs=workflow_kwargs, + ) + if result: + import torch + + from areal.infra.rpc.rtensor import RTensor + + batch_rewards = [] + for traj in result: + local_traj = RTensor.localize(traj) + batch_rewards.append(local_traj["rewards"]) + all_rewards = torch.cat(batch_rewards, dim=0) + logger.info( + "Batch %d: n_trajs=%d, rewards=%s, avg_reward=%.4f", + batch_idx, + len(result), + all_rewards, + all_rewards.mean().item(), + ) + else: + logger.warning("Batch %d: empty result (all rejected?)", batch_idx) + batch_count += 1 + logger.info("Rollout complete (%d batches)", batch_count) + finally: + agent_ctrl.destroy() + inf_ctrl.destroy() + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/examples/experimental/inference_service/tau2_rollout.py b/examples/experimental/inference_service/tau2_rollout.py index d7c6828216..352ec83eff 100644 --- a/examples/experimental/inference_service/tau2_rollout.py +++ b/examples/experimental/inference_service/tau2_rollout.py @@ -179,6 +179,7 @@ def main(argv: list[str]) -> None: ) # --- Build GatewayControllerConfig from YAML rollout section --- + openai_cfg = rollout_cfg.openai ctrl_config = GatewayControllerConfig( tokenizer_path=config.tokenizer_path, model_path=config.model_path, @@ -195,7 +196,19 @@ def main(argv: list[str]) -> None: scheduling_spec=rollout_cfg.scheduling_spec, setup_timeout=rollout_cfg.setup_timeout, request_timeout=rollout_cfg.request_timeout, - openai=rollout_cfg.openai, + **( + { + "admin_api_key": openai_cfg.admin_api_key, + "turn_discount": openai_cfg.turn_discount, + "export_style": openai_cfg.export_style, + "tool_call_parser": openai_cfg.tool_call_parser, + "reasoning_parser": openai_cfg.reasoning_parser, + "engine_max_tokens": openai_cfg.engine_max_tokens, + "chat_template_type": openai_cfg.chat_template_type, + } + if openai_cfg + else {} + ), ) # --- Scheduler --- @@ -225,9 +238,12 @@ def main(argv: list[str]) -> None: server_args=server_args, ) - # --- Workflow kwargs (identical to examples/tau2/train.py) --- econfig_dict = asdict(econfig) + workflow_kwargs: dict[str, Any] = dict( + gateway_addr=ctrl.gateway_addr, + gateway_api_key=ctrl_config.admin_api_key or "areal-admin-key", + model=config.model_path, econfig=econfig_dict, gen_args=dict( temperature=config.gconfig.temperature, @@ -248,7 +264,7 @@ def main(argv: list[str]) -> None: result = ctrl.rollout_batch( data=data, - workflow="examples.tau2.agent.Tau2AgentWorkflow", + workflow="examples.experimental.inference_service.tau2_workflow.Tau2InferenceWorkflow", workflow_kwargs=workflow_kwargs, ) if result: diff --git a/examples/experimental/inference_service/tau2_workflow.py b/examples/experimental/inference_service/tau2_workflow.py new file mode 100644 index 0000000000..38a68cc5d4 --- /dev/null +++ b/examples/experimental/inference_service/tau2_workflow.py @@ -0,0 +1,389 @@ +"""Tau2 workflows for inference-service examples. + +`Tau2InferenceWorkflow` runs the Tau2 simulation locally and manages +`/rl/start_session` + `/rl/set_reward` directly against the inference-service gateway. +`Tau2AgentServiceWorkflow` keeps the agent-service-based path for the companion +example that runs the agent loop out of process. +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any + +import httpx + +from areal.utils import logging + +logger = logging.getLogger("Tau2Workflow") + + +class Tau2InferenceWorkflow: + """Run Tau2 locally while reporting session lifecycle to the gateway.""" + + def __init__( + self, + gateway_addr: str, + gateway_api_key: str, + model: str | None = None, + econfig: dict | None = None, + gen_args: dict | None = None, + timeout: float = 600.0, + ) -> None: + from examples.tau2.utils import Tau2EnvConfig + + if econfig is None: + self.econfig = Tau2EnvConfig() + elif isinstance(econfig, dict): + self.econfig = Tau2EnvConfig(**econfig) + else: + self.econfig = econfig + self.gen_args = gen_args or {} + self.timeout = timeout + self.gateway_addr = gateway_addr.rstrip("/") + self.gateway_api_key = gateway_api_key + self.model = model + + async def _request( + self, + client: httpx.AsyncClient, + endpoint: str, + api_key: str, + payload: dict[str, Any], + ) -> dict[str, Any]: + response = await client.post( + f"{self.gateway_addr}{endpoint}", + json=payload, + headers={"Authorization": f"Bearer {api_key}"}, + ) + response.raise_for_status() + return response.json() + + async def _start_session( + self, + client: httpx.AsyncClient, + task_id: str, + ) -> dict[str, str]: + response = await self._request( + client, + "/rl/start_session", + self.gateway_api_key, + {"task_id": task_id}, + ) + return { + "session_id": response["session_id"], + "api_key": response["api_key"], + } + + async def _set_reward( + self, + client: httpx.AsyncClient, + session_api_key: str, + reward: float, + ) -> dict[str, Any]: + payload: dict[str, Any] = {"interaction_id": None, "reward": reward} + if self.model: + payload["model"] = self.model + return await self._request( + client, + "/rl/set_reward", + session_api_key, + payload, + ) + + async def run(self, data: dict[str, Any], **extra_kwargs: Any) -> dict[str, Any]: + from openai import AsyncOpenAI + + from examples.tau2.agent import Tau2Runner, _get_task + from examples.tau2.utils import Tau2EnvConfig + + base_url: str | None = extra_kwargs.get("base_url") or os.getenv( + "OPENAI_BASE_URL" + ) + task_id_str = extra_kwargs.get("task_id", str(data.get("task_id", ""))) + http_client: httpx.AsyncClient | None = extra_kwargs.get("http_client") + + if base_url is None: + raise ValueError("base_url is required for Tau2InferenceWorkflow") + + client = http_client or httpx.AsyncClient(timeout=30.0) + owns_client = http_client is None + try: + session = await self._start_session(client, task_id_str) + session_api_key = session["api_key"] + + econfig = self.econfig + if "econfig" in data: + econfig = Tau2EnvConfig(**data["econfig"]) + + gen_args = self.gen_args.copy() + if "gconfig" in data: + gen_args.update(data["gconfig"]) + + domain = econfig.domain + split = data.get("split", "train") + task = _get_task(domain=domain, task_id=data["task_id"], split=split) + + agent_client = AsyncOpenAI( + base_url=base_url, + api_key=session_api_key, + http_client=client, + max_retries=0, + ) + + user_client = None + if not econfig.solo_mode and econfig.user_llm_base_url: + user_client = AsyncOpenAI( + base_url=econfig.user_llm_base_url, + api_key="dummy", + max_retries=3, + timeout=120.0, + ) + + runner = Tau2Runner( + econfig=econfig, + gen_args=gen_args, + agent_client=agent_client, + user_client=user_client, + ) + + finished = False + try: + run_info = await asyncio.wait_for( + runner.run(task), timeout=self.timeout + ) + reward = run_info.reward + result = await self._set_reward(client, session_api_key, reward) + finished = True + except Exception: + if not finished: + try: + await self._set_reward(client, session_api_key, 0.0) + except Exception: + logger.warning( + "Failed to set 0 reward for session %s", + session["session_id"], + ) + raise + finally: + if owns_client: + await client.aclose() + + trajectory_id = result.get("trajectory_id") + return { + "session_id": session["session_id"], + "trajectory_id": ( + int(trajectory_id) if trajectory_id is not None else None + ), + "reward": reward, + } + + +def _extract_response_text(response: dict[str, Any]) -> str: + """Extract assistant text from an OpenResponses /v1/responses result.""" + parts: list[str] = [] + for item in response.get("output", []): + if item.get("type") == "message": + for block in item.get("content", []): + if block.get("type") == "output_text": + parts.append(block.get("text", "")) + return "\n".join(parts) + + +def _extract_completion_text(completion: Any) -> str: + choice = completion.choices[0] + message = getattr(choice, "message", None) + content = getattr(message, "content", "") if message is not None else "" + if isinstance(content, str): + return content.strip() + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text" and item.get("text"): + parts.append(str(item["text"])) + else: + text = getattr(item, "text", None) + if text: + parts.append(str(text)) + return "\n".join(parts).strip() + return str(content).strip() + + +class Tau2AgentServiceWorkflow: + """Tau2 workflow using agent service + inference service. + + Orchestrates the agent loop through an ``AgentServiceController`` + using ``new_session``, ``step``, and ``set_reward``. + + The agent-service agent handles tool calls internally. This workflow + seeds the first user turn from rollout data and generates later user + turns through the configured user LLM. + """ + + def __init__( + self, + agent_controller: Any | None = None, + econfig: dict | None = None, + gen_args: dict | None = None, + timeout: float = 600.0, + max_turns: int = 10, + ) -> None: + from examples.tau2.utils import Tau2EnvConfig + + if econfig is None: + self.econfig = Tau2EnvConfig() + elif isinstance(econfig, dict): + self.econfig = Tau2EnvConfig(**econfig) + else: + self.econfig = econfig + self.gen_args = gen_args or {} + self.timeout = timeout + self.max_turns = max_turns + self.agent_controller = agent_controller + + async def run(self, data: dict[str, Any], **extra_kwargs: Any) -> dict[str, Any]: + from openai import AsyncOpenAI + from tau2.data_model.message import AssistantMessage, UserMessage + from tau2.data_model.simulation import SimulationRun, TerminationReason + from tau2.evaluator.evaluator import EvaluationType, evaluate_simulation + + from examples.tau2.agent import _get_task + from examples.tau2.utils import Tau2EnvConfig + + ctrl = self.agent_controller + if ctrl is None: + raise ValueError( + "agent_controller is required for Tau2AgentServiceWorkflow" + ) + + task_id_str = extra_kwargs.get("task_id", str(data.get("task_id", ""))) + + econfig = self.econfig + if "econfig" in data: + econfig = Tau2EnvConfig(**data["econfig"]) + + domain = econfig.domain + split = data.get("split", "train") + task = _get_task(domain=domain, task_id=data["task_id"], split=split) + + session = ctrl.new_session(task_id=task_id_str) + first_user_message = str(data.get("prompt") or task.user_scenario).strip() + if not first_user_message: + raise ValueError("data.prompt or task.user_scenario is required") + + if not econfig.solo_mode and not econfig.user_llm_base_url: + raise ValueError( + "econfig.user_llm_base_url is required for Tau2AgentServiceWorkflow" + ) + + user_client = None + if not econfig.solo_mode: + user_client = AsyncOpenAI( + base_url=econfig.user_llm_base_url, + api_key="dummy", + max_retries=3, + timeout=120.0, + ) + + tau2_messages: list[UserMessage | AssistantMessage] = [] + chat_history: list[dict[str, str]] = [ + {"role": "user", "content": first_user_message} + ] + next_user_message = first_user_message + finished = False + try: + for i in range(self.max_turns): + response = ctrl.step(next_user_message, session["session_id"]) + agent_text = _extract_response_text(response) + + tau2_messages.append( + UserMessage( + role="user", + content=next_user_message, + turn_idx=len(tau2_messages), + ) + ) + tau2_messages.append( + AssistantMessage( + role="assistant", + content=agent_text or "(no response)", + turn_idx=len(tau2_messages), + ) + ) + + if i + 1 >= self.max_turns: + break + + if user_client is None: + break + + chat_history.append( + {"role": "assistant", "content": agent_text or "(no response)"} + ) + completion = await user_client.chat.completions.create( + model=econfig.user_llm or "dummy", + messages=[ + { + "role": "system", + "content": ( + "You are simulating the tau2 user described below. " + "Respond with the user's next message only, in one turn, " + "based on the conversation so far.\n\n" + f"User scenario:\n{task.user_scenario}" + ), + }, + *chat_history, + ], + **(econfig.user_llm_args or {}), + ) + next_user_message = _extract_completion_text(completion) + if not next_user_message: + break + chat_history.append({"role": "user", "content": next_user_message}) + + reward = 0.0 + if tau2_messages: + try: + simulation = SimulationRun( + id=f"agent-svc-{task.id}", + task_id=task.id, + messages=tau2_messages, + start_time="", + end_time="", + duration=0.0, + termination_reason=TerminationReason.USER_STOP, + ) + reward_info = evaluate_simulation( + simulation=simulation, + task=task, + evaluation_type=EvaluationType.ALL, + solo_mode=False, + domain=domain, + ) + reward = reward_info.reward + except Exception as e: + logger.error("Evaluation failed for task %s: %s", task.id, e) + + result = ctrl.set_reward(reward, session["session_id"]) + finished = True + except Exception: + if not finished: + try: + ctrl.set_reward(0.0, session["session_id"]) + except Exception: + logger.warning( + "Failed to set 0 reward for session %s", + session["session_id"], + ) + raise + + trajectory_id = result.get("trajectory_id") + return { + "session_id": session["inference_session_id"], + "trajectory_id": ( + int(trajectory_id) if trajectory_id is not None else None + ), + "reward": reward, + } diff --git a/tests/experimental/agent_service/test_controller.py b/tests/experimental/agent_service/test_controller.py index 376bed71c6..c7b9c4d66e 100644 --- a/tests/experimental/agent_service/test_controller.py +++ b/tests/experimental/agent_service/test_controller.py @@ -338,3 +338,277 @@ def test_health_monitor_disabled_when_interval_zero(self, mock_requests, config) assert ctrl._health_thread is None ctrl.destroy() + + +@pytest.fixture() +def dc_config(): + return AgentServiceControllerConfig( + agent_cls_path="my.Agent", + admin_api_key="test-key", + num_pairs=0, + setup_timeout=1.0, + health_poll_interval=0, + inference_addr="http://inf-gw:8080", + inference_model="Qwen/Qwen3-0.6B", + inference_api_key="inf-admin-key", + ) + + +def _make_dc_controller(mock_requests, dc_config) -> AgentServiceController: + _setup_mock_requests(mock_requests) + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=dc_config, scheduler=scheduler) + ctrl.initialize() + return ctrl + + +class TestNewSession: + @patch(f"{CTRL}.requests") + def test_new_session_calls_inference_start_session(self, mock_requests, dc_config): + post_calls: list[tuple[str, dict]] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + + def mock_post(url, **kwargs): + if "/rl/start_session" in url: + post_calls.append((url, kwargs)) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "session_id": "inf-sess-1", + "api_key": "sk-sess-abc123", + } + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + result = ctrl.new_session(task_id="my-task") + + assert result["inference_session_id"] == "inf-sess-1" + assert result["inference_api_key"] == "sk-sess-abc123" + assert result["session_id"].startswith("agent-sess-") + + assert len(post_calls) == 1 + call_url, call_kwargs = post_calls[0] + assert "/rl/start_session" in call_url + assert call_kwargs["json"]["task_id"] == "my-task" + assert "inf-admin-key" in call_kwargs["headers"]["Authorization"] + + @patch(f"{CTRL}.requests") + def test_new_session_stores_session_and_sets_latest(self, mock_requests, dc_config): + ctrl = _make_dc_controller(mock_requests, dc_config) + + def mock_post(url, **kwargs): + if "/rl/start_session" in url: + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "session_id": "inf-sess-1", + "api_key": "sk-sess-abc123", + } + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + result = ctrl.new_session() + sid = result["session_id"] + + assert ctrl._sessions[sid] == result + + @patch(f"{CTRL}.requests") + def test_new_session_defaults_task_id_to_session_id(self, mock_requests, dc_config): + captured_task_id: list[str] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + + def mock_post(url, **kwargs): + if "/rl/start_session" in url: + captured_task_id.append(kwargs["json"]["task_id"]) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "session_id": "inf-sess-1", + "api_key": "sk-sess-abc", + } + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + result = ctrl.new_session() + assert captured_task_id[0] == result["session_id"] + + def test_new_session_raises_without_inference_addr(self, config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=config, scheduler=scheduler) + + with pytest.raises(RuntimeError, match="inference_addr must be set"): + ctrl.new_session() + + +def _mock_start_session_post(url, **kwargs): + if "/rl/start_session" in url: + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "session_id": "inf-sess-1", + "api_key": "sk-sess-abc", + } + return resp + return MagicMock(status_code=404) + + +class TestStep: + @patch(f"{CTRL}.requests") + def test_step_sends_string_input_to_gateway(self, mock_requests, dc_config): + v1_calls: list[tuple[str, dict]] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + mock_requests.post = _mock_start_session_post + session = ctrl.new_session() + + def mock_post(url, **kwargs): + if "/v1/responses" in url: + v1_calls.append((url, kwargs)) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = { + "id": "resp-1", + "status": "completed", + "output": [], + } + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + result = ctrl.step("Hello agent", session["session_id"]) + + assert result["status"] == "completed" + assert len(v1_calls) == 1 + _, call_kwargs = v1_calls[0] + body = call_kwargs["json"] + assert body["input"] == [{"type": "message", "content": "Hello agent"}] + assert body["user"] == session["session_id"] + assert body["model"] == "Qwen--Qwen3-0.6B" + assert body["metadata"]["inference_base_url"] == "http://inf-gw:8080" + assert body["metadata"]["inference_model"] == "Qwen/Qwen3-0.6B" + assert body["metadata"]["inference_api_key"] == "sk-sess-abc" + + @patch(f"{CTRL}.requests") + def test_step_sends_list_input_unchanged(self, mock_requests, dc_config): + v1_calls: list[tuple[str, dict]] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + mock_requests.post = _mock_start_session_post + session = ctrl.new_session() + + def mock_post(url, **kwargs): + if "/v1/responses" in url: + v1_calls.append((url, kwargs)) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = {"id": "resp-1", "status": "completed"} + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + custom_input = [ + {"type": "message", "content": "first"}, + {"type": "function_call_output", "output": "42"}, + ] + ctrl.step(custom_input, session["session_id"]) + + _, call_kwargs = v1_calls[0] + assert call_kwargs["json"]["input"] == custom_input + + @patch(f"{CTRL}.requests") + def test_step_requires_explicit_session_id(self, mock_requests, dc_config): + v1_calls: list[tuple[str, dict]] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + mock_requests.post = _mock_start_session_post + session = ctrl.new_session() + + def mock_post(url, **kwargs): + if "/v1/responses" in url: + v1_calls.append((url, kwargs)) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = {"id": "r", "status": "completed"} + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + ctrl.step("hi", session["session_id"]) + + _, call_kwargs = v1_calls[0] + assert call_kwargs["json"]["user"] == session["session_id"] + + +class TestSetReward: + @patch(f"{CTRL}.requests") + def test_set_reward_calls_inference_gateway(self, mock_requests, dc_config): + reward_calls: list[tuple[str, dict]] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + mock_requests.post = _mock_start_session_post + session = ctrl.new_session() + + def mock_post(url, **kwargs): + if "/rl/set_reward" in url: + reward_calls.append((url, kwargs)) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = {"trajectory_id": 42} + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + result = ctrl.set_reward(1.0, session["session_id"]) + + assert result["trajectory_id"] == 42 + assert len(reward_calls) == 1 + call_url, call_kwargs = reward_calls[0] + assert "inf-gw:8080/rl/set_reward" in call_url + assert call_kwargs["json"]["reward"] == 1.0 + assert call_kwargs["json"]["interaction_id"] is None + assert "sk-sess-abc" in call_kwargs["headers"]["Authorization"] + + @patch(f"{CTRL}.requests") + def test_set_reward_requires_explicit_session_id(self, mock_requests, dc_config): + reward_calls: list[tuple[str, dict]] = [] + + ctrl = _make_dc_controller(mock_requests, dc_config) + mock_requests.post = _mock_start_session_post + session = ctrl.new_session() + + def mock_post(url, **kwargs): + if "/rl/set_reward" in url: + reward_calls.append((url, kwargs)) + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.json.return_value = {"trajectory_id": 1} + return resp + return MagicMock(status_code=404) + + mock_requests.post = mock_post + + ctrl.set_reward(0.5, session["session_id"]) + + _, call_kwargs = reward_calls[0] + assert "sk-sess-abc" in call_kwargs["headers"]["Authorization"] + + +class TestResolveSession: + def test_resolve_raises_on_unknown_id(self, dc_config): + scheduler = _make_scheduler(("10.0.0.1", "8090")) + ctrl = AgentServiceController(config=dc_config, scheduler=scheduler) + + with pytest.raises(KeyError, match="Unknown session_id"): + ctrl._resolve_session("nonexistent") diff --git a/tests/experimental/inference_service/integration_utils.py b/tests/experimental/inference_service/integration_utils.py index 532dba6a98..d437e4df17 100644 --- a/tests/experimental/inference_service/integration_utils.py +++ b/tests/experimental/inference_service/integration_utils.py @@ -90,6 +90,60 @@ def get_vlm_test_model_path() -> str: return _get_model_path(VLM_LOCAL_MODEL_PATH, VLM_HF_MODEL_ID) +class InferenceServiceAgent: + """Agent that follows the inference-service session lifecycle. + + Unlike ``SimpleAgent`` (which returns a bare float), this agent starts an + RL session, chats using the session key, sets a reward, and returns the + ``{"session_id", "trajectory_id", "reward"}`` dict expected by + ``InferenceServiceWorkflow._run_offline``. + """ + + async def run(self, data: dict, **extra_kwargs: Any) -> dict[str, Any]: + import httpx + from openai import AsyncOpenAI + + http_client: httpx.AsyncClient | None = extra_kwargs.get("http_client") + base_url: str = extra_kwargs.get("base_url") or os.getenv("OPENAI_BASE_URL", "") + api_key: str = extra_kwargs.get("api_key") or os.getenv("OPENAI_API_KEY", "") + task_id: str = extra_kwargs.get("task_id", "0") + + raw_client = http_client or httpx.AsyncClient(timeout=30.0) + + start_resp = await raw_client.post( + f"{base_url}/rl/start_session", + json={"task_id": task_id}, + headers={"Authorization": f"Bearer {api_key}"}, + ) + start_resp.raise_for_status() + session_info = start_resp.json() + session_id = session_info["session_id"] + session_api_key = session_info["api_key"] + + client = AsyncOpenAI( + base_url=base_url, + api_key=session_api_key, + http_client=raw_client, + max_retries=0, + ) + await client.chat.completions.create(messages=data["messages"], model="default") + + reward = 1.0 + reward_resp = await raw_client.post( + f"{base_url}/rl/set_reward", + json={"interaction_id": None, "reward": reward}, + headers={"Authorization": f"Bearer {session_api_key}"}, + ) + reward_resp.raise_for_status() + reward_data = reward_resp.json() + + return { + "session_id": session_id, + "trajectory_id": reward_data.get("trajectory_id"), + "reward": reward, + } + + def check_server_health(base_url: str) -> bool: """Check if the inference server is healthy. diff --git a/tests/experimental/inference_service/test_controller.py b/tests/experimental/inference_service/test_controller.py index 5e61f768f1..3403b9e0fd 100644 --- a/tests/experimental/inference_service/test_controller.py +++ b/tests/experimental/inference_service/test_controller.py @@ -224,7 +224,7 @@ def test_has_properties(self): "workflow_executor", "dispatcher", "runner", - "proxy_gateway_addr", + "gateway_addr", "worker_ids", ] for p in properties: @@ -298,12 +298,11 @@ def test_start_proxy_is_noop(self): controller.start_proxy() controller.start_proxy_gateway() - def test_proxy_gateway_addr(self): + def test_gateway_addr(self): cfg = GatewayControllerConfig(admin_api_key="test-key") scheduler = MagicMock() controller = GatewayInferenceController(config=cfg, scheduler=scheduler) - # Before initialize, proxy_gateway_addr returns the empty _gateway_addr - assert controller.proxy_gateway_addr == "" + assert controller.gateway_addr == "" def test_callback_addr_formats_ipv6_hostport(self): cfg = GatewayControllerConfig(admin_api_key="test-key") @@ -592,7 +591,11 @@ async def test_offline_mode_runs_agent(self): class MockAgent: async def run(self, data, **kwargs): - return 1.0 + return { + "session_id": "sess-1", + "trajectory_id": None, + "reward": 1.0, + } mock_interaction = MagicMock(reward=1.0) workflow = InferenceServiceWorkflow( @@ -602,8 +605,6 @@ async def run(self, data, **kwargs): admin_api_key="test-key", ) workflow._grant_capacity = AsyncMock() - workflow._start_session = AsyncMock(return_value=("sess-1", "sess-api-key-1")) - workflow._set_last_reward = AsyncMock(return_value=None) workflow._export_interactions = AsyncMock( return_value={"chatcmpl-1": mock_interaction} ) @@ -628,10 +629,6 @@ async def run(self, data, **kwargs): assert result is not None assert "chatcmpl-1" in result workflow._grant_capacity.assert_awaited_once() - workflow._start_session.assert_awaited_once() - workflow._set_last_reward.assert_awaited_once_with( - mock_http_session, 1.0, "sess-api-key-1" - ) workflow._export_interactions.assert_awaited_once_with( mock_http_session, "sess-1", trajectory_id=None ) diff --git a/tests/experimental/inference_service/test_controller_integration.py b/tests/experimental/inference_service/test_controller_integration.py index cdaa0b1868..5c52af85be 100644 --- a/tests/experimental/inference_service/test_controller_integration.py +++ b/tests/experimental/inference_service/test_controller_integration.py @@ -309,10 +309,8 @@ def test_data_proxy_health(self, gateway_controller): assert resp.status_code == 200 assert resp.json()["status"] == "ok" - def test_proxy_gateway_addr_set(self, gateway_controller): - """proxy_gateway_addr should point to the gateway port.""" - addr = gateway_controller.proxy_gateway_addr - # proxy_gateway_addr should be a valid http URL + def test_gateway_addr_set(self, gateway_controller): + addr = gateway_controller.gateway_addr assert addr.startswith("http://") assert addr == gateway_controller._gateway_addr @@ -348,7 +346,7 @@ def test_set_version_does_not_raise_without_broadcast(self, gateway_controller): gateway_controller.set_version(10) assert gateway_controller.get_version() == 10 # Verify gateway is still healthy (no stale broadcast attempted) - addr = gateway_controller.proxy_gateway_addr + addr = gateway_controller.gateway_addr resp = httpx.get(f"{addr}/health", timeout=10.0) assert resp.status_code == 200 finally: @@ -395,7 +393,7 @@ def test_pause_resume_roundtrip_keeps_services_healthy(self, gateway_controller) time.sleep(0.5) # Gateway still healthy - addr = gateway_controller.proxy_gateway_addr + addr = gateway_controller.gateway_addr resp = httpx.get(f"{addr}/health", timeout=10.0) assert resp.status_code == 200 @@ -425,7 +423,7 @@ def test_rollout_batch_with_simple_agent(self, gateway_controller): result = gateway_controller.rollout_batch( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert result is not None @@ -455,7 +453,7 @@ def accept_all(trajectory: dict) -> bool: result = gateway_controller.rollout_batch( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", should_accept_fn=accept_all, ) @@ -513,7 +511,7 @@ def test_prepare_batch_returns_results(self, gateway_controller): result = gateway_controller.prepare_batch( dataloader=dataloader, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert isinstance(result, list) @@ -546,7 +544,7 @@ def accept_all(trajectory: dict) -> bool: result = gateway_controller.prepare_batch( dataloader=dataloader, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", should_accept_fn=accept_all, ) @@ -580,7 +578,7 @@ def test_submit_returns_task_id(self, gateway_controller): task_id = gateway_controller.submit( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert isinstance(task_id, int) @@ -598,7 +596,7 @@ def test_submit_wait_roundtrip(self, gateway_controller): task_id = gateway_controller.submit( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert isinstance(task_id, int) @@ -625,7 +623,7 @@ class TestControllerOnlineWorkflow: def test_online_workflow_submit_wait_roundtrip(self, gateway_controller_online): import requests - gateway_url = gateway_controller_online.proxy_gateway_addr + gateway_url = gateway_controller_online.gateway_addr assert gateway_controller_online.config.admin_api_key is not None admin_key = gateway_controller_online.config.admin_api_key @@ -682,7 +680,7 @@ def test_online_workflow_submit_wait_roundtrip(self, gateway_controller_online): def test_offline_export_applies_discount_after_multiple_rewards_in_same_trajectory( self, gateway_controller_with_reward_timeout ): - gateway_url = gateway_controller_with_reward_timeout.proxy_gateway_addr + gateway_url = gateway_controller_with_reward_timeout.gateway_addr assert gateway_controller_with_reward_timeout.config.admin_api_key is not None admin_key = gateway_controller_with_reward_timeout.config.admin_api_key @@ -968,7 +966,7 @@ def test_rollout_batch_with_simple_agent(self, gateway_controller_full_init): result = ctrl.rollout_batch( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert result is not None @@ -994,7 +992,7 @@ def test_rtensor_localize_on_rollout_result(self, gateway_controller_full_init): result = ctrl.rollout_batch( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert result is not None @@ -1044,7 +1042,7 @@ def test_rtensor_localize_batch4(self, gateway_controller_full_init): result = ctrl.rollout_batch( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert result is not None @@ -1190,7 +1188,7 @@ def test_rollout_batch_with_simple_agent_vllm( result = ctrl.rollout_batch( data=data, - workflow="tests.experimental.openai.utils.SimpleAgent", + workflow="tests.experimental.inference_service.integration_utils.InferenceServiceAgent", ) assert result is not None diff --git a/tests/experimental/inference_service/test_external_model.py b/tests/experimental/inference_service/test_external_model.py index bef88ac7ec..bd72977721 100644 --- a/tests/experimental/inference_service/test_external_model.py +++ b/tests/experimental/inference_service/test_external_model.py @@ -409,11 +409,11 @@ def mock_areal_client(): @pytest_asyncio.fixture async def data_proxy_client(data_proxy_config, mock_tokenizer, mock_areal_client): - from areal.experimental.inference_service.data_proxy.backend import ( + from areal.experimental.inference_service.data_proxy.pause import PauseState + from areal.experimental.inference_service.inf_bridge import InfBridge + from areal.experimental.inference_service.sglang.bridge import ( SGLangBridgeBackend, ) - from areal.experimental.inference_service.data_proxy.inf_bridge import InfBridge - from areal.experimental.inference_service.data_proxy.pause import PauseState app = create_data_proxy_app(data_proxy_config) pause_state = PauseState()