1
- //! reference: https://html.spec.whatwg.org/multipage/server-sent-events.html
2
- use std:: { pin:: Pin , sync:: Arc } ;
1
+ //! Reference: <https://html.spec.whatwg.org/multipage/server-sent-events.html>
2
+ use std:: {
3
+ pin:: Pin ,
4
+ sync:: { Arc , RwLock } ,
5
+ } ;
3
6
4
7
use futures:: { StreamExt , future:: BoxFuture } ;
5
8
use http:: Uri ;
6
- use sse_stream:: Error as SseError ;
9
+ use sse_stream:: { Error as SseError , Sse } ;
7
10
use thiserror:: Error ;
8
11
9
12
use super :: {
@@ -54,9 +57,13 @@ pub trait SseClient: Clone + Send + Sync + 'static {
54
57
) -> impl Future < Output = Result < BoxedSseResponse , SseTransportError < Self :: Error > > > + Send + ' _ ;
55
58
}
56
59
60
+ /// Helper that refreshes the POST endpoint whenever the server emits
61
+ /// control frames during SSE reconnect; used together with
62
+ /// [`SseAutoReconnectStream`].
57
63
struct SseClientReconnect < C > {
58
64
pub client : C ,
59
65
pub uri : Uri ,
66
+ pub message_endpoint : Arc < RwLock < Uri > > ,
60
67
}
61
68
62
69
impl < C : SseClient > SseStreamReconnect for SseClientReconnect < C > {
@@ -68,6 +75,37 @@ impl<C: SseClient> SseStreamReconnect for SseClientReconnect<C> {
68
75
let last_event_id = last_event_id. map ( |s| s. to_owned ( ) ) ;
69
76
Box :: pin ( async move { client. get_stream ( uri, last_event_id, None ) . await } )
70
77
}
78
+
79
+ fn handle_control_event ( & mut self , event : & Sse ) -> Result < ( ) , Self :: Error > {
80
+ if event. event . as_deref ( ) != Some ( "endpoint" ) {
81
+ return Ok ( ( ) ) ;
82
+ }
83
+ let Some ( data) = event. data . as_ref ( ) else {
84
+ return Ok ( ( ) ) ;
85
+ } ;
86
+ // Servers typically resend the message POST endpoint (often with a new
87
+ // sessionId) when a stream reconnects. Reuse `message_endpoint` helper
88
+ // to resolve it and update the shared URI.
89
+ let new_endpoint = message_endpoint ( self . uri . clone ( ) , data. clone ( ) )
90
+ . map_err ( SseTransportError :: InvalidUri ) ?;
91
+ * self
92
+ . message_endpoint
93
+ . write ( )
94
+ . expect ( "message endpoint lock poisoned" ) = new_endpoint;
95
+ Ok ( ( ) )
96
+ }
97
+
98
+ fn handle_stream_error (
99
+ & mut self ,
100
+ error : & ( dyn std:: error:: Error + ' static ) ,
101
+ last_event_id : Option < & str > ,
102
+ ) {
103
+ tracing:: warn!(
104
+ uri = %self . uri,
105
+ last_event_id = last_event_id. unwrap_or( "" ) ,
106
+ "sse stream error: {error}"
107
+ ) ;
108
+ }
71
109
}
72
110
type ServerMessageStream < C > = Pin < Box < SseAutoReconnectStream < SseClientReconnect < C > > > > ;
73
111
@@ -81,7 +119,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
81
119
///
82
120
/// ## Using reqwest
83
121
///
84
- /// ```rust
122
+ /// ```rust,ignore
85
123
/// use rmcp::transport::SseClientTransport;
86
124
///
87
125
/// // Enable the reqwest feature in Cargo.toml:
@@ -95,7 +133,7 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
95
133
///
96
134
/// ## Using a custom HTTP client
97
135
///
98
- /// ```rust
136
+ /// ```rust,ignore
99
137
/// use rmcp::transport::sse_client::{SseClient, SseClientTransport, SseClientConfig};
100
138
/// use std::sync::Arc;
101
139
/// use futures::stream::BoxStream;
@@ -154,7 +192,9 @@ type ServerMessageStream<C> = Pin<Box<SseAutoReconnectStream<SseClientReconnect<
154
192
pub struct SseClientTransport < C : SseClient > {
155
193
client : C ,
156
194
config : SseClientConfig ,
157
- message_endpoint : Uri ,
195
+ /// Current POST endpoint; refreshed when the server sends new endpoint
196
+ /// control frames.
197
+ message_endpoint : Arc < RwLock < Uri > > ,
158
198
stream : Option < ServerMessageStream < C > > ,
159
199
}
160
200
@@ -168,8 +208,16 @@ impl<C: SseClient> Transport<RoleClient> for SseClientTransport<C> {
168
208
item : crate :: service:: TxJsonRpcMessage < RoleClient > ,
169
209
) -> impl Future < Output = Result < ( ) , Self :: Error > > + Send + ' static {
170
210
let client = self . client . clone ( ) ;
171
- let uri = self . message_endpoint . clone ( ) ;
172
- async move { client. post_message ( uri, item, None ) . await }
211
+ let message_endpoint = self . message_endpoint . clone ( ) ;
212
+ async move {
213
+ let uri = {
214
+ let guard = message_endpoint
215
+ . read ( )
216
+ . expect ( "message endpoint lock poisoned" ) ;
217
+ guard. clone ( )
218
+ } ;
219
+ client. post_message ( uri, item, None ) . await
220
+ }
173
221
}
174
222
async fn close ( & mut self ) -> Result < ( ) , Self :: Error > {
175
223
self . stream . take ( ) ;
@@ -194,7 +242,7 @@ impl<C: SseClient> SseClientTransport<C> {
194
242
let sse_endpoint = config. sse_endpoint . as_ref ( ) . parse :: < http:: Uri > ( ) ?;
195
243
196
244
let mut sse_stream = client. get_stream ( sse_endpoint. clone ( ) , None , None ) . await ?;
197
- let message_endpoint = if let Some ( endpoint) = config. use_message_endpoint . clone ( ) {
245
+ let initial_message_endpoint = if let Some ( endpoint) = config. use_message_endpoint . clone ( ) {
198
246
let ep = endpoint. parse :: < http:: Uri > ( ) ?;
199
247
let mut sse_endpoint_parts = sse_endpoint. clone ( ) . into_parts ( ) ;
200
248
sse_endpoint_parts. path_and_query = ep. into_parts ( ) . path_and_query ;
@@ -214,12 +262,14 @@ impl<C: SseClient> SseClientTransport<C> {
214
262
break message_endpoint ( sse_endpoint. clone ( ) , ep) ?;
215
263
}
216
264
} ;
265
+ let message_endpoint = Arc :: new ( RwLock :: new ( initial_message_endpoint) ) ;
217
266
218
267
let stream = Box :: pin ( SseAutoReconnectStream :: new (
219
268
sse_stream,
220
269
SseClientReconnect {
221
270
client : client. clone ( ) ,
222
271
uri : sse_endpoint. clone ( ) ,
272
+ message_endpoint : message_endpoint. clone ( ) ,
223
273
} ,
224
274
config. retry_policy . clone ( ) ,
225
275
) ) ;
@@ -274,7 +324,7 @@ pub struct SseClientConfig {
274
324
/// and the server send the message endpoint event as `message?session_id=123`,
275
325
/// then the message endpoint will be `http://example.com/message`.
276
326
///
277
- /// This follow the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/zh-CN /docs/Web/API/URL/URL)
327
+ /// This follows the rules of JavaScript's [`new URL(url, base)`](https://developer.mozilla.org/en-US /docs/Web/API/URL/URL)
278
328
pub sse_endpoint : Arc < str > ,
279
329
pub retry_policy : Arc < dyn SseRetryPolicy > ,
280
330
/// if this is settled, the client will use this endpoint to send message and skip get the endpoint event
@@ -293,8 +343,40 @@ impl Default for SseClientConfig {
293
343
294
344
#[ cfg( test) ]
295
345
mod tests {
346
+ use futures:: StreamExt ;
347
+ use serde_json:: { Value , json} ;
348
+
296
349
use super :: * ;
297
350
351
+ #[ derive( Clone ) ]
352
+ struct DummyClient ;
353
+
354
+ #[ derive( Debug , thiserror:: Error ) ]
355
+ #[ error( "dummy error" ) ]
356
+ struct DummyError ;
357
+
358
+ impl SseClient for DummyClient {
359
+ type Error = DummyError ;
360
+
361
+ async fn post_message (
362
+ & self ,
363
+ _uri : Uri ,
364
+ _message : ClientJsonRpcMessage ,
365
+ _auth_token : Option < String > ,
366
+ ) -> Result < ( ) , SseTransportError < Self :: Error > > {
367
+ Ok ( ( ) )
368
+ }
369
+
370
+ async fn get_stream (
371
+ & self ,
372
+ _uri : Uri ,
373
+ _last_event_id : Option < String > ,
374
+ _auth_token : Option < String > ,
375
+ ) -> Result < BoxedSseResponse , SseTransportError < Self :: Error > > {
376
+ unreachable ! ( "get_stream should not be called in this test" )
377
+ }
378
+ }
379
+
298
380
#[ test]
299
381
fn test_message_endpoint ( ) {
300
382
let base_url = "https://localhost/sse" . parse :: < http:: Uri > ( ) . unwrap ( ) ;
@@ -319,4 +401,58 @@ mod tests {
319
401
. unwrap ( ) ;
320
402
assert_eq ! ( result. to_string( ) , "http://example.com/xxx?sessionId=x" ) ;
321
403
}
404
+
405
+ #[ test]
406
+ fn handle_endpoint_control_event_updates_uri ( ) {
407
+ let initial_endpoint = "https://example.com/message?sessionId=old"
408
+ . parse :: < Uri > ( )
409
+ . unwrap ( ) ;
410
+ let shared_endpoint = Arc :: new ( RwLock :: new ( initial_endpoint) ) ;
411
+ let mut reconnect = SseClientReconnect {
412
+ client : DummyClient ,
413
+ uri : "https://example.com/sse" . parse :: < Uri > ( ) . unwrap ( ) ,
414
+ message_endpoint : shared_endpoint. clone ( ) ,
415
+ } ;
416
+
417
+ let control_event = Sse :: default ( )
418
+ . event ( "endpoint" )
419
+ . data ( "/message?sessionId=new" ) ;
420
+
421
+ reconnect. handle_control_event ( & control_event) . unwrap ( ) ;
422
+
423
+ let guard = shared_endpoint. read ( ) . expect ( "lock poisoned" ) ;
424
+ assert_eq ! (
425
+ guard. to_string( ) ,
426
+ "https://example.com/message?sessionId=new"
427
+ ) ;
428
+ }
429
+
430
+ #[ tokio:: test]
431
+ async fn control_event_frames_are_skipped ( ) {
432
+ let payload = json ! ( {
433
+ "jsonrpc" : "2.0" ,
434
+ "id" : 1 ,
435
+ "result" : { "ok" : true }
436
+ } )
437
+ . to_string ( ) ;
438
+
439
+ let events = vec ! [
440
+ Ok ( Sse :: default ( )
441
+ . event( "endpoint" )
442
+ . data( "/message?sessionId=reconnect" ) ) ,
443
+ Ok ( Sse :: default ( ) . event( "message" ) . data( payload. clone( ) ) ) ,
444
+ ] ;
445
+
446
+ let sse_src: BoxedSseResponse = futures:: stream:: iter ( events) . boxed ( ) ;
447
+ let reconn_stream = SseAutoReconnectStream :: never_reconnect ( sse_src, DummyError ) ;
448
+ futures:: pin_mut!( reconn_stream) ;
449
+
450
+ let message = reconn_stream. next ( ) . await . expect ( "stream item" ) . unwrap ( ) ;
451
+ let actual: Value = serde_json:: to_value ( message) . expect ( "serialize actual message" ) ;
452
+ // We only need to assert that a valid JSON-RPC response came through after
453
+ // skipping control frames. The exact `result` shape depends on the SDK's
454
+ // typed result enums and is not asserted here.
455
+ assert_eq ! ( actual. get( "jsonrpc" ) , Some ( & Value :: String ( "2.0" . into( ) ) ) ) ;
456
+ assert_eq ! ( actual. get( "id" ) , Some ( & Value :: Number ( 1u64 . into( ) ) ) ) ;
457
+ }
322
458
}
0 commit comments