Skip to content

Commit 36425e1

Browse files
committed
fix(sse-client): consume control frames; refresh message endpoint
1 parent ddef4ce commit 36425e1

File tree

2 files changed

+188
-15
lines changed

2 files changed

+188
-15
lines changed

crates/rmcp/src/transport/common/client_side_sse.rs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,29 @@ impl<E: std::error::Error + Send> SseStreamReconnect for NeverReconnect<E> {
9898
}
9999
}
100100

101+
/// Abstraction for SSE reconnection logic. Implementors can hook into
102+
/// [`handle_control_event`](Self::handle_control_event) to consume control
103+
/// frames (e.g. `event: endpoint`) that arrive when a server restarts an SSE
104+
/// stream. The default implementation is a no-op, keeping existing behaviour
105+
/// intact.
101106
pub(crate) trait SseStreamReconnect {
102107
type Error: std::error::Error;
103108
type Future: Future<Output = Result<BoxedSseResponse, Self::Error>> + Send;
104109
fn retry_connection(&mut self, last_event_id: Option<&str>) -> Self::Future;
110+
fn handle_control_event(&mut self, _event: &Sse) -> Result<(), Self::Error> {
111+
Ok(())
112+
}
113+
fn handle_stream_error(
114+
&mut self,
115+
error: &(dyn std::error::Error + 'static),
116+
last_event_id: Option<&str>,
117+
) {
118+
if let Some(id) = last_event_id {
119+
tracing::warn!(%id, "sse stream error: {error}");
120+
} else {
121+
tracing::warn!("sse stream error: {error}");
122+
}
123+
}
105124
}
106125

107126
pin_project_lite::pin_project! {
@@ -189,14 +208,31 @@ where
189208
*this.server_retry_interval =
190209
Some(Duration::from_millis(new_server_retry));
191210
}
192-
if let Some(event_id) = sse.id {
193-
*this.last_event_id = Some(event_id);
211+
if let Some(ref event_id) = sse.id {
212+
*this.last_event_id = Some(event_id.clone());
213+
}
214+
// Only treat blank/`message` events as JSON-RPC payloads.
215+
// Other control frames (endpoint, ping, etc.) are passed to
216+
// the reconnection handler.
217+
let is_message_event =
218+
matches!(sse.event.as_deref(), None | Some("") | Some("message"));
219+
if !is_message_event {
220+
match this.connector.handle_control_event(&sse) {
221+
Ok(()) => return self.poll_next(cx),
222+
Err(e) => {
223+
this.state.set(SseAutoReconnectStreamState::Terminated);
224+
return Poll::Ready(Some(Err(e)));
225+
}
226+
}
194227
}
195228
if let Some(data) = sse.data {
196229
match serde_json::from_str::<ServerJsonRpcMessage>(&data) {
197230
Err(e) => {
198-
// not sure should this be a hard error
199-
tracing::warn!("failed to deserialize server message: {e}");
231+
// Downgrade to debug to avoid noisy logs when servers emit
232+
// non-JSON payloads as message frames. Include last_event_id
233+
// to aid troubleshooting while keeping default behaviour.
234+
let last_id = this.last_event_id.as_deref().unwrap_or("");
235+
tracing::debug!(last_event_id=%last_id, "failed to deserialize server message: {e}");
200236
return self.poll_next(cx);
201237
}
202238
Ok(message) => {
@@ -208,7 +244,8 @@ where
208244
}
209245
}
210246
Some(Err(e)) => {
211-
tracing::warn!("sse stream error: {e}");
247+
this.connector
248+
.handle_stream_error(&e, this.last_event_id.as_deref());
212249
let retrying = this
213250
.connector
214251
.retry_connection(this.last_event_id.as_deref());

crates/rmcp/src/transport/sse_client.rs

Lines changed: 146 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
//! reference: https://html.spec.whatwg.org/multipage/server-sent-events.html
2-
use std::{pin::Pin, sync::Arc};
1+
//! Reference: <https://html.spec.whatwg.org/multipage/server-sent-events.html>
2+
use std::{
3+
pin::Pin,
4+
sync::{Arc, RwLock},
5+
};
36

47
use futures::{StreamExt, future::BoxFuture};
58
use http::Uri;
6-
use sse_stream::Error as SseError;
9+
use sse_stream::{Error as SseError, Sse};
710
use thiserror::Error;
811

912
use super::{
@@ -54,9 +57,13 @@ pub trait SseClient: Clone + Send + Sync + 'static {
5457
) -> impl Future<Output = Result<BoxedSseResponse, SseTransportError<Self::Error>>> + Send + '_;
5558
}
5659

60+
/// Helper that refreshes the POST endpoint whenever the server emits
61+
/// control frames during SSE reconnect; used together with
62+
/// [`SseAutoReconnectStream`].
5763
struct SseClientReconnect<C> {
5864
pub client: C,
5965
pub uri: Uri,
66+
pub message_endpoint: Arc<RwLock<Uri>>,
6067
}
6168

6269
impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
@@ -68,6 +75,37 @@ impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
6875
let last_event_id = last_event_id.map(|s| s.to_owned());
6976
Box::pin(async move { client.get_stream(uri, last_event_id, None).await })
7077
}
78+
79+
fn handle_control_event(&mut self, event: &Sse) -> Result<(), Self::Error> {
80+
if event.event.as_deref() != Some("endpoint") {
81+
return Ok(());
82+
}
83+
let Some(data) = event.data.as_ref() else {
84+
return Ok(());
85+
};
86+
// Servers typically resend the message POST endpoint (often with a new
87+
// sessionId) when a stream reconnects. Reuse `message_endpoint` helper
88+
// to resolve it and update the shared URI.
89+
let new_endpoint = message_endpoint(self.uri.clone(), data.clone())
90+
.map_err(SseTransportError::InvalidUri)?;
91+
*self
92+
.message_endpoint
93+
.write()
94+
.expect("message endpoint lock poisoned") = new_endpoint;
95+
Ok(())
96+
}
97+
98+
fn handle_stream_error(
99+
&mut self,
100+
error: &(dyn std::error::Error + 'static),
101+
last_event_id: Option<&str>,
102+
) {
103+
tracing::warn!(
104+
uri = %self.uri,
105+
last_event_id = last_event_id.unwrap_or(""),
106+
"sse stream error: {error}"
107+
);
108+
}
71109
}
72110
type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<C>>>>;
73111

@@ -81,7 +119,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
81119
///
82120
/// ## Using reqwest
83121
///
84-
/// ```rust
122+
/// ```rust,ignore
85123
/// use rmcp::transport::SseClientTransport;
86124
///
87125
/// // Enable the reqwest feature in Cargo.toml:
@@ -95,7 +133,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
95133
///
96134
/// ## Using a custom HTTP client
97135
///
98-
/// ```rust
136+
/// ```rust,ignore
99137
/// use rmcp::transport::sse_client::{SseClient, SseClientTransport, SseClientConfig};
100138
/// use std::sync::Arc;
101139
/// use futures::stream::BoxStream;
@@ -154,7 +192,9 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
154192
pub struct SseClientTransport<C: SseClient> {
155193
client: C,
156194
config: SseClientConfig,
157-
message_endpoint: Uri,
195+
/// Current POST endpoint; refreshed when the server sends new endpoint
196+
/// control frames.
197+
message_endpoint: Arc<RwLock<Uri>>,
158198
stream: Option<ServerMessageStream<C>>,
159199
}
160200

@@ -168,8 +208,16 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
168208
item: crate::service::TxJsonRpcMessage<RoleClient>,
169209
) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
170210
let client = self.client.clone();
171-
let uri = self.message_endpoint.clone();
172-
async move { client.post_message(uri, item, None).await }
211+
let message_endpoint = self.message_endpoint.clone();
212+
async move {
213+
let uri = {
214+
let guard = message_endpoint
215+
.read()
216+
.expect("message endpoint lock poisoned");
217+
guard.clone()
218+
};
219+
client.post_message(uri, item, None).await
220+
}
173221
}
174222
async fn close(&mut self) -> Result<(), Self::Error> {
175223
self.stream.take();
@@ -194,7 +242,7 @@ impl<C: SseClient> SseClientTransport<C> {
194242
let sse_endpoint = config.sse_endpoint.as_ref().parse::<http::Uri>()?;
195243

196244
let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
197-
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
245+
let initial_message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
198246
let ep = endpoint.parse::<http::Uri>()?;
199247
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
200248
sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query;
@@ -214,12 +262,14 @@ impl<C: SseClient> SseClientTransport<C> {
214262
break message_endpoint(sse_endpoint.clone(), ep)?;
215263
}
216264
};
265+
let message_endpoint = Arc::new(RwLock::new(initial_message_endpoint));
217266

218267
let stream = Box::pin(SseAutoReconnectStream::new(
219268
sse_stream,
220269
SseClientReconnect {
221270
client: client.clone(),
222271
uri: sse_endpoint.clone(),
272+
message_endpoint: message_endpoint.clone(),
223273
},
224274
config.retry_policy.clone(),
225275
));
@@ -274,7 +324,7 @@ pub struct SseClientConfig {
274324
/// and the server send the message endpoint event as `message?session_id=123`,
275325
/// then the message endpoint will be `http://example.com/message`.
276326
///
277-
/// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN/docs/Web/API/URL/URL)
327+
/// This follows the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/en-US/docs/Web/API/URL/URL)
278328
pub sse_endpoint: Arc<str>,
279329
pub retry_policy: Arc<dyn SseRetryPolicy>,
280330
/// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
@@ -293,8 +343,40 @@ impl Default for SseClientConfig {
293343

294344
#[cfg(test)]
295345
mod tests {
346+
use futures::StreamExt;
347+
use serde_json::{Value, json};
348+
296349
use super::*;
297350

351+
#[derive(Clone)]
352+
struct DummyClient;
353+
354+
#[derive(Debug, thiserror::Error)]
355+
#[error("dummy error")]
356+
struct DummyError;
357+
358+
impl SseClient for DummyClient {
359+
type Error = DummyError;
360+
361+
async fn post_message(
362+
&self,
363+
_uri: Uri,
364+
_message: ClientJsonRpcMessage,
365+
_auth_token: Option<String>,
366+
) -> Result<(), SseTransportError<Self::Error>> {
367+
Ok(())
368+
}
369+
370+
async fn get_stream(
371+
&self,
372+
_uri: Uri,
373+
_last_event_id: Option<String>,
374+
_auth_token: Option<String>,
375+
) -> Result<BoxedSseResponse, SseTransportError<Self::Error>> {
376+
unreachable!("get_stream should not be called in this test")
377+
}
378+
}
379+
298380
#[test]
299381
fn test_message_endpoint() {
300382
let base_url = "https://localhost/sse".parse::<http::Uri>().unwrap();
@@ -319,4 +401,58 @@ mod tests {
319401
.unwrap();
320402
assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x");
321403
}
404+
405+
#[test]
406+
fn handle_endpoint_control_event_updates_uri() {
407+
let initial_endpoint = "https://example.com/message?sessionId=old"
408+
.parse::<Uri>()
409+
.unwrap();
410+
let shared_endpoint = Arc::new(RwLock::new(initial_endpoint));
411+
let mut reconnect = SseClientReconnect {
412+
client: DummyClient,
413+
uri: "https://example.com/sse".parse::<Uri>().unwrap(),
414+
message_endpoint: shared_endpoint.clone(),
415+
};
416+
417+
let control_event = Sse::default()
418+
.event("endpoint")
419+
.data("/message?sessionId=new");
420+
421+
reconnect.handle_control_event(&control_event).unwrap();
422+
423+
let guard = shared_endpoint.read().expect("lock poisoned");
424+
assert_eq!(
425+
guard.to_string(),
426+
"https://example.com/message?sessionId=new"
427+
);
428+
}
429+
430+
#[tokio::test]
431+
async fn control_event_frames_are_skipped() {
432+
let payload = json!({
433+
"jsonrpc": "2.0",
434+
"id": 1,
435+
"result": {"ok": true}
436+
})
437+
.to_string();
438+
439+
let events = vec![
440+
Ok(Sse::default()
441+
.event("endpoint")
442+
.data("/message?sessionId=reconnect")),
443+
Ok(Sse::default().event("message").data(payload.clone())),
444+
];
445+
446+
let sse_src: BoxedSseResponse = futures::stream::iter(events).boxed();
447+
let reconn_stream = SseAutoReconnectStream::never_reconnect(sse_src, DummyError);
448+
futures::pin_mut!(reconn_stream);
449+
450+
let message = reconn_stream.next().await.expect("stream item").unwrap();
451+
let actual: Value = serde_json::to_value(message).expect("serialize actual message");
452+
// We only need to assert that a valid JSON-RPC response came through after
453+
// skipping control frames. The exact `result` shape depends on the SDK's
454+
// typed result enums and is not asserted here.
455+
assert_eq!(actual.get("jsonrpc"), Some(&Value::String("2.0".into())));
456+
assert_eq!(actual.get("id"), Some(&Value::Number(1u64.into())));
457+
}
322458
}

0 commit comments

Comments
 (0)