diff --git a/README.ja.md b/README.ja.md index 6544eca..8faee7f 100644 --- a/README.ja.md +++ b/README.ja.md @@ -40,8 +40,10 @@ Bearer token、カスタムヘッダー、OAuth 2.1 認証情報をリモート - §3.4–3.5 `authorization_pending` / `slow_down`(interval +=5 s)/ `expired_token` / `access_denied` ハンドリング - DCR の `grant_types` に `urn:ietf:params:oauth:grant-type:device_code` を登録(RFC 7591 §2) - [RFC 7591](https://www.rfc-editor.org/rfc/rfc7591) Dynamic Client Registration - - §3 クライアント登録リクエスト(公開クライアント、`token_endpoint_auth_method: none`) + - §3 クライアント登録リクエスト。AS メタデータの `token_endpoint_auth_methods_supported` から最適な認証方式を選択(`none` → `client_secret_post` → `client_secret_basic` の優先順) - §3.2.1 `client_secret_expires_at` に対応、期限切れ時に自動再登録 + - [RFC 6749](https://www.rfc-editor.org/rfc/rfc6749) OAuth 2.0 + - §2.3.1 `client_secret_basic`:percent-encode した認証情報を `Authorization: Basic` ヘッダーで送信(コード交換・トークンリフレッシュ・Device Authorization Grant ポーリングに適用) - [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) Bearer Token の利用 - §2.1 `Authorization: Bearer ` リクエストヘッダー - **バックオフ付きリトライ** — 接続エラー時に最大3回リトライ diff --git a/README.md b/README.md index a3d262a..a27a3b7 100644 --- a/README.md +++ b/README.md @@ -42,8 +42,10 @@ Bearer tokens, custom headers, and OAuth 2.1 credentials are forwarded to the re - §3.4–3.5 token polling with `authorization_pending` / `slow_down` (interval +=5 s) / `expired_token` / `access_denied` handling - DCR registers `urn:ietf:params:oauth:grant-type:device_code` in `grant_types` (RFC 7591 §2) - [RFC 7591](https://www.rfc-editor.org/rfc/rfc7591) Dynamic Client Registration - - §3 client registration request (public client with `token_endpoint_auth_method: none`) + - §3 client registration request; `token_endpoint_auth_method` chosen from `token_endpoint_auth_methods_supported` in AS metadata (prefers `none` → `client_secret_post` → `client_secret_basic`) - §3.2.1 `client_secret_expires_at` handling — auto re-register on expiry + - [RFC 6749](https://www.rfc-editor.org/rfc/rfc6749) OAuth 2.0 + - §2.3.1 `client_secret_basic`: `Authorization: Basic` header with percent-encoded credentials (applied to code exchange, token refresh, and Device Authorization Grant polling) - [RFC 6750](https://www.rfc-editor.org/rfc/rfc6750) Bearer Token usage - §2.1 `Authorization: Bearer ` request header - **Retry with backoff** — retries up to 3 times on connection errors diff --git a/WORKAROUNDS.md b/WORKAROUNDS.md index 6329cd5..cbb72ab 100644 --- a/WORKAROUNDS.md +++ b/WORKAROUNDS.md @@ -35,8 +35,10 @@ Works around known issues in Claude Code's HTTP transport: - **Cannot connect to servers that only support static bearer tokens** — mcp-remote always initiates an OAuth discovery handshake before any tool call; servers that authenticate via static bearer tokens (e.g. [Zabbix MCP Server](https://github.com/initMAX/zabbix-mcp-server)) respond with 404 on `/.well-known/oauth-authorization-server` and the proxy gives up before the bearer-auth path is ever reached ([zabbix-mcp-server#36](https://github.com/initMAX/zabbix-mcp-server/issues/36)); mcp-stdio connects directly with `--bearer-token YOUR_TOKEN http://your-server:8080/mcp`, skipping OAuth entirely. - **OAuth token exchange fails with URL-encoded responses** — TypeScript SDK's token exchange assumes the response is always JSON; servers that return `application/x-www-form-urlencoded` (e.g. GitHub OAuth) cause a JSON parse error, blocking authentication ([typescript-sdk#759](https://github.com/modelcontextprotocol/typescript-sdk/issues/759)); mcp-stdio's `_parse_token_response()` checks the `Content-Type` header and parses `application/x-www-form-urlencoded` via `urllib.parse.parse_qs`, so GitHub MCP and similar servers work without extra configuration. - **No proactive token refresh window** — mcp-remote (and `adaptOAuthProvider` in TypeScript SDK) only refreshes the access token after a 401 has already fired, with no early-refresh leeway. ASes that issue refresh tokens whose lifetime is barely longer than the access token's leave no margin for clock skew, so a refresh attempt can race the token expiry and fail ([mcp-remote#252](https://github.com/geelen/mcp-remote/issues/252), [typescript-sdk#1954](https://github.com/modelcontextprotocol/typescript-sdk/issues/1954)). mcp-stdio's `ensure_token()` performs proactive refresh: a cached access token is treated as expired when its expiry is within `--oauth-refresh-leeway` seconds (default 60, configurable via flag or `MCP_OAUTH_REFRESH_LEEWAY`), so the refresh hits well before the AS revokes the token. +- **Session affinity cookies not forwarded across requests** — mcp-remote creates a new HTTP client per request, so load-balancer session cookies (e.g. `AWSALB`, `AWSALBCORS`) returned in a `Set-Cookie` response header are discarded; subsequent requests land on a different backend node, breaking server-side session state ([mcp-remote#168](https://github.com/geelen/mcp-remote/issues/168)). mcp-stdio reuses a single `httpx.Client` instance across all requests within a session; httpx automatically stores `Set-Cookie` values and re-sends them on subsequent requests to the same origin, so ALB and similar sticky-session cookies are forwarded transparently without any extra configuration. - **No OAuth support in headless/SSH environments** — mcp-remote's OAuth flow requires opening a browser window, making it unusable in SSH sessions, CI/CD pipelines, or other browserless environments; there is no Device Authorization Grant support ([mcp-remote#228](https://github.com/geelen/mcp-remote/issues/228)). mcp-stdio supports RFC 8628 Device Authorization Grant via `--oauth-device`: it displays a short user code and verification URI on stderr so the user can authenticate from any browser, while the device polls the token endpoint in the background. - **`resource` indicator gets a trailing slash appended** — TypeScript SDK normalises the resource URL via `new URL(...).href`, converting `https://api.example.com` to `https://api.example.com/`; Atlassian authv2 and similar servers reject this with `InvalidTargetError: Incorrect resource parameters` ([typescript-sdk#1968](https://github.com/modelcontextprotocol/typescript-sdk/issues/1968), [mcp-remote#261](https://github.com/geelen/mcp-remote/issues/261)); mcp-stdio passes `resource=server_url` verbatim in both code exchange and refresh requests — Python's URL handling does not add trailing slashes. +- **DCR hardcodes `token_endpoint_auth_method: none`, breaking confidential-client servers** — mcp-remote always registers with `token_endpoint_auth_method: none` and sends `client_id`/`client_secret` in the POST body; authorization servers that publish only `client_secret_basic` in `token_endpoint_auth_methods_supported` (e.g. Microsoft Entra ID v2, some enterprise OIDC providers) reject the resulting token request ([mcp-remote#184](https://github.com/geelen/mcp-remote/issues/184), [mcp-remote#217](https://github.com/geelen/mcp-remote/issues/217)); mcp-stdio reads `token_endpoint_auth_methods_supported` from RFC 8414 AS metadata, picks the best supported method (`none` → `client_secret_post` → `client_secret_basic`), registers with that method via DCR, and applies it consistently across code exchange, token refresh, and Device Authorization Grant polling — `client_secret_basic` sends credentials as `Authorization: Basic base64(percent_encode(client_id):percent_encode(client_secret))` per RFC 6749 §2.3.1. ## Windows diff --git a/src/mcp_stdio/oauth.py b/src/mcp_stdio/oauth.py index 610e784..6709072 100644 --- a/src/mcp_stdio/oauth.py +++ b/src/mcp_stdio/oauth.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any -from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlsplit, urlunsplit +from urllib.parse import ParseResult, parse_qs, quote, urlencode, urlparse, urlsplit, urlunsplit import httpx @@ -35,6 +35,7 @@ class OAuthMetadata: token_endpoint: str registration_endpoint: str | None = None device_authorization_endpoint: str | None = None + token_endpoint_auth_methods_supported: list[str] | None = None # --------------------------------------------------------------------------- @@ -248,12 +249,14 @@ def _fetch_authorization_server_metadata( f"warning: RFC 8414 §3 issuer mismatch — " f"expected {auth_server_url!r}, got {issuer!r}" ) + methods = data.get("token_endpoint_auth_methods_supported") return OAuthMetadata( authorization_endpoint=data.get("authorization_endpoint") or f"{auth_server_url}/authorize", token_endpoint=data.get("token_endpoint") or f"{auth_server_url}/token", registration_endpoint=data.get("registration_endpoint") or None, device_authorization_endpoint=data.get("device_authorization_endpoint") or None, + token_endpoint_auth_methods_supported=methods if isinstance(methods, list) else None, ) except Exception: pass @@ -370,6 +373,32 @@ def discover_oauth_metadata( # Dynamic Client Registration (RFC 7591) # --------------------------------------------------------------------------- +# Methods supported by mcp-stdio, in preference order (most-compatible first). +# "none" covers public clients; "client_secret_post" and "client_secret_basic" +# cover confidential clients per RFC 6749 §2.3. +_SUPPORTED_AUTH_METHODS = ("none", "client_secret_post", "client_secret_basic") + + +def _pick_token_endpoint_auth_method(supported: list[str] | None) -> str: + """Pick the best token endpoint auth method from AS-advertised list. + + Returns the first entry of ``_SUPPORTED_AUTH_METHODS`` that the AS also + supports. Falls back to ``"none"`` when the AS list is absent (pre-RFC 8414 + servers) or only advertises methods that mcp-stdio does not implement + (e.g. ``private_key_jwt``), with a warning in the latter case. + """ + if not supported: + return "none" + for method in _SUPPORTED_AUTH_METHODS: + if method in supported: + return method + log( + f"warning: AS token_endpoint_auth_methods_supported {supported!r} " + f"contains no methods supported by mcp-stdio " + f"({', '.join(_SUPPORTED_AUTH_METHODS)}); defaulting to 'none'" + ) + return "none" + @dataclass class ClientRegistration: @@ -378,6 +407,7 @@ class ClientRegistration: client_id: str client_secret: str | None = None client_secret_expires_at: float | None = None # RFC 7591 §3.2.1; None = no expiry + auth_method: str = "none" def _is_client_secret_expired(cached: TokenData) -> bool: @@ -407,11 +437,14 @@ def register_client( "Provide a --client-id instead." ) + auth_method = _pick_token_endpoint_auth_method( + metadata.token_endpoint_auth_methods_supported + ) if device_flow: body: dict[str, object] = { "client_name": "mcp-stdio", "grant_types": ["urn:ietf:params:oauth:grant-type:device_code", "refresh_token"], - "token_endpoint_auth_method": "none", + "token_endpoint_auth_method": auth_method, } else: body = { @@ -419,7 +452,7 @@ def register_client( "redirect_uris": [redirect_uri], "response_types": ["code"], "grant_types": ["authorization_code", "refresh_token"], - "token_endpoint_auth_method": "none", + "token_endpoint_auth_method": auth_method, } resp = client.post( metadata.registration_endpoint, @@ -438,6 +471,7 @@ def register_client( client_id=data["client_id"], client_secret=data.get("client_secret"), client_secret_expires_at=expiry, + auth_method=auth_method, ) @@ -552,33 +586,47 @@ def exchange_code( client: httpx.Client, *, resource: str | None = None, + auth_method: str = "none", ) -> dict[str, Any]: """Exchange authorization code for tokens. Args: resource: RFC 8707 resource indicator (the MCP server URL). + auth_method: Token endpoint authentication method (RFC 6749 §2.3). + ``"client_secret_basic"`` sends credentials in an ``Authorization: + Basic`` header per RFC 6749 §2.3.1; ``"none"`` / ``"client_secret_post"`` + keep them in the request body. Returns the raw token response dict. """ data: dict[str, str] = { "grant_type": "authorization_code", "code": code, - "client_id": client_id, "redirect_uri": redirect_uri, "code_verifier": code_verifier, } - if client_secret: - data["client_secret"] = client_secret + req_headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + if auth_method == "client_secret_basic" and client_secret: + # RFC 6749 §2.3.1: percent-encode client_id and client_secret before + # base64-encoding for HTTP Basic auth. + creds = base64.b64encode( + f"{quote(client_id, safe='')}:{quote(client_secret, safe='')}".encode() + ).decode() + req_headers["Authorization"] = f"Basic {creds}" + else: + data["client_id"] = client_id + if client_secret: + data["client_secret"] = client_secret if resource: data["resource"] = resource resp = client.post( metadata.token_endpoint, data=data, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - }, + headers=req_headers, ) return _parse_token_response(resp) @@ -591,31 +639,40 @@ def refresh_access_token( client: httpx.Client, *, resource: str | None = None, + auth_method: str = "none", ) -> dict[str, Any]: """Refresh an access token. Args: resource: RFC 8707 resource indicator (the MCP server URL). + auth_method: Token endpoint authentication method (RFC 6749 §2.3). Returns the raw token response dict. """ data: dict[str, str] = { "grant_type": "refresh_token", "refresh_token": refresh_token, - "client_id": client_id, } - if client_secret: - data["client_secret"] = client_secret + req_headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + if auth_method == "client_secret_basic" and client_secret: + creds = base64.b64encode( + f"{quote(client_id, safe='')}:{quote(client_secret, safe='')}".encode() + ).decode() + req_headers["Authorization"] = f"Basic {creds}" + else: + data["client_id"] = client_id + if client_secret: + data["client_secret"] = client_secret if resource: data["resource"] = resource resp = client.post( token_endpoint, data=data, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - }, + headers=req_headers, ) return _parse_token_response(resp) @@ -633,6 +690,7 @@ def _token_response_to_data( *, previous_refresh_token: str | None = None, client_secret_expires_at: float | None = None, + auth_method: str = "none", ) -> TokenData: """Convert a raw token response to TokenData. @@ -655,6 +713,7 @@ def _token_response_to_data( token_endpoint=metadata.token_endpoint, authorization_endpoint=metadata.authorization_endpoint, registration_endpoint=metadata.registration_endpoint, + token_endpoint_auth_method=auth_method, ) @@ -683,6 +742,7 @@ def refresh_cached_token( log("OAuth client_secret expired (RFC 7591 §3.2.1) — cannot refresh") return None log("access token expired, attempting refresh") + auth_method = cached.token_endpoint_auth_method try: raw = refresh_access_token( cached.token_endpoint, @@ -691,6 +751,7 @@ def refresh_cached_token( cached.refresh_token, client, resource=server_url, + auth_method=auth_method, ) except Exception as e: log(f"token refresh failed: {e}") @@ -707,6 +768,7 @@ def refresh_cached_token( cached.client_secret, previous_refresh_token=cached.refresh_token, client_secret_expires_at=cached.client_secret_expires_at, + auth_method=auth_method, ) save_token(server_url, data) log("token refreshed successfully") @@ -740,17 +802,20 @@ def _run_authorization_flow( cid = client_id_override csecret: str | None = None cse_at: float | None = None + auth_method = "none" if not cid: if cached and cached.client_id and not _is_client_secret_expired(cached): cid = cached.client_id csecret = cached.client_secret cse_at = cached.client_secret_expires_at + auth_method = cached.token_endpoint_auth_method else: log("registering OAuth client") reg = register_client(metadata, redirect_uri, client) cid = reg.client_id csecret = reg.client_secret cse_at = reg.client_secret_expires_at + auth_method = reg.auth_method log(f"registered client: {cid}") assert cid is not None @@ -823,9 +888,10 @@ def serve() -> None: redirect_uri, client, resource=server_url, + auth_method=auth_method, ) data = _token_response_to_data( - raw, metadata, cid, csecret, client_secret_expires_at=cse_at + raw, metadata, cid, csecret, client_secret_expires_at=cse_at, auth_method=auth_method ) save_token(server_url, data) log("OAuth token obtained and saved") @@ -857,17 +923,20 @@ def _run_device_authorization_flow( cid = client_id_override csecret: str | None = None cse_at: float | None = None + auth_method = "none" if not cid: if cached and cached.client_id and not _is_client_secret_expired(cached): cid = cached.client_id csecret = cached.client_secret cse_at = cached.client_secret_expires_at + auth_method = cached.token_endpoint_auth_method elif metadata.registration_endpoint: log("registering OAuth client for device flow") reg = register_client(metadata, "", client, device_flow=True) cid = reg.client_id csecret = reg.client_secret cse_at = reg.client_secret_expires_at + auth_method = reg.auth_method log(f"registered client: {cid}") else: raise ValueError( @@ -878,19 +947,28 @@ def _run_device_authorization_flow( # Step 1: Device Authorization Request (RFC 8628 §3.1) da_params: dict[str, str] = { - "client_id": cid, "resource": server_url, } if scope: da_params["scope"] = scope + da_headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + if auth_method == "client_secret_basic" and csecret: + creds = base64.b64encode( + f"{quote(cid, safe='')}:{quote(csecret, safe='')}".encode() + ).decode() + da_headers["Authorization"] = f"Basic {creds}" + else: + da_params["client_id"] = cid + if csecret: + da_params["client_secret"] = csecret da_resp = client.post( metadata.device_authorization_endpoint, data=da_params, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - }, + headers=da_headers, ) da_resp.raise_for_status() da = da_resp.json() @@ -918,19 +996,26 @@ def _run_device_authorization_flow( poll_data: dict[str, str] = { "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, - "client_id": cid, } - if csecret: - poll_data["client_secret"] = csecret + poll_headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + } + if auth_method == "client_secret_basic" and csecret: + creds = base64.b64encode( + f"{quote(cid, safe='')}:{quote(csecret, safe='')}".encode() + ).decode() + poll_headers["Authorization"] = f"Basic {creds}" + else: + poll_data["client_id"] = cid + if csecret: + poll_data["client_secret"] = csecret try: tok_resp = client.post( metadata.token_endpoint, data=poll_data, - headers={ - "Content-Type": "application/x-www-form-urlencoded", - "Accept": "application/json", - }, + headers=poll_headers, ) except Exception as exc: log(f"device flow poll error: {exc}") @@ -940,7 +1025,7 @@ def _run_device_authorization_flow( if tok_resp.status_code == 200: raw = _parse_token_response(tok_resp) data = _token_response_to_data( - raw, metadata, cid, csecret, client_secret_expires_at=cse_at + raw, metadata, cid, csecret, client_secret_expires_at=cse_at, auth_method=auth_method ) save_token(server_url, data) log("device flow token obtained and saved") diff --git a/src/mcp_stdio/token_store.py b/src/mcp_stdio/token_store.py index 6e38b2b..d977c36 100644 --- a/src/mcp_stdio/token_store.py +++ b/src/mcp_stdio/token_store.py @@ -35,6 +35,8 @@ class TokenData: token_endpoint: str = "" authorization_endpoint: str = "" registration_endpoint: str | None = None + # Token endpoint authentication method (RFC 6749 §2.3 / RFC 8414) + token_endpoint_auth_method: str = "none" def _ensure_store_dir() -> None: diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 39cb3f6..267cfe5 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -19,6 +19,7 @@ _make_callback_handler, _parse_resource_metadata_hint, _parse_token_response, + _pick_token_endpoint_auth_method, _probe_www_authenticate, _run_device_authorization_flow, _token_response_to_data, @@ -3145,3 +3146,383 @@ def test_device_flow_flag_routes_to_device_flow(self, httpx_mock, tmp_path, monk # Verify device_authorization endpoint was actually called reqs = httpx_mock.get_requests() assert any(str(r.url).startswith(DEVICE_AUTH_URL) for r in reqs) + + +# --- _pick_token_endpoint_auth_method --- + + +class TestPickTokenEndpointAuthMethod: + def test_none_when_supported_is_absent(self): + assert _pick_token_endpoint_auth_method(None) == "none" + + def test_none_when_list_is_empty(self): + assert _pick_token_endpoint_auth_method([]) == "none" + + def test_prefers_none_over_post(self): + assert _pick_token_endpoint_auth_method(["none", "client_secret_post"]) == "none" + + def test_prefers_none_over_basic(self): + assert _pick_token_endpoint_auth_method(["client_secret_basic", "none"]) == "none" + + def test_prefers_post_over_basic(self): + assert _pick_token_endpoint_auth_method(["client_secret_basic", "client_secret_post"]) == "client_secret_post" + + def test_selects_client_secret_basic_when_only_option(self): + assert _pick_token_endpoint_auth_method(["client_secret_basic"]) == "client_secret_basic" + + def test_selects_client_secret_post(self): + assert _pick_token_endpoint_auth_method(["client_secret_post"]) == "client_secret_post" + + def test_falls_back_to_none_for_unsupported_methods(self, capsys): + """AS that only advertises private_key_jwt → warn + default none.""" + result = _pick_token_endpoint_auth_method(["private_key_jwt", "tls_client_auth"]) + assert result == "none" + # warning should have been emitted + captured = capsys.readouterr() + assert "warning" in captured.err.lower() + + +# --- token_endpoint_auth_methods_supported in discovery --- + + +class TestDiscoverMetadataAuthMethods: + def test_parses_token_endpoint_auth_methods_supported(self, httpx_mock): + """RFC 8414 token_endpoint_auth_methods_supported is captured.""" + httpx_mock.add_response( + url="https://example.com/.well-known/oauth-protected-resource", + status_code=404, + ) + httpx_mock.add_response( + url="https://example.com/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + "token_endpoint_auth_methods_supported": ["client_secret_basic", "none"], + }, + ) + meta = discover_oauth_metadata("https://example.com/mcp", httpx.Client()) + assert meta.token_endpoint_auth_methods_supported == ["client_secret_basic", "none"] + + def test_missing_field_is_none(self, httpx_mock): + """Older AS metadata without the field → None (public client default).""" + httpx_mock.add_response( + url="https://example.com/.well-known/oauth-protected-resource", + status_code=404, + ) + httpx_mock.add_response( + url="https://example.com/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": "https://example.com/authorize", + "token_endpoint": "https://example.com/token", + }, + ) + meta = discover_oauth_metadata("https://example.com/mcp", httpx.Client()) + assert meta.token_endpoint_auth_methods_supported is None + + +# --- client_secret_basic in exchange_code --- + + +class TestExchangeCodeBasicAuth: + META = OAuthMetadata( + authorization_endpoint="https://as.example.com/authorize", + token_endpoint="https://as.example.com/token", + ) + + def test_basic_auth_header_sent(self, httpx_mock): + """client_secret_basic: credentials in Authorization header, not body.""" + httpx_mock.add_response( + url="https://as.example.com/token", + json={"access_token": "at"}, + ) + client = httpx.Client() + exchange_code( + self.META, + "my_client", + "my_secret", + "code", + "verifier", + "http://127.0.0.1:9/cb", + client, + auth_method="client_secret_basic", + ) + req = httpx_mock.get_requests()[0] + expected = base64.b64encode(b"my_client:my_secret").decode() + assert req.headers.get("authorization") == f"Basic {expected}" + assert b"client_id" not in req.content + assert b"client_secret" not in req.content + + def test_basic_auth_with_special_chars_percent_encoded(self, httpx_mock): + """RFC 6749 §2.3.1: client_id / client_secret are percent-encoded.""" + httpx_mock.add_response( + url="https://as.example.com/token", + json={"access_token": "at"}, + ) + client = httpx.Client() + exchange_code( + self.META, + "c:id", + "s:ecret", + "code", + "v", + "http://127.0.0.1:9/cb", + client, + auth_method="client_secret_basic", + ) + req = httpx_mock.get_requests()[0] + # percent-encode colons: "c%3Aid:s%3Aecret" + expected = base64.b64encode(b"c%3Aid:s%3Aecret").decode() + assert req.headers.get("authorization") == f"Basic {expected}" + + def test_client_secret_post_sends_credentials_in_body(self, httpx_mock): + httpx_mock.add_response( + url="https://as.example.com/token", + json={"access_token": "at"}, + ) + client = httpx.Client() + exchange_code( + self.META, + "cid", + "csec", + "code", + "v", + "http://127.0.0.1:9/cb", + client, + auth_method="client_secret_post", + ) + req = httpx_mock.get_requests()[0] + assert b"client_id=cid" in req.content + assert b"client_secret=csec" in req.content + assert "authorization" not in req.headers + + +# --- client_secret_basic in refresh_access_token --- + + +class TestRefreshTokenBasicAuth: + def test_basic_auth_header_sent(self, httpx_mock): + httpx_mock.add_response( + url="https://as.example.com/token", + json={"access_token": "new_at"}, + ) + client = httpx.Client() + refresh_access_token( + "https://as.example.com/token", + "my_client", + "my_secret", + "rt", + client, + auth_method="client_secret_basic", + ) + req = httpx_mock.get_requests()[0] + expected = base64.b64encode(b"my_client:my_secret").decode() + assert req.headers.get("authorization") == f"Basic {expected}" + assert b"client_id" not in req.content + assert b"client_secret" not in req.content + + def test_none_method_sends_credentials_in_body(self, httpx_mock): + httpx_mock.add_response( + url="https://as.example.com/token", + json={"access_token": "new_at"}, + ) + client = httpx.Client() + refresh_access_token( + "https://as.example.com/token", + "cid", + "csec", + "rt", + client, + auth_method="none", + ) + req = httpx_mock.get_requests()[0] + assert b"client_id=cid" in req.content + assert b"client_secret=csec" in req.content + assert "authorization" not in req.headers + + +# --- register_client picks auth method from metadata --- + + +class TestRegisterClientAuthMethod: + def test_picks_client_secret_basic_from_metadata(self, httpx_mock): + """AS that only supports client_secret_basic → DCR registers with that method.""" + httpx_mock.add_response( + url="https://as.example.com/register", + json={"client_id": "cid", "client_secret": "csec"}, + ) + meta = OAuthMetadata( + authorization_endpoint="https://as.example.com/authorize", + token_endpoint="https://as.example.com/token", + registration_endpoint="https://as.example.com/register", + token_endpoint_auth_methods_supported=["client_secret_basic"], + ) + reg = register_client(meta, "http://127.0.0.1:9/cb", httpx.Client()) + assert reg.auth_method == "client_secret_basic" + body = json.loads(httpx_mock.get_requests()[0].content) + assert body["token_endpoint_auth_method"] == "client_secret_basic" + + def test_defaults_to_none_when_field_absent(self, httpx_mock): + httpx_mock.add_response( + url="https://as.example.com/register", + json={"client_id": "cid"}, + ) + meta = OAuthMetadata( + authorization_endpoint="https://as.example.com/authorize", + token_endpoint="https://as.example.com/token", + registration_endpoint="https://as.example.com/register", + ) + reg = register_client(meta, "http://127.0.0.1:9/cb", httpx.Client()) + assert reg.auth_method == "none" + + def test_prefers_none_when_both_none_and_basic_supported(self, httpx_mock): + httpx_mock.add_response( + url="https://as.example.com/register", + json={"client_id": "cid"}, + ) + meta = OAuthMetadata( + authorization_endpoint="https://as.example.com/authorize", + token_endpoint="https://as.example.com/token", + registration_endpoint="https://as.example.com/register", + token_endpoint_auth_methods_supported=["client_secret_basic", "none"], + ) + reg = register_client(meta, "http://127.0.0.1:9/cb", httpx.Client()) + assert reg.auth_method == "none" + + +# --- token_endpoint_auth_method persisted and reused --- + + +class TestTokenEndpointAuthMethodPersistence: + def test_auth_method_stored_in_token_data(self): + """_token_response_to_data persists auth_method for subsequent refreshes.""" + meta = OAuthMetadata( + authorization_endpoint="https://as.example.com/authorize", + token_endpoint="https://as.example.com/token", + ) + data = _token_response_to_data( + {"access_token": "at", "expires_in": 3600}, + meta, + "cid", + "csec", + auth_method="client_secret_basic", + ) + assert data.token_endpoint_auth_method == "client_secret_basic" + + def test_default_auth_method_is_none(self): + meta = OAuthMetadata( + authorization_endpoint="https://as.example.com/authorize", + token_endpoint="https://as.example.com/token", + ) + data = _token_response_to_data( + {"access_token": "at"}, + meta, + "cid", + None, + ) + assert data.token_endpoint_auth_method == "none" + + def test_legacy_token_loads_with_none_default(self, tmp_path, monkeypatch): + """TokenData(**old_entry) without token_endpoint_auth_method defaults to 'none'.""" + import json as _json + from mcp_stdio.token_store import load_token + + store = tmp_path / "tokens.json" + monkeypatch.setattr("mcp_stdio.token_store._STORE_DIR", tmp_path) + monkeypatch.setattr("mcp_stdio.token_store._STORE_FILE", store) + + # Write a token entry WITHOUT token_endpoint_auth_method (legacy format) + store.write_text(_json.dumps({ + "https://example.com/mcp": { + "access_token": "at", + "token_type": "Bearer", + "expires_at": None, + "refresh_token": "rt", + "scope": None, + "client_id": "cid", + "client_secret": None, + "client_secret_expires_at": None, + "token_endpoint": "https://example.com/token", + "authorization_endpoint": "https://example.com/authorize", + "registration_endpoint": None, + } + })) + loaded = load_token("https://example.com/mcp") + assert loaded is not None + assert loaded.token_endpoint_auth_method == "none" + + def test_refresh_cached_token_uses_stored_auth_method( + self, tmp_path, monkeypatch, httpx_mock + ): + """refresh_cached_token passes token_endpoint_auth_method from cache.""" + from mcp_stdio.token_store import save_token + + monkeypatch.setattr("mcp_stdio.token_store._STORE_DIR", tmp_path) + monkeypatch.setattr("mcp_stdio.token_store._STORE_FILE", tmp_path / "tokens.json") + + save_token( + "https://example.com/mcp", + TokenData( + access_token="stale", + expires_at=time.time() - 1, + refresh_token="rt", + client_id="cid", + client_secret="csec", + token_endpoint="https://example.com/token", + authorization_endpoint="https://example.com/authorize", + token_endpoint_auth_method="client_secret_basic", + ), + ) + httpx_mock.add_response( + url="https://example.com/token", + json={"access_token": "new_at", "expires_in": 3600}, + ) + data = refresh_cached_token("https://example.com/mcp", httpx.Client()) + assert data is not None + assert data.access_token == "new_at" + # Verify Basic auth header was used + req = httpx_mock.get_requests()[0] + expected = base64.b64encode(b"cid:csec").decode() + assert req.headers.get("authorization") == f"Basic {expected}" + assert b"client_id" not in req.content + # Persisted method is preserved in the refreshed token + assert data.token_endpoint_auth_method == "client_secret_basic" + + +# --- client_secret_basic in device authorization request (Step 1) --- + + +class TestDeviceAuthStepOneBasicAuth: + def test_device_authorization_request_uses_basic_auth( + self, httpx_mock, tmp_path, monkeypatch + ): + """client_secret_basic: DA request (Step 1) puts credentials in Authorization header.""" + monkeypatch.setattr("mcp_stdio.token_store._STORE_DIR", tmp_path) + monkeypatch.setattr("mcp_stdio.token_store._STORE_FILE", tmp_path / "tokens.json") + + httpx_mock.add_response(url=DEVICE_AUTH_URL, json=_da_response()) + httpx_mock.add_response( + url=TOKEN_URL, + json={"access_token": "at", "token_type": "Bearer", "expires_in": 3600}, + ) + + # Provide cached client with client_secret_basic already selected. + cached = TokenData( + access_token="stale", + client_id="da_cid", + client_secret="da_secret", + token_endpoint="https://api.example.com/token", + authorization_endpoint=AUTH_URL, + token_endpoint_auth_method="client_secret_basic", + ) + client = httpx.Client() + _run_device_authorization_flow( + MCP_URL, client, metadata=_device_meta(), cached=cached + ) + + # First request is the device authorization request (Step 1). + da_req = httpx_mock.get_requests()[0] + assert da_req.url == DEVICE_AUTH_URL + expected = base64.b64encode(b"da_cid:da_secret").decode() + assert da_req.headers.get("authorization") == f"Basic {expected}" + assert b"client_id" not in da_req.content + assert b"client_secret" not in da_req.content