Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions areal/experimental/agent_service/controller/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ class AgentServiceControllerConfig:
admin_api_key: str = DEFAULT_ADMIN_API_KEY
"""Shared admin API key for inter-service Bearer auth."""

# -- Inference service integration -------------------------------------
inference_addr: str = ""
"""Address of the inference service gateway (e.g. ``http://host:port``).
Required for ``new_session`` / ``set_reward`` APIs that interact with
the inference service for RL data collection."""

inference_model: str = ""
"""Model name served by the inference service. Passed to agents so
they can issue ``/chat/completions`` requests against the inference
gateway."""

inference_api_key: str = ""
"""Admin API key for the inference service gateway. Used to call
``/rl/start_session`` and other admin-only inference endpoints."""

# -- Scaling -----------------------------------------------------------
num_pairs: int = 1
"""Number of Worker+DataProxy pairs to launch on initialize."""
Expand All @@ -34,6 +49,10 @@ class AgentServiceControllerConfig:
setup_timeout: float = 120.0
"""Timeout (seconds) waiting for each service to become healthy."""

request_timeout: float = 600.0
"""Timeout (seconds) for runtime HTTP requests (``step()``,
``set_reward()``, ``new_session()``)."""

health_poll_interval: float = 5.0
"""Seconds between health polls for crash detection (0 = disabled)."""

Expand All @@ -49,14 +68,20 @@ class AgentServiceControllerConfig:
"""Extra environment variables to pass to all forked child processes."""

def __post_init__(self) -> None:
if not self.agent_cls_path:
raise ValueError("agent_cls_path must be a non-empty import path")
if not self.agent_cls_path and self.num_pairs > 0:
raise ValueError(
"agent_cls_path must be a non-empty import path when num_pairs > 0"
)
if self.num_pairs < 0:
raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}")
if self.setup_timeout <= 0:
raise ValueError(
f"setup_timeout must be positive, got {self.setup_timeout}"
)
if self.request_timeout <= 0:
raise ValueError(
f"request_timeout must be positive, got {self.request_timeout}"
)
if self.drain_timeout < 0:
raise ValueError(
f"drain_timeout must be non-negative, got {self.drain_timeout}"
Expand Down
208 changes: 197 additions & 11 deletions areal/experimental/agent_service/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -74,7 +75,7 @@ class AgentServiceController:
def __init__(
self,
config: AgentServiceControllerConfig,
scheduler: Scheduler,
scheduler: Scheduler | None = None,
) -> None:
self.config = config
self.scheduler = scheduler
Expand All @@ -92,6 +93,9 @@ def __init__(

self._forked_services: list[tuple[str, str, int]] = []

self._sessions: dict[str, dict[str, Any]] = {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _sessions dictionary grows indefinitely as new sessions are created via new_session(), but there is no mechanism to remove them once they are completed (e.g., after set_reward()). This will lead to a memory leak in long-running controller instances. Consider adding a way to prune old sessions or at least clearing them in the destroy() method.

self._sessions_lock = threading.Lock()

self._health_stop = threading.Event()
self._health_thread: threading.Thread | None = None

Expand All @@ -105,7 +109,19 @@ def initialize(self) -> None:
Order: Guards (via scheduler) → Router → Worker+DataProxy pairs →
register → Gateway → health monitor.
On failure, already-forked services are cleaned up via destroy().

When ``num_pairs`` is 0 and no scheduler is provided, the stack
is skipped entirely — only the data-collection APIs
(``new_session``, ``set_reward``) are available.
"""
if self.config.num_pairs == 0 and self.scheduler is None:
logger.info(
"num_pairs=0 with no scheduler; "
"skipping micro-service stack (data-collection-only mode)"
)
return
if self.scheduler is None:
raise ValueError("A scheduler is required when num_pairs > 0")
try:
self._do_initialize()
except Exception:
Expand Down Expand Up @@ -204,16 +220,17 @@ def destroy(self) -> None:
)
self._forked_services.clear()

for role in reversed(self._service_roles):
try:
self.scheduler.delete_workers(role=role)
logger.info("Workers deleted for role: %s", role)
except Exception:
logger.error(
"Error deleting workers for role %s: %s",
role,
traceback.format_exc(),
)
if self.scheduler is not None:
for role in reversed(self._service_roles):
try:
self.scheduler.delete_workers(role=role)
logger.info("Workers deleted for role: %s", role)
except Exception:
logger.error(
"Error deleting workers for role %s: %s",
role,
traceback.format_exc(),
)
self._service_roles.clear()
self._workers.clear()
self._guard_addrs.clear()
Expand Down Expand Up @@ -371,6 +388,175 @@ def pairs(self) -> dict[int, _WorkerPair]:
with self._pairs_lock:
return dict(self._pairs)

# ------------------------------------------------------------------
# Data-collection APIs (inference service integration)
# ------------------------------------------------------------------

def new_session(self, task_id: str = "") -> dict[str, str]:
"""Create a new session for data collection.

Generates a session ID for the agent service and starts a
corresponding session on the inference service via
``/rl/start_session``.

Parameters
----------
task_id:
Task identifier forwarded to the inference service. Defaults
to the generated session ID when empty.

Returns
-------
dict with keys:

* ``session_id`` — agent-service session ID (use as ``user``
field in ``/v1/responses`` requests).
* ``inference_session_id`` — inference-service session ID
(for trajectory export).
* ``inference_api_key`` — session-scoped API key for the
inference gateway.
"""
cfg = self.config
if not cfg.inference_addr:
raise RuntimeError(
"inference_addr must be set in AgentServiceControllerConfig "
"to use data-collection APIs"
)

session_id = f"agent-sess-{uuid.uuid4().hex[:12]}"
if not task_id:
task_id = session_id

inf_addr = cfg.inference_addr.rstrip("/")
resp = requests.post(
f"{inf_addr}/rl/start_session",
json={"task_id": task_id},
headers={"Authorization": f"Bearer {cfg.inference_api_key}"},
timeout=cfg.request_timeout,
)
Comment on lines +431 to +436
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of synchronous requests.post here (and in step and set_reward) will block the event loop when called from asynchronous contexts, such as the Tau2AgentServiceWorkflow introduced in this PR. This can significantly degrade performance and scalability when handling multiple concurrent rollouts. Consider using an asynchronous client like httpx.AsyncClient or wrapping these calls in asyncio.to_thread.

resp.raise_for_status()
inf_data = resp.json()

session_info: dict[str, str] = {
"session_id": session_id,
"inference_session_id": inf_data["session_id"],
"inference_api_key": inf_data["api_key"],
}

with self._sessions_lock:
self._sessions[session_id] = session_info

logger.info(
"New session: %s (inference session: %s)",
session_id,
inf_data["session_id"],
)
return session_info

def step(
self,
input: str | list[dict[str, Any]],
session_id: str,
) -> dict[str, Any]:
"""Send a message to the agent service and return the response.

Parameters
----------
input:
A plain string or an OpenResponses-style input list
(e.g. ``[{"type": "message", "content": "hello"}]``).
session_id:
Agent-service session ID returned by :meth:`new_session`.

Returns
-------
dict
The JSON response from the agent service gateway
``POST /v1/responses``.
"""
session_info = self._resolve_session(session_id)
sid = session_info["session_id"]

if not self._gateway_addr:
raise RuntimeError(
"step() requires the agent-service gateway to be running. "
"It is not available in data-collection-only mode "
"(num_pairs=0 with no scheduler)."
)

if isinstance(input, str):
input_items: list[dict[str, Any]] = [{"type": "message", "content": input}]
else:
input_items = input

cfg = self.config
metadata: dict[str, Any] = {}
if cfg.inference_addr:
metadata["inference_base_url"] = cfg.inference_addr.rstrip("/")
if cfg.inference_model:
metadata["inference_model"] = cfg.inference_model
inf_api_key = session_info.get("inference_api_key", "")
if inf_api_key:
metadata["inference_api_key"] = inf_api_key

body: dict[str, Any] = {
"input": input_items,
"model": (cfg.inference_model or "default").replace("/", "--"),
"user": sid,
}
if metadata:
body["metadata"] = metadata

resp = requests.post(
f"{self._gateway_addr}/v1/responses",
json=body,
headers={"Authorization": f"Bearer {cfg.admin_api_key}"},
timeout=cfg.request_timeout,
)
resp.raise_for_status()
return resp.json()

def set_reward(
self,
reward: float,
session_id: str,
) -> dict[str, Any]:
"""Set a reward on the inference service for the current session.

Parameters
----------
reward:
Scalar reward value.
session_id:
Agent-service session ID returned by :meth:`new_session`.

Returns
-------
dict
The JSON response from the inference gateway
``POST /rl/set_reward``.
"""
session_info = self._resolve_session(session_id)
inf_api_key = session_info["inference_api_key"]

cfg = self.config
inf_addr = cfg.inference_addr.rstrip("/")
resp = requests.post(
f"{inf_addr}/rl/set_reward",
json={"interaction_id": None, "reward": reward},
headers={"Authorization": f"Bearer {inf_api_key}"},
timeout=cfg.request_timeout,
)
resp.raise_for_status()
return resp.json()

def _resolve_session(self, session_id: str) -> dict[str, Any]:
with self._sessions_lock:
session_info = self._sessions.get(session_id)
if session_info is None:
raise KeyError(f"Unknown session_id: {session_id!r}")
return session_info

# ------------------------------------------------------------------
# Guard interaction helpers
# ------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ def start_proxy_gateway(self) -> None:
"""No-op — gateway already acts as the proxy gateway."""

@property
def proxy_gateway_addr(self) -> str:
def gateway_addr(self) -> str:
return self._gateway_addr

# -- Properties --------------------------------------------------------
Expand Down
Loading
Loading