Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"pydantic[email]>=2.11.7",
"pyperclip>=1.9.0",
"openapi-core>=0.19.5",
"py-key-value-aio[disk,memory]>=0.2.1",
"websockets>=15.0.1",
]

Expand Down
279 changes: 80 additions & 199 deletions src/fastmcp/client/auth/oauth.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,33 @@
from __future__ import annotations

import asyncio
import time
import webbrowser
from asyncio import Future
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
from typing import Any
from urllib.parse import urlparse

import anyio
import httpx
from key_value.aio.adapters.pydantic import PydanticAdapter
from key_value.aio.protocols import AsyncKeyValue
from key_value.aio.stores.memory import MemoryStore
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthToken,
)
from mcp.shared.auth import (
OAuthToken as OAuthToken,
)
from pydantic import AnyHttpUrl, BaseModel, TypeAdapter, ValidationError
from pydantic import AnyHttpUrl
from typing_extensions import override
from uvicorn.server import Server

from fastmcp import settings as fastmcp_global_settings
from fastmcp.client.oauth_callback import (
create_oauth_callback_server,
)
from fastmcp.utilities.http import find_available_port
from fastmcp.utilities.logging import get_logger
from fastmcp.utilities.storage import JSONFileStorage

__all__ = ["OAuth"]

Expand All @@ -41,174 +40,6 @@ class ClientNotFoundError(Exception):
pass


class StoredToken(BaseModel):
"""Token storage format with absolute expiry time."""

token_payload: OAuthToken
expires_at: datetime | None


# Create TypeAdapter at module level for efficient parsing
stored_token_adapter = TypeAdapter(StoredToken)


def default_cache_dir() -> Path:
return fastmcp_global_settings.home / "oauth-mcp-client-cache"


class FileTokenStorage(TokenStorage):
"""
File-based token storage implementation for OAuth credentials and tokens.
Implements the mcp.client.auth.TokenStorage protocol.

Each instance is tied to a specific server URL for proper token isolation.
Uses JSONFileStorage internally for consistent file handling.
"""

def __init__(self, server_url: str, cache_dir: Path | None = None):
"""Initialize storage for a specific server URL."""
self.server_url = server_url
# Use JSONFileStorage for actual file operations
self._storage = JSONFileStorage(cache_dir or default_cache_dir())

@staticmethod
def get_base_url(url: str) -> str:
"""Extract the base URL (scheme + host) from a URL."""
parsed = urlparse(url)
return f"{parsed.scheme}://{parsed.netloc}"

def _get_storage_key(self, file_type: Literal["client_info", "tokens"]) -> str:
"""Get the storage key for the specified data type.

JSONFileStorage will handle making the key filesystem-safe.
"""
base_url = self.get_base_url(self.server_url)
return f"{base_url}_{file_type}"

def _get_file_path(self, file_type: Literal["client_info", "tokens"]) -> Path:
"""Get the file path for the specified cache file type.

This method is kept for backward compatibility with tests that access _get_file_path.
"""
key = self._get_storage_key(file_type)
return self._storage._get_file_path(key)

async def get_tokens(self) -> OAuthToken | None:
"""Load tokens from file storage."""
key = self._get_storage_key("tokens")
data = await self._storage.get(key)

if data is None:
return None

try:
# Parse and validate as StoredToken
stored = stored_token_adapter.validate_python(data)

# Check if token is expired
if stored.expires_at is not None:
now = datetime.now(timezone.utc)
if now >= stored.expires_at:
logger.debug(
f"Token expired for {self.get_base_url(self.server_url)}"
)
return None

# Recalculate expires_in to be correct relative to now
if stored.token_payload.expires_in is not None:
remaining = stored.expires_at - now
stored.token_payload.expires_in = max(
0, int(remaining.total_seconds())
)

return stored.token_payload

except ValidationError as e:
logger.debug(
f"Could not validate tokens for {self.get_base_url(self.server_url)}: {e}"
)
return None

async def set_tokens(self, tokens: OAuthToken) -> None:
"""Save tokens to file storage."""
key = self._get_storage_key("tokens")

# Calculate absolute expiry time if expires_in is present
expires_at = None
if tokens.expires_in is not None:
expires_at = datetime.now(timezone.utc) + timedelta(
seconds=tokens.expires_in
)

# Create StoredToken and save using storage
# Note: JSONFileStorage will wrap this in {"data": ..., "timestamp": ...}
stored = StoredToken(token_payload=tokens, expires_at=expires_at)
await self._storage.set(key, stored.model_dump(mode="json"))
logger.debug(f"Saved tokens for {self.get_base_url(self.server_url)}")

async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Load client information from file storage."""
key = self._get_storage_key("client_info")
data = await self._storage.get(key)

if data is None:
return None

try:
client_info = OAuthClientInformationFull.model_validate(data)
# Check if we have corresponding valid tokens
# If no tokens exist, the OAuth flow was incomplete and we should
# force a fresh client registration
tokens = await self.get_tokens()
if tokens is None:
logger.debug(
f"No tokens found for client info at {self.get_base_url(self.server_url)}. "
"OAuth flow may have been incomplete. Clearing client info to force fresh registration."
)
# Clear the incomplete client info
await self._storage.delete(key)
return None

return client_info
except ValidationError as e:
logger.debug(
f"Could not validate client info for {self.get_base_url(self.server_url)}: {e}"
)
return None

async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Save client information to file storage."""
key = self._get_storage_key("client_info")
await self._storage.set(key, client_info.model_dump(mode="json"))
logger.debug(f"Saved client info for {self.get_base_url(self.server_url)}")

def clear(self) -> None:
"""Clear all cached data for this server.

Note: This is a synchronous method for backward compatibility.
Uses direct file operations instead of async storage methods.
"""
file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
for file_type in file_types:
# Use the file path directly for synchronous deletion
path = self._get_file_path(file_type)
path.unlink(missing_ok=True)
logger.debug(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")

@classmethod
def clear_all(cls, cache_dir: Path | None = None) -> None:
"""Clear all cached data for all servers."""
cache_dir = cache_dir or default_cache_dir()
if not cache_dir.exists():
return

file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
for file_type in file_types:
for file in cache_dir.glob(f"*_{file_type}.json"):
file.unlink(missing_ok=True)
logger.info("Cleared all OAuth client cache data.")


async def check_if_auth_required(
mcp_url: str, httpx_kwargs: dict[str, Any] | None = None
) -> bool:
Expand Down Expand Up @@ -239,6 +70,70 @@ async def check_if_auth_required(
return True


class TokenStorageAdapter(TokenStorage):
_server_url: str
_key_value_store: AsyncKeyValue
_storage_oauth_token: PydanticAdapter[OAuthToken]
_storage_client_info: PydanticAdapter[OAuthClientInformationFull]

def __init__(self, async_key_value: AsyncKeyValue, server_url: str):
self._server_url = server_url
self._key_value_store = async_key_value
self._storage_oauth_token = PydanticAdapter[OAuthToken](
default_collection="mcp-oauth-token",
key_value=async_key_value,
pydantic_model=OAuthToken,
raise_on_validation_error=True,
)
self._storage_client_info = PydanticAdapter[OAuthClientInformationFull](
default_collection="mcp-oauth-client-info",
key_value=async_key_value,
pydantic_model=OAuthClientInformationFull,
raise_on_validation_error=True,
)

def _get_token_cache_key(self) -> str:
return f"{self._server_url}/tokens"

def _get_client_info_cache_key(self) -> str:
return f"{self._server_url}/client_info"

async def clear(self) -> None:
await self._storage_oauth_token.delete(key=self._get_token_cache_key())
await self._storage_client_info.delete(key=self._get_client_info_cache_key())

@override
async def get_tokens(self) -> OAuthToken | None:
return await self._storage_oauth_token.get(key=self._get_token_cache_key())

@override
async def set_tokens(self, tokens: OAuthToken) -> None:
await self._storage_oauth_token.put(
key=self._get_token_cache_key(),
value=tokens,
ttl=tokens.expires_in,
)

@override
async def get_client_info(self) -> OAuthClientInformationFull | None:
return await self._storage_client_info.get(
key=self._get_client_info_cache_key()
)

@override
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
ttl: int | None = None

if client_info.client_secret_expires_at:
ttl = client_info.client_secret_expires_at - int(time.time())

await self._storage_client_info.put(
key=self._get_client_info_cache_key(),
value=client_info,
ttl=ttl,
)


class OAuth(OAuthClientProvider):
"""
OAuth client provider for MCP servers with browser-based authentication.
Expand All @@ -252,7 +147,7 @@ def __init__(
mcp_url: str,
scopes: str | list[str] | None = None,
client_name: str = "FastMCP Client",
token_storage_cache_dir: Path | None = None,
token_storage: AsyncKeyValue | None = None,
Copy link

Copilot AI Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Breaking change: The parameter name changed from token_storage_cache_dir to token_storage and the type changed from Path | None to AsyncKeyValue | None. This will break existing code that passes the old parameter name or a Path object.

Copilot uses AI. Check for mistakes.
additional_client_metadata: dict[str, Any] | None = None,
callback_port: int | None = None,
):
Expand All @@ -264,7 +159,7 @@ def __init__(
scopes: OAuth scopes to request. Can be a
space-separated string or a list of strings.
client_name: Name for this client during registration
token_storage_cache_dir: Directory for FileTokenStorage
token_storage: An AsyncKeyValue-compatible token store, tokens are stored in memory if not provided
additional_client_metadata: Extra fields for OAuthClientMetadata
callback_port: Fixed port for OAuth callback (default: random available port)
"""
Expand Down Expand Up @@ -294,8 +189,10 @@ def __init__(
)

# Create server-specific token storage
storage = FileTokenStorage(
server_url=server_base_url, cache_dir=token_storage_cache_dir
token_storage = token_storage or MemoryStore()

self.token_storage_adapter: TokenStorageAdapter = TokenStorageAdapter(
async_key_value=token_storage, server_url=server_base_url
)

# Store server_base_url for use in callback_handler
Expand All @@ -305,7 +202,7 @@ def __init__(
super().__init__(
server_url=server_base_url,
client_metadata=client_metadata,
storage=storage,
storage=self.token_storage_adapter,
redirect_handler=self.redirect_handler,
callback_handler=self.callback_handler,
)
Expand Down Expand Up @@ -399,23 +296,7 @@ async def async_auth_flow(

# Clear cached state and retry once
self._initialized = False

# Try to clear storage if it supports it
if hasattr(self.context.storage, "clear"):
try:
self.context.storage.clear()
except Exception as e:
logger.warning(f"Failed to clear OAuth storage cache: {e}")
# Can't retry without clearing cache, re-raise original error
raise ClientNotFoundError(
"OAuth client not found and cache could not be cleared"
) from e
else:
logger.warning(
"Storage does not support clear() - cannot retry with fresh credentials"
)
# Can't retry without clearing cache, re-raise original error
raise
await self.token_storage_adapter.clear()

gen = super().async_auth_flow(request)
response = None
Expand Down
Loading