Skip to content

Commit 7a77925

Browse files
committed
Fix some unit tests
Signed-off-by: JermaineHua <[email protected]>
1 parent d18d478 commit 7a77925

File tree

4 files changed

+78
-11
lines changed

4 files changed

+78
-11
lines changed

mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportProvider.java

+44-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,50 @@ public void onEvent(SseEvent event) {
394394
if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
395395
String endpoint = event.data();
396396
messageEndpoint.set(endpoint);
397-
session.setId(getSessionIdFromUrl(endpoint));
397+
closeLatch.countDown();
398+
future.complete(null);
399+
}
400+
else if (MESSAGE_EVENT_TYPE.equals(event.type())) {
401+
JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, event.data());
402+
session.handle(message).subscribe();
403+
}
404+
else {
405+
logger.error("Received unrecognized SSE event type: {}", event.type());
406+
}
407+
}
408+
catch (IOException e) {
409+
logger.error("Error processing SSE event", e);
410+
future.completeExceptionally(e);
411+
}
412+
}
413+
414+
@Override
415+
public void onError(Throwable error) {
416+
if (!isClosing) {
417+
logger.error("SSE connection error", error);
418+
future.completeExceptionally(error);
419+
}
420+
}
421+
});
422+
423+
return Mono.fromFuture(future);
424+
}
425+
426+
@Override
427+
public Mono<Void> connect() {
428+
CompletableFuture<Void> future = new CompletableFuture<>();
429+
connectionFuture.set(future);
430+
sseClient.subscribe(baseUri + sseEndpoint, new FlowSseClient.SseEventHandler() {
431+
@Override
432+
public void onEvent(SseEvent event) {
433+
if (isClosing) {
434+
return;
435+
}
436+
437+
try {
438+
if (ENDPOINT_EVENT_TYPE.equals(event.type())) {
439+
String endpoint = event.data();
440+
messageEndpoint.set(endpoint);
398441
closeLatch.countDown();
399442
future.complete(null);
400443
}

mcp/src/test/java/io/modelcontextprotocol/client/StdioMcpSyncClientTests.java

+4
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
package io.modelcontextprotocol.client;
66

77
import java.time.Duration;
8+
import java.util.Map;
89
import java.util.concurrent.CountDownLatch;
910
import java.util.concurrent.TimeUnit;
1011
import java.util.concurrent.atomic.AtomicReference;
1112

1213
import io.modelcontextprotocol.client.transport.ServerParameters;
1314
import io.modelcontextprotocol.client.transport.StdioClientTransportProvider;
15+
import io.modelcontextprotocol.spec.McpClientSession;
1416
import io.modelcontextprotocol.spec.McpClientTransport;
1517
import io.modelcontextprotocol.spec.McpClientTransportProvider;
1618
import org.junit.jupiter.api.Test;
@@ -49,6 +51,8 @@ void customErrorHandlerShouldReceiveErrors() throws InterruptedException {
4951
AtomicReference<String> receivedError = new AtomicReference<>();
5052

5153
McpClientTransportProvider transportProvider = createMcpClientTransportProvider();
54+
transportProvider.setSessionFactory(
55+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
5256
McpClientTransport transport = transportProvider.getSession().getTransport();
5357
StepVerifier.create(transport.connect(msg -> msg)).verifyComplete();
5458

mcp/src/test/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransportTests.java

+10-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ void setUp() {
124124
transportProvider.setSessionFactory(
125125
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
126126
transport = new TestHttpClientSseClientTransport(transportProvider.getSession().getTransport());
127-
transport.connect(Function.identity()).block();
127+
transport.connect().block();
128128
}
129129

130130
@AfterEach
@@ -334,6 +334,9 @@ void testCustomizeClient() {
334334
customizerCalled.set(true);
335335
})
336336
.build();
337+
customizedTransport.setSessionFactory(
338+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
339+
customizedTransport.getSession();
337340

338341
// Verify the customizer was called
339342
assertThat(customizerCalled.get()).isTrue();
@@ -364,6 +367,9 @@ void testCustomizeRequest() {
364367
headerValue.set(request.headers().firstValue("X-Custom-Header").orElse(null));
365368
})
366369
.build();
370+
customizedTransport.setSessionFactory(
371+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
372+
customizedTransport.getSession();
367373

368374
// Verify the customizer was called
369375
assertThat(customizerCalled.get()).isTrue();
@@ -393,6 +399,9 @@ void testChainedCustomizations() {
393399
requestCustomizerCalled.set(true);
394400
})
395401
.build();
402+
customizedTransport.setSessionFactory(
403+
(transport) -> new McpClientSession(Duration.ofSeconds(5), transport, Map.of(), Map.of()));
404+
customizedTransport.getSession();
396405

397406
// Verify both customizers were called
398407
assertThat(clientCustomizerCalled.get()).isTrue();

mcp/src/test/java/io/modelcontextprotocol/spec/McpClientSessionTests.java

+20-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import com.fasterxml.jackson.core.type.TypeReference;
1111
import io.modelcontextprotocol.MockMcpClientTransport;
12+
import io.modelcontextprotocol.MockMcpClientTransportProvider;
1213
import org.junit.jupiter.api.AfterEach;
1314
import org.junit.jupiter.api.BeforeEach;
1415
import org.junit.jupiter.api.Test;
@@ -41,13 +42,18 @@ class McpClientSessionTests {
4142

4243
private McpClientSession session;
4344

44-
private MockMcpClientTransport transport;
45+
private MockMcpClientTransportProvider.MockMcpClientTransport transport;
46+
47+
private MockMcpClientTransportProvider transportProvider;
4548

4649
@BeforeEach
4750
void setUp() {
48-
transport = new MockMcpClientTransport();
49-
session = new McpClientSession(TIMEOUT, transport, Map.of(),
50-
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params))));
51+
transportProvider = new MockMcpClientTransportProvider();
52+
transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(),
53+
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> logger.info("Status update: " + params)))));
54+
session = transportProvider.getSession();
55+
transport = transportProvider.getTransport();
56+
5157
}
5258

5359
@AfterEach
@@ -139,8 +145,11 @@ void testRequestHandling() {
139145
String echoMessage = "Hello MCP!";
140146
Map<String, McpClientSession.RequestHandler<?>> requestHandlers = Map.of(ECHO_METHOD,
141147
params -> Mono.just(params));
142-
transport = new MockMcpClientTransport();
143-
session = new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of());
148+
transportProvider = new MockMcpClientTransportProvider();
149+
transportProvider
150+
.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, requestHandlers, Map.of()));
151+
session = transportProvider.getSession();
152+
transport = transportProvider.getTransport();
144153

145154
// Simulate incoming request
146155
McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, ECHO_METHOD,
@@ -159,9 +168,11 @@ void testRequestHandling() {
159168
void testNotificationHandling() {
160169
Sinks.One<Object> receivedParams = Sinks.one();
161170

162-
transport = new MockMcpClientTransport();
163-
session = new McpClientSession(TIMEOUT, transport, Map.of(),
164-
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params))));
171+
transportProvider = new MockMcpClientTransportProvider();
172+
transportProvider.setSessionFactory((transport) -> new McpClientSession(TIMEOUT, transport, Map.of(),
173+
Map.of(TEST_NOTIFICATION, params -> Mono.fromRunnable(() -> receivedParams.tryEmitValue(params)))));
174+
session = transportProvider.getSession();
175+
transport = transportProvider.getTransport();
165176

166177
// Simulate incoming notification from the server
167178
Map<String, Object> notificationParams = Map.of("status", "ready");

0 commit comments

Comments
 (0)