diff --git a/kiro/converters_anthropic.py b/kiro/converters_anthropic.py index 4c2d38b1..f32e0f07 100644 --- a/kiro/converters_anthropic.py +++ b/kiro/converters_anthropic.py @@ -424,4 +424,4 @@ def anthropic_to_kiro( inject_thinking=True, ) - return result.payload + return result diff --git a/kiro/converters_core.py b/kiro/converters_core.py index 21cf758b..4f8c63c7 100644 --- a/kiro/converters_core.py +++ b/kiro/converters_core.py @@ -30,6 +30,7 @@ to convert their formats to Kiro API format. """ +import hashlib import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -93,9 +94,11 @@ class KiroPayloadResult: Attributes: payload: The complete Kiro API payload tool_documentation: Documentation for tools with long descriptions (to add to system prompt) + tool_name_mapping: Mapping of truncated tool names back to originals (short→original) """ payload: Dict[str, Any] tool_documentation: str = "" + tool_name_mapping: Dict[str, str] = field(default_factory=dict) # ================================================================================================== @@ -491,46 +494,34 @@ def process_tools_with_long_descriptions( return processed_tools if processed_tools else None, tool_documentation -def validate_tool_names(tools: Optional[List[UnifiedTool]]) -> None: +TOOL_NAME_MAX_LENGTH = 64 + + +def _make_short_name(name: str) -> str: + """Deterministically shorten a tool name to fit within TOOL_NAME_MAX_LENGTH.""" + suffix = hashlib.md5(name.encode()).hexdigest()[:8] + return name[:TOOL_NAME_MAX_LENGTH - 9] + "_" + suffix + + +def truncate_tool_names( + tools: Optional[List[UnifiedTool]], +) -> Dict[str, str]: """ - Validates tool names against Kiro API 64-character limit. - - Logs WARNING for each problematic tool and raises ValueError - with complete list of violations. - - Args: - tools: List of tools to validate - - Raises: - ValueError: If any tool name exceeds 64 characters - - Example: - >>> validate_tool_names([UnifiedTool(name="short_name", description="test")]) - # No error - >>> validate_tool_names([UnifiedTool(name="a" * 70, description="test")]) - # Raises ValueError with detailed message + Truncate tool names exceeding 64 characters in-place. + + Returns a mapping {short_name: original_name} for names that were changed. """ if not tools: - return - - problematic_tools = [] + return {} + + mapping: Dict[str, str] = {} for tool in tools: - if len(tool.name) > 64: - problematic_tools.append((tool.name, len(tool.name))) - - if problematic_tools: - # Build detailed error message for client (no logging here - routes will log) - tool_list = "\n".join([ - f" - '{name}' ({length} characters)" - for name, length in problematic_tools - ]) - - raise ValueError( - f"Tool name(s) exceed Kiro API limit of 64 characters:\n" - f"{tool_list}\n\n" - f"Solution: Use shorter tool names (max 64 characters).\n" - f"Example: 'get_user_data' instead of 'get_authenticated_user_profile_data_with_extended_information_about_it'" - ) + if len(tool.name) > TOOL_NAME_MAX_LENGTH: + short = _make_short_name(tool.name) + mapping[short] = tool.name + logger.debug(f"Truncated tool name '{tool.name}' -> '{short}'") + tool.name = short + return mapping def convert_tools_to_kiro_format(tools: Optional[List[UnifiedTool]]) -> List[Dict[str, Any]]: @@ -1370,8 +1361,10 @@ def build_kiro_payload( # Process tools with long descriptions processed_tools, tool_documentation = process_tools_with_long_descriptions(tools) - # Validate tool names against Kiro API 64-character limit - validate_tool_names(processed_tools) + # Truncate tool names that exceed 64-character Kiro API limit + tool_name_mapping = truncate_tool_names(processed_tools) + if tool_name_mapping: + logger.info(f"Truncated {len(tool_name_mapping)} tool name(s) exceeding {TOOL_NAME_MAX_LENGTH} chars") # Add tool documentation to system prompt if present full_system_prompt = system_prompt @@ -1429,6 +1422,17 @@ def build_kiro_payload( history = build_kiro_history(history_messages, model_id) + # Apply tool name truncation to tool_use blocks in history + if tool_name_mapping: + reverse = {v: k for k, v in tool_name_mapping.items()} + for entry in history: + arm = entry.get("assistantResponseMessage") + if arm: + for tu in arm.get("toolUses", []): + orig = tu.get("name", "") + if orig in reverse: + tu["name"] = reverse[orig] + # Current message (the last one) current_message = merged_messages[-1] current_content = extract_text_content(current_message.content) @@ -1519,4 +1523,4 @@ def build_kiro_payload( if profile_arn: payload["profileArn"] = profile_arn - return KiroPayloadResult(payload=payload, tool_documentation=tool_documentation) \ No newline at end of file + return KiroPayloadResult(payload=payload, tool_documentation=tool_documentation, tool_name_mapping=tool_name_mapping) \ No newline at end of file diff --git a/kiro/converters_openai.py b/kiro/converters_openai.py index aad3b83c..5a49da1a 100644 --- a/kiro/converters_openai.py +++ b/kiro/converters_openai.py @@ -345,4 +345,4 @@ def build_kiro_payload( inject_thinking=True ) - return result.payload \ No newline at end of file + return result \ No newline at end of file diff --git a/kiro/routes_anthropic.py b/kiro/routes_anthropic.py index 1bc6bd10..49f5b171 100644 --- a/kiro/routes_anthropic.py +++ b/kiro/routes_anthropic.py @@ -50,7 +50,7 @@ ) from kiro.http_client import KiroHttpClient from kiro.utils import generate_conversation_id -from kiro.tokenizer import count_tools_tokens +from kiro.tokenizer import count_tokens # Import debug_logger try: @@ -257,11 +257,13 @@ async def messages( profile_arn_for_payload = auth_manager.profile_arn try: - kiro_payload = anthropic_to_kiro( + kiro_result = anthropic_to_kiro( request_data, conversation_id, profile_arn_for_payload ) + kiro_payload = kiro_result.payload + tool_name_mapping = kiro_result.tool_name_mapping except ValueError as e: logger.error(f"Conversion error: {e}") return JSONResponse( @@ -298,10 +300,11 @@ async def messages( shared_client = request.app.state.http_client http_client = KiroHttpClient(auth_manager, shared_client=shared_client) - # Prepare data for token counting - # Convert Pydantic models to dicts for tokenizer - messages_for_tokenizer = [msg.model_dump() for msg in request_data.messages] - tools_for_tokenizer = [tool.model_dump() for tool in request_data.tools] if request_data.tools else None + # Count prompt tokens from the full Kiro payload (system prompt + messages + tools) + kiro_payload_prompt_tokens = count_tokens( + kiro_request_body.decode('utf-8', errors='ignore'), + apply_claude_correction=False + ) try: # Make request to Kiro API (for both streaming and non-streaming modes) @@ -368,7 +371,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ): yield chunk except GeneratorExit: @@ -415,7 +419,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ) await http_client.close() @@ -449,4 +454,56 @@ async def stream_wrapper(): "message": f"Internal Server Error: {str(e)}" } } - ) \ No newline at end of file + ) + + +@router.post("/v1/messages/count_tokens", dependencies=[Depends(verify_anthropic_api_key)]) +async def count_tokens_endpoint( + request: Request, + request_data: AnthropicMessagesRequest, +): + """ + Anthropic Count Tokens API endpoint. + + Returns estimated token count for the given request payload. + Used by Claude Code to decide when to trigger conversation compaction. + + Builds the full Kiro payload and counts tokens on the serialized JSON, + consistent with the token counting approach used in the messages endpoint. + """ + logger.info(f"Request to /v1/messages/count_tokens (model={request_data.model}, messages={len(request_data.messages)})") + + auth_manager: KiroAuthManager = request.app.state.auth_manager + + # Build Kiro payload (same as messages endpoint) + conversation_id = generate_conversation_id() + profile_arn_for_payload = "" + if auth_manager.auth_type == AuthType.KIRO_DESKTOP and auth_manager.profile_arn: + profile_arn_for_payload = auth_manager.profile_arn + + try: + kiro_payload = anthropic_to_kiro( + request_data, + conversation_id, + profile_arn_for_payload + ).payload + except ValueError as e: + logger.error(f"Conversion error in count_tokens: {e}") + return JSONResponse( + status_code=400, + content={ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": str(e) + } + } + ) + + # Count tokens from the full serialized Kiro payload (same as messages endpoint) + kiro_request_body = json.dumps(kiro_payload, ensure_ascii=False, indent=2) + input_tokens = count_tokens(kiro_request_body, apply_claude_correction=False) + + logger.info(f"Token count estimate: {input_tokens} (payload size: {len(kiro_request_body)} chars)") + + return JSONResponse(content={"input_tokens": input_tokens}) diff --git a/kiro/routes_openai.py b/kiro/routes_openai.py index 301ae9b5..f129811d 100644 --- a/kiro/routes_openai.py +++ b/kiro/routes_openai.py @@ -50,6 +50,7 @@ from kiro.streaming_openai import stream_kiro_to_openai, collect_stream_response, stream_with_first_token_retry from kiro.http_client import KiroHttpClient from kiro.utils import generate_conversation_id +from kiro.tokenizer import count_tokens # Import debug_logger try: @@ -240,11 +241,13 @@ async def chat_completions(request: Request, request_data: ChatCompletionRequest profile_arn_for_payload = auth_manager.profile_arn try: - kiro_payload = build_kiro_payload( + kiro_result = build_kiro_payload( request_data, conversation_id, profile_arn_for_payload ) + kiro_payload = kiro_result.payload + tool_name_mapping = kiro_result.tool_name_mapping except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -324,10 +327,12 @@ async def chat_completions(request: Request, request_data: ChatCompletionRequest } ) - # Prepare data for fallback token counting - # Convert Pydantic models to dicts for tokenizer - messages_for_tokenizer = [msg.model_dump() for msg in request_data.messages] - tools_for_tokenizer = [tool.model_dump() for tool in request_data.tools] if request_data.tools else None + # Count prompt tokens from the full Kiro payload (system prompt + messages + tools) + # This matches what actually gets sent to the API, giving accurate token counts + kiro_payload_prompt_tokens = count_tokens( + kiro_request_body.decode('utf-8', errors='ignore'), + apply_claude_correction=False + ) if request_data.stream: # Streaming mode @@ -341,8 +346,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer, - request_tools=tools_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ): yield chunk except GeneratorExit: @@ -387,8 +392,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer, - request_tools=tools_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ) await http_client.close() diff --git a/kiro/streaming_anthropic.py b/kiro/streaming_anthropic.py index 979113ef..2177da2d 100644 --- a/kiro/streaming_anthropic.py +++ b/kiro/streaming_anthropic.py @@ -44,10 +44,9 @@ collect_stream_to_result, FirstTokenTimeoutError, KiroEvent, - calculate_tokens_from_context_usage, stream_with_first_token_retry, ) -from kiro.tokenizer import count_tokens, count_message_tokens, count_tools_tokens +from kiro.tokenizer import count_tokens from kiro.parsers import parse_bracket_tool_calls, deduplicate_tool_calls from kiro.config import FIRST_TOKEN_TIMEOUT, FIRST_TOKEN_MAX_RETRIES, FAKE_REASONING_HANDLING @@ -104,40 +103,37 @@ async def stream_kiro_to_anthropic( model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - conversation_id: Optional[str] = None + prompt_tokens: int = 0, + conversation_id: Optional[str] = None, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> AsyncGenerator[str, None]: """ Generator for converting Kiro stream to Anthropic SSE format. - + Parses Kiro AWS SSE stream and converts events to Anthropic format. Supports thinking content blocks when FAKE_REASONING_HANDLING=as_reasoning_content. - + Args: response: HTTP response with data stream model: Model name to include in response model_cache: Model cache for getting token limits auth_manager: Authentication manager first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for token counting) + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) conversation_id: Stable conversation ID for truncation recovery (optional) - + Yields: Strings in Anthropic SSE format - + Raises: FirstTokenTimeoutError: If first token not received within timeout """ message_id = generate_message_id() - input_tokens = 0 + input_tokens = prompt_tokens output_tokens = 0 full_content = "" full_thinking_content = "" - # Count input tokens from request messages - if request_messages: - input_tokens = count_message_tokens(request_messages, apply_claude_correction=False) - # Track content blocks - thinking block is index 0, text block is index 1 (when thinking enabled) current_block_index = 0 thinking_block_started = False @@ -302,6 +298,10 @@ async def stream_kiro_to_anthropic( tool_name = tool.get("function", {}).get("name", "") or tool.get("name", "") tool_input = tool.get("function", {}).get("arguments", {}) or tool.get("input", {}) + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + # Check if this tool was truncated if tool.get('_truncation_detected'): truncated_tools.append({ @@ -385,6 +385,10 @@ async def stream_kiro_to_anthropic( tool_name = tc.get("function", {}).get("name", "") tool_input = tc.get("function", {}).get("arguments", {}) + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + if isinstance(tool_input, str): try: tool_input = json.loads(tool_input) @@ -456,13 +460,9 @@ async def stream_kiro_to_anthropic( # Calculate output tokens output_tokens = count_tokens(full_content + full_thinking_content) - - # Calculate total tokens from context usage if available - if context_usage_percentage is not None: - prompt_tokens, total_tokens, _, _ = calculate_tokens_from_context_usage( - context_usage_percentage, output_tokens, model_cache, model - ) - input_tokens = prompt_tokens + + # input_tokens already set from pre-counted prompt_tokens (full Kiro payload). + # Don't override with contextUsagePercentage — it's unreliable. # Determine stop reason stop_reason = "tool_use" if tool_blocks else "end_turn" @@ -545,29 +545,28 @@ async def collect_anthropic_response( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None + prompt_tokens: int = 0, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> dict: """ Collect full response from Kiro stream in Anthropic format. - + Used for non-streaming mode. - + Args: response: HTTP response with stream model: Model name model_cache: Model cache auth_manager: Authentication manager - request_messages: Original request messages (for token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Returns: Dictionary with full response in Anthropic Messages format """ message_id = generate_message_id() - - # Count input tokens - input_tokens = 0 - if request_messages: - input_tokens = count_message_tokens(request_messages, apply_claude_correction=False) + + # Use pre-counted prompt tokens + input_tokens = prompt_tokens # Collect stream result result = await collect_stream_to_result(response) @@ -601,6 +600,10 @@ async def collect_anthropic_response( tool_name = tc.get("function", {}).get("name", "") or tc.get("name", "") tool_input = tc.get("function", {}).get("arguments", {}) or tc.get("input", {}) + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + if isinstance(tool_input, str): try: tool_input = json.loads(tool_input) @@ -617,12 +620,8 @@ async def collect_anthropic_response( # Calculate output tokens output_tokens = count_tokens(result.content + result.thinking_content) - # Calculate from context usage if available - if result.context_usage_percentage is not None: - prompt_tokens, _, _, _ = calculate_tokens_from_context_usage( - result.context_usage_percentage, output_tokens, model_cache, model - ) - input_tokens = prompt_tokens + # input_tokens already set from pre-counted prompt_tokens (full Kiro payload). + # Don't override with contextUsagePercentage — it's unreliable. # Determine stop reason stop_reason = "tool_use" if result.tool_calls else "end_turn" @@ -655,18 +654,17 @@ async def stream_with_first_token_retry_anthropic( auth_manager: "KiroAuthManager", max_retries: int = FIRST_TOKEN_MAX_RETRIES, first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Streaming with automatic retry on first token timeout for Anthropic API. - + If model doesn't respond within first_token_timeout seconds, request is cancelled and a new one is made. Maximum max_retries attempts. - + This is seamless for user - they just see a delay, but eventually get a response (or error after all attempts). - + Args: make_request: Function to create new HTTP request model: Model name @@ -674,12 +672,11 @@ async def stream_with_first_token_retry_anthropic( auth_manager: Authentication manager max_retries: Maximum number of attempts first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in Anthropic SSE format - + Raises: Exception with Anthropic error format after exhausting all attempts """ @@ -711,7 +708,7 @@ async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None model_cache, auth_manager, first_token_timeout=first_token_timeout, - request_messages=request_messages + prompt_tokens=prompt_tokens ): yield chunk diff --git a/kiro/streaming_openai.py b/kiro/streaming_openai.py index 5153bc82..2ae02704 100644 --- a/kiro/streaming_openai.py +++ b/kiro/streaming_openai.py @@ -30,7 +30,7 @@ import json import time -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Awaitable, Optional +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Awaitable, Dict, Optional import httpx from fastapi import HTTPException @@ -43,14 +43,13 @@ FIRST_TOKEN_MAX_RETRIES, FAKE_REASONING_HANDLING, ) -from kiro.tokenizer import count_tokens, count_message_tokens, count_tools_tokens +from kiro.tokenizer import count_tokens # Import from streaming_core - reuse shared parsing logic from kiro.streaming_core import ( parse_kiro_stream, FirstTokenTimeoutError, KiroEvent, - calculate_tokens_from_context_usage, stream_with_first_token_retry as stream_with_first_token_retry_core, ) @@ -76,9 +75,9 @@ async def stream_kiro_to_openai_internal( model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None, - conversation_id: Optional[str] = None + prompt_tokens: int = 0, + conversation_id: Optional[str] = None, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> AsyncGenerator[str, None]: """ Internal generator for converting Kiro stream to OpenAI format. @@ -96,9 +95,7 @@ async def stream_kiro_to_openai_internal( model_cache: Model cache for getting token limits auth_manager: Authentication manager first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - conversation_id: Stable conversation ID for truncation recovery (optional) + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) conversation_id: Stable conversation ID for truncation recovery (optional) Yields: @@ -224,24 +221,9 @@ async def stream_kiro_to_openai_internal( # Count completion_tokens (output) using tiktoken completion_tokens = count_tokens(full_content + full_thinking_content) - - # Calculate total_tokens based on context_usage_percentage from Kiro API - # context_usage shows TOTAL percentage of context usage (input + output) - prompt_tokens, total_tokens, prompt_source, total_source = calculate_tokens_from_context_usage( - context_usage_percentage, completion_tokens, model_cache, model - ) - - # Fallback: Kiro API didn't return context_usage, use tiktoken - # Count prompt_tokens from original messages - # IMPORTANT: Don't apply correction coefficient for prompt_tokens, - # as it was calibrated for completion_tokens - if prompt_source == "unknown" and request_messages: - prompt_tokens = count_message_tokens(request_messages, apply_claude_correction=False) - if request_tools: - prompt_tokens += count_tools_tokens(request_tools, apply_claude_correction=False) - total_tokens = prompt_tokens + completion_tokens - prompt_source = "tiktoken" - total_source = "tiktoken" + + # Use pre-counted prompt_tokens from the full Kiro payload + total_tokens = prompt_tokens + completion_tokens # Send tool calls if present if all_tool_calls: @@ -257,6 +239,10 @@ async def stream_kiro_to_openai_internal( tool_name = func.get("name") or "" tool_args = func.get("arguments") or "{}" + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + logger.debug(f"Tool call [{idx}] '{tool_name}': id={tc.get('id')}, args_length={len(tool_args)}") indexed_tc = { @@ -329,9 +315,9 @@ async def stream_kiro_to_openai_internal( # Log final token values being sent to client logger.debug( f"[Usage] {model}: " - f"prompt_tokens={prompt_tokens} ({prompt_source}), " + f"prompt_tokens={prompt_tokens} (payload tiktoken), " f"completion_tokens={completion_tokens} (tiktoken), " - f"total_tokens={total_tokens} ({total_source})" + f"total_tokens={total_tokens} (sum)" ) yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" @@ -374,31 +360,31 @@ async def stream_kiro_to_openai( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> AsyncGenerator[str, None]: """ Generator for converting Kiro stream to OpenAI format. - + This is a wrapper over stream_kiro_to_openai_internal that does NOT retry. Retry logic is implemented in stream_with_first_token_retry. - + Args: client: HTTP client (for connection management) response: HTTP response with data stream model: Model name to include in response model_cache: Model cache for getting token limits auth_manager: Authentication manager - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + tool_name_mapping: Mapping of truncated tool names back to originals + Yields: Strings in SSE format: "data: {...}\\n\\n" or "data: [DONE]\\n\\n" """ async for chunk in stream_kiro_to_openai_internal( client, response, model, model_cache, auth_manager, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens, + tool_name_mapping=tool_name_mapping ): yield chunk @@ -411,8 +397,7 @@ async def stream_with_first_token_retry( auth_manager: "KiroAuthManager", max_retries: int = FIRST_TOKEN_MAX_RETRIES, first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Streaming with automatic retry on first token timeout. @@ -433,15 +418,14 @@ async def stream_with_first_token_retry( auth_manager: Authentication manager max_retries: Maximum number of attempts first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in SSE format - + Raises: HTTPException: After exhausting all attempts - + Example: >>> async def make_req(): ... return await http_client.request_with_retry("POST", url, payload, stream=True) @@ -454,14 +438,14 @@ def create_http_error(status_code: int, error_text: str) -> HTTPException: status_code=status_code, detail=f"Upstream API error: {error_text}" ) - + def create_timeout_error(retries: int, timeout: float) -> HTTPException: """Create HTTPException for timeout errors.""" return HTTPException( status_code=504, detail=f"Model did not respond within {timeout}s after {retries} attempts. Please try again." ) - + async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None]: """Process response and yield OpenAI SSE chunks.""" async for chunk in stream_kiro_to_openai_internal( @@ -471,8 +455,7 @@ async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None model_cache, auth_manager, first_token_timeout=first_token_timeout, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens ): yield chunk @@ -493,24 +476,23 @@ async def collect_stream_response( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> dict: """ Collect full response from streaming stream. - + Used for non-streaming mode - collects all chunks and forms a single response. - + Args: client: HTTP client response: HTTP response with stream model: Model name model_cache: Model cache auth_manager: Authentication manager - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Returns: Dictionary with full response in OpenAI chat.completion format """ @@ -519,15 +501,15 @@ async def collect_stream_response( final_usage = None tool_calls = [] completion_id = generate_completion_id() - + async for chunk_str in stream_kiro_to_openai( client, response, model, model_cache, auth_manager, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens, + tool_name_mapping=tool_name_mapping ): if not chunk_str.startswith("data:"): continue diff --git a/tests/unit/test_streaming_anthropic.py b/tests/unit/test_streaming_anthropic.py index 6932fad1..16f77ba7 100644 --- a/tests/unit/test_streaming_anthropic.py +++ b/tests/unit/test_streaming_anthropic.py @@ -682,12 +682,11 @@ async def test_includes_usage_info(self, mock_response, mock_model_cache, mock_a print("Action: Collecting Anthropic response...") with patch('kiro.streaming_anthropic.collect_stream_to_result', return_value=mock_result): - with patch('kiro.streaming_anthropic.count_message_tokens', return_value=10): - with patch('kiro.streaming_anthropic.count_tokens', return_value=5): - result = await collect_anthropic_response( - mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, - request_messages=[{"role": "user", "content": "Hi"}] - ) + with patch('kiro.streaming_anthropic.count_tokens', return_value=5): + result = await collect_anthropic_response( + mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, + prompt_tokens=10 + ) print(f"Usage: {result['usage']}") assert "input_tokens" in result["usage"] @@ -1082,36 +1081,32 @@ async def mock_parse_kiro_stream(*args, **kwargs): print("✓ Tokens calculated from context usage") @pytest.mark.asyncio - async def test_uses_request_messages_for_input_tokens(self, mock_response, mock_model_cache, mock_auth_manager): + async def test_uses_prompt_tokens_for_input_tokens(self, mock_response, mock_model_cache, mock_auth_manager): """ - What it does: Uses request messages for input token count. - Goal: Verify input tokens are counted from request. + What it does: Uses pre-counted prompt_tokens for input token count. + Goal: Verify input tokens are passed through from prompt_tokens param. """ print("Setup: Mock stream...") - + async def mock_parse_kiro_stream(*args, **kwargs): yield KiroEvent(type="content", content="Hello") - - request_messages = [ - {"role": "user", "content": "Hi there!"} - ] - - print("Action: Streaming to Anthropic format with request messages...") + + print("Action: Streaming to Anthropic format with prompt_tokens...") events = [] - + with patch('kiro.streaming_anthropic.parse_kiro_stream', mock_parse_kiro_stream): with patch('kiro.streaming_anthropic.parse_bracket_tool_calls', return_value=[]): - with patch('kiro.streaming_anthropic.count_message_tokens', return_value=10) as mock_count: - async for event in stream_kiro_to_anthropic( - mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, - request_messages=request_messages - ): - events.append(event) - - # Verify count_message_tokens was called - mock_count.assert_called_once_with(request_messages, apply_claude_correction=False) - - print("✓ Request messages used for input token count") + async for event in stream_kiro_to_anthropic( + mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, + prompt_tokens=42 + ): + events.append(event) + + # Check message_start has the prompt_tokens as input_tokens + message_start = json.loads(events[0].split("data: ")[1]) + assert message_start["message"]["usage"]["input_tokens"] == 42 + + print("✓ prompt_tokens used for input token count") # ================================================================================================== @@ -1362,44 +1357,42 @@ async def mock_make_request(): print("✓ Anthropic-formatted error raised on HTTP error") @pytest.mark.asyncio - async def test_passes_request_messages_to_stream(self, mock_model_cache, mock_auth_manager): + async def test_passes_prompt_tokens_to_stream(self, mock_model_cache, mock_auth_manager): """ - What it does: Passes request_messages to underlying stream function. + What it does: Passes prompt_tokens to underlying stream function. Goal: Verify token counting parameters are forwarded. """ - print("Setup: Mock request with messages...") - + print("Setup: Mock request with prompt_tokens...") + mock_response = AsyncMock() mock_response.status_code = 200 mock_response.aclose = AsyncMock() - + async def mock_make_request(): return mock_response - + captured_kwargs = {} - + async def mock_stream_kiro_to_anthropic(*args, **kwargs): captured_kwargs.update(kwargs) yield "event: message_start\ndata: {}\n\n" yield "event: message_stop\ndata: {}\n\n" - - request_messages = [{"role": "user", "content": "Hello"}] - - print("Action: Streaming with request_messages...") - + + print("Action: Streaming with prompt_tokens...") + with patch('kiro.streaming_anthropic.stream_kiro_to_anthropic', mock_stream_kiro_to_anthropic): async for chunk in stream_with_first_token_retry_anthropic( make_request=mock_make_request, model="claude-sonnet-4", model_cache=mock_model_cache, auth_manager=mock_auth_manager, - request_messages=request_messages + prompt_tokens=42 ): pass - + print(f"Captured kwargs: {captured_kwargs}") - assert captured_kwargs.get("request_messages") == request_messages - print("✓ request_messages passed to stream function") + assert captured_kwargs.get("prompt_tokens") == 42 + print("✓ prompt_tokens passed to stream function") @pytest.mark.asyncio async def test_uses_configured_max_retries(self, mock_model_cache, mock_auth_manager):