diff --git a/kiro/auth.py b/kiro/auth.py index 9ec39cf8..3780b757 100644 --- a/kiro/auth.py +++ b/kiro/auth.py @@ -637,8 +637,9 @@ async def _refresh_token_aws_sso_oidc(self) -> None: Used by kiro-cli which authenticates via AWS IAM Identity Center. Strategy: Try with current in-memory token first. If it fails with 400 - (invalid_request - token was invalidated by kiro-cli re-login), reload - credentials from SQLite and retry once. + (invalid_request - token was invalidated by kiro-cli re-login) or 401 + (invalid_client - device registration rotated by kiro-cli re-login), + reload all credentials from SQLite and retry once. This approach handles both scenarios: 1. Container successfully refreshed token (uses in-memory token) @@ -656,9 +657,11 @@ async def _refresh_token_aws_sso_oidc(self) -> None: try: await self._do_aws_sso_oidc_refresh() except httpx.HTTPStatusError as e: - # 400 = invalid_request, likely stale token after kiro-cli re-login - if e.response.status_code == 400 and self._sqlite_db: - logger.warning("Token refresh failed with 400, reloading credentials from SQLite and retrying...") + # 400 = invalid_grant (stale refresh_token) + # 401 = invalid_client (stale client_secret after kiro-cli re-login) + # Both are recoverable by reloading all credentials from SQLite + if e.response.status_code in (400, 401) and self._sqlite_db: + logger.warning(f"Token refresh failed with {e.response.status_code}, reloading credentials from SQLite and retrying...") self._load_credentials_from_sqlite(self._sqlite_db) await self._do_aws_sso_oidc_refresh() else: @@ -795,7 +798,7 @@ async def get_access_token(self) -> str: except httpx.HTTPStatusError as e: # Graceful degradation for SQLite mode when refresh fails twice # This happens when kiro-cli refreshed tokens in memory without persisting - if e.response.status_code == 400 and self._sqlite_db: + if e.response.status_code in (400, 401) and self._sqlite_db: logger.warning( "Token refresh failed with 400 after SQLite reload. " "This may happen if kiro-cli refreshed tokens in memory without persisting." diff --git a/tests/unit/test_auth_manager.py b/tests/unit/test_auth_manager.py index 5ba36dec..739588a5 100644 --- a/tests/unit/test_auth_manager.py +++ b/tests/unit/test_auth_manager.py @@ -1596,8 +1596,8 @@ async def test_refresh_token_aws_sso_oidc_no_retry_on_non_400_error( self, mock_aws_sso_oidc_token_response ): """ - What it does: Verifies that non-400 errors are not retried. - Purpose: Ensure only 400 (invalid_request) triggers SQLite reload. + What it does: Verifies that non-retryable errors (e.g. 500) are not retried. + Purpose: Ensure only 400/401 triggers SQLite reload. """ print("Setup: Creating KiroAuthManager...") manager = KiroAuthManager( @@ -1687,6 +1687,71 @@ async def test_refresh_token_aws_sso_oidc_no_retry_without_sqlite_db( # ============================================================================= + + @pytest.mark.asyncio + async def test_refresh_token_aws_sso_oidc_retries_on_401_invalid_client( + self, mock_aws_sso_oidc_token_response + ): + """ + What it does: Verifies that 401 (invalid_client) triggers SQLite reload and retry. + Purpose: When kiro-cli re-login rotates device registration (client_id/client_secret), + the gateway should reload from SQLite and retry instead of failing permanently. + """ + print("Setup: Creating KiroAuthManager with stale client credentials...") + manager = KiroAuthManager( + refresh_token="test_refresh", + client_id="old_client_id", + client_secret="old_client_secret" + ) + manager._sqlite_db = "/fake/path/data.sqlite3" + + call_count = 0 + + print("Setup: Mocking HTTP client - 401 first, then 200...") + mock_error_response = AsyncMock() + mock_error_response.status_code = 401 + mock_error_response.text = '{"error":"invalid_client","error_description":"Invalid client secret provided"}' + mock_error_response.json = Mock(return_value={"error": "invalid_client", "error_description": "Invalid client secret provided"}) + mock_error_response.raise_for_status = Mock( + side_effect=httpx.HTTPStatusError( + "401 Unauthorized", + request=Mock(), + response=mock_error_response + ) + ) + + mock_success_response = AsyncMock() + mock_success_response.status_code = 200 + mock_success_response.json = Mock(return_value=mock_aws_sso_oidc_token_response()) + mock_success_response.raise_for_status = Mock() + + async def side_effect_post(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + return mock_error_response + return mock_success_response + + with patch('kiro.auth.httpx.AsyncClient') as mock_client_class: + mock_client = AsyncMock() + mock_client.post = AsyncMock(side_effect=side_effect_post) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + with patch.object(manager, '_load_credentials_from_sqlite') as mock_load: + print("Action: Calling _refresh_token_aws_sso_oidc...") + await manager._refresh_token_aws_sso_oidc() + + print("Verification: Two requests were made (retry on 401)...") + assert call_count == 2 + + print("Verification: SQLite was reloaded on 401...") + mock_load.assert_called_once() + + print("Verification: Token was updated from successful retry...") + assert manager._access_token == "new_aws_sso_access_token" + # Tests for is_token_expired() method # ============================================================================= @@ -2966,4 +3031,4 @@ def test_enterprise_ide_and_kiro_cli_use_same_format(self): print("") print("This is verified by other tests in this class and") print("TestKiroAuthManagerSsoRegionSeparation class.") - assert True # Documentation test \ No newline at end of file + assert True # Documentation test