diff --git a/src/client/auth.test.ts b/src/client/auth.test.ts index c38f19331..b31b4ab44 100644 --- a/src/client/auth.test.ts +++ b/src/client/auth.test.ts @@ -1587,7 +1587,8 @@ describe('OAuth Authorization', () => { // Mock provider methods for authorization flow (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1657,7 +1658,8 @@ describe('OAuth Authorization', () => { // Mock provider methods for token exchange (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.codeVerifier as jest.Mock).mockResolvedValue('test-verifier'); (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); @@ -1723,7 +1725,8 @@ describe('OAuth Authorization', () => { // Mock provider methods for token refresh (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.tokens as jest.Mock).mockResolvedValue({ access_token: 'old-access', @@ -1789,7 +1792,8 @@ describe('OAuth Authorization', () => { // Mock provider methods (providerWithCustomValidation.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (providerWithCustomValidation.tokens as jest.Mock).mockResolvedValue(undefined); (providerWithCustomValidation.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1844,7 +1848,8 @@ describe('OAuth Authorization', () => { // Mock provider methods (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1902,7 +1907,8 @@ describe('OAuth Authorization', () => { // Mock provider methods (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -1969,7 +1975,8 @@ describe('OAuth Authorization', () => { // Mock provider methods for token exchange (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.codeVerifier as jest.Mock).mockResolvedValue('test-verifier'); (mockProvider.saveTokens as jest.Mock).mockResolvedValue(undefined); @@ -2032,7 +2039,8 @@ describe('OAuth Authorization', () => { // Mock provider methods for token refresh (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.tokens as jest.Mock).mockResolvedValue({ access_token: 'old-access', @@ -2093,7 +2101,8 @@ describe('OAuth Authorization', () => { // Mock provider methods (mockProvider.clientInformation as jest.Mock).mockResolvedValue({ client_id: 'test-client', - client_secret: 'test-secret' + client_secret: 'test-secret', + redirect_uris: ['http://localhost:3000/callback'] }); (mockProvider.tokens as jest.Mock).mockResolvedValue(undefined); (mockProvider.saveCodeVerifier as jest.Mock).mockResolvedValue(undefined); @@ -2155,7 +2164,8 @@ describe('OAuth Authorization', () => { }, clientInformation: jest.fn().mockResolvedValue({ client_id: 'client123', - client_secret: 'secret123' + client_secret: 'secret123', + redirect_uris: ['http://localhost:3000/callback'] }), tokens: jest.fn().mockResolvedValue(undefined), saveTokens: jest.fn(), diff --git a/src/client/auth.ts b/src/client/auth.ts index 3c04f7cb5..1c45f5691 100644 --- a/src/client/auth.ts +++ b/src/client/auth.ts @@ -2,7 +2,6 @@ import pkceChallenge from 'pkce-challenge'; import { LATEST_PROTOCOL_VERSION } from '../types.js'; import { OAuthClientMetadata, - OAuthClientInformation, OAuthTokens, OAuthMetadata, OAuthClientInformationFull, @@ -56,7 +55,7 @@ export interface OAuthClientProvider { * server, or returns `undefined` if the client is not registered with the * server. */ - clientInformation(): OAuthClientInformation | undefined | Promise; + clientInformation(): OAuthClientInformationFull | undefined | Promise; /** * If implemented, this permits the OAuth client to dynamically register with @@ -149,6 +148,10 @@ export class UnauthorizedError extends Error { type ClientAuthMethod = 'client_secret_basic' | 'client_secret_post' | 'none'; +function isClientAuthMethod(method: string): method is ClientAuthMethod { + return ['client_secret_basic', 'client_secret_post', 'none'].includes(method); +} + const AUTHORIZATION_CODE_RESPONSE_TYPE = 'code'; const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256'; @@ -164,7 +167,7 @@ const AUTHORIZATION_CODE_CHALLENGE_METHOD = 'S256'; * @param supportedMethods - Authentication methods supported by the authorization server * @returns The selected authentication method */ -function selectClientAuthMethod(clientInformation: OAuthClientInformation, supportedMethods: string[]): ClientAuthMethod { +function selectClientAuthMethod(clientInformation: OAuthClientInformationFull, supportedMethods: string[]): ClientAuthMethod { const hasClientSecret = clientInformation.client_secret !== undefined; // If server doesn't specify supported methods, use RFC 6749 defaults @@ -172,6 +175,15 @@ function selectClientAuthMethod(clientInformation: OAuthClientInformation, suppo return hasClientSecret ? 'client_secret_post' : 'none'; } + // Prefer the method returned by the server during client registration if valid and supported + if ( + clientInformation.token_endpoint_auth_method && + isClientAuthMethod(clientInformation.token_endpoint_auth_method) && + supportedMethods.includes(clientInformation.token_endpoint_auth_method) + ) { + return clientInformation.token_endpoint_auth_method; + } + // Try methods in priority order (most secure first) if (hasClientSecret && supportedMethods.includes('client_secret_basic')) { return 'client_secret_basic'; @@ -205,7 +217,7 @@ function selectClientAuthMethod(clientInformation: OAuthClientInformation, suppo */ function applyClientAuthentication( method: ClientAuthMethod, - clientInformation: OAuthClientInformation, + clientInformation: OAuthClientInformationFull, headers: Headers, params: URLSearchParams ): void { @@ -790,7 +802,7 @@ export async function startAuthorization( resource }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationFull; redirectUrl: string | URL; scope?: string; state?: string; @@ -873,7 +885,7 @@ export async function exchangeAuthorization( fetchFn }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationFull; authorizationCode: string; codeVerifier: string; redirectUri: string | URL; @@ -952,7 +964,7 @@ export async function refreshAuthorization( fetchFn }: { metadata?: AuthorizationServerMetadata; - clientInformation: OAuthClientInformation; + clientInformation: OAuthClientInformationFull; refreshToken: string; resource?: URL; addClientAuthentication?: OAuthClientProvider['addClientAuthentication']; diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 9e4b73e92..dcc0cca1c 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -355,7 +355,11 @@ describe('SSEClientTransport', () => { get clientMetadata() { return { redirect_uris: ['http://localhost/callback'] }; }, - clientInformation: jest.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), + clientInformation: jest.fn(() => ({ + client_id: 'test-client-id', + client_secret: 'test-client-secret', + redirect_uris: ['http://localhost/callback'] + })), tokens: jest.fn(), saveTokens: jest.fn(), redirectToAuthorization: jest.fn(), @@ -1159,7 +1163,8 @@ describe('SSEClientTransport', () => { const clientInfo = config.clientRegistered ? { client_id: 'test-client-id', - client_secret: 'test-client-secret' + client_secret: 'test-client-secret', + redirect_uris: ['http://localhost/callback'] } : undefined; diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 3c6a9ec4d..794119bb1 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -15,7 +15,11 @@ describe('StreamableHTTPClientTransport', () => { get clientMetadata() { return { redirect_uris: ['http://localhost/callback'] }; }, - clientInformation: jest.fn(() => ({ client_id: 'test-client-id', client_secret: 'test-client-secret' })), + clientInformation: jest.fn(() => ({ + client_id: 'test-client-id', + client_secret: 'test-client-secret', + redirect_uris: ['http://localhost/callback'] + })), tokens: jest.fn(), saveTokens: jest.fn(), redirectToAuthorization: jest.fn(), diff --git a/src/examples/client/simpleOAuthClient.ts b/src/examples/client/simpleOAuthClient.ts index 354886050..2289a4362 100644 --- a/src/examples/client/simpleOAuthClient.ts +++ b/src/examples/client/simpleOAuthClient.ts @@ -6,7 +6,7 @@ import { URL } from 'node:url'; import { exec } from 'node:child_process'; import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; -import { OAuthClientInformation, OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; +import { OAuthClientInformationFull, OAuthClientMetadata, OAuthTokens } from '../../shared/auth.js'; import { CallToolRequest, ListToolsRequest, CallToolResultSchema, ListToolsResultSchema } from '../../types.js'; import { OAuthClientProvider, UnauthorizedError } from '../../client/auth.js'; @@ -46,7 +46,7 @@ class InMemoryOAuthClientProvider implements OAuthClientProvider { return this._clientMetadata; } - clientInformation(): OAuthClientInformation | undefined { + clientInformation(): OAuthClientInformationFull | undefined { return this._clientInformation; }