diff --git a/OAUTH_README.md b/OAUTH_README.md new file mode 100644 index 0000000..2b02a9e --- /dev/null +++ b/OAUTH_README.md @@ -0,0 +1,156 @@ +# Model Context Protocol OAuth Authorization + +This document describes the OAuth 2.1 authorization implementation for Model Context Protocol (MCP), following the [MCP 2025-03-26 Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/). + +## Features + +- Full support for OAuth 2.1 authorization flow +- PKCE support for enhanced security +- Authorization server metadata discovery +- Dynamic client registration +- Automatic token refresh +- Authorized SSE transport implementation + +## Usage Guide + +### 1. Enable Features + +Enable the auth feature in Cargo.toml: + +```toml +[dependencies] +rmcp = { version = "0.1", features = ["auth", "transport-sse"] } +``` + +### 2. Create Authorization Manager + +```rust +use std::sync::Arc; +use rmcp::transport::auth::AuthorizationManager; + +async fn main() -> anyhow::Result<()> { + // Create authorization manager + let auth_manager = Arc::new(AuthorizationManager::new("https://api.example.com/mcp").await?); + + Ok(()) +} +``` + +### 3. Create Authorization Session and Get Authorization + +```rust +use rmcp::transport::auth::AuthorizationSession; + +async fn get_authorization(auth_manager: Arc) -> anyhow::Result<()> { + // Create authorization session + let session = AuthorizationSession::new( + auth_manager.clone(), + &["mcp"], // Requested scopes + "http://localhost:8080/callback", // Redirect URI + ).await?; + + // Get authorization URL and guide user to open it + let auth_url = session.get_authorization_url(); + println!("Please open the following URL in your browser for authorization:\n{}", auth_url); + + // Handle callback - In real applications, this is typically done in a callback server + let auth_code = "Authorization code obtained from browser after user authorization"; + let credentials = session.handle_callback(auth_code).await?; + + println!("Authorization successful, access token: {}", credentials.access_token); + + Ok(()) +} +``` + +### 4. Use Authorized SSE Transport + +```rust +use rmcp::{ServiceExt, model::ClientInfo, transport::create_authorized_transport}; + +async fn connect_with_auth(auth_manager: Arc) -> anyhow::Result<()> { + // Create authorized SSE transport + let transport = create_authorized_transport( + "https://api.example.com/mcp", + auth_manager.clone() + ).await?; + + // Create client + let client_service = ClientInfo::default(); + let client = client_service.serve(transport).await?; + + // Use client to call APIs + let tools = client.peer().list_all_tools().await?; + + for tool in tools { + println!("Tool: {} - {}", tool.name, tool.description); + } + + Ok(()) +} +``` + +### 5. Use Authorized HTTP Client + +```rust +use rmcp::transport::auth::AuthorizedHttpClient; + +async fn make_authorized_request(auth_manager: Arc) -> anyhow::Result<()> { + // Create authorized HTTP client + let client = AuthorizedHttpClient::new(auth_manager, None); + + // Send authorized request + let response = client.get("https://api.example.com/resources").await?; + let resources = response.json::>().await?; + + println!("Number of resources: {}", resources.len()); + + Ok(()) +} +``` + +## Complete Example + +Please refer to `examples/oauth_client.rs` for a complete usage example. + +## Running the Example + +```bash +# Set server URL (optional) +export MCP_SERVER_URL=https://api.example.com/mcp + +# Run example +cargo run --bin oauth-client +``` + +## Authorization Flow Description + +1. **Metadata Discovery**: Client attempts to get authorization server metadata from `/.well-known/oauth-authorization-server` +2. **Client Registration**: If supported, client dynamically registers itself +3. **Authorization Request**: Build authorization URL with PKCE and guide user to access +4. **Authorization Code Exchange**: After user authorization, exchange authorization code for access token +5. **Token Usage**: Use access token for API calls +6. **Token Refresh**: Automatically use refresh token to get new access token when current one expires + +## Security Considerations + +- All tokens are securely stored in memory +- PKCE implementation prevents authorization code interception attacks +- Automatic token refresh support reduces user intervention +- Only accepts HTTPS connections or secure local callback URIs + +## Troubleshooting + +If you encounter authorization issues, check the following: + +1. Ensure server supports OAuth 2.1 authorization +2. Verify callback URI matches server's allowed redirect URIs +3. Check network connection and firewall settings +4. Verify server supports metadata discovery or dynamic client registration + +## References + +- [MCP Authorization Specification](https://spec.modelcontextprotocol.io/specification/2025-03-26/basic/authorization/) +- [OAuth 2.1 Specification Draft](https://oauth.net/2.1/) +- [RFC 8414: OAuth 2.0 Authorization Server Metadata](https://datatracker.ietf.org/doc/html/rfc8414) +- [RFC 7591: OAuth 2.0 Dynamic Client Registration Protocol](https://datatracker.ietf.org/doc/html/rfc7591) \ No newline at end of file diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 738f244..088c72f 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -23,6 +23,10 @@ tracing = { version = "0.1" } tokio-util = { version = "0.7" } pin-project-lite = "0.2" paste = { version = "1", optional = true } + +# oauth2 support +oauth2 = { version = "5.0", optional = true } + # for auto generate schema schemars = { version = "0.8", optional = true } @@ -70,6 +74,8 @@ transport-sse-server = [ ] # transport-ws = ["transport-io", "dep:tokio-tungstenite"] tower = ["dep:tower-service"] +auth = ["dep:oauth2", "dep:reqwest", "dep:url"] + [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "0.8" } diff --git a/crates/rmcp/src/transport.rs b/crates/rmcp/src/transport.rs index cbfbf9f..9b49f76 100644 --- a/crates/rmcp/src/transport.rs +++ b/crates/rmcp/src/transport.rs @@ -56,6 +56,11 @@ pub mod sse; #[cfg(feature = "transport-sse")] pub use sse::SseTransport; +#[cfg(all(feature = "transport-sse", feature = "auth"))] +pub mod sse_auth; +#[cfg(all(feature = "transport-sse", feature = "auth"))] +pub use sse_auth::{AuthorizedSseClient, create_authorized_transport}; + // #[cfg(feature = "tower")] // pub mod tower; @@ -64,6 +69,11 @@ pub mod sse_server; #[cfg(feature = "transport-sse-server")] pub use sse_server::SseServer; +#[cfg(feature = "auth")] +pub mod auth; +#[cfg(feature = "auth")] +pub use auth::{AuthError, AuthorizationManager, AuthorizationSession, AuthorizedHttpClient}; + // #[cfg(feature = "transport-ws")] // pub mod ws; diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs new file mode 100644 index 0000000..6fdd613 --- /dev/null +++ b/crates/rmcp/src/transport/auth.rs @@ -0,0 +1,555 @@ +use std::{sync::Arc, time::Duration}; + +use oauth2::{ + AccessToken, AuthUrl, AuthorizationCode, AuthorizationRequest, ClientId, ClientSecret, + CsrfToken, EmptyExtraTokenFields, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, + RefreshToken, RefreshTokenRequest, Scope, StandardTokenResponse, TokenResponse, TokenType, + TokenUrl, + basic::{BasicClient, BasicTokenType}, +}; +use reqwest::{Client as HttpClient, IntoUrl, StatusCode, Url, header::AUTHORIZATION}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio::{ + sync::{Mutex, RwLock}, + time::{self, Instant}, +}; +use tracing::{debug, error}; + +/// Auth error +#[derive(Debug, Error)] +pub enum AuthError { + #[error("OAuth authorization required")] + AuthorizationRequired, + + #[error("OAuth authorization failed: {0}")] + AuthorizationFailed(String), + + #[error("OAuth token exchange failed: {0}")] + TokenExchangeFailed(String), + + #[error("OAuth token refresh failed: {0}")] + TokenRefreshFailed(String), + + #[error("HTTP error: {0}")] + HttpError(#[from] reqwest::Error), + + #[error("OAuth error: {0}")] + OAuthError(String), + + #[error("Metadata error: {0}")] + MetadataError(String), + + #[error("URL parse error: {0}")] + UrlError(#[from] url::ParseError), + + #[error("No authorization support detected")] + NoAuthorizationSupport, + + #[error("Internal error: {0}")] + InternalError(String), + + #[error("Invalid token type: {0}")] + InvalidTokenType(String), + + #[error("Token expired")] + TokenExpired, + + #[error("Invalid scope: {0}")] + InvalidScope(String), + + #[error("Registration failed: {0}")] + RegistrationFailed(String), +} + +/// oauth2 metadata +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct AuthorizationMetadata { + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: String, + pub issuer: Option, + pub jwks_uri: Option, + pub scopes_supported: Option>, +} + +/// oauth2 client config +#[derive(Debug, Clone)] +pub struct OAuthClientConfig { + pub client_id: String, + pub client_secret: Option, + pub scopes: Vec, + pub redirect_uri: String, +} + +/// oauth2 auth manager +pub struct AuthorizationManager { + http_client: HttpClient, + metadata: Option, + oauth_client: Option< + oauth2::Client< + oauth2::StandardErrorResponse, + StandardTokenResponse, + oauth2::StandardTokenIntrospectionResponse, + oauth2::StandardRevocableToken, + oauth2::StandardErrorResponse, + oauth2::EndpointSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointNotSet, + oauth2::EndpointSet, + >, + >, + credentials: RwLock>>, + pkce_verifier: RwLock>, + base_url: Url, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationRequest { + pub client_name: String, + pub redirect_uris: Vec, + pub grant_types: Vec, + pub token_endpoint_auth_method: String, + pub response_types: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientRegistrationResponse { + pub client_id: String, + pub client_secret: String, + pub client_name: String, + pub redirect_uris: Vec, +} + +impl AuthorizationManager { + /// create new auth manager + pub async fn new(base_url: U) -> Result { + let base_url = base_url.into_url()?; + let http_client = HttpClient::builder() + .timeout(Duration::from_secs(30)) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + + let mut manager = Self { + http_client, + metadata: None, + oauth_client: None, + credentials: RwLock::new(None), + pkce_verifier: RwLock::new(None), + base_url, + }; + + // try to discover oauth2 metadata + if let Ok(metadata) = manager.discover_metadata().await { + manager.metadata = Some(metadata); + } + + Ok(manager) + } + + /// discover oauth2 metadata + pub async fn discover_metadata(&self) -> Result { + // according to the specification, the metadata should be located at "/.well-known/oauth-authorization-server" + let mut discovery_url = self.base_url.clone(); + discovery_url.set_path("/.well-known/oauth-authorization-server"); + debug!("discovery url: {:?}", discovery_url); + let response = self + .http_client + .get(discovery_url) + .header("MCP-Protocol-Version", "2024-11-05") + .send() + .await?; + + if response.status() == StatusCode::OK { + let metadata = response + .json::() + .await + .map_err(|e| { + AuthError::MetadataError(format!("Failed to parse metadata: {}", e)) + })?; + debug!("metadata: {:?}", metadata); + Ok(metadata) + } else { + // fallback to default endpoints + let mut auth_base = self.base_url.clone(); + // discard the path part, only keep scheme, host, port + auth_base.set_path(""); + + Ok(AuthorizationMetadata { + authorization_endpoint: format!("{}/authorize", auth_base), + token_endpoint: format!("{}/token", auth_base), + registration_endpoint: format!("{}/register", auth_base), + issuer: None, + jwks_uri: None, + scopes_supported: None, + }) + } + } + + /// configure oauth2 client with client credentials + pub fn configure_client(&mut self, config: OAuthClientConfig) -> Result<(), AuthError> { + if self.metadata.is_none() { + return Err(AuthError::NoAuthorizationSupport); + } + + let metadata = self.metadata.as_ref().unwrap(); + + let auth_url = AuthUrl::new(metadata.authorization_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid authorization URL: {}", e)))?; + + let token_url = TokenUrl::new(metadata.token_endpoint.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid token URL: {}", e)))?; + + // debug!("token url: {:?}", token_url); + let client_id = ClientId::new(config.client_id); + let redirect_url = RedirectUrl::new(config.redirect_uri.clone()) + .map_err(|e| AuthError::OAuthError(format!("Invalid re URL: {}", e)))?; + + debug!("client_id: {:?}", client_id); + let mut client_builder = BasicClient::new(client_id.clone()) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri(redirect_url); + + if let Some(secret) = config.client_secret { + client_builder = client_builder.set_client_secret(ClientSecret::new(secret)); + } + + self.oauth_client = Some(client_builder); + Ok(()) + } + + /// dynamic register oauth2 client + pub async fn register_client( + &mut self, + name: &str, + redirect_uri: &str, + ) -> Result { + if self.metadata.is_none() { + error!("No authorization support detected"); + return Err(AuthError::NoAuthorizationSupport); + } + + let metadata = self.metadata.as_ref().unwrap(); + let registration_url = metadata.registration_endpoint.clone(); + + debug!("registration url: {:?}", registration_url); + // prepare registration request + let registration_request = ClientRegistrationRequest { + client_name: name.to_string(), + redirect_uris: vec![redirect_uri.to_string()], + grant_types: vec![ + "authorization_code".to_string(), + "refresh_token".to_string(), + ], + token_endpoint_auth_method: "none".to_string(), // public client + response_types: vec!["code".to_string()], + }; + + debug!("registration request: {:?}", registration_request); + + let response = match self + .http_client + .post(registration_url) + .json(®istration_request) + .send() + .await + { + Ok(response) => response, + Err(e) => { + error!("Registration request failed: {}", e); + return Err(AuthError::RegistrationFailed(format!( + "HTTP request error: {}", + e + ))); + } + }; + + if !response.status().is_success() { + let status = response.status(); + let error_text = match response.text().await { + Ok(text) => text, + Err(_) => "cannot get error details".to_string(), + }; + + error!("Registration failed: HTTP {} - {}", status, error_text); + return Err(AuthError::RegistrationFailed(format!( + "HTTP {}: {}", + status, error_text + ))); + } + debug!("registration response: {:?}", response); + let reg_response = match response.json::().await { + Ok(response) => response, + Err(e) => { + error!("Failed to parse registration response: {}", e); + return Err(AuthError::RegistrationFailed(format!( + "analyze response error: {}", + e + ))); + } + }; + + let config = OAuthClientConfig { + client_id: reg_response.client_id, + client_secret: Some(reg_response.client_secret), + redirect_uri: redirect_uri.to_string(), + scopes: vec![], + }; + + self.configure_client(config.clone())?; + Ok(config) + } + + /// generate authorization url + pub async fn get_authorization_url(&self, scopes: &[&str]) -> Result { + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + // generate pkce challenge + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // build authorization request + let mut auth_request = oauth_client + .authorize_url(CsrfToken::new_random) + .set_pkce_challenge(pkce_challenge); + + // add request scopes + for scope in scopes { + auth_request = auth_request.add_scope(Scope::new(scope.to_string())); + } + + let (auth_url, _csrf_token) = auth_request.url(); + + // store pkce verifier for later use + *self.pkce_verifier.write().await = Some(pkce_verifier); + debug!("set pkce verifier: {:?}", self.pkce_verifier.read().await); + + Ok(auth_url.to_string()) + } + + /// exchange authorization code for access token + pub async fn exchange_code_for_token( + &self, + code: &str, + ) -> Result, AuthError> { + debug!("start exchange code for token: {:?}", code); + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let pkce_verifier = self + .pkce_verifier + .write() + .await + .take() + .ok_or_else(|| AuthError::InternalError("PKCE verifier not found".to_string()))?; + let http_client = reqwest::ClientBuilder::new() + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|e| AuthError::InternalError(e.to_string()))?; + debug!("client_id: {:?}", oauth_client.client_id()); + // exchange token + let token_result = oauth_client + .exchange_code(AuthorizationCode::new(code.to_string())) + .add_extra_param("client_id", oauth_client.client_id().to_string()) + .set_pkce_verifier(pkce_verifier) + .request_async(&http_client) + .await + .map_err(|e| AuthError::TokenExchangeFailed(e.to_string()))?; + + debug!("exchange token result: {:?}", token_result); + // store credentials + *self.credentials.write().await = Some(token_result.clone()); + + Ok(token_result) + } + + /// get access token, if expired, refresh it automatically + pub async fn get_access_token(&self) -> Result { + let credentials = self.credentials.read().await; + + if let Some(creds) = credentials.as_ref() { + // check if the token is expired + if let Some(expires_in) = creds.expires_in() { + if expires_in <= Duration::from_secs(0) { + // token expired, try to refresh + drop(credentials); // release the lock + let new_creds = self.refresh_token().await?; + return Ok(new_creds.access_token().secret().to_string()); + } + } + + Ok(creds.access_token().secret().to_string()) + } else { + Err(AuthError::AuthorizationRequired) + } + } + + /// refresh access token + pub async fn refresh_token( + &self, + ) -> Result, AuthError> { + let oauth_client = self + .oauth_client + .as_ref() + .ok_or_else(|| AuthError::InternalError("OAuth client not configured".to_string()))?; + + let current_credentials = self + .credentials + .read() + .await + .clone() + .ok_or_else(|| AuthError::AuthorizationRequired)?; + + let refresh_token = current_credentials.refresh_token().ok_or_else(|| { + AuthError::TokenRefreshFailed("No refresh token available".to_string()) + })?; + + // refresh token + let token_result = oauth_client + .exchange_refresh_token(&RefreshToken::new(refresh_token.secret().to_string())) + .request_async(&self.http_client) + .await + .map_err(|e| AuthError::TokenRefreshFailed(e.to_string()))?; + + // store new credentials + *self.credentials.write().await = Some(token_result.clone()); + + Ok(token_result) + } + + /// prepare request, add authorization header + pub async fn prepare_request( + &self, + request: reqwest::RequestBuilder, + ) -> Result { + let token = self.get_access_token().await?; + Ok(request.header(AUTHORIZATION, format!("Bearer {}", token))) + } + + /// handle response, check if need to re-authorize + pub async fn handle_response( + &self, + response: reqwest::Response, + ) -> Result { + if response.status() == StatusCode::UNAUTHORIZED { + // 401 Unauthorized, need to re-authorize + Err(AuthError::AuthorizationRequired) + } else { + Ok(response) + } + } +} + +/// oauth2 authorization session, for guiding user to complete the authorization process +pub struct AuthorizationSession { + pub auth_manager: Arc>, + pub auth_url: String, + pub redirect_uri: String, +} + +impl AuthorizationSession { + /// create new authorization session + pub async fn new( + auth_manager: Arc>, + scopes: &[&str], + redirect_uri: &str, + ) -> Result { + // set redirect uri + let config = OAuthClientConfig { + client_id: "mcp-client".to_string(), // temporary id, will be updated by dynamic registration + client_secret: None, + scopes: scopes.iter().map(|s| s.to_string()).collect(), + redirect_uri: redirect_uri.to_string(), + }; + + // try to dynamic register client + let config = match auth_manager + .lock() + .await + .register_client("MCP Client", redirect_uri) + .await + { + Ok(config) => config, + Err(e) => { + eprintln!("Dynamic registration failed: {}", e); + // fallback to default config + config + } + }; + // reset client config + auth_manager.lock().await.configure_client(config)?; + let auth_url = auth_manager + .lock() + .await + .get_authorization_url(scopes) + .await?; + + Ok(Self { + auth_manager, + auth_url, + redirect_uri: redirect_uri.to_string(), + }) + } + + /// get authorization url + pub fn get_authorization_url(&self) -> &str { + &self.auth_url + } + + /// handle authorization code callback + pub async fn handle_callback( + &self, + code: &str, + ) -> Result, AuthError> { + self.auth_manager + .lock() + .await + .exchange_code_for_token(code) + .await + } +} + +/// http client extension, automatically add authorization header +pub struct AuthorizedHttpClient { + auth_manager: Arc, + inner_client: HttpClient, +} + +impl AuthorizedHttpClient { + /// create new authorized http client + pub fn new(auth_manager: Arc, client: Option) -> Self { + let inner_client = client.unwrap_or_else(|| HttpClient::new()); + Self { + auth_manager, + inner_client, + } + } + + /// send authorized request + pub async fn request( + &self, + method: reqwest::Method, + url: U, + ) -> Result { + let request = self.inner_client.request(method, url); + self.auth_manager.prepare_request(request).await + } + + /// send get request + pub async fn get(&self, url: U) -> Result { + let request = self.request(reqwest::Method::GET, url).await?; + let response = request.send().await?; + self.auth_manager.handle_response(response).await + } + + /// send post request + pub async fn post(&self, url: U) -> Result { + self.request(reqwest::Method::POST, url).await + } +} diff --git a/crates/rmcp/src/transport/sse_auth.rs b/crates/rmcp/src/transport/sse_auth.rs new file mode 100644 index 0000000..e8e6957 --- /dev/null +++ b/crates/rmcp/src/transport/sse_auth.rs @@ -0,0 +1,207 @@ +use std::{sync::Arc, time::Duration}; + +use futures::{ + Future, Sink, Stream, StreamExt, future::BoxFuture, sink::SinkExt as FuturesSinkExt, + stream::BoxStream, +}; +use reqwest::{ + Client as HttpClient, IntoUrl, Url, + header::{ACCEPT, AUTHORIZATION, HeaderValue}, +}; +use sse_stream::{Error as SseError, Sse, SseStream}; +use thiserror::Error; +use tokio::sync::Mutex; + +use super::{ + auth::{AuthError, AuthorizationManager}, + sse::{SseClient, SseTransport, SseTransportError, SseTransportRetryConfig}, +}; +use crate::model::{ClientJsonRpcMessage, ServerJsonRpcMessage}; + +// SSE MIME type +const MIME_TYPE: &str = "text/event-stream"; +const HEADER_LAST_EVENT_ID: &str = "Last-Event-ID"; + +/// sse client with oauth2 authorization +#[derive(Clone)] +pub struct AuthorizedSseClient { + http_client: HttpClient, + sse_url: Url, + auth_manager: Arc>, + retry_config: SseTransportRetryConfig, +} + +impl AuthorizedSseClient { + /// create new authorized sse client + pub fn new( + url: U, + auth_manager: Arc>, + retry_config: Option, + ) -> Result> + where + U: IntoUrl, + { + let url = url.into_url().map_err(SseTransportError::from)?; + Ok(Self { + http_client: HttpClient::default(), + sse_url: url, + auth_manager, + retry_config: retry_config.unwrap_or_default(), + }) + } + + /// create authorized sse client with custom http client + pub async fn new_with_client( + url: U, + client: HttpClient, + auth_manager: Arc>, + retry_config: Option, + ) -> Result> + where + U: IntoUrl, + { + let url = url.into_url().map_err(SseTransportError::from)?; + Ok(Self { + http_client: client, + sse_url: url, + auth_manager, + retry_config: retry_config.unwrap_or_default(), + }) + } + + /// get access token, support retry + async fn get_token_with_retry(&self) -> Result> { + let mut retries = 0; + let max_retries = self.retry_config.max_times; + let base_delay = self.retry_config.min_duration; + + loop { + match self.auth_manager.lock().await.get_access_token().await { + Ok(token) => return Ok(token), + Err(AuthError::AuthorizationRequired) => { + return Err(SseTransportError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + "Authorization required", + ))); + } + Err(_e) => { + if retries >= max_retries.unwrap_or(0) { + return Err(SseTransportError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + "Authorization required", + ))); + } + retries += 1; + // todo: need to optimize + let delay = base_delay.as_millis(); + tokio::time::sleep(Duration::from_millis(delay as u64)).await; + } + } + } + } +} + +impl SseClient for AuthorizedSseClient { + fn connect( + &self, + last_event_id: Option, + ) -> BoxFuture< + 'static, + Result>, SseTransportError>, + > { + let client = self.http_client.clone(); + let sse_url = self.sse_url.as_ref().to_string(); + let last_event_id = last_event_id.clone(); + let auth_manager = self.auth_manager.clone(); + + let fut = async move { + // get access token + let token = auth_manager.lock().await.get_access_token().await?; + + // build request + let mut request_builder = client + .get(&sse_url) + .header(ACCEPT, MIME_TYPE) + .header(AUTHORIZATION, format!("Bearer {}", token)); + + if let Some(last_event_id) = last_event_id { + request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); + } + + let response = request_builder.send().await?; + let response = response.error_for_status()?; + + match response.headers().get(reqwest::header::CONTENT_TYPE) { + Some(ct) => { + if !ct.as_bytes().starts_with(MIME_TYPE.as_bytes()) { + return Err(SseTransportError::UnexpectedContentType(Some(ct.clone()))); + } + } + None => { + return Err(SseTransportError::UnexpectedContentType(None)); + } + } + + let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); + Ok(event_stream) + }; + + Box::pin(fut) + } + + fn post( + &self, + session_id: &str, + message: ClientJsonRpcMessage, + ) -> BoxFuture<'static, Result<(), SseTransportError>> { + let client = self.http_client.clone(); + let sse_url = self.sse_url.clone(); + let session_id = session_id.to_string(); + let auth_manager = self.auth_manager.clone(); + + Box::pin(async move { + // get access token + let token = auth_manager + .lock() + .await + .get_access_token() + .await + .map_err(|e| SseTransportError::::from(e))?; + + let uri = sse_url.join(&session_id).map_err(SseTransportError::from)?; + let request_builder = client + .post(uri.as_ref()) + .header(AUTHORIZATION, format!("Bearer {}", token)) + .json(&message); + + request_builder + .send() + .await + .and_then(|resp| resp.error_for_status()) + .map_err(SseTransportError::from) + .map(drop) + }) + } +} + +impl From for SseTransportError { + fn from(err: AuthError) -> Self { + SseTransportError::Io(std::io::Error::new( + std::io::ErrorKind::Other, + err.to_string(), + )) + } +} + +/// create authorized sse transport +pub async fn create_authorized_transport( + url: U, + auth_manager: Arc>, + retry_config: Option, +) -> Result, SseTransportError> +where + U: IntoUrl, +{ + let client = AuthorizedSseClient::new(url, auth_manager, retry_config)?; + SseTransport::start_with_client(client).await +} diff --git a/examples/clients/Cargo.toml b/examples/clients/Cargo.toml index 52226ed..312def3 100644 --- a/examples/clients/Cargo.toml +++ b/examples/clients/Cargo.toml @@ -11,7 +11,8 @@ rmcp = { path = "../../crates/rmcp", features = [ "client", "transport-sse", "transport-child-process", - "tower" + "tower", + "auth" ] } tokio = { version = "1", features = ["full"] } serde = { version = "1.0", features = ["derive"] } @@ -21,8 +22,9 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } rand = "0.8" futures = "0.3" anyhow = "1.0" - +url = "2.4" tower = "0.5" +axum = "0.8" [[example]] name = "clients_sse" @@ -40,3 +42,6 @@ path = "src/everything_stdio.rs" name = "clients_collection" path = "src/collection.rs" +[[example]] +name = "oauth_client" +path = "src/oauth_client.rs" \ No newline at end of file diff --git a/examples/clients/src/oauth_client.rs b/examples/clients/src/oauth_client.rs new file mode 100644 index 0000000..25c3531 --- /dev/null +++ b/examples/clients/src/oauth_client.rs @@ -0,0 +1,273 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use anyhow::{Context, Result}; +use axum::{ + Router, + extract::{Query, State}, + response::{Html, Redirect}, + routing::get, +}; +use rmcp::{ + ServiceExt, + model::ClientInfo, + transport::{ + auth::{AuthError, AuthorizationManager, AuthorizationSession}, + create_authorized_transport, + sse::SseTransportRetryConfig, + }, +}; +use serde::Deserialize; +use tokio::{ + io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter}, + sync::{Mutex, oneshot}, +}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const MCP_SERVER_URL: &str = "http://localhost:3000/mcp"; +const MCP_REDIRECT_URI: &str = "http://localhost:8080/callback"; +const MCP_SSE_URL: &str = "http://localhost:3000/mcp/sse"; +const CALLBACK_PORT: u16 = 8080; + +#[derive(Clone)] +struct AppState { + auth_session: Arc, + code_receiver: Arc>>>, +} + +#[derive(Debug, Deserialize)] +struct CallbackParams { + code: String, + state: Option, +} + +async fn callback_handler( + Query(params): Query, + State(state): State, +) -> Html { + tracing::info!("Received callback with code: {}", params.code); + + // Send the code to the main thread + if let Some(sender) = state.code_receiver.lock().await.take() { + let _ = sender.send(params.code); + } + + // Return success page + Html(format!( + r#" + + + + OAuth Authorization Success + + + +
+
+

Authorization Successful

+

You have successfully authorized the MCP client. You can now close this window and return to the application.

+
+ + + "# + )) +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Get server URL + let server_url = MCP_SERVER_URL.to_string(); + tracing::info!("Using MCP server URL: {}", server_url); + + // Configure retry settings + let retry_config = SseTransportRetryConfig { + max_times: Some(3), + min_duration: Duration::from_secs(1), + }; + + // Initialize authorization manager + let auth_manager = AuthorizationManager::new(&server_url) + .await + .context("Failed to initialize authorization manager")?; + let auth_manager_arc = Arc::new(Mutex::new(auth_manager)); + + // Create authorization session + let session = AuthorizationSession::new( + auth_manager_arc.clone(), + &["mcp", "profile", "email"], + &MCP_REDIRECT_URI, + ) + .await + .context("Failed to create authorization session")?; + + let session_arc = Arc::new(session); + + // Create channel for receiving authorization code + let (code_sender, code_receiver) = oneshot::channel::(); + + // Create app state + let app_state = AppState { + auth_session: session_arc.clone(), + code_receiver: Arc::new(Mutex::new(Some(code_sender))), + }; + + // Start HTTP server for handling callbacks + let app = Router::new() + .route("/callback", get(callback_handler)) + .with_state(app_state); + + let addr = SocketAddr::from(([127, 0, 0, 1], CALLBACK_PORT)); + tracing::info!("Starting callback server at: http://{}", addr); + + // Start server in a separate task + tokio::spawn(async move { + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + let result = axum::serve(listener, app).await; + + if let Err(e) = result { + tracing::error!("Callback server error: {}", e); + } + }); + + // Output authorization URL to user + let mut output = BufWriter::new(tokio::io::stdout()); + output.write_all(b"\n=== MCP OAuth Client ===\n\n").await?; + output + .write_all(b"Please open the following URL in your browser to authorize:\n\n") + .await?; + output + .write_all(session_arc.get_authorization_url().as_bytes()) + .await?; + output + .write_all(b"\n\nWaiting for browser callback, please do not close this window...\n") + .await?; + output.flush().await?; + + // Wait for authorization code + tracing::info!("Waiting for authorization code..."); + let auth_code = code_receiver + .await + .context("Failed to get authorization code")?; + tracing::info!("Received authorization code: {}", auth_code); + // Exchange code for access token + tracing::info!("Exchanging authorization code for access token..."); + let credentials = match session_arc.handle_callback(&auth_code).await { + Ok(creds) => { + tracing::info!("Successfully obtained access token"); + creds + } + Err(e) => { + tracing::error!("Failed to obtain access token: {}", e); + return Err(anyhow::anyhow!("Authorization failed: {}", e)); + } + }; + tracing::info!("Access token: {:?}", credentials); + + output + .write_all(b"\nAuthorization successful! Access token obtained.\n\n") + .await?; + output.flush().await?; + + // Create authorized transport + tracing::info!("Establishing authorized connection to MCP server..."); + let transport = match create_authorized_transport( + MCP_SSE_URL.to_string(), + auth_manager_arc, + Some(retry_config), + ) + .await + { + Ok(t) => t, + Err(e) => { + tracing::error!("Failed to create authorized transport: {}", e); + return Err(anyhow::anyhow!("Connection failed: {}", e)); + } + }; + + // Create client and connect to MCP server + let client_service = ClientInfo::default(); + let client = client_service.serve(transport).await?; + tracing::info!("Successfully connected to MCP server"); + + // Test API requests + output + .write_all(b"Fetching available tools from server...\n") + .await?; + output.flush().await?; + + match client.peer().list_all_tools().await { + Ok(tools) => { + output + .write_all(format!("Available tools: {}\n\n", tools.len()).as_bytes()) + .await?; + for tool in tools { + output + .write_all( + format!( + "- {} ({})\n", + tool.name, + tool.description.unwrap_or_default() + ) + .as_bytes(), + ) + .await?; + } + } + Err(e) => { + output + .write_all(format!("Error fetching tools: {}\n", e).as_bytes()) + .await?; + } + } + + output + .write_all(b"\nFetching available prompts from server...\n") + .await?; + output.flush().await?; + + match client.peer().list_all_prompts().await { + Ok(prompts) => { + output + .write_all(format!("Available prompts: {}\n\n", prompts.len()).as_bytes()) + .await?; + for prompt in prompts { + output + .write_all(format!("- {}\n", prompt.name).as_bytes()) + .await?; + } + } + Err(e) => { + output + .write_all(format!("Error fetching prompts: {}\n", e).as_bytes()) + .await?; + } + } + + output + .write_all(b"\nConnection established successfully. You are now authenticated with the MCP server.\n") + .await?; + output.flush().await?; + + // Keep the program running, wait for user input to exit + output.write_all(b"\nPress Enter to exit...\n").await?; + output.flush().await?; + + let mut input = String::new(); + let mut reader = BufReader::new(tokio::io::stdin()); + reader.read_line(&mut input).await?; + + Ok(()) +} diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 63d9d6d..6fae7ed 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -1,5 +1,3 @@ - - [package] name = "mcp-server-examples" version = "0.1.5" @@ -7,7 +5,7 @@ edition = "2024" publish = false [dependencies] -rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io"] } +rmcp= { path = "../../crates/rmcp", features = ["server", "transport-sse-server", "transport-io", "auth"] } tokio = { version = "1", features = ["macros", "rt", "rt-multi-thread", "io-std", "signal"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -19,9 +17,14 @@ tracing-subscriber = { version = "0.3", features = [ "fmt", ] } futures = "0.3" -rand = { version = "0.9" } +rand = { version = "0.8", features = ["std"] } axum = { version = "0.8", features = ["macros"] } schemars = { version = "0.8", optional = true } +reqwest = { version = "0.12", features = ["json"] } +chrono = "0.4" +uuid = { version = "1.6", features = ["v4", "serde"] } +serde_urlencoded = "0.7" +hyper = "1.3" # [dev-dependencies.'cfg(target_arch="linux")'.dependencies] [dev-dependencies] @@ -43,4 +46,12 @@ path = "src/axum_router.rs" [[example]] name = "servers_generic_server" -path = "src/generic_service.rs" \ No newline at end of file +path = "src/generic_service.rs" + +[[example]] +name = "servers_auth_sse" +path = "src/auth_sse.rs" + +[[example]] +name = "mcp_oauth_server" +path = "src/mcp_oauth_server.rs" \ No newline at end of file diff --git a/examples/servers/src/auth_sse.rs b/examples/servers/src/auth_sse.rs new file mode 100644 index 0000000..6389328 --- /dev/null +++ b/examples/servers/src/auth_sse.rs @@ -0,0 +1,214 @@ +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +use anyhow::Result; +use axum::{ + Json, Router, + extract::{Path, State}, + http::{HeaderMap, Request, StatusCode}, + middleware::{self, Next}, + response::{Html, Response}, + routing::get, +}; +use rmcp::transport::{SseServer, sse_server::SseServerConfig}; +use tokio::sync::Mutex; +use tokio_util::sync::CancellationToken; +mod common; +use common::{calculator::Calculator, counter::Counter}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; +// A simple token store +struct TokenStore { + valid_tokens: Vec, +} + +impl TokenStore { + fn new() -> Self { + // For demonstration purposes, use more secure token management in production + Self { + valid_tokens: vec!["demo-token".to_string(), "test-token".to_string()], + } + } + + fn is_valid(&self, token: &str) -> bool { + self.valid_tokens.contains(&token.to_string()) + } +} + +// Extract authorization token +fn extract_token(headers: &HeaderMap) -> Option { + headers + .get("Authorization") + .and_then(|value| value.to_str().ok()) + .and_then(|auth_header| { + if auth_header.starts_with("Bearer ") { + Some(auth_header[7..].to_string()) + } else { + None + } + }) +} + +// Authorization middleware +async fn auth_middleware( + State(token_store): State>, + headers: HeaderMap, + request: Request, + next: Next, +) -> Result { + match extract_token(&headers) { + Some(token) if token_store.is_valid(&token) => { + // Token is valid, proceed with the request + Ok(next.run(request).await) + } + _ => { + // Token is invalid, return 401 error + Err(StatusCode::UNAUTHORIZED) + } + } +} + +// Root path handler +async fn index() -> Html<&'static str> { + Html( + r#" + + + + RMCP Authorized SSE Server + + + +

RMCP Authorized SSE Server

+

This is a Server-Sent Events server example that requires OAuth authorization.

+ +

Available Endpoints:

+
    +
  • /api/health - Health check
  • +
  • /api/token/{token_id} - Get test token (available: demo, test)
  • +
  • /sse - SSE connection endpoint (requires authorization)
  • +
  • /message - Message sending endpoint (requires authorization)
  • +
+ +

Usage:

+
+# Get a token
+curl http://127.0.0.1:8000/api/token/demo
+
+# Connect to SSE using the token
+curl -H "Authorization: Bearer demo-token" http://127.0.0.1:8000/sse
+        
+ + + "#, + ) +} + +// Health check endpoint +async fn health_check() -> &'static str { + "OK" +} + +// Token generation endpoint (simplified example) +async fn get_token(Path(token_id): Path) -> Result, StatusCode> { + // In a real application, you should authenticate the user and generate a real token + if token_id == "demo" || token_id == "test" { + let token = format!("{}-token", token_id); + Ok(Json(serde_json::json!({ + "access_token": token, + "token_type": "Bearer", + "expires_in": 3600 + }))) + } else { + Err(StatusCode::UNAUTHORIZED) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Create token store + let token_store = Arc::new(TokenStore::new()); + + // Set up port + let addr = BIND_ADDRESS.parse::()?; + + // Create SSE server configuration + let sse_config = SseServerConfig { + bind: addr, + sse_path: "/sse".to_string(), + post_path: "/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + // Create SSE server + let (sse_server, sse_router) = SseServer::new(sse_config); + + // Create API routes + let api_routes = Router::new() + .route("/health", get(health_check)) + .route("/token/{token_id}", get(get_token)); + + // Create protected SSE routes (require authorization) + let protected_sse_router = sse_router.layer(middleware::from_fn_with_state( + token_store.clone(), + auth_middleware, + )); + + // Create main router, public endpoints don't require authorization + let app = Router::new() + .route("/", get(index)) + .nest("/api", api_routes) + .merge(protected_sse_router) + .with_state(()); + + // Start server and register service + let listener = tokio::net::TcpListener::bind(addr).await?; + let ct = sse_server.config.ct.clone(); + + // Start SSE server with Counter service + sse_server.with_service(Counter::new); + + // Handle signals for graceful shutdown + let cancel_token = ct.clone(); + tokio::spawn(async move { + match tokio::signal::ctrl_c().await { + Ok(()) => { + println!("Received Ctrl+C, shutting down server..."); + cancel_token.cancel(); + } + Err(err) => { + eprintln!("Unable to listen for Ctrl+C signal: {}", err); + } + } + }); + + // Start HTTP server + tracing::info!("Server started on {}", addr); + let server = axum::serve(listener, app).with_graceful_shutdown(async move { + // Wait for cancellation signal + ct.cancelled().await; + println!("Server is shutting down..."); + }); + + if let Err(e) = server.await { + eprintln!("Server error: {}", e); + } + + println!("Server has been shut down"); + Ok(()) +} diff --git a/examples/servers/src/mcp_oauth_server.rs b/examples/servers/src/mcp_oauth_server.rs new file mode 100644 index 0000000..a128216 --- /dev/null +++ b/examples/servers/src/mcp_oauth_server.rs @@ -0,0 +1,798 @@ +use std::{collections::HashMap, net::SocketAddr, sync::Arc, time::Duration, usize}; + +use anyhow::Result; +use axum::{ + Json, Router, + body::Body, + extract::{Form, Path, Query, State}, + http::{HeaderMap, Method, Request, StatusCode, Uri}, + middleware::{self, Next}, + response::{Html, IntoResponse, Redirect, Response}, + routing::{get, post}, +}; +use hyper::body; +use rand::{Rng, distributions::Alphanumeric}; +use reqwest::Client as HttpClient; +use rmcp::transport::{ + SseServer, + auth::{ + AuthorizationMetadata, ClientRegistrationRequest, ClientRegistrationResponse, + OAuthClientConfig, + }, + sse_server::SseServerConfig, +}; +use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock; +use tokio_util::sync::CancellationToken; +use tracing::{debug, error, info}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use uuid::Uuid; +// Import Counter tool for MCP service +mod common; +use common::counter::Counter; + +const BIND_ADDRESS: &str = "127.0.0.1:3000"; + +// MCP OAuth Store for managing tokens and sessions +#[derive(Clone, Debug)] +struct McpOAuthStore { + clients: Arc>>, + auth_sessions: Arc>>, + access_tokens: Arc>>, + http_client: HttpClient, +} + +impl McpOAuthStore { + fn new() -> Self { + let mut clients = HashMap::new(); + clients.insert( + "mcp-client".to_string(), + OAuthClientConfig { + client_id: "mcp-client".to_string(), + client_secret: Some("mcp-client-secret".to_string()), + scopes: vec!["profile".to_string(), "email".to_string()], + redirect_uri: "http://localhost:8080/callback".to_string(), + }, + ); + + Self { + clients: Arc::new(RwLock::new(clients)), + auth_sessions: Arc::new(RwLock::new(HashMap::new())), + access_tokens: Arc::new(RwLock::new(HashMap::new())), + http_client: HttpClient::builder() + .timeout(Duration::from_secs(30)) + .build() + .expect("Failed to create HTTP client"), + } + } + + async fn validate_client( + &self, + client_id: &str, + redirect_uri: &str, + ) -> Option { + let clients = self.clients.read().await; + if let Some(client) = clients.get(client_id) { + if client.redirect_uri.contains(&redirect_uri.to_string()) { + return Some(client.clone()); + } + } + None + } + + async fn create_auth_session( + &self, + client_id: String, + redirect_uri: String, + scope: Option, + state: Option, + session_id: String, + ) -> String { + let session = AuthSession { + id: session_id.clone(), + client_id, + redirect_uri, + scope, + state, + created_at: chrono::Utc::now(), + auth_token: None, + }; + + self.auth_sessions + .write() + .await + .insert(session_id.clone(), session); + session_id + } + + async fn update_auth_session_token( + &self, + session_id: &str, + token: AuthToken, + ) -> Result<(), String> { + let mut sessions = self.auth_sessions.write().await; + if let Some(session) = sessions.get_mut(session_id) { + session.auth_token = Some(token); + Ok(()) + } else { + Err("Session not found".to_string()) + } + } + + async fn create_mcp_token(&self, session_id: &str) -> Result { + let sessions = self.auth_sessions.read().await; + if let Some(session) = sessions.get(session_id) { + if let Some(auth_token) = &session.auth_token { + let access_token = format!("mcp-token-{}", Uuid::new_v4()); + let token = McpAccessToken { + access_token: access_token.clone(), + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: format!("mcp-refresh-{}", Uuid::new_v4()), + scope: session.scope.clone(), + auth_token: auth_token.clone(), + client_id: session.client_id.clone(), + }; + + self.access_tokens + .write() + .await + .insert(access_token.clone(), token.clone()); + Ok(token) + } else { + Err("No third-party token available for session".to_string()) + } + } else { + Err("Session not found".to_string()) + } + } + + async fn validate_token(&self, token: &str) -> Option { + self.access_tokens.read().await.get(token).cloned() + } +} + +#[derive(Clone, Debug)] +struct AuthSession { + id: String, + client_id: String, + redirect_uri: String, + scope: Option, + state: Option, + created_at: chrono::DateTime, + auth_token: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +struct AuthToken { + access_token: String, + token_type: String, + expires_in: u64, + refresh_token: String, + scope: Option, +} + +#[derive(Clone, Debug, Serialize)] +struct McpAccessToken { + access_token: String, + token_type: String, + expires_in: u64, + refresh_token: String, + scope: Option, + auth_token: AuthToken, + client_id: String, +} + +#[derive(Debug, Deserialize)] +struct AuthorizeQuery { + response_type: String, + client_id: String, + redirect_uri: String, + scope: Option, + state: Option, +} + +#[derive(Debug, Deserialize)] +struct AuthCallbackQuery { + code: String, + state: Option, + session_id: String, +} + +#[derive(Debug, Deserialize, Serialize)] +struct TokenRequest { + grant_type: String, + code: String, + #[serde(default)] + client_id: String, + #[serde(default)] + client_secret: String, + redirect_uri: String, + #[serde(default)] + code_verifier: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +struct UserInfo { + sub: String, + name: String, + email: String, + username: String, +} + +fn generate_random_string(length: usize) -> String { + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(length) + .map(char::from) + .collect() +} + +// Root path handler +async fn index() -> Html<&'static str> { + Html( + r#" + + + + MCP OAuth Server + + + +

MCP OAuth Server

+

This is an MCP server with OAuth 2.0 integration to a third-party authorization server.

+ +

Available Endpoints:

+ +
+

Authorization Endpoint

+

GET /oauth/authorize

+

Parameters:

+
    +
  • response_type - Must be "code"
  • +
  • client_id - Client identifier (e.g., "mcp-client")
  • +
  • redirect_uri - URI to redirect after authorization
  • +
  • scope - Optional requested scope
  • +
  • state - Optional state value for CSRF prevention
  • +
+
+ +
+

Token Endpoint

+

POST /oauth/token

+

Parameters:

+
    +
  • grant_type - Must be "authorization_code"
  • +
  • code - The authorization code
  • +
  • client_id - Client identifier
  • +
  • client_secret - Client secret
  • +
  • redirect_uri - Redirect URI used in authorization request
  • +
+
+ +
+

MCP SSE Endpoints

+

/mcp/sse - SSE connection endpoint (requires OAuth token)

+

/mcp/message - Message endpoint (requires OAuth token)

+
+ +
+

OAuth Flow:

+
    +
  1. MCP Client initiates OAuth flow with this MCP Server
  2. +
  3. MCP Server redirects to Third-Party OAuth Server
  4. +
  5. User authenticates with Third-Party Server
  6. +
  7. Third-Party Server redirects back to MCP Server with auth code
  8. +
  9. MCP Server exchanges the code for a third-party access token
  10. +
  11. MCP Server generates its own token bound to the third-party session
  12. +
  13. MCP Server completes the OAuth flow with the MCP Client
  14. +
+
+ + + "#, + ) +} + +// Initial OAuth authorize endpoint +async fn oauth_authorize( + Query(params): Query, + State(state): State>, +) -> impl IntoResponse { + if let Some(client) = state + .validate_client(¶ms.client_id, ¶ms.redirect_uri) + .await + { + // create authorize page for user to approve + let html = format!( + r#" + + + + MCP OAuth + + + +

MCP OAuth Server

+
+
+

{client_id} requests access to your account.

+

requested scopes: {scopes}

+
+ +
+ + + + + +
+ + +
+
+
+ + + "#, + client_id = params.client_id, + redirect_uri = params.redirect_uri, + scope = params.scope.clone().unwrap_or_default(), + state = params.state.clone().unwrap_or_default(), + scopes = params + .scope + .clone() + .unwrap_or_else(|| "basic access".to_string()), + ); + + Html(html).into_response() + } else { + // invalid client_id or redirect_uri + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "invalid client id or redirect uri" + })), + ) + .into_response() + } +} + +// handle approval of authorization +#[derive(Debug, Deserialize)] +struct ApprovalForm { + client_id: String, + redirect_uri: String, + scope: String, + state: String, + approved: String, +} + +async fn oauth_approve( + State(state): State>, + Form(form): Form, +) -> impl IntoResponse { + if form.approved != "true" { + // user rejected the authorization request + let redirect_url = format!( + "{}?error=access_denied&error_description={}{}", + form.redirect_uri, + "user rejected the authorization request", + form.state + .is_empty() + .then_some("") + .unwrap_or(&format!("&state={}", form.state)) + ); + return Redirect::to(&redirect_url).into_response(); + } + + // user approved the authorization request, generate authorization code + let session_id = Uuid::new_v4().to_string(); + let auth_code = format!("mcp-code-{}", session_id); + + // create new session record authorization information + let session_id = state + .create_auth_session( + form.client_id.clone(), + form.redirect_uri.clone(), + Some(form.scope.clone()), + Some(form.state.clone()), + session_id.clone(), + ) + .await; + + // create token + let created_token = AuthToken { + access_token: format!("tp-token-{}", Uuid::new_v4()), + token_type: "Bearer".to_string(), + expires_in: 3600, + refresh_token: format!("tp-refresh-{}", Uuid::new_v4()), + scope: Some(form.scope), + }; + + // update session token + if let Err(e) = state + .update_auth_session_token(&session_id, created_token) + .await + { + error!("Failed to update session token: {}", e); + } + + // redirect back to client, with authorization code + let redirect_url = format!( + "{}?code={}{}", + form.redirect_uri, + auth_code, + form.state + .is_empty() + .then_some("") + .unwrap_or(&format!("&state={}", form.state)) + ); + + info!("authorization approved, redirecting to: {}", redirect_url); + Redirect::to(&redirect_url).into_response() +} + +// Handle token request from the MCP client +async fn oauth_token( + State(state): State>, + request: axum::http::Request, +) -> impl IntoResponse { + info!("Received token request"); + + let bytes = match axum::body::to_bytes(request.into_body(), usize::MAX).await { + Ok(bytes) => bytes, + Err(e) => { + error!("can't read request body: {}", e); + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "can't read request body" + })), + ) + .into_response(); + } + }; + + let body_str = String::from_utf8_lossy(&bytes); + info!("request body: {}", body_str); + + let token_req = match serde_urlencoded::from_bytes::(&bytes) { + Ok(form) => { + info!("successfully parsed form data: {:?}", form); + form + } + Err(e) => { + error!("can't parse form data: {}", e); + return ( + StatusCode::UNPROCESSABLE_ENTITY, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": format!("can't parse form data: {}", e) + })), + ) + .into_response(); + } + }; + + + if token_req.grant_type != "authorization_code" { + info!("unsupported grant type: {}", token_req.grant_type); + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "unsupported_grant_type", + "error_description": "only authorization_code is supported" + })), + ) + .into_response(); + } + + // get session_id from code + if !token_req.code.starts_with("mcp-code-") { + info!("invalid authorization code: {}", token_req.code); + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_grant", + "error_description": "invalid authorization code" + })), + ) + .into_response(); + } + + // handle empty client_id + let client_id = if token_req.client_id.is_empty() { + "mcp-client".to_string() + } else { + token_req.client_id.clone() + }; + + // validate client + match state + .validate_client(&client_id, &token_req.redirect_uri) + .await + { + Some(_) => { + let session_id = token_req.code.replace("mcp-code-", ""); + info!("got session id: {}", session_id); + + // create mcp access token + match state.create_mcp_token(&session_id).await { + Ok(token) => { + info!("successfully created access token"); + ( + StatusCode::OK, + Json(serde_json::json!({ + "access_token": token.access_token, + "token_type": token.token_type, + "expires_in": token.expires_in, + "refresh_token": token.refresh_token, + "scope": token.scope, + })), + ) + .into_response() + } + Err(e) => { + error!("failed to create access token: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": "server_error", + "error_description": format!("failed to create access token: {}", e) + })), + ) + .into_response() + } + } + } + None => { + info!( + "invalid client id or redirect uri: {} / {}", + client_id, token_req.redirect_uri + ); + ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_client", + "error_description": "invalid client id or redirect uri" + })), + ) + .into_response() + } + } +} + +// Auth middleware for SSE connections +async fn validate_token_middleware( + State(token_store): State>, + request: Request, +) -> Result, StatusCode> { + // Extract the access token from the Authorization header + let auth_header = request.headers().get("Authorization"); + let token = match auth_header { + Some(header) => { + let header_str = header.to_str().unwrap_or(""); + if header_str.starts_with("Bearer ") { + header_str[7..].to_string() + } else { + return Err(StatusCode::UNAUTHORIZED); + } + } + None => { + return Err(StatusCode::UNAUTHORIZED); + } + }; + + // Validate the token + match token_store.validate_token(&token).await { + Some(_) => Ok(Some(token)), + None => Err(StatusCode::UNAUTHORIZED), + } +} + +// handle oauth server metadata request +async fn oauth_authorization_server() -> impl IntoResponse { + let metadata = AuthorizationMetadata { + authorization_endpoint: format!("http://{}/oauth/authorize", BIND_ADDRESS), + token_endpoint: format!("http://{}/oauth/token", BIND_ADDRESS), + scopes_supported: Some(vec!["profile".to_string(), "email".to_string()]), + registration_endpoint: format!("http://{}/oauth/register", BIND_ADDRESS), + issuer: Some(format!("{}", BIND_ADDRESS)), + jwks_uri: Some(format!("http://{}/oauth/jwks", BIND_ADDRESS)), + }; + debug!("metadata: {:?}", metadata); + (StatusCode::OK, Json(metadata)) +} + +// handle client registration request +async fn oauth_register( + State(state): State>, + Json(req): Json, +) -> impl IntoResponse { + debug!("register request: {:?}", req); + if req.redirect_uris.is_empty() { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({ + "error": "invalid_request", + "error_description": "at least one redirect uri is required" + })), + ) + .into_response(); + } + + // generate client id and secret + let client_id = format!("client-{}", Uuid::new_v4()); + let client_secret = generate_random_string(32); + + let client = OAuthClientConfig { + client_id: client_id.clone(), + client_secret: Some(client_secret.clone()), + redirect_uri: req.redirect_uris[0].clone(), + scopes: vec![], + }; + + state + .clients + .write() + .await + .insert(client_id.clone(), client); + + // return client information + let response = ClientRegistrationResponse { + client_id, + client_secret, + client_name: req.client_name, + redirect_uris: req.redirect_uris, + }; + + (StatusCode::CREATED, Json(response)).into_response() +} + +// Log all HTTP requests +async fn log_request(request: Request, next: Next) -> Response { + let method = request.method().clone(); + let uri = request.uri().clone(); + let version = request.version(); + + // Log headers + let headers = request.headers().clone(); + let mut header_log = String::new(); + for (key, value) in headers.iter() { + let value_str = match value.to_str() { + Ok(v) => v, + Err(_) => "", + }; + header_log.push_str(&format!("\n {}: {}", key, value_str)); + } + + // Try to get request body for form submissions + let content_type = headers + .get("content-type") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + let request_info = if content_type.contains("application/x-www-form-urlencoded") + || content_type.contains("application/json") + { + format!( + "{} {} {:?}{}\nContent-Type: {}", + method, uri, version, header_log, content_type + ) + } else { + format!("{} {} {:?}{}", method, uri, version, header_log) + }; + + info!("REQUEST: {}", request_info); + + // Call the actual handler + let response = next.run(request).await; + + // Log response status + let status = response.status(); + info!("RESPONSE: {} for {} {}", status, method, uri); + + response +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "debug".to_string().into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // Create the OAuth store + let oauth_store = Arc::new(McpOAuthStore::new()); + + // Set up port + let addr = BIND_ADDRESS.parse::()?; + + // Create SSE server configuration for MCP + let sse_config = SseServerConfig { + bind: addr.clone(), + sse_path: "/mcp/sse".to_string(), + post_path: "/mcp/message".to_string(), + ct: CancellationToken::new(), + sse_keep_alive: Some(Duration::from_secs(15)), + }; + + // Create SSE server + let (sse_server, sse_router) = SseServer::new(sse_config); + + // Create protected SSE routes (require authorization) + // let protected_sse_router = sse_router.layer(middleware::from_fn_with_state( + // oauth_store.clone(), + // validate_token_middleware, + // )); + + // Create HTTP router with request logging middleware + let app = Router::new() + .route("/", get(index)) + .route("/mcp", get(index)) + .route( + "/.well-known/oauth-authorization-server", + get(oauth_authorization_server), + ) + .route("/oauth/authorize", get(oauth_authorize)) + .route("/oauth/approve", post(oauth_approve)) + .route("/oauth/token", post(oauth_token)) + .route("/oauth/register", post(oauth_register)) + // .merge(protected_sse_router) + .with_state(oauth_store.clone()) + .layer(middleware::from_fn(log_request)); + + let app = app.merge(sse_router.with_state(())); + // Register token validation middleware for SSE + let cancel_token = sse_server.config.ct.clone(); + // Handle Ctrl+C + let cancel_token2 = sse_server.config.ct.clone(); + // Start SSE server with Counter service + sse_server.with_service(Counter::new); + + // Start HTTP server + info!("MCP OAuth Server started on {}", addr); + let listener = tokio::net::TcpListener::bind(addr).await?; + let server = axum::serve(listener, app).with_graceful_shutdown(async move { + cancel_token.cancelled().await; + info!("Server is shutting down"); + }); + + tokio::spawn(async move { + match tokio::signal::ctrl_c().await { + Ok(()) => { + info!("Received Ctrl+C, shutting down"); + cancel_token2.cancel(); + } + Err(e) => error!("Failed to listen for Ctrl+C: {}", e), + } + }); + + if let Err(e) = server.await { + error!("Server error: {}", e); + } + + Ok(()) +}