From 6a393f79292f68cc13ec49446a8de3dffdbd9646 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Sat, 12 Apr 2025 19:28:48 +0800 Subject: [PATCH 1/6] chore: try to add oauth2 support add oauth2 client support Signed-off-by: jokemanfire --- OAUTH_README.md | 156 ++++++++++ crates/rmcp/Cargo.toml | 6 + crates/rmcp/src/transport.rs | 10 + crates/rmcp/src/transport/auth.rs | 428 ++++++++++++++++++++++++++ crates/rmcp/src/transport/sse_auth.rs | 183 +++++++++++ examples/clients/Cargo.toml | 8 +- examples/clients/src/oauth_client.rs | 90 ++++++ 7 files changed, 879 insertions(+), 2 deletions(-) create mode 100644 OAUTH_README.md create mode 100644 crates/rmcp/src/transport/auth.rs create mode 100644 crates/rmcp/src/transport/sse_auth.rs create mode 100644 examples/clients/src/oauth_client.rs diff --git a/OAUTH_README.md b/OAUTH_README.md new file mode 100644 index 0000000..2b02a9e --- /dev/null +++ b/OAUTH_README.md @@ -0,0 +1,156 @@ +# Model Context Protocol OAuth Authorization + +This document describes the OAuth 2.1 authorization implementation for Model Context Protocol (MCP), following the [MCP 2025-03-26 Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/). + +## Features + +- Full support for OAuth 2.1 authorization flow +- PKCE support for enhanced security +- Authorization server metadata discovery +- Dynamic client registration +- Automatic token refresh +- Authorized SSE transport implementation + +## Usage Guide + +### 1. Enable Features + +Enable the auth feature in Cargo.toml: + +```toml +[dependencies] +rmcp = { version = "0.1", features = ["auth", "transport-sse"] } +``` + +### 2. Create Authorization Manager + +```rust +use std::sync::Arc; +use rmcp::transport::auth::AuthorizationManager; + +async fn main() -> anyhow::Result<()> { + // Create authorization manager + let auth_manager = Arc::new(AuthorizationManager::new("https://api.example.com/mcp").await?); + + Ok(()) +} +``` + +### 3. Create Authorization Session and Get Authorization + +```rust +use rmcp::transport::auth::AuthorizationSession; + +async fn get_authorization(auth_manager: Arc) -> anyhow::Result<()> { + // Create authorization session + let session = AuthorizationSession::new( + auth_manager.clone(), + &["mcp"], // Requested scopes + "http://localhost:8080/callback", // Redirect URI + ).await?; + + // Get authorization URL and guide user to open it + let auth_url = session.get_authorization_url(); + println!("Please open the following URL in your browser for authorization:\n{}", auth_url); + + // Handle callback - In real applications, this is typically done in a callback server + let auth_code = "Authorization code obtained from browser after user authorization"; + let credentials = session.handle_callback(auth_code).await?; + + println!("Authorization successful, access token: {}", credentials.access_token); + + Ok(()) +} +``` + +### 4. Use Authorized SSE Transport + +```rust +use rmcp::{ServiceExt, model::ClientInfo, transport::create_authorized_transport}; + +async fn connect_with_auth(auth_manager: Arc) -> anyhow::Result<()> { + // Create authorized SSE transport + let transport = create_authorized_transport( + "https://api.example.com/mcp", + auth_manager.clone() + ).await?; + + // Create client + let client_service = ClientInfo::default(); + let client = client_service.serve(transport).await?; + + // Use client to call APIs + let tools = client.peer().list_all_tools().await?; + + for tool in tools { + println!("Tool: {} - {}", tool.name, tool.description); + } + + Ok(()) +} +``` + +### 5. Use Authorized HTTP Client + +```rust +use rmcp::transport::auth::AuthorizedHttpClient; + +async fn make_authorized_request(auth_manager: Arc) -> anyhow::Result<()> { + // Create authorized HTTP client + let client = AuthorizedHttpClient::new(auth_manager, None); + + // Send authorized request + let response = client.get("https://api.example.com/resources").await?; + let resources = response.json::>().await?; + + println!("Number of resources: {}", resources.len()); + + Ok(()) +} +``` + +## Complete Example + +Please refer to `examples/oauth_client.rs` for a complete usage example. + +## Running the Example + +```bash +# Set server URL (optional) +export MCP_SERVER_URL=https://api.example.com/mcp + +# Run example +cargo run --bin oauth-client +``` + +## Authorization Flow Description + +1. **Metadata Discovery**: Client attempts to get authorization server metadata from `/.well-known/oauth-authorization-server` +2. **Client Registration**: If supported, client dynamically registers itself +3. **Authorization Request**: Build authorization URL with PKCE and guide user to access +4. **Authorization Code Exchange**: After user authorization, exchange authorization code for access token +5. **Token Usage**: Use access token for API calls +6. **Token Refresh**: Automatically use refresh token to get new access token when current one expires + +## Security Considerations + +- All tokens are securely stored in memory +- PKCE implementation prevents authorization code interception attacks +- Automatic token refresh support reduces user intervention +- Only accepts HTTPS connections or secure local callback URIs + +## Troubleshooting + +If you encounter authorization issues, check the following: + +1. Ensure server supports OAuth 2.1 authorization +2. Verify callback URI matches server's allowed redirect URIs +3. Check network connection and firewall settings +4. Verify server supports metadata discovery or dynamic client registration + +## References + +- [MCP Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) +- [OAuth 2.1 Specification Draft](https://oauth.net/2.1/) +- [RFC 8414: OAuth 2.0 Authorization Server Metadata](https://datatracker.ietf.org/doc/html/rfc8414) +- [RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol](https://datatracker.ietf.org/doc/html/rfc7591) \ No newline at end of file diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 738f244..74bb0e1 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -23,6 +23,10 @@ tracing = { version = "0.1" } tokio-util = { version = "0.7" } pin-project-lite = "0.2" paste = { version = "1", optional = true } + +# oauth2 support +oauth2 = { version = "4.3", optional = true } + # for auto generate schema schemars = { version = "0.8", optional = true } @@ -70,6 +74,8 @@ transport-sse-server = [ ] # transport-ws = ["transport-io", "dep:tokio-tungstenite"] tower = ["dep:tower-service"] +auth = ["dep:oauth2", "dep:reqwest", "dep:url"] + [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "0.8" } diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index cbfbf9f..3158478 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -56,6 +56,11 @@ pub mod sse; #[cfg(feature = "transport-sse")] pub use sse::SseTransport; +#[cfg(all(feature = "transport-sse", feature = "auth"))] +pub mod sse_auth; +#[cfg(all(feature = "transport-sse", feature = "auth"))] +pub use sse_auth::{AuthorizedSseClient, create_authorized_transport}; + // #[cfg(feature = "tower")] // pub mod tower; @@ -64,6 +69,11 @@ pub mod sse_server; #[cfg(feature = "transport-sse-server")] pub use sse_server::SseServer; +#[cfg(feature = "auth")] +pub mod auth; +#[cfg(feature = "auth")] +pub use auth::{AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, AuthError}; + // #[cfg(feature = "transport-ws")] // pub mod ws; diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs new file mode 100644 index 0000000..7108486 --- /dev/null +++ b/crates/rmcp/src/transport/auth.rs @@ -0,0 +1,428 @@ +use std::sync::Arc; +use std::time::Duration; +use futures::future::BoxFuture; +use oauth2::basic::BasicTokenType; +use oauth2::{ + AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, + RedirectUrl, TokenResponse, Scope, AuthUrl, TokenUrl, RefreshToken, + StandardTokenResponse, TokenType, AccessToken, EmptyExtraTokenFields, + basic::BasicClient, reqwest::http_client, RefreshTokenRequest, AuthorizationRequest +}; +use reqwest::{Client as HttpClient, header::AUTHORIZATION, StatusCode, Url, IntoUrl}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio::sync::{Mutex, RwLock}; +use tokio::time::{self, Instant}; + +/// 错误定义 +#[derive(Debug, Error)] +pub enum AuthError { + #[error("OAuth authorization required")] + AuthorizationRequired, + + #[error("OAuth authorization failed: {0}")] + AuthorizationFailed(String), + + #[error("OAuth token exchange failed: {0}")] + TokenExchangeFailed(String), + + #[error("OAuth token refresh failed: {0}")] + TokenRefreshFailed(String), + + #[error("HTTP error: {0}")] + HttpError(#[from] reqwest::Error), + + #[error("OAuth error: {0}")] + OAuthError(String), + + #[error("Metadata error: {0}")] + MetadataError(String), + + #[error("URL parse error: {0}")] + UrlError(#[from] url::ParseError), + + #[error("No authorization support detected")] + NoAuthorizationSupport, + + #[error("Internal error: {0}")] + InternalError(String), + + #[error("Invalid token type: {0}")] + InvalidTokenType(String), + + #[error("Token expired")] + TokenExpired, + + #[error("Invalid scope: {0}")] + InvalidScope(String), + + #[error("Registration failed: {0}")] + RegistrationFailed(String), +} + +/// oauth2 metadata +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AuthorizationMetadata { + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: Option, + pub issuer: Option, + pub jwks_uri: Option, + pub scopes_supported: Option>, +} + +/// oauth2 client config +#[derive(Debug, Clone)] +pub struct OAuthClientConfig { + pub client_id: String, + pub client_secret: Option, + pub redirect_uri: String, + pub scopes: Vec, +} + +/// oauth2 auth manager +pub struct AuthorizationManager { + http_client: HttpClient, + metadata: Option, + oauth_client: Option, + credentials: RwLock>>, + pkce_verifier: RwLock>, + base_url: Url, +} + +impl AuthorizationManager { + /// create new auth manager + pub async fn new(base_url: U) -> Result { + let base_url = base_url.into_url()?; + let http_client = HttpClient::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + + let mut manager = Self { + http_client, + metadata: None, + oauth_client: None, + credentials: RwLock::new(None), + pkce_verifier: RwLock::new(None), + base_url, + }; + + // try to discover oauth2 metadata + if let Ok(metadata) = manager.discover_metadata().await { + manager.metadata = Some(metadata); + } + + Ok(manager) + } + + /// discover oauth2 metadata + pub async fn discover_metadata(&self) -> Result { + // according to the specification, the metadata should be located at "/.well-known/oauth-authorization-server" + let mut discovery_url = self.base_url.clone(); + discovery_url.set_path("/.well-known/oauth-authorization-server"); + + let response = self.http_client + .get(discovery_url) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await?; + + if response.status() == StatusCode::OK { + let metadata = response.json::().await + .map_err(|e| AuthError::MetadataError(format!("Failed to parse metadata: {}", e)))?; + Ok(metadata) + } else { + // fallback to default endpoints + let mut auth_base = self.base_url.clone(); + // discard the path part, only keep scheme, host, port + auth_base.set_path(""); + + Ok(AuthorizationMetadata { + authorization_endpoint: format!("{}/authorize", auth_base), + token_endpoint: format!("{}/token", auth_base), + registration_endpoint: Some(format!("{}/register", auth_base)), + issuer: None, + jwks_uri: None, + scopes_supported: None, + }) + } + } + + /// configure oauth2 client with client credentials + pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> { + if self.metadata.is_none() { + return Err(AuthError::NoAuthorizationSupport); + } + + let metadata = self.metadata.as_ref().unwrap(); + + let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?; + + let token_url = TokenUrl::new(metadata.token_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?; + + let client_id = ClientId::new(config.client_id); + let redirect_url = RedirectUrl::new(config.redirect_uri) + .map_err(|e| AuthError::OAuthError(format!("Invalid redirect URL: {}", e)))?; + + let mut client_builder = BasicClient::new(client_id.clone(), None, auth_url.clone(), Some(token_url.clone())) + .set_redirect_uri(redirect_url.clone()); + + if let Some(secret) = config.client_secret { + client_builder = BasicClient::new(client_id, Some(ClientSecret::new(secret)), auth_url, Some(token_url)) + .set_redirect_uri(redirect_url); + } + + self.oauth_client = Some(client_builder); + Ok(()) + } + + /// dynamic register oauth2 client + pub async fn register_client(&mut self, name: &str, redirect_uri: &str) -> Result { + if self.metadata.is_none() { + return Err(AuthError::NoAuthorizationSupport); + } + + let metadata = self.metadata.as_ref().unwrap(); + let registration_url = metadata.registration_endpoint.as_ref() + .ok_or_else(|| AuthError::NoAuthorizationSupport)?; + + // prepare registration request + let registration_request = serde_json::json!({ + "client_name": name, + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "token_endpoint_auth_method": "none", // public client + "response_types": ["code"], + }); + + let response = self.http_client + .post(registration_url) + .json(®istration_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(AuthError::OAuthError(format!( + "Client registration failed: HTTP {}", response.status() + ))); + } + + #[derive(Deserialize)] + struct RegistrationResponse { + client_id: String, + client_secret: Option, + } + + let reg_response = response.json::().await + .map_err(|e| AuthError::OAuthError(format!("Failed to parse registration response: {}", e)))?; + + let config = OAuthClientConfig { + client_id: reg_response.client_id, + client_secret: reg_response.client_secret, + redirect_uri: redirect_uri.to_string(), + scopes: vec![], + }; + + self.configure_client(config.clone())?; + Ok(config) + } + + /// generate authorization url + pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result { + let oauth_client = self.oauth_client.as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + // generate pkce challenge + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // build authorization request + let mut auth_request = oauth_client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); + + // add request scopes + for scope in scopes { + auth_request = auth_request.add_scope(Scope::new(scope.to_string())); + } + + let (auth_url, _csrf_token) = auth_request.url(); + + // store pkce verifier for later use + *self.pkce_verifier.write().await = Some(pkce_verifier); + + Ok(auth_url.to_string()) + } + + /// exchange authorization code for access token + pub async fn exchange_code_for_token(&self, code: &str) -> Result, AuthError> { + let oauth_client = self.oauth_client.as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let pkce_verifier = self.pkce_verifier.write().await.take().unwrap(); + + // exchange token + let token_result = oauth_client + .exchange_code(AuthorizationCode::new(code.to_string())) + .set_pkce_verifier(pkce_verifier) + .request(http_client) + .map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?; + + // store credentials + *self.credentials.write().await = Some(token_result.clone()); + + Ok(token_result) + } + + /// get access token, if expired, refresh it automatically + pub async fn get_access_token(&self) -> Result { + let credentials = self.credentials.read().await; + + if let Some(creds) = credentials.as_ref() { + // check if the token is expired + if let Some(expires_in) = creds.expires_in() { + if expires_in <= Duration::from_secs(0) { + // token expired, try to refresh + drop(credentials); // release the lock + let new_creds = self.refresh_token().await?; + return Ok(new_creds.access_token().secret().to_string()); + } + } + + Ok(creds.access_token().secret().to_string()) + } else { + Err(AuthError::AuthorizationRequired) + } + } + + /// refresh access token + pub async fn refresh_token(&self) -> Result, AuthError> { + let oauth_client = self.oauth_client.as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let current_credentials = self.credentials.read().await.clone() + .ok_or_else(|| AuthError::AuthorizationRequired)?; + + let refresh_token = current_credentials.refresh_token() + .ok_or_else(|| AuthError::TokenRefreshFailed("No refresh token available".to_string()))?; + + // refresh token + let token_result = oauth_client + .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) + .request(http_client) + .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; + + + // store new credentials + *self.credentials.write().await = Some(token_result.clone()); + + Ok(token_result) + } + + /// prepare request, add authorization header + pub async fn prepare_request(&self, mut request: reqwest::RequestBuilder) -> Result { + let token = self.get_access_token().await?; + Ok(request.header(AUTHORIZATION, format!("Bearer {}", token))) + } + + /// handle response, check if need to re-authorize + pub async fn handle_response(&self, response: reqwest::Response) -> Result { + if response.status() == StatusCode::UNAUTHORIZED { + // 401 Unauthorized, need to re-authorize + Err(AuthError::AuthorizationRequired) + } else { + Ok(response) + } + } +} + +/// oauth2 authorization session, for guiding user to complete the authorization process +pub struct AuthorizationSession { + pub auth_manager: Arc>, + pub auth_url: String, + pub redirect_uri: String, + pub pkce_verifier: PkceCodeVerifier, +} + +impl AuthorizationSession { + /// create new authorization session + pub async fn new( + auth_manager: Arc>, + scopes: &[&str], + redirect_uri: &str, + ) -> Result { + // set redirect uri + let config = OAuthClientConfig { + client_id: "mcp-client".to_string(), // temporary id, will be updated by dynamic registration + client_secret: None, + redirect_uri: redirect_uri.to_string(), + scopes: scopes.iter().map(|s| s.to_string()).collect(), + }; + + // try to dynamic register client + let config = match auth_manager.lock().await.register_client("MCP Client", redirect_uri).await { + Ok(config) => config, + Err(e) => { + eprintln!("Dynamic registration failed: {}", e); + // fallback to default config + config + } + }; + + let auth_url= auth_manager.lock().await.get_authorization_url(scopes).await?; + let pkce_verifier = auth_manager.lock().await.pkce_verifier.write().await.take().unwrap(); + Ok(Self { + auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + pkce_verifier, + }) + } + + /// get authorization url + pub fn get_authorization_url(&self) -> &str { + &self.auth_url + } + + /// handle authorization code callback + pub async fn handle_callback(&self, code: &str) -> Result, AuthError> { + self.auth_manager.lock().await.exchange_code_for_token(code).await + } +} + +/// http client extension, automatically add authorization header +pub struct AuthorizedHttpClient { + auth_manager: Arc, + inner_client: HttpClient, +} + +impl AuthorizedHttpClient { + /// create new authorized http client + pub fn new(auth_manager: Arc, client: Option) -> Self { + let inner_client = client.unwrap_or_else(|| HttpClient::new()); + Self { + auth_manager, + inner_client, + } + } + + /// send authorized request + pub async fn request(&self, method: reqwest::Method, url: U) -> Result { + let request = self.inner_client.request(method, url); + self.auth_manager.prepare_request(request).await + } + + /// send get request + pub async fn get(&self, url: U) -> Result { + let request = self.request(reqwest::Method::GET, url).await?; + let response = request.send().await?; + self.auth_manager.handle_response(response).await + } + + /// send post request + pub async fn post(&self, url: U) -> Result { + self.request(reqwest::Method::POST, url).await + } +} \ No newline at end of file diff --git a/crates/rmcp/src/transport/sse_auth.rs b/crates/rmcp/src/transport/sse_auth.rs new file mode 100644 index 0000000..3a4798f --- /dev/null +++ b/crates/rmcp/src/transport/sse_auth.rs @@ -0,0 +1,183 @@ +use std::sync::Arc; +use std::time::Duration; + +use futures::{Future, Sink, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; +use futures::sink::SinkExt as FuturesSinkExt; +use reqwest::{ + Client as HttpClient, IntoUrl, Url, + header::{ACCEPT, AUTHORIZATION, HeaderValue}, +}; +use sse_stream::{Error as SseError, Sse, SseStream}; +use thiserror::Error; +use tokio::sync::Mutex; + +use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; +use super::auth::{AuthorizationManager, AuthError}; +use super::sse::{SseTransportError, SseClient, SseTransport, SseTransportRetryConfig}; + +// SSE MIME type +const MIME_TYPE: &str = "text/event-stream"; +const HEADER_LAST_EVENT_ID: &str = "Last-Event-ID"; + +/// sse client with oauth2 authorization +#[derive(Clone)] +pub struct AuthorizedSseClient { + http_client: HttpClient, + sse_url: Url, + auth_manager: Arc>, + retry_config: SseTransportRetryConfig, +} + +impl AuthorizedSseClient { + /// create new authorized sse client + pub fn new( + url: U, + auth_manager: Arc>, + retry_config: Option, + ) -> Result> + where + U: IntoUrl, + { + let url = url.into_url().map_err(SseTransportError::from)?; + Ok(Self { + http_client: HttpClient::default(), + sse_url: url, + auth_manager, + retry_config: retry_config.unwrap_or_default(), + }) + } + + /// create authorized sse client with custom http client + pub async fn new_with_client( + url: U, + client: HttpClient, + auth_manager: Arc>, + retry_config: Option, + ) -> Result> + where + U: IntoUrl, + { + let url = url.into_url().map_err(SseTransportError::from)?; + Ok(Self { + http_client: client, + sse_url: url, + auth_manager, + retry_config: retry_config.unwrap_or_default(), + }) + } + + /// get access token, support retry + async fn get_token_with_retry(&self) -> Result> { + let mut retries = 0; + let max_retries = self.retry_config.max_times; + let base_delay = self.retry_config.min_duration; + + loop { + match self.auth_manager.lock().await.get_access_token().await { + Ok(token) => return Ok(token), + Err(AuthError::AuthorizationRequired) => { + return Err(SseTransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, "Authorization required"))); + } + Err(_e) => { + if retries >= max_retries.unwrap_or(0) { + return Err(SseTransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, "Authorization required"))); + } + retries += 1; + // todo: need to optimize + let delay = base_delay.as_millis(); + tokio::time::sleep(Duration::from_millis(delay as u64)).await; + } + } + } + } +} + +impl SseClient for AuthorizedSseClient { + fn connect(&self, last_event_id: Option) -> BoxFuture<'static, Result>, SseTransportError>> { + let client = self.http_client.clone(); + let sse_url = self.sse_url.as_ref().to_string(); + let last_event_id = last_event_id.clone(); + let auth_manager = self.auth_manager.clone(); + + let fut = async move { + // get access token + let token = auth_manager.lock().await.get_access_token().await?; + + // build request + let mut request_builder = client.get(&sse_url) + .header(ACCEPT, MIME_TYPE) + .header(AUTHORIZATION, format!("Bearer {}", token)); + + if let Some(last_event_id) = last_event_id { + request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); + } + + let response = request_builder.send().await?; + let response = response.error_for_status()?; + + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(ct) => { + if !ct.as_bytes().starts_with(MIME_TYPE.as_bytes()) { + return Err(SseTransportError::UnexpectedContentType(Some(ct.clone()))); + } + } + None => { + return Err(SseTransportError::UnexpectedContentType(None)); + } + } + + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(event_stream) + }; + + Box::pin(fut) + } + + fn post( + &self, + session_id: &str, + message: ClientJsonRpcMessage, + ) -> BoxFuture<'static, Result<(), SseTransportError>> { + let client = self.http_client.clone(); + let sse_url = self.sse_url.clone(); + let session_id = session_id.to_string(); + let auth_manager = self.auth_manager.clone(); + + Box::pin(async move { + // get access token + let token = auth_manager.lock().await.get_access_token().await + .map_err(|e| SseTransportError::::from(e))?; + + let uri = sse_url.join(&session_id).map_err(SseTransportError::from)?; + let request_builder = client.post(uri.as_ref()) + .header(AUTHORIZATION, format!("Bearer {}", token)) + .json(&message); + + request_builder + .send() + .await + .and_then(|resp| resp.error_for_status()) + .map_err(SseTransportError::from) + .map(drop) + }) + } +} + +impl From for SseTransportError { + fn from(err: AuthError) -> Self { + SseTransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, err.to_string())) + } +} + +/// create authorized sse transport +pub async fn create_authorized_transport( + url: U, + auth_manager: Arc>, + retry_config: Option, +) -> Result, SseTransportError> +where + U: IntoUrl, +{ + let client = AuthorizedSseClient::new(url, auth_manager, retry_config)?; + SseTransport::start_with_client(client).await +} \ No newline at end of file diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index 52226ed..a84a11d 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -11,7 +11,8 @@ rmcp = { path = "../../crates/rmcp", features = [ "client", "transport-sse", "transport-child-process", - "tower" + "tower", + "auth" ] } tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } @@ -21,7 +22,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } rand = "0.8" futures = "0.3" anyhow = "1.0" - +url = "2.4" tower = "0.5" [[example]] @@ -40,3 +41,6 @@ path = "src/everything_stdio.rs" name = "clients_collection" path = "src/collection.rs" +[[example]] +name = "oauth_client" +path = "src/oauth_client.rs" \ No newline at end of file diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs new file mode 100644 index 0000000..2e07825 --- /dev/null +++ b/examples/clients/src/oauth_client.rs @@ -0,0 +1,90 @@ +use std::{sync::Arc, time::Duration}; + +use anyhow::Result; +use rmcp::{ + RoleClient, ServiceExt, + model::ClientInfo, + transport::{ + auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}, + create_authorized_transport, + sse::SseTransportRetryConfig, + }, +}; +use tokio::{ + io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + sync::Mutex, +}; +use url::Url; + +#[tokio::main] +async fn main() -> Result<()> { + // init logger + tracing_subscriber::fmt::init(); + + // server url + let server_url = + std::env::var("MCP_SERVER_URL").unwrap_or_else(|_| "http://localhost:3000/mcp".to_string()); + + // retry config + let retry_config = SseTransportRetryConfig { + max_times: Some(3), + min_duration: Duration::from_secs(1), + }; + + // init auth manager + let auth_manager = AuthorizationManager::new(&server_url).await?; + let auth_manager_arc = Arc::new(Mutex::new(auth_manager)); + + // create authorization session + let session = AuthorizationSession::new( + auth_manager_arc.clone(), + &["mcp"], // request scopes + "http://localhost:8080/callback", // redirect uri + ) + .await?; + + // output authorization url + let mut output = BufWriter::new(tokio::io::stdout()); + output + .write_all(b"please open the following URL in your browser:\n") + .await?; + output + .write_all(session.get_authorization_url().as_bytes()) + .await?; + output + .write_all(b"\nplease input the authorization code:\n") + .await?; + output.flush().await?; + + // read authorization code + let mut auth_code = String::new(); + let mut reader = BufReader::new(tokio::io::stdin()); + reader.read_line(&mut auth_code).await?; + let auth_code = auth_code.trim(); + + // exchange access token + let credentials = session.handle_callback(auth_code).await?; + tracing::info!("Successfully obtained access token"); + + // create authorized sse transport, use retry config + let transport = + create_authorized_transport(&server_url, auth_manager_arc, Some(retry_config)).await?; + + // create client + let client_service = ClientInfo::default(); + let client = client_service.serve(transport).await?; + + // test api request + let tools = client.peer().list_all_tools().await?; + tracing::info!("Available tools: {tools:#?}"); + + // get prompt list + let prompts = client.peer().list_all_prompts().await?; + tracing::info!("Available prompts: {prompts:#?}"); + + // get resource list + let resources = client.peer().list_all_resources().await?; + tracing::info!("Available resources: {resources:#?}"); + + Ok(()) +} From 7cdc64036f771478c8224da20e2f363d5032df2b Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Tue, 22 Apr 2025 14:40:25 +0800 Subject: [PATCH 2/6] chore: add oauth2 server example Signed-off-by: jokemanfire --- crates/rmcp/src/transport/auth.rs | 3 +- examples/clients/src/oauth_client.rs | 7 +- examples/servers/Cargo.toml | 10 +- examples/servers/src/auth_sse.rs | 217 +++++++++++++++++++++++++++ 4 files changed, 227 insertions(+), 10 deletions(-) create mode 100644 examples/servers/src/auth_sse.rs diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 7108486..52e44eb 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,6 +1,5 @@ use std::sync::Arc; use std::time::Duration; -use futures::future::BoxFuture; use oauth2::basic::BasicTokenType; use oauth2::{ AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, @@ -14,7 +13,7 @@ use thiserror::Error; use tokio::sync::{Mutex, RwLock}; use tokio::time::{self, Instant}; -/// 错误定义 +/// Auth error #[derive(Debug, Error)] pub enum AuthError { #[error("OAuth authorization required")] diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs index 2e07825..bae4429 100644 --- a/examples/clients/src/oauth_client.rs +++ b/examples/clients/src/oauth_client.rs @@ -2,19 +2,18 @@ use std::{sync::Arc, time::Duration}; use anyhow::Result; use rmcp::{ - RoleClient, ServiceExt, + ServiceExt, model::ClientInfo, transport::{ - auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}, + auth::{AuthorizationManager, AuthorizationSession}, create_authorized_transport, sse::SseTransportRetryConfig, }, }; use tokio::{ - io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, + io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, sync::Mutex, }; -use url::Url; #[tokio::main] async fn main() -> Result<()> { diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 63d9d6d..99b145c 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -1,5 +1,3 @@ - - [package] name = "mcp-server-examples" version = "0.1.5" @@ -7,7 +5,7 @@ edition = "2024" publish = false [dependencies] -rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io"] } +rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io", "auth"] } tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "io-std", "signal"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -43,4 +41,8 @@ path = "src/axum_router.rs" [[example]] name = "servers_generic_server" -path = "src/generic_service.rs" \ No newline at end of file +path = "src/generic_service.rs" + +[[example]] +name = "servers_auth_sse" +path = "src/auth_sse.rs" \ No newline at end of file diff --git a/examples/servers/src/auth_sse.rs b/examples/servers/src/auth_sse.rs new file mode 100644 index 0000000..2f10edb --- /dev/null +++ b/examples/servers/src/auth_sse.rs @@ -0,0 +1,217 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use anyhow::Result; +use axum::{ + extract::{Path, State}, + http::{HeaderMap, Request, StatusCode}, + middleware::{self, Next}, + response::{Html, Response}, + routing::get, + Json, Router, +}; +use rmcp::{ + service::ServiceExt, + transport::{auth::AuthError, SseServer, sse_server::SseServerConfig}, + ServerHandler, tool, +}; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; +mod common; +use common::{calculator::Calculator, counter::Counter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; +// A simple token store +struct TokenStore { + valid_tokens: Vec, +} + +impl TokenStore { + fn new() -> Self { + // For demonstration purposes, use more secure token management in production + Self { + valid_tokens: vec!["demo-token".to_string(), "test-token".to_string()], + } + } + + fn is_valid(&self, token: &str) -> bool { + self.valid_tokens.contains(&token.to_string()) + } +} + +// Extract authorization token +fn extract_token(headers: &HeaderMap) -> Option { + headers + .get("Authorization") + .and_then(|value| value.to_str().ok()) + .and_then(|auth_header| { + if auth_header.starts_with("Bearer ") { + Some(auth_header[7..].to_string()) + } else { + None + } + }) +} + +// Authorization middleware +async fn auth_middleware( + State(token_store): State>, + headers: HeaderMap, + request: Request, + next: Next, +) -> Result { + match extract_token(&headers) { + Some(token) if token_store.is_valid(&token) => { + // Token is valid, proceed with the request + Ok(next.run(request).await) + } + _ => { + // Token is invalid, return 401 error + Err(StatusCode::UNAUTHORIZED) + } + } +} + +// Root path handler +async fn index() -> Html<&'static str> { + Html(r#" + + + + RMCP Authorized SSE Server + + + +

RMCP Authorized SSE Server

+

This is a Server-Sent Events server example that requires OAuth authorization.

+ +

Available Endpoints:

+
    +
  • /api/health - Health check
  • +
  • /api/token/{token_id} - Get test token (available: demo, test)
  • +
  • /sse - SSE connection endpoint (requires authorization)
  • +
  • /message - Message sending endpoint (requires authorization)
  • +
+ +

Usage:

+
+# Get a token
+curl http://127.0.0.1:8000/api/token/demo
+
+# Connect to SSE using the token
+curl -H "Authorization: Bearer demo-token" http://127.0.0.1:8000/sse
+        
+ + + "#) +} + +// Health check endpoint +async fn health_check() -> &'static str { + "OK" +} + +// Token generation endpoint (simplified example) +async fn get_token(Path(token_id): Path) -> Result, StatusCode> { + // In a real application, you should authenticate the user and generate a real token + if token_id == "demo" || token_id == "test" { + let token = format!("{}-token", token_id); + Ok(Json(serde_json::json!({ + "access_token": token, + "token_type": "Bearer", + "expires_in": 3600 + }))) + } else { + Err(StatusCode::UNAUTHORIZED) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Create token store + let token_store = Arc::new(TokenStore::new()); + + // Set up port + let addr = BIND_ADDRESS.parse::()?; + + // Create SSE server configuration + let sse_config = SseServerConfig { + bind: addr, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + // Create SSE server + let (sse_server, sse_router) = SseServer::new(sse_config); + + // Create API routes + let api_routes = Router::new() + .route("/health", get(health_check)) + .route("/token/{token_id}", get(get_token)); + + // Create protected SSE routes (require authorization) + let protected_sse_router = sse_router + .layer(middleware::from_fn_with_state( + token_store.clone(), + auth_middleware, + )); + + // Create main router, public endpoints don't require authorization + let app = Router::new() + .route("/", get(index)) + .nest("/api", api_routes) + .merge(protected_sse_router) + .with_state(()); + + // Start server and register service + let listener = tokio::net::TcpListener::bind(addr).await?; + let ct = sse_server.config.ct.clone(); + + // Start SSE server with Counter service + sse_server.with_service(Counter::new); + + // Handle signals for graceful shutdown + let cancel_token = ct.clone(); + tokio::spawn(async move { + match tokio::signal::ctrl_c().await { + Ok(()) => { + println!("Received Ctrl+C, shutting down server..."); + cancel_token.cancel(); + } + Err(err) => { + eprintln!("Unable to listen for Ctrl+C signal: {}", err); + } + } + }); + + // Start HTTP server + tracing::info!("Server started on {}", addr); + let server = axum::serve(listener, app).with_graceful_shutdown(async move { + // Wait for cancellation signal + ct.cancelled().await; + println!("Server is shutting down..."); + }); + + if let Err(e) = server.await { + eprintln!("Server error: {}", e); + } + + println!("Server has been shut down"); + Ok(()) +} \ No newline at end of file From c20878331cefb879a91f8c6849daa84d220ff9e4 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Tue, 22 Apr 2025 20:25:44 +0800 Subject: [PATCH 3/6] chore: complete the oauth2 before exchange code Signed-off-by: jokemanfire --- crates/rmcp/src/transport.rs | 2 +- crates/rmcp/src/transport/auth.rs | 377 ++++++++----- crates/rmcp/src/transport/sse_auth.rs | 72 ++- examples/clients/Cargo.toml | 3 +- examples/clients/src/oauth_client.rs | 264 +++++++-- examples/servers/Cargo.toml | 11 +- examples/servers/src/auth_sse.rs | 37 +- examples/servers/src/mcp_oauth_server.rs | 674 +++++++++++++++++++++++ 8 files changed, 1220 insertions(+), 220 deletions(-) create mode 100644 examples/servers/src/mcp_oauth_server.rs diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index 3158478..9b49f76 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -72,7 +72,7 @@ pub use sse_server::SseServer; #[cfg(feature = "auth")] pub mod auth; #[cfg(feature = "auth")] -pub use auth::{AuthorizationManager, AuthorizationSession, AuthorizedHttpClient, AuthError}; +pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}; // #[cfg(feature = "transport-ws")] // pub mod ws; diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 52e44eb..4d8c0d4 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,60 +1,64 @@ -use std::sync::Arc; -use std::time::Duration; -use oauth2::basic::BasicTokenType; +use std::{sync::Arc, time::Duration}; + use oauth2::{ - AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, - RedirectUrl, TokenResponse, Scope, AuthUrl, TokenUrl, RefreshToken, - StandardTokenResponse, TokenType, AccessToken, EmptyExtraTokenFields, - basic::BasicClient, reqwest::http_client, RefreshTokenRequest, AuthorizationRequest + AccessToken, AuthUrl, AuthorizationCode, AuthorizationRequest, ClientId, ClientSecret, + CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, + RefreshToken, RefreshTokenRequest, Scope, StandardTokenResponse, TokenResponse, TokenType, + TokenUrl, + basic::{BasicClient, BasicTokenType}, + reqwest::http_client, }; -use reqwest::{Client as HttpClient, header::AUTHORIZATION, StatusCode, Url, IntoUrl}; +use reqwest::{Client as HttpClient, IntoUrl, StatusCode, Url, header::AUTHORIZATION}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::sync::{Mutex, RwLock}; -use tokio::time::{self, Instant}; +use tokio::{ + sync::{Mutex, RwLock}, + time::{self, Instant}, +}; +use tracing::{debug, error}; /// Auth error #[derive(Debug, Error)] pub enum AuthError { #[error("OAuth authorization required")] AuthorizationRequired, - + #[error("OAuth authorization failed: {0}")] AuthorizationFailed(String), - + #[error("OAuth token exchange failed: {0}")] TokenExchangeFailed(String), - + #[error("OAuth token refresh failed: {0}")] TokenRefreshFailed(String), - + #[error("HTTP error: {0}")] HttpError(#[from] reqwest::Error), - + #[error("OAuth error: {0}")] OAuthError(String), - + #[error("Metadata error: {0}")] MetadataError(String), - + #[error("URL parse error: {0}")] UrlError(#[from] url::ParseError), - + #[error("No authorization support detected")] NoAuthorizationSupport, - + #[error("Internal error: {0}")] InternalError(String), - + #[error("Invalid token type: {0}")] InvalidTokenType(String), - + #[error("Token expired")] TokenExpired, - + #[error("Invalid scope: {0}")] InvalidScope(String), - + #[error("Registration failed: {0}")] RegistrationFailed(String), } @@ -64,7 +68,7 @@ pub enum AuthError { pub struct AuthorizationMetadata { pub authorization_endpoint: String, pub token_endpoint: String, - pub registration_endpoint: Option, + pub registration_endpoint: String, pub issuer: Option, pub jwks_uri: Option, pub scopes_supported: Option>, @@ -75,8 +79,8 @@ pub struct AuthorizationMetadata { pub struct OAuthClientConfig { pub client_id: String, pub client_secret: Option, - pub redirect_uri: String, pub scopes: Vec, + pub redirect_uri: String, } /// oauth2 auth manager @@ -89,6 +93,23 @@ pub struct AuthorizationManager { base_url: Url, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationRequest { + pub client_name: String, + pub redirect_uris: Vec, + pub grant_types: Vec, + pub token_endpoint_auth_method: String, + pub response_types: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationResponse { + pub client_id: String, + pub client_secret: String, + pub client_name: String, + pub redirect_uris: Vec, +} + impl AuthorizationManager { /// create new auth manager pub async fn new(base_url: U) -> Result { @@ -97,7 +118,7 @@ impl AuthorizationManager { .timeout(Duration::from_secs(30)) .build() .map_err(|e| AuthError::InternalError(e.to_string()))?; - + let mut manager = Self { http_client, metadata: None, @@ -106,179 +127,233 @@ impl AuthorizationManager { pkce_verifier: RwLock::new(None), base_url, }; - + // try to discover oauth2 metadata if let Ok(metadata) = manager.discover_metadata().await { manager.metadata = Some(metadata); } - + Ok(manager) } - + /// discover oauth2 metadata pub async fn discover_metadata(&self) -> Result { // according to the specification, the metadata should be located at "/.well-known/oauth-authorization-server" let mut discovery_url = self.base_url.clone(); discovery_url.set_path("/.well-known/oauth-authorization-server"); - - let response = self.http_client + debug!("discovery url: {:?}", discovery_url); + let response = self + .http_client .get(discovery_url) .header("MCP-Protocol-Version", "2024-11-05") .send() .await?; - + if response.status() == StatusCode::OK { - let metadata = response.json::().await - .map_err(|e| AuthError::MetadataError(format!("Failed to parse metadata: {}", e)))?; + let metadata = response + .json::() + .await + .map_err(|e| { + AuthError::MetadataError(format!("Failed to parse metadata: {}", e)) + })?; + debug!("metadata: {:?}", metadata); Ok(metadata) } else { // fallback to default endpoints let mut auth_base = self.base_url.clone(); // discard the path part, only keep scheme, host, port auth_base.set_path(""); - + Ok(AuthorizationMetadata { authorization_endpoint: format!("{}/authorize", auth_base), token_endpoint: format!("{}/token", auth_base), - registration_endpoint: Some(format!("{}/register", auth_base)), + registration_endpoint: format!("{}/register", auth_base), issuer: None, jwks_uri: None, scopes_supported: None, }) } } - + /// configure oauth2 client with client credentials pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> { if self.metadata.is_none() { return Err(AuthError::NoAuthorizationSupport); } - + let metadata = self.metadata.as_ref().unwrap(); - + let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone()) .map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?; - + let token_url = TokenUrl::new(metadata.token_endpoint.clone()) .map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?; - + let client_id = ClientId::new(config.client_id); - let redirect_url = RedirectUrl::new(config.redirect_uri) - .map_err(|e| AuthError::OAuthError(format!("Invalid redirect URL: {}", e)))?; - - let mut client_builder = BasicClient::new(client_id.clone(), None, auth_url.clone(), Some(token_url.clone())) - .set_redirect_uri(redirect_url.clone()); - + let redirect_url = RedirectUrl::new(config.redirect_uri.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid registry URL: {}", e)))?; + + let mut client_builder = BasicClient::new( + client_id.clone(), + None, + auth_url.clone(), + Some(token_url.clone()), + ) + .set_redirect_uri(redirect_url.clone()); + if let Some(secret) = config.client_secret { - client_builder = BasicClient::new(client_id, Some(ClientSecret::new(secret)), auth_url, Some(token_url)) - .set_redirect_uri(redirect_url); + client_builder = BasicClient::new( + client_id, + Some(ClientSecret::new(secret)), + auth_url, + Some(token_url), + ) + .set_redirect_uri(redirect_url); } - + self.oauth_client = Some(client_builder); Ok(()) } - + /// dynamic register oauth2 client - pub async fn register_client(&mut self, name: &str, redirect_uri: &str) -> Result { + pub async fn register_client( + &mut self, + name: &str, + redirect_uri: &str, + ) -> Result { if self.metadata.is_none() { + error!("No authorization support detected"); return Err(AuthError::NoAuthorizationSupport); } - + let metadata = self.metadata.as_ref().unwrap(); - let registration_url = metadata.registration_endpoint.as_ref() - .ok_or_else(|| AuthError::NoAuthorizationSupport)?; - + let registration_url = metadata.registration_endpoint.clone(); + + debug!("registration url: {:?}", registration_url); // prepare registration request - let registration_request = serde_json::json!({ - "client_name": name, - "redirect_uris": [redirect_uri], - "grant_types": ["authorization_code", "refresh_token"], - "token_endpoint_auth_method": "none", // public client - "response_types": ["code"], - }); - - let response = self.http_client + let registration_request = ClientRegistrationRequest { + client_name: name.to_string(), + redirect_uris: vec![redirect_uri.to_string()], + grant_types: vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ], + token_endpoint_auth_method: "none".to_string(), // public client + response_types: vec!["code".to_string()], + }; + + debug!("registration request: {:?}", registration_request); + + let response = match self + .http_client .post(registration_url) .json(®istration_request) .send() - .await?; - + .await + { + Ok(response) => response, + Err(e) => { + error!("Registration request failed: {}", e); + return Err(AuthError::RegistrationFailed(format!( + "HTTP request error: {}", + e + ))); + } + }; + if !response.status().is_success() { - return Err(AuthError::OAuthError(format!( - "Client registration failed: HTTP {}", response.status() + let status = response.status(); + let error_text = match response.text().await { + Ok(text) => text, + Err(_) => "cannot get error details".to_string(), + }; + + error!("Registration failed: HTTP {} - {}", status, error_text); + return Err(AuthError::RegistrationFailed(format!( + "HTTP {}: {}", + status, error_text ))); } - - #[derive(Deserialize)] - struct RegistrationResponse { - client_id: String, - client_secret: Option, - } - - let reg_response = response.json::().await - .map_err(|e| AuthError::OAuthError(format!("Failed to parse registration response: {}", e)))?; - + + let reg_response = match response.json::().await { + Ok(response) => response, + Err(e) => { + error!("Failed to parse registration response: {}", e); + return Err(AuthError::RegistrationFailed(format!( + "analyze response error: {}", + e + ))); + } + }; + let config = OAuthClientConfig { client_id: reg_response.client_id, - client_secret: reg_response.client_secret, + client_secret: Some(reg_response.client_secret), redirect_uri: redirect_uri.to_string(), scopes: vec![], }; - + self.configure_client(config.clone())?; Ok(config) } - + /// generate authorization url pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result { - let oauth_client = self.oauth_client.as_ref() + let oauth_client = self + .oauth_client + .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - + // generate pkce challenge let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - + // build authorization request let mut auth_request = oauth_client .authorize_url(CsrfToken::new_random) .set_pkce_challenge(pkce_challenge); - + // add request scopes for scope in scopes { auth_request = auth_request.add_scope(Scope::new(scope.to_string())); } - + let (auth_url, _csrf_token) = auth_request.url(); // store pkce verifier for later use *self.pkce_verifier.write().await = Some(pkce_verifier); - + Ok(auth_url.to_string()) } - + /// exchange authorization code for access token - pub async fn exchange_code_for_token(&self, code: &str) -> Result, AuthError> { - let oauth_client = self.oauth_client.as_ref() + pub async fn exchange_code_for_token( + &self, + code: &str, + ) -> Result, AuthError> { + let oauth_client = self + .oauth_client + .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - + let pkce_verifier = self.pkce_verifier.write().await.take().unwrap(); - + // exchange token let token_result = oauth_client .exchange_code(AuthorizationCode::new(code.to_string())) .set_pkce_verifier(pkce_verifier) .request(http_client) .map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?; - + // store credentials *self.credentials.write().await = Some(token_result.clone()); - + Ok(token_result) } - + /// get access token, if expired, refresh it automatically pub async fn get_access_token(&self) -> Result { let credentials = self.credentials.read().await; - + if let Some(creds) = credentials.as_ref() { // check if the token is expired if let Some(expires_in) = creds.expires_in() { @@ -289,45 +364,59 @@ impl AuthorizationManager { return Ok(new_creds.access_token().secret().to_string()); } } - + Ok(creds.access_token().secret().to_string()) } else { Err(AuthError::AuthorizationRequired) } } - + /// refresh access token - pub async fn refresh_token(&self) -> Result, AuthError> { - let oauth_client = self.oauth_client.as_ref() + pub async fn refresh_token( + &self, + ) -> Result, AuthError> { + let oauth_client = self + .oauth_client + .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - - let current_credentials = self.credentials.read().await.clone() + + let current_credentials = self + .credentials + .read() + .await + .clone() .ok_or_else(|| AuthError::AuthorizationRequired)?; - - let refresh_token = current_credentials.refresh_token() - .ok_or_else(|| AuthError::TokenRefreshFailed("No refresh token available".to_string()))?; - + + let refresh_token = current_credentials.refresh_token().ok_or_else(|| { + AuthError::TokenRefreshFailed("No refresh token available".to_string()) + })?; + // refresh token let token_result = oauth_client .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) .request(http_client) .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; - - + // store new credentials *self.credentials.write().await = Some(token_result.clone()); - + Ok(token_result) } - + /// prepare request, add authorization header - pub async fn prepare_request(&self, mut request: reqwest::RequestBuilder) -> Result { + pub async fn prepare_request( + &self, + mut request: reqwest::RequestBuilder, + ) -> Result { let token = self.get_access_token().await?; Ok(request.header(AUTHORIZATION, format!("Bearer {}", token))) } - + /// handle response, check if need to re-authorize - pub async fn handle_response(&self, response: reqwest::Response) -> Result { + pub async fn handle_response( + &self, + response: reqwest::Response, + ) -> Result { if response.status() == StatusCode::UNAUTHORIZED { // 401 Unauthorized, need to re-authorize Err(AuthError::AuthorizationRequired) @@ -348,7 +437,7 @@ pub struct AuthorizationSession { impl AuthorizationSession { /// create new authorization session pub async fn new( - auth_manager: Arc>, + auth_manager: Arc>, scopes: &[&str], redirect_uri: &str, ) -> Result { @@ -356,12 +445,17 @@ impl AuthorizationSession { let config = OAuthClientConfig { client_id: "mcp-client".to_string(), // temporary id, will be updated by dynamic registration client_secret: None, - redirect_uri: redirect_uri.to_string(), scopes: scopes.iter().map(|s| s.to_string()).collect(), + redirect_uri: redirect_uri.to_string(), }; - + // try to dynamic register client - let config = match auth_manager.lock().await.register_client("MCP Client", redirect_uri).await { + let config = match auth_manager + .lock() + .await + .register_client("MCP Client", redirect_uri) + .await + { Ok(config) => config, Err(e) => { eprintln!("Dynamic registration failed: {}", e); @@ -369,9 +463,21 @@ impl AuthorizationSession { config } }; - - let auth_url= auth_manager.lock().await.get_authorization_url(scopes).await?; - let pkce_verifier = auth_manager.lock().await.pkce_verifier.write().await.take().unwrap(); + // reset client config + auth_manager.lock().await.configure_client(config)?; + let auth_url = auth_manager + .lock() + .await + .get_authorization_url(scopes) + .await?; + let pkce_verifier = auth_manager + .lock() + .await + .pkce_verifier + .write() + .await + .take() + .unwrap(); Ok(Self { auth_manager, auth_url, @@ -379,15 +485,22 @@ impl AuthorizationSession { pkce_verifier, }) } - + /// get authorization url pub fn get_authorization_url(&self) -> &str { &self.auth_url } - + /// handle authorization code callback - pub async fn handle_callback(&self, code: &str) -> Result, AuthError> { - self.auth_manager.lock().await.exchange_code_for_token(code).await + pub async fn handle_callback( + &self, + code: &str, + ) -> Result, AuthError> { + self.auth_manager + .lock() + .await + .exchange_code_for_token(code) + .await } } @@ -406,22 +519,26 @@ impl AuthorizedHttpClient { inner_client, } } - + /// send authorized request - pub async fn request(&self, method: reqwest::Method, url: U) -> Result { + pub async fn request( + &self, + method: reqwest::Method, + url: U, + ) -> Result { let request = self.inner_client.request(method, url); self.auth_manager.prepare_request(request).await } - + /// send get request pub async fn get(&self, url: U) -> Result { let request = self.request(reqwest::Method::GET, url).await?; let response = request.send().await?; self.auth_manager.handle_response(response).await } - + /// send post request pub async fn post(&self, url: U) -> Result { self.request(reqwest::Method::POST, url).await } -} \ No newline at end of file +} diff --git a/crates/rmcp/src/transport/sse_auth.rs b/crates/rmcp/src/transport/sse_auth.rs index 3a4798f..e8e6957 100644 --- a/crates/rmcp/src/transport/sse_auth.rs +++ b/crates/rmcp/src/transport/sse_auth.rs @@ -1,8 +1,9 @@ -use std::sync::Arc; -use std::time::Duration; +use std::{sync::Arc, time::Duration}; -use futures::{Future, Sink, Stream, StreamExt, future::BoxFuture, stream::BoxStream}; -use futures::sink::SinkExt as FuturesSinkExt; +use futures::{ + Future, Sink, Stream, StreamExt, future::BoxFuture, sink::SinkExt as FuturesSinkExt, + stream::BoxStream, +}; use reqwest::{ Client as HttpClient, IntoUrl, Url, header::{ACCEPT, AUTHORIZATION, HeaderValue}, @@ -11,9 +12,11 @@ use sse_stream::{Error as SseError, Sse, SseStream}; use thiserror::Error; use tokio::sync::Mutex; +use super::{ + auth::{AuthError, AuthorizationManager}, + sse::{SseClient, SseTransport, SseTransportError, SseTransportRetryConfig}, +}; use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; -use super::auth::{AuthorizationManager, AuthError}; -use super::sse::{SseTransportError, SseClient, SseTransport, SseTransportRetryConfig}; // SSE MIME type const MIME_TYPE: &str = "text/event-stream"; @@ -76,11 +79,17 @@ impl AuthorizedSseClient { match self.auth_manager.lock().await.get_access_token().await { Ok(token) => return Ok(token), Err(AuthError::AuthorizationRequired) => { - return Err(SseTransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, "Authorization required"))); + return Err(SseTransportError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + "Authorization required", + ))); } Err(_e) => { if retries >= max_retries.unwrap_or(0) { - return Err(SseTransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, "Authorization required"))); + return Err(SseTransportError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + "Authorization required", + ))); } retries += 1; // todo: need to optimize @@ -93,28 +102,35 @@ impl AuthorizedSseClient { } impl SseClient for AuthorizedSseClient { - fn connect(&self, last_event_id: Option) -> BoxFuture<'static, Result>, SseTransportError>> { + fn connect( + &self, + last_event_id: Option, + ) -> BoxFuture< + 'static, + Result>, SseTransportError>, + > { let client = self.http_client.clone(); let sse_url = self.sse_url.as_ref().to_string(); let last_event_id = last_event_id.clone(); let auth_manager = self.auth_manager.clone(); - + let fut = async move { // get access token let token = auth_manager.lock().await.get_access_token().await?; - + // build request - let mut request_builder = client.get(&sse_url) + let mut request_builder = client + .get(&sse_url) .header(ACCEPT, MIME_TYPE) .header(AUTHORIZATION, format!("Bearer {}", token)); - + if let Some(last_event_id) = last_event_id { request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); } - + let response = request_builder.send().await?; let response = response.error_for_status()?; - + match response.headers().get(reqwest::header::CONTENT_TYPE) { Some(ct) => { if !ct.as_bytes().starts_with(MIME_TYPE.as_bytes()) { @@ -125,11 +141,11 @@ impl SseClient for AuthorizedSseClient { return Err(SseTransportError::UnexpectedContentType(None)); } } - + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); Ok(event_stream) }; - + Box::pin(fut) } @@ -142,17 +158,22 @@ impl SseClient for AuthorizedSseClient { let sse_url = self.sse_url.clone(); let session_id = session_id.to_string(); let auth_manager = self.auth_manager.clone(); - + Box::pin(async move { // get access token - let token = auth_manager.lock().await.get_access_token().await + let token = auth_manager + .lock() + .await + .get_access_token() + .await .map_err(|e| SseTransportError::::from(e))?; - + let uri = sse_url.join(&session_id).map_err(SseTransportError::from)?; - let request_builder = client.post(uri.as_ref()) + let request_builder = client + .post(uri.as_ref()) .header(AUTHORIZATION, format!("Bearer {}", token)) .json(&message); - + request_builder .send() .await @@ -165,7 +186,10 @@ impl SseClient for AuthorizedSseClient { impl From for SseTransportError { fn from(err: AuthError) -> Self { - SseTransportError::Io(std::io::Error::new(std::io::ErrorKind::Other, err.to_string())) + SseTransportError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + err.to_string(), + )) } } @@ -180,4 +204,4 @@ where { let client = AuthorizedSseClient::new(url, auth_manager, retry_config)?; SseTransport::start_with_client(client).await -} \ No newline at end of file +} diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index a84a11d..4420551 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -1,5 +1,3 @@ - - [package] name = "mcp-client-examples" version = "0.1.5" @@ -24,6 +22,7 @@ futures = "0.3" anyhow = "1.0" url = "2.4" tower = "0.5" +axum = "0.8" [[example]] name = "clients_sse" diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs index bae4429..add8b76 100644 --- a/examples/clients/src/oauth_client.rs +++ b/examples/clients/src/oauth_client.rs @@ -1,89 +1,271 @@ -use std::{sync::Arc, time::Duration}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; -use anyhow::Result; +use anyhow::{Context, Result}; +use axum::{ + Router, + extract::{Query, State}, + response::{Html, Redirect}, + routing::get, +}; use rmcp::{ ServiceExt, model::ClientInfo, transport::{ - auth::{AuthorizationManager, AuthorizationSession}, + auth::{AuthError, AuthorizationManager, AuthorizationSession}, create_authorized_transport, sse::SseTransportRetryConfig, }, }; +use serde::Deserialize; use tokio::{ io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, - sync::Mutex, + sync::{Mutex, oneshot}, }; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const MCP_SERVER_URL: &str = "http://localhost:3000/mcp"; +const MCP_REDIRECT_URI: &str = "http://localhost:8080/callback"; +const CALLBACK_PORT: u16 = 8080; + +#[derive(Clone)] +struct AppState { + auth_session: Arc, + code_receiver: Arc>>>, +} + +#[derive(Debug, Deserialize)] +struct CallbackParams { + code: String, + state: Option, +} + +async fn callback_handler( + Query(params): Query, + State(state): State, +) -> Html { + tracing::info!("Received callback with code: {}", params.code); + + // Send the code to the main thread + if let Some(sender) = state.code_receiver.lock().await.take() { + let _ = sender.send(params.code); + } + + // Return success page + Html(format!( + r#" + + + + OAuth Authorization Success + + + +
+
+

Authorization Successful

+

You have successfully authorized the MCP client. You can now close this window and return to the application.

+
+ + + "# + )) +} #[tokio::main] async fn main() -> Result<()> { - // init logger - tracing_subscriber::fmt::init(); + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); - // server url - let server_url = - std::env::var("MCP_SERVER_URL").unwrap_or_else(|_| "http://localhost:3000/mcp".to_string()); + // Get server URL + let server_url = MCP_SERVER_URL.to_string(); + tracing::info!("Using MCP server URL: {}", server_url); - // retry config + // Configure retry settings let retry_config = SseTransportRetryConfig { max_times: Some(3), min_duration: Duration::from_secs(1), }; - // init auth manager - let auth_manager = AuthorizationManager::new(&server_url).await?; + // Initialize authorization manager + let auth_manager = AuthorizationManager::new(&server_url) + .await + .context("Failed to initialize authorization manager")?; let auth_manager_arc = Arc::new(Mutex::new(auth_manager)); - // create authorization session + // Create authorization session let session = AuthorizationSession::new( auth_manager_arc.clone(), - &["mcp"], // request scopes - "http://localhost:8080/callback", // redirect uri + &["mcp", "profile", "email"], + &MCP_REDIRECT_URI, ) - .await?; + .await + .context("Failed to create authorization session")?; + + let session_arc = Arc::new(session); - // output authorization url + // Create channel for receiving authorization code + let (code_sender, code_receiver) = oneshot::channel::(); + + // Create app state + let app_state = AppState { + auth_session: session_arc.clone(), + code_receiver: Arc::new(Mutex::new(Some(code_sender))), + }; + + // Start HTTP server for handling callbacks + let app = Router::new() + .route("/callback", get(callback_handler)) + .with_state(app_state); + + let addr = SocketAddr::from(([127, 0, 0, 1], CALLBACK_PORT)); + tracing::info!("Starting callback server at: http://{}", addr); + + // Start server in a separate task + tokio::spawn(async move { + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + let result = axum::serve(listener, app).await; + + if let Err(e) = result { + tracing::error!("Callback server error: {}", e); + } + }); + + // Output authorization URL to user let mut output = BufWriter::new(tokio::io::stdout()); + output.write_all(b"\n=== MCP OAuth Client ===\n\n").await?; output - .write_all(b"please open the following URL in your browser:\n") + .write_all(b"Please open the following URL in your browser to authorize:\n\n") .await?; output - .write_all(session.get_authorization_url().as_bytes()) + .write_all(session_arc.get_authorization_url().as_bytes()) .await?; output - .write_all(b"\nplease input the authorization code:\n") + .write_all(b"\n\nWaiting for browser callback, please do not close this window...\n") .await?; output.flush().await?; - // read authorization code - let mut auth_code = String::new(); - let mut reader = BufReader::new(tokio::io::stdin()); - reader.read_line(&mut auth_code).await?; - let auth_code = auth_code.trim(); + // Wait for authorization code + tracing::info!("Waiting for authorization code..."); + let auth_code = code_receiver + .await + .context("Failed to get authorization code")?; - // exchange access token - let credentials = session.handle_callback(auth_code).await?; - tracing::info!("Successfully obtained access token"); + // Exchange code for access token + tracing::info!("Exchanging authorization code for access token..."); + let credentials = match session_arc.handle_callback(&auth_code).await { + Ok(creds) => { + tracing::info!("Successfully obtained access token"); + creds + } + Err(e) => { + tracing::error!("Failed to obtain access token: {}", e); + return Err(anyhow::anyhow!("Authorization failed: {}", e)); + } + }; - // create authorized sse transport, use retry config - let transport = - create_authorized_transport(&server_url, auth_manager_arc, Some(retry_config)).await?; + output + .write_all(b"\nAuthorization successful! Access token obtained.\n\n") + .await?; + output.flush().await?; + + // Create authorized transport + tracing::info!("Establishing authorized connection to MCP server..."); + let transport = match create_authorized_transport( + &server_url, + auth_manager_arc, + Some(retry_config), + ) + .await + { + Ok(t) => t, + Err(e) => { + tracing::error!("Failed to create authorized transport: {}", e); + return Err(anyhow::anyhow!("Connection failed: {}", e)); + } + }; - // create client + // Create client and connect to MCP server let client_service = ClientInfo::default(); let client = client_service.serve(transport).await?; + tracing::info!("Successfully connected to MCP server"); + + // Test API requests + output + .write_all(b"Fetching available tools from server...\n") + .await?; + output.flush().await?; + + match client.peer().list_all_tools().await { + Ok(tools) => { + output + .write_all(format!("Available tools: {}\n\n", tools.len()).as_bytes()) + .await?; + for tool in tools { + output + .write_all( + format!( + "- {} ({})\n", + tool.name, + tool.description.unwrap_or_default() + ) + .as_bytes(), + ) + .await?; + } + } + Err(e) => { + output + .write_all(format!("Error fetching tools: {}\n", e).as_bytes()) + .await?; + } + } - // test api request - let tools = client.peer().list_all_tools().await?; - tracing::info!("Available tools: {tools:#?}"); + output + .write_all(b"\nFetching available prompts from server...\n") + .await?; + output.flush().await?; + + match client.peer().list_all_prompts().await { + Ok(prompts) => { + output + .write_all(format!("Available prompts: {}\n\n", prompts.len()).as_bytes()) + .await?; + for prompt in prompts { + output + .write_all(format!("- {}\n", prompt.name).as_bytes()) + .await?; + } + } + Err(e) => { + output + .write_all(format!("Error fetching prompts: {}\n", e).as_bytes()) + .await?; + } + } + + output + .write_all(b"\nConnection established successfully. You are now authenticated with the MCP server.\n") + .await?; + output.flush().await?; - // get prompt list - let prompts = client.peer().list_all_prompts().await?; - tracing::info!("Available prompts: {prompts:#?}"); + // Keep the program running, wait for user input to exit + output.write_all(b"\nPress Enter to exit...\n").await?; + output.flush().await?; - // get resource list - let resources = client.peer().list_all_resources().await?; - tracing::info!("Available resources: {resources:#?}"); + let mut input = String::new(); + let mut reader = BufReader::new(tokio::io::stdin()); + reader.read_line(&mut input).await?; Ok(()) } diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 99b145c..6be6e7e 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -17,9 +17,12 @@ tracing-subscriber = { version = "0.3", features = [ "fmt", ] } futures = "0.3" -rand = { version = "0.9" } +rand = { version = "0.8", features = ["std"] } axum = { version = "0.8", features = ["macros"] } schemars = { version = "0.8", optional = true } +reqwest = { version = "0.12", features = ["json"] } +chrono = "0.4" +uuid = { version = "1.6", features = ["v4", "serde"] } # [dev-dependencies.'cfg(target_arch="linux")'.dependencies] [dev-dependencies] @@ -45,4 +48,8 @@ path = "src/generic_service.rs" [[example]] name = "servers_auth_sse" -path = "src/auth_sse.rs" \ No newline at end of file +path = "src/auth_sse.rs" + +[[example]] +name = "mcp_oauth_server" +path = "src/mcp_oauth_server.rs" \ No newline at end of file diff --git a/examples/servers/src/auth_sse.rs b/examples/servers/src/auth_sse.rs index 2f10edb..6389328 100644 --- a/examples/servers/src/auth_sse.rs +++ b/examples/servers/src/auth_sse.rs @@ -2,18 +2,14 @@ use std::{net::SocketAddr, sync::Arc, time::Duration}; use anyhow::Result; use axum::{ + Json, Router, extract::{Path, State}, http::{HeaderMap, Request, StatusCode}, middleware::{self, Next}, response::{Html, Response}, routing::get, - Json, Router, -}; -use rmcp::{ - service::ServiceExt, - transport::{auth::AuthError, SseServer, sse_server::SseServerConfig}, - ServerHandler, tool, }; +use rmcp::transport::{SseServer, sse_server::SseServerConfig}; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; mod common; @@ -74,7 +70,8 @@ async fn auth_middleware( // Root path handler async fn index() -> Html<&'static str> { - Html(r#" + Html( + r#" @@ -108,7 +105,8 @@ curl -H "Authorization: Bearer demo-token" http://127.0.0.1:8000/sse - "#) + "#, + ) } // Health check endpoint @@ -135,12 +133,12 @@ async fn get_token(Path(token_id): Path) -> Result Result<()> { // Initialize logging tracing_subscriber::registry() - .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "debug".to_string().into()), - ) - .with(tracing_subscriber::fmt::layer()) - .init(); + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); // Create token store let token_store = Arc::new(TokenStore::new()); @@ -166,11 +164,10 @@ async fn main() -> Result<()> { .route("/token/{token_id}", get(get_token)); // Create protected SSE routes (require authorization) - let protected_sse_router = sse_router - .layer(middleware::from_fn_with_state( - token_store.clone(), - auth_middleware, - )); + let protected_sse_router = sse_router.layer(middleware::from_fn_with_state( + token_store.clone(), + auth_middleware, + )); // Create main router, public endpoints don't require authorization let app = Router::new() @@ -214,4 +211,4 @@ async fn main() -> Result<()> { println!("Server has been shut down"); Ok(()) -} \ No newline at end of file +} diff --git a/examples/servers/src/mcp_oauth_server.rs b/examples/servers/src/mcp_oauth_server.rs new file mode 100644 index 0000000..e716355 --- /dev/null +++ b/examples/servers/src/mcp_oauth_server.rs @@ -0,0 +1,674 @@ +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; + +use anyhow::Result; +use axum::{ + Json, Router, + extract::{Form, Path, Query, State}, + http::{HeaderMap, StatusCode, Uri}, + response::{Html, IntoResponse, Redirect, Response}, + routing::{get, post}, +}; +use rand::{Rng, distributions::Alphanumeric}; +use reqwest::Client as HttpClient; +use rmcp::transport::{ + SseServer, + auth::{ + AuthorizationMetadata, ClientRegistrationRequest, ClientRegistrationResponse, + OAuthClientConfig, + }, + sse_server::SseServerConfig, +}; +use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use uuid::Uuid; +// Import Counter tool for MCP service +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:3000"; + +// MCP OAuth Store for managing tokens and sessions +#[derive(Clone, Debug)] +struct McpOAuthStore { + clients: Arc>>, + auth_sessions: Arc>>, + access_tokens: Arc>>, + http_client: HttpClient, +} + +impl McpOAuthStore { + fn new() -> Self { + let mut clients = HashMap::new(); + clients.insert( + "mcp-client".to_string(), + OAuthClientConfig { + client_id: "mcp-client".to_string(), + client_secret: Some("mcp-client-secret".to_string()), + scopes: vec!["profile".to_string(), "email".to_string()], + redirect_uri: "http://localhost:8080/callback".to_string(), + }, + ); + + Self { + clients: Arc::new(RwLock::new(clients)), + auth_sessions: Arc::new(RwLock::new(HashMap::new())), + access_tokens: Arc::new(RwLock::new(HashMap::new())), + http_client: HttpClient::builder() + .timeout(Duration::from_secs(30)) + .build() + .expect("Failed to create HTTP client"), + } + } + + async fn validate_client( + &self, + client_id: &str, + redirect_uri: &str, + ) -> Option { + let clients = self.clients.read().await; + if let Some(client) = clients.get(client_id) { + if client.redirect_uri.contains(&redirect_uri.to_string()) { + return Some(client.clone()); + } + } + None + } + + async fn create_auth_session( + &self, + client_id: String, + redirect_uri: String, + scope: Option, + state: Option, + ) -> String { + let session_id = generate_random_string(16); + let session = AuthSession { + id: session_id.clone(), + client_id, + redirect_uri, + scope, + state, + created_at: chrono::Utc::now(), + third_party_token: None, + }; + + self.auth_sessions + .write() + .await + .insert(session_id.clone(), session); + session_id + } + + async fn get_auth_session(&self, session_id: &str) -> Option { + self.auth_sessions.read().await.get(session_id).cloned() + } + + async fn update_auth_session_token( + &self, + session_id: &str, + token: ThirdPartyToken, + ) -> Result<(), String> { + let mut sessions = self.auth_sessions.write().await; + if let Some(session) = sessions.get_mut(session_id) { + session.third_party_token = Some(token); + Ok(()) + } else { + Err("Session not found".to_string()) + } + } + + async fn create_mcp_token(&self, session_id: &str) -> Result { + let sessions = self.auth_sessions.read().await; + if let Some(session) = sessions.get(session_id) { + if let Some(third_party_token) = &session.third_party_token { + let access_token = format!("mcp-token-{}", Uuid::new_v4()); + let token = McpAccessToken { + access_token: access_token.clone(), + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: format!("mcp-refresh-{}", Uuid::new_v4()), + scope: session.scope.clone(), + third_party_token: third_party_token.clone(), + client_id: session.client_id.clone(), + }; + + self.access_tokens + .write() + .await + .insert(access_token.clone(), token.clone()); + Ok(token) + } else { + Err("No third-party token available for session".to_string()) + } + } else { + Err("Session not found".to_string()) + } + } + + async fn validate_token(&self, token: &str) -> Option { + self.access_tokens.read().await.get(token).cloned() + } +} + +#[derive(Clone, Debug)] +struct AuthSession { + id: String, + client_id: String, + redirect_uri: String, + scope: Option, + state: Option, + created_at: chrono::DateTime, + third_party_token: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +struct ThirdPartyToken { + access_token: String, + token_type: String, + expires_in: u64, + refresh_token: String, + scope: Option, +} + +#[derive(Clone, Debug, Serialize)] +struct McpAccessToken { + access_token: String, + token_type: String, + expires_in: u64, + refresh_token: String, + scope: Option, + third_party_token: ThirdPartyToken, + client_id: String, +} + +#[derive(Debug, Deserialize)] +struct AuthorizeQuery { + response_type: String, + client_id: String, + redirect_uri: String, + scope: Option, + state: Option, +} + +#[derive(Debug, Deserialize)] +struct AuthCallbackQuery { + code: String, + state: Option, + session_id: String, +} + +#[derive(Debug, Deserialize)] +struct TokenRequest { + grant_type: String, + code: String, + client_id: String, + client_secret: String, + redirect_uri: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct ThirdPartyTokenRequest { + grant_type: String, + code: String, + client_id: String, + client_secret: String, + redirect_uri: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct UserInfo { + sub: String, + name: String, + email: String, + username: String, +} + +fn generate_random_string(length: usize) -> String { + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(length) + .map(char::from) + .collect() +} + +// Root path handler +async fn index() -> Html<&'static str> { + Html( + r#" + + + + MCP OAuth Server + + + +

MCP OAuth Server

+

This is an MCP server with OAuth 2.0 integration to a third-party authorization server.

+ +

Available Endpoints:

+ +
+

Authorization Endpoint

+

GET /oauth/authorize

+

Parameters:

+
    +
  • response_type - Must be "code"
  • +
  • client_id - Client identifier (e.g., "mcp-client")
  • +
  • redirect_uri - URI to redirect after authorization
  • +
  • scope - Optional requested scope
  • +
  • state - Optional state value for CSRF prevention
  • +
+
+ +
+

Token Endpoint

+

POST /oauth/token

+

Parameters:

+
    +
  • grant_type - Must be "authorization_code"
  • +
  • code - The authorization code
  • +
  • client_id - Client identifier
  • +
  • client_secret - Client secret
  • +
  • redirect_uri - Redirect URI used in authorization request
  • +
+
+ +
+

MCP SSE Endpoints

+

/mcp/sse - SSE connection endpoint (requires OAuth token)

+

/mcp/message - Message endpoint (requires OAuth token)

+
+ +
+

OAuth Flow:

+
    +
  1. MCP Client initiates OAuth flow with this MCP Server
  2. +
  3. MCP Server redirects to Third-Party OAuth Server
  4. +
  5. User authenticates with Third-Party Server
  6. +
  7. Third-Party Server redirects back to MCP Server with auth code
  8. +
  9. MCP Server exchanges the code for a third-party access token
  10. +
  11. MCP Server generates its own token bound to the third-party session
  12. +
  13. MCP Server completes the OAuth flow with the MCP Client
  14. +
+
+ + + "#, + ) +} + +// Initial OAuth authorize endpoint +async fn oauth_authorize( + Query(params): Query, + State(state): State, +) -> impl IntoResponse { + if let Some(client) = state + .validate_client(¶ms.client_id, ¶ms.redirect_uri) + .await + { + // create authorize page for user to approve + let html = format!( + r#" + + + + MCP OAuth + + + +

MCP OAuth Server

+
+
+

{client_id} requests access to your account.

+

requested scopes: {scopes}

+
+ +
+ + + + + +
+ + +
+
+
+ + + "#, + client_id = params.client_id, + redirect_uri = params.redirect_uri, + scope = params.scope.clone().unwrap_or_default(), + state = params.state.clone().unwrap_or_default(), + scopes = params + .scope + .clone() + .unwrap_or_else(|| "basic access".to_string()), + ); + + Html(html).into_response() + } else { + // invalid client_id or redirect_uri + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "invalid client id or redirect uri" + })), + ) + .into_response() + } +} + +// handle approval of authorization +#[derive(Debug, Deserialize)] +struct ApprovalForm { + client_id: String, + redirect_uri: String, + scope: String, + state: String, + approved: String, +} + +async fn oauth_approve( + State(state): State, + Form(form): Form, +) -> impl IntoResponse { + if form.approved != "true" { + // user rejected the authorization request + let redirect_url = format!( + "{}?error=access_denied&error_description={}{}", + form.redirect_uri, + "user rejected the authorization request", + form.state + .is_empty() + .then_some("") + .unwrap_or(&format!("&state={}", form.state)) + ); + return Redirect::to(&redirect_url).into_response(); + } + + // user approved the authorization request, generate authorization code + let auth_code = format!("mcp-code-{}", Uuid::new_v4().to_string()); + + // create new session record authorization information + let session_id = state + .create_auth_session( + form.client_id, + form.redirect_uri.clone(), + Some(form.scope), + Some(form.state.clone()), + ) + .await; + + // redirect back to client, with authorization code + let redirect_url = format!( + "{}?code={}{}", + form.redirect_uri, + auth_code, + form.state + .is_empty() + .then_some("") + .unwrap_or(&format!("&state={}", form.state)) + ); + + info!("authorization approved, redirecting to: {}", redirect_url); + Redirect::to(&redirect_url).into_response() +} + +// Handle token request from the MCP client +async fn oauth_token( + State(state): State, + Form(token_req): Form, +) -> impl IntoResponse { + if token_req.grant_type != "authorization_code" { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "unsupported_grant_type" + })), + ) + .into_response(); + } + + // Validate the client + if let Some(_client) = state + .validate_client(&token_req.client_id, &token_req.redirect_uri) + .await + { + // The code we generated earlier is "mcp-code-{session_id}" + if !token_req.code.starts_with("mcp-code-") { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_grant", + "error_description": "Invalid authorization code" + })), + ) + .into_response(); + } + + let session_id = token_req.code.replace("mcp-code-", ""); + + // Create an MCP access token bound to the third-party token + match state.create_mcp_token(&session_id).await { + Ok(token) => { + // Return the token + ( + StatusCode::OK, + Json(serde_json::json!({ + "access_token": token.access_token, + "token_type": token.token_type, + "expires_in": token.expires_in, + "refresh_token": token.refresh_token, + "scope": token.scope, + })), + ) + .into_response() + } + Err(e) => { + error!("Failed to create MCP token: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": "Failed to create access token" + })), + ) + .into_response() + } + } + } else { + // Invalid client credentials + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_client" + })), + ) + .into_response() + } +} + +// Auth middleware for SSE connections +async fn validate_token_middleware( + headers: HeaderMap, + State(state): State, +) -> Result, StatusCode> { + // Extract the access token from the Authorization header + let auth_header = headers.get("Authorization"); + let token = match auth_header { + Some(header) => { + let header_str = header.to_str().unwrap_or(""); + if header_str.starts_with("Bearer ") { + header_str[7..].to_string() + } else { + return Err(StatusCode::UNAUTHORIZED); + } + } + None => { + return Err(StatusCode::UNAUTHORIZED); + } + }; + + // Validate the token + match state.validate_token(&token).await { + Some(_) => Ok(Some(token)), + None => Err(StatusCode::UNAUTHORIZED), + } +} + +// handle oauth server metadata request +async fn oauth_authorization_server() -> impl IntoResponse { + let metadata = AuthorizationMetadata { + authorization_endpoint: format!("http://{}/oauth/authorize", BIND_ADDRESS), + token_endpoint: format!("http://{}/oauth/token", BIND_ADDRESS), + scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]), + registration_endpoint: format!("http://{}/oauth/register", BIND_ADDRESS), + issuer: Some(format!("{}", BIND_ADDRESS)), + jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)), + }; + debug!("metadata: {:?}", metadata); + (StatusCode::OK, Json(metadata)) +} + +// handle client registration request +async fn oauth_register( + State(state): State, + Json(req): Json, +) -> impl IntoResponse { + debug!("register request: {:?}", req); + if req.redirect_uris.is_empty() { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "at least one redirect uri is required" + })), + ) + .into_response(); + } + + // generate client id and secret + let client_id = format!("client-{}", Uuid::new_v4()); + let client_secret = generate_random_string(32); + + let client = OAuthClientConfig { + client_id: client_id.clone(), + client_secret: Some(client_secret.clone()), + redirect_uri: req.redirect_uris[0].clone(), + scopes: vec![], + }; + + state + .clients + .write() + .await + .insert(client_id.clone(), client); + + // return client information + let response = ClientRegistrationResponse { + client_id, + client_secret, + client_name: req.client_name, + redirect_uris: req.redirect_uris, + }; + + (StatusCode::CREATED, Json(response)).into_response() +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Create the OAuth store + let oauth_store = McpOAuthStore::new(); + + // Set up port + let addr = BIND_ADDRESS.parse::()?; + + // Create SSE server configuration for MCP + let sse_config = SseServerConfig { + bind: addr.clone(), + sse_path: "/mcp/sse".to_string(), + post_path: "/mcp/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + // Create SSE server + let (sse_server, sse_router) = SseServer::new(sse_config); + + // Create HTTP router + let app = Router::new() + .route("/", get(index)) + .route("/mcp", get(index)) + .route( + "/.well-known/oauth-authorization-server", + get(oauth_authorization_server), + ) + .route("/oauth/authorize", get(oauth_authorize)) + .route("/oauth/approve", post(oauth_approve)) + .route("/oauth/token", post(oauth_token)) + .route("/oauth/register", post(oauth_register)) + .with_state(oauth_store.clone()); + + let app = app.merge(sse_router.with_state(())); + // Register token validation middleware for SSE + let cancel_token = sse_server.config.ct.clone(); + // Handle Ctrl+C + let cancel_token2 = sse_server.config.ct.clone(); + // Start SSE server with Counter service + sse_server.with_service(Counter::new); + + // Start HTTP server + info!("MCP OAuth Server started on {}", addr); + let listener = tokio::net::TcpListener::bind(addr).await?; + let server = axum::serve(listener, app).with_graceful_shutdown(async move { + cancel_token.cancelled().await; + info!("Server is shutting down"); + }); + + tokio::spawn(async move { + match tokio::signal::ctrl_c().await { + Ok(()) => { + info!("Received Ctrl+C, shutting down"); + cancel_token2.cancel(); + } + Err(e) => error!("Failed to listen for Ctrl+C: {}", e), + } + }); + + if let Err(e) = server.await { + error!("Server error: {}", e); + } + + Ok(()) +} From 2a72ae986b06a738884a2dfef379540c91b74fe5 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Wed, 23 Apr 2025 11:01:10 +0800 Subject: [PATCH 4/6] fix: pkg is nil Signed-off-by: jokemanfire --- crates/rmcp/src/transport/auth.rs | 21 ++++++++------------- examples/clients/Cargo.toml | 2 ++ 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 4d8c0d4..2ace0c3 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -191,7 +191,7 @@ impl AuthorizationManager { let client_id = ClientId::new(config.client_id); let redirect_url = RedirectUrl::new(config.redirect_uri.clone()) - .map_err(|e| AuthError::OAuthError(format!("Invalid registry URL: {}", e)))?; + .map_err(|e| AuthError::OAuthError(format!("Invalid redirect URL: {}", e)))?; let mut client_builder = BasicClient::new( client_id.clone(), @@ -321,6 +321,7 @@ impl AuthorizationManager { // store pkce verifier for later use *self.pkce_verifier.write().await = Some(pkce_verifier); + debug!("set pkce verifier: {:?}", self.pkce_verifier.read().await); Ok(auth_url.to_string()) } @@ -335,7 +336,9 @@ impl AuthorizationManager { .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - let pkce_verifier = self.pkce_verifier.write().await.take().unwrap(); + let pkce_verifier = self.pkce_verifier.write().await.take().ok_or_else(|| { + AuthError::InternalError("PKCE verifier not found".to_string()) + })?; // exchange token let token_result = oauth_client @@ -344,6 +347,7 @@ impl AuthorizationManager { .request(http_client) .map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?; + debug!("exchange token result: {:?}", token_result); // store credentials *self.credentials.write().await = Some(token_result.clone()); @@ -406,7 +410,7 @@ impl AuthorizationManager { /// prepare request, add authorization header pub async fn prepare_request( &self, - mut request: reqwest::RequestBuilder, + request: reqwest::RequestBuilder, ) -> Result { let token = self.get_access_token().await?; Ok(request.header(AUTHORIZATION, format!("Bearer {}", token))) @@ -431,7 +435,6 @@ pub struct AuthorizationSession { pub auth_manager: Arc>, pub auth_url: String, pub redirect_uri: String, - pub pkce_verifier: PkceCodeVerifier, } impl AuthorizationSession { @@ -470,19 +473,11 @@ impl AuthorizationSession { .await .get_authorization_url(scopes) .await?; - let pkce_verifier = auth_manager - .lock() - .await - .pkce_verifier - .write() - .await - .take() - .unwrap(); + Ok(Self { auth_manager, auth_url, redirect_uri: redirect_uri.to_string(), - pkce_verifier, }) } diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index 4420551..312def3 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -1,3 +1,5 @@ + + [package] name = "mcp-client-examples" version = "0.1.5" From ee36c6ff4c65acacf2a7e864659f755fe88f44c8 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Wed, 23 Apr 2025 16:05:18 +0800 Subject: [PATCH 5/6] fix: the basic verify is ok Signed-off-by: jokemanfire --- crates/rmcp/Cargo.toml | 2 +- crates/rmcp/src/transport/auth.rs | 64 +++-- examples/clients/src/oauth_client.rs | 6 +- examples/servers/Cargo.toml | 2 + examples/servers/src/mcp_oauth_server.rs | 304 ++++++++++++++++------- 5 files changed, 261 insertions(+), 117 deletions(-) diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 74bb0e1..088c72f 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -25,7 +25,7 @@ pin-project-lite = "0.2" paste = { version = "1", optional = true } # oauth2 support -oauth2 = { version = "4.3", optional = true } +oauth2 = { version = "5.0", optional = true } # for auto generate schema schemars = { version = "0.8", optional = true } diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 2ace0c3..6fdd613 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -6,7 +6,6 @@ use oauth2::{ RefreshToken, RefreshTokenRequest, Scope, StandardTokenResponse, TokenResponse, TokenType, TokenUrl, basic::{BasicClient, BasicTokenType}, - reqwest::http_client, }; use reqwest::{Client as HttpClient, IntoUrl, StatusCode, Url, header::AUTHORIZATION}; use serde::{Deserialize, Serialize}; @@ -87,7 +86,20 @@ pub struct OAuthClientConfig { pub struct AuthorizationManager { http_client: HttpClient, metadata: Option, - oauth_client: Option, + oauth_client: Option< + oauth2::Client< + oauth2::StandardErrorResponse, + StandardTokenResponse, + oauth2::StandardTokenIntrospectionResponse, + oauth2::StandardRevocableToken, + oauth2::StandardErrorResponse, + oauth2::EndpointSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointSet, + >, + >, credentials: RwLock>>, pkce_verifier: RwLock>, base_url: Url, @@ -189,26 +201,19 @@ impl AuthorizationManager { let token_url = TokenUrl::new(metadata.token_endpoint.clone()) .map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?; + // debug!("token url: {:?}", token_url); let client_id = ClientId::new(config.client_id); let redirect_url = RedirectUrl::new(config.redirect_uri.clone()) - .map_err(|e| AuthError::OAuthError(format!("Invalid redirect URL: {}", e)))?; + .map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?; - let mut client_builder = BasicClient::new( - client_id.clone(), - None, - auth_url.clone(), - Some(token_url.clone()), - ) - .set_redirect_uri(redirect_url.clone()); + debug!("client_id: {:?}", client_id); + let mut client_builder = BasicClient::new(client_id.clone()) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri(redirect_url); if let Some(secret) = config.client_secret { - client_builder = BasicClient::new( - client_id, - Some(ClientSecret::new(secret)), - auth_url, - Some(token_url), - ) - .set_redirect_uri(redirect_url); + client_builder = client_builder.set_client_secret(ClientSecret::new(secret)); } self.oauth_client = Some(client_builder); @@ -274,7 +279,7 @@ impl AuthorizationManager { status, error_text ))); } - + debug!("registration response: {:?}", response); let reg_response = match response.json::().await { Ok(response) => response, Err(e) => { @@ -331,20 +336,30 @@ impl AuthorizationManager { &self, code: &str, ) -> Result, AuthError> { + debug!("start exchange code for token: {:?}", code); let oauth_client = self .oauth_client .as_ref() .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; - let pkce_verifier = self.pkce_verifier.write().await.take().ok_or_else(|| { - AuthError::InternalError("PKCE verifier not found".to_string()) - })?; - + let pkce_verifier = self + .pkce_verifier + .write() + .await + .take() + .ok_or_else(|| AuthError::InternalError("PKCE verifier not found".to_string()))?; + let http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + debug!("client_id: {:?}", oauth_client.client_id()); // exchange token let token_result = oauth_client .exchange_code(AuthorizationCode::new(code.to_string())) + .add_extra_param("client_id", oauth_client.client_id().to_string()) .set_pkce_verifier(pkce_verifier) - .request(http_client) + .request_async(&http_client) + .await .map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?; debug!("exchange token result: {:?}", token_result); @@ -398,7 +413,8 @@ impl AuthorizationManager { // refresh token let token_result = oauth_client .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) - .request(http_client) + .request_async(&self.http_client) + .await .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; // store new credentials diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs index add8b76..25c3531 100644 --- a/examples/clients/src/oauth_client.rs +++ b/examples/clients/src/oauth_client.rs @@ -25,6 +25,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const MCP_SERVER_URL: &str = "http://localhost:3000/mcp"; const MCP_REDIRECT_URI: &str = "http://localhost:8080/callback"; +const MCP_SSE_URL: &str = "http://localhost:3000/mcp/sse"; const CALLBACK_PORT: u16 = 8080; #[derive(Clone)] @@ -160,7 +161,7 @@ async fn main() -> Result<()> { let auth_code = code_receiver .await .context("Failed to get authorization code")?; - + tracing::info!("Received authorization code: {}", auth_code); // Exchange code for access token tracing::info!("Exchanging authorization code for access token..."); let credentials = match session_arc.handle_callback(&auth_code).await { @@ -173,6 +174,7 @@ async fn main() -> Result<()> { return Err(anyhow::anyhow!("Authorization failed: {}", e)); } }; + tracing::info!("Access token: {:?}", credentials); output .write_all(b"\nAuthorization successful! Access token obtained.\n\n") @@ -182,7 +184,7 @@ async fn main() -> Result<()> { // Create authorized transport tracing::info!("Establishing authorized connection to MCP server..."); let transport = match create_authorized_transport( - &server_url, + MCP_SSE_URL.to_string(), auth_manager_arc, Some(retry_config), ) diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 6be6e7e..6fae7ed 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -23,6 +23,8 @@ schemars = { version = "0.8", optional = true } reqwest = { version = "0.12", features = ["json"] } chrono = "0.4" uuid = { version = "1.6", features = ["v4", "serde"] } +serde_urlencoded = "0.7" +hyper = "1.3" # [dev-dependencies.'cfg(target_arch="linux")'.dependencies] [dev-dependencies] diff --git a/examples/servers/src/mcp_oauth_server.rs b/examples/servers/src/mcp_oauth_server.rs index e716355..a128216 100644 --- a/examples/servers/src/mcp_oauth_server.rs +++ b/examples/servers/src/mcp_oauth_server.rs @@ -1,13 +1,16 @@ -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration, usize}; use anyhow::Result; use axum::{ Json, Router, + body::Body, extract::{Form, Path, Query, State}, - http::{HeaderMap, StatusCode, Uri}, + http::{HeaderMap, Method, Request, StatusCode, Uri}, + middleware::{self, Next}, response::{Html, IntoResponse, Redirect, Response}, routing::{get, post}, }; +use hyper::body; use rand::{Rng, distributions::Alphanumeric}; use reqwest::Client as HttpClient; use rmcp::transport::{ @@ -83,8 +86,8 @@ impl McpOAuthStore { redirect_uri: String, scope: Option, state: Option, + session_id: String, ) -> String { - let session_id = generate_random_string(16); let session = AuthSession { id: session_id.clone(), client_id, @@ -92,7 +95,7 @@ impl McpOAuthStore { scope, state, created_at: chrono::Utc::now(), - third_party_token: None, + auth_token: None, }; self.auth_sessions @@ -102,18 +105,14 @@ impl McpOAuthStore { session_id } - async fn get_auth_session(&self, session_id: &str) -> Option { - self.auth_sessions.read().await.get(session_id).cloned() - } - async fn update_auth_session_token( &self, session_id: &str, - token: ThirdPartyToken, + token: AuthToken, ) -> Result<(), String> { let mut sessions = self.auth_sessions.write().await; if let Some(session) = sessions.get_mut(session_id) { - session.third_party_token = Some(token); + session.auth_token = Some(token); Ok(()) } else { Err("Session not found".to_string()) @@ -123,7 +122,7 @@ impl McpOAuthStore { async fn create_mcp_token(&self, session_id: &str) -> Result { let sessions = self.auth_sessions.read().await; if let Some(session) = sessions.get(session_id) { - if let Some(third_party_token) = &session.third_party_token { + if let Some(auth_token) = &session.auth_token { let access_token = format!("mcp-token-{}", Uuid::new_v4()); let token = McpAccessToken { access_token: access_token.clone(), @@ -131,7 +130,7 @@ impl McpOAuthStore { expires_in: 3600, refresh_token: format!("mcp-refresh-{}", Uuid::new_v4()), scope: session.scope.clone(), - third_party_token: third_party_token.clone(), + auth_token: auth_token.clone(), client_id: session.client_id.clone(), }; @@ -161,11 +160,11 @@ struct AuthSession { scope: Option, state: Option, created_at: chrono::DateTime, - third_party_token: Option, + auth_token: Option, } #[derive(Clone, Debug, Serialize, Deserialize)] -struct ThirdPartyToken { +struct AuthToken { access_token: String, token_type: String, expires_in: u64, @@ -180,7 +179,7 @@ struct McpAccessToken { expires_in: u64, refresh_token: String, scope: Option, - third_party_token: ThirdPartyToken, + auth_token: AuthToken, client_id: String, } @@ -200,22 +199,17 @@ struct AuthCallbackQuery { session_id: String, } -#[derive(Debug, Deserialize)] -struct TokenRequest { - grant_type: String, - code: String, - client_id: String, - client_secret: String, - redirect_uri: String, -} - #[derive(Debug, Deserialize, Serialize)] -struct ThirdPartyTokenRequest { +struct TokenRequest { grant_type: String, code: String, + #[serde(default)] client_id: String, + #[serde(default)] client_secret: String, redirect_uri: String, + #[serde(default)] + code_verifier: Option, } #[derive(Debug, Deserialize, Serialize)] @@ -309,7 +303,7 @@ async fn index() -> Html<&'static str> { // Initial OAuth authorize endpoint async fn oauth_authorize( Query(params): Query, - State(state): State, + State(state): State>, ) -> impl IntoResponse { if let Some(client) = state .validate_client(¶ms.client_id, ¶ms.redirect_uri) @@ -391,7 +385,7 @@ struct ApprovalForm { } async fn oauth_approve( - State(state): State, + State(state): State>, Form(form): Form, ) -> impl IntoResponse { if form.approved != "true" { @@ -409,18 +403,37 @@ async fn oauth_approve( } // user approved the authorization request, generate authorization code - let auth_code = format!("mcp-code-{}", Uuid::new_v4().to_string()); + let session_id = Uuid::new_v4().to_string(); + let auth_code = format!("mcp-code-{}", session_id); // create new session record authorization information let session_id = state .create_auth_session( - form.client_id, + form.client_id.clone(), form.redirect_uri.clone(), - Some(form.scope), + Some(form.scope.clone()), Some(form.state.clone()), + session_id.clone(), ) .await; + // create token + let created_token = AuthToken { + access_token: format!("tp-token-{}", Uuid::new_v4()), + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: format!("tp-refresh-{}", Uuid::new_v4()), + scope: Some(form.scope), + }; + + // update session token + if let Err(e) = state + .update_auth_session_token(&session_id, created_token) + .await + { + error!("Failed to update session token: {}", e); + } + // redirect back to client, with authorization code let redirect_url = format!( "{}?code={}{}", @@ -438,85 +451,142 @@ async fn oauth_approve( // Handle token request from the MCP client async fn oauth_token( - State(state): State, - Form(token_req): Form, + State(state): State>, + request: axum::http::Request, ) -> impl IntoResponse { + info!("Received token request"); + + let bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await { + Ok(bytes) => bytes, + Err(e) => { + error!("can't read request body: {}", e); + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "can't read request body" + })), + ) + .into_response(); + } + }; + + let body_str = String::from_utf8_lossy(&bytes); + info!("request body: {}", body_str); + + let token_req = match serde_urlencoded::from_bytes::(&bytes) { + Ok(form) => { + info!("successfully parsed form data: {:?}", form); + form + } + Err(e) => { + error!("can't parse form data: {}", e); + return ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": format!("can't parse form data: {}", e) + })), + ) + .into_response(); + } + }; + + if token_req.grant_type != "authorization_code" { + info!("unsupported grant type: {}", token_req.grant_type); return ( StatusCode::BAD_REQUEST, Json(serde_json::json!({ - "error": "unsupported_grant_type" + "error": "unsupported_grant_type", + "error_description": "only authorization_code is supported" })), ) .into_response(); } - // Validate the client - if let Some(_client) = state - .validate_client(&token_req.client_id, &token_req.redirect_uri) + // get session_id from code + if !token_req.code.starts_with("mcp-code-") { + info!("invalid authorization code: {}", token_req.code); + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_grant", + "error_description": "invalid authorization code" + })), + ) + .into_response(); + } + + // handle empty client_id + let client_id = if token_req.client_id.is_empty() { + "mcp-client".to_string() + } else { + token_req.client_id.clone() + }; + + // validate client + match state + .validate_client(&client_id, &token_req.redirect_uri) .await { - // The code we generated earlier is "mcp-code-{session_id}" - if !token_req.code.starts_with("mcp-code-") { - return ( + Some(_) => { + let session_id = token_req.code.replace("mcp-code-", ""); + info!("got session id: {}", session_id); + + // create mcp access token + match state.create_mcp_token(&session_id).await { + Ok(token) => { + info!("successfully created access token"); + ( + StatusCode::OK, + Json(serde_json::json!({ + "access_token": token.access_token, + "token_type": token.token_type, + "expires_in": token.expires_in, + "refresh_token": token.refresh_token, + "scope": token.scope, + })), + ) + .into_response() + } + Err(e) => { + error!("failed to create access token: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": format!("failed to create access token: {}", e) + })), + ) + .into_response() + } + } + } + None => { + info!( + "invalid client id or redirect uri: {} / {}", + client_id, token_req.redirect_uri + ); + ( StatusCode::BAD_REQUEST, Json(serde_json::json!({ - "error": "invalid_grant", - "error_description": "Invalid authorization code" + "error": "invalid_client", + "error_description": "invalid client id or redirect uri" })), ) - .into_response(); - } - - let session_id = token_req.code.replace("mcp-code-", ""); - - // Create an MCP access token bound to the third-party token - match state.create_mcp_token(&session_id).await { - Ok(token) => { - // Return the token - ( - StatusCode::OK, - Json(serde_json::json!({ - "access_token": token.access_token, - "token_type": token.token_type, - "expires_in": token.expires_in, - "refresh_token": token.refresh_token, - "scope": token.scope, - })), - ) - .into_response() - } - Err(e) => { - error!("Failed to create MCP token: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(serde_json::json!({ - "error": "server_error", - "error_description": "Failed to create access token" - })), - ) - .into_response() - } + .into_response() } - } else { - // Invalid client credentials - ( - StatusCode::BAD_REQUEST, - Json(serde_json::json!({ - "error": "invalid_client" - })), - ) - .into_response() } } // Auth middleware for SSE connections async fn validate_token_middleware( - headers: HeaderMap, - State(state): State, + State(token_store): State>, + request: Request, ) -> Result, StatusCode> { // Extract the access token from the Authorization header - let auth_header = headers.get("Authorization"); + let auth_header = request.headers().get("Authorization"); let token = match auth_header { Some(header) => { let header_str = header.to_str().unwrap_or(""); @@ -532,7 +602,7 @@ async fn validate_token_middleware( }; // Validate the token - match state.validate_token(&token).await { + match token_store.validate_token(&token).await { Some(_) => Ok(Some(token)), None => Err(StatusCode::UNAUTHORIZED), } @@ -554,7 +624,7 @@ async fn oauth_authorization_server() -> impl IntoResponse { // handle client registration request async fn oauth_register( - State(state): State, + State(state): State>, Json(req): Json, ) -> impl IntoResponse { debug!("register request: {:?}", req); @@ -597,6 +667,52 @@ async fn oauth_register( (StatusCode::CREATED, Json(response)).into_response() } +// Log all HTTP requests +async fn log_request(request: Request, next: Next) -> Response { + let method = request.method().clone(); + let uri = request.uri().clone(); + let version = request.version(); + + // Log headers + let headers = request.headers().clone(); + let mut header_log = String::new(); + for (key, value) in headers.iter() { + let value_str = match value.to_str() { + Ok(v) => v, + Err(_) => "", + }; + header_log.push_str(&format!("\n {}: {}", key, value_str)); + } + + // Try to get request body for form submissions + let content_type = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + let request_info = if content_type.contains("application/x-www-form-urlencoded") + || content_type.contains("application/json") + { + format!( + "{} {} {:?}{}\nContent-Type: {}", + method, uri, version, header_log, content_type + ) + } else { + format!("{} {} {:?}{}", method, uri, version, header_log) + }; + + info!("REQUEST: {}", request_info); + + // Call the actual handler + let response = next.run(request).await; + + // Log response status + let status = response.status(); + info!("RESPONSE: {} for {} {}", status, method, uri); + + response +} + #[tokio::main] async fn main() -> Result<()> { // Initialize logging @@ -609,7 +725,7 @@ async fn main() -> Result<()> { .init(); // Create the OAuth store - let oauth_store = McpOAuthStore::new(); + let oauth_store = Arc::new(McpOAuthStore::new()); // Set up port let addr = BIND_ADDRESS.parse::()?; @@ -626,7 +742,13 @@ async fn main() -> Result<()> { // Create SSE server let (sse_server, sse_router) = SseServer::new(sse_config); - // Create HTTP router + // Create protected SSE routes (require authorization) + // let protected_sse_router = sse_router.layer(middleware::from_fn_with_state( + // oauth_store.clone(), + // validate_token_middleware, + // )); + + // Create HTTP router with request logging middleware let app = Router::new() .route("/", get(index)) .route("/mcp", get(index)) @@ -638,7 +760,9 @@ async fn main() -> Result<()> { .route("/oauth/approve", post(oauth_approve)) .route("/oauth/token", post(oauth_token)) .route("/oauth/register", post(oauth_register)) - .with_state(oauth_store.clone()); + // .merge(protected_sse_router) + .with_state(oauth_store.clone()) + .layer(middleware::from_fn(log_request)); let app = app.merge(sse_router.with_state(())); // Register token validation middleware for SSE From ec767bac001d70187d256ef043fba2d9c8a87397 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Sat, 26 Apr 2025 16:03:01 +0800 Subject: [PATCH 6/6] fix: the example in oauth2 is ok Signed-off-by: jokemanfire --- README.md | 4 + crates/rmcp/src/transport/auth.rs | 49 ++++++------ crates/rmcp/src/transport/sse_auth.rs | 55 +++---------- OAUTH_README.md => docs/OAUTH_SUPPORT.md | 23 +++--- examples/clients/src/oauth_client.rs | 15 ++-- examples/servers/Cargo.toml | 2 +- examples/servers/src/auth_sse.rs | 18 +++-- examples/servers/src/mcp_oauth_server.rs | 99 ++++++++++-------------- 8 files changed, 116 insertions(+), 149 deletions(-) rename OAUTH_README.md => docs/OAUTH_SUPPORT.md (91%) diff --git a/README.md b/README.md index 65b2ab2..a174a40 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,10 @@ For many cases you need to manage several service in a collection, you can call let service = service.into_dyn(); ``` +### OAuth Support + +See [docs/OAUTH_SUPPORT.md](docs/OAUTH_SUPPORT.md) for details. + ### Examples See [examples](examples/README.md) diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index 6fdd613..42e0097 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,19 +1,15 @@ use std::{sync::Arc, time::Duration}; use oauth2::{ - AccessToken, AuthUrl, AuthorizationCode, AuthorizationRequest, ClientId, ClientSecret, - CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, - RefreshToken, RefreshTokenRequest, Scope, StandardTokenResponse, TokenResponse, TokenType, - TokenUrl, + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, + PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, RefreshToken, Scope, StandardTokenResponse, + TokenResponse, TokenUrl, basic::{BasicClient, BasicTokenType}, }; use reqwest::{Client as HttpClient, IntoUrl, StatusCode, Url, header::AUTHORIZATION}; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tokio::{ - sync::{Mutex, RwLock}, - time::{self, Instant}, -}; +use tokio::sync::{Mutex, RwLock}; use tracing::{debug, error}; /// Auth error @@ -82,24 +78,31 @@ pub struct OAuthClientConfig { pub redirect_uri: String, } +// add type aliases for oauth2 types +type OAuthErrorResponse = oauth2::StandardErrorResponse; +type OAuthTokenResponse = StandardTokenResponse; +type OAuthTokenIntrospection = + oauth2::StandardTokenIntrospectionResponse; +type OAuthRevocableToken = oauth2::StandardRevocableToken; +type OAuthRevocationError = oauth2::StandardErrorResponse; +type OAuthClient = oauth2::Client< + OAuthErrorResponse, + OAuthTokenResponse, + OAuthTokenIntrospection, + OAuthRevocableToken, + OAuthRevocationError, + oauth2::EndpointSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointSet, +>; + /// oauth2 auth manager pub struct AuthorizationManager { http_client: HttpClient, metadata: Option, - oauth_client: Option< - oauth2::Client< - oauth2::StandardErrorResponse, - StandardTokenResponse, - oauth2::StandardTokenIntrospectionResponse, - oauth2::StandardRevocableToken, - oauth2::StandardErrorResponse, - oauth2::EndpointSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointNotSet, - oauth2::EndpointSet, - >, - >, + oauth_client: Option, credentials: RwLock>>, pkce_verifier: RwLock>, base_url: Url, @@ -524,7 +527,7 @@ pub struct AuthorizedHttpClient { impl AuthorizedHttpClient { /// create new authorized http client pub fn new(auth_manager: Arc, client: Option) -> Self { - let inner_client = client.unwrap_or_else(|| HttpClient::new()); + let inner_client = client.unwrap_or_default(); Self { auth_manager, inner_client, diff --git a/crates/rmcp/src/transport/sse_auth.rs b/crates/rmcp/src/transport/sse_auth.rs index e8e6957..0e3e89c 100644 --- a/crates/rmcp/src/transport/sse_auth.rs +++ b/crates/rmcp/src/transport/sse_auth.rs @@ -1,22 +1,18 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; -use futures::{ - Future, Sink, Stream, StreamExt, future::BoxFuture, sink::SinkExt as FuturesSinkExt, - stream::BoxStream, -}; +use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use reqwest::{ Client as HttpClient, IntoUrl, Url, - header::{ACCEPT, AUTHORIZATION, HeaderValue}, + header::{ACCEPT, AUTHORIZATION}, }; use sse_stream::{Error as SseError, Sse, SseStream}; -use thiserror::Error; use tokio::sync::Mutex; use super::{ auth::{AuthError, AuthorizationManager}, sse::{SseClient, SseTransport, SseTransportError, SseTransportRetryConfig}, }; -use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; +use crate::model::ClientJsonRpcMessage; // SSE MIME type const MIME_TYPE: &str = "text/event-stream"; @@ -28,7 +24,7 @@ pub struct AuthorizedSseClient { http_client: HttpClient, sse_url: Url, auth_manager: Arc>, - retry_config: SseTransportRetryConfig, + _retry_config: Option, // TODO retry config may be used by authorized transport for token } impl AuthorizedSseClient { @@ -46,7 +42,7 @@ impl AuthorizedSseClient { http_client: HttpClient::default(), sse_url: url, auth_manager, - retry_config: retry_config.unwrap_or_default(), + _retry_config: retry_config, }) } @@ -65,40 +61,9 @@ impl AuthorizedSseClient { http_client: client, sse_url: url, auth_manager, - retry_config: retry_config.unwrap_or_default(), + _retry_config: retry_config, }) } - - /// get access token, support retry - async fn get_token_with_retry(&self) -> Result> { - let mut retries = 0; - let max_retries = self.retry_config.max_times; - let base_delay = self.retry_config.min_duration; - - loop { - match self.auth_manager.lock().await.get_access_token().await { - Ok(token) => return Ok(token), - Err(AuthError::AuthorizationRequired) => { - return Err(SseTransportError::Io(std::io::Error::new( - std::io::ErrorKind::Other, - "Authorization required", - ))); - } - Err(_e) => { - if retries >= max_retries.unwrap_or(0) { - return Err(SseTransportError::Io(std::io::Error::new( - std::io::ErrorKind::Other, - "Authorization required", - ))); - } - retries += 1; - // todo: need to optimize - let delay = base_delay.as_millis(); - tokio::time::sleep(Duration::from_millis(delay as u64)).await; - } - } - } - } } impl SseClient for AuthorizedSseClient { @@ -166,7 +131,7 @@ impl SseClient for AuthorizedSseClient { .await .get_access_token() .await - .map_err(|e| SseTransportError::::from(e))?; + .map_err(SseTransportError::::from)?; let uri = sse_url.join(&session_id).map_err(SseTransportError::from)?; let request_builder = client @@ -203,5 +168,7 @@ where U: IntoUrl, { let client = AuthorizedSseClient::new(url, auth_manager, retry_config)?; - SseTransport::start_with_client(client).await + let mut transport = SseTransport::start_with_client(client).await?; + transport.retry_config = retry_config.unwrap_or_default(); + Ok(transport) } diff --git a/OAUTH_README.md b/docs/OAUTH_SUPPORT.md similarity index 91% rename from OAUTH_README.md rename to docs/OAUTH_SUPPORT.md index 2b02a9e..aa4ac62 100644 --- a/OAUTH_README.md +++ b/docs/OAUTH_SUPPORT.md @@ -24,7 +24,7 @@ rmcp = { version = "0.1", features = ["auth", "transport-sse"] } ### 2. Create Authorization Manager -```rust +```rust ignore use std::sync::Arc; use rmcp::transport::auth::AuthorizationManager; @@ -38,7 +38,7 @@ async fn main() -> anyhow::Result<()> { ### 3. Create Authorization Session and Get Authorization -```rust +```rust ignore use rmcp::transport::auth::AuthorizationSession; async fn get_authorization(auth_manager: Arc) -> anyhow::Result<()> { @@ -72,7 +72,8 @@ async fn connect_with_auth(auth_manager: Arc) -> anyhow::R // Create authorized SSE transport let transport = create_authorized_transport( "https://api.example.com/mcp", - auth_manager.clone() + auth_manager.clone(), + None ).await?; // Create client @@ -110,17 +111,19 @@ async fn make_authorized_request(auth_manager: Arc) -> any ``` ## Complete Example +client: Please refer to `examples/clients/src/oauth_client.rs` for a complete usage example. +server: Please refer to `examples/servers/src/mcp_oauth_server.rs` for a complete usage example. +### Running the Example in server +```bash +# Run example +cargo run --example mcp_oauth_server +``` -Please refer to `examples/oauth_client.rs` for a complete usage example. - -## Running the Example +### Running the Example in client ```bash -# Set server URL (optional) -export MCP_SERVER_URL=https://api.example.com/mcp - # Run example -cargo run --bin oauth-client +cargo run --example oauth-client ``` ## Authorization Flow Description diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs index 25c3531..b6dc56f 100644 --- a/examples/clients/src/oauth_client.rs +++ b/examples/clients/src/oauth_client.rs @@ -4,14 +4,14 @@ use anyhow::{Context, Result}; use axum::{ Router, extract::{Query, State}, - response::{Html, Redirect}, + response::Html, routing::get, }; use rmcp::{ ServiceExt, model::ClientInfo, transport::{ - auth::{AuthError, AuthorizationManager, AuthorizationSession}, + auth::{AuthorizationManager, AuthorizationSession}, create_authorized_transport, sse::SseTransportRetryConfig, }, @@ -30,13 +30,13 @@ const CALLBACK_PORT: u16 = 8080; #[derive(Clone)] struct AppState { - auth_session: Arc, code_receiver: Arc>>>, } #[derive(Debug, Deserialize)] struct CallbackParams { code: String, + #[allow(dead_code)] state: Option, } @@ -52,7 +52,7 @@ async fn callback_handler( } // Return success page - Html(format!( + Html( r#" @@ -73,8 +73,8 @@ async fn callback_handler( - "# - )) + "#.to_string() + ) } #[tokio::main] @@ -108,7 +108,7 @@ async fn main() -> Result<()> { let session = AuthorizationSession::new( auth_manager_arc.clone(), &["mcp", "profile", "email"], - &MCP_REDIRECT_URI, + MCP_REDIRECT_URI, ) .await .context("Failed to create authorization session")?; @@ -120,7 +120,6 @@ async fn main() -> Result<()> { // Create app state let app_state = AppState { - auth_session: session_arc.clone(), code_receiver: Arc::new(Mutex::new(Some(code_sender))), }; diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 6fae7ed..bafe3b1 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -24,7 +24,7 @@ reqwest = { version = "0.12", features = ["json"] } chrono = "0.4" uuid = { version = "1.6", features = ["v4", "serde"] } serde_urlencoded = "0.7" -hyper = "1.3" + # [dev-dependencies.'cfg(target_arch="linux")'.dependencies] [dev-dependencies] diff --git a/examples/servers/src/auth_sse.rs b/examples/servers/src/auth_sse.rs index 6389328..ecc32a3 100644 --- a/examples/servers/src/auth_sse.rs +++ b/examples/servers/src/auth_sse.rs @@ -1,3 +1,10 @@ +/// This example shows how to use the RMCP SSE server with OAuth authorization. +/// Use the inspector to view this server https://github.com/modelcontextprotocol/inspector +/// The default index page is available at http://127.0.0.1:8000/ +/// # Get a token +/// curl http://127.0.0.1:8000/api/token/demo +/// # Connect to SSE using the token +/// curl -H "Authorization: Bearer demo-token" http://127.0.0.1:8000/sse use std::{net::SocketAddr, sync::Arc, time::Duration}; use anyhow::Result; @@ -10,10 +17,9 @@ use axum::{ routing::get, }; use rmcp::transport::{SseServer, sse_server::SseServerConfig}; -use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; mod common; -use common::{calculator::Calculator, counter::Counter}; +use common::counter::Counter; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; const BIND_ADDRESS: &str = "127.0.0.1:8000"; @@ -41,11 +47,9 @@ fn extract_token(headers: &HeaderMap) -> Option { .get("Authorization") .and_then(|value| value.to_str().ok()) .and_then(|auth_header| { - if auth_header.starts_with("Bearer ") { - Some(auth_header[7..].to_string()) - } else { - None - } + auth_header + .strip_prefix("Bearer ") + .map(|stripped| stripped.to_string()) }) } diff --git a/examples/servers/src/mcp_oauth_server.rs b/examples/servers/src/mcp_oauth_server.rs index a128216..86b5a0a 100644 --- a/examples/servers/src/mcp_oauth_server.rs +++ b/examples/servers/src/mcp_oauth_server.rs @@ -1,18 +1,16 @@ -use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration, usize}; +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration}; use anyhow::Result; use axum::{ Json, Router, body::Body, - extract::{Form, Path, Query, State}, - http::{HeaderMap, Method, Request, StatusCode, Uri}, + extract::{Form, Query, State}, + http::{Request, StatusCode}, middleware::{self, Next}, response::{Html, IntoResponse, Redirect, Response}, routing::{get, post}, }; -use hyper::body; use rand::{Rng, distributions::Alphanumeric}; -use reqwest::Client as HttpClient; use rmcp::transport::{ SseServer, auth::{ @@ -33,13 +31,12 @@ use common::counter::Counter; const BIND_ADDRESS: &str = "127.0.0.1:3000"; -// MCP OAuth Store for managing tokens and sessions +// A easy way to manage MCP OAuth Store for managing tokens and sessions #[derive(Clone, Debug)] struct McpOAuthStore { clients: Arc>>, auth_sessions: Arc>>, access_tokens: Arc>>, - http_client: HttpClient, } impl McpOAuthStore { @@ -59,10 +56,6 @@ impl McpOAuthStore { clients: Arc::new(RwLock::new(clients)), auth_sessions: Arc::new(RwLock::new(HashMap::new())), access_tokens: Arc::new(RwLock::new(HashMap::new())), - http_client: HttpClient::builder() - .timeout(Duration::from_secs(30)) - .build() - .expect("Failed to create HTTP client"), } } @@ -83,18 +76,15 @@ impl McpOAuthStore { async fn create_auth_session( &self, client_id: String, - redirect_uri: String, scope: Option, state: Option, session_id: String, ) -> String { let session = AuthSession { - id: session_id.clone(), client_id, - redirect_uri, scope, - state, - created_at: chrono::Utc::now(), + _state: state, + _created_at: chrono::Utc::now(), auth_token: None, }; @@ -152,17 +142,18 @@ impl McpOAuthStore { } } +// a simple session record for auth session #[derive(Clone, Debug)] struct AuthSession { - id: String, client_id: String, - redirect_uri: String, scope: Option, - state: Option, - created_at: chrono::DateTime, + _state: Option, + _created_at: chrono::DateTime, auth_token: Option, } +// a simple token record for auth token +// not used oauth2 token for avoid include oauth2 crate in this example #[derive(Clone, Debug, Serialize, Deserialize)] struct AuthToken { access_token: String, @@ -172,6 +163,8 @@ struct AuthToken { scope: Option, } +// a simple token record for mcp token , +// not used oauth2 token for avoid include oauth2 crate in this example #[derive(Clone, Debug, Serialize)] struct McpAccessToken { access_token: String, @@ -185,6 +178,7 @@ struct McpAccessToken { #[derive(Debug, Deserialize)] struct AuthorizeQuery { + #[allow(dead_code)] response_type: String, client_id: String, redirect_uri: String, @@ -192,13 +186,6 @@ struct AuthorizeQuery { state: Option, } -#[derive(Debug, Deserialize)] -struct AuthCallbackQuery { - code: String, - state: Option, - session_id: String, -} - #[derive(Debug, Deserialize, Serialize)] struct TokenRequest { grant_type: String, @@ -305,7 +292,8 @@ async fn oauth_authorize( Query(params): Query, State(state): State>, ) -> impl IntoResponse { - if let Some(client) = state + debug!("doing oauth_authorize"); + if let Some(_client) = state .validate_client(¶ms.client_id, ¶ms.redirect_uri) .await { @@ -394,10 +382,11 @@ async fn oauth_approve( "{}?error=access_denied&error_description={}{}", form.redirect_uri, "user rejected the authorization request", - form.state - .is_empty() - .then_some("") - .unwrap_or(&format!("&state={}", form.state)) + if form.state.is_empty() { + "".to_string() + } else { + format!("&state={}", form.state) + } ); return Redirect::to(&redirect_url).into_response(); } @@ -410,7 +399,6 @@ async fn oauth_approve( let session_id = state .create_auth_session( form.client_id.clone(), - form.redirect_uri.clone(), Some(form.scope.clone()), Some(form.state.clone()), session_id.clone(), @@ -439,10 +427,11 @@ async fn oauth_approve( "{}?code={}{}", form.redirect_uri, auth_code, - form.state - .is_empty() - .then_some("") - .unwrap_or(&format!("&state={}", form.state)) + if form.state.is_empty() { + "".to_string() + } else { + format!("&state={}", form.state) + } ); info!("authorization approved, redirecting to: {}", redirect_url); @@ -492,7 +481,6 @@ async fn oauth_token( } }; - if token_req.grant_type != "authorization_code" { info!("unsupported grant type: {}", token_req.grant_type); return ( @@ -584,27 +572,29 @@ async fn oauth_token( async fn validate_token_middleware( State(token_store): State>, request: Request, -) -> Result, StatusCode> { + next: Next, +) -> Response { + debug!("validate_token_middleware"); // Extract the access token from the Authorization header let auth_header = request.headers().get("Authorization"); let token = match auth_header { Some(header) => { let header_str = header.to_str().unwrap_or(""); - if header_str.starts_with("Bearer ") { - header_str[7..].to_string() + if let Some(stripped) = header_str.strip_prefix("Bearer ") { + stripped.to_string() } else { - return Err(StatusCode::UNAUTHORIZED); + return StatusCode::UNAUTHORIZED.into_response(); } } None => { - return Err(StatusCode::UNAUTHORIZED); + return StatusCode::UNAUTHORIZED.into_response(); } }; // Validate the token match token_store.validate_token(&token).await { - Some(_) => Ok(Some(token)), - None => Err(StatusCode::UNAUTHORIZED), + Some(_) => next.run(request).await, + None => StatusCode::UNAUTHORIZED.into_response(), } } @@ -615,7 +605,7 @@ async fn oauth_authorization_server() -> impl IntoResponse { token_endpoint: format!("http://{}/oauth/token", BIND_ADDRESS), scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]), registration_endpoint: format!("http://{}/oauth/register", BIND_ADDRESS), - issuer: Some(format!("{}", BIND_ADDRESS)), + issuer: Some(BIND_ADDRESS.to_string()), jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)), }; debug!("metadata: {:?}", metadata); @@ -677,10 +667,7 @@ async fn log_request(request: Request, next: Next) -> Response { let headers = request.headers().clone(); let mut header_log = String::new(); for (key, value) in headers.iter() { - let value_str = match value.to_str() { - Ok(v) => v, - Err(_) => "", - }; + let value_str = value.to_str().unwrap_or(""); header_log.push_str(&format!("\n {}: {}", key, value_str)); } @@ -732,7 +719,7 @@ async fn main() -> Result<()> { // Create SSE server configuration for MCP let sse_config = SseServerConfig { - bind: addr.clone(), + bind: addr, sse_path: "/mcp/sse".to_string(), post_path: "/mcp/message".to_string(), ct: CancellationToken::new(), @@ -743,10 +730,10 @@ async fn main() -> Result<()> { let (sse_server, sse_router) = SseServer::new(sse_config); // Create protected SSE routes (require authorization) - // let protected_sse_router = sse_router.layer(middleware::from_fn_with_state( - // oauth_store.clone(), - // validate_token_middleware, - // )); + let protected_sse_router = sse_router.layer(middleware::from_fn_with_state( + oauth_store.clone(), + validate_token_middleware, + )); // Create HTTP router with request logging middleware let app = Router::new() @@ -764,7 +751,7 @@ async fn main() -> Result<()> { .with_state(oauth_store.clone()) .layer(middleware::from_fn(log_request)); - let app = app.merge(sse_router.with_state(())); + let app = app.merge(protected_sse_router); // Register token validation middleware for SSE let cancel_token = sse_server.config.ct.clone(); // Handle Ctrl+C