Skip to content

Commit a736ea0

Browse files
committed
Fix some unit tests
Signed-off-by: JermaineHua <[email protected]>
1 parent d18d478 commit a736ea0

File tree

6 files changed

+41
-76
lines changed

6 files changed

+41
-76
lines changed

mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/client/transport/WebFluxSseClientTransportProvider.java

+1-33
Original file line numberDiff line numberDiff line change
@@ -263,39 +263,7 @@ public WebFluxSseClientTransport() {
263263
*/
264264
@Override
265265
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
266-
Flux<ServerSentEvent<String>> events = eventStream();
267-
inboundSubscription = events
268-
.concatMap(event -> Mono.just(event).<McpSchema.JSONRPCMessage>handle((e, s) -> {
269-
if (ENDPOINT_EVENT_TYPE.equals(event.event())) {
270-
String messageEndpointUri = event.data();
271-
if (messageEndpointSink.tryEmitValue(messageEndpointUri).isSuccess()) {
272-
session.setId(event.id());
273-
s.complete();
274-
}
275-
else {
276-
// TODO: clarify with the spec if multiple events can be
277-
// received
278-
s.error(new McpError("Failed to handle SSE endpoint event"));
279-
}
280-
}
281-
else if (MESSAGE_EVENT_TYPE.equals(event.event())) {
282-
try {
283-
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper,
284-
event.data());
285-
s.next(message);
286-
}
287-
catch (IOException ioException) {
288-
s.error(ioException);
289-
}
290-
}
291-
else {
292-
s.error(new McpError("Received unrecognized SSE event type: " + event.event()));
293-
}
294-
}).flatMap(message -> session.handle(message)))
295-
.subscribe();
296-
297-
// The connection is established once the server sends the endpoint event
298-
return messageEndpointSink.asMono().then();
266+
return connect();
299267
}
300268

301269
@Override

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,11 @@ public HttpClientSseClientTransport(FlowSseClient sseClient, HttpClient httpClie
381381
*/
382382
@Override
383383
public Mono<Void> connect(Function<Mono<JSONRPCMessage>, Mono<JSONRPCMessage>> handler) {
384+
return connect();
385+
}
386+
387+
@Override
388+
public Mono<Void> connect() {
384389
CompletableFuture<Void> future = new CompletableFuture<>();
385390
connectionFuture.set(future);
386391
sseClient.subscribe(baseUri + sseEndpoint, new FlowSseClient.SseEventHandler() {
@@ -394,7 +399,6 @@ public void onEvent(SseEvent event) {
394399
if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
395400
String endpoint = event.data();
396401
messageEndpoint.set(endpoint);
397-
session.setId(getSessionIdFromUrl(endpoint));
398402
closeLatch.countDown();
399403
future.complete(null);
400404
}

mcp/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransportProvider.java

+1-32
Original file line numberDiff line numberDiff line change
@@ -179,38 +179,7 @@ public StdioClientTransport(ServerParameters params, ObjectMapper objectMapper)
179179
*/
180180
@Override
181181
public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
182-
return Mono.<Void>fromRunnable(() -> {
183-
handleIncomingMessages(handler);
184-
handleIncomingErrors();
185-
186-
// Prepare command and environment
187-
List<String> fullCommand = new ArrayList<>();
188-
fullCommand.add(params.getCommand());
189-
fullCommand.addAll(params.getArgs());
190-
191-
ProcessBuilder processBuilder = this.getProcessBuilder();
192-
processBuilder.command(fullCommand);
193-
processBuilder.environment().putAll(params.getEnv());
194-
195-
// Start the process
196-
try {
197-
this.process = processBuilder.start();
198-
}
199-
catch (IOException e) {
200-
throw new RuntimeException("Failed to start process with command: " + fullCommand, e);
201-
}
202-
203-
// Validate process streams
204-
if (this.process.getInputStream() == null || process.getOutputStream() == null) {
205-
this.process.destroy();
206-
throw new RuntimeException("Process input or output stream is null");
207-
}
208-
209-
// Start threads
210-
startInboundProcessing();
211-
startOutboundProcessing();
212-
startErrorProcessing();
213-
}).subscribeOn(Schedulers.boundedElastic());
182+
return connect();
214183
}
215184

216185
@Override

mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java

+4
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
package io.modelcontextprotocol.client;
66

77
import java.time.Duration;
8+
import java.util.Map;
89
import java.util.concurrent.CountDownLatch;
910
import java.util.concurrent.TimeUnit;
1011
import java.util.concurrent.atomic.AtomicReference;
1112

1213
import io.modelcontextprotocol.client.transport.ServerParameters;
1314
import io.modelcontextprotocol.client.transport.StdioClientTransportProvider;
15+
import io.modelcontextprotocol.spec.McpClientSession;
1416
import io.modelcontextprotocol.spec.McpClientTransport;
1517
import io.modelcontextprotocol.spec.McpClientTransportProvider;
1618
import org.junit.jupiter.api.Test;
@@ -49,6 +51,8 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException {
4951
AtomicReference<String> receivedError = new AtomicReference<>();
5052

5153
McpClientTransportProvider transportProvider = createMcpClientTransportProvider();
54+
transportProvider.setSessionFactory(
55+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
5256
McpClientTransport transport = transportProvider.getSession().getTransport();
5357
StepVerifier.create(transport.connect(msg -> msg)).verifyComplete();
5458

mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void setUp() {
124124
transportProvider.setSessionFactory(
125125
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
126126
transport = new TestHttpClientSseClientTransport(transportProvider.getSession().getTransport());
127-
transport.connect(Function.identity()).block();
127+
transport.connect().block();
128128
}
129129

130130
@AfterEach
@@ -334,6 +334,9 @@ void testCustomizeClient() {
334334
customizerCalled.set(true);
335335
})
336336
.build();
337+
customizedTransport.setSessionFactory(
338+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
339+
customizedTransport.getSession();
337340

338341
// Verify the customizer was called
339342
assertThat(customizerCalled.get()).isTrue();
@@ -364,6 +367,9 @@ void testCustomizeRequest() {
364367
headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null));
365368
})
366369
.build();
370+
customizedTransport.setSessionFactory(
371+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
372+
customizedTransport.getSession();
367373

368374
// Verify the customizer was called
369375
assertThat(customizerCalled.get()).isTrue();
@@ -393,6 +399,9 @@ void testChainedCustomizations() {
393399
requestCustomizerCalled.set(true);
394400
})
395401
.build();
402+
customizedTransport.setSessionFactory(
403+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
404+
customizedTransport.getSession();
396405

397406
// Verify both customizers were called
398407
assertThat(clientCustomizerCalled.get()).isTrue();

mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java

+20-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import com.fasterxml.jackson.core.type.TypeReference;
1111
import io.modelcontextprotocol.MockMcpClientTransport;
12+
import io.modelcontextprotocol.MockMcpClientTransportProvider;
1213
import org.junit.jupiter.api.AfterEach;
1314
import org.junit.jupiter.api.BeforeEach;
1415
import org.junit.jupiter.api.Test;
@@ -41,13 +42,18 @@ class McpClientSessionTests {
4142

4243
private McpClientSession session;
4344

44-
private MockMcpClientTransport transport;
45+
private MockMcpClientTransportProvider.MockMcpClientTransport transport;
46+
47+
private MockMcpClientTransportProvider transportProvider;
4548

4649
@BeforeEach
4750
void setUp() {
48-
transport = new MockMcpClientTransport();
49-
session = new McpClientSession(TIMEOUT, transport, Map.of(),
50-
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params))));
51+
transportProvider = new MockMcpClientTransportProvider();
52+
transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(),
53+
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))));
54+
session = transportProvider.getSession();
55+
transport = transportProvider.getTransport();
56+
5157
}
5258

5359
@AfterEach
@@ -139,8 +145,11 @@ void testRequestHandling() {
139145
String echoMessage = "Hello MCP!";
140146
Map<String, McpClientSession.RequestHandler<?>> requestHandlers = Map.of(ECHO_METHOD,
141147
params -> Mono.just(params));
142-
transport = new MockMcpClientTransport();
143-
session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of());
148+
transportProvider = new MockMcpClientTransportProvider();
149+
transportProvider
150+
.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()));
151+
session = transportProvider.getSession();
152+
transport = transportProvider.getTransport();
144153

145154
// Simulate incoming request
146155
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD,
@@ -159,9 +168,11 @@ void testRequestHandling() {
159168
void testNotificationHandling() {
160169
Sinks.One<Object> receivedParams = Sinks.one();
161170

162-
transport = new MockMcpClientTransport();
163-
session = new McpClientSession(TIMEOUT, transport, Map.of(),
164-
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))));
171+
transportProvider = new MockMcpClientTransportProvider();
172+
transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(),
173+
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))));
174+
session = transportProvider.getSession();
175+
transport = transportProvider.getTransport();
165176

166177
// Simulate incoming notification from the server
167178
Map<String, Object> notificationParams = Map.of("status", "ready");

0 commit comments

Comments
 (0)