11from __future__ import annotations
22
33import asyncio
4+ import time
45import webbrowser
56from asyncio import Future
67from collections .abc import AsyncGenerator
7- from datetime import datetime , timedelta , timezone
8+ from datetime import datetime
89from pathlib import Path
9- from typing import Any , Literal
10+ from typing import Any
1011from urllib .parse import urlparse
1112
1213import anyio
1314import httpx
15+ from kv_store_adapter .adapters .pydantic import PydanticAdapter
16+ from kv_store_adapter .stores .disk import DiskStore
17+ from kv_store_adapter .types import KVStoreProtocol
1418from mcp .client .auth import OAuthClientProvider , TokenStorage
1519from mcp .shared .auth import (
1620 OAuthClientInformationFull ,
1721 OAuthClientMetadata ,
22+ OAuthToken ,
1823)
19- from mcp .shared .auth import (
20- OAuthToken as OAuthToken ,
21- )
22- from pydantic import AnyHttpUrl , BaseModel , TypeAdapter , ValidationError
24+ from pydantic import AnyHttpUrl , BaseModel
2325from uvicorn .server import Server
2426
25- from fastmcp import settings as fastmcp_global_settings
27+ from fastmcp import settings
2628from fastmcp .client .oauth_callback import (
2729 create_oauth_callback_server ,
2830)
2931from fastmcp .utilities .http import find_available_port
3032from fastmcp .utilities .logging import get_logger
31- from fastmcp .utilities .storage import JSONFileStorage
3233
3334__all__ = ["OAuth" ]
3435
@@ -41,174 +42,6 @@ class ClientNotFoundError(Exception):
4142 pass
4243
4344
44- class StoredToken (BaseModel ):
45- """Token storage format with absolute expiry time."""
46-
47- token_payload : OAuthToken
48- expires_at : datetime | None
49-
50-
51- # Create TypeAdapter at module level for efficient parsing
52- stored_token_adapter = TypeAdapter (StoredToken )
53-
54-
55- def default_cache_dir () -> Path :
56- return fastmcp_global_settings .home / "oauth-mcp-client-cache"
57-
58-
59- class FileTokenStorage (TokenStorage ):
60- """
61- File-based token storage implementation for OAuth credentials and tokens.
62- Implements the mcp.client.auth.TokenStorage protocol.
63-
64- Each instance is tied to a specific server URL for proper token isolation.
65- Uses JSONFileStorage internally for consistent file handling.
66- """
67-
68- def __init__ (self , server_url : str , cache_dir : Path | None = None ):
69- """Initialize storage for a specific server URL."""
70- self .server_url = server_url
71- # Use JSONFileStorage for actual file operations
72- self ._storage = JSONFileStorage (cache_dir or default_cache_dir ())
73-
74- @staticmethod
75- def get_base_url (url : str ) -> str :
76- """Extract the base URL (scheme + host) from a URL."""
77- parsed = urlparse (url )
78- return f"{ parsed .scheme } ://{ parsed .netloc } "
79-
80- def _get_storage_key (self , file_type : Literal ["client_info" , "tokens" ]) -> str :
81- """Get the storage key for the specified data type.
82-
83- JSONFileStorage will handle making the key filesystem-safe.
84- """
85- base_url = self .get_base_url (self .server_url )
86- return f"{ base_url } _{ file_type } "
87-
88- def _get_file_path (self , file_type : Literal ["client_info" , "tokens" ]) -> Path :
89- """Get the file path for the specified cache file type.
90-
91- This method is kept for backward compatibility with tests that access _get_file_path.
92- """
93- key = self ._get_storage_key (file_type )
94- return self ._storage ._get_file_path (key )
95-
96- async def get_tokens (self ) -> OAuthToken | None :
97- """Load tokens from file storage."""
98- key = self ._get_storage_key ("tokens" )
99- data = await self ._storage .get (key )
100-
101- if data is None :
102- return None
103-
104- try :
105- # Parse and validate as StoredToken
106- stored = stored_token_adapter .validate_python (data )
107-
108- # Check if token is expired
109- if stored .expires_at is not None :
110- now = datetime .now (timezone .utc )
111- if now >= stored .expires_at :
112- logger .debug (
113- f"Token expired for { self .get_base_url (self .server_url )} "
114- )
115- return None
116-
117- # Recalculate expires_in to be correct relative to now
118- if stored .token_payload .expires_in is not None :
119- remaining = stored .expires_at - now
120- stored .token_payload .expires_in = max (
121- 0 , int (remaining .total_seconds ())
122- )
123-
124- return stored .token_payload
125-
126- except ValidationError as e :
127- logger .debug (
128- f"Could not validate tokens for { self .get_base_url (self .server_url )} : { e } "
129- )
130- return None
131-
132- async def set_tokens (self , tokens : OAuthToken ) -> None :
133- """Save tokens to file storage."""
134- key = self ._get_storage_key ("tokens" )
135-
136- # Calculate absolute expiry time if expires_in is present
137- expires_at = None
138- if tokens .expires_in is not None :
139- expires_at = datetime .now (timezone .utc ) + timedelta (
140- seconds = tokens .expires_in
141- )
142-
143- # Create StoredToken and save using storage
144- # Note: JSONFileStorage will wrap this in {"data": ..., "timestamp": ...}
145- stored = StoredToken (token_payload = tokens , expires_at = expires_at )
146- await self ._storage .set (key , stored .model_dump (mode = "json" ))
147- logger .debug (f"Saved tokens for { self .get_base_url (self .server_url )} " )
148-
149- async def get_client_info (self ) -> OAuthClientInformationFull | None :
150- """Load client information from file storage."""
151- key = self ._get_storage_key ("client_info" )
152- data = await self ._storage .get (key )
153-
154- if data is None :
155- return None
156-
157- try :
158- client_info = OAuthClientInformationFull .model_validate (data )
159- # Check if we have corresponding valid tokens
160- # If no tokens exist, the OAuth flow was incomplete and we should
161- # force a fresh client registration
162- tokens = await self .get_tokens ()
163- if tokens is None :
164- logger .debug (
165- f"No tokens found for client info at { self .get_base_url (self .server_url )} . "
166- "OAuth flow may have been incomplete. Clearing client info to force fresh registration."
167- )
168- # Clear the incomplete client info
169- await self ._storage .delete (key )
170- return None
171-
172- return client_info
173- except ValidationError as e :
174- logger .debug (
175- f"Could not validate client info for { self .get_base_url (self .server_url )} : { e } "
176- )
177- return None
178-
179- async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None :
180- """Save client information to file storage."""
181- key = self ._get_storage_key ("client_info" )
182- await self ._storage .set (key , client_info .model_dump (mode = "json" ))
183- logger .debug (f"Saved client info for { self .get_base_url (self .server_url )} " )
184-
185- def clear (self ) -> None :
186- """Clear all cached data for this server.
187-
188- Note: This is a synchronous method for backward compatibility.
189- Uses direct file operations instead of async storage methods.
190- """
191- file_types : list [Literal ["client_info" , "tokens" ]] = ["client_info" , "tokens" ]
192- for file_type in file_types :
193- # Use the file path directly for synchronous deletion
194- path = self ._get_file_path (file_type )
195- path .unlink (missing_ok = True )
196- logger .debug (f"Cleared OAuth cache for { self .get_base_url (self .server_url )} " )
197-
198- @classmethod
199- def clear_all (cls , cache_dir : Path | None = None ) -> None :
200- """Clear all cached data for all servers."""
201- cache_dir = cache_dir or default_cache_dir ()
202- if not cache_dir .exists ():
203- return
204-
205- file_types : list [Literal ["client_info" , "tokens" ]] = ["client_info" , "tokens" ]
206- for file_type in file_types :
207- for file in cache_dir .glob (f"*_{ file_type } .json" ):
208- file .unlink (missing_ok = True )
209- logger .info ("Cleared all OAuth client cache data." )
210-
211-
21245async def check_if_auth_required (
21346 mcp_url : str , httpx_kwargs : dict [str , Any ] | None = None
21447) -> bool :
@@ -239,6 +72,68 @@ async def check_if_auth_required(
23972 return True
24073
24174
75+ class TokenStorageAdapter (TokenStorage ):
76+ _server_url : str
77+ _kv_store_protocol : KVStoreProtocol
78+ _storage_oauth_token : PydanticAdapter [OAuthToken ]
79+ _storage_client_info : PydanticAdapter [OAuthClientInformationFull ]
80+
81+ def __init__ (self , kv_store_protocol : KVStoreProtocol , server_url : str ):
82+ self ._server_url = server_url
83+ self ._kv_store_protocol = kv_store_protocol
84+ self ._storage_oauth_token = PydanticAdapter [OAuthToken ](
85+ store_protocol = kv_store_protocol , pydantic_model = OAuthToken
86+ )
87+ self ._storage_client_info = PydanticAdapter [OAuthClientInformationFull ](
88+ store_protocol = kv_store_protocol , pydantic_model = OAuthClientInformationFull
89+ )
90+
91+ def _get_token_cache_key (self ) -> str :
92+ return f"{ self ._server_url } /tokens"
93+
94+ def _get_client_info_cache_key (self ) -> str :
95+ return f"{ self ._server_url } /client_info"
96+
97+ async def clear (self ) -> None :
98+ await self ._storage_oauth_token .delete (
99+ collection = "oauth-mcp-client-cache" , key = self ._get_token_cache_key ()
100+ )
101+ await self ._storage_client_info .delete (
102+ collection = "oauth-mcp-client-cache" , key = self ._get_client_info_cache_key ()
103+ )
104+
105+ async def get_tokens (self ) -> OAuthToken | None :
106+ return await self ._storage_oauth_token .get (
107+ collection = "oauth-mcp-client-cache" , key = self ._get_token_cache_key ()
108+ )
109+
110+ async def set_tokens (self , tokens : OAuthToken ) -> None :
111+ await self ._storage_oauth_token .put (
112+ collection = "oauth-mcp-client-cache" ,
113+ key = self ._get_token_cache_key (),
114+ value = tokens ,
115+ ttl = tokens .expires_in ,
116+ )
117+
118+ async def get_client_info (self ) -> OAuthClientInformationFull | None :
119+ return await self ._storage_client_info .get (
120+ collection = "oauth-mcp-client-cache" , key = self ._get_client_info_cache_key ()
121+ )
122+
123+ async def set_client_info (self , client_info : OAuthClientInformationFull ) -> None :
124+ ttl : int | None = None
125+
126+ if client_info .client_secret_expires_at :
127+ ttl = client_info .client_secret_expires_at - int (time .time ())
128+
129+ await self ._storage_client_info .put (
130+ collection = "oauth-mcp-client-cache" ,
131+ key = self ._get_client_info_cache_key (),
132+ value = client_info ,
133+ ttl = ttl ,
134+ )
135+
136+
242137class OAuth (OAuthClientProvider ):
243138 """
244139 OAuth client provider for MCP servers with browser-based authentication.
@@ -252,7 +147,7 @@ def __init__(
252147 mcp_url : str ,
253148 scopes : str | list [str ] | None = None ,
254149 client_name : str = "FastMCP Client" ,
255- token_storage_cache_dir : Path | None = None ,
150+ token_storage : KVStoreProtocol | None = None ,
256151 additional_client_metadata : dict [str , Any ] | None = None ,
257152 callback_port : int | None = None ,
258153 ):
@@ -264,7 +159,7 @@ def __init__(
264159 scopes: OAuth scopes to request. Can be a
265160 space-separated string or a list of strings.
266161 client_name: Name for this client during registration
267- token_storage_cache_dir: Directory for FileTokenStorage
162+ token_storage: KVStoreProtocol for token storage, the default disk store is used if not provided
268163 additional_client_metadata: Extra fields for OAuthClientMetadata
269164 callback_port: Fixed port for OAuth callback (default: random available port)
270165 """
@@ -294,8 +189,10 @@ def __init__(
294189 )
295190
296191 # Create server-specific token storage
297- storage = FileTokenStorage (
298- server_url = server_base_url , cache_dir = token_storage_cache_dir
192+ token_storage = token_storage or settings .data_store
193+
194+ self .token_storage_adapter : TokenStorageAdapter = TokenStorageAdapter (
195+ kv_store_protocol = token_storage , server_url = server_base_url
299196 )
300197
301198 # Store server_base_url for use in callback_handler
@@ -305,7 +202,7 @@ def __init__(
305202 super ().__init__ (
306203 server_url = server_base_url ,
307204 client_metadata = client_metadata ,
308- storage = storage ,
205+ storage = self . token_storage_adapter ,
309206 redirect_handler = self .redirect_handler ,
310207 callback_handler = self .callback_handler ,
311208 )
@@ -399,23 +296,7 @@ async def async_auth_flow(
399296
400297 # Clear cached state and retry once
401298 self ._initialized = False
402-
403- # Try to clear storage if it supports it
404- if hasattr (self .context .storage , "clear" ):
405- try :
406- self .context .storage .clear ()
407- except Exception as e :
408- logger .warning (f"Failed to clear OAuth storage cache: { e } " )
409- # Can't retry without clearing cache, re-raise original error
410- raise ClientNotFoundError (
411- "OAuth client not found and cache could not be cleared"
412- ) from e
413- else :
414- logger .warning (
415- "Storage does not support clear() - cannot retry with fresh credentials"
416- )
417- # Can't retry without clearing cache, re-raise original error
418- raise
299+ await self .token_storage_adapter .clear ()
419300
420301 gen = super ().async_auth_flow (request )
421302 response = None
0 commit comments