diff --git a/routstr/upstream/base.py b/routstr/upstream/base.py index 6586a11b..cdbdb1d0 100644 --- a/routstr/upstream/base.py +++ b/routstr/upstream/base.py @@ -4,7 +4,6 @@ import asyncio import json -import re from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Mapping @@ -356,155 +355,64 @@ async def handle_streaming_chat_completion( async def stream_with_cost( max_cost_for_model: int, ) -> AsyncGenerator[bytes, None]: - stored_chunks: list[bytes] = [] - usage_finalized: bool = False - last_model_seen: str | None = None - - async def finalize_without_usage() -> bytes | None: - nonlocal usage_finalized - if usage_finalized: - return None - async with create_session() as new_session: - fresh_key = await new_session.get(key.__class__, key.hashed_key) - if not fresh_key: - return None - try: - fallback: dict = { - "model": last_model_seen or "unknown", - "usage": None, - } - cost_data = await adjust_payment_for_tokens( - fresh_key, fallback, new_session, max_cost_for_model - ) - usage_finalized = True - logger.info( - "Finalized streaming payment without explicit usage", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "cost_data": cost_data, - "balance_after_adjustment": fresh_key.balance, - }, - ) - return f"data: {json.dumps({'cost': cost_data})}\n\n".encode() - except Exception as cost_error: - logger.error( - "Error finalizing payment without usage", - extra={ - "error": str(cost_error), - "error_type": type(cost_error).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) - return None + tail_buffer: list[bytes] = [] + MAX_TAIL_CHUNKS = 5 try: async for chunk in response.aiter_bytes(): - stored_chunks.append(chunk) - try: - for part in re.split(b"data: ", chunk): - if not part or part.strip() in (b"[DONE]", b""): - continue - try: - obj = json.loads(part) - if isinstance(obj, dict) and obj.get("model"): - last_model_seen = str(obj.get("model")) - except json.JSONDecodeError: - pass - except Exception: - pass - yield chunk + tail_buffer.append(chunk) + if len(tail_buffer) > MAX_TAIL_CHUNKS: + tail_buffer.pop(0) - logger.debug( - "Streaming completed, analyzing usage data", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "chunks_count": len(stored_chunks), - }, - ) + # Post-stream processing + tail_content = b"".join(tail_buffer).decode("utf-8", errors="ignore") + usage_data, model = self._extract_usage_from_tail(tail_content) - for i in range(len(stored_chunks) - 1, -1, -1): - chunk = stored_chunks[i] - if not chunk: - continue - try: - events = re.split(b"data: ", chunk) - for event_data in events: - if not event_data or event_data.strip() in (b"[DONE]", b""): - continue - try: - data = json.loads(event_data) - if isinstance(data, dict) and data.get("model"): - last_model_seen = str(data.get("model")) - if isinstance(data, dict) and isinstance( - data.get("usage"), dict - ): - async with create_session() as new_session: - fresh_key = await new_session.get( - key.__class__, key.hashed_key - ) - if fresh_key: - try: - cost_data = ( - await adjust_payment_for_tokens( - fresh_key, - data, - new_session, - max_cost_for_model, - ) - ) - usage_finalized = True - logger.info( - "Payment adjustment completed for streaming", - extra={ - "key_hash": key.hashed_key[:8] - + "...", - "cost_data": cost_data, - "model": last_model_seen, - "balance_after_adjustment": fresh_key.balance, - }, - ) - yield f"data: {json.dumps({'cost': cost_data})}\n\n".encode() - except Exception as cost_error: - logger.error( - "Error adjusting payment for streaming tokens", - extra={ - "error": str(cost_error), - "error_type": type( - cost_error - ).__name__, - "key_hash": key.hashed_key[:8] - + "...", - }, - ) - break - except json.JSONDecodeError: - continue - except Exception as e: - logger.error( - "Error processing streaming response chunk", - extra={ - "error": str(e), - "error_type": type(e).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) + if usage_data: + # Calculate final cost using usage data extracted from stream tail - if not usage_finalized: - maybe_cost_event = await finalize_without_usage() - if maybe_cost_event is not None: - yield maybe_cost_event + async with create_session() as new_session: + fresh_key = await new_session.get(key.__class__, key.hashed_key) + if fresh_key: + try: + # We need to reconstruct a response object for adjust_payment_for_tokens + # or call it with usage directly if supported. + # adjust_payment_for_tokens expects a dict with "usage" and "model" keys usually. + + data = { + "usage": usage_data, + "model": model or "unknown", + } + + cost_data = await adjust_payment_for_tokens( + fresh_key, + data, + new_session, + max_cost_for_model, + ) + + logger.info( + "Payment adjustment completed for streaming", + extra={ + "key_hash": key.hashed_key[:8] + "...", + "cost_data": cost_data, + "balance_after_adjustment": fresh_key.balance, + }, + ) + # We yield the cost data event at the end + yield f"data: {json.dumps({'cost': cost_data})}\n\n".encode() + except Exception as cost_error: + logger.error( + "Error adjusting payment for streaming tokens", + extra={"error": str(cost_error)}, + ) except Exception as stream_error: logger.warning( - "Streaming interrupted; finalizing without usage", - extra={ - "error": str(stream_error), - "error_type": type(stream_error).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, + "Streaming interrupted", + extra={"error": str(stream_error)}, ) - await finalize_without_usage() raise response_headers = dict(response.headers) @@ -517,6 +425,27 @@ async def finalize_without_usage() -> bytes | None: headers=response_headers, ) + def _extract_usage_from_tail( + self, tail_content: str + ) -> tuple[dict | None, str | None]: + """Extract usage and model from tail content of a stream.""" + usage_data = None + model = None + + lines = tail_content.strip().split("\n") + for line in lines: + if line.startswith("data: "): + try: + data = json.loads(line[6:]) + if isinstance(data, dict): + if "usage" in data: + usage_data = data["usage"] + if "model" in data: + model = data["model"] + except json.JSONDecodeError: + continue + return usage_data, model + async def handle_non_streaming_chat_completion( self, response: httpx.Response, @@ -626,157 +555,62 @@ async def handle_streaming_responses_api_completion( async def stream_with_responses_api_cost( max_cost_for_model: int, ) -> AsyncGenerator[bytes, None]: - stored_chunks: list[bytes] = [] - usage_finalized: bool = False - last_model_seen: str | None = None - - async def finalize_without_usage() -> bytes | None: - nonlocal usage_finalized - if usage_finalized: - return None - async with create_session() as new_session: - fresh_key = await new_session.get(key.__class__, key.hashed_key) - if not fresh_key: - return None - try: - fallback: dict = { - "model": last_model_seen or "unknown", - "usage": None, - } - cost_data = await adjust_payment_for_tokens( - fresh_key, fallback, new_session, max_cost_for_model - ) - usage_finalized = True - logger.info( - "Finalized Responses API streaming payment without explicit usage", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "cost_data": cost_data, - "balance_after_adjustment": fresh_key.balance, - }, - ) - return f"data: {json.dumps({'cost': cost_data})}\n\n".encode() - except Exception as cost_error: - logger.error( - "Error finalizing Responses API payment without usage", - extra={ - "error": str(cost_error), - "error_type": type(cost_error).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) - return None + tail_buffer: list[bytes] = [] + MAX_TAIL_CHUNKS = 5 try: async for chunk in response.aiter_bytes(): - stored_chunks.append(chunk) - try: - for part in re.split(b"data: ", chunk): - if not part or part.strip() in (b"[DONE]", b""): - continue - try: - obj = json.loads(part) - if isinstance(obj, dict): - if obj.get("model"): - last_model_seen = str(obj.get("model")) - except json.JSONDecodeError: - pass - except Exception: - pass - yield chunk - - logger.debug( - "Responses API streaming completed, analyzing usage data", - extra={ - "key_hash": key.hashed_key[:8] + "...", - "chunks_count": len(stored_chunks), - }, - ) - - # Process final usage data - for i in range(len(stored_chunks) - 1, -1, -1): - chunk = stored_chunks[i] - if not chunk: - continue - try: - events = re.split(b"data: ", chunk) - for event_data in events: - if not event_data or event_data.strip() in (b"[DONE]", b""): - continue + tail_buffer.append(chunk) + if len(tail_buffer) > MAX_TAIL_CHUNKS: + tail_buffer.pop(0) + + # Post-stream processing + tail_content = b"".join(tail_buffer).decode("utf-8", errors="ignore") + usage_data, model = self._extract_usage_from_tail(tail_content) + + if usage_data: + async with create_session() as new_session: + fresh_key = await new_session.get(key.__class__, key.hashed_key) + if fresh_key: try: - data = json.loads(event_data) - if isinstance(data, dict) and data.get("model"): - last_model_seen = str(data.get("model")) - if isinstance(data, dict) and isinstance( - data.get("usage"), dict - ): - async with create_session() as new_session: - fresh_key = await new_session.get( - key.__class__, key.hashed_key - ) - if fresh_key: - try: - cost_data = ( - await adjust_payment_for_tokens( - fresh_key, - data, - new_session, - max_cost_for_model, - ) - ) - usage_finalized = True - logger.info( - "Payment adjustment completed for Responses API streaming", - extra={ - "key_hash": key.hashed_key[:8] - + "...", - "cost_data": cost_data, - "model": last_model_seen, - "balance_after_adjustment": fresh_key.balance, - }, - ) - yield f"data: {json.dumps({'cost': cost_data})}\n\n".encode() - except Exception as cost_error: - logger.error( - "Error adjusting payment for Responses API streaming tokens", - extra={ - "error": str(cost_error), - "error_type": type( - cost_error - ).__name__, - "key_hash": key.hashed_key[:8] - + "...", - }, - ) - break - except json.JSONDecodeError: - continue - except Exception as e: - logger.error( - "Error processing Responses API streaming response chunk", - extra={ - "error": str(e), - "error_type": type(e).__name__, - "key_hash": key.hashed_key[:8] + "...", - }, - ) - - if not usage_finalized: - maybe_cost_event = await finalize_without_usage() - if maybe_cost_event is not None: - yield maybe_cost_event + data = { + "usage": usage_data, + "model": model or "unknown", + } + cost_data = await adjust_payment_for_tokens( + fresh_key, + data, + new_session, + max_cost_for_model, + ) + logger.info( + "Payment adjustment completed for Responses API streaming", + extra={ + "key_hash": key.hashed_key[:8] + "...", + "cost_data": cost_data, + "model": model, + "balance_after_adjustment": fresh_key.balance, + }, + ) + yield f"data: {json.dumps({'cost': cost_data})}\n\n".encode() + except Exception as cost_error: + logger.error( + "Error adjusting payment for Responses API streaming tokens", + extra={ + "error": str(cost_error), + "key_hash": key.hashed_key[:8] + "...", + }, + ) except Exception as stream_error: logger.warning( - "Responses API streaming interrupted; finalizing without usage", + "Responses API streaming interrupted", extra={ "error": str(stream_error), - "error_type": type(stream_error).__name__, "key_hash": key.hashed_key[:8] + "...", }, ) - await finalize_without_usage() raise response_headers = dict(response.headers) diff --git a/routstr/upstream/handlers/cashu_payment_handler.py b/routstr/upstream/handlers/cashu_payment_handler.py index 82d5edc1..d00e763c 100644 --- a/routstr/upstream/handlers/cashu_payment_handler.py +++ b/routstr/upstream/handlers/cashu_payment_handler.py @@ -220,20 +220,67 @@ async def _handle_chat_completion_response( get_x_cashu_cost_func: Callable, ) -> Response | StreamingResponse: """Handle chat completion response processing.""" + # For X-Cashu requests, we force non-streaming response even if upstream is streaming, + # because we need to calculate the refund and include it in the headers. + # Refund token must be in the headers, which are sent before the body. + content = await response.aread() content_str = content.decode("utf-8") if isinstance(content, bytes) else content - is_streaming = ChatCompletionProcessor.is_streaming_response(content_str) + # Check if it was streaming content to process usage correctly + is_streaming_content = ChatCompletionProcessor.is_streaming_response( + content_str + ) - if is_streaming: - return await self._handle_streaming_response( - content_str, - response, - amount, - unit, - max_cost_for_model, - mint, - get_x_cashu_cost_func, + if is_streaming_content: + # Buffer streaming content to calculate refund before sending headers. + # Headers (containing refund token) must be sent before body, preventing + # true streaming while supporting dynamic refunds. + + # Extract usage data from buffered content + usage_data, model = ChatCompletionProcessor.extract_usage_from_streaming( + content_str + ) + + if usage_data and model: + try: + response_data = {"usage": usage_data, "model": model} + cost_data = await get_x_cashu_cost_func( + response_data, max_cost_for_model + ) + + if cost_data: + refund_amount = self._calculate_refund_amount( + amount, unit, cost_data.total_msats + ) + + if refund_amount > 0: + refund_token = await self.send_refund( + refund_amount, unit, mint + ) + # Refund token will be added to new response headers below + except Exception as e: + logger.error( + "Error calculating refund for buffered streaming response", + extra={"error": str(e)}, + ) + + # We return a StreamingResponse that yields the buffered content + # This satisfies the client expecting a stream, while allowing us to set headers. + + response_headers = ChatCompletionProcessor.clean_response_headers( + dict(response.headers) + ) + + if "refund_token" in locals() and refund_token: + response_headers["X-Cashu"] = refund_token + + lines = content_str.strip().split("\n") + return StreamingResponse( + ChatCompletionProcessor.create_streaming_generator(lines), + status_code=response.status_code, + headers=response_headers, + media_type="text/event-stream", # Keep original media type if possible, or text/event-stream ) else: return await self._handle_non_streaming_response( @@ -246,55 +293,6 @@ async def _handle_chat_completion_response( get_x_cashu_cost_func, ) - async def _handle_streaming_response( - self, - content_str: str, - response: httpx.Response, - amount: int, - unit: str, - max_cost_for_model: int, - mint: str | None, - get_x_cashu_cost_func: Callable, - ) -> StreamingResponse: - """Handle streaming chat completion response with refund calculation.""" - response_headers = ChatCompletionProcessor.clean_response_headers( - dict(response.headers) - ) - - usage_data, model = ChatCompletionProcessor.extract_usage_from_streaming( - content_str - ) - - if usage_data and model: - try: - response_data = {"usage": usage_data, "model": model} - cost_data = await get_x_cashu_cost_func( - response_data, max_cost_for_model - ) - - if cost_data: - refund_amount = self._calculate_refund_amount( - amount, unit, cost_data.total_msats - ) - - if refund_amount > 0: - refund_token = await self.send_refund(refund_amount, unit, mint) - response_headers["X-Cashu"] = refund_token - - except Exception as e: - logger.error( - "Error calculating cost for streaming response", - extra={"error": str(e), "error_type": type(e).__name__}, - ) - - lines = content_str.strip().split("\n") - return StreamingResponse( - ChatCompletionProcessor.create_streaming_generator(lines), - status_code=response.status_code, - headers=response_headers, - media_type="text/plain", - ) - async def _handle_non_streaming_response( self, content_str: str, @@ -435,20 +433,53 @@ async def _handle_responses_api_completion( get_x_cashu_cost_func: Callable, ) -> Response | StreamingResponse: """Handle Responses API completion response processing.""" + # Force non-streaming for X-Cashu to handle refunds in headers content = await response.aread() content_str = content.decode("utf-8") if isinstance(content, bytes) else content - is_streaming = ResponsesApiProcessor.is_streaming_response(content_str) + is_streaming_content = ResponsesApiProcessor.is_streaming_response(content_str) - if is_streaming: - return await self._handle_streaming_responses_api_response( - content_str, - response, - amount, - unit, - max_cost_for_model, - mint, - get_x_cashu_cost_func, + if is_streaming_content: + usage_data, model = ( + ResponsesApiProcessor.extract_usage_with_reasoning_tokens(content_str) + ) + + refund_token = None + if usage_data and model: + try: + response_data = {"usage": usage_data, "model": model} + cost_data = await get_x_cashu_cost_func( + response_data, max_cost_for_model + ) + + if cost_data: + refund_amount = self._calculate_refund_amount( + amount, unit, cost_data.total_msats + ) + + if refund_amount > 0: + refund_token = await self.send_refund( + refund_amount, unit, mint + ) + except Exception as e: + logger.error( + "Error calculating refund for buffered Responses API response", + extra={"error": str(e)}, + ) + + response_headers = ResponsesApiProcessor.clean_response_headers( + dict(response.headers) + ) + + if refund_token: + response_headers["X-Cashu"] = refund_token + + lines = content_str.strip().split("\n") + return StreamingResponse( + ResponsesApiProcessor.create_streaming_generator(lines), + status_code=response.status_code, + headers=response_headers, + media_type="text/event-stream", ) else: return await self._handle_non_streaming_responses_api_response( @@ -461,56 +492,6 @@ async def _handle_responses_api_completion( get_x_cashu_cost_func, ) - async def _handle_streaming_responses_api_response( - self, - content_str: str, - response: httpx.Response, - amount: int, - unit: str, - max_cost_for_model: int, - mint: str | None, - get_x_cashu_cost_func: Callable, - ) -> StreamingResponse: - """Handle streaming Responses API response with refund calculation.""" - response_headers = ResponsesApiProcessor.clean_response_headers( - dict(response.headers) - ) - - # Extract usage data (ignoring reasoning tokens as per requirement) - usage_data, model = ResponsesApiProcessor.extract_usage_with_reasoning_tokens( - content_str - ) - - if usage_data and model: - try: - response_data = {"usage": usage_data, "model": model} - cost_data = await get_x_cashu_cost_func( - response_data, max_cost_for_model - ) - - if cost_data: - refund_amount = self._calculate_refund_amount( - amount, unit, cost_data.total_msats - ) - - if refund_amount > 0: - refund_token = await self.send_refund(refund_amount, unit, mint) - response_headers["X-Cashu"] = refund_token - - except Exception as e: - logger.error( - "Error calculating cost for streaming Responses API response", - extra={"error": str(e), "error_type": type(e).__name__}, - ) - - lines = content_str.strip().split("\n") - return StreamingResponse( - ResponsesApiProcessor.create_streaming_generator(lines), - status_code=response.status_code, - headers=response_headers, - media_type="text/plain", - ) - async def _handle_non_streaming_responses_api_response( self, content_str: str, @@ -549,11 +530,13 @@ async def _handle_non_streaming_responses_api_response( except json.JSONDecodeError: # Emergency refund on parse error emergency_refund = amount - refund_token = await send_token(emergency_refund, unit=unit, mint_url=mint) + refund_token = await self.send_refund( + emergency_refund, unit=unit, mint=mint + ) response_headers["X-Cashu"] = refund_token logger.warning( - "Emergency refund issued for Responses API due to JSON parse error", + "Emergency refund issued due to JSON parse error in Responses API", extra={ "original_amount": amount, "refund_amount": emergency_refund, @@ -571,7 +554,7 @@ async def _handle_non_streaming_responses_api_response( def _create_default_streaming_response( self, response: httpx.Response ) -> StreamingResponse: - """Create default streaming response for non-responses endpoints.""" + """Create default streaming response for non-chat endpoints.""" background_tasks = BackgroundTasks() background_tasks.add_task(response.aclose)