Skip to content

Commit dee2bb5

Browse files
committed
Switch to DiskStore KV implementation
1 parent a41600d commit dee2bb5

File tree

16 files changed

+369
-813
lines changed

16 files changed

+369
-813
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies = [
1515
"pydantic[email]>=2.11.7",
1616
"pyperclip>=1.9.0",
1717
"openapi-core>=0.19.5",
18+
"kv-store-adapter[disk,memory]>=0.1.1",
1819
]
1920

2021
requires-python = ">=3.10"
@@ -117,7 +118,7 @@ testpaths = ["tests"]
117118
python_files = ["test_*.py", "*_test.py"]
118119
python_classes = ["Test*"]
119120
python_functions = ["test_*"]
120-
addopts = ["--inline-snapshot=disable"]
121+
addopts = ["--inline-snapshot=fix,create"]
121122

122123
[tool.ty.src]
123124
include = ["src", "tests"]

src/fastmcp/client/auth/oauth.py

Lines changed: 79 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,35 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import time
45
import webbrowser
56
from asyncio import Future
67
from collections.abc import AsyncGenerator
7-
from datetime import datetime, timedelta, timezone
8+
from datetime import datetime
89
from pathlib import Path
9-
from typing import Any, Literal
10+
from typing import Any
1011
from urllib.parse import urlparse
1112

1213
import anyio
1314
import httpx
15+
from kv_store_adapter.adapters.pydantic import PydanticAdapter
16+
from kv_store_adapter.stores.disk import DiskStore
17+
from kv_store_adapter.types import KVStoreProtocol
1418
from mcp.client.auth import OAuthClientProvider, TokenStorage
1519
from mcp.shared.auth import (
1620
OAuthClientInformationFull,
1721
OAuthClientMetadata,
22+
OAuthToken,
1823
)
19-
from mcp.shared.auth import (
20-
OAuthToken as OAuthToken,
21-
)
22-
from pydantic import AnyHttpUrl, BaseModel, TypeAdapter, ValidationError
24+
from pydantic import AnyHttpUrl, BaseModel
2325
from uvicorn.server import Server
2426

25-
from fastmcp import settings as fastmcp_global_settings
27+
from fastmcp import settings
2628
from fastmcp.client.oauth_callback import (
2729
create_oauth_callback_server,
2830
)
2931
from fastmcp.utilities.http import find_available_port
3032
from fastmcp.utilities.logging import get_logger
31-
from fastmcp.utilities.storage import JSONFileStorage
3233

3334
__all__ = ["OAuth"]
3435

@@ -41,174 +42,6 @@ class ClientNotFoundError(Exception):
4142
pass
4243

4344

44-
class StoredToken(BaseModel):
45-
"""Token storage format with absolute expiry time."""
46-
47-
token_payload: OAuthToken
48-
expires_at: datetime | None
49-
50-
51-
# Create TypeAdapter at module level for efficient parsing
52-
stored_token_adapter = TypeAdapter(StoredToken)
53-
54-
55-
def default_cache_dir() -> Path:
56-
return fastmcp_global_settings.home / "oauth-mcp-client-cache"
57-
58-
59-
class FileTokenStorage(TokenStorage):
60-
"""
61-
File-based token storage implementation for OAuth credentials and tokens.
62-
Implements the mcp.client.auth.TokenStorage protocol.
63-
64-
Each instance is tied to a specific server URL for proper token isolation.
65-
Uses JSONFileStorage internally for consistent file handling.
66-
"""
67-
68-
def __init__(self, server_url: str, cache_dir: Path | None = None):
69-
"""Initialize storage for a specific server URL."""
70-
self.server_url = server_url
71-
# Use JSONFileStorage for actual file operations
72-
self._storage = JSONFileStorage(cache_dir or default_cache_dir())
73-
74-
@staticmethod
75-
def get_base_url(url: str) -> str:
76-
"""Extract the base URL (scheme + host) from a URL."""
77-
parsed = urlparse(url)
78-
return f"{parsed.scheme}://{parsed.netloc}"
79-
80-
def _get_storage_key(self, file_type: Literal["client_info", "tokens"]) -> str:
81-
"""Get the storage key for the specified data type.
82-
83-
JSONFileStorage will handle making the key filesystem-safe.
84-
"""
85-
base_url = self.get_base_url(self.server_url)
86-
return f"{base_url}_{file_type}"
87-
88-
def _get_file_path(self, file_type: Literal["client_info", "tokens"]) -> Path:
89-
"""Get the file path for the specified cache file type.
90-
91-
This method is kept for backward compatibility with tests that access _get_file_path.
92-
"""
93-
key = self._get_storage_key(file_type)
94-
return self._storage._get_file_path(key)
95-
96-
async def get_tokens(self) -> OAuthToken | None:
97-
"""Load tokens from file storage."""
98-
key = self._get_storage_key("tokens")
99-
data = await self._storage.get(key)
100-
101-
if data is None:
102-
return None
103-
104-
try:
105-
# Parse and validate as StoredToken
106-
stored = stored_token_adapter.validate_python(data)
107-
108-
# Check if token is expired
109-
if stored.expires_at is not None:
110-
now = datetime.now(timezone.utc)
111-
if now >= stored.expires_at:
112-
logger.debug(
113-
f"Token expired for {self.get_base_url(self.server_url)}"
114-
)
115-
return None
116-
117-
# Recalculate expires_in to be correct relative to now
118-
if stored.token_payload.expires_in is not None:
119-
remaining = stored.expires_at - now
120-
stored.token_payload.expires_in = max(
121-
0, int(remaining.total_seconds())
122-
)
123-
124-
return stored.token_payload
125-
126-
except ValidationError as e:
127-
logger.debug(
128-
f"Could not validate tokens for {self.get_base_url(self.server_url)}: {e}"
129-
)
130-
return None
131-
132-
async def set_tokens(self, tokens: OAuthToken) -> None:
133-
"""Save tokens to file storage."""
134-
key = self._get_storage_key("tokens")
135-
136-
# Calculate absolute expiry time if expires_in is present
137-
expires_at = None
138-
if tokens.expires_in is not None:
139-
expires_at = datetime.now(timezone.utc) + timedelta(
140-
seconds=tokens.expires_in
141-
)
142-
143-
# Create StoredToken and save using storage
144-
# Note: JSONFileStorage will wrap this in {"data": ..., "timestamp": ...}
145-
stored = StoredToken(token_payload=tokens, expires_at=expires_at)
146-
await self._storage.set(key, stored.model_dump(mode="json"))
147-
logger.debug(f"Saved tokens for {self.get_base_url(self.server_url)}")
148-
149-
async def get_client_info(self) -> OAuthClientInformationFull | None:
150-
"""Load client information from file storage."""
151-
key = self._get_storage_key("client_info")
152-
data = await self._storage.get(key)
153-
154-
if data is None:
155-
return None
156-
157-
try:
158-
client_info = OAuthClientInformationFull.model_validate(data)
159-
# Check if we have corresponding valid tokens
160-
# If no tokens exist, the OAuth flow was incomplete and we should
161-
# force a fresh client registration
162-
tokens = await self.get_tokens()
163-
if tokens is None:
164-
logger.debug(
165-
f"No tokens found for client info at {self.get_base_url(self.server_url)}. "
166-
"OAuth flow may have been incomplete. Clearing client info to force fresh registration."
167-
)
168-
# Clear the incomplete client info
169-
await self._storage.delete(key)
170-
return None
171-
172-
return client_info
173-
except ValidationError as e:
174-
logger.debug(
175-
f"Could not validate client info for {self.get_base_url(self.server_url)}: {e}"
176-
)
177-
return None
178-
179-
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
180-
"""Save client information to file storage."""
181-
key = self._get_storage_key("client_info")
182-
await self._storage.set(key, client_info.model_dump(mode="json"))
183-
logger.debug(f"Saved client info for {self.get_base_url(self.server_url)}")
184-
185-
def clear(self) -> None:
186-
"""Clear all cached data for this server.
187-
188-
Note: This is a synchronous method for backward compatibility.
189-
Uses direct file operations instead of async storage methods.
190-
"""
191-
file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
192-
for file_type in file_types:
193-
# Use the file path directly for synchronous deletion
194-
path = self._get_file_path(file_type)
195-
path.unlink(missing_ok=True)
196-
logger.debug(f"Cleared OAuth cache for {self.get_base_url(self.server_url)}")
197-
198-
@classmethod
199-
def clear_all(cls, cache_dir: Path | None = None) -> None:
200-
"""Clear all cached data for all servers."""
201-
cache_dir = cache_dir or default_cache_dir()
202-
if not cache_dir.exists():
203-
return
204-
205-
file_types: list[Literal["client_info", "tokens"]] = ["client_info", "tokens"]
206-
for file_type in file_types:
207-
for file in cache_dir.glob(f"*_{file_type}.json"):
208-
file.unlink(missing_ok=True)
209-
logger.info("Cleared all OAuth client cache data.")
210-
211-
21245
async def check_if_auth_required(
21346
mcp_url: str, httpx_kwargs: dict[str, Any] | None = None
21447
) -> bool:
@@ -239,6 +72,68 @@ async def check_if_auth_required(
23972
return True
24073

24174

75+
class TokenStorageAdapter(TokenStorage):
76+
_server_url: str
77+
_kv_store_protocol: KVStoreProtocol
78+
_storage_oauth_token: PydanticAdapter[OAuthToken]
79+
_storage_client_info: PydanticAdapter[OAuthClientInformationFull]
80+
81+
def __init__(self, kv_store_protocol: KVStoreProtocol, server_url: str):
82+
self._server_url = server_url
83+
self._kv_store_protocol = kv_store_protocol
84+
self._storage_oauth_token = PydanticAdapter[OAuthToken](
85+
store_protocol=kv_store_protocol, pydantic_model=OAuthToken
86+
)
87+
self._storage_client_info = PydanticAdapter[OAuthClientInformationFull](
88+
store_protocol=kv_store_protocol, pydantic_model=OAuthClientInformationFull
89+
)
90+
91+
def _get_token_cache_key(self) -> str:
92+
return f"{self._server_url}/tokens"
93+
94+
def _get_client_info_cache_key(self) -> str:
95+
return f"{self._server_url}/client_info"
96+
97+
async def clear(self) -> None:
98+
await self._storage_oauth_token.delete(
99+
collection="oauth-mcp-client-cache", key=self._get_token_cache_key()
100+
)
101+
await self._storage_client_info.delete(
102+
collection="oauth-mcp-client-cache", key=self._get_client_info_cache_key()
103+
)
104+
105+
async def get_tokens(self) -> OAuthToken | None:
106+
return await self._storage_oauth_token.get(
107+
collection="oauth-mcp-client-cache", key=self._get_token_cache_key()
108+
)
109+
110+
async def set_tokens(self, tokens: OAuthToken) -> None:
111+
await self._storage_oauth_token.put(
112+
collection="oauth-mcp-client-cache",
113+
key=self._get_token_cache_key(),
114+
value=tokens,
115+
ttl=tokens.expires_in,
116+
)
117+
118+
async def get_client_info(self) -> OAuthClientInformationFull | None:
119+
return await self._storage_client_info.get(
120+
collection="oauth-mcp-client-cache", key=self._get_client_info_cache_key()
121+
)
122+
123+
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
124+
ttl: int | None = None
125+
126+
if client_info.client_secret_expires_at:
127+
ttl = client_info.client_secret_expires_at - int(time.time())
128+
129+
await self._storage_client_info.put(
130+
collection="oauth-mcp-client-cache",
131+
key=self._get_client_info_cache_key(),
132+
value=client_info,
133+
ttl=ttl,
134+
)
135+
136+
242137
class OAuth(OAuthClientProvider):
243138
"""
244139
OAuth client provider for MCP servers with browser-based authentication.
@@ -252,7 +147,7 @@ def __init__(
252147
mcp_url: str,
253148
scopes: str | list[str] | None = None,
254149
client_name: str = "FastMCP Client",
255-
token_storage_cache_dir: Path | None = None,
150+
token_storage: KVStoreProtocol | None = None,
256151
additional_client_metadata: dict[str, Any] | None = None,
257152
callback_port: int | None = None,
258153
):
@@ -264,7 +159,7 @@ def __init__(
264159
scopes: OAuth scopes to request. Can be a
265160
space-separated string or a list of strings.
266161
client_name: Name for this client during registration
267-
token_storage_cache_dir: Directory for FileTokenStorage
162+
token_storage: KVStoreProtocol for token storage, the default disk store is used if not provided
268163
additional_client_metadata: Extra fields for OAuthClientMetadata
269164
callback_port: Fixed port for OAuth callback (default: random available port)
270165
"""
@@ -294,8 +189,10 @@ def __init__(
294189
)
295190

296191
# Create server-specific token storage
297-
storage = FileTokenStorage(
298-
server_url=server_base_url, cache_dir=token_storage_cache_dir
192+
token_storage = token_storage or settings.data_store
193+
194+
self.token_storage_adapter: TokenStorageAdapter = TokenStorageAdapter(
195+
kv_store_protocol=token_storage, server_url=server_base_url
299196
)
300197

301198
# Store server_base_url for use in callback_handler
@@ -305,7 +202,7 @@ def __init__(
305202
super().__init__(
306203
server_url=server_base_url,
307204
client_metadata=client_metadata,
308-
storage=storage,
205+
storage=self.token_storage_adapter,
309206
redirect_handler=self.redirect_handler,
310207
callback_handler=self.callback_handler,
311208
)
@@ -399,23 +296,7 @@ async def async_auth_flow(
399296

400297
# Clear cached state and retry once
401298
self._initialized = False
402-
403-
# Try to clear storage if it supports it
404-
if hasattr(self.context.storage, "clear"):
405-
try:
406-
self.context.storage.clear()
407-
except Exception as e:
408-
logger.warning(f"Failed to clear OAuth storage cache: {e}")
409-
# Can't retry without clearing cache, re-raise original error
410-
raise ClientNotFoundError(
411-
"OAuth client not found and cache could not be cleared"
412-
) from e
413-
else:
414-
logger.warning(
415-
"Storage does not support clear() - cannot retry with fresh credentials"
416-
)
417-
# Can't retry without clearing cache, re-raise original error
418-
raise
299+
await self.token_storage_adapter.clear()
419300

420301
gen = super().async_auth_flow(request)
421302
response = None

0 commit comments

Comments
 (0)