-
Notifications
You must be signed in to change notification settings - Fork 499
feat(service): add tau2 agent+inference service rollout example #1226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
| import threading | ||
| import time | ||
| import traceback | ||
| import uuid | ||
| from concurrent.futures import ThreadPoolExecutor, as_completed | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
@@ -74,7 +75,7 @@ class AgentServiceController: | |
| def __init__( | ||
| self, | ||
| config: AgentServiceControllerConfig, | ||
| scheduler: Scheduler, | ||
| scheduler: Scheduler | None = None, | ||
| ) -> None: | ||
| self.config = config | ||
| self.scheduler = scheduler | ||
|
|
@@ -92,6 +93,9 @@ def __init__( | |
|
|
||
| self._forked_services: list[tuple[str, str, int]] = [] | ||
|
|
||
| self._sessions: dict[str, dict[str, Any]] = {} | ||
| self._sessions_lock = threading.Lock() | ||
|
|
||
| self._health_stop = threading.Event() | ||
| self._health_thread: threading.Thread | None = None | ||
|
|
||
|
|
@@ -105,7 +109,19 @@ def initialize(self) -> None: | |
| Order: Guards (via scheduler) → Router → Worker+DataProxy pairs → | ||
| register → Gateway → health monitor. | ||
| On failure, already-forked services are cleaned up via destroy(). | ||
|
|
||
| When ``num_pairs`` is 0 and no scheduler is provided, the stack | ||
| is skipped entirely — only the data-collection APIs | ||
| (``new_session``, ``set_reward``) are available. | ||
| """ | ||
| if self.config.num_pairs == 0 and self.scheduler is None: | ||
| logger.info( | ||
| "num_pairs=0 with no scheduler; " | ||
| "skipping micro-service stack (data-collection-only mode)" | ||
| ) | ||
| return | ||
| if self.scheduler is None: | ||
| raise ValueError("A scheduler is required when num_pairs > 0") | ||
| try: | ||
| self._do_initialize() | ||
| except Exception: | ||
|
|
@@ -204,16 +220,17 @@ def destroy(self) -> None: | |
| ) | ||
| self._forked_services.clear() | ||
|
|
||
| for role in reversed(self._service_roles): | ||
| try: | ||
| self.scheduler.delete_workers(role=role) | ||
| logger.info("Workers deleted for role: %s", role) | ||
| except Exception: | ||
| logger.error( | ||
| "Error deleting workers for role %s: %s", | ||
| role, | ||
| traceback.format_exc(), | ||
| ) | ||
| if self.scheduler is not None: | ||
| for role in reversed(self._service_roles): | ||
| try: | ||
| self.scheduler.delete_workers(role=role) | ||
| logger.info("Workers deleted for role: %s", role) | ||
| except Exception: | ||
| logger.error( | ||
| "Error deleting workers for role %s: %s", | ||
| role, | ||
| traceback.format_exc(), | ||
| ) | ||
| self._service_roles.clear() | ||
| self._workers.clear() | ||
| self._guard_addrs.clear() | ||
|
|
@@ -371,6 +388,175 @@ def pairs(self) -> dict[int, _WorkerPair]: | |
| with self._pairs_lock: | ||
| return dict(self._pairs) | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Data-collection APIs (inference service integration) | ||
| # ------------------------------------------------------------------ | ||
|
|
||
| def new_session(self, task_id: str = "") -> dict[str, str]: | ||
| """Create a new session for data collection. | ||
|
|
||
| Generates a session ID for the agent service and starts a | ||
| corresponding session on the inference service via | ||
| ``/rl/start_session``. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| task_id: | ||
| Task identifier forwarded to the inference service. Defaults | ||
| to the generated session ID when empty. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict with keys: | ||
|
|
||
| * ``session_id`` — agent-service session ID (use as ``user`` | ||
| field in ``/v1/responses`` requests). | ||
| * ``inference_session_id`` — inference-service session ID | ||
| (for trajectory export). | ||
| * ``inference_api_key`` — session-scoped API key for the | ||
| inference gateway. | ||
| """ | ||
| cfg = self.config | ||
| if not cfg.inference_addr: | ||
| raise RuntimeError( | ||
| "inference_addr must be set in AgentServiceControllerConfig " | ||
| "to use data-collection APIs" | ||
| ) | ||
|
|
||
| session_id = f"agent-sess-{uuid.uuid4().hex[:12]}" | ||
| if not task_id: | ||
| task_id = session_id | ||
|
|
||
| inf_addr = cfg.inference_addr.rstrip("/") | ||
| resp = requests.post( | ||
| f"{inf_addr}/rl/start_session", | ||
| json={"task_id": task_id}, | ||
| headers={"Authorization": f"Bearer {cfg.inference_api_key}"}, | ||
| timeout=cfg.request_timeout, | ||
| ) | ||
|
Comment on lines
+431
to
+436
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of synchronous |
||
| resp.raise_for_status() | ||
| inf_data = resp.json() | ||
|
|
||
| session_info: dict[str, str] = { | ||
| "session_id": session_id, | ||
| "inference_session_id": inf_data["session_id"], | ||
| "inference_api_key": inf_data["api_key"], | ||
| } | ||
|
|
||
| with self._sessions_lock: | ||
| self._sessions[session_id] = session_info | ||
|
|
||
| logger.info( | ||
| "New session: %s (inference session: %s)", | ||
| session_id, | ||
| inf_data["session_id"], | ||
| ) | ||
| return session_info | ||
|
|
||
| def step( | ||
| self, | ||
| input: str | list[dict[str, Any]], | ||
| session_id: str, | ||
| ) -> dict[str, Any]: | ||
| """Send a message to the agent service and return the response. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| input: | ||
| A plain string or an OpenResponses-style input list | ||
| (e.g. ``[{"type": "message", "content": "hello"}]``). | ||
| session_id: | ||
| Agent-service session ID returned by :meth:`new_session`. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict | ||
| The JSON response from the agent service gateway | ||
| ``POST /v1/responses``. | ||
| """ | ||
| session_info = self._resolve_session(session_id) | ||
| sid = session_info["session_id"] | ||
|
|
||
| if not self._gateway_addr: | ||
| raise RuntimeError( | ||
| "step() requires the agent-service gateway to be running. " | ||
| "It is not available in data-collection-only mode " | ||
| "(num_pairs=0 with no scheduler)." | ||
| ) | ||
|
|
||
| if isinstance(input, str): | ||
| input_items: list[dict[str, Any]] = [{"type": "message", "content": input}] | ||
| else: | ||
| input_items = input | ||
|
|
||
| cfg = self.config | ||
| metadata: dict[str, Any] = {} | ||
| if cfg.inference_addr: | ||
| metadata["inference_base_url"] = cfg.inference_addr.rstrip("/") | ||
| if cfg.inference_model: | ||
| metadata["inference_model"] = cfg.inference_model | ||
| inf_api_key = session_info.get("inference_api_key", "") | ||
| if inf_api_key: | ||
| metadata["inference_api_key"] = inf_api_key | ||
|
|
||
| body: dict[str, Any] = { | ||
| "input": input_items, | ||
| "model": (cfg.inference_model or "default").replace("/", "--"), | ||
| "user": sid, | ||
| } | ||
| if metadata: | ||
| body["metadata"] = metadata | ||
|
|
||
| resp = requests.post( | ||
| f"{self._gateway_addr}/v1/responses", | ||
| json=body, | ||
| headers={"Authorization": f"Bearer {cfg.admin_api_key}"}, | ||
| timeout=cfg.request_timeout, | ||
| ) | ||
| resp.raise_for_status() | ||
| return resp.json() | ||
|
|
||
| def set_reward( | ||
| self, | ||
| reward: float, | ||
| session_id: str, | ||
| ) -> dict[str, Any]: | ||
| """Set a reward on the inference service for the current session. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| reward: | ||
| Scalar reward value. | ||
| session_id: | ||
| Agent-service session ID returned by :meth:`new_session`. | ||
|
|
||
| Returns | ||
| ------- | ||
| dict | ||
| The JSON response from the inference gateway | ||
| ``POST /rl/set_reward``. | ||
| """ | ||
| session_info = self._resolve_session(session_id) | ||
| inf_api_key = session_info["inference_api_key"] | ||
|
|
||
| cfg = self.config | ||
| inf_addr = cfg.inference_addr.rstrip("/") | ||
| resp = requests.post( | ||
| f"{inf_addr}/rl/set_reward", | ||
| json={"interaction_id": None, "reward": reward}, | ||
| headers={"Authorization": f"Bearer {inf_api_key}"}, | ||
| timeout=cfg.request_timeout, | ||
| ) | ||
| resp.raise_for_status() | ||
| return resp.json() | ||
|
|
||
| def _resolve_session(self, session_id: str) -> dict[str, Any]: | ||
| with self._sessions_lock: | ||
| session_info = self._sessions.get(session_id) | ||
| if session_info is None: | ||
| raise KeyError(f"Unknown session_id: {session_id!r}") | ||
| return session_info | ||
|
|
||
| # ------------------------------------------------------------------ | ||
| # Guard interaction helpers | ||
| # ------------------------------------------------------------------ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_sessionsdictionary grows indefinitely as new sessions are created vianew_session(), but there is no mechanism to remove them once they are completed (e.g., afterset_reward()). This will lead to a memory leak in long-running controller instances. Consider adding a way to prune old sessions or at least clearing them in thedestroy()method.