9
9
import java .util .concurrent .atomic .AtomicInteger ;
10
10
import java .util .function .Function ;
11
11
12
+ import com .fasterxml .jackson .core .type .TypeReference ;
12
13
import com .fasterxml .jackson .databind .ObjectMapper ;
14
+ import io .modelcontextprotocol .spec .McpClientSession ;
15
+ import io .modelcontextprotocol .spec .McpClientTransport ;
13
16
import io .modelcontextprotocol .spec .McpSchema ;
14
17
import io .modelcontextprotocol .spec .McpSchema .JSONRPCRequest ;
15
18
import org .junit .jupiter .api .AfterEach ;
31
34
import static org .assertj .core .api .Assertions .assertThatThrownBy ;
32
35
33
36
/**
34
- * Tests for the {@link WebFluxSseClientTransport} class.
35
- *
36
37
* @author Christian Tzolov
37
38
*/
38
39
@ Timeout (15 )
@@ -46,20 +47,22 @@ class WebFluxSseClientTransportTests {
46
47
.withExposedPorts (3001 )
47
48
.waitingFor (Wait .forHttp ("/" ).forStatusCode (404 ));
48
49
49
- private TestSseClientTransport transport ;
50
+ private TestSseClientTransportProvider transportProvider ;
51
+
52
+ private McpClientTransport transport ;
50
53
51
54
private WebClient .Builder webClientBuilder ;
52
55
53
56
private ObjectMapper objectMapper ;
54
57
55
58
// Test class to access protected methods
56
- static class TestSseClientTransport extends WebFluxSseClientTransport {
59
+ static class TestSseClientTransportProvider extends WebFluxSseClientTransportProvider {
57
60
58
61
private final AtomicInteger inboundMessageCount = new AtomicInteger (0 );
59
62
60
63
private Sinks .Many <ServerSentEvent <String >> events = Sinks .many ().unicast ().onBackpressureBuffer ();
61
64
62
- public TestSseClientTransport (WebClient .Builder webClientBuilder , ObjectMapper objectMapper ) {
65
+ public TestSseClientTransportProvider (WebClient .Builder webClientBuilder , ObjectMapper objectMapper ) {
63
66
super (webClientBuilder , objectMapper );
64
67
}
65
68
@@ -69,7 +72,7 @@ protected Flux<ServerSentEvent<String>> eventStream() {
69
72
}
70
73
71
74
public String getLastEndpoint () {
72
- return messageEndpointSink .asMono ().block ();
75
+ return (( WebFluxSseClientTransport ) getSession (). getTransport ()). messageEndpointSink .asMono ().block ();
73
76
}
74
77
75
78
public int getInboundMessageCount () {
@@ -99,7 +102,10 @@ void setUp() {
99
102
startContainer ();
100
103
webClientBuilder = WebClient .builder ().baseUrl (host );
101
104
objectMapper = new ObjectMapper ();
102
- transport = new TestSseClientTransport (webClientBuilder , objectMapper );
105
+ transportProvider = new TestSseClientTransportProvider (webClientBuilder , objectMapper );
106
+ transportProvider .setSessionFactory (
107
+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
108
+ transport = transportProvider .getSession ().getTransport ();
103
109
transport .connect (Function .identity ()).block ();
104
110
}
105
111
@@ -117,44 +123,62 @@ void cleanup() {
117
123
118
124
@ Test
119
125
void testEndpointEventHandling () {
120
- assertThat (transport .getLastEndpoint ()).startsWith ("/message?" );
126
+ assertThat (transportProvider .getLastEndpoint ()).startsWith ("/message?" );
121
127
}
122
128
123
129
@ Test
124
130
void constructorValidation () {
125
- assertThatThrownBy (() -> new WebFluxSseClientTransport (null )).isInstanceOf (IllegalArgumentException .class )
131
+ assertThatThrownBy (() -> new WebFluxSseClientTransportProvider (null ))
132
+ .isInstanceOf (IllegalArgumentException .class )
126
133
.hasMessageContaining ("WebClient.Builder must not be null" );
127
134
128
- assertThatThrownBy (() -> new WebFluxSseClientTransport (webClientBuilder , null ))
135
+ assertThatThrownBy (() -> new WebFluxSseClientTransportProvider (webClientBuilder , null ))
129
136
.isInstanceOf (IllegalArgumentException .class )
130
137
.hasMessageContaining ("ObjectMapper must not be null" );
131
138
}
132
139
133
140
@ Test
134
141
void testBuilderPattern () {
135
142
// Test default builder
136
- WebFluxSseClientTransport transport1 = WebFluxSseClientTransport .builder (webClientBuilder ).build ();
137
- assertThatCode (() -> transport1 .closeGracefully ().block ()).doesNotThrowAnyException ();
143
+ WebFluxSseClientTransportProvider transportProvider1 = WebFluxSseClientTransportProvider
144
+ .builder (webClientBuilder )
145
+ .build ();
146
+ transportProvider1 .setSessionFactory (
147
+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
148
+ transportProvider1 .getSession ();
149
+ assertThatCode (() -> transportProvider1 .closeGracefully ().block ()).doesNotThrowAnyException ();
138
150
139
151
// Test builder with custom ObjectMapper
140
152
ObjectMapper customMapper = new ObjectMapper ();
141
- WebFluxSseClientTransport transport2 = WebFluxSseClientTransport .builder (webClientBuilder )
153
+ WebFluxSseClientTransportProvider transportProvider2 = WebFluxSseClientTransportProvider
154
+ .builder (webClientBuilder )
142
155
.objectMapper (customMapper )
143
156
.build ();
144
- assertThatCode (() -> transport2 .closeGracefully ().block ()).doesNotThrowAnyException ();
157
+ transportProvider2 .setSessionFactory (
158
+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
159
+ transportProvider2 .getSession ();
160
+ assertThatCode (() -> transportProvider2 .closeGracefully ().block ()).doesNotThrowAnyException ();
145
161
146
162
// Test builder with custom SSE endpoint
147
- WebFluxSseClientTransport transport3 = WebFluxSseClientTransport .builder (webClientBuilder )
163
+ WebFluxSseClientTransportProvider transportProvider3 = WebFluxSseClientTransportProvider
164
+ .builder (webClientBuilder )
148
165
.sseEndpoint ("/custom-sse" )
149
166
.build ();
150
- assertThatCode (() -> transport3 .closeGracefully ().block ()).doesNotThrowAnyException ();
167
+ transportProvider3 .setSessionFactory (
168
+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
169
+ transportProvider3 .getSession ();
170
+ assertThatCode (() -> transportProvider3 .closeGracefully ().block ()).doesNotThrowAnyException ();
151
171
152
172
// Test builder with all custom parameters
153
- WebFluxSseClientTransport transport4 = WebFluxSseClientTransport .builder (webClientBuilder )
173
+ WebFluxSseClientTransportProvider transportProvider4 = WebFluxSseClientTransportProvider
174
+ .builder (webClientBuilder )
154
175
.objectMapper (customMapper )
155
176
.sseEndpoint ("/custom-sse" )
156
177
.build ();
157
- assertThatCode (() -> transport4 .closeGracefully ().block ()).doesNotThrowAnyException ();
178
+ transportProvider4 .setSessionFactory (
179
+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
180
+ transportProvider4 .getSession ();
181
+ assertThatCode (() -> transportProvider4 .closeGracefully ().block ()).doesNotThrowAnyException ();
158
182
}
159
183
160
184
@ Test
@@ -164,7 +188,7 @@ void testMessageProcessing() {
164
188
Map .of ("key" , "value" ));
165
189
166
190
// Simulate receiving the message
167
- transport .simulateMessageEvent ("""
191
+ transportProvider .simulateMessageEvent ("""
168
192
{
169
193
"jsonrpc": "2.0",
170
194
"method": "test-method",
@@ -176,13 +200,13 @@ void testMessageProcessing() {
176
200
// Subscribe to messages and verify
177
201
StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
178
202
179
- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
203
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
180
204
}
181
205
182
206
@ Test
183
207
void testResponseMessageProcessing () {
184
208
// Simulate receiving a response message
185
- transport .simulateMessageEvent ("""
209
+ transportProvider .simulateMessageEvent ("""
186
210
{
187
211
"jsonrpc": "2.0",
188
212
"id": "test-id",
@@ -197,13 +221,13 @@ void testResponseMessageProcessing() {
197
221
// Verify message handling
198
222
StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
199
223
200
- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
224
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
201
225
}
202
226
203
227
@ Test
204
228
void testErrorMessageProcessing () {
205
229
// Simulate receiving an error message
206
- transport .simulateMessageEvent ("""
230
+ transportProvider .simulateMessageEvent ("""
207
231
{
208
232
"jsonrpc": "2.0",
209
233
"id": "test-id",
@@ -221,13 +245,13 @@ void testErrorMessageProcessing() {
221
245
// Verify message handling
222
246
StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
223
247
224
- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
248
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
225
249
}
226
250
227
251
@ Test
228
252
void testNotificationMessageProcessing () {
229
253
// Simulate receiving a notification message (no id)
230
- transport .simulateMessageEvent ("""
254
+ transportProvider .simulateMessageEvent ("""
231
255
{
232
256
"jsonrpc": "2.0",
233
257
"method": "update",
@@ -236,7 +260,7 @@ void testNotificationMessageProcessing() {
236
260
""" );
237
261
238
262
// Verify the notification was processed
239
- assertThat (transport .getInboundMessageCount ()).isEqualTo (1 );
263
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (1 );
240
264
}
241
265
242
266
@ Test
@@ -252,27 +276,31 @@ void testGracefulShutdown() {
252
276
StepVerifier .create (transport .sendMessage (testMessage )).verifyComplete ();
253
277
254
278
// Message count should remain 0 after shutdown
255
- assertThat (transport .getInboundMessageCount ()).isEqualTo (0 );
279
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (0 );
256
280
}
257
281
258
282
@ Test
259
283
void testRetryBehavior () {
260
284
// Create a WebClient that simulates connection failures
261
285
WebClient .Builder failingWebClientBuilder = WebClient .builder ().baseUrl ("http://non-existent-host" );
262
286
263
- WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport .builder (failingWebClientBuilder ).build ();
287
+ WebFluxSseClientTransportProvider failingTransportProvider = WebFluxSseClientTransportProvider
288
+ .builder (failingWebClientBuilder )
289
+ .build ();
290
+ failingTransportProvider .setSessionFactory (
291
+ (transport ) -> new McpClientSession (Duration .ofSeconds (5 ), transport , Map .of (), Map .of ()));
264
292
265
293
// Verify that the transport attempts to reconnect
266
294
StepVerifier .create (Mono .delay (Duration .ofSeconds (2 ))).expectNextCount (1 ).verifyComplete ();
267
295
268
296
// Clean up
269
- failingTransport .closeGracefully ().block ();
297
+ failingTransportProvider . getSession (). getTransport () .closeGracefully ().block ();
270
298
}
271
299
272
300
@ Test
273
301
void testMultipleMessageProcessing () {
274
302
// Simulate receiving multiple messages in sequence
275
- transport .simulateMessageEvent ("""
303
+ transportProvider .simulateMessageEvent ("""
276
304
{
277
305
"jsonrpc": "2.0",
278
306
"method": "method1",
@@ -281,7 +309,7 @@ void testMultipleMessageProcessing() {
281
309
}
282
310
""" );
283
311
284
- transport .simulateMessageEvent ("""
312
+ transportProvider .simulateMessageEvent ("""
285
313
{
286
314
"jsonrpc": "2.0",
287
315
"method": "method2",
@@ -301,13 +329,13 @@ void testMultipleMessageProcessing() {
301
329
StepVerifier .create (transport .sendMessage (message1 ).then (transport .sendMessage (message2 ))).verifyComplete ();
302
330
303
331
// Verify message count
304
- assertThat (transport .getInboundMessageCount ()).isEqualTo (2 );
332
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (2 );
305
333
}
306
334
307
335
@ Test
308
336
void testMessageOrderPreservation () {
309
337
// Simulate receiving messages in a specific order
310
- transport .simulateMessageEvent ("""
338
+ transportProvider .simulateMessageEvent ("""
311
339
{
312
340
"jsonrpc": "2.0",
313
341
"method": "first",
@@ -316,7 +344,7 @@ void testMessageOrderPreservation() {
316
344
}
317
345
""" );
318
346
319
- transport .simulateMessageEvent ("""
347
+ transportProvider .simulateMessageEvent ("""
320
348
{
321
349
"jsonrpc": "2.0",
322
350
"method": "second",
@@ -325,7 +353,7 @@ void testMessageOrderPreservation() {
325
353
}
326
354
""" );
327
355
328
- transport .simulateMessageEvent ("""
356
+ transportProvider .simulateMessageEvent ("""
329
357
{
330
358
"jsonrpc": "2.0",
331
359
"method": "third",
@@ -335,7 +363,7 @@ void testMessageOrderPreservation() {
335
363
""" );
336
364
337
365
// Verify message count and order
338
- assertThat (transport .getInboundMessageCount ()).isEqualTo (3 );
366
+ assertThat (transportProvider .getInboundMessageCount ()).isEqualTo (3 );
339
367
}
340
368
341
369
}
0 commit comments