diff --git a/README.md b/README.md index 89c2e625..a2ee86bb 100644 --- a/README.md +++ b/README.md @@ -430,6 +430,7 @@ Leave `VPN_PROXY_URL` empty (default) if you don't need proxy support. | `/` | GET | Health check | | `/health` | GET | Detailed health check | | `/v1/models` | GET | List available models | +| `/v1/usage` | GET | Current Kiro plan and usage limits | | `/v1/chat/completions` | POST | OpenAI Chat Completions API | | `/v1/messages` | POST | Anthropic Messages API | @@ -457,6 +458,32 @@ curl http://localhost:8000/v1/chat/completions \ +
+📊 Check Usage Limits + +```bash +curl "http://localhost:8000/v1/usage" \ + -H "Authorization: Bearer my-super-secret-password-123" +``` + +The endpoint proxies CodeWhisperer Runtime `GetUsageLimits`, preserves the upstream JSON, +and adds a derived `usageSummary` block with normalized key fields such as reset time, +primary quota usage, and free trial usage. + +Default query parameters: +- `origin=AI_EDITOR` +- `resource_type=AGENTIC_REQUEST` +- `is_email_required=true` + +Custom example: + +```bash +curl "http://localhost:8000/v1/usage?origin=AI_EDITOR&resource_type=AGENTIC_REQUEST&is_email_required=false" \ + -H "Authorization: Bearer my-super-secret-password-123" +``` + +
+
🔹 Streaming Request diff --git a/docs/en/ARCHITECTURE.md b/docs/en/ARCHITECTURE.md index 93eb4dd3..7261dfc5 100644 --- a/docs/en/ARCHITECTURE.md +++ b/docs/en/ARCHITECTURE.md @@ -385,6 +385,7 @@ Supports async context manager (`async with`). | `/` | GET | Health check (status, message, version) | | `/health` | GET | Detailed health check (status, timestamp, version) | | `/v1/models` | GET | List of available models (requires API key) | +| `/v1/usage` | GET | Current Kiro plan and usage limits (requires API key) | | `/v1/chat/completions` | POST | Chat completions (requires API key) | **Authentication:** Bearer token in `Authorization` header @@ -666,10 +667,22 @@ TOOL_DESCRIPTION_MAX_LENGTH="10000" | Endpoint | Method | Description | |----------|--------|-------------| | `/v1/models` | GET | List of available models | +| `/v1/usage` | GET | Current Kiro plan and usage limits | | `/v1/chat/completions` | POST | Chat completions (streaming/non-streaming) | **Authentication:** `Authorization: Bearer {PROXY_API_KEY}` +`/v1/usage` proxies CodeWhisperer Runtime `GetUsageLimits` using: +- `origin=AI_EDITOR` by default +- `resource_type=AGENTIC_REQUEST` by default +- `is_email_required=true` by default + +The route also accepts those values as query parameters when callers need to override them. + +The gateway preserves the upstream response body and adds a derived `usageSummary` +block so clients can read normalized reset/free-trial values without parsing the +nested runtime payload shape. + ### 7.3 Anthropic-compatible Endpoints | Endpoint | Method | Description | diff --git a/docs/zh/README.md b/docs/zh/README.md index 41353f2d..3545aee0 100644 --- a/docs/zh/README.md +++ b/docs/zh/README.md @@ -430,6 +430,7 @@ VPN_PROXY_URL=192.168.1.100:8080 | `/` | GET | 健康检查 | | `/health` | GET | 详细健康检查 | | `/v1/models` | GET | 列出可用模型 | +| `/v1/usage` | GET | 当前 Kiro 套餐与用量限制 | | `/v1/chat/completions` | POST | OpenAI Chat Completions API | | `/v1/messages` | POST | Anthropic Messages API | @@ -457,6 +458,32 @@ curl http://localhost:8000/v1/chat/completions \
+
+📊 查询用量限制 + +```bash +curl "http://localhost:8000/v1/usage" \ + -H "Authorization: Bearer my-super-secret-password-123" +``` + +该端点会代理 CodeWhisperer Runtime 的 `GetUsageLimits`,保留上游原始 JSON, +并额外补充一个 `usageSummary` 摘要字段,方便直接读取重置时间、主额度用量和 +Free trial 用量等关键信息。 + +默认查询参数: +- `origin=AI_EDITOR` +- `resource_type=AGENTIC_REQUEST` +- `is_email_required=true` + +自定义参数示例: + +```bash +curl "http://localhost:8000/v1/usage?origin=AI_EDITOR&resource_type=AGENTIC_REQUEST&is_email_required=false" \ + -H "Authorization: Bearer my-super-secret-password-123" +``` + +
+
🔹 流式请求 diff --git a/kiro/http_client.py b/kiro/http_client.py index 943fa499..a996ff6f 100644 --- a/kiro/http_client.py +++ b/kiro/http_client.py @@ -31,7 +31,7 @@ """ import asyncio -from typing import Optional +from typing import Any, Optional import httpx from fastapi import HTTPException @@ -170,8 +170,9 @@ async def request_with_retry( self, method: str, url: str, - json_data: dict, - stream: bool = False + json_data: Optional[dict[str, Any]] = None, + stream: bool = False, + params: Optional[dict[str, Any]] = None, ) -> httpx.Response: """ Executes an HTTP request with retry logic. @@ -188,9 +189,10 @@ async def request_with_retry( Args: method: HTTP method (GET, POST, etc.) url: Request URL - json_data: Request body (JSON) + json_data: Optional request body (JSON) stream: Use streaming (default False) - + params: Optional query parameters + Returns: httpx.Response with successful response @@ -214,12 +216,24 @@ async def request_with_retry( if stream: # Prevent CLOSE_WAIT connection leak (issue #38) headers["Connection"] = "close" - req = client.build_request(method, url, json=json_data, headers=headers) + req = client.build_request( + method, + url, + params=params, + json=json_data, + headers=headers, + ) logger.debug("Sending request to Kiro API...") response = await client.send(req, stream=True) else: logger.debug("Sending request to Kiro API...") - response = await client.request(method, url, json=json_data, headers=headers) + response = await client.request( + method, + url, + params=params, + json=json_data, + headers=headers, + ) # Check status if response.status_code == 200: @@ -323,4 +337,4 @@ async def __aenter__(self) -> "KiroHttpClient": async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: """Closes the client when exiting context.""" - await self.close() \ No newline at end of file + await self.close() diff --git a/kiro/routes_openai.py b/kiro/routes_openai.py index 301ae9b5..8d2eada7 100644 --- a/kiro/routes_openai.py +++ b/kiro/routes_openai.py @@ -23,6 +23,7 @@ Contains all API endpoints: - / and /health: Health check - /v1/models: Models list +- /v1/usage: Current Kiro usage and plan limits - /v1/chat/completions: Chat completions """ @@ -49,6 +50,7 @@ from kiro.converters_openai import build_kiro_payload from kiro.streaming_openai import stream_kiro_to_openai, collect_stream_response, stream_with_first_token_retry from kiro.http_client import KiroHttpClient +from kiro.runtime_usage import fetch_usage_limits from kiro.utils import generate_conversation_id # Import debug_logger @@ -150,6 +152,48 @@ async def get_models(request: Request): return ModelList(data=openai_models) +@router.get("/v1/usage", dependencies=[Depends(verify_api_key)]) +async def get_usage( + request: Request, + origin: str = "AI_EDITOR", + resource_type: str = "AGENTIC_REQUEST", + is_email_required: bool = True, +): + """ + Return current Kiro usage and plan limits. + + Args: + request: FastAPI Request for accessing app.state + origin: Upstream runtime origin query parameter + resource_type: Upstream runtime resource type query parameter + is_email_required: Whether upstream should include user email + + Returns: + Upstream GetUsageLimits payload plus a derived `usageSummary` block + + Raises: + HTTPException: On authentication, validation, network, or upstream errors + """ + logger.info( + "Request to /v1/usage " + f"(origin={origin}, resource_type={resource_type}, include_email={is_email_required})" + ) + + auth_manager: KiroAuthManager = request.app.state.auth_manager + shared_client = request.app.state.http_client + + try: + return await fetch_usage_limits( + auth_manager=auth_manager, + shared_client=shared_client, + origin=origin, + resource_type=resource_type, + is_email_required=is_email_required, + ) + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + @router.post("/v1/chat/completions", dependencies=[Depends(verify_api_key)]) async def chat_completions(request: Request, request_data: ChatCompletionRequest): """ @@ -418,4 +462,4 @@ async def stream_wrapper(): # Flush debug logs on internal error ("errors" mode) if debug_logger: debug_logger.flush_on_error(500, str(e)) - raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") \ No newline at end of file + raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}") diff --git a/kiro/runtime_usage.py b/kiro/runtime_usage.py new file mode 100644 index 00000000..08cf15ed --- /dev/null +++ b/kiro/runtime_usage.py @@ -0,0 +1,243 @@ +""" +Usage limit helpers for CodeWhisperer Runtime endpoints. + +This module keeps usage/plan retrieval separate from route handlers so +additional runtime billing endpoints can reuse the same request flow. +""" + +import json +from datetime import datetime, timezone +from typing import Any, Optional + +import httpx +from fastapi import HTTPException +from loguru import logger + +from kiro.auth import AuthType, KiroAuthManager +from kiro.http_client import KiroHttpClient + +DEFAULT_USAGE_ORIGIN = "AI_EDITOR" +DEFAULT_USAGE_RESOURCE_TYPE = "AGENTIC_REQUEST" + + +def build_usage_limits_params( + auth_manager: KiroAuthManager, + origin: str = DEFAULT_USAGE_ORIGIN, + resource_type: str = DEFAULT_USAGE_RESOURCE_TYPE, + is_email_required: bool = True, +) -> dict[str, str]: + """ + Build query parameters for CodeWhisperer Runtime GetUsageLimits. + + Args: + auth_manager: Authentication manager with auth type and optional profile ARN + origin: Request origin expected by CodeWhisperer Runtime + resource_type: Resource type to query usage for + is_email_required: Whether upstream should include user email in the response + + Returns: + Query parameters for the `/getUsageLimits` runtime endpoint + + Raises: + ValueError: If origin or resource_type is empty + """ + if not origin: + raise ValueError("origin must not be empty") + if not resource_type: + raise ValueError("resource_type must not be empty") + + params = { + "origin": origin, + "resourceType": resource_type, + "isEmailRequired": str(is_email_required).lower(), + } + + if auth_manager.auth_type == AuthType.KIRO_DESKTOP and auth_manager.profile_arn: + params["profileArn"] = auth_manager.profile_arn + + return params + + +def _extract_usage_error_message(response: httpx.Response) -> str: + """ + Extract a readable error message from an upstream usage response. + + Args: + response: HTTP response returned by the runtime API + + Returns: + Best-effort human-readable error message + """ + try: + payload = response.json() + except json.JSONDecodeError: + return response.text or "Unknown upstream error" + + if isinstance(payload, dict): + if isinstance(payload.get("Output"), dict): + nested_message = payload["Output"].get("message") + if isinstance(nested_message, str) and nested_message: + return nested_message + + for key in ("message", "Message", "detail"): + value = payload.get(key) + if isinstance(value, str) and value: + return value + + return response.text or "Unknown upstream error" + + +def _format_utc_timestamp(timestamp: Any) -> Optional[str]: + """ + Convert a Unix timestamp into an ISO 8601 UTC string. + + Args: + timestamp: Unix timestamp in seconds + + Returns: + ISO 8601 UTC string, or None when the value is missing/invalid + """ + if not isinstance(timestamp, (int, float)) or isinstance(timestamp, bool): + return None + + dt = datetime.fromtimestamp(float(timestamp), tz=timezone.utc) + timespec = "seconds" if float(timestamp).is_integer() else "milliseconds" + return dt.isoformat(timespec=timespec).replace("+00:00", "Z") + + +def _find_primary_usage_breakdown(payload: dict[str, Any]) -> Optional[dict[str, Any]]: + """ + Select the primary usage breakdown from the upstream payload. + + Args: + payload: Parsed upstream usage response + + Returns: + Preferred usage breakdown entry, or None if unavailable + """ + breakdown_list = payload.get("usageBreakdownList") + if isinstance(breakdown_list, list): + dict_items = [item for item in breakdown_list if isinstance(item, dict)] + for item in dict_items: + if item.get("resourceType") == "CREDIT": + return item + if dict_items: + return dict_items[0] + + breakdown = payload.get("usageBreakdown") + if isinstance(breakdown, dict): + return breakdown + + return None + + +def build_usage_summary(payload: dict[str, Any]) -> dict[str, Any]: + """ + Build a stable summary block from the upstream usage payload. + + The gateway keeps the original upstream structure intact and adds this + derived block so clients can read common usage fields without reverse + engineering the nested response shape. + + Args: + payload: Raw upstream usage payload + + Returns: + Derived usage summary with normalized timestamp fields + """ + primary_breakdown = _find_primary_usage_breakdown(payload) + free_trial_info = ( + primary_breakdown.get("freeTrialInfo") + if isinstance(primary_breakdown, dict) and isinstance(primary_breakdown.get("freeTrialInfo"), dict) + else {} + ) + + return { + "resetAt": _format_utc_timestamp( + primary_breakdown.get("nextDateReset") + if isinstance(primary_breakdown, dict) + else payload.get("nextDateReset") + ), + "primaryLimit": primary_breakdown.get("usageLimit") if isinstance(primary_breakdown, dict) else None, + "primaryUsed": primary_breakdown.get("currentUsageWithPrecision") + if isinstance(primary_breakdown, dict) + else None, + "primaryUnit": primary_breakdown.get("unit") if isinstance(primary_breakdown, dict) else None, + "freeTrialLimit": free_trial_info.get("usageLimit"), + "freeTrialUsed": free_trial_info.get("currentUsageWithPrecision"), + "freeTrialExpiresAt": _format_utc_timestamp(free_trial_info.get("freeTrialExpiry")), + } + + +async def fetch_usage_limits( + auth_manager: KiroAuthManager, + shared_client: Optional[httpx.AsyncClient], + origin: str = DEFAULT_USAGE_ORIGIN, + resource_type: str = DEFAULT_USAGE_RESOURCE_TYPE, + is_email_required: bool = True, +) -> dict[str, Any]: + """ + Fetch current usage limits from CodeWhisperer Runtime. + + Args: + auth_manager: Authentication manager for token refresh and profile selection + shared_client: Shared HTTP client from app state + origin: Request origin expected by upstream + resource_type: Usage bucket to query + is_email_required: Whether upstream should include user email + + Returns: + Upstream JSON response plus a derived `usageSummary` block + + Raises: + HTTPException: If upstream returns an error or invalid JSON + ValueError: If the request parameters are invalid + """ + params = build_usage_limits_params( + auth_manager=auth_manager, + origin=origin, + resource_type=resource_type, + is_email_required=is_email_required, + ) + url = f"{auth_manager.q_host}/getUsageLimits" + + logger.info( + "Fetching usage limits from CodeWhisperer Runtime " + f"(origin={origin}, resource_type={resource_type}, include_email={is_email_required})" + ) + + async with KiroHttpClient(auth_manager, shared_client=shared_client) as http_client: + response = await http_client.request_with_retry( + "GET", + url, + params=params, + ) + + if response.status_code != 200: + error_message = _extract_usage_error_message(response) + logger.error( + f"GetUsageLimits failed: status={response.status_code}, message={error_message}" + ) + raise HTTPException( + status_code=response.status_code, + detail=f"Failed to fetch usage limits: {error_message}", + ) + + try: + payload = response.json() + except json.JSONDecodeError as exc: + logger.error(f"GetUsageLimits returned invalid JSON: {exc}") + raise HTTPException( + status_code=502, + detail="Kiro usage endpoint returned invalid JSON.", + ) from exc + + if not isinstance(payload, dict): + logger.error(f"GetUsageLimits returned unexpected JSON type: {type(payload).__name__}") + raise HTTPException( + status_code=502, + detail="Kiro usage endpoint returned an unexpected JSON structure.", + ) + + payload["usageSummary"] = build_usage_summary(payload) + return payload diff --git a/tests/integration/test_full_flow.py b/tests/integration/test_full_flow.py index 2f0f7b5c..229eae8b 100644 --- a/tests/integration/test_full_flow.py +++ b/tests/integration/test_full_flow.py @@ -303,6 +303,51 @@ def test_models_caching_behavior(self, test_client, valid_proxy_api_key): print("Caching works correctly") +class TestUsageEndpointIntegration: + """Integration tests for /v1/usage endpoint.""" + + @patch("kiro.routes_openai.fetch_usage_limits", new_callable=AsyncMock) + def test_usage_requires_auth_and_returns_payload( + self, + mock_fetch_usage_limits, + test_client, + valid_proxy_api_key, + ): + """ + What it does: Checks authentication and successful payload flow for /v1/usage. + Goal: Ensure the gateway exposes GetUsageLimits through a protected endpoint. + """ + print("Step 1: Request /v1/usage without authorization...") + unauthorized = test_client.get("/v1/usage") + assert unauthorized.status_code == 401 + + mock_fetch_usage_limits.return_value = { + "subscriptionInfo": {"subscriptionTitle": "KIRO PRO"}, + "usageBreakdownList": [{"usageLimit": 1000}], + "usageSummary": { + "resetAt": "2026-04-01T00:00:00Z", + "primaryLimit": 1000, + "primaryUsed": 0.0, + "primaryUnit": "INVOCATIONS", + "freeTrialLimit": 500, + "freeTrialUsed": 330.11, + "freeTrialExpiresAt": "2026-04-11T14:15:32.340Z", + }, + } + + print("Step 2: Request /v1/usage with valid authorization...") + authorized = test_client.get( + "/v1/usage", + headers={"Authorization": f"Bearer {valid_proxy_api_key}"}, + ) + + assert authorized.status_code == 200 + assert authorized.json()["subscriptionInfo"]["subscriptionTitle"] == "KIRO PRO" + assert authorized.json()["usageBreakdownList"][0]["usageLimit"] == 1000 + assert authorized.json()["usageSummary"]["freeTrialExpiresAt"] == "2026-04-11T14:15:32.340Z" + print(f"Usage payload: {authorized.json()}") + + class TestStreamingFlagHandling: """Integration tests for stream flag handling.""" diff --git a/tests/unit/test_http_client.py b/tests/unit/test_http_client.py index a559e646..c1995724 100644 --- a/tests/unit/test_http_client.py +++ b/tests/unit/test_http_client.py @@ -212,6 +212,51 @@ async def test_successful_request_returns_response(self, mock_auth_manager_for_h print("Verification: Response received...") assert response.status_code == 200 mock_client.request.assert_called_once() + + @pytest.mark.asyncio + async def test_successful_get_request_passes_query_params(self, mock_auth_manager_for_http): + """ + What it does: Verifies GET requests forward query params to httpx. + Purpose: Ensure runtime endpoints like getUsageLimits can use request_with_retry. + """ + print("Setup: Creating KiroHttpClient...") + http_client = KiroHttpClient(mock_auth_manager_for_http) + + mock_response = AsyncMock() + mock_response.status_code = 200 + + captured_kwargs = {} + + async def capture_request(method, url, params=None, json=None, headers=None): + captured_kwargs["method"] = method + captured_kwargs["url"] = url + captured_kwargs["params"] = params + captured_kwargs["json"] = json + captured_kwargs["headers"] = headers + return mock_response + + mock_client = AsyncMock() + mock_client.is_closed = False + mock_client.request = AsyncMock(side_effect=capture_request) + + print("Action: Executing GET request with query params...") + with patch.object(http_client, '_get_client', return_value=mock_client): + with patch('kiro.http_client.get_kiro_headers', return_value={}): + response = await http_client.request_with_retry( + "GET", + "https://api.example.com/getUsageLimits", + params={"origin": "AI_EDITOR", "resourceType": "AGENTIC_REQUEST"}, + ) + + print("Verification: Query params passed to request()...") + assert response.status_code == 200 + assert captured_kwargs["method"] == "GET" + assert captured_kwargs["url"] == "https://api.example.com/getUsageLimits" + assert captured_kwargs["params"] == { + "origin": "AI_EDITOR", + "resourceType": "AGENTIC_REQUEST", + } + assert captured_kwargs["json"] is None @pytest.mark.asyncio async def test_403_triggers_token_refresh(self, mock_auth_manager_for_http): @@ -1046,7 +1091,7 @@ async def test_streaming_request_includes_connection_close_header(self, mock_aut mock_request = Mock() captured_headers = {} - def capture_build_request(method, url, json, headers): + def capture_build_request(method, url, params=None, json=None, headers=None): captured_headers.update(headers) return mock_request @@ -1086,7 +1131,7 @@ async def test_non_streaming_request_does_not_include_connection_close_header(se captured_headers = {} - async def capture_request(method, url, json, headers): + async def capture_request(method, url, params=None, json=None, headers=None): captured_headers.update(headers) return mock_response @@ -1124,7 +1169,7 @@ async def test_streaming_connection_close_preserves_other_headers(self, mock_aut mock_request = Mock() captured_headers = {} - def capture_build_request(method, url, json, headers): + def capture_build_request(method, url, params=None, json=None, headers=None): captured_headers.update(headers) return mock_request @@ -1155,4 +1200,46 @@ def capture_build_request(method, url, json, headers): assert captured_headers["Content-Type"] == "application/json" assert captured_headers["X-Custom-Header"] == "custom_value" assert captured_headers["Connection"] == "close" - assert response.status_code == 200 \ No newline at end of file + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_streaming_request_passes_query_params_to_build_request(self, mock_auth_manager_for_http): + """ + What it does: Verifies streaming requests pass query params to build_request. + Purpose: Ensure query-based streaming endpoints remain supported. + """ + print("Setup: Creating KiroHttpClient...") + http_client = KiroHttpClient(mock_auth_manager_for_http) + + mock_response = AsyncMock() + mock_response.status_code = 200 + + mock_request = Mock() + captured_params = {} + + def capture_build_request(method, url, params=None, json=None, headers=None): + captured_params["method"] = method + captured_params["url"] = url + captured_params["params"] = params + return mock_request + + mock_client = AsyncMock() + mock_client.is_closed = False + mock_client.build_request = Mock(side_effect=capture_build_request) + mock_client.send = AsyncMock(return_value=mock_response) + + print("Action: Executing streaming GET request with query params...") + with patch.object(http_client, '_get_client', return_value=mock_client): + with patch('kiro.http_client.get_kiro_headers', return_value={"Authorization": "Bearer test"}): + response = await http_client.request_with_retry( + "GET", + "https://api.example.com/stream", + params={"cursor": "abc123"}, + stream=True, + ) + + print("Verification: Query params passed to build_request()...") + assert response.status_code == 200 + assert captured_params["method"] == "GET" + assert captured_params["url"] == "https://api.example.com/stream" + assert captured_params["params"] == {"cursor": "abc123"} diff --git a/tests/unit/test_routes_openai.py b/tests/unit/test_routes_openai.py index b407427b..b30b249e 100644 --- a/tests/unit/test_routes_openai.py +++ b/tests/unit/test_routes_openai.py @@ -8,6 +8,7 @@ - GET / - Root endpoint - GET /health - Health check - GET /v1/models - List available models +- GET /v1/usage - Current Kiro usage and plan limits - POST /v1/chat/completions - Chat completions For Anthropic API tests, see test_routes_anthropic.py. @@ -380,6 +381,117 @@ def test_models_owned_by_anthropic(self, test_client, valid_proxy_api_key): assert model["owned_by"] == "anthropic" +# ============================================================================= +# Tests for usage endpoint (/v1/usage) +# ============================================================================= + +class TestUsageEndpoint: + """Tests for the GET /v1/usage endpoint.""" + + def test_usage_requires_authentication(self, test_client): + """ + What it does: Verifies usage endpoint requires authentication. + Purpose: Ensure usage data is protected. + """ + print("Action: GET /v1/usage without auth...") + response = test_client.get("/v1/usage") + + print(f"Response: {response.status_code} {response.json()}") + assert response.status_code == 401 + + @patch("kiro.routes_openai.fetch_usage_limits", new_callable=AsyncMock) + def test_usage_returns_service_payload( + self, + mock_fetch_usage_limits, + test_client, + valid_proxy_api_key, + ): + """ + What it does: Verifies usage endpoint returns service payload unchanged. + Purpose: Preserve upstream GetUsageLimits shape. + """ + mock_fetch_usage_limits.return_value = { + "subscriptionInfo": {"subscriptionTitle": "KIRO PRO"}, + "usageBreakdownList": [], + "usageSummary": { + "resetAt": "2026-04-01T00:00:00Z", + "primaryLimit": 1000, + "primaryUsed": 0.0, + "primaryUnit": "INVOCATIONS", + "freeTrialLimit": 500, + "freeTrialUsed": 330.11, + "freeTrialExpiresAt": "2026-04-11T14:15:32.340Z", + }, + } + + print("Action: GET /v1/usage with valid auth...") + response = test_client.get( + "/v1/usage", + headers={"Authorization": f"Bearer {valid_proxy_api_key}"}, + ) + + print(f"Response: {response.json()}") + assert response.status_code == 200 + assert response.json()["subscriptionInfo"]["subscriptionTitle"] == "KIRO PRO" + assert response.json()["usageSummary"]["primaryLimit"] == 1000 + assert response.json()["usageSummary"]["freeTrialUsed"] == 330.11 + mock_fetch_usage_limits.assert_awaited_once() + + @patch("kiro.routes_openai.fetch_usage_limits", new_callable=AsyncMock) + def test_usage_forwards_query_parameters( + self, + mock_fetch_usage_limits, + test_client, + valid_proxy_api_key, + ): + """ + What it does: Verifies route forwards query parameters to usage service. + Purpose: Keep usage endpoint configurable without duplicating service logic. + """ + mock_fetch_usage_limits.return_value = {"usageBreakdownList": []} + + print("Action: GET /v1/usage with custom query params...") + response = test_client.get( + "/v1/usage", + headers={"Authorization": f"Bearer {valid_proxy_api_key}"}, + params={ + "origin": "CUSTOM_ORIGIN", + "resource_type": "CUSTOM_RESOURCE", + "is_email_required": "false", + }, + ) + + print(f"Response: {response.json()}") + assert response.status_code == 200 + _, kwargs = mock_fetch_usage_limits.await_args + assert kwargs["origin"] == "CUSTOM_ORIGIN" + assert kwargs["resource_type"] == "CUSTOM_RESOURCE" + assert kwargs["is_email_required"] is False + + @patch("kiro.routes_openai.fetch_usage_limits", new_callable=AsyncMock) + def test_usage_returns_400_for_invalid_arguments( + self, + mock_fetch_usage_limits, + test_client, + valid_proxy_api_key, + ): + """ + What it does: Verifies route converts ValueError into HTTP 400. + Purpose: Return user-friendly validation errors for usage endpoint input. + """ + mock_fetch_usage_limits.side_effect = ValueError("origin must not be empty") + + print("Action: GET /v1/usage with service ValueError...") + response = test_client.get( + "/v1/usage", + headers={"Authorization": f"Bearer {valid_proxy_api_key}"}, + ) + + print(f"Response: {response.status_code} {response.json()}") + assert response.status_code == 400 + assert response.json()["detail"] == "origin must not be empty" + + # ============================================================================= # Tests for chat completions endpoint (/v1/chat/completions) # ============================================================================= @@ -813,6 +925,17 @@ def test_router_has_models_endpoint(self): print(f"Found routes: {routes}") assert "/v1/models" in routes + + def test_router_has_usage_endpoint(self): + """ + What it does: Verifies usage endpoint is registered. + Purpose: Ensure endpoint is available. + """ + print("Checking: Router endpoints...") + routes = [route.path for route in router.routes] + + print(f"Found routes: {routes}") + assert "/v1/usage" in routes def test_router_has_chat_completions_endpoint(self): """ @@ -863,6 +986,19 @@ def test_models_endpoint_uses_get_method(self): assert "GET" in route.methods return pytest.fail("Models endpoint not found") + + def test_usage_endpoint_uses_get_method(self): + """ + What it does: Verifies usage endpoint uses GET method. + Purpose: Ensure correct HTTP method. + """ + print("Checking: HTTP methods...") + for route in router.routes: + if route.path == "/v1/usage": + print(f"Route /v1/usage methods: {route.methods}") + assert "GET" in route.methods + return + pytest.fail("Usage endpoint not found") def test_chat_completions_endpoint_uses_post_method(self): """ @@ -1396,4 +1532,4 @@ def test_content_hash_matches_first_500_chars(self): print("Checking: Match found...") assert info is not None - assert info.message_hash == hash1 \ No newline at end of file + assert info.message_hash == hash1 diff --git a/tests/unit/test_runtime_usage.py b/tests/unit/test_runtime_usage.py new file mode 100644 index 00000000..0a4219b7 --- /dev/null +++ b/tests/unit/test_runtime_usage.py @@ -0,0 +1,256 @@ +# -*- coding: utf-8 -*- + +""" +Unit tests for runtime usage helpers. +""" + +import json +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import pytest +from fastapi import HTTPException + +from kiro.auth import AuthType +from kiro.runtime_usage import ( + build_usage_limits_params, + build_usage_summary, + fetch_usage_limits, +) + + +@pytest.fixture +def mock_usage_auth_manager(): + """Create a mocked auth manager for usage helper tests.""" + manager = Mock() + manager.auth_type = AuthType.KIRO_DESKTOP + manager.profile_arn = "arn:aws:codewhisperer:us-east-1:123456789012:profile/test" + manager.q_host = "https://q.us-east-1.amazonaws.com" + return manager + + +class TestBuildUsageLimitsParams: + """Tests for usage query parameter construction.""" + + def test_includes_profile_arn_for_kiro_desktop(self, mock_usage_auth_manager): + """ + What it does: Verifies desktop auth includes profileArn in query params. + Purpose: Ensure GetUsageLimits matches Kiro desktop client behavior. + """ + print("Action: Building params for desktop auth...") + params = build_usage_limits_params(mock_usage_auth_manager) + + print(f"Result: {params}") + assert params["origin"] == "AI_EDITOR" + assert params["resourceType"] == "AGENTIC_REQUEST" + assert params["isEmailRequired"] == "true" + assert params["profileArn"] == mock_usage_auth_manager.profile_arn + + def test_omits_profile_arn_for_aws_sso(self, mock_usage_auth_manager): + """ + What it does: Verifies AWS SSO auth omits profileArn. + Purpose: Avoid sending unsupported profileArn for non-desktop auth. + """ + print("Setup: Switching auth type to AWS SSO OIDC...") + mock_usage_auth_manager.auth_type = AuthType.AWS_SSO_OIDC + + print("Action: Building params for AWS SSO auth...") + params = build_usage_limits_params(mock_usage_auth_manager) + + print(f"Result: {params}") + assert "profileArn" not in params + + def test_empty_origin_raises_value_error(self, mock_usage_auth_manager): + """ + What it does: Verifies empty origin is rejected. + Purpose: Ensure invalid query construction fails early. + """ + print("Action: Building params with empty origin...") + with pytest.raises(ValueError) as exc_info: + build_usage_limits_params(mock_usage_auth_manager, origin="") + + print(f"Error: {exc_info.value}") + assert "origin must not be empty" in str(exc_info.value) + + +class TestFetchUsageLimits: + """Tests for GetUsageLimits fetching.""" + + @pytest.mark.asyncio + @patch("kiro.runtime_usage.KiroHttpClient") + async def test_returns_usage_payload(self, mock_http_client_class, mock_usage_auth_manager): + """ + What it does: Verifies successful upstream payload is returned with a derived summary. + Purpose: Preserve transparency while exposing stable usage fields. + """ + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "subscriptionInfo": {"subscriptionTitle": "KIRO PRO"}, + "usageBreakdownList": [ + { + "resourceType": "CREDIT", + "usageLimit": 1000, + "currentUsageWithPrecision": 0.0, + "unit": "INVOCATIONS", + "nextDateReset": 1775001600.0, + "freeTrialInfo": { + "usageLimit": 500, + "currentUsageWithPrecision": 330.11, + "freeTrialExpiry": 1775916932.34, + }, + } + ], + } + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.request_with_retry = AsyncMock(return_value=mock_response) + mock_http_client_class.return_value = mock_client + + print("Action: Fetching usage limits...") + payload = await fetch_usage_limits( + auth_manager=mock_usage_auth_manager, + shared_client=None, + ) + + print(f"Result: {payload}") + assert payload["subscriptionInfo"]["subscriptionTitle"] == "KIRO PRO" + assert payload["usageSummary"] == { + "resetAt": "2026-04-01T00:00:00Z", + "primaryLimit": 1000, + "primaryUsed": 0.0, + "primaryUnit": "INVOCATIONS", + "freeTrialLimit": 500, + "freeTrialUsed": 330.11, + "freeTrialExpiresAt": "2026-04-11T14:15:32.340Z", + } + mock_client.request_with_retry.assert_awaited_once_with( + "GET", + "https://q.us-east-1.amazonaws.com/getUsageLimits", + params={ + "origin": "AI_EDITOR", + "resourceType": "AGENTIC_REQUEST", + "isEmailRequired": "true", + "profileArn": mock_usage_auth_manager.profile_arn, + }, + ) + + @pytest.mark.asyncio + @patch("kiro.runtime_usage.KiroHttpClient") + async def test_upstream_error_raises_http_exception(self, mock_http_client_class, mock_usage_auth_manager): + """ + What it does: Verifies non-200 upstream responses become HTTPException. + Purpose: Return actionable route errors to gateway clients. + """ + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 403 + mock_response.text = '{"message":"Access denied"}' + mock_response.json.return_value = {"message": "Access denied"} + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.request_with_retry = AsyncMock(return_value=mock_response) + mock_http_client_class.return_value = mock_client + + print("Action: Fetching usage limits with upstream 403...") + with pytest.raises(HTTPException) as exc_info: + await fetch_usage_limits( + auth_manager=mock_usage_auth_manager, + shared_client=None, + ) + + print(f"Error response: {exc_info.value.detail}") + assert exc_info.value.status_code == 403 + assert "Failed to fetch usage limits" in exc_info.value.detail + assert "Access denied" in exc_info.value.detail + + @pytest.mark.asyncio + @patch("kiro.runtime_usage.KiroHttpClient") + async def test_invalid_json_raises_502(self, mock_http_client_class, mock_usage_auth_manager): + """ + What it does: Verifies invalid upstream JSON raises 502. + Purpose: Prevent malformed upstream responses from leaking through silently. + """ + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.side_effect = json.JSONDecodeError("bad json", "x", 0) + + mock_client = AsyncMock() + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + mock_client.request_with_retry = AsyncMock(return_value=mock_response) + mock_http_client_class.return_value = mock_client + + print("Action: Fetching usage limits with invalid JSON response...") + with pytest.raises(HTTPException) as exc_info: + await fetch_usage_limits( + auth_manager=mock_usage_auth_manager, + shared_client=None, + ) + + print(f"Error response: {exc_info.value.detail}") + assert exc_info.value.status_code == 502 + assert "invalid JSON" in exc_info.value.detail + + +class TestBuildUsageSummary: + """Tests for derived usage summary generation.""" + + def test_builds_summary_from_usage_breakdown_list(self): + """ + What it does: Verifies the gateway derives stable summary fields. + Purpose: Expose the key usage values clients need without losing raw payload data. + """ + payload = { + "nextDateReset": 1775001600.0, + "usageBreakdownList": [ + { + "resourceType": "CREDIT", + "usageLimit": 1000, + "currentUsageWithPrecision": 0.0, + "unit": "INVOCATIONS", + "nextDateReset": 1775001600.0, + "freeTrialInfo": { + "usageLimit": 500, + "currentUsageWithPrecision": 330.11, + "freeTrialExpiry": 1775916932.34, + }, + } + ], + } + + print("Action: Building usage summary from upstream payload...") + summary = build_usage_summary(payload) + + print(f"Summary: {summary}") + assert summary == { + "resetAt": "2026-04-01T00:00:00Z", + "primaryLimit": 1000, + "primaryUsed": 0.0, + "primaryUnit": "INVOCATIONS", + "freeTrialLimit": 500, + "freeTrialUsed": 330.11, + "freeTrialExpiresAt": "2026-04-11T14:15:32.340Z", + } + + def test_handles_missing_breakdown_data(self): + """ + What it does: Verifies missing upstream fields degrade gracefully. + Purpose: Keep response shape stable even when upstream omits usage data. + """ + print("Action: Building usage summary from incomplete payload...") + summary = build_usage_summary({}) + + print(f"Summary: {summary}") + assert summary == { + "resetAt": None, + "primaryLimit": None, + "primaryUsed": None, + "primaryUnit": None, + "freeTrialLimit": None, + "freeTrialUsed": None, + "freeTrialExpiresAt": None, + }