Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,10 @@ public sealed class ClientOAuthOptions
/// </para>
/// </remarks>
public IDictionary<string, string> AdditionalAuthorizationParameters { get; set; } = new Dictionary<string, string>();

/// <summary>
/// Gets or sets the token cache to use for storing and retrieving tokens beyond the lifetime of the transport.
/// If none is provided, tokens will be cached with the transport.
/// </summary>
public ITokenCache? TokenCache { get; set; }
}
39 changes: 25 additions & 14 deletions src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal sealed partial class ClientOAuthProvider
private string? _clientId;
private string? _clientSecret;

private TokenContainer? _token;
private ITokenCache _tokenCache;
private AuthorizationServerMetadata? _authServerMetadata;

/// <summary>
Expand Down Expand Up @@ -85,6 +85,7 @@ public ClientOAuthProvider(
_dcrClientUri = options.DynamicClientRegistration?.ClientUri;
_dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken;
_dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate;
_tokenCache = options.TokenCache ?? new InMemoryTokenCache();
}

/// <summary>
Expand Down Expand Up @@ -138,20 +139,22 @@ public ClientOAuthProvider(
{
ThrowIfNotBearerScheme(scheme);

var tokens = await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false);

// Return the token if it's valid
if (_token != null && _token.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
if (tokens != null && tokens.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(5))
{
return _token.AccessToken;
return tokens.AccessToken;
}

// Try to refresh the token if we have a refresh token
if (_token?.RefreshToken != null && _authServerMetadata != null)
if (tokens?.RefreshToken != null && _authServerMetadata != null)
{
var newToken = await RefreshTokenAsync(_token.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
if (newToken != null)
var newTokens = await RefreshTokenAsync(tokens.RefreshToken, resourceUri, _authServerMetadata, cancellationToken).ConfigureAwait(false);
if (newTokens != null)
{
_token = newToken;
return _token.AccessToken;
await _tokenCache.StoreTokensAsync(newTokens, cancellationToken).ConfigureAwait(false);
return newTokens.AccessToken;
}
}

Expand Down Expand Up @@ -230,14 +233,14 @@ private async Task PerformOAuthAuthorizationAsync(
}

// Perform the OAuth flow
var token = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);
var tokens = await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false);

if (token is null)
if (tokens is null)
{
ThrowFailedToHandleUnauthorizedResponse($"The {nameof(AuthorizationRedirectDelegate)} returned a null or empty token.");
}

_token = token;
await _tokenCache.StoreTokensAsync(tokens, cancellationToken).ConfigureAwait(false);
LogOAuthAuthorizationCompleted();
}

Expand Down Expand Up @@ -409,15 +412,23 @@ private async Task<TokenContainer> FetchTokenAsync(HttpRequestMessage request, C
httpResponse.EnsureSuccessStatusCode();

using var stream = await httpResponse.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenContainer, cancellationToken).ConfigureAwait(false);
var tokenResponse = await JsonSerializer.DeserializeAsync(stream, McpJsonUtilities.JsonContext.Default.TokenResponse, cancellationToken).ConfigureAwait(false);

if (tokenResponse is null)
{
ThrowFailedToHandleUnauthorizedResponse($"The token endpoint '{request.RequestUri}' returned an empty response.");
}

tokenResponse.ObtainedAt = DateTimeOffset.UtcNow;
return tokenResponse;
return new()
{
AccessToken = tokenResponse.AccessToken,
RefreshToken = tokenResponse.RefreshToken,
ExpiresIn = tokenResponse.ExpiresIn,
ExtExpiresIn = tokenResponse.ExtExpiresIn,
TokenType = tokenResponse.TokenType,
Scope = tokenResponse.Scope,
ObtainedAt = DateTimeOffset.UtcNow,
};
}

/// <summary>
Expand Down
17 changes: 17 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/ITokenCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace ModelContextProtocol.Authentication;

/// <summary>
/// Allows the client to cache access tokens beyond the lifetime of the transport.
/// </summary>
public interface ITokenCache
{
/// <summary>
/// Cache the token. After a new access token is acquired, this method is invoked to store it.
/// </summary>
ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken);

/// <summary>
/// Get the cached token. This method is invoked for every request.
/// </summary>
ValueTask<TokenContainer?> GetTokensAsync(CancellationToken cancellationToken);
}
27 changes: 27 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/InMemoryTokenCache.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Caches the token in-memory within this instance.
/// </summary>
internal class InMemoryTokenCache : ITokenCache
{
private TokenContainer? _tokens;

/// <summary>
/// Cache the token.
/// </summary>
public ValueTask StoreTokensAsync(TokenContainer tokens, CancellationToken cancellationToken)
{
_tokens = tokens;
return default;
}

/// <summary>
/// Get the cached token.
/// </summary>
public ValueTask<TokenContainer?> GetTokensAsync(CancellationToken cancellationToken)
{
return new ValueTask<TokenContainer?>(_tokens);
}
}
14 changes: 2 additions & 12 deletions src/ModelContextProtocol.Core/Authentication/TokenContainer.cs
Original file line number Diff line number Diff line change
@@ -1,57 +1,47 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents a token response from the OAuth server.
/// Represents a cacheable combination of tokens ready to be used for authentication.
/// </summary>
internal sealed class TokenContainer
public class TokenContainer
{
/// <summary>
/// Gets or sets the access token.
/// </summary>
[JsonPropertyName("access_token")]
public string AccessToken { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the refresh token.
/// </summary>
[JsonPropertyName("refresh_token")]
public string? RefreshToken { get; set; }

/// <summary>
/// Gets or sets the number of seconds until the access token expires.
/// </summary>
[JsonPropertyName("expires_in")]
public int ExpiresIn { get; set; }

/// <summary>
/// Gets or sets the extended expiration time in seconds.
/// </summary>
[JsonPropertyName("ext_expires_in")]
public int ExtExpiresIn { get; set; }

/// <summary>
/// Gets or sets the token type (typically "Bearer").
/// </summary>
[JsonPropertyName("token_type")]
public string TokenType { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the scope of the access token.
/// </summary>
[JsonPropertyName("scope")]
public string Scope { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the timestamp when the token was obtained.
/// </summary>
[JsonIgnore]
public DateTimeOffset ObtainedAt { get; set; }

/// <summary>
/// Gets the timestamp when the token expires, calculated from ObtainedAt and ExpiresIn.
/// </summary>
[JsonIgnore]
public DateTimeOffset ExpiresAt => ObtainedAt.AddSeconds(ExpiresIn);
}
45 changes: 45 additions & 0 deletions src/ModelContextProtocol.Core/Authentication/TokenResponse.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using System.Text.Json.Serialization;

namespace ModelContextProtocol.Authentication;

/// <summary>
/// Represents a token response from the OAuth server.
/// </summary>
internal sealed class TokenResponse
{
/// <summary>
/// Gets or sets the access token.
/// </summary>
[JsonPropertyName("access_token")]
public string AccessToken { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the refresh token.
/// </summary>
[JsonPropertyName("refresh_token")]
public string? RefreshToken { get; set; }

/// <summary>
/// Gets or sets the number of seconds until the access token expires.
/// </summary>
[JsonPropertyName("expires_in")]
public int ExpiresIn { get; set; }

/// <summary>
/// Gets or sets the extended expiration time in seconds.
/// </summary>
[JsonPropertyName("ext_expires_in")]
public int ExtExpiresIn { get; set; }

/// <summary>
/// Gets or sets the token type (typically "Bearer").
/// </summary>
[JsonPropertyName("token_type")]
public string TokenType { get; set; } = string.Empty;

/// <summary>
/// Gets or sets the scope of the access token.
/// </summary>
[JsonPropertyName("scope")]
public string Scope { get; set; } = string.Empty;
}
2 changes: 1 addition & 1 deletion src/ModelContextProtocol.Core/McpJsonUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ internal static bool IsValidMcpToolSchema(JsonElement element)

[JsonSerializable(typeof(ProtectedResourceMetadata))]
[JsonSerializable(typeof(AuthorizationServerMetadata))]
[JsonSerializable(typeof(TokenContainer))]
[JsonSerializable(typeof(TokenResponse))]
[JsonSerializable(typeof(DynamicClientRegistrationRequest))]
[JsonSerializable(typeof(DynamicClientRegistrationResponse))]

Expand Down
Loading