diff --git a/kiro/converters_anthropic.py b/kiro/converters_anthropic.py index 4c2d38b1..9231e0ad 100644 --- a/kiro/converters_anthropic.py +++ b/kiro/converters_anthropic.py @@ -341,6 +341,10 @@ def convert_anthropic_tools( """ Converts Anthropic tools to unified format. + Silently skips Anthropic built-in server tools (web_search, code_execution, + bash, text_editor, etc.) that have no input_schema, since the Kiro API + cannot handle them. + Args: tools: List of Anthropic tools @@ -356,11 +360,18 @@ def convert_anthropic_tools( if isinstance(tool, dict): name = tool.get("name", "") description = tool.get("description") - input_schema = tool.get("input_schema", {}) + input_schema = tool.get("input_schema") + tool_type = tool.get("type") else: - name = tool.name - description = tool.description - input_schema = tool.input_schema + name = getattr(tool, "name", "") or "" + description = getattr(tool, "description", None) + input_schema = getattr(tool, "input_schema", None) + tool_type = getattr(tool, "type", None) + + # Skip built-in server tools (no input_schema) — Kiro API can't handle them + if input_schema is None: + logger.debug(f"Skipping server tool '{name or tool_type}' (no input_schema)") + continue unified_tools.append( UnifiedTool(name=name, description=description, input_schema=input_schema) @@ -424,4 +435,4 @@ def anthropic_to_kiro( inject_thinking=True, ) - return result.payload + return result diff --git a/kiro/converters_core.py b/kiro/converters_core.py index 21cf758b..4f8c63c7 100644 --- a/kiro/converters_core.py +++ b/kiro/converters_core.py @@ -30,6 +30,7 @@ to convert their formats to Kiro API format. """ +import hashlib import json from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple @@ -93,9 +94,11 @@ class KiroPayloadResult: Attributes: payload: The complete Kiro API payload tool_documentation: Documentation for tools with long descriptions (to add to system prompt) + tool_name_mapping: Mapping of truncated tool names back to originals (short→original) """ payload: Dict[str, Any] tool_documentation: str = "" + tool_name_mapping: Dict[str, str] = field(default_factory=dict) # ================================================================================================== @@ -491,46 +494,34 @@ def process_tools_with_long_descriptions( return processed_tools if processed_tools else None, tool_documentation -def validate_tool_names(tools: Optional[List[UnifiedTool]]) -> None: +TOOL_NAME_MAX_LENGTH = 64 + + +def _make_short_name(name: str) -> str: + """Deterministically shorten a tool name to fit within TOOL_NAME_MAX_LENGTH.""" + suffix = hashlib.md5(name.encode()).hexdigest()[:8] + return name[:TOOL_NAME_MAX_LENGTH - 9] + "_" + suffix + + +def truncate_tool_names( + tools: Optional[List[UnifiedTool]], +) -> Dict[str, str]: """ - Validates tool names against Kiro API 64-character limit. - - Logs WARNING for each problematic tool and raises ValueError - with complete list of violations. - - Args: - tools: List of tools to validate - - Raises: - ValueError: If any tool name exceeds 64 characters - - Example: - >>> validate_tool_names([UnifiedTool(name="short_name", description="test")]) - # No error - >>> validate_tool_names([UnifiedTool(name="a" * 70, description="test")]) - # Raises ValueError with detailed message + Truncate tool names exceeding 64 characters in-place. + + Returns a mapping {short_name: original_name} for names that were changed. """ if not tools: - return - - problematic_tools = [] + return {} + + mapping: Dict[str, str] = {} for tool in tools: - if len(tool.name) > 64: - problematic_tools.append((tool.name, len(tool.name))) - - if problematic_tools: - # Build detailed error message for client (no logging here - routes will log) - tool_list = "\n".join([ - f" - '{name}' ({length} characters)" - for name, length in problematic_tools - ]) - - raise ValueError( - f"Tool name(s) exceed Kiro API limit of 64 characters:\n" - f"{tool_list}\n\n" - f"Solution: Use shorter tool names (max 64 characters).\n" - f"Example: 'get_user_data' instead of 'get_authenticated_user_profile_data_with_extended_information_about_it'" - ) + if len(tool.name) > TOOL_NAME_MAX_LENGTH: + short = _make_short_name(tool.name) + mapping[short] = tool.name + logger.debug(f"Truncated tool name '{tool.name}' -> '{short}'") + tool.name = short + return mapping def convert_tools_to_kiro_format(tools: Optional[List[UnifiedTool]]) -> List[Dict[str, Any]]: @@ -1370,8 +1361,10 @@ def build_kiro_payload( # Process tools with long descriptions processed_tools, tool_documentation = process_tools_with_long_descriptions(tools) - # Validate tool names against Kiro API 64-character limit - validate_tool_names(processed_tools) + # Truncate tool names that exceed 64-character Kiro API limit + tool_name_mapping = truncate_tool_names(processed_tools) + if tool_name_mapping: + logger.info(f"Truncated {len(tool_name_mapping)} tool name(s) exceeding {TOOL_NAME_MAX_LENGTH} chars") # Add tool documentation to system prompt if present full_system_prompt = system_prompt @@ -1429,6 +1422,17 @@ def build_kiro_payload( history = build_kiro_history(history_messages, model_id) + # Apply tool name truncation to tool_use blocks in history + if tool_name_mapping: + reverse = {v: k for k, v in tool_name_mapping.items()} + for entry in history: + arm = entry.get("assistantResponseMessage") + if arm: + for tu in arm.get("toolUses", []): + orig = tu.get("name", "") + if orig in reverse: + tu["name"] = reverse[orig] + # Current message (the last one) current_message = merged_messages[-1] current_content = extract_text_content(current_message.content) @@ -1519,4 +1523,4 @@ def build_kiro_payload( if profile_arn: payload["profileArn"] = profile_arn - return KiroPayloadResult(payload=payload, tool_documentation=tool_documentation) \ No newline at end of file + return KiroPayloadResult(payload=payload, tool_documentation=tool_documentation, tool_name_mapping=tool_name_mapping) \ No newline at end of file diff --git a/kiro/converters_openai.py b/kiro/converters_openai.py index aad3b83c..5a49da1a 100644 --- a/kiro/converters_openai.py +++ b/kiro/converters_openai.py @@ -345,4 +345,4 @@ def build_kiro_payload( inject_thinking=True ) - return result.payload \ No newline at end of file + return result \ No newline at end of file diff --git a/kiro/models_anthropic.py b/kiro/models_anthropic.py index 126537a0..8f7e2025 100644 --- a/kiro/models_anthropic.py +++ b/kiro/models_anthropic.py @@ -185,15 +185,23 @@ class AnthropicTool(BaseModel): """ Tool definition in Anthropic format. + Supports both custom tools (with input_schema) and Anthropic built-in + server tools like web_search, code_execution, bash, text_editor + (which have a type field but no input_schema). + Attributes: name: Tool name (must match pattern ^[a-zA-Z0-9_-]{1,64}$) description: Tool description (optional but recommended) - input_schema: JSON Schema for tool parameters + input_schema: JSON Schema for tool parameters (required for custom tools, absent for server tools) + type: Tool type identifier for built-in tools (e.g. "web_search_20250305") """ - name: str + model_config = {"extra": "allow"} + + name: Optional[str] = None description: Optional[str] = None - input_schema: Dict[str, Any] + input_schema: Optional[Dict[str, Any]] = None + type: Optional[str] = None class ToolChoiceAuto(BaseModel): @@ -263,7 +271,7 @@ class AnthropicMessagesRequest(BaseModel): model: str messages: List[AnthropicMessage] = Field(min_length=1) - max_tokens: int + max_tokens: int = 4096 # Optional parameters - system can be string or list of content blocks system: Optional[SystemPrompt] = None diff --git a/kiro/routes_anthropic.py b/kiro/routes_anthropic.py index 1bc6bd10..8173815a 100644 --- a/kiro/routes_anthropic.py +++ b/kiro/routes_anthropic.py @@ -50,7 +50,7 @@ ) from kiro.http_client import KiroHttpClient from kiro.utils import generate_conversation_id -from kiro.tokenizer import count_tools_tokens +from kiro.tokenizer import count_tokens # Import debug_logger try: @@ -257,11 +257,13 @@ async def messages( profile_arn_for_payload = auth_manager.profile_arn try: - kiro_payload = anthropic_to_kiro( + kiro_result = anthropic_to_kiro( request_data, conversation_id, profile_arn_for_payload ) + kiro_payload = kiro_result.payload + tool_name_mapping = kiro_result.tool_name_mapping except ValueError as e: logger.error(f"Conversion error: {e}") return JSONResponse( @@ -298,10 +300,11 @@ async def messages( shared_client = request.app.state.http_client http_client = KiroHttpClient(auth_manager, shared_client=shared_client) - # Prepare data for token counting - # Convert Pydantic models to dicts for tokenizer - messages_for_tokenizer = [msg.model_dump() for msg in request_data.messages] - tools_for_tokenizer = [tool.model_dump() for tool in request_data.tools] if request_data.tools else None + # Count prompt tokens from the full Kiro payload (system prompt + messages + tools) + kiro_payload_prompt_tokens = count_tokens( + kiro_request_body.decode('utf-8', errors='ignore'), + apply_claude_correction=False + ) try: # Make request to Kiro API (for both streaming and non-streaming modes) @@ -368,7 +371,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ): yield chunk except GeneratorExit: @@ -415,7 +419,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ) await http_client.close() @@ -449,4 +454,59 @@ async def stream_wrapper(): "message": f"Internal Server Error: {str(e)}" } } - ) \ No newline at end of file + ) + + +@router.post("/v1/messages/count_tokens", dependencies=[Depends(verify_anthropic_api_key)]) +async def count_tokens_endpoint( + request: Request, + request_data: AnthropicMessagesRequest, +): + """ + Anthropic Count Tokens API endpoint. + + Returns estimated token count for the given request payload. + Used by Claude Code to decide when to trigger conversation compaction. + + Builds the full Kiro payload and counts tokens on the serialized JSON, + consistent with the token counting approach used in the messages endpoint. + """ + logger.info(f"Request to /v1/messages/count_tokens (model={request_data.model}, messages={len(request_data.messages)})") + + auth_manager: KiroAuthManager = request.app.state.auth_manager + + # Build Kiro payload (same as messages endpoint) + conversation_id = generate_conversation_id() + profile_arn_for_payload = "" + if auth_manager.auth_type == AuthType.KIRO_DESKTOP and auth_manager.profile_arn: + profile_arn_for_payload = auth_manager.profile_arn + + try: + kiro_payload = anthropic_to_kiro( + request_data, + conversation_id, + profile_arn_for_payload + ).payload + except ValueError as e: + logger.error(f"Conversion error in count_tokens: {e}") + return JSONResponse( + status_code=400, + content={ + "type": "error", + "error": { + "type": "invalid_request_error", + "message": str(e) + } + } + ) + + # Count tokens from the full serialized Kiro payload (same as messages endpoint) + kiro_request_body = json.dumps(kiro_payload, ensure_ascii=False, indent=2) + input_tokens = count_tokens(kiro_request_body, apply_claude_correction=False) + + logger.info(f"Token count estimate: {input_tokens} (payload size: {len(kiro_request_body)} chars)") + + return JSONResponse(content={ + "input_tokens": input_tokens, + "context_management": {"original_input_tokens": input_tokens} + }) diff --git a/kiro/routes_openai.py b/kiro/routes_openai.py index 301ae9b5..f129811d 100644 --- a/kiro/routes_openai.py +++ b/kiro/routes_openai.py @@ -50,6 +50,7 @@ from kiro.streaming_openai import stream_kiro_to_openai, collect_stream_response, stream_with_first_token_retry from kiro.http_client import KiroHttpClient from kiro.utils import generate_conversation_id +from kiro.tokenizer import count_tokens # Import debug_logger try: @@ -240,11 +241,13 @@ async def chat_completions(request: Request, request_data: ChatCompletionRequest profile_arn_for_payload = auth_manager.profile_arn try: - kiro_payload = build_kiro_payload( + kiro_result = build_kiro_payload( request_data, conversation_id, profile_arn_for_payload ) + kiro_payload = kiro_result.payload + tool_name_mapping = kiro_result.tool_name_mapping except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -324,10 +327,12 @@ async def chat_completions(request: Request, request_data: ChatCompletionRequest } ) - # Prepare data for fallback token counting - # Convert Pydantic models to dicts for tokenizer - messages_for_tokenizer = [msg.model_dump() for msg in request_data.messages] - tools_for_tokenizer = [tool.model_dump() for tool in request_data.tools] if request_data.tools else None + # Count prompt tokens from the full Kiro payload (system prompt + messages + tools) + # This matches what actually gets sent to the API, giving accurate token counts + kiro_payload_prompt_tokens = count_tokens( + kiro_request_body.decode('utf-8', errors='ignore'), + apply_claude_correction=False + ) if request_data.stream: # Streaming mode @@ -341,8 +346,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer, - request_tools=tools_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ): yield chunk except GeneratorExit: @@ -387,8 +392,8 @@ async def stream_wrapper(): request_data.model, model_cache, auth_manager, - request_messages=messages_for_tokenizer, - request_tools=tools_for_tokenizer + prompt_tokens=kiro_payload_prompt_tokens, + tool_name_mapping=tool_name_mapping ) await http_client.close() diff --git a/kiro/streaming_anthropic.py b/kiro/streaming_anthropic.py index 979113ef..2177da2d 100644 --- a/kiro/streaming_anthropic.py +++ b/kiro/streaming_anthropic.py @@ -44,10 +44,9 @@ collect_stream_to_result, FirstTokenTimeoutError, KiroEvent, - calculate_tokens_from_context_usage, stream_with_first_token_retry, ) -from kiro.tokenizer import count_tokens, count_message_tokens, count_tools_tokens +from kiro.tokenizer import count_tokens from kiro.parsers import parse_bracket_tool_calls, deduplicate_tool_calls from kiro.config import FIRST_TOKEN_TIMEOUT, FIRST_TOKEN_MAX_RETRIES, FAKE_REASONING_HANDLING @@ -104,40 +103,37 @@ async def stream_kiro_to_anthropic( model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - conversation_id: Optional[str] = None + prompt_tokens: int = 0, + conversation_id: Optional[str] = None, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> AsyncGenerator[str, None]: """ Generator for converting Kiro stream to Anthropic SSE format. - + Parses Kiro AWS SSE stream and converts events to Anthropic format. Supports thinking content blocks when FAKE_REASONING_HANDLING=as_reasoning_content. - + Args: response: HTTP response with data stream model: Model name to include in response model_cache: Model cache for getting token limits auth_manager: Authentication manager first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for token counting) + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) conversation_id: Stable conversation ID for truncation recovery (optional) - + Yields: Strings in Anthropic SSE format - + Raises: FirstTokenTimeoutError: If first token not received within timeout """ message_id = generate_message_id() - input_tokens = 0 + input_tokens = prompt_tokens output_tokens = 0 full_content = "" full_thinking_content = "" - # Count input tokens from request messages - if request_messages: - input_tokens = count_message_tokens(request_messages, apply_claude_correction=False) - # Track content blocks - thinking block is index 0, text block is index 1 (when thinking enabled) current_block_index = 0 thinking_block_started = False @@ -302,6 +298,10 @@ async def stream_kiro_to_anthropic( tool_name = tool.get("function", {}).get("name", "") or tool.get("name", "") tool_input = tool.get("function", {}).get("arguments", {}) or tool.get("input", {}) + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + # Check if this tool was truncated if tool.get('_truncation_detected'): truncated_tools.append({ @@ -385,6 +385,10 @@ async def stream_kiro_to_anthropic( tool_name = tc.get("function", {}).get("name", "") tool_input = tc.get("function", {}).get("arguments", {}) + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + if isinstance(tool_input, str): try: tool_input = json.loads(tool_input) @@ -456,13 +460,9 @@ async def stream_kiro_to_anthropic( # Calculate output tokens output_tokens = count_tokens(full_content + full_thinking_content) - - # Calculate total tokens from context usage if available - if context_usage_percentage is not None: - prompt_tokens, total_tokens, _, _ = calculate_tokens_from_context_usage( - context_usage_percentage, output_tokens, model_cache, model - ) - input_tokens = prompt_tokens + + # input_tokens already set from pre-counted prompt_tokens (full Kiro payload). + # Don't override with contextUsagePercentage — it's unreliable. # Determine stop reason stop_reason = "tool_use" if tool_blocks else "end_turn" @@ -545,29 +545,28 @@ async def collect_anthropic_response( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None + prompt_tokens: int = 0, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> dict: """ Collect full response from Kiro stream in Anthropic format. - + Used for non-streaming mode. - + Args: response: HTTP response with stream model: Model name model_cache: Model cache auth_manager: Authentication manager - request_messages: Original request messages (for token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Returns: Dictionary with full response in Anthropic Messages format """ message_id = generate_message_id() - - # Count input tokens - input_tokens = 0 - if request_messages: - input_tokens = count_message_tokens(request_messages, apply_claude_correction=False) + + # Use pre-counted prompt tokens + input_tokens = prompt_tokens # Collect stream result result = await collect_stream_to_result(response) @@ -601,6 +600,10 @@ async def collect_anthropic_response( tool_name = tc.get("function", {}).get("name", "") or tc.get("name", "") tool_input = tc.get("function", {}).get("arguments", {}) or tc.get("input", {}) + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + if isinstance(tool_input, str): try: tool_input = json.loads(tool_input) @@ -617,12 +620,8 @@ async def collect_anthropic_response( # Calculate output tokens output_tokens = count_tokens(result.content + result.thinking_content) - # Calculate from context usage if available - if result.context_usage_percentage is not None: - prompt_tokens, _, _, _ = calculate_tokens_from_context_usage( - result.context_usage_percentage, output_tokens, model_cache, model - ) - input_tokens = prompt_tokens + # input_tokens already set from pre-counted prompt_tokens (full Kiro payload). + # Don't override with contextUsagePercentage — it's unreliable. # Determine stop reason stop_reason = "tool_use" if result.tool_calls else "end_turn" @@ -655,18 +654,17 @@ async def stream_with_first_token_retry_anthropic( auth_manager: "KiroAuthManager", max_retries: int = FIRST_TOKEN_MAX_RETRIES, first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Streaming with automatic retry on first token timeout for Anthropic API. - + If model doesn't respond within first_token_timeout seconds, request is cancelled and a new one is made. Maximum max_retries attempts. - + This is seamless for user - they just see a delay, but eventually get a response (or error after all attempts). - + Args: make_request: Function to create new HTTP request model: Model name @@ -674,12 +672,11 @@ async def stream_with_first_token_retry_anthropic( auth_manager: Authentication manager max_retries: Maximum number of attempts first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in Anthropic SSE format - + Raises: Exception with Anthropic error format after exhausting all attempts """ @@ -711,7 +708,7 @@ async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None model_cache, auth_manager, first_token_timeout=first_token_timeout, - request_messages=request_messages + prompt_tokens=prompt_tokens ): yield chunk diff --git a/kiro/streaming_openai.py b/kiro/streaming_openai.py index 5153bc82..2ae02704 100644 --- a/kiro/streaming_openai.py +++ b/kiro/streaming_openai.py @@ -30,7 +30,7 @@ import json import time -from typing import TYPE_CHECKING, AsyncGenerator, Callable, Awaitable, Optional +from typing import TYPE_CHECKING, AsyncGenerator, Callable, Awaitable, Dict, Optional import httpx from fastapi import HTTPException @@ -43,14 +43,13 @@ FIRST_TOKEN_MAX_RETRIES, FAKE_REASONING_HANDLING, ) -from kiro.tokenizer import count_tokens, count_message_tokens, count_tools_tokens +from kiro.tokenizer import count_tokens # Import from streaming_core - reuse shared parsing logic from kiro.streaming_core import ( parse_kiro_stream, FirstTokenTimeoutError, KiroEvent, - calculate_tokens_from_context_usage, stream_with_first_token_retry as stream_with_first_token_retry_core, ) @@ -76,9 +75,9 @@ async def stream_kiro_to_openai_internal( model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None, - conversation_id: Optional[str] = None + prompt_tokens: int = 0, + conversation_id: Optional[str] = None, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> AsyncGenerator[str, None]: """ Internal generator for converting Kiro stream to OpenAI format. @@ -96,9 +95,7 @@ async def stream_kiro_to_openai_internal( model_cache: Model cache for getting token limits auth_manager: Authentication manager first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - conversation_id: Stable conversation ID for truncation recovery (optional) + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) conversation_id: Stable conversation ID for truncation recovery (optional) Yields: @@ -224,24 +221,9 @@ async def stream_kiro_to_openai_internal( # Count completion_tokens (output) using tiktoken completion_tokens = count_tokens(full_content + full_thinking_content) - - # Calculate total_tokens based on context_usage_percentage from Kiro API - # context_usage shows TOTAL percentage of context usage (input + output) - prompt_tokens, total_tokens, prompt_source, total_source = calculate_tokens_from_context_usage( - context_usage_percentage, completion_tokens, model_cache, model - ) - - # Fallback: Kiro API didn't return context_usage, use tiktoken - # Count prompt_tokens from original messages - # IMPORTANT: Don't apply correction coefficient for prompt_tokens, - # as it was calibrated for completion_tokens - if prompt_source == "unknown" and request_messages: - prompt_tokens = count_message_tokens(request_messages, apply_claude_correction=False) - if request_tools: - prompt_tokens += count_tools_tokens(request_tools, apply_claude_correction=False) - total_tokens = prompt_tokens + completion_tokens - prompt_source = "tiktoken" - total_source = "tiktoken" + + # Use pre-counted prompt_tokens from the full Kiro payload + total_tokens = prompt_tokens + completion_tokens # Send tool calls if present if all_tool_calls: @@ -257,6 +239,10 @@ async def stream_kiro_to_openai_internal( tool_name = func.get("name") or "" tool_args = func.get("arguments") or "{}" + # Reverse truncated tool name back to original + if tool_name_mapping and tool_name in tool_name_mapping: + tool_name = tool_name_mapping[tool_name] + logger.debug(f"Tool call [{idx}] '{tool_name}': id={tc.get('id')}, args_length={len(tool_args)}") indexed_tc = { @@ -329,9 +315,9 @@ async def stream_kiro_to_openai_internal( # Log final token values being sent to client logger.debug( f"[Usage] {model}: " - f"prompt_tokens={prompt_tokens} ({prompt_source}), " + f"prompt_tokens={prompt_tokens} (payload tiktoken), " f"completion_tokens={completion_tokens} (tiktoken), " - f"total_tokens={total_tokens} ({total_source})" + f"total_tokens={total_tokens} (sum)" ) yield f"data: {json.dumps(final_chunk, ensure_ascii=False)}\n\n" @@ -374,31 +360,31 @@ async def stream_kiro_to_openai( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> AsyncGenerator[str, None]: """ Generator for converting Kiro stream to OpenAI format. - + This is a wrapper over stream_kiro_to_openai_internal that does NOT retry. Retry logic is implemented in stream_with_first_token_retry. - + Args: client: HTTP client (for connection management) response: HTTP response with data stream model: Model name to include in response model_cache: Model cache for getting token limits auth_manager: Authentication manager - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + tool_name_mapping: Mapping of truncated tool names back to originals + Yields: Strings in SSE format: "data: {...}\\n\\n" or "data: [DONE]\\n\\n" """ async for chunk in stream_kiro_to_openai_internal( client, response, model, model_cache, auth_manager, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens, + tool_name_mapping=tool_name_mapping ): yield chunk @@ -411,8 +397,7 @@ async def stream_with_first_token_retry( auth_manager: "KiroAuthManager", max_retries: int = FIRST_TOKEN_MAX_RETRIES, first_token_timeout: float = FIRST_TOKEN_TIMEOUT, - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0 ) -> AsyncGenerator[str, None]: """ Streaming with automatic retry on first token timeout. @@ -433,15 +418,14 @@ async def stream_with_first_token_retry( auth_manager: Authentication manager max_retries: Maximum number of attempts first_token_timeout: First token wait timeout (seconds) - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Yields: Strings in SSE format - + Raises: HTTPException: After exhausting all attempts - + Example: >>> async def make_req(): ... return await http_client.request_with_retry("POST", url, payload, stream=True) @@ -454,14 +438,14 @@ def create_http_error(status_code: int, error_text: str) -> HTTPException: status_code=status_code, detail=f"Upstream API error: {error_text}" ) - + def create_timeout_error(retries: int, timeout: float) -> HTTPException: """Create HTTPException for timeout errors.""" return HTTPException( status_code=504, detail=f"Model did not respond within {timeout}s after {retries} attempts. Please try again." ) - + async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None]: """Process response and yield OpenAI SSE chunks.""" async for chunk in stream_kiro_to_openai_internal( @@ -471,8 +455,7 @@ async def stream_processor(response: httpx.Response) -> AsyncGenerator[str, None model_cache, auth_manager, first_token_timeout=first_token_timeout, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens ): yield chunk @@ -493,24 +476,23 @@ async def collect_stream_response( model: str, model_cache: "ModelInfoCache", auth_manager: "KiroAuthManager", - request_messages: Optional[list] = None, - request_tools: Optional[list] = None + prompt_tokens: int = 0, + tool_name_mapping: Optional[Dict[str, str]] = None ) -> dict: """ Collect full response from streaming stream. - + Used for non-streaming mode - collects all chunks and forms a single response. - + Args: client: HTTP client response: HTTP response with stream model: Model name model_cache: Model cache auth_manager: Authentication manager - request_messages: Original request messages (for fallback token counting) - request_tools: Original request tools (for fallback token counting) - + prompt_tokens: Pre-counted prompt tokens (from full Kiro payload) + Returns: Dictionary with full response in OpenAI chat.completion format """ @@ -519,15 +501,15 @@ async def collect_stream_response( final_usage = None tool_calls = [] completion_id = generate_completion_id() - + async for chunk_str in stream_kiro_to_openai( client, response, model, model_cache, auth_manager, - request_messages=request_messages, - request_tools=request_tools + prompt_tokens=prompt_tokens, + tool_name_mapping=tool_name_mapping ): if not chunk_str.startswith("data:"): continue diff --git a/tests/unit/test_converters_anthropic.py b/tests/unit/test_converters_anthropic.py index 692bf81d..eec7b7bf 100644 --- a/tests/unit/test_converters_anthropic.py +++ b/tests/unit/test_converters_anthropic.py @@ -1461,7 +1461,7 @@ def test_builds_simple_payload(self): return_value="claude-sonnet-4.5", ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", False): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") assert "conversationState" in result @@ -1489,7 +1489,7 @@ def test_includes_system_prompt(self): return_value="claude-sonnet-4.5", ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", False): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") current_content = result["conversationState"]["currentMessage"][ @@ -1526,7 +1526,7 @@ def test_includes_tools(self): return_value="claude-sonnet-4.5", ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", False): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") context = result["conversationState"]["currentMessage"]["userInputMessage"].get( @@ -1559,7 +1559,7 @@ def test_builds_history_for_multi_turn(self): return_value="claude-sonnet-4.5", ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", False): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") history = result["conversationState"].get("history", []) @@ -1621,7 +1621,7 @@ def test_handles_tool_use_and_result_flow(self): return_value="claude-sonnet-4.5", ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", False): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") @@ -1677,7 +1677,7 @@ def test_injects_thinking_tags_when_enabled(self): ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", True): with patch("kiro.converters_core.FAKE_REASONING_MAX_TOKENS", 4000): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") current_content = result["conversationState"]["currentMessage"][ @@ -1727,7 +1727,7 @@ def test_injects_thinking_tags_even_when_tool_results_present(self): ): with patch("kiro.converters_core.FAKE_REASONING_ENABLED", True): with patch("kiro.converters_core.FAKE_REASONING_MAX_TOKENS", 4000): - result = anthropic_to_kiro(request, "conv-123", "arn:aws:test") + result = anthropic_to_kiro(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") current_content = result["conversationState"]["currentMessage"][ diff --git a/tests/unit/test_converters_core.py b/tests/unit/test_converters_core.py index c042dd8b..32663a42 100644 --- a/tests/unit/test_converters_core.py +++ b/tests/unit/test_converters_core.py @@ -5918,205 +5918,91 @@ def test_images_with_thinking_injection(self): # ================================================================================================== -# Tests for validate_tool_names (Issue #41 fix) +# Tests for truncate_tool_names (Issue #41 fix) # ================================================================================================== -class TestValidateToolNames: +class TestTruncateToolNames: """ - Tests for validate_tool_names function. + Tests for truncate_tool_names function. - This function validates tool names against Kiro API 64-character limit. + This function truncates tool names exceeding Kiro API 64-character limit. Issue #41: 400 Improperly formed request with long tool names from MCP servers. """ - def test_accepts_short_tool_names(self): - """ - What it does: Verifies that short tool names are accepted. - Purpose: Ensure normal tool names pass validation. - """ - print("Setup: Tool with short name...") + def test_short_names_unchanged(self): + """Short names should not be modified.""" + from kiro.converters_core import truncate_tool_names tools = [UnifiedTool(name="get_weather", description="Get weather")] - - print("Action: Validating tool names...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - print("Validation passed - OK") - except ValueError as e: - print(f"ERROR: Validation failed: {e}") - raise AssertionError("Short tool names should be accepted") + mapping = truncate_tool_names(tools) + assert mapping == {} + assert tools[0].name == "get_weather" - def test_accepts_exactly_64_character_name(self): - """ - What it does: Verifies that exactly 64-character names are accepted (boundary). - Purpose: Ensure boundary case is handled correctly. - """ - print("Setup: Tool with exactly 64-character name...") + def test_exactly_64_unchanged(self): + """Exactly 64-character names should not be modified.""" + from kiro.converters_core import truncate_tool_names name_64 = "a" * 64 tools = [UnifiedTool(name=name_64, description="Test")] - - print(f"Tool name length: {len(name_64)}") - print("Action: Validating tool names...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - print("Validation passed - OK") - except ValueError as e: - print(f"ERROR: Validation failed: {e}") - raise AssertionError("64-character names should be accepted") + mapping = truncate_tool_names(tools) + assert mapping == {} + assert tools[0].name == name_64 - def test_rejects_65_character_name(self): - """ - What it does: Verifies that 65-character names are rejected. - Purpose: Ensure names exceeding limit are caught. - """ - print("Setup: Tool with 65-character name...") + def test_65_char_name_truncated(self): + """65-character names should be truncated to 64.""" + from kiro.converters_core import truncate_tool_names name_65 = "a" * 65 tools = [UnifiedTool(name=name_65, description="Test")] - - print(f"Tool name length: {len(name_65)}") - print("Action: Validating tool names (should raise ValueError)...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - print("ERROR: Validation passed but should have failed") - raise AssertionError("65-character names should be rejected") - except ValueError as e: - print(f"Validation correctly rejected: {str(e)[:100]}...") - assert "exceed Kiro API limit" in str(e) - assert name_65 in str(e) - - def test_rejects_very_long_tool_names(self): - """ - What it does: Verifies that very long tool names are rejected. - Purpose: Ensure the validation works for extreme cases. - """ - print("Setup: Tool with 100-character name...") - name_100 = "mcp__GitHub__" + "a" * 87 - tools = [UnifiedTool(name=name_100, description="Test")] - - print(f"Tool name length: {len(name_100)}") - print("Action: Validating tool names (should raise ValueError)...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - raise AssertionError("Very long names should be rejected") - except ValueError as e: - print(f"Validation correctly rejected: {str(e)[:100]}...") - assert "exceed Kiro API limit" in str(e) - assert "100 characters" in str(e) - - def test_rejects_multiple_long_names(self): - """ - What it does: Verifies that all long names are listed in error message. - Purpose: Ensure user sees all problematic tools at once. - """ - print("Setup: Multiple tools with long names...") + mapping = truncate_tool_names(tools) + assert len(tools[0].name) == 64 + assert mapping[tools[0].name] == name_65 + + def test_truncation_is_deterministic(self): + """Same input should always produce same truncated name.""" + from kiro.converters_core import truncate_tool_names + name = "mcp__GitHub__check_if_a_person_is_followed_by_the_authenticated_user" + tools1 = [UnifiedTool(name=name, description="Test")] + tools2 = [UnifiedTool(name=name, description="Test")] + truncate_tool_names(tools1) + truncate_tool_names(tools2) + assert tools1[0].name == tools2[0].name + + def test_multiple_long_names(self): + """Multiple long names should all be truncated.""" + from kiro.converters_core import truncate_tool_names tools = [ UnifiedTool(name="a" * 65, description="Test 1"), UnifiedTool(name="short", description="Test 2"), UnifiedTool(name="b" * 70, description="Test 3") ] - - print("Action: Validating tool names (should raise ValueError)...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - raise AssertionError("Should reject multiple long names") - except ValueError as e: - error_msg = str(e) - print(f"Error message: {error_msg[:200]}...") - - print("Checking that both long names are listed...") - assert "65 characters" in error_msg - assert "70 characters" in error_msg - - def test_handles_none_tools(self): - """ - What it does: Verifies that None tools list is handled gracefully. - Purpose: Ensure function doesn't crash on None input. - """ - print("Setup: None tools...") - - print("Action: Validating None...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(None) - print("Validation passed - OK") - except Exception as e: - print(f"ERROR: Unexpected exception: {e}") - raise AssertionError("None should be handled gracefully") + mapping = truncate_tool_names(tools) + assert len(mapping) == 2 + assert tools[1].name == "short" + assert all(len(t.name) <= 64 for t in tools) - def test_handles_empty_tools_list(self): - """ - What it does: Verifies that empty tools list is handled gracefully. - Purpose: Ensure function doesn't crash on empty list. - """ - print("Setup: Empty tools list...") - - print("Action: Validating empty list...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names([]) - print("Validation passed - OK") - except Exception as e: - print(f"ERROR: Unexpected exception: {e}") - raise AssertionError("Empty list should be handled gracefully") - - def test_error_message_includes_solution(self): - """ - What it does: Verifies that error message includes solution guidance. - Purpose: Ensure user knows how to fix the problem. - """ - print("Setup: Tool with long name...") - tools = [UnifiedTool(name="mcp__GitHub__" + "a" * 60, description="Test")] - - print("Action: Validating tool names (should raise ValueError)...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - raise AssertionError("Should reject long name") - except ValueError as e: - error_msg = str(e) - print(f"Error message: {error_msg[:300]}...") - - print("Checking that error message includes solution...") - assert "Solution:" in error_msg - assert "64 characters" in error_msg - assert "Example:" in error_msg + def test_handles_none(self): + """None input should return empty mapping.""" + from kiro.converters_core import truncate_tool_names + assert truncate_tool_names(None) == {} - def test_real_world_mcp_tool_names(self): - """ - What it does: Verifies rejection of real MCP tool names from Issue #41. - Purpose: Ensure the fix works for actual problematic tool names. - """ - print("Setup: Real MCP tool names from Issue #41...") - problematic_names = [ + def test_handles_empty_list(self): + """Empty list should return empty mapping.""" + from kiro.converters_core import truncate_tool_names + assert truncate_tool_names([]) == {} + + def test_real_mcp_tool_names(self): + """Real MCP tool names from Issue #41 should be truncated.""" + from kiro.converters_core import truncate_tool_names + names = [ "mcp__GitHub__check_if_a_person_is_followed_by_the_authenticated_user", "mcp__GitHub__check_if_a_repository_is_starred_by_the_authenticated_user", "mcp__GitHub__remove_interaction_restrictions_from_your_public_repositories", ] - - tools = [UnifiedTool(name=name, description="Test") for name in problematic_names] - - print("Action: Validating real MCP tool names (should raise ValueError)...") - try: - from kiro.converters_core import validate_tool_names - validate_tool_names(tools) - raise AssertionError("Should reject real MCP tool names") - except ValueError as e: - error_msg = str(e) - print(f"Error message length: {len(error_msg)} chars") - print(f"Error message: {error_msg[:400]}...") - - print("Checking that all problematic names are listed...") - for name in problematic_names: - assert name in error_msg, f"Tool name '{name}' should be in error message" - - print("Checking that character counts are shown...") - assert "68 characters" in error_msg - assert "71 characters" in error_msg - assert "74 characters" in error_msg + tools = [UnifiedTool(name=n, description="Test") for n in names] + mapping = truncate_tool_names(tools) + assert len(mapping) == 3 + assert all(len(t.name) <= 64 for t in tools) + # Mapping should reverse back to originals + for tool in tools: + assert mapping[tool.name] in names # ================================================================================================== diff --git a/tests/unit/test_converters_openai.py b/tests/unit/test_converters_openai.py index 8d46a410..360cc51c 100644 --- a/tests/unit/test_converters_openai.py +++ b/tests/unit/test_converters_openai.py @@ -680,7 +680,7 @@ def test_builds_simple_payload(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "arn:aws:test") + result = build_kiro_payload(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") assert "conversationState" in result @@ -703,7 +703,7 @@ def test_includes_system_prompt_in_first_message(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") current_content = result["conversationState"]["currentMessage"]["userInputMessage"]["content"] @@ -726,7 +726,7 @@ def test_builds_history_for_multi_turn(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") assert "history" in result["conversationState"] @@ -747,7 +747,7 @@ def test_handles_assistant_as_last_message(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") current_content = result["conversationState"]["currentMessage"]["userInputMessage"]["content"] @@ -785,7 +785,7 @@ def test_uses_continue_for_empty_content(self): print("Action: Building payload (with fake reasoning and truncation recovery disabled)...") with patch('kiro.converters_core.FAKE_REASONING_ENABLED', False): with patch('kiro.config.TRUNCATION_RECOVERY', False): - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") current_content = result["conversationState"]["currentMessage"]["userInputMessage"]["content"] @@ -807,7 +807,7 @@ def test_normalizes_model_id_correctly(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") model_id = result["conversationState"]["currentMessage"]["userInputMessage"]["modelId"] @@ -835,7 +835,7 @@ def test_includes_tools_in_context(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") context = result["conversationState"]["currentMessage"]["userInputMessage"]["userInputMessageContext"] @@ -880,7 +880,7 @@ def test_injects_thinking_tags_even_when_tool_results_present(self): print("Action: Building payload with FAKE_REASONING_ENABLED=True...") with patch('kiro.converters_core.FAKE_REASONING_ENABLED', True): with patch('kiro.converters_core.FAKE_REASONING_MAX_TOKENS', 4000): - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload current_msg = result["conversationState"]["currentMessage"]["userInputMessage"] content = current_msg["content"] @@ -907,7 +907,7 @@ def test_injects_thinking_tags_when_no_tool_results(self): print("Action: Building payload with FAKE_REASONING_ENABLED=True...") with patch('kiro.converters_core.FAKE_REASONING_ENABLED', True): with patch('kiro.converters_core.FAKE_REASONING_MAX_TOKENS', 4000): - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload current_msg = result["conversationState"]["currentMessage"]["userInputMessage"] content = current_msg["content"] @@ -1045,7 +1045,7 @@ def test_empty_description_replaced_with_placeholder(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") print("Checking that description is replaced with placeholder...") @@ -1073,7 +1073,7 @@ def test_whitespace_only_description_replaced_with_placeholder(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") print("Checking that description is replaced with placeholder...") @@ -1101,7 +1101,7 @@ def test_none_description_replaced_with_placeholder(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") print("Checking that description is replaced with placeholder...") @@ -1129,7 +1129,7 @@ def test_non_empty_description_preserved(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") print("Checking that description is preserved...") @@ -1162,7 +1162,7 @@ def test_sanitizes_tool_parameters(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") print("Checking that parameters are sanitized...") @@ -1212,7 +1212,7 @@ def test_mixed_tools_with_empty_and_normal_descriptions(self): ) print("Action: Building payload...") - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print(f"Result: {result}") print("Checking descriptions...") @@ -1284,7 +1284,7 @@ def test_multiple_assistant_tool_calls_with_results(self): ) print("Action: Building Kiro payload...") - result = build_kiro_payload(request, "conv-123", "arn:aws:test") + result = build_kiro_payload(request, "conv-123", "arn:aws:test").payload print(f"Result: {result}") @@ -1351,7 +1351,7 @@ def test_long_tool_description_added_to_system_prompt(self): print("Action: Building payload...") with patch('kiro.converters_core.TOOL_DESCRIPTION_MAX_LENGTH', 10000): - result = build_kiro_payload(request, "conv-123", "") + result = build_kiro_payload(request, "conv-123", "").payload print("Checking that system prompt contains tool documentation...") current_content = result["conversationState"]["currentMessage"]["userInputMessage"]["content"] diff --git a/tests/unit/test_models_anthropic.py b/tests/unit/test_models_anthropic.py index 9e7ef2aa..5fa1e216 100644 --- a/tests/unit/test_models_anthropic.py +++ b/tests/unit/test_models_anthropic.py @@ -1012,33 +1012,24 @@ def test_valid_tool(self): print(f"Comparing input_schema: Got {tool.input_schema}") assert "properties" in tool.input_schema - def test_requires_name(self): + def test_name_is_optional_for_server_tools(self): """ - What it does: Verifies that name is required. - Purpose: Ensure validation fails without name. + What it does: Verifies that name is optional (server tools use type instead). + Purpose: Ensure server tools like web_search can be created without name. """ - print("Setup: Attempting to create AnthropicTool without name...") - - print("Action: Creating model (should raise ValidationError)...") - with pytest.raises(ValidationError) as exc_info: - AnthropicTool(input_schema={}) - - print(f"ValidationError raised: {exc_info.value}") - assert "name" in str(exc_info.value) + print("Setup: Creating AnthropicTool without name (server tool)...") + tool = AnthropicTool(type="web_search_20250305", input_schema=None) + assert tool.name is None + assert tool.type == "web_search_20250305" - def test_requires_input_schema(self): + def test_input_schema_is_optional_for_server_tools(self): """ - What it does: Verifies that input_schema is required. - Purpose: Ensure validation fails without input_schema. + What it does: Verifies that input_schema is optional (server tools don't have it). + Purpose: Ensure server tools can be created without input_schema. """ - print("Setup: Attempting to create AnthropicTool without input_schema...") - - print("Action: Creating model (should raise ValidationError)...") - with pytest.raises(ValidationError) as exc_info: - AnthropicTool(name="test") - - print(f"ValidationError raised: {exc_info.value}") - assert "input_schema" in str(exc_info.value) + print("Setup: Creating AnthropicTool without input_schema (server tool)...") + tool = AnthropicTool(type="code_execution_20250522", name="code_execution") + assert tool.input_schema is None def test_description_is_optional(self): """ diff --git a/tests/unit/test_routes_anthropic.py b/tests/unit/test_routes_anthropic.py index d8b2db4f..2006fbc0 100644 --- a/tests/unit/test_routes_anthropic.py +++ b/tests/unit/test_routes_anthropic.py @@ -266,8 +266,8 @@ def test_validates_missing_model(self, test_client, valid_proxy_api_key): def test_validates_missing_max_tokens(self, test_client, valid_proxy_api_key): """ - What it does: Verifies missing max_tokens field is rejected. - Purpose: Ensure max_tokens is required (Anthropic API requirement). + What it does: Verifies missing max_tokens defaults to 4096. + Purpose: Ensure max_tokens is optional (needed for count_tokens endpoint). """ print("Action: POST /v1/messages without max_tokens...") response = test_client.post( @@ -280,7 +280,8 @@ def test_validates_missing_max_tokens(self, test_client, valid_proxy_api_key): ) print(f"Status: {response.status_code}") - assert response.status_code == 422 + # Should not be 422 — max_tokens defaults to 4096 + assert response.status_code != 422 def test_validates_missing_messages(self, test_client, valid_proxy_api_key): """ diff --git a/tests/unit/test_streaming_anthropic.py b/tests/unit/test_streaming_anthropic.py index 6932fad1..16f77ba7 100644 --- a/tests/unit/test_streaming_anthropic.py +++ b/tests/unit/test_streaming_anthropic.py @@ -682,12 +682,11 @@ async def test_includes_usage_info(self, mock_response, mock_model_cache, mock_a print("Action: Collecting Anthropic response...") with patch('kiro.streaming_anthropic.collect_stream_to_result', return_value=mock_result): - with patch('kiro.streaming_anthropic.count_message_tokens', return_value=10): - with patch('kiro.streaming_anthropic.count_tokens', return_value=5): - result = await collect_anthropic_response( - mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, - request_messages=[{"role": "user", "content": "Hi"}] - ) + with patch('kiro.streaming_anthropic.count_tokens', return_value=5): + result = await collect_anthropic_response( + mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, + prompt_tokens=10 + ) print(f"Usage: {result['usage']}") assert "input_tokens" in result["usage"] @@ -1082,36 +1081,32 @@ async def mock_parse_kiro_stream(*args, **kwargs): print("✓ Tokens calculated from context usage") @pytest.mark.asyncio - async def test_uses_request_messages_for_input_tokens(self, mock_response, mock_model_cache, mock_auth_manager): + async def test_uses_prompt_tokens_for_input_tokens(self, mock_response, mock_model_cache, mock_auth_manager): """ - What it does: Uses request messages for input token count. - Goal: Verify input tokens are counted from request. + What it does: Uses pre-counted prompt_tokens for input token count. + Goal: Verify input tokens are passed through from prompt_tokens param. """ print("Setup: Mock stream...") - + async def mock_parse_kiro_stream(*args, **kwargs): yield KiroEvent(type="content", content="Hello") - - request_messages = [ - {"role": "user", "content": "Hi there!"} - ] - - print("Action: Streaming to Anthropic format with request messages...") + + print("Action: Streaming to Anthropic format with prompt_tokens...") events = [] - + with patch('kiro.streaming_anthropic.parse_kiro_stream', mock_parse_kiro_stream): with patch('kiro.streaming_anthropic.parse_bracket_tool_calls', return_value=[]): - with patch('kiro.streaming_anthropic.count_message_tokens', return_value=10) as mock_count: - async for event in stream_kiro_to_anthropic( - mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, - request_messages=request_messages - ): - events.append(event) - - # Verify count_message_tokens was called - mock_count.assert_called_once_with(request_messages, apply_claude_correction=False) - - print("✓ Request messages used for input token count") + async for event in stream_kiro_to_anthropic( + mock_response, "claude-sonnet-4", mock_model_cache, mock_auth_manager, + prompt_tokens=42 + ): + events.append(event) + + # Check message_start has the prompt_tokens as input_tokens + message_start = json.loads(events[0].split("data: ")[1]) + assert message_start["message"]["usage"]["input_tokens"] == 42 + + print("✓ prompt_tokens used for input token count") # ================================================================================================== @@ -1362,44 +1357,42 @@ async def mock_make_request(): print("✓ Anthropic-formatted error raised on HTTP error") @pytest.mark.asyncio - async def test_passes_request_messages_to_stream(self, mock_model_cache, mock_auth_manager): + async def test_passes_prompt_tokens_to_stream(self, mock_model_cache, mock_auth_manager): """ - What it does: Passes request_messages to underlying stream function. + What it does: Passes prompt_tokens to underlying stream function. Goal: Verify token counting parameters are forwarded. """ - print("Setup: Mock request with messages...") - + print("Setup: Mock request with prompt_tokens...") + mock_response = AsyncMock() mock_response.status_code = 200 mock_response.aclose = AsyncMock() - + async def mock_make_request(): return mock_response - + captured_kwargs = {} - + async def mock_stream_kiro_to_anthropic(*args, **kwargs): captured_kwargs.update(kwargs) yield "event: message_start\ndata: {}\n\n" yield "event: message_stop\ndata: {}\n\n" - - request_messages = [{"role": "user", "content": "Hello"}] - - print("Action: Streaming with request_messages...") - + + print("Action: Streaming with prompt_tokens...") + with patch('kiro.streaming_anthropic.stream_kiro_to_anthropic', mock_stream_kiro_to_anthropic): async for chunk in stream_with_first_token_retry_anthropic( make_request=mock_make_request, model="claude-sonnet-4", model_cache=mock_model_cache, auth_manager=mock_auth_manager, - request_messages=request_messages + prompt_tokens=42 ): pass - + print(f"Captured kwargs: {captured_kwargs}") - assert captured_kwargs.get("request_messages") == request_messages - print("✓ request_messages passed to stream function") + assert captured_kwargs.get("prompt_tokens") == 42 + print("✓ prompt_tokens passed to stream function") @pytest.mark.asyncio async def test_uses_configured_max_retries(self, mock_model_cache, mock_auth_manager):