Skip to content

Commit 5efe8e6

Browse files
committed
Merge branch 'refs/heads/fix/non-blocking-context_test' into fix/non-blocking-context
# Conflicts: # mcp/src/main/java/io/modelcontextprotocol/spec/McpClientSession.java
2 parents 56facbb + 77e871a commit 5efe8e6

30 files changed

+2079
-1479
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransport.java

-408
This file was deleted.

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java

+494
Large diffs are not rendered by default.

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

+7-7
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313
import com.fasterxml.jackson.databind.ObjectMapper;
1414
import io.modelcontextprotocol.client.McpClient;
15-
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
16-
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
15+
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransportProvider;
16+
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
1717
import io.modelcontextprotocol.server.McpServer;
1818
import io.modelcontextprotocol.server.McpServerFeatures;
1919
import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider;
@@ -78,14 +78,14 @@ public void before() {
7878
this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow();
7979

8080
clientBulders.put("httpclient",
81-
McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT)
81+
McpClient.sync(HttpClientSseClientTransportProvider.builder("http://localhost:" + PORT)
8282
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
8383
.build()));
8484
clientBulders.put("webflux",
85-
McpClient
86-
.sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
87-
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
88-
.build()));
85+
McpClient.sync(WebFluxSseClientTransportProvider
86+
.builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
87+
.sseEndpoint(CUSTOM_SSE_ENDPOINT)
88+
.build()));
8989

9090
}
9191

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpAsyncClientTests.java

+4-7
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66

77
import java.time.Duration;
88

9-
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
10-
import io.modelcontextprotocol.spec.McpClientTransport;
9+
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
10+
import io.modelcontextprotocol.spec.McpClientTransportProvider;
1111
import org.junit.jupiter.api.Timeout;
1212
import org.testcontainers.containers.GenericContainer;
1313
import org.testcontainers.containers.wait.strategy.Wait;
1414

1515
import org.springframework.web.reactive.function.client.WebClient;
1616

1717
/**
18-
* Tests for the {@link McpAsyncClient} with {@link WebFluxSseClientTransport}.
19-
*
2018
* @author Christian Tzolov
2119
*/
2220
@Timeout(15) // Giving extra time beyond the client timeout
@@ -31,9 +29,8 @@ class WebFluxSseMcpAsyncClientTests extends AbstractMcpAsyncClientTests {
3129
.withExposedPorts(3001)
3230
.waitingFor(Wait.forHttp("/").forStatusCode(404));
3331

34-
@Override
35-
protected McpClientTransport createMcpTransport() {
36-
return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build();
32+
protected McpClientTransportProvider createMcpClientTransportProvider() {
33+
return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host));
3734
}
3835

3936
@Override

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/WebFluxSseMcpSyncClientTests.java

+4-6
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,15 @@
66

77
import java.time.Duration;
88

9-
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
10-
import io.modelcontextprotocol.spec.McpClientTransport;
9+
import io.modelcontextprotocol.client.transport.WebFluxSseClientTransportProvider;
10+
import io.modelcontextprotocol.spec.McpClientTransportProvider;
1111
import org.junit.jupiter.api.Timeout;
1212
import org.testcontainers.containers.GenericContainer;
1313
import org.testcontainers.containers.wait.strategy.Wait;
1414

1515
import org.springframework.web.reactive.function.client.WebClient;
1616

1717
/**
18-
* Tests for the {@link McpSyncClient} with {@link WebFluxSseClientTransport}.
19-
*
2018
* @author Christian Tzolov
2119
*/
2220
@Timeout(15) // Giving extra time beyond the client timeout
@@ -32,8 +30,8 @@ class WebFluxSseMcpSyncClientTests extends AbstractMcpSyncClientTests {
3230
.waitingFor(Wait.forHttp("/").forStatusCode(404));
3331

3432
@Override
35-
protected McpClientTransport createMcpTransport() {
36-
return WebFluxSseClientTransport.builder(WebClient.builder().baseUrl(host)).build();
33+
protected McpClientTransportProvider createMcpClientTransportProvider() {
34+
return new WebFluxSseClientTransportProvider(WebClient.builder().baseUrl(host));
3735
}
3836

3937
@Override

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportTests.java

+64-36
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
import java.util.concurrent.atomic.AtomicInteger;
1010
import java.util.function.Function;
1111

12+
import com.fasterxml.jackson.core.type.TypeReference;
1213
import com.fasterxml.jackson.databind.ObjectMapper;
14+
import io.modelcontextprotocol.spec.McpClientSession;
15+
import io.modelcontextprotocol.spec.McpClientTransport;
1316
import io.modelcontextprotocol.spec.McpSchema;
1417
import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest;
1518
import org.junit.jupiter.api.AfterEach;
@@ -31,8 +34,6 @@
3134
import static org.assertj.core.api.Assertions.assertThatThrownBy;
3235

3336
/**
34-
* Tests for the {@link WebFluxSseClientTransport} class.
35-
*
3637
* @author Christian Tzolov
3738
*/
3839
@Timeout(15)
@@ -46,20 +47,22 @@ class WebFluxSseClientTransportTests {
4647
.withExposedPorts(3001)
4748
.waitingFor(Wait.forHttp("/").forStatusCode(404));
4849

49-
private TestSseClientTransport transport;
50+
private TestSseClientTransportProvider transportProvider;
51+
52+
private McpClientTransport transport;
5053

5154
private WebClient.Builder webClientBuilder;
5255

5356
private ObjectMapper objectMapper;
5457

5558
// Test class to access protected methods
56-
static class TestSseClientTransport extends WebFluxSseClientTransport {
59+
static class TestSseClientTransportProvider extends WebFluxSseClientTransportProvider {
5760

5861
private final AtomicInteger inboundMessageCount = new AtomicInteger(0);
5962

6063
private Sinks.Many<ServerSentEvent<String>> events = Sinks.many().unicast().onBackpressureBuffer();
6164

62-
public TestSseClientTransport(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
65+
public TestSseClientTransportProvider(WebClient.Builder webClientBuilder, ObjectMapper objectMapper) {
6366
super(webClientBuilder, objectMapper);
6467
}
6568

@@ -69,7 +72,7 @@ protected Flux<ServerSentEvent<String>> eventStream() {
6972
}
7073

7174
public String getLastEndpoint() {
72-
return messageEndpointSink.asMono().block();
75+
return ((WebFluxSseClientTransport) getSession().getTransport()).messageEndpointSink.asMono().block();
7376
}
7477

7578
public int getInboundMessageCount() {
@@ -99,7 +102,10 @@ void setUp() {
99102
startContainer();
100103
webClientBuilder = WebClient.builder().baseUrl(host);
101104
objectMapper = new ObjectMapper();
102-
transport = new TestSseClientTransport(webClientBuilder, objectMapper);
105+
transportProvider = new TestSseClientTransportProvider(webClientBuilder, objectMapper);
106+
transportProvider.setSessionFactory(
107+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
108+
transport = transportProvider.getSession().getTransport();
103109
transport.connect(Function.identity()).block();
104110
}
105111

@@ -117,44 +123,62 @@ void cleanup() {
117123

118124
@Test
119125
void testEndpointEventHandling() {
120-
assertThat(transport.getLastEndpoint()).startsWith("/message?");
126+
assertThat(transportProvider.getLastEndpoint()).startsWith("/message?");
121127
}
122128

123129
@Test
124130
void constructorValidation() {
125-
assertThatThrownBy(() -> new WebFluxSseClientTransport(null)).isInstanceOf(IllegalArgumentException.class)
131+
assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(null))
132+
.isInstanceOf(IllegalArgumentException.class)
126133
.hasMessageContaining("WebClient.Builder must not be null");
127134

128-
assertThatThrownBy(() -> new WebFluxSseClientTransport(webClientBuilder, null))
135+
assertThatThrownBy(() -> new WebFluxSseClientTransportProvider(webClientBuilder, null))
129136
.isInstanceOf(IllegalArgumentException.class)
130137
.hasMessageContaining("ObjectMapper must not be null");
131138
}
132139

133140
@Test
134141
void testBuilderPattern() {
135142
// Test default builder
136-
WebFluxSseClientTransport transport1 = WebFluxSseClientTransport.builder(webClientBuilder).build();
137-
assertThatCode(() -> transport1.closeGracefully().block()).doesNotThrowAnyException();
143+
WebFluxSseClientTransportProvider transportProvider1 = WebFluxSseClientTransportProvider
144+
.builder(webClientBuilder)
145+
.build();
146+
transportProvider1.setSessionFactory(
147+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
148+
transportProvider1.getSession();
149+
assertThatCode(() -> transportProvider1.closeGracefully().block()).doesNotThrowAnyException();
138150

139151
// Test builder with custom ObjectMapper
140152
ObjectMapper customMapper = new ObjectMapper();
141-
WebFluxSseClientTransport transport2 = WebFluxSseClientTransport.builder(webClientBuilder)
153+
WebFluxSseClientTransportProvider transportProvider2 = WebFluxSseClientTransportProvider
154+
.builder(webClientBuilder)
142155
.objectMapper(customMapper)
143156
.build();
144-
assertThatCode(() -> transport2.closeGracefully().block()).doesNotThrowAnyException();
157+
transportProvider2.setSessionFactory(
158+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
159+
transportProvider2.getSession();
160+
assertThatCode(() -> transportProvider2.closeGracefully().block()).doesNotThrowAnyException();
145161

146162
// Test builder with custom SSE endpoint
147-
WebFluxSseClientTransport transport3 = WebFluxSseClientTransport.builder(webClientBuilder)
163+
WebFluxSseClientTransportProvider transportProvider3 = WebFluxSseClientTransportProvider
164+
.builder(webClientBuilder)
148165
.sseEndpoint("/custom-sse")
149166
.build();
150-
assertThatCode(() -> transport3.closeGracefully().block()).doesNotThrowAnyException();
167+
transportProvider3.setSessionFactory(
168+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
169+
transportProvider3.getSession();
170+
assertThatCode(() -> transportProvider3.closeGracefully().block()).doesNotThrowAnyException();
151171

152172
// Test builder with all custom parameters
153-
WebFluxSseClientTransport transport4 = WebFluxSseClientTransport.builder(webClientBuilder)
173+
WebFluxSseClientTransportProvider transportProvider4 = WebFluxSseClientTransportProvider
174+
.builder(webClientBuilder)
154175
.objectMapper(customMapper)
155176
.sseEndpoint("/custom-sse")
156177
.build();
157-
assertThatCode(() -> transport4.closeGracefully().block()).doesNotThrowAnyException();
178+
transportProvider4.setSessionFactory(
179+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
180+
transportProvider4.getSession();
181+
assertThatCode(() -> transportProvider4.closeGracefully().block()).doesNotThrowAnyException();
158182
}
159183

160184
@Test
@@ -164,7 +188,7 @@ void testMessageProcessing() {
164188
Map.of("key", "value"));
165189

166190
// Simulate receiving the message
167-
transport.simulateMessageEvent("""
191+
transportProvider.simulateMessageEvent("""
168192
{
169193
"jsonrpc": "2.0",
170194
"method": "test-method",
@@ -176,13 +200,13 @@ void testMessageProcessing() {
176200
// Subscribe to messages and verify
177201
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
178202

179-
assertThat(transport.getInboundMessageCount()).isEqualTo(1);
203+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
180204
}
181205

182206
@Test
183207
void testResponseMessageProcessing() {
184208
// Simulate receiving a response message
185-
transport.simulateMessageEvent("""
209+
transportProvider.simulateMessageEvent("""
186210
{
187211
"jsonrpc": "2.0",
188212
"id": "test-id",
@@ -197,13 +221,13 @@ void testResponseMessageProcessing() {
197221
// Verify message handling
198222
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
199223

200-
assertThat(transport.getInboundMessageCount()).isEqualTo(1);
224+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
201225
}
202226

203227
@Test
204228
void testErrorMessageProcessing() {
205229
// Simulate receiving an error message
206-
transport.simulateMessageEvent("""
230+
transportProvider.simulateMessageEvent("""
207231
{
208232
"jsonrpc": "2.0",
209233
"id": "test-id",
@@ -221,13 +245,13 @@ void testErrorMessageProcessing() {
221245
// Verify message handling
222246
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
223247

224-
assertThat(transport.getInboundMessageCount()).isEqualTo(1);
248+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
225249
}
226250

227251
@Test
228252
void testNotificationMessageProcessing() {
229253
// Simulate receiving a notification message (no id)
230-
transport.simulateMessageEvent("""
254+
transportProvider.simulateMessageEvent("""
231255
{
232256
"jsonrpc": "2.0",
233257
"method": "update",
@@ -236,7 +260,7 @@ void testNotificationMessageProcessing() {
236260
""");
237261

238262
// Verify the notification was processed
239-
assertThat(transport.getInboundMessageCount()).isEqualTo(1);
263+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(1);
240264
}
241265

242266
@Test
@@ -252,27 +276,31 @@ void testGracefulShutdown() {
252276
StepVerifier.create(transport.sendMessage(testMessage)).verifyComplete();
253277

254278
// Message count should remain 0 after shutdown
255-
assertThat(transport.getInboundMessageCount()).isEqualTo(0);
279+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(0);
256280
}
257281

258282
@Test
259283
void testRetryBehavior() {
260284
// Create a WebClient that simulates connection failures
261285
WebClient.Builder failingWebClientBuilder = WebClient.builder().baseUrl("http://non-existent-host");
262286

263-
WebFluxSseClientTransport failingTransport = WebFluxSseClientTransport.builder(failingWebClientBuilder).build();
287+
WebFluxSseClientTransportProvider failingTransportProvider = WebFluxSseClientTransportProvider
288+
.builder(failingWebClientBuilder)
289+
.build();
290+
failingTransportProvider.setSessionFactory(
291+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
264292

265293
// Verify that the transport attempts to reconnect
266294
StepVerifier.create(Mono.delay(Duration.ofSeconds(2))).expectNextCount(1).verifyComplete();
267295

268296
// Clean up
269-
failingTransport.closeGracefully().block();
297+
failingTransportProvider.getSession().getTransport().closeGracefully().block();
270298
}
271299

272300
@Test
273301
void testMultipleMessageProcessing() {
274302
// Simulate receiving multiple messages in sequence
275-
transport.simulateMessageEvent("""
303+
transportProvider.simulateMessageEvent("""
276304
{
277305
"jsonrpc": "2.0",
278306
"method": "method1",
@@ -281,7 +309,7 @@ void testMultipleMessageProcessing() {
281309
}
282310
""");
283311

284-
transport.simulateMessageEvent("""
312+
transportProvider.simulateMessageEvent("""
285313
{
286314
"jsonrpc": "2.0",
287315
"method": "method2",
@@ -301,13 +329,13 @@ void testMultipleMessageProcessing() {
301329
StepVerifier.create(transport.sendMessage(message1).then(transport.sendMessage(message2))).verifyComplete();
302330

303331
// Verify message count
304-
assertThat(transport.getInboundMessageCount()).isEqualTo(2);
332+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(2);
305333
}
306334

307335
@Test
308336
void testMessageOrderPreservation() {
309337
// Simulate receiving messages in a specific order
310-
transport.simulateMessageEvent("""
338+
transportProvider.simulateMessageEvent("""
311339
{
312340
"jsonrpc": "2.0",
313341
"method": "first",
@@ -316,7 +344,7 @@ void testMessageOrderPreservation() {
316344
}
317345
""");
318346

319-
transport.simulateMessageEvent("""
347+
transportProvider.simulateMessageEvent("""
320348
{
321349
"jsonrpc": "2.0",
322350
"method": "second",
@@ -325,7 +353,7 @@ void testMessageOrderPreservation() {
325353
}
326354
""");
327355

328-
transport.simulateMessageEvent("""
356+
transportProvider.simulateMessageEvent("""
329357
{
330358
"jsonrpc": "2.0",
331359
"method": "third",
@@ -335,7 +363,7 @@ void testMessageOrderPreservation() {
335363
""");
336364

337365
// Verify message count and order
338-
assertThat(transport.getInboundMessageCount()).isEqualTo(3);
366+
assertThat(transportProvider.getInboundMessageCount()).isEqualTo(3);
339367
}
340368

341369
}

0 commit comments

Comments
 (0)