diff --git a/app/control/account/invalid_credentials.py b/app/control/account/invalid_credentials.py index fa6826233..ad66c5a6a 100644 --- a/app/control/account/invalid_credentials.py +++ b/app/control/account/invalid_credentials.py @@ -45,8 +45,7 @@ async def mark_account_invalid_credentials( "expired_reason": reason, }, ) - ] - ) + ]) logger.info( "account expired from {}: token={}... status={} upstream_status={}", source, @@ -72,6 +71,12 @@ def feedback_kind_for_error(exc: BaseException | None) -> FeedbackKind: status = getattr(exc, "status", 0) if status == 429: return FeedbackKind.RATE_LIMITED + if status == 402: + # console.x.ai returns 402 when the account has exhausted its + # prepaid web_search/credit balance. Treat it like a rate limit + # so the account pool routes around this token until credits + # refresh or the operator tops up the balance. + return FeedbackKind.RATE_LIMITED if status == 401: return FeedbackKind.UNAUTHORIZED if status == 403: diff --git a/app/control/account/state_machine.py b/app/control/account/state_machine.py index 5e626edf3..aa2dde847 100644 --- a/app/control/account/state_machine.py +++ b/app/control/account/state_machine.py @@ -27,7 +27,6 @@ class StatePolicy: _DEFAULT_POLICY = StatePolicy() - # --------------------------------------------------------------------------- # Feedback # --------------------------------------------------------------------------- @@ -66,7 +65,9 @@ def from_status_code( kind = FeedbackKind.UNAUTHORIZED elif status_code == 403: kind = FeedbackKind.FORBIDDEN - elif status_code == 429: + elif status_code == 429 or status_code == 402: + # 402 from console.x.ai = account credits exhausted; treat as a + # rate-limited token so the pool routes around it. kind = FeedbackKind.RATE_LIMITED elif status_code >= 500: kind = FeedbackKind.SERVER_ERROR @@ -111,9 +112,7 @@ def derive_status(record: AccountRecord, *, now: int | None = None) -> AccountSt return AccountStatus.COOLING -def is_selectable( - record: AccountRecord, mode_id: int, *, now: int | None = None -) -> bool: +def is_selectable(record: AccountRecord, mode_id: int, *, now: int | None = None) -> bool: """Return True if the account can be selected for *mode_id*.""" if record.is_deleted(): return False @@ -185,10 +184,7 @@ def apply_feedback( win = qs.get(feedback.mode_id) if win is not None: reset_at = ( - ts + feedback.retry_after_ms - if feedback.retry_after_ms - else (ts + win.window_seconds * 1000) - ) + ts + feedback.retry_after_ms if feedback.retry_after_ms else (ts + win.window_seconds * 1000)) qs.set( feedback.mode_id, QuotaWindow( @@ -206,10 +202,10 @@ def apply_feedback( use_count += 1 last_use_at = ts elif feedback.kind not in ( - FeedbackKind.SUCCESS, - FeedbackKind.RESTORE, - FeedbackKind.DISABLE, - FeedbackKind.DELETE, + FeedbackKind.SUCCESS, + FeedbackKind.RESTORE, + FeedbackKind.DISABLE, + FeedbackKind.DELETE, ): fail_count += 1 last_fail_at = ts @@ -236,11 +232,7 @@ def apply_feedback( ext[_DISABLED_REASON_KEY] = state_reason elif feedback.kind == FeedbackKind.RATE_LIMITED: - cooldown_ms = ( - feedback.retry_after_ms - if feedback.retry_after_ms - else policy.default_cooling_ms - ) + cooldown_ms = (feedback.retry_after_ms if feedback.retry_after_ms else policy.default_cooling_ms) status = AccountStatus.COOLING state_reason = feedback.reason or "rate_limited" ext[_COOLDOWN_UNTIL_KEY] = ts + cooldown_ms @@ -292,21 +284,20 @@ def apply_feedback( "state_reason": state_reason, "ext": ext, "updated_at": ts, - } - ) + }) def clear_failures(record: AccountRecord) -> AccountRecord: """Reset failure counters and restore ACTIVE status.""" ext = dict(record.ext) for k in ( - _COOLDOWN_UNTIL_KEY, - _COOLDOWN_REASON_KEY, - _DISABLED_AT_KEY, - _DISABLED_REASON_KEY, - _EXPIRED_AT_KEY, - _EXPIRED_REASON_KEY, - _FORBIDDEN_STRIKE_KEY, + _COOLDOWN_UNTIL_KEY, + _COOLDOWN_REASON_KEY, + _DISABLED_AT_KEY, + _DISABLED_REASON_KEY, + _EXPIRED_AT_KEY, + _EXPIRED_REASON_KEY, + _FORBIDDEN_STRIKE_KEY, ): ext.pop(k, None) return record.model_copy( @@ -318,8 +309,7 @@ def clear_failures(record: AccountRecord) -> AccountRecord: "state_reason": None, "ext": ext, "updated_at": now_ms(), - } - ) + }) __all__ = [ diff --git a/app/control/model/registry.py b/app/control/model/registry.py index bed8e5375..073870b75 100644 --- a/app/control/model/registry.py +++ b/app/control/model/registry.py @@ -36,6 +36,23 @@ # Super+(basic 池不支持此模式) ModelSpec("grok-4.3-beta", ModeId.GROK_4_3, Tier.SUPER, Capability.CHAT, True, "Grok 4.3 Beta"), + # === Console API (console.x.ai/v1/responses) ============================ + # 通过 SSO cookie 直接调用 console.x.ai,basic 账号即可使用所有模型 + # 速率限制由 console.x.ai 控制(免费 tier: 1 rps / 60 RPM) + # Hybrid reasoning models default to effort="high" so callers that omit + # reasoning_effort still get the "think hard" experience the model name + # implies. Pass an explicit value (e.g. "minimal") to override. + ModelSpec("grok-4.3", ModeId.FAST, Tier.BASIC, Capability.CHAT, True, "Grok 4.3 (Console)", console_model="grok-4.3", default_reasoning_effort="high"), + ModelSpec("grok-4", ModeId.FAST, Tier.BASIC, Capability.CHAT, True, "Grok 4 (Console)", console_model="grok-4", default_reasoning_effort="high"), + ModelSpec("grok-4.20", ModeId.FAST, Tier.BASIC, Capability.CHAT, True, "Grok 4.20 (Console)", console_model="grok-4.20", default_reasoning_effort="high"), + # Fixed-intensity reasoning model — upstream rejects reasoning.effort. + ModelSpec("grok-4.20-reasoning", ModeId.FAST, Tier.BASIC, Capability.CHAT, True, "Grok 4.20 Reasoning (Console)", console_model="grok-4.20-0309-reasoning"), + # Non-reasoning model — effort is not applicable. + ModelSpec("grok-4.20-non-reasoning", ModeId.FAST, Tier.BASIC, Capability.CHAT, True, "Grok 4.20 Non-Reasoning (Console)", console_model="grok-4.20-0309-non-reasoning"), + # Multi-agent — left default; effort behaviour with this variant has not + # been verified, so we don't auto-inject "high" to avoid surprising 400s. + ModelSpec("grok-4.20-multi-agent", ModeId.FAST, Tier.BASIC, Capability.CHAT, True, "Grok 4.20 Multi-Agent (Console)", console_model="grok-4.20-multi-agent-0309"), + # === Image ============================================================== # Basic fast @@ -66,7 +83,6 @@ for _m in MODELS: _BY_CAP.setdefault(int(_m.capability), []).append(_m) - # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- diff --git a/app/control/model/spec.py b/app/control/model/spec.py index c860be0f5..ef32881b1 100644 --- a/app/control/model/spec.py +++ b/app/control/model/spec.py @@ -20,6 +20,23 @@ class ModelSpec: ``public_name`` is the human-readable display name. ``prefer_best`` when True, reverses pool priority to try higher-tier pools first (hard priority, not soft preference). + ``console_model`` when non-empty, route this model through the + ``console.x.ai/v1/responses`` endpoint instead of the + ``grok.com`` web chat API. The string is the actual + model ID sent to console.x.ai (e.g. ``"grok-4.3"``). + SSO cookies from grok.com work for both endpoints, + so basic-tier accounts can access all models this way. + ``default_reasoning_effort`` when non-empty, this value is forwarded as + ``reasoning.effort`` to the console upstream when the + caller doesn't specify ``reasoning_effort`` themselves. + Use ``"high"`` for hybrid reasoning models the user + expects to "think hard by default" (grok-4, grok-4.3, + grok-4.20). Leave empty for models that either don't + support the effort field (grok-4.20-reasoning is fixed + intensity; the upstream rejects effort with HTTP 400) + or don't reason at all (grok-4.20-non-reasoning). + Only consulted when ``console_model`` is set; ignored + on the legacy grok.com path. """ model_name: str @@ -29,6 +46,8 @@ class ModelSpec: enabled: bool public_name: str prefer_best: bool = False + console_model: str = "" + default_reasoning_effort: str = "" # --- convenience predicates --- @@ -47,6 +66,10 @@ def is_video(self) -> bool: def is_voice(self) -> bool: return bool(self.capability & Capability.VOICE) + def is_console(self) -> bool: + """Return True if this model routes through console.x.ai.""" + return bool(self.console_model) + def pool_name(self) -> str: """Return the canonical pool string for this tier.""" if self.tier == Tier.SUPER: @@ -80,7 +103,7 @@ def pool_candidates(self) -> tuple[int, ...]: """ if self.prefer_best: if self.tier == Tier.HEAVY: - return (2,) # heavy only + return (2, ) # heavy only if self.tier == Tier.SUPER: return (2, 1) # heavy, super return (2, 1, 0) # heavy, super, basic @@ -88,7 +111,7 @@ def pool_candidates(self) -> tuple[int, ...]: return (0, 1, 2) # basic, super, heavy if self.tier == Tier.SUPER: return (1, 2) # super, heavy - return (2,) # heavy only + return (2, ) # heavy only __all__ = ["ModelSpec"] diff --git a/app/dataplane/reverse/planner.py b/app/dataplane/reverse/planner.py index b034b9675..3dc88fb69 100644 --- a/app/dataplane/reverse/planner.py +++ b/app/dataplane/reverse/planner.py @@ -8,27 +8,41 @@ from app.control.model.spec import ModelSpec from app.dataplane.reverse.runtime.endpoint_table import ( - CHAT, MEDIA_POST, WS_IMAGINE, + CHAT, + CONSOLE_RESPONSES, + MEDIA_POST, + WS_IMAGINE, ) from .types import ReversePlan, TransportKind - # --------------------------------------------------------------------------- # Profile defaults (timeout / content-type per transport) # --------------------------------------------------------------------------- _DEFAULTS: dict[TransportKind, dict[str, Any]] = { - TransportKind.HTTP_SSE: {"timeout_s": 120.0, "content_type": "application/json"}, - TransportKind.HTTP_JSON: {"timeout_s": 30.0, "content_type": "application/json"}, - TransportKind.WEBSOCKET: {"timeout_s": 300.0, "content_type": "application/json"}, - TransportKind.GRPC_WEB: {"timeout_s": 15.0, "content_type": "application/grpc-web+proto"}, + TransportKind.HTTP_SSE: { + "timeout_s": 120.0, + "content_type": "application/json" + }, + TransportKind.HTTP_JSON: { + "timeout_s": 30.0, + "content_type": "application/json" + }, + TransportKind.WEBSOCKET: { + "timeout_s": 300.0, + "content_type": "application/json" + }, + TransportKind.GRPC_WEB: { + "timeout_s": 15.0, + "content_type": "application/grpc-web+proto" + }, } - # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- + def build_plan(spec: ModelSpec, request: dict[str, Any] | None = None) -> ReversePlan: """Produce a ReversePlan for the given model spec. @@ -39,12 +53,12 @@ def build_plan(spec: ModelSpec, request: dict[str, Any] | None = None) -> Revers defaults = _DEFAULTS.get(tkind, _DEFAULTS[TransportKind.HTTP_JSON]) return ReversePlan( - endpoint = endpoint, - transport_kind = tkind, - pool_candidates = spec.pool_candidates(), - mode_id = int(spec.mode_id), - timeout_s = defaults["timeout_s"], - content_type = defaults["content_type"], + endpoint=endpoint, + transport_kind=tkind, + pool_candidates=spec.pool_candidates(), + mode_id=int(spec.mode_id), + timeout_s=defaults["timeout_s"], + content_type=defaults["content_type"], ) @@ -52,12 +66,21 @@ def build_plan(spec: ModelSpec, request: dict[str, Any] | None = None) -> Revers # Internal routing logic # --------------------------------------------------------------------------- + def _resolve_endpoint( spec: ModelSpec, request: dict[str, Any], ) -> tuple[str, TransportKind]: """Determine (endpoint_url, transport_kind) for the given capability.""" + # Console models route through console.x.ai/v1/responses (OpenAI Responses API) + if spec.is_console() and spec.is_chat(): + # When stream=true the response is SSE; otherwise plain JSON. + # Use HTTP_SSE as the default since both streaming and non-streaming + # share the same content-type and the timeout profile is more + # permissive (long-running responses). + return CONSOLE_RESPONSES, TransportKind.HTTP_SSE + if spec.is_chat(): return CHAT, TransportKind.HTTP_SSE diff --git a/app/dataplane/reverse/protocol/xai_console.py b/app/dataplane/reverse/protocol/xai_console.py new file mode 100644 index 000000000..e7274eee9 --- /dev/null +++ b/app/dataplane/reverse/protocol/xai_console.py @@ -0,0 +1,892 @@ +"""Console API protocol — payload builder and response parser. + +The ``console.x.ai/v1/responses`` endpoint shares SSO cookies with grok.com +but exposes the OpenAI Responses API directly. Free/basic accounts can call +all models (grok-4.3, grok-4.20-*, etc.) through this endpoint, bypassing +the tier restrictions of the grok.com web chat API. + +The upstream API supports: + - Plain string input or structured input arrays (for multimodal / chat history) + - Native function calling via ``tools`` field + - Reasoning summary streaming + - SSE streaming with OpenAI Responses API event names + +Request format (string input): + {"model": "grok-4.3", "input": "What is 1+1?", "stream": true} + +Request format (structured input + tools): + {s model. Please try again later. Resets in: 30m0s (trace ID: 681cf17538e7dbc5a2362c74348fb8b9) + Feedback submitted + "model": "grok-4.3", + "input": [ + {"role": "user", "content": [ + {"type": "input_text", "text": "What's the weather?"}, + {"type": "input_image", "image_url": "https://...", "detail": "auto"} + ]} + ], + "tools": [ + {"type": "function", "name": "get_weather", + "description": "...", "parameters": {...}} + ], + "tool_choice": "auto" + } + +Response output items (non-streaming): + - {"type": "reasoning", "summary": [{"type": "summary_text", "text": "..."}]} + - {"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "...", "annotations": [...]}]} + - {"type": "function_call", "call_id": "...", "name": "...", "arguments": "..."} +""" + +from typing import Any + +import orjson + +from app.platform.config.snapshot import get_config +from app.platform.errors import UpstreamError +from app.platform.logging.logger import logger + +# --------------------------------------------------------------------------- +# Input conversion (OpenAI Chat Completions → console.x.ai input array) +# --------------------------------------------------------------------------- + + +def build_console_input(messages: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], str]: + """Convert OpenAI Chat Completions messages → console structured input. + + Returns ``(input_array, instructions)``: + - ``input_array`` is the list passed as Responses API ``input`` field. + - ``instructions`` aggregates all ``role=system`` messages and is + passed via the separate Responses API ``instructions`` field for + better reasoning model behaviour. + + Mapping rules: + - ``role=system`` → folded into ``instructions`` + - ``role=user/assistant`` → preserved with content blocks converted + - Content block ``text`` → ``{type: input_text/output_text, text}`` + - Content block ``image_url`` → ``{type: input_image, image_url, detail}`` + - ``role=tool`` → ``{type: function_call_output, + call_id, output}`` + - ``role=assistant`` with ``tool_calls`` → emit one ``function_call`` + item per call before any accompanying text. + """ + instructions_parts: list[str] = [] + output: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role") or "user" + content = msg.get("content") + tool_calls = msg.get("tool_calls") + + # ── system → instructions ──────────────────────────────────────── + if role == "system": + if isinstance(content, str) and content.strip(): + instructions_parts.append(content.strip()) + elif isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + text = block.get("text") or "" + if text.strip(): + instructions_parts.append(text.strip()) + continue + + # ── tool result → function_call_output ─────────────────────────── + if role == "tool": + call_id = msg.get("tool_call_id") or "" + text = content if isinstance(content, str) else _flatten_text(content) + output.append({ + "type": "function_call_output", + "call_id": call_id, + "output": text or "", + }) + continue + + # ── assistant with tool_calls → function_call items ────────────── + if role == "assistant" and tool_calls: + for tc in tool_calls: + if not isinstance(tc, dict): + continue + fn = tc.get("function") or {} + output.append( + { + "type": "function_call", + "call_id": tc.get("id") or fn.get("name") or "", + "name": fn.get("name") or "", + "arguments": fn.get("arguments") or "{}", + }) + # Trailing assistant text (rare) is emitted as a normal message + text = content if isinstance(content, str) else _flatten_text(content) + if text and text.strip(): + output.append( + { + "role": "assistant", + "content": [{ + "type": "output_text", + "text": text.strip() + }], + }) + continue + + # ── normal user / assistant message ────────────────────────────── + blocks = _convert_content_blocks(content, role) + if not blocks: + continue + output.append({"role": role, "content": blocks}) + + instructions = "\n\n".join(instructions_parts).strip() + return output, instructions + + +def _flatten_text(content: Any) -> str: + """Flatten an OpenAI content array into a single text string.""" + if isinstance(content, str): + return content + if not isinstance(content, list): + return "" + parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") == "text": + text = block.get("text") or "" + if text: + parts.append(text) + return "\n".join(parts) + + +def _convert_content_blocks( + content: Any, + role: str, +) -> list[dict[str, Any]]: + """Convert one OpenAI message content (str or array) → console blocks.""" + text_type = "output_text" if role == "assistant" else "input_text" + + # Plain string content + if isinstance(content, str): + text = content.strip() + if not text: + return [] + return [{"type": text_type, "text": text}] + + # Already-structured array + if not isinstance(content, list): + return [] + + blocks: list[dict[str, Any]] = [] + for block in content: + if not isinstance(block, dict): + continue + btype = block.get("type") + + if btype == "text": + text = block.get("text") or "" + if text.strip(): + blocks.append({"type": text_type, "text": text}) + elif btype == "image_url": + inner = block.get("image_url") or {} + if isinstance(inner, str): + url, detail = inner, "auto" + else: + url = inner.get("url") or "" + detail = inner.get("detail") or "auto" + if url: + blocks.append({ + "type": "input_image", + "image_url": url, + "detail": detail, + }) + elif btype in ("input_text", "output_text", "input_image"): + # Already in console format — pass through + blocks.append(dict(block)) + + return blocks + + +# --------------------------------------------------------------------------- +# Tool format conversion +# --------------------------------------------------------------------------- + + +def convert_openai_tools_to_console(tools: list[dict[str, Any]] | None, ) -> list[dict[str, Any]]: + """Convert OpenAI Chat Completions tools → console (Responses API) tools. + + OpenAI Chat Completions: + {"type": "function", "function": {"name", "description", "parameters"}} + + Console (Responses API): + {"type": "function", "name", "description", "parameters"} + + Already-flat tools are passed through (e.g. ``web_search`` server-side + tool, ``code_interpreter``, ``x_search`` etc.). + """ + if not tools: + return [] + out: list[dict[str, Any]] = [] + for t in tools: + if not isinstance(t, dict): + continue + if t.get("type") != "function": + # Pass through server-side tools (web_search, x_search, etc.) + out.append(dict(t)) + continue + fn = t.get("function") if isinstance(t.get("function"), dict) else None + if fn is not None: + out.append( + { + "type": "function", + "name": fn.get("name") or "", + "description": fn.get("description") or "", + "parameters": fn.get("parameters") or {}, + }) + else: + # Already flat + out.append(dict(t)) + return out + + +def convert_openai_tool_choice(tool_choice: Any) -> Any: + """Convert OpenAI tool_choice → console tool_choice. + + OpenAI: "none" | "auto" | "required" | {"type":"function","function":{"name":"x"}} + Console: "none" | "auto" | "required" | {"type":"function","name":"x"} + """ + if isinstance(tool_choice, str): + return tool_choice + if isinstance(tool_choice, dict) and tool_choice.get("type") == "function": + fn = tool_choice.get("function") if isinstance(tool_choice.get("function"), dict) else None + if fn: + return {"type": "function", "name": fn.get("name") or ""} + return dict(tool_choice) + return tool_choice + + +# --------------------------------------------------------------------------- +# Payload builder +# --------------------------------------------------------------------------- + + +def build_console_payload( + *, + console_model: str, + input: Any, + instructions: str = "", + stream: bool = False, + temperature: float | None = None, + top_p: float | None = None, + reasoning_effort: str | None = None, + tools: list[dict[str, Any]] | None = None, + tool_choice: Any = None, +) -> dict[str, Any]: + """Build the JSON payload for POST /v1/responses on console.x.ai. + + ``input`` may be a plain string or an array of structured input items + (use :func:`build_console_input` to convert OpenAI messages). + + ``tools`` should already be in console format (use + :func:`convert_openai_tools_to_console`). + + ``features.custom_instruction`` (the admin-configured global system + prompt) is merged into ``instructions`` at the protocol layer so that + every console request mirrors the grok.com path's ``customPersonality`` + injection. The global instruction is prepended; per-request system + messages follow and may refine or override it. + """ + payload: dict[str, Any] = { + "model": console_model, + "input": input, + } + if stream: + payload["stream"] = True + + custom = get_config().get_str("features.custom_instruction", "").strip() + user_sys = (instructions or "").strip() + merged = "\n\n".join(p for p in (custom, user_sys) if p) + if merged: + payload["instructions"] = merged + if temperature is not None: + payload["temperature"] = temperature + if top_p is not None: + payload["top_p"] = top_p + # Console upstream accepts effort ∈ {"minimal", "low", "medium", "high"}. + # Map project-specific values: "none" → omit (emit_think handles client-side + # suppression separately); "xhigh" → "high" (upstream cap). + if reasoning_effort and reasoning_effort != "none": + upstream_effort = "high" if reasoning_effort == "xhigh" else reasoning_effort + payload["reasoning"] = {"effort": upstream_effort} + if tools: + payload["tools"] = tools + if tool_choice is not None: + payload["tool_choice"] = tool_choice + + if isinstance(input, str): + msg_repr = f"len={len(input)}" + elif isinstance(input, list): + msg_repr = f"items={len(input)}" + else: + msg_repr = "unknown" + logger.debug( + "console payload built: model={} stream={} input_{} tools={}", + console_model, + stream, + msg_repr, + len(tools) if tools else 0, + ) + return payload + + +# --------------------------------------------------------------------------- +# Non-streaming response parsing +# --------------------------------------------------------------------------- + + +def extract_console_text(response_json: dict[str, Any]) -> str: + """Extract the assistant's final text from a non-streaming response.""" + output = response_json.get("output") or [] + for item in output: + if not isinstance(item, dict): + continue + if item.get("type") != "message": + continue + contents = item.get("content") or [] + for c in contents: + if not isinstance(c, dict): + continue + if c.get("type") == "output_text": + return c.get("text") or "" + return "" + + +def extract_console_reasoning(response_json: dict[str, Any]) -> str: + """Extract reasoning summary text if present (non-streaming).""" + output = response_json.get("output") or [] + for item in output: + if not isinstance(item, dict): + continue + if item.get("type") == "reasoning": + summary = item.get("summary") or [] + parts: list[str] = [] + for s in summary: + if isinstance(s, dict): + text = s.get("text") or s.get("content") or "" + if text: + parts.append(text) + elif isinstance(s, str): + parts.append(s) + return "\n".join(parts) + return "" + + +def extract_console_tool_calls(response_json: dict[str, Any], ) -> list[dict[str, Any]]: + """Extract tool calls from a non-streaming response. + + Returns a list of OpenAI Chat Completions tool_call dicts: + [{"id": "call_xxx", "type": "function", + "function": {"name": "...", "arguments": "..."}}] + + Console responses include each tool call as a top-level output item + of type ``function_call`` with a ``call_id``, ``name`` and + JSON-serialised ``arguments`` string. + """ + output = response_json.get("output") or [] + calls: list[dict[str, Any]] = [] + for item in output: + if not isinstance(item, dict): + continue + if item.get("type") != "function_call": + continue + call_id = item.get("call_id") or item.get("id") or "" + calls.append( + { + "id": call_id, + "type": "function", + "function": { + "name": item.get("name") or "", + "arguments": item.get("arguments") or "{}", + }, + }) + return calls + + +def extract_console_search_sources(response_json: dict[str, Any], ) -> list[dict[str, Any]]: + """Extract the search sources list from web_search_call output items. + + Returns a deduplicated list of source dicts in the format used by the + existing grok.com path: + [{"url": "https://...", "title": ""}, ...] + + Two upstream variants are handled: + + 1. Single-agent models (grok-4.3, grok-4.20-reasoning) emit a + ``web_search_call`` output item per search with full sources: + ``{"type": "search", "sources": [{"url": "..."}, ...]}`` + or ``{"type": "open_page", "url": "..."}``. + + 2. Multi-agent models (grok-4.20-multi-agent) skip ``web_search_call`` + items entirely and embed URLs only as document-level annotations on + the final assistant message with ``start_index == end_index == 0``. + We fall back to those annotation URLs so callers always see a + useful citation list regardless of the upstream emission format. + """ + seen: set[str] = set() + out: list[dict[str, Any]] = [] + for item in response_json.get("output") or []: + if not isinstance(item, dict): + continue + if item.get("type") != "web_search_call": + continue + action = item.get("action") or {} + if not isinstance(action, dict): + continue + # Search action with sources list + for src in action.get("sources") or []: + if not isinstance(src, dict): + continue + url = src.get("url") or "" + if not url or url in seen: + continue + seen.add(url) + out.append({ + "url": url, + "title": src.get("title") or "", + }) + # Page-open action — single URL + if action.get("type") == "open_page": + url = action.get("url") or "" + if url and url not in seen: + seen.add(url) + out.append({"url": url, "title": ""}) + + # Fallback: harvest URLs from message annotations. Multi-agent + # responses publish citations only here. We dedupe against the + # web_search_call sources collected above so single-agent paths + # remain unchanged. + for item in response_json.get("output") or []: + if not isinstance(item, dict) or item.get("type") != "message": + continue + for content in item.get("content") or []: + if not isinstance(content, dict): + continue + for ann in content.get("annotations") or []: + if not isinstance(ann, dict): + continue + if ann.get("type") not in (None, "url_citation"): + continue + url = ann.get("url") or "" + if not url or url in seen: + continue + seen.add(url) + title = ann.get("title") or "" + # Multi-agent annotations sometimes set title=url; strip + # the duplicate so the source list reads cleanly. + if title == url: + title = "" + out.append({"url": url, "title": title}) + return out + + +def format_search_sources_suffix(search_sources: list[dict[str, Any]] | None) -> str: + """Format collected search sources as a ``## Sources`` markdown section. + + Returns ``""`` when ``features.show_search_sources`` is disabled or the + input list is empty. Mirrors :meth:`xai_chat.StreamAdapter.references_suffix` + so that text-parsing clients (which can't read the structured + ``search_sources`` field) see identical formatting across both the + grok.com app-chat and console.x.ai paths. + + The leading ``[grok2api-sources]: #`` marker is a markdown link reference + definition that renderers ignore; multi-turn handlers use it to identify + and strip prior-turn ``## Sources`` blocks. + """ + if not search_sources: + return "" + if not get_config().get_bool("features.show_search_sources", False): + return "" + lines = ["\n\n## Sources", "[grok2api-sources]: #"] + for item in search_sources: + url = (item or {}).get("url") or "" + if not url: + continue + title = (item or {}).get("title") or url + title = title.replace("\\", "\\\\").replace("[", "\\[").replace("]", "\\]") + lines.append(f"- [{title}]({url})") + if len(lines) == 2: + return "" + return "\n".join(lines) + "\n" + + +def inject_web_search_tool(tools: list[dict[str, Any]] | None, ) -> list[dict[str, Any]]: + """Ensure a ``web_search`` tool is present in the console tools list. + + If the user already supplied any ``web_search`` tool (with or without + options), it's left untouched. Otherwise a default ``{"type": + "web_search"}`` entry is appended. xAI charges $5/1000 calls for web + search; this is consumed from the account's prepaid (trial) credits. + """ + existing = list(tools or []) + for t in existing: + if isinstance(t, dict) and t.get("type") == "web_search": + return existing + existing.append({"type": "web_search"}) + return existing + + +def extract_console_annotations(response_json: dict[str, Any], ) -> list[dict[str, Any]]: + """Extract URL citation annotations from a non-streaming response. + + Returns a flat list of citation dicts in chat-completions format: + [{"url": "...", "title": "...", "start_index": 0, "end_index": 0}] + """ + out: list[dict[str, Any]] = [] + output = response_json.get("output") or [] + for item in output: + if not isinstance(item, dict): + continue + if item.get("type") != "message": + continue + contents = item.get("content") or [] + for c in contents: + if not isinstance(c, dict): + continue + anns = c.get("annotations") or [] + for a in anns: + if not isinstance(a, dict): + continue + if a.get("type") not in (None, "url_citation"): + continue + url = a.get("url") or "" + if not url: + continue + out.append( + { + "url": url, + "title": a.get("title") or "", + "start_index": int(a.get("start_index") or 0), + "end_index": int(a.get("end_index") or 0), + }) + return out + + +def extract_console_usage(response_json: dict[str, Any]) -> dict[str, int]: + """Extract usage tokens from a non-streaming response.""" + usage = response_json.get("usage") or {} + return { + "prompt_tokens": int(usage.get("input_tokens") or 0), + "completion_tokens": int(usage.get("output_tokens") or 0), + "total_tokens": int(usage.get("total_tokens") or 0), + "reasoning_tokens": int( + (usage.get("output_tokens_details") or {}).get("reasoning_tokens") or + usage.get("reasoning_tokens") or 0), + } + + +def parse_console_error(status_code: int, body: str) -> UpstreamError: + """Convert a non-200 console response into an UpstreamError.""" + message = f"Console upstream returned {status_code}" + try: + obj = orjson.loads(body) if body else {} + if isinstance(obj, dict): + err = obj.get("error") or obj.get("code") or "" + if isinstance(err, dict): + err = err.get("message") or err.get("code") or "" + if err: + message = f"{message}: {err}" + except (orjson.JSONDecodeError, ValueError, TypeError): + pass + return UpstreamError(message, status=status_code, body=body[:400]) + + +# --------------------------------------------------------------------------- +# SSE streaming event parsing +# --------------------------------------------------------------------------- + + +def classify_console_sse_line(line: str | bytes) -> tuple[str, str]: + """Return (kind, payload) for one raw SSE line. + + kind: + - 'data' — SSE data line; payload is the JSON string + - 'event' — SSE event name line; payload is the event name + - 'skip' — comment / blank / unrecognized + """ + if isinstance(line, bytes): + line = line.decode("utf-8", "replace") + line = line.strip() + if not line: + return "skip", "" + if line.startswith("event:"): + return "event", line[6:].strip() + if line.startswith("data:"): + data = line[5:].strip() + return "data", data + if line.startswith("{"): + return "data", line + return "skip", "" + + +class ConsoleStreamAdapter: + """Parse upstream Console SSE frames and emit text/reasoning/tool deltas. + + The console.x.ai SSE protocol uses OpenAI Responses API event names: + - response.created + - response.output_item.added ← announces a new item + - response.content_part.added + - response.output_text.delta ← text chunks + - response.output_text.done + - response.reasoning_summary_text.delta ← reasoning chunks + - response.function_call_arguments.delta ← tool args streaming + - response.function_call_arguments.done ← tool args complete + - response.output_item.done ← completed item + - response.output_text.annotation.added ← citation annotation + - response.completed + - response.failed / response.cancelled / response.error + """ + + __slots__ = ( + "_current_event", + "_active_tool_index", + "_tool_args_buf", + "_seen_source_urls", + "tool_calls", + "annotations", + "search_sources", + "text_buf", + "thinking_buf", + "_usage", + ) + + def __init__(self) -> None: + self._current_event: str = "" + self._active_tool_index: dict[str, int] = {} # item_id → index + self._tool_args_buf: dict[str, list[str]] = {} # item_id → args chunks + self._seen_source_urls: set[str] = set() + self.tool_calls: list[dict[str, Any]] = [] + self.annotations: list[dict[str, Any]] = [] + self.search_sources: list[dict[str, Any]] = [] + self.text_buf: list[str] = [] + self.thinking_buf: list[str] = [] + self._usage: dict[str, int] = {} + + def references_suffix(self) -> str: + """Return the ``## Sources`` markdown block for the collected sources. + + Returns ``""`` when ``features.show_search_sources`` is disabled or + no sources were collected. Shared formatting with the grok.com path + via :func:`format_search_sources_suffix`. + """ + return format_search_sources_suffix(self.search_sources) + + def feed_event(self, event_name: str) -> None: + """Record the most recent ``event:`` name from the SSE stream.""" + self._current_event = event_name + + def feed_data(self, data: str) -> dict[str, Any]: + """Parse one SSE data frame; return the kind/content classification. + + Returns a dict like: + {"kind": "text", "content": "Two"} + {"kind": "thinking", "content": "Let me think..."} + {"kind": "tool_call_start", "index": 0, "call_id": "...", "name": "..."} + {"kind": "tool_call_args", "index": 0, "delta": "..."} + {"kind": "tool_call_done", "index": 0} + {"kind": "annotation", "annotation_data": {...}} + {"kind": "done"} + {"kind": "error", "message": "..."} + {"kind": "skip"} + """ + if not data or data == "[DONE]": + return {"kind": "done"} + try: + obj = orjson.loads(data) + except (orjson.JSONDecodeError, ValueError, TypeError): + return {"kind": "skip"} + if not isinstance(obj, dict): + return {"kind": "skip"} + + # Event-specific dispatch (event: line precedes data: line in SSE). + ev = self._current_event or obj.get("type") or "" + + # ── Text delta ──────────────────────────────────────────────────────── + if ev == "response.output_text.delta" or obj.get("type") == "response.output_text.delta": + delta = obj.get("delta") or "" + if isinstance(delta, str) and delta: + self.text_buf.append(delta) + return {"kind": "text", "content": delta} + return {"kind": "skip"} + + # ── Reasoning summary delta (thinking) ──────────────────────────────── + if ev in ( + "response.reasoning_summary_text.delta", + "response.reasoning_summary.delta", + ) or obj.get("type") in ( + "response.reasoning_summary_text.delta", + "response.reasoning_summary.delta", + ): + delta = obj.get("delta") or "" + if isinstance(delta, str) and delta: + self.thinking_buf.append(delta) + return {"kind": "thinking", "content": delta} + return {"kind": "skip"} + + # ── Tool call start (output_item.added with type=function_call) ────── + if ev == "response.output_item.added" or obj.get("type") == "response.output_item.added": + item = obj.get("item") or {} + if isinstance(item, dict) and item.get("type") == "function_call": + item_id = item.get("id") or item.get("call_id") or "" + call_id = item.get("call_id") or item_id + name = item.get("name") or "" + idx = len(self.tool_calls) + self._active_tool_index[item_id] = idx + self._tool_args_buf[item_id] = [] + self.tool_calls.append( + { + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": "" + }, + }) + return { + "kind": "tool_call_start", + "index": idx, + "call_id": call_id, + "name": name, + } + return {"kind": "skip"} + + # ── Web search call done — collect sources ─────────────────────────── + if ev == "response.output_item.done" or obj.get("type") == "response.output_item.done": + item = obj.get("item") or {} + if isinstance(item, dict) and item.get("type") == "web_search_call": + action = item.get("action") or {} + if isinstance(action, dict): + for src in action.get("sources") or []: + if not isinstance(src, dict): + continue + url = src.get("url") or "" + if url and url not in self._seen_source_urls: + self._seen_source_urls.add(url) + self.search_sources.append({ + "url": url, + "title": src.get("title") or "", + }) + if action.get("type") == "open_page": + url = action.get("url") or "" + if url and url not in self._seen_source_urls: + self._seen_source_urls.add(url) + self.search_sources.append({ + "url": url, + "title": "", + }) + return {"kind": "skip"} + + # ── Tool call argument delta ────────────────────────────────────────── + if ev == "response.function_call_arguments.delta" or obj.get( + "type") == "response.function_call_arguments.delta": + item_id = obj.get("item_id") or "" + delta = obj.get("delta") or "" + if not isinstance(delta, str) or not delta: + return {"kind": "skip"} + idx = self._active_tool_index.get(item_id) + if idx is None: + return {"kind": "skip"} + self._tool_args_buf.setdefault(item_id, []).append(delta) + return {"kind": "tool_call_args", "index": idx, "delta": delta} + + # ── Tool call complete ──────────────────────────────────────────────── + if ev == "response.function_call_arguments.done" or obj.get( + "type") == "response.function_call_arguments.done": + item_id = obj.get("item_id") or "" + idx = self._active_tool_index.get(item_id) + if idx is None: + return {"kind": "skip"} + # Prefer upstream-provided final arguments string when present. + final_args = obj.get("arguments") + if not isinstance(final_args, str) or not final_args: + final_args = "".join(self._tool_args_buf.get(item_id, [])) + self.tool_calls[idx]["function"]["arguments"] = final_args + return {"kind": "tool_call_done", "index": idx} + + # ── URL citation annotation ─────────────────────────────────────────── + if ev == "response.output_text.annotation.added" or obj.get( + "type") == "response.output_text.annotation.added": + ann = obj.get("annotation") or {} + if isinstance(ann, dict) and ann.get("type") in (None, "url_citation"): + url = ann.get("url") or "" + if url: + title = ann.get("title") or "" + if title == url: + # Multi-agent often duplicates URL into title; clean it. + title = "" + record = { + "url": url, + "title": title, + "start_index": int(ann.get("start_index") or 0), + "end_index": int(ann.get("end_index") or 0), + } + self.annotations.append(record) + # Fallback for multi-agent: harvest citation URL into + # search_sources too. Dedupe against web_search_call + # sources to avoid duplicating single-agent entries. + if url not in self._seen_source_urls: + self._seen_source_urls.add(url) + self.search_sources.append({ + "url": url, + "title": title, + }) + return {"kind": "annotation", "annotation_data": record} + return {"kind": "skip"} + + # ── Final completion frame — capture usage for accounting ──────────── + if ev == "response.completed" or obj.get("type") == "response.completed": + resp = obj.get("response") or obj + usage = resp.get("usage") or {} + if usage: + self._usage = { + "prompt_tokens": int(usage.get("input_tokens") or 0), + "completion_tokens": int(usage.get("output_tokens") or 0), + "total_tokens": int(usage.get("total_tokens") or 0), + "reasoning_tokens": int( + (usage.get("output_tokens_details") or {}).get("reasoning_tokens") or + usage.get("reasoning_tokens") or 0), + } + return {"kind": "done"} + + # ── Error frames ────────────────────────────────────────────────────── + if ev in ("response.failed", "response.error", "error") or obj.get("type") in ( + "response.failed", + "response.error", + "error", + ): + err = obj.get("error") or obj.get("response", {}).get("error") or {} + if isinstance(err, dict): + msg = err.get("message") or err.get("code") or "Console stream error" + else: + msg = str(err) or "Console stream error" + return {"kind": "error", "message": str(msg)} + + return {"kind": "skip"} + + @property + def usage(self) -> dict[str, int]: + """Return collected usage tokens (populated after stream completion).""" + return dict(self._usage) + + +__all__ = [ + "build_console_input", + "build_console_payload", + "convert_openai_tools_to_console", + "convert_openai_tool_choice", + "inject_web_search_tool", + "extract_console_text", + "extract_console_reasoning", + "extract_console_tool_calls", + "extract_console_annotations", + "extract_console_search_sources", + "extract_console_usage", + "format_search_sources_suffix", + "parse_console_error", + "classify_console_sse_line", + "ConsoleStreamAdapter", +] diff --git a/app/dataplane/reverse/runtime/endpoint_table.py b/app/dataplane/reverse/runtime/endpoint_table.py index bb390e503..37bb14990 100644 --- a/app/dataplane/reverse/runtime/endpoint_table.py +++ b/app/dataplane/reverse/runtime/endpoint_table.py @@ -8,47 +8,62 @@ hosts (accounts.x.ai, grok.com with gRPC path), listed separately. """ -BASE = "https://grok.com" +BASE = "https://grok.com" ASSETS_CDN = "https://assets.grok.com" +CONSOLE_BASE = "https://console.x.ai" + +# ── Console API (SSO-shared with grok.com, supports all models) ──────── +CONSOLE_RESPONSES = f"{CONSOLE_BASE}/v1/responses" # ── App-chat (SSE streaming, new conversation) ────────────────────────── -CHAT = f"{BASE}/rest/app-chat/conversations/new" +CHAT = f"{BASE}/rest/app-chat/conversations/new" # ── Asset management ───────────────────────────────────────────────────── -ASSETS_UPLOAD = f"{BASE}/rest/app-chat/upload-file" # POST (base64 upload) -ASSETS_LIST = f"{BASE}/rest/assets" # GET -ASSETS_DELETE = f"{BASE}/rest/assets-metadata" # DELETE /{asset_id} -ASSETS_DOWNLOAD = ASSETS_CDN # GET /{path} +ASSETS_UPLOAD = f"{BASE}/rest/app-chat/upload-file" # POST (base64 upload) +ASSETS_LIST = f"{BASE}/rest/assets" # GET +ASSETS_DELETE = f"{BASE}/rest/assets-metadata" # DELETE /{asset_id} +ASSETS_DOWNLOAD = ASSETS_CDN # GET /{path} # ── Rate limits (usage / quota sync) ───────────────────────────────────── -RATE_LIMITS = f"{BASE}/rest/rate-limits" # POST +RATE_LIMITS = f"{BASE}/rest/rate-limits" # POST # ── gRPC-Web endpoints ────────────────────────────────────────────────── -ACCEPT_TOS = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion" -NSFW_MGMT = f"{BASE}/auth_mgmt.AuthManagement/UpdateUserFeatureControls" +ACCEPT_TOS = "https://accounts.x.ai/auth_mgmt.AuthManagement/SetTosAcceptedVersion" +NSFW_MGMT = f"{BASE}/auth_mgmt.AuthManagement/UpdateUserFeatureControls" # ── Auth REST ──────────────────────────────────────────────────────────── -SET_BIRTH = f"{BASE}/rest/auth/set-birth-date" # POST +SET_BIRTH = f"{BASE}/rest/auth/set-birth-date" # POST # ── Media (video) ──────────────────────────────────────────────────────── -MEDIA_POST = f"{BASE}/rest/media/post/create" # POST -MEDIA_POST_LINK = f"{BASE}/rest/media/post/create-link" # POST -VIDEO_UPSCALE = f"{BASE}/rest/media/video/upscale" # POST +MEDIA_POST = f"{BASE}/rest/media/post/create" # POST +MEDIA_POST_LINK = f"{BASE}/rest/media/post/create-link" # POST +VIDEO_UPSCALE = f"{BASE}/rest/media/video/upscale" # POST # ── WebSocket endpoints ───────────────────────────────────────────────── -WS_IMAGINE = "wss://grok.com/ws/imagine/listen" -WS_LIVEKIT = "wss://livekit.grok.com" +WS_IMAGINE = "wss://grok.com/ws/imagine/listen" +WS_LIVEKIT = "wss://livekit.grok.com" # ── LiveKit ───────────────────────────────────────────────────────────── -LIVEKIT_TOKENS = f"{BASE}/rest/livekit/tokens" # POST - +LIVEKIT_TOKENS = f"{BASE}/rest/livekit/tokens" # POST __all__ = [ - "BASE", "ASSETS_CDN", + "BASE", + "ASSETS_CDN", + "CONSOLE_BASE", + "CONSOLE_RESPONSES", "CHAT", - "ASSETS_UPLOAD", "ASSETS_LIST", "ASSETS_DELETE", "ASSETS_DOWNLOAD", + "ASSETS_UPLOAD", + "ASSETS_LIST", + "ASSETS_DELETE", + "ASSETS_DOWNLOAD", "RATE_LIMITS", - "ACCEPT_TOS", "NSFW_MGMT", "SET_BIRTH", - "MEDIA_POST", "MEDIA_POST_LINK", "VIDEO_UPSCALE", - "WS_IMAGINE", "WS_LIVEKIT", "LIVEKIT_TOKENS", + "ACCEPT_TOS", + "NSFW_MGMT", + "SET_BIRTH", + "MEDIA_POST", + "MEDIA_POST_LINK", + "VIDEO_UPSCALE", + "WS_IMAGINE", + "WS_LIVEKIT", + "LIVEKIT_TOKENS", ] diff --git a/app/products/anthropic/messages.py b/app/products/anthropic/messages.py index 42ba98a17..31b77d663 100644 --- a/app/products/anthropic/messages.py +++ b/app/products/anthropic/messages.py @@ -24,23 +24,33 @@ from app.control.account.enums import FeedbackKind from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter from app.dataplane.reverse.protocol.tool_prompt import ( - build_tool_system_prompt, extract_tool_names, inject_into_message, + build_tool_system_prompt, + extract_tool_names, + inject_into_message, ) from app.dataplane.reverse.protocol.tool_parser import parse_tool_calls from app.products.openai.chat import ( - _stream_chat, _extract_message, _resolve_image, - _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, - _configured_retry_codes, _should_retry_upstream, + _stream_chat, + _extract_message, + _resolve_image, + _quota_sync, + _fail_sync, + _parse_retry_codes, + _feedback_kind, + _log_task_exception, + _configured_retry_codes, + _should_retry_upstream, + _console_completions, ) from app.products._account_selection import reserve_account, selection_max_retries from app.products.openai._tool_sieve import ToolSieve - # --------------------------------------------------------------------------- # ID helpers # --------------------------------------------------------------------------- + def _make_msg_id() -> str: return f"msg_{int(time.time() * 1000)}{os.urandom(4).hex()}" @@ -53,6 +63,7 @@ def _make_tool_id() -> str: # SSE encoding (Anthropic event format) # --------------------------------------------------------------------------- + def _sse(event: str, data: dict) -> str: return f"event: {event}\ndata: {orjson.dumps(data).decode()}\n\n" @@ -61,6 +72,7 @@ def _sse(event: str, data: dict) -> str: # Request conversion: Anthropic → internal format # --------------------------------------------------------------------------- + def _anthropic_content_to_internal(content: Any, role: str) -> list[dict]: """Convert Anthropic content (string or block list) to internal message list. @@ -74,16 +86,10 @@ def _anthropic_content_to_internal(content: Any, role: str) -> list[dict]: return [] # Check if content contains tool_use blocks (assistant calling tools) - has_tool_use = any( - isinstance(b, dict) and b.get("type") == "tool_use" - for b in content - ) + has_tool_use = any(isinstance(b, dict) and b.get("type") == "tool_use" for b in content) # Check if content contains tool_result blocks (user returning results) - tool_result_blocks = [ - b for b in content - if isinstance(b, dict) and b.get("type") == "tool_result" - ] + tool_result_blocks = [b for b in content if isinstance(b, dict) and b.get("type") == "tool_result"] if tool_result_blocks: # Each tool_result → a separate tool-role message @@ -93,14 +99,13 @@ def _anthropic_content_to_internal(content: Any, role: str) -> list[dict]: if isinstance(result_content, list): # array of text blocks → join result_content = "\n".join( - b.get("text", "") for b in result_content - if isinstance(b, dict) and b.get("type") == "text" - ) - messages.append({ - "role": "tool", - "tool_call_id": block.get("tool_use_id", ""), - "content": result_content or "", - }) + b.get("text", "") for b in result_content if isinstance(b, dict) and b.get("type") == "text") + messages.append( + { + "role": "tool", + "tool_call_id": block.get("tool_use_id", ""), + "content": result_content or "", + }) return messages if has_tool_use: @@ -114,17 +119,18 @@ def _anthropic_content_to_internal(content: Any, role: str) -> list[dict]: if btype == "text": text_parts.append(block.get("text", "")) elif btype == "tool_use": - tool_calls.append({ - "id": block.get("id", _make_tool_id()), - "type": "function", - "function": { - "name": block.get("name", ""), - "arguments": orjson.dumps(block.get("input") or {}).decode(), - }, - }) + tool_calls.append( + { + "id": block.get("id", _make_tool_id()), + "type": "function", + "function": { + "name": block.get("name", ""), + "arguments": orjson.dumps(block.get("input") or {}).decode(), + }, + }) msg: dict = { - "role": "assistant", - "content": " ".join(text_parts) if text_parts else None, + "role": "assistant", + "content": " ".join(text_parts) if text_parts else None, "tool_calls": tool_calls, } return [msg] @@ -144,25 +150,31 @@ def _anthropic_content_to_internal(content: Any, role: str) -> list[dict]: src_type = source.get("type", "") if src_type == "base64": media = source.get("media_type", "image/jpeg") - data = source.get("data", "") + data = source.get("data", "") normalized.append({ - "type": "image_url", - "image_url": {"url": f"data:{media};base64,{data}"}, + "type": "image_url", + "image_url": { + "url": f"data:{media};base64,{data}" + }, }) elif src_type == "url": normalized.append({ - "type": "image_url", - "image_url": {"url": source.get("url", "")}, + "type": "image_url", + "image_url": { + "url": source.get("url", "") + }, }) elif btype == "document": source = block.get("source") or {} src_type = source.get("type", "") if src_type == "base64": media = source.get("media_type", "application/pdf") - data = source.get("data", "") + data = source.get("data", "") normalized.append({ "type": "file", - "file": {"data": f"data:{media};base64,{data}"}, + "file": { + "data": f"data:{media};base64,{data}" + }, }) if not normalized: @@ -172,7 +184,7 @@ def _anthropic_content_to_internal(content: Any, role: str) -> list[dict]: def _parse_anthropic_messages( messages: list[dict], - system: str | list | None, + system: str | list | None, ) -> list[dict]: """Convert Anthropic messages + system prompt to internal format.""" internal: list[dict] = [] @@ -183,16 +195,14 @@ def _parse_anthropic_messages( system_text = system elif isinstance(system, list): system_text = "\n".join( - b.get("text", "") for b in system - if isinstance(b, dict) and b.get("type") == "text" - ) + b.get("text", "") for b in system if isinstance(b, dict) and b.get("type") == "text") else: system_text = str(system) if system_text.strip(): internal.append({"role": "system", "content": system_text}) for msg in messages: - role = msg.get("role", "user") + role = msg.get("role", "user") content = msg.get("content", "") internal.extend(_anthropic_content_to_internal(content, role)) @@ -207,14 +217,15 @@ def _convert_tools(tools: list[dict]) -> list[dict]: """ result = [] for tool in tools: - result.append({ - "type": "function", - "function": { - "name": tool.get("name", ""), - "description": tool.get("description", ""), - "parameters": tool.get("input_schema"), - }, - }) + result.append( + { + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("input_schema"), + }, + }) return result @@ -239,57 +250,310 @@ def _convert_tool_choice(tool_choice: Any) -> Any: # Response format helpers # --------------------------------------------------------------------------- + def _finish_reason_to_stop_reason(finish_reason: str | None) -> str: mapping = {"stop": "end_turn", "tool_calls": "tool_use", "length": "max_tokens"} return mapping.get(finish_reason or "stop", "end_turn") def _build_message_response( - msg_id: str, - model: str, - content: list[dict], + msg_id: str, + model: str, + content: list[dict], stop_reason: str, - input_tokens: int, + input_tokens: int, output_tokens: int, ) -> dict: return { - "id": msg_id, - "type": "message", - "role": "assistant", - "model": model, - "content": content, - "stop_reason": stop_reason, + "id": msg_id, + "type": "message", + "role": "assistant", + "model": model, + "content": content, + "stop_reason": stop_reason, "stop_sequence": None, "usage": { - "input_tokens": input_tokens, + "input_tokens": input_tokens, "output_tokens": output_tokens, }, } +# --------------------------------------------------------------------------- +# Chat Completions → Anthropic Messages conversion (used for console dispatch) +# --------------------------------------------------------------------------- + + +def _chat_completion_to_anthropic( + chat_response: dict, + msg_id: str, + model: str, +) -> dict: + """Convert a Chat Completions response dict → Anthropic Messages response.""" + choice = (chat_response.get("choices") or [{}])[0] + message = choice.get("message") or {} + text = message.get("content") or "" + tool_calls = message.get("tool_calls") or [] + + content: list[dict] = [] + if text: + content.append({"type": "text", "text": text}) + for tc in tool_calls: + if not isinstance(tc, dict): + continue + fn = tc.get("function") or {} + try: + input_args = orjson.loads(fn.get("arguments") or "{}") + except (orjson.JSONDecodeError, ValueError, TypeError): + input_args = {} + content.append( + { + "type": "tool_use", + "id": tc.get("id") or _make_tool_id(), + "name": fn.get("name") or "", + "input": input_args, + }) + + if not content: + # Anthropic requires at least one content block + content = [{"type": "text", "text": ""}] + + finish = choice.get("finish_reason") + stop_reason = _finish_reason_to_stop_reason(finish) + + usage = chat_response.get("usage") or {} + return _build_message_response( + msg_id, + model, + content, + stop_reason, + input_tokens=int(usage.get("prompt_tokens", 0)), + output_tokens=int(usage.get("completion_tokens", 0)), + ) + + +async def _chat_stream_to_anthropic_sse( + chat_stream: AsyncGenerator[str, None], + msg_id: str, + model: str, +) -> AsyncGenerator[str, None]: + """Convert a Chat Completions SSE stream → Anthropic Messages SSE events. + + Maps Chat Completions chunks into Anthropic's message_start / + content_block_start / content_block_delta / content_block_stop / + message_delta / message_stop event sequence. Tool call argument deltas + are forwarded as ``input_json_delta`` events on the corresponding + tool_use content block. + """ + yield _sse( + "message_start", { + "type": "message_start", + "message": { + "id": msg_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "stop_sequence": None, + "usage": { + "input_tokens": 0, + "output_tokens": 0 + }, + }, + }) + + text_block_open = False + text_block_index = -1 + tool_blocks: dict[int, dict[str, Any]] = {} # tc_index → {block_index, id, name} + next_block_index = 0 + full_text = "" + output_tokens = 0 + final_stop_reason = "end_turn" + + async for chunk_line in chat_stream: + if not chunk_line.startswith("data:"): + continue + data_str = chunk_line[5:].strip() + if not data_str or data_str == "[DONE]": + continue + try: + chunk = orjson.loads(data_str) + except (orjson.JSONDecodeError, ValueError, TypeError): + continue + + choices = chunk.get("choices") or [] + if not choices: + usage = chunk.get("usage") or {} + if usage: + output_tokens = int(usage.get("completion_tokens") or 0) + continue + choice = choices[0] + delta = choice.get("delta") or {} + + # Text content delta + text = delta.get("content") + if isinstance(text, str) and text: + if not text_block_open: + text_block_index = next_block_index + next_block_index += 1 + yield _sse( + "content_block_start", { + "type": "content_block_start", + "index": text_block_index, + "content_block": { + "type": "text", + "text": "" + }, + }) + text_block_open = True + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": text_block_index, + "delta": { + "type": "text_delta", + "text": text + }, + }) + full_text += text + + # Tool call deltas + tool_calls_delta = delta.get("tool_calls") or [] + for tc in tool_calls_delta: + if not isinstance(tc, dict): + continue + tc_index = int(tc.get("index", 0)) + existing = tool_blocks.get(tc_index) + fn = tc.get("function") or {} + if existing is None: + # First chunk for this tool call — emit content_block_start + block_idx = next_block_index + next_block_index += 1 + tc_id = tc.get("id") or _make_tool_id() + tc_name = fn.get("name") or "" + tool_blocks[tc_index] = { + "block_index": block_idx, + "id": tc_id, + "name": tc_name, + } + yield _sse( + "content_block_start", { + "type": "content_block_start", + "index": block_idx, + "content_block": { + "type": "tool_use", + "id": tc_id, + "name": tc_name, + "input": {}, + }, + }) + args_delta = fn.get("arguments") or "" + if isinstance(args_delta, str) and args_delta: + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": tool_blocks[tc_index]["block_index"], + "delta": { + "type": "input_json_delta", + "partial_json": args_delta, + }, + }) + + # finish_reason + finish = choice.get("finish_reason") + if finish: + final_stop_reason = _finish_reason_to_stop_reason(finish) + + # Usage on final chunk + usage = chunk.get("usage") or {} + if usage: + output_tokens = int(usage.get("completion_tokens") or 0) + + # Close all open content blocks + if text_block_open: + yield _sse("content_block_stop", { + "type": "content_block_stop", + "index": text_block_index, + }) + for tc_info in tool_blocks.values(): + yield _sse("content_block_stop", { + "type": "content_block_stop", + "index": tc_info["block_index"], + }) + + if tool_blocks and final_stop_reason == "end_turn": + # If upstream didn't set finish_reason=tool_calls, force tool_use + final_stop_reason = "tool_use" + + yield _sse( + "message_delta", { + "type": "message_delta", + "delta": { + "stop_reason": final_stop_reason, + "stop_sequence": None + }, + "usage": { + "output_tokens": output_tokens or estimate_tokens(full_text), + }, + }) + yield _sse("message_stop", {"type": "message_stop"}) + + # --------------------------------------------------------------------------- # Main handler # --------------------------------------------------------------------------- + async def create( *, - model: str, - messages: list[dict], - system: str | list | None = None, - stream: bool, - emit_think: bool, - temperature: float, - top_p: float, - tools: list[dict] | None = None, - tool_choice: Any = None, + model: str, + messages: list[dict], + system: str | list | None = None, + stream: bool, + emit_think: bool, + temperature: float, + top_p: float, + tools: list[dict] | None = None, + tool_choice: Any = None, ) -> dict | AsyncGenerator[str, None]: - cfg = get_config() - spec = resolve_model(model) + cfg = get_config() + spec = resolve_model(model) mode_id = int(spec.mode_id) # Build internal message list internal_messages = _parse_anthropic_messages(messages, system) + + # ── Console API dispatch ───────────────────────────────────────────────── + # Models with `console_model` set route through console.x.ai/v1/responses + # via the Chat Completions bridge. This uses native function calling + # (no ToolSieve XML injection) and supports multimodal input. + if spec.is_console(): + chat_tools_arg = _convert_tools(tools) if tools else None + chat_tool_choice_arg = (_convert_tool_choice(tool_choice) if tools else None) + msg_id_for_console = _make_msg_id() + logger.info( + "console messages dispatch: model={} stream={} message_count={}", + model, + stream, + len(internal_messages), + ) + chat_result = await _console_completions( + spec=spec, + model=model, + messages=internal_messages, + is_stream=stream, + emit_think=emit_think, + temperature=temperature, + top_p=top_p, + tools=chat_tools_arg, + tool_choice=chat_tool_choice_arg, + ) + if stream: + return _chat_stream_to_anthropic_sse(chat_result, msg_id_for_console, model) + return _chat_completion_to_anthropic(chat_result, msg_id_for_console, model) + internal_message, files = _extract_message(internal_messages) if not internal_message.strip(): raise UpstreamError("Empty message after extraction", status=400) @@ -298,10 +562,10 @@ async def create( tool_names: list[str] = [] internal_tool_choice: Any = None if tools: - chat_tools = _convert_tools(tools) - tool_names = extract_tool_names(chat_tools) + chat_tools = _convert_tools(tools) + tool_names = extract_tool_names(chat_tools) internal_tool_choice = _convert_tool_choice(tool_choice) - tool_prompt = build_tool_system_prompt(chat_tools, internal_tool_choice) + tool_prompt = build_tool_system_prompt(chat_tools, internal_tool_choice) internal_message = inject_into_message(internal_message, tool_prompt) logger.info("messages tool injection: tool_names={} choice={}", tool_names, internal_tool_choice) @@ -312,8 +576,8 @@ async def create( max_retries = selection_max_retries() retry_codes = _configured_retry_codes(cfg) - timeout_s = cfg.get_float("chat.timeout", 120.0) - msg_id = _make_msg_id() + timeout_s = cfg.get_float("chat.timeout", 120.0) + msg_id = _make_msg_id() # ------------------------------------------------------------------------- # Streaming @@ -330,46 +594,50 @@ async def _run_stream() -> AsyncGenerator[str, None]: if acct is None: raise RateLimitError("No available accounts for this model tier") - token = acct.token + token = acct.token success = False - _retry = False + _retry = False fail_exc: BaseException | None = None - adapter = StreamAdapter() - think_buf: list[str] = [] - text_buf: list[str] = [] - think_started = False - think_closed = False - text_started = False - sieve = ToolSieve(tool_names) if tool_names else None - tool_calls_emitted = False - tool_output_tokens = 0 - block_index = 0 # tracks next content_block index + adapter = StreamAdapter() + think_buf: list[str] = [] + text_buf: list[str] = [] + think_started = False + think_closed = False + text_started = False + sieve = ToolSieve(tool_names) if tool_names else None + tool_calls_emitted = False + tool_output_tokens = 0 + block_index = 0 # tracks next content_block index collected_annotations: list[dict] = [] try: try: # message_start - yield _sse("message_start", { - "type": "message_start", - "message": { - "id": msg_id, - "type": "message", - "role": "assistant", - "model": model, - "content": [], - "stop_reason": None, - "usage": {"input_tokens": estimate_prompt_tokens(internal_message), "output_tokens": 0}, - }, - }) + yield _sse( + "message_start", { + "type": "message_start", + "message": { + "id": msg_id, + "type": "message", + "role": "assistant", + "model": model, + "content": [], + "stop_reason": None, + "usage": { + "input_tokens": estimate_prompt_tokens(internal_message), + "output_tokens": 0 + }, + }, + }) yield _sse("ping", {"type": "ping"}) ended = False async for line in _stream_chat( - token = token, - mode_id = ModeId(selected_mode_id), - message = internal_message, - files = files, - timeout_s = timeout_s, + token=token, + mode_id=ModeId(selected_mode_id), + message=internal_message, + files=files, + timeout_s=timeout_s, ): if tool_calls_emitted: break @@ -385,26 +653,35 @@ async def _run_stream() -> AsyncGenerator[str, None]: if ev.kind == "thinking" and emit_think and not think_closed: if not think_started: think_started = True - yield _sse("content_block_start", { - "type": "content_block_start", - "index": block_index, - "content_block": {"type": "thinking", "thinking": ""}, - }) + yield _sse( + "content_block_start", { + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "thinking", + "thinking": "" + }, + }) think_buf.append(ev.content) - yield _sse("content_block_delta", { - "type": "content_block_delta", - "index": block_index, - "delta": {"type": "thinking_delta", "thinking": ev.content}, - }) + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": block_index, + "delta": { + "type": "thinking_delta", + "thinking": ev.content + }, + }) elif ev.kind == "text": # Close thinking block if open if think_started and not think_closed: think_closed = True - yield _sse("content_block_stop", { - "type": "content_block_stop", - "index": block_index, - }) + yield _sse( + "content_block_stop", { + "type": "content_block_stop", + "index": block_index, + }) block_index += 1 # Feed through ToolSieve if tools active @@ -413,28 +690,31 @@ async def _run_stream() -> AsyncGenerator[str, None]: if calls is not None: # Emit tool_use blocks for call in calls: - yield _sse("content_block_start", { - "type": "content_block_start", - "index": block_index, - "content_block": { - "type": "tool_use", - "id": call.call_id, - "name": call.name, - "input": {}, - }, - }) - yield _sse("content_block_delta", { - "type": "content_block_delta", - "index": block_index, - "delta": { - "type": "input_json_delta", - "partial_json": call.arguments, - }, - }) - yield _sse("content_block_stop", { - "type": "content_block_stop", - "index": block_index, - }) + yield _sse( + "content_block_start", { + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "tool_use", + "id": call.call_id, + "name": call.name, + "input": {}, + }, + }) + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": block_index, + "delta": { + "type": "input_json_delta", + "partial_json": call.arguments, + }, + }) + yield _sse( + "content_block_stop", { + "type": "content_block_stop", + "index": block_index, + }) block_index += 1 tool_output_tokens = estimate_tool_call_tokens(calls) tool_calls_emitted = True @@ -447,17 +727,25 @@ async def _run_stream() -> AsyncGenerator[str, None]: if text_chunk: if not text_started: text_started = True - yield _sse("content_block_start", { - "type": "content_block_start", - "index": block_index, - "content_block": {"type": "text", "text": ""}, - }) + yield _sse( + "content_block_start", { + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "text", + "text": "" + }, + }) text_buf.append(text_chunk) - yield _sse("content_block_delta", { - "type": "content_block_delta", - "index": block_index, - "delta": {"type": "text_delta", "text": text_chunk}, - }) + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": block_index, + "delta": { + "type": "text_delta", + "text": text_chunk + }, + }) elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) @@ -475,35 +763,39 @@ async def _run_stream() -> AsyncGenerator[str, None]: if calls: # Close text block if open if text_started: - yield _sse("content_block_stop", { - "type": "content_block_stop", - "index": block_index, - }) + yield _sse( + "content_block_stop", { + "type": "content_block_stop", + "index": block_index, + }) block_index += 1 text_started = False for call in calls: - yield _sse("content_block_start", { - "type": "content_block_start", - "index": block_index, - "content_block": { - "type": "tool_use", - "id": call.call_id, - "name": call.name, - "input": {}, - }, - }) - yield _sse("content_block_delta", { - "type": "content_block_delta", - "index": block_index, - "delta": { - "type": "input_json_delta", - "partial_json": call.arguments, - }, - }) - yield _sse("content_block_stop", { - "type": "content_block_stop", - "index": block_index, - }) + yield _sse( + "content_block_start", { + "type": "content_block_start", + "index": block_index, + "content_block": { + "type": "tool_use", + "id": call.call_id, + "name": call.name, + "input": {}, + }, + }) + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": block_index, + "delta": { + "type": "input_json_delta", + "partial_json": call.arguments, + }, + }) + yield _sse( + "content_block_stop", { + "type": "content_block_stop", + "index": block_index, + }) block_index += 1 tool_output_tokens = estimate_tool_call_tokens(calls) tool_calls_emitted = True @@ -514,16 +806,20 @@ async def _run_stream() -> AsyncGenerator[str, None]: sources = adapter.search_sources_list() if sources: tool_delta["search_sources"] = sources - yield _sse("message_delta", { - "type": "message_delta", - "delta": tool_delta, - "usage": {"output_tokens": tool_output_tokens}, - }) + yield _sse( + "message_delta", { + "type": "message_delta", + "delta": tool_delta, + "usage": { + "output_tokens": tool_output_tokens + }, + }) yield _sse("message_stop", {"type": "message_stop"}) yield "data: [DONE]\n\n" success = True - logger.info("messages stream tool_calls: attempt={}/{} model={}", - attempt + 1, max_retries + 1, model) + logger.info( + "messages stream tool_calls: attempt={}/{} model={}", attempt + 1, max_retries + 1, + model) else: # Resolve image attachments and references for url, img_id in adapter.image_urls: @@ -532,37 +828,47 @@ async def _run_stream() -> AsyncGenerator[str, None]: chunk = img_text + "\n" text_buf.append(chunk) if text_started: - yield _sse("content_block_delta", { - "type": "content_block_delta", - "index": block_index, - "delta": {"type": "text_delta", "text": chunk}, - }) + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": block_index, + "delta": { + "type": "text_delta", + "text": chunk + }, + }) references = adapter.references_suffix() if references: text_buf.append(references) if text_started: - yield _sse("content_block_delta", { - "type": "content_block_delta", - "index": block_index, - "delta": {"type": "text_delta", "text": references}, - }) + yield _sse( + "content_block_delta", { + "type": "content_block_delta", + "index": block_index, + "delta": { + "type": "text_delta", + "text": references + }, + }) # Close open blocks if think_started and not think_closed: - yield _sse("content_block_stop", { - "type": "content_block_stop", - "index": block_index, - }) + yield _sse( + "content_block_stop", { + "type": "content_block_stop", + "index": block_index, + }) block_index += 1 if text_started: - yield _sse("content_block_stop", { - "type": "content_block_stop", - "index": block_index, - }) + yield _sse( + "content_block_stop", { + "type": "content_block_stop", + "index": block_index, + }) - full_text = "".join(text_buf) + full_text = "".join(text_buf) full_think = "".join(think_buf) out_tokens = estimate_tokens(full_text) if full_think: @@ -575,18 +881,25 @@ async def _run_stream() -> AsyncGenerator[str, None]: msg_delta["search_sources"] = sources if collected_annotations: msg_delta["annotations"] = collected_annotations - yield _sse("message_delta", { - "type": "message_delta", - "delta": msg_delta, - "usage": {"output_tokens": out_tokens}, - }) + yield _sse( + "message_delta", { + "type": "message_delta", + "delta": msg_delta, + "usage": { + "output_tokens": out_tokens + }, + }) yield _sse("message_stop", {"type": "message_stop"}) yield "data: [DONE]\n\n" success = True logger.info( "messages stream completed: attempt={}/{} model={} text_len={} think_len={} images={}", - attempt + 1, max_retries + 1, model, - len(full_text), len(full_think), len(adapter.image_urls), + attempt + 1, + max_retries + 1, + model, + len(full_text), + len(full_think), + len(adapter.image_urls), ) except UpstreamError as exc: @@ -595,7 +908,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: _retry = True logger.warning( "messages stream retry: attempt={}/{} status={} token={}...", - attempt + 1, max_retries, exc.status, token[:8], + attempt + 1, + max_retries, + exc.status, + token[:8], ) else: raise @@ -603,15 +919,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: finally: await directory.release(acct) kind = ( - FeedbackKind.SUCCESS if success - else _feedback_kind(fail_exc) if fail_exc - else FeedbackKind.SERVER_ERROR - ) + FeedbackKind.SUCCESS + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, + selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: return @@ -624,8 +940,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: # Non-streaming # ------------------------------------------------------------------------- excluded: list[str] = [] - token = "" - adapter = StreamAdapter() + token = "" + adapter = StreamAdapter() for attempt in range(max_retries + 1): acct, selected_mode_id = await reserve_account( @@ -637,21 +953,21 @@ async def _run_stream() -> AsyncGenerator[str, None]: if acct is None: raise RateLimitError("No available accounts for this model tier") - token = acct.token - success = False - _retry = False + token = acct.token + success = False + _retry = False fail_exc: BaseException | None = None - adapter = StreamAdapter() + adapter = StreamAdapter() try: try: ended = False async for line in _stream_chat( - token = token, - mode_id = ModeId(selected_mode_id), - message = internal_message, - files = files, - timeout_s = timeout_s, + token=token, + mode_id=ModeId(selected_mode_id), + message=internal_message, + files=files, + timeout_s=timeout_s, ): event_type, data = classify_line(line) if event_type == "done": @@ -672,7 +988,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: _retry = True logger.warning( "messages retry: attempt={}/{} status={} token={}...", - attempt + 1, max_retries, exc.status, token[:8], + attempt + 1, + max_retries, + exc.status, + token[:8], ) else: raise @@ -680,15 +999,14 @@ async def _run_stream() -> AsyncGenerator[str, None]: finally: await directory.release(acct) kind = ( - FeedbackKind.SUCCESS if success - else _feedback_kind(fail_exc) if fail_exc - else FeedbackKind.SERVER_ERROR - ) + FeedbackKind.SUCCESS + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: break @@ -714,7 +1032,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: full_think = ("".join(adapter.thinking_buf) or "") if emit_think else "" - in_tokens = estimate_prompt_tokens(internal_message) + in_tokens = estimate_prompt_tokens(internal_message) out_tokens = estimate_tokens(full_text) if full_think: out_tokens += estimate_tokens(full_think) @@ -729,12 +1047,13 @@ async def _run_stream() -> AsyncGenerator[str, None]: parsed_input = orjson.loads(call.arguments) except (orjson.JSONDecodeError, ValueError): parsed_input = {} - content.append({ - "type": "tool_use", - "id": call.call_id, - "name": call.name, - "input": parsed_input, - }) + content.append( + { + "type": "tool_use", + "id": call.call_id, + "name": call.name, + "input": parsed_input, + }) ct = estimate_tool_call_tokens(tc_result.calls) logger.info("messages tool_calls: model={} calls={}", model, len(tc_result.calls)) resp = _build_message_response(msg_id, model, content, "tool_use", in_tokens, ct) @@ -746,7 +1065,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: logger.info( "messages request completed: model={} text_len={} think_len={} images={}", - model, len(full_text), len(full_think), len(adapter.image_urls), + model, + len(full_text), + len(full_think), + len(adapter.image_urls), ) content = [{"type": "text", "text": full_text}] diff --git a/app/products/openai/chat.py b/app/products/openai/chat.py index 7a551b13e..5d4404639 100644 --- a/app/products/openai/chat.py +++ b/app/products/openai/chat.py @@ -35,8 +35,25 @@ classify_line, StreamAdapter, ) +from app.dataplane.reverse.protocol.xai_console import ( + build_console_input, + build_console_payload, + classify_console_sse_line, + ConsoleStreamAdapter, + convert_openai_tool_choice, + convert_openai_tools_to_console, + extract_console_annotations, + extract_console_reasoning, + extract_console_search_sources, + extract_console_text, + extract_console_tool_calls, + extract_console_usage, + format_search_sources_suffix, + inject_web_search_tool, + parse_console_error, +) from app.dataplane.reverse.protocol.xai_usage import is_invalid_credentials_error -from app.dataplane.reverse.runtime.endpoint_table import CHAT +from app.dataplane.reverse.runtime.endpoint_table import CHAT, CONSOLE_RESPONSES from app.dataplane.reverse.transport.asset_upload import upload_from_input from app.dataplane.reverse.protocol.tool_prompt import ( build_tool_system_prompt, @@ -71,12 +88,8 @@ def _to_chat_annotations(anns: list[dict]) -> list[dict]: "start_index": a["start_index"], "end_index": a["end_index"], }, - } - for a in anns - ] - if anns - else [] - ) + } for a in anns + ] if anns else []) def _log_task_exception(task: "asyncio.Task") -> None: @@ -122,9 +135,7 @@ async def _quota_sync(token: str, mode_id: int) -> None: ) -async def _fail_sync( - token: str, mode_id: int, exc: BaseException | None = None -) -> None: +async def _fail_sync(token: str, mode_id: int, exc: BaseException | None = None) -> None: """Fire-and-forget: persist failure metadata after a failed call. In random mode this helper must not trigger upstream quota probes. It still @@ -135,10 +146,7 @@ async def _fail_sync( svc = get_refresh_service() if svc: await svc.record_failure_async(token, mode_id, exc) - if ( - current_strategy() == "quota" - and getattr(exc, "status", None) == 429 - ): + if (current_strategy() == "quota" and getattr(exc, "status", None) == 429): result = await svc.refresh_on_demand() logger.info( "account on-demand refresh triggered: token={}... mode_id={} refreshed={} failed={} rate_limited={}", @@ -236,9 +244,7 @@ async def _resolve_image(token: str, url: str, image_id: str) -> str: fmt = _normalize_image_format(cfg.get_str("features.image_format", "grok_url")) proxy_imagine_public = ( - _is_imagine_public_url(url) - and cfg.get_bool("features.imagine_public_image_proxy", False) - ) + _is_imagine_public_url(url) and cfg.get_bool("features.imagine_public_image_proxy", False)) # Formats that don't need downloading if fmt == "grok_url" and not proxy_imagine_public: @@ -250,9 +256,7 @@ async def _resolve_image(token: str, url: str, image_id: str) -> str: try: raw, mime = await _download_image_bytes(token, url) except Exception as exc: - logger.warning( - "chat image download failed: fallback_to=upstream_url error={}", exc - ) + logger.warning("chat image download failed: fallback_to=upstream_url error={}", exc) return url if fmt == "base64": @@ -262,11 +266,7 @@ async def _resolve_image(token: str, url: str, image_id: str) -> str: # local_url / local_md: save to disk and return local path file_id = await asyncio.to_thread(_save_image, raw, mime, image_id) app_url = cfg.get_str("app.app_url", "").rstrip("/") - local_url = ( - f"{app_url}/v1/files/image?id={file_id}" - if app_url - else f"/v1/files/image?id={file_id}" - ) + local_url = (f"{app_url}/v1/files/image?id={file_id}" if app_url else f"/v1/files/image?id={file_id}") if fmt in {"grok_url", "local_url"}: return local_url @@ -284,9 +284,7 @@ def _normalize_image_format(value: str | None) -> str: # 精确匹配 grok2api 注入的 Sources 段落(含标记行),用于多轮对话剥离 -_SOURCES_STRIP_RE = re.compile( - r"(?:^|\r?\n\r?\n)## Sources\r?\n\[grok2api-sources\]: #\r?\n[\s\S]*$" -) +_SOURCES_STRIP_RE = re.compile(r"(?:^|\r?\n\r?\n)## Sources\r?\n\[grok2api-sources\]: #\r?\n[\s\S]*$") def _strip_generated_artifacts(text: str, *, strip_sources: bool = False) -> str: @@ -311,9 +309,7 @@ def _extract_message(messages: list[dict]) -> tuple[str, list[str]]: # ── role=tool: tool execution result ───────────────────────────────── if role == "tool": tool_call_id = msg.get("tool_call_id", "") - label = ( - f"[tool result for {tool_call_id}]" if tool_call_id else "[tool result]" - ) + label = (f"[tool result for {tool_call_id}]" if tool_call_id else "[tool result]") text = content.strip() if isinstance(content, str) else "" if text: parts.append(f"{label}:\n{text}") @@ -422,9 +418,7 @@ async def _stream_chat( stream=True, ) except Exception as exc: - raise _transport_upstream_error( - exc, context="Chat transport failed" - ) from exc + raise _transport_upstream_error(exc, context="Chat transport failed") from exc if response.status_code != 200: try: @@ -441,9 +435,499 @@ async def _stream_chat( async for line in response.aiter_lines(): yield line except Exception as exc: - raise _transport_upstream_error( - exc, context="Chat stream read failed" - ) from exc + raise _transport_upstream_error(exc, context="Chat stream read failed") from exc + + +# --------------------------------------------------------------------------- +# Console API (console.x.ai/v1/responses) dispatch +# --------------------------------------------------------------------------- + + +async def _console_post( + *, + token: str, + console_model: str, + input: Any, + instructions: str, + stream: bool, + temperature: float | None, + top_p: float | None, + tools: list[dict] | None, + tool_choice: Any, + timeout_s: float, + reasoning_effort: str | None = None, +) -> Any: + """POST to console.x.ai/v1/responses; return ``(session, response)``. + + For ``stream=True`` the response object's ``aiter_lines()`` must be + consumed by the caller. For ``stream=False`` the caller should read + ``response.content`` and parse it as JSON. The caller is responsible + for closing the returned ``session`` via ``await session.__aexit__()``. + """ + proxy = await get_proxy_runtime() + lease = await proxy.acquire() + + payload = build_console_payload( + console_model=console_model, + input=input, + instructions=instructions, + stream=stream, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + ) + payload_bytes = orjson.dumps(payload) + + headers = build_http_headers( + token, + content_type="application/json", + origin="https://console.x.ai", + referer="https://console.x.ai/", + lease=lease, + ) + session_kwargs = build_session_kwargs(lease=lease) + + session = ResettableSession(**session_kwargs) + await session.__aenter__() + try: + response = await session.post( + CONSOLE_RESPONSES, + headers=headers, + data=payload_bytes, + timeout=timeout_s, + stream=stream, + ) + except Exception as exc: + await session.__aexit__(None, None, None) + raise _transport_upstream_error(exc, context="Console transport failed") from exc + + if response.status_code != 200: + try: + body = response.content.decode("utf-8", "replace")[:400] + except Exception: + body = "" + await session.__aexit__(None, None, None) + raise parse_console_error(response.status_code, body) + + return session, response + + +def _console_input_to_text(input_array: list[dict]) -> str: + """Flatten a console input array into plain text for token estimation.""" + parts: list[str] = [] + for item in input_array: + if not isinstance(item, dict): + continue + content = item.get("content") + if isinstance(content, str): + parts.append(content) + elif isinstance(content, list): + for block in content: + if not isinstance(block, dict): + continue + if block.get("type") in ("input_text", "output_text", "text"): + text = block.get("text") or "" + if text: + parts.append(text) + return "\n".join(parts) + + +async def _console_completions( + *, + spec, + model: str, + messages: list[dict], + is_stream: bool, + emit_think: bool, + temperature: float = 0.8, + top_p: float = 0.95, + reasoning_effort: str | None = None, + tools: list[dict] | None = None, + tool_choice: Any = None, +) -> dict | AsyncGenerator[str, None]: + """Dispatch a chat completion through console.x.ai/v1/responses. + + Used for models with ``spec.console_model`` set. SSO cookies from the + grok.com account pool authenticate console.x.ai requests, allowing + basic-tier (free) accounts to access all available models. + + Supports: + - Multimodal input (text + images) via OpenAI Responses-style content blocks + - Native function calling via the ``tools`` parameter + - SSE streaming for both text and tool call arguments + - URL citation annotations from upstream search results + """ + # Apply per-model default effort when caller didn't specify. Hybrid + # reasoning models (grok-4, grok-4.3, grok-4.20) default to "high" so + # callers expecting "think hard by default" get it without explicit opt-in. + if reasoning_effort is None and spec.default_reasoning_effort: + reasoning_effort = spec.default_reasoning_effort + cfg = get_config() + console_model = spec.console_model + + # Convert OpenAI messages → console structured input + instructions. + # System messages are folded into ``instructions`` for cleaner reasoning + # behaviour; text/image blocks become input_text/input_image; assistant + # tool_calls become function_call items; tool results become + # function_call_output items. + input_array, instructions = build_console_input(messages) + if not input_array and not instructions: + raise UpstreamError("Empty messages after conversion", status=400) + + # Convert OpenAI tools → console tools (flat name/description/parameters). + console_tools = convert_openai_tools_to_console(tools) if tools else None + console_tool_choice = ( + convert_openai_tool_choice(tool_choice) if console_tools and tool_choice is not None else None) + + # Always enable web search for console models — this is the primary + # reason for selecting the console route. Costs $5/1000 calls from + # the account's prepaid (trial) credits. Idempotent: existing + # ``web_search`` tool in the request is preserved. + console_tools = inject_web_search_tool(console_tools) + + from app.dataplane.account import _directory as _acct_dir + + if _acct_dir is None: + raise RateLimitError("Account directory not initialised") + directory = _acct_dir + + max_retries = selection_max_retries() + retry_codes = _configured_retry_codes(cfg) + response_id = make_response_id() + timeout_s = cfg.get_float("chat.timeout", 120.0) + prompt_text = _console_input_to_text(input_array) + + # ── Streaming path ──────────────────────────────────────────────────────── + if is_stream: + + async def _run_stream() -> AsyncGenerator[str, None]: + excluded: list[str] = [] + for attempt in range(max_retries + 1): + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, + ) + if acct is None: + raise RateLimitError("No available accounts for this model tier") + + token = acct.token + success = False + _retry = False + fail_exc: BaseException | None = None + adapter = ConsoleStreamAdapter() + tool_calls_emitted = False + + try: + try: + session, response = await _console_post( + token=token, + console_model=console_model, + input=input_array, + instructions=instructions, + stream=True, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=console_tools, + tool_choice=console_tool_choice, + timeout_s=timeout_s, + ) + try: + async for raw_line in response.aiter_lines(): + kind, payload = classify_console_sse_line(raw_line) + if kind == "event": + adapter.feed_event(payload) + continue + if kind != "data" or not payload: + continue + ev = adapter.feed_data(payload) + ev_kind = ev.get("kind") + if ev_kind == "text": + chunk = make_stream_chunk(response_id, model, ev["content"]) + yield f"data: {orjson.dumps(chunk).decode()}\n\n" + elif ev_kind == "thinking" and emit_think: + chunk = make_thinking_chunk(response_id, model, ev["content"]) + yield f"data: {orjson.dumps(chunk).decode()}\n\n" + elif ev_kind == "tool_call_start": + # First chunk for this tool call: id + name + empty args + tool_calls_emitted = True + chunk = make_tool_call_chunk( + response_id, + model, + ev["index"], + ev["call_id"], + ev["name"], + "", + is_first=True, + ) + yield f"data: {orjson.dumps(chunk).decode()}\n\n" + elif ev_kind == "tool_call_args": + # Subsequent chunks: incremental args delta + chunk = make_tool_call_chunk( + response_id, + model, + ev["index"], + "", + "", + ev["delta"], + is_first=False, + ) + yield f"data: {orjson.dumps(chunk).decode()}\n\n" + elif ev_kind == "tool_call_done": + # No-op; final done chunk is emitted after + # all events have completed. + pass + elif ev_kind == "error": + raise UpstreamError( + ev.get("message", "Console stream error"), + status=502, + ) + elif ev_kind == "done": + break + finally: + await session.__aexit__(None, None, None) + + # Stream completed — emit appropriate final chunk + if tool_calls_emitted: + done_chunk = make_tool_call_done_chunk(response_id, model) + if adapter.search_sources: + done_chunk["search_sources"] = list(adapter.search_sources) + yield f"data: {orjson.dumps(done_chunk).decode()}\n\n" + else: + chat_anns = ( + _to_chat_annotations(adapter.annotations) if adapter.annotations else None) + # Append ## Sources markdown block when + # features.show_search_sources is enabled (mirrors + # the grok.com path). Emitted as a separate text + # chunk before the final empty chunk so clients + # streaming raw deltas see the suffix in order. + references = adapter.references_suffix() + if references: + ref_chunk = make_stream_chunk(response_id, model, references) + yield f"data: {orjson.dumps(ref_chunk).decode()}\n\n" + final = make_stream_chunk( + response_id, + model, + "", + is_final=True, + annotations=chat_anns, + ) + # Inject search_sources at root level (parallel + # to grok.com path behaviour). Avoids putting + # them inside delta which would violate strict + # OpenAI schemas. + if adapter.search_sources: + final["search_sources"] = list(adapter.search_sources) + yield f"data: {orjson.dumps(final).decode()}\n\n" + yield "data: [DONE]\n\n" + success = True + logger.info( + "console stream completed: attempt={}/{} model={} text_len={} tool_calls={} sources={}", + attempt + 1, + max_retries + 1, + model, + sum(len(s) for s in adapter.text_buf), + len(adapter.tool_calls), + len(adapter.search_sources), + ) + + except UpstreamError as exc: + fail_exc = exc + if (_should_retry_upstream(exc, retry_codes) and attempt < max_retries): + _retry = True + logger.warning( + "console stream retry: attempt={}/{} status={} token={}...", + attempt + 1, + max_retries, + exc.status, + token[:8], + ) + else: + logger.warning( + "console stream failed: attempt={}/{} model={} status={} body={}", + attempt + 1, + max_retries + 1, + model, + exc.status, + _upstream_body_excerpt(exc), + ) + raise + + finally: + await directory.release(acct) + kind = ( + FeedbackKind.SUCCESS + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) + if success: + asyncio.create_task(_quota_sync( + token, selected_mode_id)).add_done_callback(_log_task_exception) + else: + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) + + if success or not _retry: + return + excluded.append(token) + + return _run_stream() + + # ── Non-streaming path ──────────────────────────────────────────────────── + excluded: list[str] = [] + full_text = "" + full_thinking = "" + response_tool_calls: list[dict] = [] + response_annotations: list[dict] = [] + response_search_sources: list[dict] = [] + usage: dict[str, int] = {} + for attempt in range(max_retries + 1): + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, + ) + if acct is None: + raise RateLimitError("No available accounts for this model tier") + + token = acct.token + success = False + _retry = False + fail_exc: BaseException | None = None + + try: + try: + session, response = await _console_post( + token=token, + console_model=console_model, + input=input_array, + instructions=instructions, + stream=False, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=console_tools, + tool_choice=console_tool_choice, + timeout_s=timeout_s, + ) + try: + body_bytes = response.content + if hasattr(body_bytes, "__await__"): + body_bytes = await body_bytes # type: ignore[misc] + finally: + await session.__aexit__(None, None, None) + + try: + response_json = orjson.loads(body_bytes) + except (orjson.JSONDecodeError, ValueError, TypeError) as exc: + raise UpstreamError( + f"Console returned non-JSON body: {exc}", + status=502, + body=str(body_bytes)[:400], + ) from exc + + full_text = extract_console_text(response_json) + full_thinking = (extract_console_reasoning(response_json) if emit_think else "") + response_tool_calls = extract_console_tool_calls(response_json) + response_annotations = extract_console_annotations(response_json) + response_search_sources = extract_console_search_sources(response_json) + usage = extract_console_usage(response_json) + success = True + + except UpstreamError as exc: + fail_exc = exc + if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: + _retry = True + logger.warning( + "console retry: attempt={}/{} status={} token={}...", + attempt + 1, + max_retries, + exc.status, + token[:8], + ) + else: + logger.warning( + "console request failed: attempt={}/{} model={} status={} body={}", + attempt + 1, + max_retries + 1, + model, + exc.status, + _upstream_body_excerpt(exc), + ) + raise + + finally: + await directory.release(acct) + kind = ( + FeedbackKind.SUCCESS + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) + if success: + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) + else: + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) + + if success or not _retry: + break + excluded.append(token) + + logger.info( + "console request completed: model={} text_len={} reasoning_len={} tool_calls={} sources={} usage={}", + model, + len(full_text), + len(full_thinking), + len(response_tool_calls), + len(response_search_sources), + usage, + ) + + # Use upstream usage when available; fall back to estimation otherwise. + pt = usage.get("prompt_tokens") or estimate_prompt_tokens(prompt_text) + ct = usage.get("completion_tokens") or estimate_tokens(full_text) + rt = usage.get("reasoning_tokens") or (estimate_tokens(full_thinking) if full_thinking else 0) + + # If upstream returned tool calls, return the tool_calls response variant. + if response_tool_calls: + from app.dataplane.reverse.protocol.tool_parser import ParsedToolCall + parsed_calls = [ + ParsedToolCall( + call_id=tc["id"], + name=tc["function"]["name"], + arguments=tc["function"]["arguments"], + ) for tc in response_tool_calls + ] + resp = make_tool_call_response( + model, + parsed_calls, + prompt_content=prompt_text, + response_id=response_id, + usage=build_usage(pt, ct + rt, reasoning_tokens=rt), + ) + if response_search_sources: + resp["search_sources"] = response_search_sources + return resp + + chat_anns = (_to_chat_annotations(response_annotations) if response_annotations else None) + # Append ## Sources markdown block to the body when + # features.show_search_sources is enabled (mirrors grok.com path). + references = format_search_sources_suffix(response_search_sources) + if references: + full_text = (full_text or "") + references + return make_chat_response( + model, + full_text, + prompt_content=prompt_text, + response_id=response_id, + reasoning_content=full_thinking or None, + search_sources=response_search_sources or None, + annotations=chat_anns, + usage=build_usage(pt, ct + rt, reasoning_tokens=rt), + ) async def completions( @@ -456,6 +940,7 @@ async def completions( tool_choice: Any = None, temperature: float = 0.8, top_p: float = 0.95, + reasoning_effort: str | None = None, request_overrides: dict | None = None, ) -> dict | AsyncGenerator[str, None]: """Entry point for /v1/chat/completions. @@ -477,6 +962,25 @@ async def completions( len(messages), ) + # ── Console API dispatch ────────────────────────────────────────────────── + # Models with `console_model` set route through console.x.ai/v1/responses + # using the same SSO cookies as grok.com, but support all models for + # basic-tier (free) accounts. Supports multimodal input and native + # function calling. + if spec.is_console(): + return await _console_completions( + spec=spec, + model=model, + messages=messages, + is_stream=is_stream, + emit_think=emit_think, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + ) + message, files = _extract_message(messages) if not message.strip(): raise UpstreamError("Empty message after extraction", status=400) @@ -528,13 +1032,13 @@ async def _run_stream() -> AsyncGenerator[str, None]: sieve = ToolSieve(tool_names) tool_calls_emitted = False async for line in _stream_chat( - token=token, - mode_id=ModeId(selected_mode_id), - message=message, - files=files, - tool_overrides=tool_overrides, - request_overrides=request_overrides, - timeout_s=timeout_s, + token=token, + mode_id=ModeId(selected_mode_id), + message=message, + files=files, + tool_overrides=tool_overrides, + request_overrides=request_overrides, + timeout_s=timeout_s, ): event_type, data = classify_line(line) if event_type == "done": @@ -549,9 +1053,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: if tool_names: safe_text, parsed_calls = sieve.feed(ev.content) if safe_text: - chunk = make_stream_chunk( - response_id, model, safe_text - ) + chunk = make_stream_chunk(response_id, model, safe_text) yield f"data: {orjson.dumps(chunk).decode()}\n\n" if parsed_calls is not None: for i, tc in enumerate(parsed_calls): @@ -565,9 +1067,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: is_first=True, ) yield f"data: {orjson.dumps(chunk).decode()}\n\n" - done_chunk = make_tool_call_done_chunk( - response_id, model - ) + done_chunk = make_tool_call_done_chunk(response_id, model) yield f"data: {orjson.dumps(done_chunk).decode()}\n\n" yield "data: [DONE]\n\n" tool_calls_emitted = True @@ -582,14 +1082,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: ended = True break # stop processing remaining events in this batch else: - chunk = make_stream_chunk( - response_id, model, ev.content - ) + chunk = make_stream_chunk(response_id, model, ev.content) yield f"data: {orjson.dumps(chunk).decode()}\n\n" elif ev.kind == "thinking" and emit_think: - chunk = make_thinking_chunk( - response_id, model, ev.content - ) + chunk = make_thinking_chunk(response_id, model, ev.content) yield f"data: {orjson.dumps(chunk).decode()}\n\n" elif ev.kind == "annotation" and ev.annotation_data: collected_annotations.append(ev.annotation_data) @@ -614,9 +1110,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: is_first=True, ) yield f"data: {orjson.dumps(chunk).decode()}\n\n" - done_chunk = make_tool_call_done_chunk( - response_id, model - ) + done_chunk = make_tool_call_done_chunk(response_id, model) # 注入结构化搜索信源(tool_calls 场景) sources = adapter.search_sources_list() if sources: @@ -634,16 +1128,12 @@ async def _run_stream() -> AsyncGenerator[str, None]: if not tool_calls_emitted: for url, img_id in adapter.image_urls: img_text = await _resolve_image(token, url, img_id) - chunk = make_stream_chunk( - response_id, model, img_text + "\n" - ) + chunk = make_stream_chunk(response_id, model, img_text + "\n") yield f"data: {orjson.dumps(chunk).decode()}\n\n" references = adapter.references_suffix() if references: - chunk = make_stream_chunk( - response_id, model, references - ) + chunk = make_stream_chunk(response_id, model, references) yield f"data: {orjson.dumps(chunk).decode()}\n\n" chat_anns = _to_chat_annotations(collected_annotations) @@ -671,10 +1161,7 @@ async def _run_stream() -> AsyncGenerator[str, None]: except UpstreamError as exc: fail_exc = exc - if ( - _should_retry_upstream(exc, retry_codes) - and attempt < max_retries - ): + if (_should_retry_upstream(exc, retry_codes) and attempt < max_retries): _retry = True logger.warning( "chat stream retry scheduled: attempt={}/{} status={} token={}...", @@ -698,22 +1185,14 @@ async def _run_stream() -> AsyncGenerator[str, None]: await directory.release(acct) kind = ( FeedbackKind.SUCCESS - if success - else _feedback_kind(fail_exc) - if fail_exc - else FeedbackKind.SERVER_ERROR - ) - await directory.feedback( - token, kind, selected_mode_id, now_s_val=now_s() - ) + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task( - _quota_sync(token, selected_mode_id) - ).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync( + token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task( - _fail_sync(token, selected_mode_id, fail_exc) - ).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: return @@ -744,13 +1223,13 @@ async def _run_stream() -> AsyncGenerator[str, None]: try: try: async for line in _stream_chat( - token=token, - mode_id=ModeId(selected_mode_id), - message=message, - files=files, - tool_overrides=tool_overrides, - request_overrides=request_overrides, - timeout_s=timeout_s, + token=token, + mode_id=ModeId(selected_mode_id), + message=message, + files=files, + tool_overrides=tool_overrides, + request_overrides=request_overrides, + timeout_s=timeout_s, ): event_type, data = classify_line(line) if event_type == "done": @@ -792,20 +1271,13 @@ async def _run_stream() -> AsyncGenerator[str, None]: await directory.release(acct) kind = ( FeedbackKind.SUCCESS - if success - else _feedback_kind(fail_exc) - if fail_exc - else FeedbackKind.SERVER_ERROR - ) + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task( - _quota_sync(token, selected_mode_id) - ).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task( - _fail_sync(token, selected_mode_id, fail_exc) - ).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: break diff --git a/app/products/openai/responses.py b/app/products/openai/responses.py index d816c7a92..4e8c5eaf9 100644 --- a/app/products/openai/responses.py +++ b/app/products/openai/responses.py @@ -18,24 +18,48 @@ from app.control.model.registry import resolve as resolve_model from app.control.account.enums import FeedbackKind from app.dataplane.reverse.protocol.xai_chat import classify_line, StreamAdapter +from app.dataplane.reverse.protocol.xai_console import ( + build_console_input, + convert_openai_tool_choice, + convert_openai_tools_to_console, + extract_console_usage, + inject_web_search_tool, +) from app.products._account_selection import reserve_account, selection_max_retries -from .chat import _stream_chat, _extract_message, _resolve_image, _quota_sync, _fail_sync, _parse_retry_codes, _feedback_kind, _log_task_exception, _upstream_body_excerpt +from .chat import ( + _stream_chat, + _extract_message, + _resolve_image, + _quota_sync, + _fail_sync, + _parse_retry_codes, + _feedback_kind, + _log_task_exception, + _upstream_body_excerpt, + _console_post, +) from .chat import _configured_retry_codes, _should_retry_upstream from ._format import ( - make_resp_id, build_resp_usage, make_resp_object, format_sse, + make_resp_id, + build_resp_usage, + make_resp_object, + format_sse, ) from app.dataplane.reverse.protocol.tool_prompt import ( - build_tool_system_prompt, extract_tool_names, inject_into_message, tool_calls_to_xml, + build_tool_system_prompt, + extract_tool_names, + inject_into_message, + tool_calls_to_xml, ) from app.dataplane.reverse.protocol.tool_parser import parse_tool_calls from ._tool_sieve import ToolSieve - # --------------------------------------------------------------------------- # Tool format normalisation # --------------------------------------------------------------------------- + def _to_chat_tools(tools: list[dict]) -> list[dict]: """Normalise Responses API tool format → Chat Completions format. @@ -48,14 +72,15 @@ def _to_chat_tools(tools: list[dict]) -> list[dict]: normalised = [] for tool in tools: if tool.get("type") == "function" and "function" not in tool and "name" in tool: - normalised.append({ - "type": "function", - "function": { - "name": tool.get("name", ""), - "description": tool.get("description", ""), - "parameters": tool.get("parameters"), - }, - }) + normalised.append( + { + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters"), + }, + }) else: normalised.append(tool) return normalised @@ -65,18 +90,18 @@ def _to_chat_tools(tools: list[dict]) -> list[dict]: # Tool call helpers (Responses API format) # --------------------------------------------------------------------------- + def _build_fc_items(calls: list) -> list[dict]: """Allocate stable IDs and build function_call output items for response.completed.""" return [ { - "id": make_resp_id("fc"), - "type": "function_call", - "call_id": tc.call_id, - "name": tc.name, + "id": make_resp_id("fc"), + "type": "function_call", + "call_id": tc.call_id, + "name": tc.name, "arguments": tc.arguments, - "status": "completed", - } - for tc in calls + "status": "completed", + } for tc in calls ] @@ -87,43 +112,48 @@ async def _emit_fc_events(items: list[dict], base_idx: int): in the final response.completed payload. """ for i, item in enumerate(items): - out_idx = base_idx + i + out_idx = base_idx + i fc_item_id = item["id"] - yield format_sse("response.output_item.added", { - "type": "response.output_item.added", - "output_index": out_idx, - "item": { - "id": fc_item_id, - "type": "function_call", - "call_id": item["call_id"], - "name": item["name"], - "arguments": "", - "status": "in_progress", - }, - }) - yield format_sse("response.function_call_arguments.delta", { - "type": "response.function_call_arguments.delta", - "item_id": fc_item_id, - "output_index": out_idx, - "delta": item["arguments"], - }) - yield format_sse("response.function_call_arguments.done", { - "type": "response.function_call_arguments.done", - "item_id": fc_item_id, - "output_index": out_idx, - "arguments": item["arguments"], - }) - yield format_sse("response.output_item.done", { - "type": "response.output_item.done", - "output_index": out_idx, - "item": item, - }) + yield format_sse( + "response.output_item.added", { + "type": "response.output_item.added", + "output_index": out_idx, + "item": { + "id": fc_item_id, + "type": "function_call", + "call_id": item["call_id"], + "name": item["name"], + "arguments": "", + "status": "in_progress", + }, + }) + yield format_sse( + "response.function_call_arguments.delta", { + "type": "response.function_call_arguments.delta", + "item_id": fc_item_id, + "output_index": out_idx, + "delta": item["arguments"], + }) + yield format_sse( + "response.function_call_arguments.done", { + "type": "response.function_call_arguments.done", + "item_id": fc_item_id, + "output_index": out_idx, + "arguments": item["arguments"], + }) + yield format_sse( + "response.output_item.done", { + "type": "response.output_item.done", + "output_index": out_idx, + "item": item, + }) # --------------------------------------------------------------------------- # Input normalisation # --------------------------------------------------------------------------- + def _parse_input(input_val: str | list) -> list[dict]: """Convert Responses API input to our internal messages list. @@ -143,34 +173,40 @@ def _parse_input(input_val: str | list) -> list[dict]: if item_type == "function_call": # Reconstruct as assistant message with tool_calls (Chat Completions format) call_id = item.get("call_id", "") - name = item.get("name", "") - args = item.get("arguments", "{}") - messages.append({ - "role": "assistant", - "content": None, - "tool_calls": [{ - "id": call_id, - "type": "function", - "function": {"name": name, "arguments": args}, - }], - }) + name = item.get("name", "") + args = item.get("arguments", "{}") + messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": call_id, + "type": "function", + "function": { + "name": name, + "arguments": args + }, + } + ], + }) continue if item_type == "function_call_output": # Reconstruct as tool result message call_id = item.get("call_id", "") - output = item.get("output", "") + output = item.get("output", "") messages.append({ - "role": "tool", + "role": "tool", "tool_call_id": call_id, - "content": output, + "content": output, }) continue if item_type != "message": continue # skip reasoning items, etc. - role = item.get("role", "user") + role = item.get("role", "user") content = item.get("content", "") if isinstance(content, list): @@ -202,32 +238,324 @@ def _parse_input(input_val: str | list) -> list[dict]: return messages +# --------------------------------------------------------------------------- +# Console API dispatch (console.x.ai/v1/responses) +# --------------------------------------------------------------------------- + + +async def _console_responses_dispatch( + *, + spec, + model: str, + messages: list[dict], + stream: bool, + temperature: float, + top_p: float, + tools: list[dict] | None, + tool_choice: Any, + reasoning_effort: str | None = None, +) -> dict | AsyncGenerator[str, None]: + """Dispatch a /v1/responses request through console.x.ai. + + Console.x.ai natively returns OpenAI Responses API format, so for + streaming we relay upstream SSE events directly (with reconstruction + of event/data block boundaries) and for non-streaming we return the + upstream JSON object as-is. + """ + # Apply per-model default effort when caller didn't specify. Mirrors the + # behaviour of _console_completions for the /v1/responses endpoint. + if reasoning_effort is None and spec.default_reasoning_effort: + reasoning_effort = spec.default_reasoning_effort + cfg = get_config() + console_model = spec.console_model + + # Convert our internal messages list → console structured input + input_array, sys_instructions = build_console_input(messages) + if not input_array and not sys_instructions: + raise UpstreamError("Empty messages after conversion", status=400) + + # Tools may arrive in either Chat Completions wrapper format or + # Responses API flat format; convert_openai_tools_to_console handles both. + console_tools = convert_openai_tools_to_console(tools) if tools else None + console_tool_choice = ( + convert_openai_tool_choice(tool_choice) if console_tools and tool_choice is not None else None) + + # Always enable web search for console models — primary reason for + # using the console route. The upstream emits web_search_call output + # items which are relayed downstream so clients can see citation URLs. + console_tools = inject_web_search_tool(console_tools) + + from app.dataplane.account import _directory as _acct_dir + if _acct_dir is None: + raise RateLimitError("Account directory not initialised") + directory = _acct_dir + + max_retries = selection_max_retries() + retry_codes = _configured_retry_codes(cfg) + timeout_s = cfg.get_float("chat.timeout", 120.0) + + # ── Streaming ──────────────────────────────────────────────────────────── + if stream: + + async def _run_stream() -> AsyncGenerator[str, None]: + excluded: list[str] = [] + for attempt in range(max_retries + 1): + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, + ) + if acct is None: + raise RateLimitError("No available accounts for this model tier") + + token = acct.token + success = False + _retry = False + fail_exc: BaseException | None = None + + try: + try: + session, response = await _console_post( + token=token, + console_model=console_model, + input=input_array, + instructions=sys_instructions, + stream=True, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=console_tools, + tool_choice=console_tool_choice, + timeout_s=timeout_s, + ) + try: + # Relay upstream SSE events. Upstream uses + # OpenAI Responses API format natively, so we + # reconstruct event/data blocks and forward. + current_event = "" + async for raw_line in response.aiter_lines(): + if isinstance(raw_line, bytes): + raw_line = raw_line.decode("utf-8", "replace") + raw_line = raw_line.rstrip("\r") + if not raw_line: + continue + if raw_line.startswith("event:"): + current_event = raw_line[6:].strip() + continue + if raw_line.startswith("data:"): + data = raw_line[5:].strip() + if data == "[DONE]": + break + if current_event: + yield f"event: {current_event}\ndata: {data}\n\n" + else: + yield f"data: {data}\n\n" + current_event = "" + finally: + await session.__aexit__(None, None, None) + + yield "data: [DONE]\n\n" + success = True + logger.info( + "console responses stream completed: attempt={}/{} model={}", + attempt + 1, + max_retries + 1, + model, + ) + + except UpstreamError as exc: + fail_exc = exc + if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: + _retry = True + logger.warning( + "console responses stream retry: attempt={}/{} status={} token={}...", + attempt + 1, + max_retries, + exc.status, + token[:8], + ) + else: + logger.warning( + "console responses stream failed: attempt={}/{} model={} status={} body={}", + attempt + 1, + max_retries + 1, + model, + exc.status, + _upstream_body_excerpt(exc), + ) + raise + + finally: + await directory.release(acct) + kind = ( + FeedbackKind.SUCCESS + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) + if success: + asyncio.create_task(_quota_sync( + token, selected_mode_id)).add_done_callback(_log_task_exception) + else: + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) + + if success or not _retry: + return + excluded.append(token) + + return _run_stream() + + # ── Non-streaming ──────────────────────────────────────────────────────── + excluded: list[str] = [] + response_json: dict[str, Any] | None = None + for attempt in range(max_retries + 1): + acct, selected_mode_id = await reserve_account( + directory, + spec, + now_s_override=now_s(), + exclude_tokens=excluded or None, + ) + if acct is None: + raise RateLimitError("No available accounts for this model tier") + + token = acct.token + success = False + _retry = False + fail_exc: BaseException | None = None + + try: + try: + session, response = await _console_post( + token=token, + console_model=console_model, + input=input_array, + instructions=sys_instructions, + stream=False, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=console_tools, + tool_choice=console_tool_choice, + timeout_s=timeout_s, + ) + try: + body_bytes = response.content + if hasattr(body_bytes, "__await__"): + body_bytes = await body_bytes # type: ignore[misc] + finally: + await session.__aexit__(None, None, None) + + try: + response_json = orjson.loads(body_bytes) + except (orjson.JSONDecodeError, ValueError, TypeError) as exc: + raise UpstreamError( + f"Console returned non-JSON body: {exc}", + status=502, + body=str(body_bytes)[:400], + ) from exc + success = True + + except UpstreamError as exc: + fail_exc = exc + if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: + _retry = True + logger.warning( + "console responses retry: attempt={}/{} status={} token={}...", + attempt + 1, + max_retries, + exc.status, + token[:8], + ) + else: + logger.warning( + "console responses failed: attempt={}/{} model={} status={} body={}", + attempt + 1, + max_retries + 1, + model, + exc.status, + _upstream_body_excerpt(exc), + ) + raise + + finally: + await directory.release(acct) + kind = ( + FeedbackKind.SUCCESS + if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR) + await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) + if success: + asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) + else: + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) + + if success or not _retry: + break + excluded.append(token) + + if not response_json: + raise UpstreamError("Console returned empty response", status=502) + + logger.info( + "console responses request completed: model={} status={} usage={}", + model, + response_json.get("status"), + extract_console_usage(response_json), + ) + return response_json + + # --------------------------------------------------------------------------- # Main handler # --------------------------------------------------------------------------- + async def create( *, - model: str, - input_val: str | list, + model: str, + input_val: str | list, instructions: str | None, - stream: bool, - emit_think: bool, - temperature: float, - top_p: float, - tools: list[dict] | None = None, - tool_choice: Any = None, + stream: bool, + emit_think: bool, + temperature: float, + top_p: float, + reasoning_effort: str | None = None, + tools: list[dict] | None = None, + tool_choice: Any = None, ) -> dict | AsyncGenerator[str, None]: - cfg = get_config() - spec = resolve_model(model) - mode_id = int(spec.mode_id) # cast once, reuse everywhere + cfg = get_config() + spec = resolve_model(model) + mode_id = int(spec.mode_id) # cast once, reuse everywhere messages: list[dict] = [] if instructions: messages.append({"role": "system", "content": instructions}) messages.extend(_parse_input(input_val)) + # ── Console API dispatch ───────────────────────────────────────────────── + # Models with `console_model` set route through console.x.ai/v1/responses + # (OpenAI Responses API native), supporting all models for basic-tier + # accounts via SSO cookies. + if spec.is_console(): + logger.info( + "console responses dispatch: model={} stream={} message_count={}", + model, + stream, + len(messages), + ) + return await _console_responses_dispatch( + spec=spec, + model=model, + messages=messages, + stream=stream, + temperature=temperature, + top_p=top_p, + reasoning_effort=reasoning_effort, + tools=tools, + tool_choice=tool_choice, + ) + message, files = _extract_message(messages) if not message.strip(): raise UpstreamError("Empty message after extraction", status=400) @@ -247,12 +575,12 @@ async def create( raise RateLimitError("Account directory not initialised") directory = _acct_dir - max_retries = selection_max_retries() - retry_codes = _configured_retry_codes(cfg) - response_id = make_resp_id("resp") + max_retries = selection_max_retries() + retry_codes = _configured_retry_codes(cfg) + response_id = make_resp_id("resp") reasoning_id = make_resp_id("rs") - message_id = make_resp_id("msg") - timeout_s = cfg.get_float("chat.timeout", 120.0) + message_id = make_resp_id("msg") + timeout_s = cfg.get_float("chat.timeout", 120.0) # ------------------------------------------------------------------------- # Streaming @@ -269,35 +597,36 @@ async def _run_stream() -> AsyncGenerator[str, None]: if acct is None: raise RateLimitError("No available accounts for this model tier") - token = acct.token + token = acct.token success = False - _retry = False + _retry = False fail_exc: BaseException | None = None - adapter = StreamAdapter() - think_buf: list[str] = [] - text_buf: list[str] = [] - reasoning_started = False - reasoning_closed = False - message_started = False - sieve = ToolSieve(tool_names) if tool_names else None - tool_calls_emitted = False + adapter = StreamAdapter() + think_buf: list[str] = [] + text_buf: list[str] = [] + reasoning_started = False + reasoning_closed = False + message_started = False + sieve = ToolSieve(tool_names) if tool_names else None + tool_calls_emitted = False detected_fc_items: list[dict] = [] collected_annotations: list[dict] = [] try: try: - yield format_sse("response.created", { - "type": "response.created", - "response": make_resp_object(response_id, model, "in_progress", []), - }) + yield format_sse( + "response.created", { + "type": "response.created", + "response": make_resp_object(response_id, model, "in_progress", []), + }) ended = False async for line in _stream_chat( - token = token, - mode_id = ModeId(selected_mode_id), - message = message, - files = files, - timeout_s = timeout_s, + token=token, + mode_id=ModeId(selected_mode_id), + message=message, + files=files, + timeout_s=timeout_s, ): if tool_calls_emitted: break @@ -313,58 +642,75 @@ async def _run_stream() -> AsyncGenerator[str, None]: if ev.kind == "thinking" and emit_think and not reasoning_closed: if not reasoning_started: reasoning_started = True - yield format_sse("response.output_item.added", { - "type": "response.output_item.added", + yield format_sse( + "response.output_item.added", { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": reasoning_id, + "type": "reasoning", + "summary": [], + "status": "in_progress", + }, + }) + yield format_sse( + "response.reasoning_summary_part.added", { + "type": "response.reasoning_summary_part.added", + "item_id": reasoning_id, + "output_index": 0, + "summary_index": 0, + "part": { + "type": "summary_text", + "text": "" + }, + }) + think_buf.append(ev.content) + yield format_sse( + "response.reasoning_summary_text.delta", { + "type": "response.reasoning_summary_text.delta", + "item_id": reasoning_id, "output_index": 0, - "item": { - "id": reasoning_id, "type": "reasoning", - "summary": [], "status": "in_progress", - }, - }) - yield format_sse("response.reasoning_summary_part.added", { - "type": "response.reasoning_summary_part.added", - "item_id": reasoning_id, - "output_index": 0, "summary_index": 0, - "part": {"type": "summary_text", "text": ""}, + "delta": ev.content, }) - think_buf.append(ev.content) - yield format_sse("response.reasoning_summary_text.delta", { - "type": "response.reasoning_summary_text.delta", - "item_id": reasoning_id, - "output_index": 0, - "summary_index": 0, - "delta": ev.content, - }) elif ev.kind == "text": if reasoning_started and not reasoning_closed: reasoning_closed = True full_think = "".join(think_buf) - yield format_sse("response.reasoning_summary_text.done", { - "type": "response.reasoning_summary_text.done", - "item_id": reasoning_id, - "output_index": 0, - "summary_index": 0, - "text": full_think, - }) - yield format_sse("response.reasoning_summary_part.done", { - "type": "response.reasoning_summary_part.done", - "item_id": reasoning_id, - "output_index": 0, - "summary_index": 0, - "part": {"type": "summary_text", "text": full_think}, - }) - yield format_sse("response.output_item.done", { - "type": "response.output_item.done", - "output_index": 0, - "item": { - "id": reasoning_id, - "type": "reasoning", - "summary": [{"type": "summary_text", "text": full_think}], - "status": "completed", - }, - }) + yield format_sse( + "response.reasoning_summary_text.done", { + "type": "response.reasoning_summary_text.done", + "item_id": reasoning_id, + "output_index": 0, + "summary_index": 0, + "text": full_think, + }) + yield format_sse( + "response.reasoning_summary_part.done", { + "type": "response.reasoning_summary_part.done", + "item_id": reasoning_id, + "output_index": 0, + "summary_index": 0, + "part": { + "type": "summary_text", + "text": full_think + }, + }) + yield format_sse( + "response.output_item.done", { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": reasoning_id, + "type": "reasoning", + "summary": [{ + "type": "summary_text", + "text": full_think + }], + "status": "completed", + }, + }) # Feed through ToolSieve if tools are active if sieve is not None: @@ -386,43 +732,54 @@ async def _run_stream() -> AsyncGenerator[str, None]: msg_idx = 1 if reasoning_started else 0 if not message_started: message_started = True - yield format_sse("response.output_item.added", { - "type": "response.output_item.added", + yield format_sse( + "response.output_item.added", { + "type": "response.output_item.added", + "output_index": msg_idx, + "item": { + "id": message_id, + "type": "message", + "role": "assistant", + "content": [], + "status": "in_progress", + }, + }) + yield format_sse( + "response.content_part.added", { + "type": "response.content_part.added", + "item_id": message_id, + "output_index": msg_idx, + "content_index": 0, + "part": { + "type": "output_text", + "text": "", + "annotations": [] + }, + }) + + text_buf.append(text_chunk) + yield format_sse( + "response.output_text.delta", { + "type": "response.output_text.delta", + "item_id": message_id, "output_index": msg_idx, - "item": { - "id": message_id, "type": "message", - "role": "assistant", "content": [], "status": "in_progress", - }, - }) - yield format_sse("response.content_part.added", { - "type": "response.content_part.added", - "item_id": message_id, - "output_index": msg_idx, "content_index": 0, - "part": {"type": "output_text", "text": "", "annotations": []}, + "delta": text_chunk, }) - text_buf.append(text_chunk) - yield format_sse("response.output_text.delta", { - "type": "response.output_text.delta", - "item_id": message_id, - "output_index": msg_idx, - "content_index": 0, - "delta": text_chunk, - }) - elif ev.kind == "annotation" and ev.annotation_data: if message_started: collected_annotations.append(ev.annotation_data) msg_idx = 1 if reasoning_started else 0 - yield format_sse("response.output_text.annotation.added", { - "type": "response.output_text.annotation.added", - "item_id": message_id, - "output_index": msg_idx, - "content_index": 0, - "annotation_index": len(collected_annotations) - 1, - "annotation": ev.annotation_data, - }) + yield format_sse( + "response.output_text.annotation.added", { + "type": "response.output_text.annotation.added", + "item_id": message_id, + "output_index": msg_idx, + "content_index": 0, + "annotation_index": len(collected_annotations) - 1, + "annotation": ev.annotation_data, + }) elif ev.kind == "soft_stop": ended = True @@ -447,133 +804,173 @@ async def _run_stream() -> AsyncGenerator[str, None]: full_think = "".join(think_buf) output: list[dict] = [] if reasoning_started and full_think: - output.append({ - "id": reasoning_id, - "type": "reasoning", - "summary": [{"type": "summary_text", "text": full_think}], - "status": "completed", - }) + output.append( + { + "id": reasoning_id, + "type": "reasoning", + "summary": [{ + "type": "summary_text", + "text": full_think + }], + "status": "completed", + }) output.extend(detected_fc_items) pt = estimate_prompt_tokens(message) ct = estimate_tool_call_tokens(detected_fc_items) rt = estimate_tokens(full_think) if full_think else 0 - yield format_sse("response.completed", { - "type": "response.completed", - "response": make_resp_object( - response_id, model, "completed", output, - build_resp_usage(pt, ct + rt, rt), - ), - }) + yield format_sse( + "response.completed", { + "type": "response.completed", + "response": make_resp_object( + response_id, + model, + "completed", + output, + build_resp_usage(pt, ct + rt, rt), + ), + }) yield "data: [DONE]\n\n" success = True - logger.info("responses stream tool_calls: attempt={}/{} model={}", - attempt + 1, max_retries + 1, model) + logger.info( + "responses stream tool_calls: attempt={}/{} model={}", attempt + 1, max_retries + 1, + model) else: # Normal text path msg_idx = 1 if reasoning_started else 0 for url, img_id in adapter.image_urls: img_text = await _resolve_image(token, url, img_id) - img_md = img_text + "\n" + img_md = img_text + "\n" text_buf.append(img_md) if message_started: - yield format_sse("response.output_text.delta", { - "type": "response.output_text.delta", - "item_id": message_id, - "output_index": msg_idx, - "content_index": 0, - "delta": img_md, - }) + yield format_sse( + "response.output_text.delta", { + "type": "response.output_text.delta", + "item_id": message_id, + "output_index": msg_idx, + "content_index": 0, + "delta": img_md, + }) references = adapter.references_suffix() if references: text_buf.append(references) if message_started: - yield format_sse("response.output_text.delta", { - "type": "response.output_text.delta", - "item_id": message_id, - "output_index": msg_idx, - "content_index": 0, - "delta": references, - }) + yield format_sse( + "response.output_text.delta", { + "type": "response.output_text.delta", + "item_id": message_id, + "output_index": msg_idx, + "content_index": 0, + "delta": references, + }) full_text = "".join(text_buf) if message_started: - yield format_sse("response.output_text.done", { - "type": "response.output_text.done", - "item_id": message_id, - "output_index": msg_idx, - "content_index": 0, - "text": full_text, - }) - yield format_sse("response.content_part.done", { - "type": "response.content_part.done", - "item_id": message_id, - "output_index": msg_idx, - "content_index": 0, - "part": {"type": "output_text", "text": full_text, "annotations": collected_annotations}, - }) + yield format_sse( + "response.output_text.done", { + "type": "response.output_text.done", + "item_id": message_id, + "output_index": msg_idx, + "content_index": 0, + "text": full_text, + }) + yield format_sse( + "response.content_part.done", { + "type": "response.content_part.done", + "item_id": message_id, + "output_index": msg_idx, + "content_index": 0, + "part": { + "type": "output_text", + "text": full_text, + "annotations": collected_annotations + }, + }) # 构建 message item(流式 output_item.done + response.completed 共用) sources = adapter.search_sources_list() msg_item: dict = { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": full_text, "annotations": collected_annotations}], - "status": "completed", + "id": message_id, + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": full_text, + "annotations": collected_annotations + } + ], + "status": "completed", } if sources: msg_item["search_sources"] = sources - yield format_sse("response.output_item.done", { - "type": "response.output_item.done", - "output_index": msg_idx, - "item": msg_item, - }) + yield format_sse( + "response.output_item.done", { + "type": "response.output_item.done", + "output_index": msg_idx, + "item": msg_item, + }) full_think = "".join(think_buf) output = [] if reasoning_started and full_think: - output.append({ - "id": reasoning_id, - "type": "reasoning", - "summary": [{"type": "summary_text", "text": full_think}], - "status": "completed", - }) + output.append( + { + "id": reasoning_id, + "type": "reasoning", + "summary": [{ + "type": "summary_text", + "text": full_think + }], + "status": "completed", + }) # 复用 msg_item(message_started 时已构建);未启动时重新构建 if not message_started: msg_item = { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": full_text, "annotations": adapter.annotations_list()}], - "status": "completed", + "id": message_id, + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": full_text, + "annotations": adapter.annotations_list() + } + ], + "status": "completed", } sources = adapter.search_sources_list() if sources: msg_item["search_sources"] = sources output.append(msg_item) - pt = estimate_prompt_tokens(message) - ct = estimate_tokens(full_text) - rt = estimate_tokens(full_think) if full_think else 0 - yield format_sse("response.completed", { - "type": "response.completed", - "response": make_resp_object( - response_id, model, "completed", output, - build_resp_usage(pt, ct + rt, rt), - ), - }) + pt = estimate_prompt_tokens(message) + ct = estimate_tokens(full_text) + rt = estimate_tokens(full_think) if full_think else 0 + yield format_sse( + "response.completed", { + "type": "response.completed", + "response": make_resp_object( + response_id, + model, + "completed", + output, + build_resp_usage(pt, ct + rt, rt), + ), + }) yield "data: [DONE]\n\n" success = True - logger.info("responses stream completed: attempt={}/{} model={} text_len={} reasoning_len={} image_count={}", - attempt + 1, max_retries + 1, model, - len(full_text), len(full_think), len(adapter.image_urls)) + logger.info( + "responses stream completed: attempt={}/{} model={} text_len={} reasoning_len={} image_count={}", + attempt + 1, max_retries + 1, model, len(full_text), len(full_think), + len(adapter.image_urls)) except UpstreamError as exc: fail_exc = exc if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: _retry = True - logger.warning("responses stream retry scheduled: attempt={}/{} status={} token={}...", - attempt + 1, max_retries, exc.status, token[:8]) + logger.warning( + "responses stream retry scheduled: attempt={}/{} status={} token={}...", attempt + 1, + max_retries, exc.status, token[:8]) else: logger.warning( "responses stream upstream failed: attempt={}/{} model={} status={} body={}", @@ -587,12 +984,15 @@ async def _run_stream() -> AsyncGenerator[str, None]: finally: await directory.release(acct) - kind = FeedbackKind.SUCCESS if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR + kind = FeedbackKind.SUCCESS if success else _feedback_kind( + fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR await directory.feedback(token, kind, selected_mode_id, now_s_val=now_s()) if success: - asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) + asyncio.create_task(_quota_sync(token, + selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: return @@ -605,8 +1005,8 @@ async def _run_stream() -> AsyncGenerator[str, None]: # Non-streaming # ------------------------------------------------------------------------- excluded: list[str] = [] - token = "" - adapter = StreamAdapter() + token = "" + adapter = StreamAdapter() for attempt in range(max_retries + 1): acct, selected_mode_id = await reserve_account( directory, @@ -617,20 +1017,20 @@ async def _run_stream() -> AsyncGenerator[str, None]: if acct is None: raise RateLimitError("No available accounts for this model tier") - token = acct.token - success = False - _retry = False + token = acct.token + success = False + _retry = False fail_exc: BaseException | None = None - adapter = StreamAdapter() # fresh adapter per attempt + adapter = StreamAdapter() # fresh adapter per attempt try: try: async for line in _stream_chat( - token = token, - mode_id = ModeId(selected_mode_id), - message = message, - files = files, - timeout_s = timeout_s, + token=token, + mode_id=ModeId(selected_mode_id), + message=message, + files=files, + timeout_s=timeout_s, ): event_type, data = classify_line(line) if event_type == "done": @@ -650,8 +1050,9 @@ async def _run_stream() -> AsyncGenerator[str, None]: fail_exc = exc if _should_retry_upstream(exc, retry_codes) and attempt < max_retries: _retry = True - logger.warning("responses retry scheduled: attempt={}/{} status={} token={}...", - attempt + 1, max_retries, exc.status, token[:8]) + logger.warning( + "responses retry scheduled: attempt={}/{} status={} token={}...", attempt + 1, + max_retries, exc.status, token[:8]) else: logger.warning( "responses upstream failed: attempt={}/{} model={} status={} body={}", @@ -665,12 +1066,14 @@ async def _run_stream() -> AsyncGenerator[str, None]: finally: await directory.release(acct) - kind = FeedbackKind.SUCCESS if success else _feedback_kind(fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR + kind = FeedbackKind.SUCCESS if success else _feedback_kind( + fail_exc) if fail_exc else FeedbackKind.SERVER_ERROR await directory.feedback(token, kind, selected_mode_id) if success: asyncio.create_task(_quota_sync(token, selected_mode_id)).add_done_callback(_log_task_exception) else: - asyncio.create_task(_fail_sync(token, selected_mode_id, fail_exc)).add_done_callback(_log_task_exception) + asyncio.create_task(_fail_sync(token, selected_mode_id, + fail_exc)).add_done_callback(_log_task_exception) if success or not _retry: break @@ -701,39 +1104,55 @@ async def _run_stream() -> AsyncGenerator[str, None]: if tc_result.calls: output: list[dict] = [] if full_think: - output.append({ - "id": reasoning_id, - "type": "reasoning", - "summary": [{"type": "summary_text", "text": full_think}], - "status": "completed", - }) + output.append( + { + "id": reasoning_id, + "type": "reasoning", + "summary": [{ + "type": "summary_text", + "text": full_think + }], + "status": "completed", + }) output.extend(_build_fc_items(tc_result.calls)) pt = estimate_prompt_tokens(message) ct = estimate_tool_call_tokens(tc_result.calls) rt = estimate_tokens(full_think) if full_think else 0 logger.info("responses tool_calls: model={} calls={}", model, len(tc_result.calls)) return make_resp_object( - response_id, model, "completed", output, + response_id, + model, + "completed", + output, build_resp_usage(pt, ct + rt, rt), ) - logger.info("responses request completed: model={} text_len={} reasoning_len={} image_count={}", - model, len(full_text), len(full_think), len(adapter.image_urls)) + logger.info( + "responses request completed: model={} text_len={} reasoning_len={} image_count={}", model, + len(full_text), len(full_think), len(adapter.image_urls)) output = [] if full_think: - output.append({ - "id": reasoning_id, - "type": "reasoning", - "summary": [{"type": "summary_text", "text": full_think}], - "status": "completed", - }) + output.append( + { + "id": reasoning_id, + "type": "reasoning", + "summary": [{ + "type": "summary_text", + "text": full_think + }], + "status": "completed", + }) msg_item: dict = { - "id": message_id, - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": full_text, "annotations": adapter.annotations_list()}], - "status": "completed", + "id": message_id, + "type": "message", + "role": "assistant", + "content": [{ + "type": "output_text", + "text": full_text, + "annotations": adapter.annotations_list() + }], + "status": "completed", } sources = adapter.search_sources_list() if sources: @@ -744,7 +1163,10 @@ async def _run_stream() -> AsyncGenerator[str, None]: ct = estimate_tokens(full_text) rt = estimate_tokens(full_think) if full_think else 0 return make_resp_object( - response_id, model, "completed", output, + response_id, + model, + "completed", + output, build_resp_usage(pt, ct + rt, rt), ) diff --git a/app/products/openai/router.py b/app/products/openai/router.py index 01a27504a..bb06efdf4 100644 --- a/app/products/openai/router.py +++ b/app/products/openai/router.py @@ -73,16 +73,12 @@ async def list_models(request: Request): "created": int(time.time()), "owned_by": "xai", "name": m.public_name, - } - for m in model_registry.list_enabled() - if _model_available_for_pools(m, pools) + } for m in model_registry.list_enabled() if _model_available_for_pools(m, pools) ] return JSONResponse({"object": "list", "data": models}) -@router.get( - "/models/{model_id}", tags=[_TAG_MODELS], dependencies=[Depends(verify_api_key)] -) +@router.get("/models/{model_id}", tags=[_TAG_MODELS], dependencies=[Depends(verify_api_key)]) async def get_model_endpoint(model_id: str, request: Request): import time @@ -90,12 +86,10 @@ async def get_model_endpoint(model_id: str, request: Request): pools = await _available_pools(request) if spec is None or not _model_available_for_pools(spec, pools): return JSONResponse( - { - "error": { - "message": f"Model {model_id!r} not found", - "type": "invalid_request_error", - } - }, + {"error": { + "message": f"Model {model_id!r} not found", + "type": "invalid_request_error", + }}, status_code=404, ) return JSONResponse( @@ -105,8 +99,7 @@ async def get_model_endpoint(model_id: str, request: Request): "created": int(time.time()), "owned_by": "xai", "name": spec.public_name, - } - ) + }) # --------------------------------------------------------------------------- @@ -124,16 +117,13 @@ async def _safe_sse(stream: AsyncIterable[str]) -> AsyncGenerator[str, None]: yield f"event: error\ndata: {payload}\n\n" yield "data: [DONE]\n\n" except Exception as exc: - payload = orjson.dumps( - {"error": {"message": str(exc), "type": "server_error"}} - ).decode() + payload = orjson.dumps({"error": {"message": str(exc), "type": "server_error"}}).decode() yield f"event: error\ndata: {payload}\n\n" yield "data: [DONE]\n\n" _SSE_HEADERS = {"Cache-Control": "no-cache", "Connection": "keep-alive"} - # --------------------------------------------------------------------------- # /v1/chat/completions # --------------------------------------------------------------------------- @@ -164,9 +154,7 @@ def _validate_chat(req: ChatCompletionRequest) -> None: param=f"messages.{i}.role", ) if req.temperature is not None and not (0 <= req.temperature <= 2): - raise ValidationError( - "temperature must be between 0 and 2", param="temperature" - ) + raise ValidationError("temperature must be between 0 and 2", param="temperature") if req.top_p is not None and not (0 <= req.top_p <= 1): raise ValidationError("top_p must be between 0 and 1", param="top_p") if req.reasoning_effort is not None and req.reasoning_effort not in _EFFORT_VALUES: @@ -196,10 +184,8 @@ async def _upload_to_data_uri(upload: UploadFile, *, param: str) -> str: raise ValidationError("Uploaded image cannot be empty", param=param) mime = ( - (upload.content_type or "").strip().lower() - or mimetypes.guess_type(upload.filename or "")[0] - or "application/octet-stream" - ) + (upload.content_type or "").strip().lower() or mimetypes.guess_type(upload.filename or "")[0] or + "application/octet-stream") if not mime.startswith("image/"): raise ValidationError("Uploaded file must be an image", param=param) @@ -210,17 +196,13 @@ async def _upload_to_data_uri(upload: UploadFile, *, param: str) -> str: return f"data:{mime};base64,{blob_b64}" -@router.post( - "/chat/completions", tags=[_TAG_CHAT], dependencies=[Depends(verify_api_key)] -) +@router.post("/chat/completions", tags=[_TAG_CHAT], dependencies=[Depends(verify_api_key)]) async def chat_completions_endpoint(req: ChatCompletionRequest): _validate_chat(req) from app.platform.config.snapshot import get_config cfg = get_config() - is_stream = ( - req.stream if req.stream is not None else cfg.get_bool("features.stream", True) - ) + is_stream = (req.stream if req.stream is not None else cfg.get_bool("features.stream", True)) spec = model_registry.get(req.model) if spec is None: @@ -259,12 +241,8 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): # Extract prompt from last user message. prompt = next( ( - m.content - for m in reversed(req.messages) - if m.role == "user" - and isinstance(m.content, str) - and m.content.strip() - ), + m.content for m in reversed(req.messages) + if m.role == "user" and isinstance(m.content, str) and m.content.strip()), "", ) result = await img_gen( @@ -309,6 +287,7 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): tool_choice=req.tool_choice, temperature=req.temperature or 0.8, top_p=req.top_p or 0.95, + reasoning_effort=req.reasoning_effort, ) except AppError: @@ -321,27 +300,19 @@ async def chat_completions_endpoint(req: ChatCompletionRequest): exc, ) if is_stream: - _err_msg = str( - exc - ) # capture before Python clears the except-scope variable + _err_msg = str(exc) # capture before Python clears the except-scope variable async def _err_stream(): - payload = orjson.dumps( - {"error": {"message": _err_msg, "type": "server_error"}} - ).decode() + payload = orjson.dumps({"error": {"message": _err_msg, "type": "server_error"}}).decode() yield f"event: error\ndata: {payload}\n\n" yield "data: [DONE]\n\n" - return StreamingResponse( - _err_stream(), media_type="text/event-stream", headers=_SSE_HEADERS - ) + return StreamingResponse(_err_stream(), media_type="text/event-stream", headers=_SSE_HEADERS) raise if isinstance(result, dict): return JSONResponse(result) - return StreamingResponse( - _safe_sse(result), media_type="text/event-stream", headers=_SSE_HEADERS - ) + return StreamingResponse(_safe_sse(result), media_type="text/event-stream", headers=_SSE_HEADERS) # --------------------------------------------------------------------------- @@ -371,9 +342,7 @@ async def _safe_sse_responses(stream) -> AsyncGenerator[str, None]: yield "data: [DONE]\n\n" -@router.post( - "/responses", tags=[_TAG_RESPONSES], dependencies=[Depends(verify_api_key)] -) +@router.post("/responses", tags=[_TAG_RESPONSES], dependencies=[Depends(verify_api_key)]) async def responses_endpoint(req: ResponsesCreateRequest): from app.platform.config.snapshot import get_config from app.platform.errors import ValidationError as _ValidationError @@ -389,15 +358,18 @@ async def responses_endpoint(req: ResponsesCreateRequest): raise _ValidationError("input cannot be empty", param="input") cfg = get_config() - is_stream = ( - req.stream if req.stream is not None else cfg.get_bool("features.stream", True) - ) + is_stream = (req.stream if req.stream is not None else cfg.get_bool("features.stream", True)) # Map reasoning param → emit_think flag. # reasoning=None → use config; reasoning.effort="none" → off; otherwise on. + reasoning_effort: str | None = None + if isinstance(req.reasoning, dict): + eff = req.reasoning.get("effort") + if isinstance(eff, str): + reasoning_effort = eff if req.reasoning is None: emit_think = cfg.get_bool("features.thinking", True) - elif isinstance(req.reasoning, dict) and req.reasoning.get("effort") == "none": + elif reasoning_effort == "none": emit_think = False else: emit_think = True @@ -412,6 +384,7 @@ async def responses_endpoint(req: ResponsesCreateRequest): emit_think=emit_think, temperature=req.temperature or 0.8, top_p=req.top_p or 0.95, + reasoning_effort=reasoning_effort, tools=req.tools or None, tool_choice=req.tool_choice, ) @@ -420,8 +393,8 @@ async def responses_endpoint(req: ResponsesCreateRequest): return JSONResponse(result) return StreamingResponse( _safe_sse_responses(result), - media_type = "text/event-stream", - headers = _SSE_HEADERS, + media_type="text/event-stream", + headers=_SSE_HEADERS, ) @@ -430,15 +403,11 @@ async def responses_endpoint(req: ResponsesCreateRequest): # --------------------------------------------------------------------------- -@router.post( - "/images/generations", tags=[_TAG_IMAGES], dependencies=[Depends(verify_api_key)] -) +@router.post("/images/generations", tags=[_TAG_IMAGES], dependencies=[Depends(verify_api_key)]) async def image_generations(req: ImageGenerationRequest): spec = model_registry.get(req.model) if spec is None or not spec.enabled or not spec.is_image(): - raise ValidationError( - f"Model {req.model!r} is not an image model", param="model" - ) + raise ValidationError(f"Model {req.model!r} is not an image model", param="model") _validate_image_n(req.model, req.n or 1, param="n") from .images import generate as img_gen @@ -465,24 +434,21 @@ async def videos_create( model: Annotated[str, Form(...)], prompt: Annotated[str, Form(...)], seconds: Annotated[int, Form()] = 6, - size: Annotated[ - Literal["720x1280", "1280x720", "1024x1024", "1024x1792", "1792x1024"], Form() - ] = "720x1280", + size: Annotated[Literal["720x1280", "1280x720", "1024x1024", "1024x1792", "1792x1024"], + Form()] = "720x1280", resolution_name: Annotated[Literal["480p", "720p"] | None, Form()] = None, - preset: Annotated[ - Literal["fun", "normal", "spicy", "custom"] | None, Form() - ] = None, - input_reference: Annotated[ - list[UploadFile] | None, File(alias="input_reference[]") - ] = None, + preset: Annotated[Literal["fun", "normal", "spicy", "custom"] | None, + Form()] = None, + input_reference: Annotated[list[UploadFile] | None, File(alias="input_reference[]")] = None, ): from .video import create_video references_payload = None if input_reference: references_payload = [ - {"image_url": await _upload_to_data_uri(f, param="input_reference")} - for f in input_reference[:7] + { + "image_url": await _upload_to_data_uri(f, param="input_reference") + } for f in input_reference[:7] ] result = await create_video( @@ -497,9 +463,7 @@ async def videos_create( return JSONResponse(result) -@router.get( - "/videos/{video_id}", tags=[_TAG_VIDEOS], dependencies=[Depends(verify_api_key)] -) +@router.get("/videos/{video_id}", tags=[_TAG_VIDEOS], dependencies=[Depends(verify_api_key)]) async def videos_retrieve(video_id: str): from .video import retrieve @@ -523,9 +487,7 @@ async def videos_content(video_id: str): # --------------------------------------------------------------------------- -@router.post( - "/images/edits", tags=[_TAG_IMAGES], dependencies=[Depends(verify_api_key)] -) +@router.post("/images/edits", tags=[_TAG_IMAGES], dependencies=[Depends(verify_api_key)]) async def image_edits( model: Annotated[str, Form(...)], prompt: Annotated[str, Form(...)], @@ -537,25 +499,17 @@ async def image_edits( ): spec = model_registry.get(model) if spec is None or not spec.enabled or not spec.is_image_edit(): - raise ValidationError( - f"Model {model!r} is not an image-edit model", param="model" - ) + raise ValidationError(f"Model {model!r} is not an image-edit model", param="model") if mask is not None: raise ValidationError("mask is not supported yet", param="mask") _validate_image_edit_n(n, param="n") from .images import edit as img_edit - image_inputs = [ - await _upload_to_data_uri(item, param=f"image.{index}") - for index, item in enumerate(image) - ] + image_inputs = [await _upload_to_data_uri(item, param=f"image.{index}") for index, item in enumerate(image)] # Wrap input into a single-message conversation. content = [{"type": "text", "text": prompt}] - content.extend( - {"type": "image_url", "image_url": {"url": image_input}} - for image_input in image_inputs - ) + content.extend({"type": "image_url", "image_url": {"url": image_input}} for image_input in image_inputs) messages = [{"role": "user", "content": content}] result = await img_edit( model=model, diff --git a/app/products/web/webui/chat.py b/app/products/web/webui/chat.py index 9d1dc2f9c..bb061a61e 100644 --- a/app/products/web/webui/chat.py +++ b/app/products/web/webui/chat.py @@ -2,12 +2,16 @@ import time -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from fastapi.responses import JSONResponse from app.control.model import registry as model_registry from app.platform.auth.middleware import verify_webui_key -from app.products.openai.router import chat_completions_endpoint +from app.products.openai.router import ( + _available_pools, + _model_available_for_pools, + chat_completions_endpoint, +) from app.products.openai.schemas import ChatCompletionRequest router = APIRouter(prefix="/webui/api", dependencies=[Depends(verify_webui_key)], tags=["WebUI - Chat"]) @@ -24,7 +28,12 @@ def _capability_name(spec) -> str: @router.get("/models") -async def list_webui_models(): +async def list_webui_models(request: Request): + # Filter by account tier availability so the WebUI dropdown only shows + # models the configured account pool can actually serve. Without this + # the user would see super/heavy-tier models that fail with + # "No available accounts for this model tier" on call. + pools = await _available_pools(request) models = [ { "id": spec.model_name, @@ -33,8 +42,7 @@ async def list_webui_models(): "owned_by": "xai", "name": spec.public_name, "capability": _capability_name(spec), - } - for spec in model_registry.list_enabled() + } for spec in model_registry.list_enabled() if _model_available_for_pools(spec, pools) ] return JSONResponse({"object": "list", "data": models})