diff --git a/routstr/upstream/base.py b/routstr/upstream/base.py index 69ef48d9..e1814a6c 100644 --- a/routstr/upstream/base.py +++ b/routstr/upstream/base.py @@ -2,7 +2,6 @@ import asyncio import json -import re import traceback from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Mapping @@ -169,6 +168,27 @@ def prepare_headers(self, request_headers: dict) -> dict: return headers + def _extract_usage_from_tail( + self, tail_content: str + ) -> tuple[dict | None, str | None]: + """Extract usage and model from tail content of a stream.""" + usage_data = None + model = None + + lines = tail_content.strip().split("\n") + for line in lines: + if line.startswith("data: "): + try: + data = json.loads(line[6:]) + if isinstance(data, dict): + if "usage" in data: + usage_data = data["usage"] + if "model" in data: + model = data["model"] + except json.JSONDecodeError: + continue + return usage_data, model + def prepare_params( self, path: str, query_params: Mapping[str, str] | None ) -> Mapping[str, str]: @@ -234,7 +254,11 @@ def prepare_responses_request_body( ) # Handle model in input field (alternative format) - if "input" in data and isinstance(data["input"], dict) and "model" in data["input"]: + if ( + "input" in data + and isinstance(data["input"], dict) + and "model" in data["input"] + ): original_model = model_obj.id transformed_model = self.transform_model_name(original_model) data["input"]["model"] = transformed_model @@ -432,155 +456,64 @@ async def handle_streaming_chat_completion( async def stream_with_cost( max_cost_for_model: int, ) -> AsyncGenerator[bytes, None]: - stored_chunks: list[bytes] = [] - usage_finalized: bool = False - last_model_seen: str | None = None - - async def finalize_without_usage() -> bytes | None: - nonlocal usage_finalized - if usage_finalized: - return None - async with create_session() as new_session: - fresh_key = await new_session.get(key.__class__, key.hashed_key) - if not fresh_key: - return None - try: - fallback: dict = { - "model": last_model_seen or "unknown", - "usage": None, - } - cost_data = await adjust_payment_for_tokens( - fresh_key, fallback, new_session, max_cost_for_model - ) - usage_finalized = True - logger.info( - "Finalized streaming payment without explicit usage", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "cost_data": cost_data, - "balance_after_adjustment": fresh_key.balance, - }, - ) - return f"data: {json.dumps({'cost': cost_data})}\n\n".encode() - except Exception as cost_error: - logger.error( - "Error finalizing payment without usage", - extra={ - "error": str(cost_error), - "error_type": type(cost_error).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) - return None + tail_buffer: list[bytes] = [] + MAX_TAIL_CHUNKS = 5 try: async for chunk in response.aiter_bytes(): - stored_chunks.append(chunk) - try: - for part in re.split(b"data: ", chunk): - if not part or part.strip() in (b"[DONE]", b""): - continue - try: - obj = json.loads(part) - if isinstance(obj, dict) and obj.get("model"): - last_model_seen = str(obj.get("model")) - except json.JSONDecodeError: - pass - except Exception: - pass - yield chunk - - logger.debug( - "Streaming completed, analyzing usage data", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "chunks_count": len(stored_chunks), - }, - ) - - for i in range(len(stored_chunks) - 1, -1, -1): - chunk = stored_chunks[i] - if not chunk: - continue - try: - events = re.split(b"data: ", chunk) - for event_data in events: - if not event_data or event_data.strip() in (b"[DONE]", b""): - continue + tail_buffer.append(chunk) + if len(tail_buffer) > MAX_TAIL_CHUNKS: + tail_buffer.pop(0) + + # Post-stream processing + tail_content = b"".join(tail_buffer).decode("utf-8", errors="ignore") + usage_data, model = self._extract_usage_from_tail(tail_content) + + if usage_data: + # Calculate final cost using usage data extracted from stream tail + async with create_session() as new_session: + fresh_key = await new_session.get(key.__class__, key.hashed_key) + if fresh_key: try: - data = json.loads(event_data) - if isinstance(data, dict) and data.get("model"): - last_model_seen = str(data.get("model")) - if isinstance(data, dict) and isinstance( - data.get("usage"), dict - ): - async with create_session() as new_session: - fresh_key = await new_session.get( - key.__class__, key.hashed_key - ) - if fresh_key: - try: - cost_data = ( - await adjust_payment_for_tokens( - fresh_key, - data, - new_session, - max_cost_for_model, - ) - ) - usage_finalized = True - logger.info( - "Payment adjustment completed for streaming", - extra={ - "key_hash": key.hashed_key[:8] - + "...", - "cost_data": cost_data, - "model": last_model_seen, - "balance_after_adjustment": fresh_key.balance, - }, - ) - yield f"data: {json.dumps({'cost': cost_data})}\n\n".encode() - except Exception as cost_error: - logger.error( - "Error adjusting payment for streaming tokens", - extra={ - "error": str(cost_error), - "error_type": type( - cost_error - ).__name__, - "key_hash": key.hashed_key[:8] - + "...", - }, - ) - break - except json.JSONDecodeError: - continue - except Exception as e: - logger.error( - "Error processing streaming response chunk", - extra={ - "error": str(e), - "error_type": type(e).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) + # We need to reconstruct a response object for adjust_payment_for_tokens + # or call it with usage directly if supported. + # adjust_payment_for_tokens expects a dict with "usage" and "model" keys usually. + + data = { + "usage": usage_data, + "model": model or "unknown", + } - if not usage_finalized: - maybe_cost_event = await finalize_without_usage() - if maybe_cost_event is not None: - yield maybe_cost_event + cost_data = await adjust_payment_for_tokens( + fresh_key, + data, + new_session, + max_cost_for_model, + ) + + logger.info( + "Payment adjustment completed for streaming", + extra={ + "key_hash": key.hashed_key[:8] + "...", + "cost_data": cost_data, + "model": model, + "balance_after_adjustment": fresh_key.balance, + }, + ) + # We yield the cost data event at the end + yield f"data: {json.dumps({'cost': cost_data})}\n\n".encode() + except Exception as cost_error: + logger.error( + "Error adjusting payment for streaming tokens", + extra={"error": str(cost_error)}, + ) except Exception as stream_error: logger.warning( - "Streaming interrupted; finalizing without usage", - extra={ - "error": str(stream_error), - "error_type": type(stream_error).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, + "Streaming interrupted", + extra={"error": str(stream_error)}, ) - await finalize_without_usage() raise # Remove inaccurate encoding headers from upstream response @@ -722,166 +655,69 @@ async def handle_streaming_responses_completion( async def stream_with_responses_cost( max_cost_for_model: int, ) -> AsyncGenerator[bytes, None]: - stored_chunks: list[bytes] = [] - usage_finalized: bool = False - last_model_seen: str | None = None - reasoning_tokens: int = 0 - - async def finalize_without_usage() -> bytes | None: - nonlocal usage_finalized - if usage_finalized: - return None - async with create_session() as new_session: - fresh_key = await new_session.get(key.__class__, key.hashed_key) - if not fresh_key: - return None - try: - fallback: dict = { - "model": last_model_seen or "unknown", - "usage": None, - } - cost_data = await adjust_payment_for_tokens( - fresh_key, fallback, new_session, max_cost_for_model - ) - usage_finalized = True - logger.info( - "Finalized Responses API streaming payment without explicit usage", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "cost_data": cost_data, - "balance_after_adjustment": fresh_key.balance, - }, - ) - return f"data: {json.dumps({'cost': cost_data})}\\n\\n".encode() - except Exception as cost_error: - logger.error( - "Error finalizing Responses API payment without usage", - extra={ - "error": str(cost_error), - "error_type": type(cost_error).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) - return None + tail_buffer: list[bytes] = [] + MAX_TAIL_CHUNKS = 5 try: async for chunk in response.aiter_bytes(): - stored_chunks.append(chunk) - try: - for part in re.split(b"data: ", chunk): - if not part or part.strip() in (b"[DONE]", b""): - continue - try: - obj = json.loads(part) - if isinstance(obj, dict): - if obj.get("model"): - last_model_seen = str(obj.get("model")) - - # Track reasoning tokens for Responses API - if usage := obj.get("usage", {}): - if isinstance(usage, dict) and "reasoning_tokens" in usage: - reasoning_tokens += usage.get("reasoning_tokens", 0) - except json.JSONDecodeError: - pass - except Exception: - pass - yield chunk - - logger.debug( - "Responses API streaming completed, analyzing usage data", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "chunks_count": len(stored_chunks), - "reasoning_tokens": reasoning_tokens, - }, - ) - - # Process final usage data - for i in range(len(stored_chunks) - 1, -1, -1): - chunk = stored_chunks[i] - if not chunk: - continue - try: - events = re.split(b"data: ", chunk) - for event_data in events: - if not event_data or event_data.strip() in (b"[DONE]", b""): - continue + tail_buffer.append(chunk) + if len(tail_buffer) > MAX_TAIL_CHUNKS: + tail_buffer.pop(0) + + # Post-stream processing + tail_content = b"".join(tail_buffer).decode("utf-8", errors="ignore") + usage_data, model = self._extract_usage_from_tail(tail_content) + + if usage_data: + async with create_session() as new_session: + fresh_key = await new_session.get(key.__class__, key.hashed_key) + if fresh_key: try: - data = json.loads(event_data) - if isinstance(data, dict) and data.get("model"): - last_model_seen = str(data.get("model")) - if isinstance(data, dict) and isinstance( - data.get("usage"), dict - ): - # Include reasoning tokens in usage calculation - async with create_session() as new_session: - fresh_key = await new_session.get( - key.__class__, key.hashed_key - ) - if fresh_key: - try: - cost_data = ( - await adjust_payment_for_tokens( - fresh_key, - data, - new_session, - max_cost_for_model, - ) - ) - usage_finalized = True - logger.info( - "Payment adjustment completed for Responses API streaming", - extra={ - "key_hash": key.hashed_key[:8] - + "...", - "cost_data": cost_data, - "model": last_model_seen, - "reasoning_tokens": reasoning_tokens, - "balance_after_adjustment": fresh_key.balance, - }, - ) - yield f"data: {json.dumps({'cost': cost_data})}\\n\\n".encode() - except Exception as cost_error: - logger.error( - "Error adjusting payment for Responses API streaming tokens", - extra={ - "error": str(cost_error), - "error_type": type( - cost_error - ).__name__, - "key_hash": key.hashed_key[:8] - + "...", - }, - ) - break - except json.JSONDecodeError: - continue - except Exception as e: - logger.error( - "Error processing Responses API streaming response chunk", - extra={ - "error": str(e), - "error_type": type(e).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) + # We need to reconstruct a response object for adjust_payment_for_tokens + # or call it with usage directly if supported. + # adjust_payment_for_tokens expects a dict with "usage" and "model" keys usually. - if not usage_finalized: - maybe_cost_event = await finalize_without_usage() - if maybe_cost_event is not None: - yield maybe_cost_event + data = { + "usage": usage_data, + "model": model or "unknown", + } + + cost_data = await adjust_payment_for_tokens( + fresh_key, + data, + new_session, + max_cost_for_model, + ) + + logger.info( + "Payment adjustment completed for Responses API streaming", + extra={ + "key_hash": key.hashed_key[:8] + "...", + "cost_data": cost_data, + "model": model, + "balance_after_adjustment": fresh_key.balance, + }, + ) + # We yield the cost data event at the end + yield f"data: {json.dumps({'cost': cost_data})}\\n\\n".encode() + except Exception as cost_error: + logger.error( + "Error adjusting payment for Responses API streaming tokens", + extra={ + "error": str(cost_error), + "key_hash": key.hashed_key[:8] + "...", + }, + ) except Exception as stream_error: logger.warning( - "Responses API streaming interrupted; finalizing without usage", + "Responses API streaming interrupted", extra={ "error": str(stream_error), - "error_type": type(stream_error).__name__, "key_hash": key.hashed_key[:8] + "...", }, ) - await finalize_without_usage() raise # Remove inaccurate encoding headers from upstream response @@ -933,8 +769,8 @@ async def handle_non_streaming_responses_completion( "model": response_json.get("model", "unknown"), "has_usage": "usage" in response_json, "has_reasoning_tokens": "usage" in response_json - and isinstance(response_json.get("usage"), dict) - and "reasoning_tokens" in response_json["usage"], + and isinstance(response_json.get("usage"), dict) + and "reasoning_tokens" in response_json["usage"], }, ) @@ -1679,21 +1515,8 @@ async def handle_x_cashu_streaming_response( if "content-encoding" in response_headers: del response_headers["content-encoding"] - usage_data = None - model = None - + usage_data, model = self._extract_usage_from_tail(content_str) lines = content_str.strip().split("\n") - for line in lines: - if line.startswith("data: "): - try: - data_json = json.loads(line[6:]) - if "usage" in data_json: - usage_data = data_json["usage"] - model = data_json.get("model") - elif "model" in data_json and not model: - model = data_json["model"] - except json.JSONDecodeError: - continue if usage_data and model: logger.debug( @@ -2480,7 +2303,7 @@ async def handle_x_cashu_streaming_responses_response( extra={ "amount": amount, "unit": unit, - "content_lines": len(content_str.strip().split("\\n")), + "content_lines": len(content_str.strip().split("\n")), }, ) @@ -2490,25 +2313,14 @@ async def handle_x_cashu_streaming_responses_response( if "content-encoding" in response_headers: del response_headers["content-encoding"] - usage_data = None - model = None + usage_data, model = self._extract_usage_from_tail(content_str) reasoning_tokens = 0 - lines = content_str.strip().split("\\n") - for line in lines: - if line.startswith("data: "): - try: - data_json = json.loads(line[6:]) - if "usage" in data_json: - usage_data = data_json["usage"] - model = data_json.get("model") - # Track reasoning tokens for Responses API - if isinstance(usage_data, dict) and "reasoning_tokens" in usage_data: - reasoning_tokens = usage_data.get("reasoning_tokens", 0) - elif "model" in data_json and not model: - model = data_json["model"] - except json.JSONDecodeError: - continue + # Responses API specific: check for reasoning tokens + if usage_data and "reasoning_tokens" in usage_data: + reasoning_tokens = usage_data.get("reasoning_tokens", 0) + + lines = content_str.strip().split("\n") if usage_data and model: logger.debug( @@ -2584,7 +2396,7 @@ async def handle_x_cashu_streaming_responses_response( async def generate() -> AsyncGenerator[bytes, None]: for line in lines: - yield (line + "\\n").encode("utf-8") + yield (line + "\n").encode("utf-8") return StreamingResponse( generate(),