Skip to content

Commit d9efa93

Browse files
author
nuzant
committed
feat(archon): add tau2 agent-service rollout workflow
Combine the agent-service controller refactor with the Tau2 rollout examples so the example stack uses a single session lifecycle and data-collection path. Key changes: - move session lifecycle and collection APIs into the agent controller - add Tau2 agent-service rollout example workflow and docs - align controller tests and gateway/session handling
1 parent 2d6ea23 commit d9efa93

15 files changed

Lines changed: 1580 additions & 115 deletions

File tree

areal/experimental/agent_service/controller/config.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,21 @@ class AgentServiceControllerConfig:
2626
admin_api_key: str = DEFAULT_ADMIN_API_KEY
2727
"""Shared admin API key for inter-service Bearer auth."""
2828

29+
# -- Inference service integration -------------------------------------
30+
inference_addr: str = ""
31+
"""Address of the inference service gateway (e.g. ``http://host:port``).
32+
Required for ``new_session`` / ``set_reward`` APIs that interact with
33+
the inference service for RL data collection."""
34+
35+
inference_model: str = ""
36+
"""Model name served by the inference service. Passed to agents so
37+
they can issue ``/chat/completions`` requests against the inference
38+
gateway."""
39+
40+
inference_api_key: str = ""
41+
"""Admin API key for the inference service gateway. Used to call
42+
``/rl/start_session`` and other admin-only inference endpoints."""
43+
2944
# -- Scaling -----------------------------------------------------------
3045
num_pairs: int = 1
3146
"""Number of Worker+DataProxy pairs to launch on initialize."""
@@ -34,6 +49,10 @@ class AgentServiceControllerConfig:
3449
setup_timeout: float = 120.0
3550
"""Timeout (seconds) waiting for each service to become healthy."""
3651

52+
request_timeout: float = 600.0
53+
"""Timeout (seconds) for runtime HTTP requests (``step()``,
54+
``set_reward()``, ``new_session()``)."""
55+
3756
health_poll_interval: float = 5.0
3857
"""Seconds between health polls for crash detection (0 = disabled)."""
3958

@@ -49,14 +68,20 @@ class AgentServiceControllerConfig:
4968
"""Extra environment variables to pass to all forked child processes."""
5069

5170
def __post_init__(self) -> None:
52-
if not self.agent_cls_path:
53-
raise ValueError("agent_cls_path must be a non-empty import path")
71+
if not self.agent_cls_path and self.num_pairs > 0:
72+
raise ValueError(
73+
"agent_cls_path must be a non-empty import path when num_pairs > 0"
74+
)
5475
if self.num_pairs < 0:
5576
raise ValueError(f"num_pairs must be non-negative, got {self.num_pairs}")
5677
if self.setup_timeout <= 0:
5778
raise ValueError(
5879
f"setup_timeout must be positive, got {self.setup_timeout}"
5980
)
81+
if self.request_timeout <= 0:
82+
raise ValueError(
83+
f"request_timeout must be positive, got {self.request_timeout}"
84+
)
6085
if self.drain_timeout < 0:
6186
raise ValueError(
6287
f"drain_timeout must be non-negative, got {self.drain_timeout}"

areal/experimental/agent_service/controller/controller.py

Lines changed: 197 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import threading
2727
import time
2828
import traceback
29+
import uuid
2930
from concurrent.futures import ThreadPoolExecutor, as_completed
3031
from dataclasses import dataclass
3132
from typing import TYPE_CHECKING, Any
@@ -74,7 +75,7 @@ class AgentServiceController:
7475
def __init__(
7576
self,
7677
config: AgentServiceControllerConfig,
77-
scheduler: Scheduler,
78+
scheduler: Scheduler | None = None,
7879
) -> None:
7980
self.config = config
8081
self.scheduler = scheduler
@@ -92,6 +93,9 @@ def __init__(
9293

9394
self._forked_services: list[tuple[str, str, int]] = []
9495

96+
self._sessions: dict[str, dict[str, Any]] = {}
97+
self._sessions_lock = threading.Lock()
98+
9599
self._health_stop = threading.Event()
96100
self._health_thread: threading.Thread | None = None
97101

@@ -105,7 +109,19 @@ def initialize(self) -> None:
105109
Order: Guards (via scheduler) → Router → Worker+DataProxy pairs →
106110
register → Gateway → health monitor.
107111
On failure, already-forked services are cleaned up via destroy().
112+
113+
When ``num_pairs`` is 0 and no scheduler is provided, the stack
114+
is skipped entirely — only the data-collection APIs
115+
(``new_session``, ``set_reward``) are available.
108116
"""
117+
if self.config.num_pairs == 0 and self.scheduler is None:
118+
logger.info(
119+
"num_pairs=0 with no scheduler; "
120+
"skipping micro-service stack (data-collection-only mode)"
121+
)
122+
return
123+
if self.scheduler is None:
124+
raise ValueError("A scheduler is required when num_pairs > 0")
109125
try:
110126
self._do_initialize()
111127
except Exception:
@@ -204,16 +220,17 @@ def destroy(self) -> None:
204220
)
205221
self._forked_services.clear()
206222

207-
for role in reversed(self._service_roles):
208-
try:
209-
self.scheduler.delete_workers(role=role)
210-
logger.info("Workers deleted for role: %s", role)
211-
except Exception:
212-
logger.error(
213-
"Error deleting workers for role %s: %s",
214-
role,
215-
traceback.format_exc(),
216-
)
223+
if self.scheduler is not None:
224+
for role in reversed(self._service_roles):
225+
try:
226+
self.scheduler.delete_workers(role=role)
227+
logger.info("Workers deleted for role: %s", role)
228+
except Exception:
229+
logger.error(
230+
"Error deleting workers for role %s: %s",
231+
role,
232+
traceback.format_exc(),
233+
)
217234
self._service_roles.clear()
218235
self._workers.clear()
219236
self._guard_addrs.clear()
@@ -371,6 +388,175 @@ def pairs(self) -> dict[int, _WorkerPair]:
371388
with self._pairs_lock:
372389
return dict(self._pairs)
373390

391+
# ------------------------------------------------------------------
392+
# Data-collection APIs (inference service integration)
393+
# ------------------------------------------------------------------
394+
395+
def new_session(self, task_id: str = "") -> dict[str, str]:
396+
"""Create a new session for data collection.
397+
398+
Generates a session ID for the agent service and starts a
399+
corresponding session on the inference service via
400+
``/rl/start_session``.
401+
402+
Parameters
403+
----------
404+
task_id:
405+
Task identifier forwarded to the inference service. Defaults
406+
to the generated session ID when empty.
407+
408+
Returns
409+
-------
410+
dict with keys:
411+
412+
* ``session_id`` — agent-service session ID (use as ``user``
413+
field in ``/v1/responses`` requests).
414+
* ``inference_session_id`` — inference-service session ID
415+
(for trajectory export).
416+
* ``inference_api_key`` — session-scoped API key for the
417+
inference gateway.
418+
"""
419+
cfg = self.config
420+
if not cfg.inference_addr:
421+
raise RuntimeError(
422+
"inference_addr must be set in AgentServiceControllerConfig "
423+
"to use data-collection APIs"
424+
)
425+
426+
session_id = f"agent-sess-{uuid.uuid4().hex[:12]}"
427+
if not task_id:
428+
task_id = session_id
429+
430+
inf_addr = cfg.inference_addr.rstrip("/")
431+
resp = requests.post(
432+
f"{inf_addr}/rl/start_session",
433+
json={"task_id": task_id},
434+
headers={"Authorization": f"Bearer {cfg.inference_api_key}"},
435+
timeout=cfg.request_timeout,
436+
)
437+
resp.raise_for_status()
438+
inf_data = resp.json()
439+
440+
session_info: dict[str, str] = {
441+
"session_id": session_id,
442+
"inference_session_id": inf_data["session_id"],
443+
"inference_api_key": inf_data["api_key"],
444+
}
445+
446+
with self._sessions_lock:
447+
self._sessions[session_id] = session_info
448+
449+
logger.info(
450+
"New session: %s (inference session: %s)",
451+
session_id,
452+
inf_data["session_id"],
453+
)
454+
return session_info
455+
456+
def step(
457+
self,
458+
input: str | list[dict[str, Any]],
459+
session_id: str,
460+
) -> dict[str, Any]:
461+
"""Send a message to the agent service and return the response.
462+
463+
Parameters
464+
----------
465+
input:
466+
A plain string or an OpenResponses-style input list
467+
(e.g. ``[{"type": "message", "content": "hello"}]``).
468+
session_id:
469+
Agent-service session ID returned by :meth:`new_session`.
470+
471+
Returns
472+
-------
473+
dict
474+
The JSON response from the agent service gateway
475+
``POST /v1/responses``.
476+
"""
477+
session_info = self._resolve_session(session_id)
478+
sid = session_info["session_id"]
479+
480+
if not self._gateway_addr:
481+
raise RuntimeError(
482+
"step() requires the agent-service gateway to be running. "
483+
"It is not available in data-collection-only mode "
484+
"(num_pairs=0 with no scheduler)."
485+
)
486+
487+
if isinstance(input, str):
488+
input_items: list[dict[str, Any]] = [{"type": "message", "content": input}]
489+
else:
490+
input_items = input
491+
492+
cfg = self.config
493+
metadata: dict[str, Any] = {}
494+
if cfg.inference_addr:
495+
metadata["inference_base_url"] = cfg.inference_addr.rstrip("/")
496+
if cfg.inference_model:
497+
metadata["inference_model"] = cfg.inference_model
498+
inf_api_key = session_info.get("inference_api_key", "")
499+
if inf_api_key:
500+
metadata["inference_api_key"] = inf_api_key
501+
502+
body: dict[str, Any] = {
503+
"input": input_items,
504+
"model": (cfg.inference_model or "default").replace("/", "--"),
505+
"user": sid,
506+
}
507+
if metadata:
508+
body["metadata"] = metadata
509+
510+
resp = requests.post(
511+
f"{self._gateway_addr}/v1/responses",
512+
json=body,
513+
headers={"Authorization": f"Bearer {cfg.admin_api_key}"},
514+
timeout=cfg.request_timeout,
515+
)
516+
resp.raise_for_status()
517+
return resp.json()
518+
519+
def set_reward(
520+
self,
521+
reward: float,
522+
session_id: str,
523+
) -> dict[str, Any]:
524+
"""Set a reward on the inference service for the current session.
525+
526+
Parameters
527+
----------
528+
reward:
529+
Scalar reward value.
530+
session_id:
531+
Agent-service session ID returned by :meth:`new_session`.
532+
533+
Returns
534+
-------
535+
dict
536+
The JSON response from the inference gateway
537+
``POST /rl/set_reward``.
538+
"""
539+
session_info = self._resolve_session(session_id)
540+
inf_api_key = session_info["inference_api_key"]
541+
542+
cfg = self.config
543+
inf_addr = cfg.inference_addr.rstrip("/")
544+
resp = requests.post(
545+
f"{inf_addr}/rl/set_reward",
546+
json={"interaction_id": None, "reward": reward},
547+
headers={"Authorization": f"Bearer {inf_api_key}"},
548+
timeout=cfg.request_timeout,
549+
)
550+
resp.raise_for_status()
551+
return resp.json()
552+
553+
def _resolve_session(self, session_id: str) -> dict[str, Any]:
554+
with self._sessions_lock:
555+
session_info = self._sessions.get(session_id)
556+
if session_info is None:
557+
raise KeyError(f"Unknown session_id: {session_id!r}")
558+
return session_info
559+
374560
# ------------------------------------------------------------------
375561
# Guard interaction helpers
376562
# ------------------------------------------------------------------

areal/experimental/inference_service/controller/controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ def start_proxy_gateway(self) -> None:
13041304
"""No-op — gateway already acts as the proxy gateway."""
13051305

13061306
@property
1307-
def proxy_gateway_addr(self) -> str:
1307+
def gateway_addr(self) -> str:
13081308
return self._gateway_addr
13091309

13101310
# -- Properties --------------------------------------------------------

0 commit comments

Comments
 (0)