diff --git a/kiro/routes_anthropic.py b/kiro/routes_anthropic.py index 1bc6bd10..d088b1b9 100644 --- a/kiro/routes_anthropic.py +++ b/kiro/routes_anthropic.py @@ -50,7 +50,7 @@ ) from kiro.http_client import KiroHttpClient from kiro.utils import generate_conversation_id -from kiro.tokenizer import count_tools_tokens +from kiro.tokenizer import count_tokens # Import debug_logger try: @@ -298,10 +298,11 @@ async def messages( shared_client = request.app.state.http_client http_client = KiroHttpClient(auth_manager, shared_client=shared_client) - # Prepare data for token counting - # Convert Pydantic models to dicts for tokenizer - messages_for_tokenizer = [msg.model_dump() for msg in request_data.messages] - tools_for_tokenizer = [tool.model_dump() for tool in request_data.tools] if request_data.tools else None + # Count prompt tokens from the full Kiro payload (system prompt + messages + tools) + kiro_payload_prompt_tokens = count_tokens( + kiro_request_body.decode('utf-8', errors='ignore'), + apply_claude_correction=False + ) try: # Make request to Kiro API (for both streaming and non-streaming modes) @@ -368,7 +369,7 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens ): yield chunk except GeneratorExit: @@ -415,7 +416,7 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens ) await http_client.close() diff --git a/kiro/routes_openai.py b/kiro/routes_openai.py index 301ae9b5..de1cc5e6 100644 --- a/kiro/routes_openai.py +++ b/kiro/routes_openai.py @@ -50,6 +50,7 @@ from kiro.streaming_openai import stream_kiro_to_openai, collect_stream_response, stream_with_first_token_retry from kiro.http_client import KiroHttpClient from kiro.utils import generate_conversation_id +from kiro.tokenizer import count_tokens # Import debug_logger try: @@ -324,10 +325,12 @@ async def chat_completions(request: Request, request_data: ChatCompletionRequest } ) - # Prepare data for fallback token counting - # Convert Pydantic models to dicts for tokenizer - messages_for_tokenizer = [msg.model_dump() for msg in request_data.messages] - tools_for_tokenizer = [tool.model_dump() for tool in request_data.tools] if request_data.tools else None + # Count prompt tokens from the full Kiro payload (system prompt + messages + tools) + # This matches what actually gets sent to the API, giving accurate token counts + kiro_payload_prompt_tokens = count_tokens( + kiro_request_body.decode('utf-8', errors='ignore'), + apply_claude_correction=False + ) if request_data.stream: # Streaming mode @@ -341,8 +344,7 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer, - request_tools=tools_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens ): yield chunk except GeneratorExit: @@ -387,8 +389,7 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer, - request_tools=tools_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens ) await http_client.close() diff --git a/kiro/streaming_anthropic.py b/kiro/streaming_anthropic.py index 979113ef..28dfe7a1 100644 --- a/kiro/streaming_anthropic.py +++ b/kiro/streaming_anthropic.py @@ -44,10 +44,9 @@ collect_stream_to_result, FirstTokenTimeoutError, KiroEvent, - calculate_tokens_from_context_usage, stream_with_first_token_retry, ) -from kiro.tokenizer import count_tokens, count_message_tokens, count_tools_tokens +from kiro.tokenizer import count_tokens from kiro.parsers import parse_bracket_tool_calls, deduplicate_tool_calls from kiro.config import FIRST_TOKEN_TIMEOUT, FIRST_TOKEN_MAX_RETRIES, FAKE_REASONING_HANDLING @@ -104,40 +103,36 @@ async def stream_kiro_to_anthropic( model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, + prompt_tokens: int = 0, conversation_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """ Generator for converting Kiro stream to Anthropic SSE format. - + Parses Kiro AWS SSE stream and converts events to Anthropic format. Supports thinking content blocks when FAKE_REASONING_HANDLING=as_reasoning_content. - + Args: response: HTTP response with data stream model: Model name to include in response model_cache: Model cache for getting token limits auth_manager: Authentication manager first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for token counting) + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) conversation_id: Stable conversation ID for truncation recovery (optional) - + Yields: Strings in Anthropic SSE format - + Raises: FirstTokenTimeoutError: If first token not received within timeout """ message_id = generate_message_id() - input_tokens = 0 + input_tokens = prompt_tokens output_tokens = 0 full_content = "" full_thinking_content = "" - # Count input tokens from request messages - if request_messages: - input_tokens = count_message_tokens(request_messages, apply_claude_correction=False) - # Track content blocks - thinking block is index 0, text block is index 1 (when thinking enabled) current_block_index = 0 thinking_block_started = False @@ -456,13 +451,9 @@ async def stream_kiro_to_anthropic( # Calculate output tokens output_tokens = count_tokens(full_content + full_thinking_content) - - # Calculate total tokens from context usage if available - if context_usage_percentage is not None: - prompt_tokens, total_tokens, _, _ = calculate_tokens_from_context_usage( - context_usage_percentage, output_tokens, model_cache, model - ) - input_tokens = prompt_tokens + + # input_tokens already set from pre-counted prompt_tokens (full Kiro payload). + # Don't override with contextUsagePercentage — it's unreliable. # Determine stop reason stop_reason = "tool_use" if tool_blocks else "end_turn" @@ -545,29 +536,27 @@ async def collect_anthropic_response( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None + prompt_tokens: int = 0 ) -> dict: """ Collect full response from Kiro stream in Anthropic format. - + Used for non-streaming mode. - + Args: response: HTTP response with stream model: Model name model_cache: Model cache auth_manager: Authentication manager - request_messages: Original request messages (for token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Returns: Dictionary with full response in Anthropic Messages format """ message_id = generate_message_id() - - # Count input tokens - input_tokens = 0 - if request_messages: - input_tokens = count_message_tokens(request_messages, apply_claude_correction=False) + + # Use pre-counted prompt tokens + input_tokens = prompt_tokens # Collect stream result result = await collect_stream_to_result(response) @@ -617,12 +606,8 @@ async def collect_anthropic_response( # Calculate output tokens output_tokens = count_tokens(result.content + result.thinking_content) - # Calculate from context usage if available - if result.context_usage_percentage is not None: - prompt_tokens, _, _, _ = calculate_tokens_from_context_usage( - result.context_usage_percentage, output_tokens, model_cache, model - ) - input_tokens = prompt_tokens + # input_tokens already set from pre-counted prompt_tokens (full Kiro payload). + # Don't override with contextUsagePercentage — it's unreliable. # Determine stop reason stop_reason = "tool_use" if result.tool_calls else "end_turn" @@ -655,18 +640,17 @@ async def stream_with_first_token_retry_anthropic( auth_manager: "KiroAuthManager", max_retries: int = FIRST_TOKEN_MAX_RETRIES, first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Streaming with automatic retry on first token timeout for Anthropic API. - + If model doesn't respond within first_token_timeout seconds, request is cancelled and a new one is made. Maximum max_retries attempts. - + This is seamless for user - they just see a delay, but eventually get a response (or error after all attempts). - + Args: make_request: Function to create new HTTP request model: Model name @@ -674,12 +658,11 @@ async def stream_with_first_token_retry_anthropic( auth_manager: Authentication manager max_retries: Maximum number of attempts first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in Anthropic SSE format - + Raises: Exception with Anthropic error format after exhausting all attempts """ @@ -711,7 +694,7 @@ async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None model_cache, auth_manager, first_token_timeout=first_token_timeout, - request_messages=request_messages + prompt_tokens=prompt_tokens ): yield chunk diff --git a/kiro/streaming_openai.py b/kiro/streaming_openai.py index 5153bc82..e65df6a3 100644 --- a/kiro/streaming_openai.py +++ b/kiro/streaming_openai.py @@ -43,14 +43,13 @@ FIRST_TOKEN_MAX_RETRIES, FAKE_REASONING_HANDLING, ) -from kiro.tokenizer import count_tokens, count_message_tokens, count_tools_tokens +from kiro.tokenizer import count_tokens # Import from streaming_core - reuse shared parsing logic from kiro.streaming_core import ( parse_kiro_stream, FirstTokenTimeoutError, KiroEvent, - calculate_tokens_from_context_usage, stream_with_first_token_retry as stream_with_first_token_retry_core, ) @@ -76,8 +75,7 @@ async def stream_kiro_to_openai_internal( model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None, + prompt_tokens: int = 0, conversation_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """ @@ -96,9 +94,7 @@ async def stream_kiro_to_openai_internal( model_cache: Model cache for getting token limits auth_manager: Authentication manager first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - conversation_id: Stable conversation ID for truncation recovery (optional) + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) conversation_id: Stable conversation ID for truncation recovery (optional) Yields: @@ -224,24 +220,9 @@ async def stream_kiro_to_openai_internal( # Count completion_tokens (output) using tiktoken completion_tokens = count_tokens(full_content + full_thinking_content) - - # Calculate total_tokens based on context_usage_percentage from Kiro API - # context_usage shows TOTAL percentage of context usage (input + output) - prompt_tokens, total_tokens, prompt_source, total_source = calculate_tokens_from_context_usage( - context_usage_percentage, completion_tokens, model_cache, model - ) - - # Fallback: Kiro API didn't return context_usage, use tiktoken - # Count prompt_tokens from original messages - # IMPORTANT: Don't apply correction coefficient for prompt_tokens, - # as it was calibrated for completion_tokens - if prompt_source == "unknown" and request_messages: - prompt_tokens = count_message_tokens(request_messages, apply_claude_correction=False) - if request_tools: - prompt_tokens += count_tools_tokens(request_tools, apply_claude_correction=False) - total_tokens = prompt_tokens + completion_tokens - prompt_source = "tiktoken" - total_source = "tiktoken" + + # Use pre-counted prompt_tokens from the full Kiro payload + total_tokens = prompt_tokens + completion_tokens # Send tool calls if present if all_tool_calls: @@ -329,9 +310,9 @@ async def stream_kiro_to_openai_internal( # Log final token values being sent to client logger.debug( f"[Usage] {model}: " - f"prompt_tokens={prompt_tokens} ({prompt_source}), " + f"prompt_tokens={prompt_tokens} (payload tiktoken), " f"completion_tokens={completion_tokens} (tiktoken), " - f"total_tokens={total_tokens} ({total_source})" + f"total_tokens={total_tokens} (sum)" ) yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" @@ -374,31 +355,28 @@ async def stream_kiro_to_openai( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Generator for converting Kiro stream to OpenAI format. - + This is a wrapper over stream_kiro_to_openai_internal that does NOT retry. Retry logic is implemented in stream_with_first_token_retry. - + Args: client: HTTP client (for connection management) response: HTTP response with data stream model: Model name to include in response model_cache: Model cache for getting token limits auth_manager: Authentication manager - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in SSE format: "data: {...}\\n\\n" or "data: [DONE]\\n\\n" """ async for chunk in stream_kiro_to_openai_internal( client, response, model, model_cache, auth_manager, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens ): yield chunk @@ -411,8 +389,7 @@ async def stream_with_first_token_retry( auth_manager: "KiroAuthManager", max_retries: int = FIRST_TOKEN_MAX_RETRIES, first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Streaming with automatic retry on first token timeout. @@ -433,15 +410,14 @@ async def stream_with_first_token_retry( auth_manager: Authentication manager max_retries: Maximum number of attempts first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in SSE format - + Raises: HTTPException: After exhausting all attempts - + Example: >>> async def make_req(): ... return await http_client.request_with_retry("POST", url, payload, stream=True) @@ -454,14 +430,14 @@ def create_http_error(status_code: int, error_text: str) -> HTTPException: status_code=status_code, detail=f"Upstream API error: {error_text}" ) - + def create_timeout_error(retries: int, timeout: float) -> HTTPException: """Create HTTPException for timeout errors.""" return HTTPException( status_code=504, detail=f"Model did not respond within {timeout}s after {retries} attempts. Please try again." ) - + async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None]: """Process response and yield OpenAI SSE chunks.""" async for chunk in stream_kiro_to_openai_internal( @@ -471,8 +447,7 @@ async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None model_cache, auth_manager, first_token_timeout=first_token_timeout, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens ): yield chunk @@ -493,24 +468,22 @@ async def collect_stream_response( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> dict: """ Collect full response from streaming stream. - + Used for non-streaming mode - collects all chunks and forms a single response. - + Args: client: HTTP client response: HTTP response with stream model: Model name model_cache: Model cache auth_manager: Authentication manager - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Returns: Dictionary with full response in OpenAI chat.completion format """ @@ -519,15 +492,14 @@ async def collect_stream_response( final_usage = None tool_calls = [] completion_id = generate_completion_id() - + async for chunk_str in stream_kiro_to_openai( client, response, model, model_cache, auth_manager, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens ): if not chunk_str.startswith("data:"): continue diff --git a/tests/unit/test_streaming_anthropic.py b/tests/unit/test_streaming_anthropic.py index 6932fad1..16f77ba7 100644 --- a/tests/unit/test_streaming_anthropic.py +++ b/tests/unit/test_streaming_anthropic.py @@ -682,12 +682,11 @@ async def test_includes_usage_info(self, mock_response, mock_model_cache, mock_a print("Action: Collecting Anthropic response...") with patch('kiro.streaming_anthropic.collect_stream_to_result', return_value=mock_result): - with patch('kiro.streaming_anthropic.count_message_tokens', return_value=10): - with patch('kiro.streaming_anthropic.count_tokens', return_value=5): - result = await collect_anthropic_response( - mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, - request_messages=[{"role": "user", "content": "Hi"}] - ) + with patch('kiro.streaming_anthropic.count_tokens', return_value=5): + result = await collect_anthropic_response( + mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, + prompt_tokens=10 + ) print(f"Usage: {result['usage']}") assert "input_tokens" in result["usage"] @@ -1082,36 +1081,32 @@ async def mock_parse_kiro_stream(*args, **kwargs): print("✓ Tokens calculated from context usage") @pytest.mark.asyncio - async def test_uses_request_messages_for_input_tokens(self, mock_response, mock_model_cache, mock_auth_manager): + async def test_uses_prompt_tokens_for_input_tokens(self, mock_response, mock_model_cache, mock_auth_manager): """ - What it does: Uses request messages for input token count. - Goal: Verify input tokens are counted from request. + What it does: Uses pre-counted prompt_tokens for input token count. + Goal: Verify input tokens are passed through from prompt_tokens param. """ print("Setup: Mock stream...") - + async def mock_parse_kiro_stream(*args, **kwargs): yield KiroEvent(type="content", content="Hello") - - request_messages = [ - {"role": "user", "content": "Hi there!"} - ] - - print("Action: Streaming to Anthropic format with request messages...") + + print("Action: Streaming to Anthropic format with prompt_tokens...") events = [] - + with patch('kiro.streaming_anthropic.parse_kiro_stream', mock_parse_kiro_stream): with patch('kiro.streaming_anthropic.parse_bracket_tool_calls', return_value=[]): - with patch('kiro.streaming_anthropic.count_message_tokens', return_value=10) as mock_count: - async for event in stream_kiro_to_anthropic( - mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, - request_messages=request_messages - ): - events.append(event) - - # Verify count_message_tokens was called - mock_count.assert_called_once_with(request_messages, apply_claude_correction=False) - - print("✓ Request messages used for input token count") + async for event in stream_kiro_to_anthropic( + mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, + prompt_tokens=42 + ): + events.append(event) + + # Check message_start has the prompt_tokens as input_tokens + message_start = json.loads(events[0].split("data: ")[1]) + assert message_start["message"]["usage"]["input_tokens"] == 42 + + print("✓ prompt_tokens used for input token count") # ================================================================================================== @@ -1362,44 +1357,42 @@ async def mock_make_request(): print("✓ Anthropic-formatted error raised on HTTP error") @pytest.mark.asyncio - async def test_passes_request_messages_to_stream(self, mock_model_cache, mock_auth_manager): + async def test_passes_prompt_tokens_to_stream(self, mock_model_cache, mock_auth_manager): """ - What it does: Passes request_messages to underlying stream function. + What it does: Passes prompt_tokens to underlying stream function. Goal: Verify token counting parameters are forwarded. """ - print("Setup: Mock request with messages...") - + print("Setup: Mock request with prompt_tokens...") + mock_response = AsyncMock() mock_response.status_code = 200 mock_response.aclose = AsyncMock() - + async def mock_make_request(): return mock_response - + captured_kwargs = {} - + async def mock_stream_kiro_to_anthropic(*args, **kwargs): captured_kwargs.update(kwargs) yield "event: message_start\ndata: {}\n\n" yield "event: message_stop\ndata: {}\n\n" - - request_messages = [{"role": "user", "content": "Hello"}] - - print("Action: Streaming with request_messages...") - + + print("Action: Streaming with prompt_tokens...") + with patch('kiro.streaming_anthropic.stream_kiro_to_anthropic', mock_stream_kiro_to_anthropic): async for chunk in stream_with_first_token_retry_anthropic( make_request=mock_make_request, model="claude-sonnet-4", model_cache=mock_model_cache, auth_manager=mock_auth_manager, - request_messages=request_messages + prompt_tokens=42 ): pass - + print(f"Captured kwargs: {captured_kwargs}") - assert captured_kwargs.get("request_messages") == request_messages - print("✓ request_messages passed to stream function") + assert captured_kwargs.get("prompt_tokens") == 42 + print("✓ prompt_tokens passed to stream function") @pytest.mark.asyncio async def test_uses_configured_max_retries(self, mock_model_cache, mock_auth_manager):