From 1a829d0319d12a32458cdb545d710cd0215de5d8 Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Fri, 23 May 2025 22:35:37 +0200 Subject: [PATCH 1/4] fix: Resolve URIs --- .../WebFluxSseServerTransportProvider.java | 4 + mcp-spring/mcp-spring-webmvc/pom.xml | 8 +- .../WebMvcSseServerTransportProvider.java | 7 + .../WebMvcSseCustomContextPathTests.java | 4 +- .../WebMvcSseCustomPathIntegrationTests.java | 145 ++++++++++++++++++ .../HttpClientSseClientTransport.java | 9 +- .../io/modelcontextprotocol/util/Utils.java | 21 +++ 7 files changed, 191 insertions(+), 7 deletions(-) create mode 100644 mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..95728c81 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -154,6 +154,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + this.objectMapper = objectMapper; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 48d1c346..a703e145 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -71,6 +71,12 @@ ${junit.version} test + + org.junit.jupiter + junit-jupiter-params + ${junit.version} + test + org.mockito mockito-core @@ -128,7 +134,7 @@ test - + \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..c8c6f7f6 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -6,6 +6,7 @@ import java.io.IOException; import java.time.Duration; +import java.util.Collections; import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; @@ -152,7 +153,13 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } this.objectMapper = objectMapper; this.baseUrl = baseUrl; diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 1b5218cc..6de6624b 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -49,8 +49,8 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) - .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT + CUSTOM_CONTEXT_PATH) + .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) .build(); clientBuilder = McpClient.sync(clientTransport); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java new file mode 100644 index 00000000..a830a780 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java @@ -0,0 +1,145 @@ +package io.modelcontextprotocol.server; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import reactor.core.publisher.Mono; +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +public class WebMvcSseCustomPathIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private WebMvcSseServerTransportProvider mcpServerTransportProvider; + + McpClient.SyncSpec clientBuilder; + + private TomcatTestUtil.TomcatServer tomcatServer; + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcSseServerTransportProvider transportProvider(org.springframework.core.env.Environment env) { + String baseUrl = env.getProperty("test.baseUrl"); + String messageEndpoint = env.getProperty("test.messageEndpoint"); + String sseEndpoint = env.getProperty("test.sseEndpoint"); + + return new WebMvcSseServerTransportProvider(new ObjectMapper(), baseUrl, messageEndpoint, sseEndpoint); + } + + @Bean + public RouterFunction routerFunction(WebMvcSseServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + @ParameterizedTest(name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" : {displayName} ") + @MethodSource("provideCustomEndpoints") + public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint) { + System.setProperty("test.baseUrl", baseUrl); + System.setProperty("test.messageEndpoint", messageEndpoint); + System.setProperty("test.sseEndpoint", sseEndpoint); + + tomcatServer = TomcatTestUtil.createTomcatServer(baseUrl, PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT + baseUrl) + .sseEndpoint(sseEndpoint) + .build()); + + McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> Mono.just(callResponse)); + + mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); + + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + assertThat(client.initialize()).isNotNull(); + assertThat(client.listTools().tools()).contains(tool1.tool()); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + server.close(); + } + + private static Stream provideCustomEndpoints() { + String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; + String[] messageEndpoints = { "/message", "/another/sse", "/" }; + String[] sseEndpoints = { "/sse", "/another/sse", "/" }; + String[] contextPath = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; + + return Stream.of(baseUrls) + .flatMap(baseUrl -> Stream.of(messageEndpoints) + .flatMap(messageEndpoint -> Stream.of(sseEndpoints) + .map(sseEndpoint -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint)) + + )); + } + + @AfterEach + public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + +} \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index 99cf2a62..a1cf6388 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -175,9 +175,9 @@ public HttpClientSseClientTransport(HttpClient.Builder clientBuilder, HttpReques HttpClientSseClientTransport(HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String sseEndpoint, ObjectMapper objectMapper) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); - Assert.hasText(baseUri, "baseUri must not be empty"); - Assert.hasText(sseEndpoint, "sseEndpoint must not be empty"); - Assert.notNull(httpClient, "httpClient must not be null"); + Assert.notNull(baseUri, "baseUri must not be null"); + Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); Assert.notNull(requestBuilder, "requestBuilder must not be null"); this.baseUri = URI.create(baseUri); this.sseEndpoint = sseEndpoint; @@ -341,7 +341,8 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); + URI clientUri = Utils.resolveSseUri(this.baseUri, this.sseEndpoint); + logger.debug("Subscribing to {}", clientUri); sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override public void onEvent(SseEvent event) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 8e654e59..cec853f0 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,6 +4,7 @@ package io.modelcontextprotocol.util; +import java.net.URL; import reactor.util.annotation.Nullable; import java.net.URI; @@ -78,6 +79,26 @@ public static URI resolveUri(URI baseUrl, String endpointUrl) { } } + public static URI resolveSseUri(URI baseUrl, String endpointUrl) { + String sanitizedEndpoint = stripLeadingSlash(endpointUrl); + URI endpointUri = URI.create(sanitizedEndpoint); + if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { + throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); + } + + URI res = ensureTrailingSlash(baseUrl).resolve(endpointUri); + return res; + } + + private static String stripLeadingSlash(String url) { + return url.startsWith("/") ? url.substring(1) : url; + } + + private static URI ensureTrailingSlash(URI uri) { + String uriString = uri.toString(); + return !uriString.endsWith("/") ? URI.create(uriString.concat("/")) : uri; + } + /** * Checks if the given absolute endpoint URI falls under the base URI. It validates * the scheme, authority (host and port), and ensures that the base path is a prefix From c9ca95674050690706068dcddcc3cbfd50568c45 Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Sun, 25 May 2025 21:37:01 +0200 Subject: [PATCH 2/4] fix: Make context path in transportprovider work. --- .../WebFluxSseServerTransportProvider.java | 26 +++- .../WebFluxSseCustomPathIntegrationTests.java | 125 ++++++++++++++++++ .../WebMvcSseServerTransportProvider.java | 21 ++- .../WebMvcSseCustomContextPathTests.java | 6 +- .../WebMvcSseCustomPathIntegrationTests.java | 35 +++-- .../io/modelcontextprotocol/util/Utils.java | 4 +- 6 files changed, 188 insertions(+), 29 deletions(-) create mode 100644 mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 95728c81..2fcf3e96 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -82,6 +82,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ public static final String DEFAULT_SSE_ENDPOINT = "/sse"; + public static final String DEFAULT_CONTEXT_PATH = ""; + public static final String DEFAULT_BASE_URL = ""; private final ObjectMapper objectMapper; @@ -92,6 +94,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv */ private final String baseUrl; + private final String contextPath; + private final String messageEndpoint; private final String sseEndpoint; @@ -134,7 +138,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * @throws IllegalArgumentException if either parameter is null */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); + this(objectMapper, DEFAULT_CONTEXT_PATH, DEFAULT_BASE_URL, messageEndpoint, sseEndpoint); } /** @@ -147,9 +151,10 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * setup. Must not be null. * @throws IllegalArgumentException if either parameter is null */ - public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl, + String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(contextPath, "Context path must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); @@ -157,14 +162,18 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU if (baseUrl.endsWith("/")) { baseUrl = baseUrl.substring(0, baseUrl.length() - 1); } + if (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() - 1); + } this.objectMapper = objectMapper; + this.contextPath = contextPath; this.baseUrl = baseUrl; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) + .GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection) + .POST(this.baseUrl + this.messageEndpoint, this::handleMessage) .build(); } @@ -275,7 +284,7 @@ private Mono handleSseConnection(ServerRequest request) { logger.debug("Sending initial endpoint event to session: {}", sessionId); sink.next(ServerSentEvent.builder() .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) + .data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId) .build()); sink.onCancel(() -> { logger.debug("Session {} cancelled", sessionId); @@ -395,6 +404,8 @@ public static class Builder { private ObjectMapper objectMapper; + private String contextPath = DEFAULT_CONTEXT_PATH; + private String baseUrl = DEFAULT_BASE_URL; private String messageEndpoint; @@ -461,7 +472,8 @@ public WebFluxSseServerTransportProvider build() { Assert.notNull(objectMapper, "ObjectMapper must be set"); Assert.notNull(messageEndpoint, "Message endpoint must be set"); - return new WebFluxSseServerTransportProvider(objectMapper, baseUrl, messageEndpoint, sseEndpoint); + return new WebFluxSseServerTransportProvider(objectMapper, contextPath, baseUrl, messageEndpoint, + sseEndpoint); } } diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java new file mode 100644 index 00000000..e5374931 --- /dev/null +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java @@ -0,0 +1,125 @@ +package io.modelcontextprotocol.server; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.transport.WebFluxSseServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.reactive.function.server.RequestPredicates; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; +import org.springframework.web.reactive.function.server.ServerResponse; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.web.reactive.function.server.RequestPredicates.path; +import static org.springframework.web.reactive.function.server.RouterFunctions.nest; +import static org.springframework.web.reactive.function.server.RouterFunctions.route; + +public class WebFluxSseCustomPathIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private DisposableServer httpServer; + + private WebFluxSseServerTransportProvider mcpServerTransportProvider; + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest( + name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ") + @MethodSource("provideCustomEndpoints") + public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint, + String contextPath) { + + this.mcpServerTransportProvider = new WebFluxSseServerTransportProvider(new ObjectMapper(), contextPath, + baseUrl, messageEndpoint, sseEndpoint); + + RouterFunction router = this.mcpServerTransportProvider.getRouterFunction(); + RouterFunction nestedRouter = (RouterFunction) nest(path(contextPath), router); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(nestedRouter); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + + this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); + + var c = contextPath; + var b = baseUrl; + var s = sseEndpoint; + if (baseUrl.endsWith("/")) { + b = b.substring(0, b.length() - 1); + } + if (contextPath.endsWith("/")) { + c = c.substring(0, c.length() - 1); + } + + var clientBuilder = McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .sseEndpoint(c + b + s) + .build()); + + McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( + List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + McpServerFeatures.AsyncToolSpecification tool1 = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), + (exchange, request) -> Mono.just(callResponse)); + + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(McpSchema.ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + assertThat(client.initialize()).isNotNull(); + assertThat(client.listTools().tools()).contains(tool1.tool()); + + McpSchema.CallToolResult response = client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + server.close(); + + } + + private static Stream provideCustomEndpoints() { + String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; + String[] messageEndpoints = { "/message", "/another/sse", "/" }; + String[] sseEndpoints = { "/sse", "/another/sse", "/" }; + String[] contextPaths = { "", "/mcp", "/root/mcp", "/", "/mcp/", "/root/mcp/" }; + + return Stream.of(baseUrls) + .flatMap(baseUrl -> Stream.of(messageEndpoints) + .flatMap(messageEndpoint -> Stream.of(sseEndpoints) + .flatMap(sseEndpoint -> Stream.of(contextPaths) + .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + } + + @AfterEach + public void after() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index c8c6f7f6..2775f335 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -94,6 +94,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private final String baseUrl; + private final String contextPath; + private final RouterFunction routerFunction; private McpServerSession.Factory sessionFactory; @@ -133,13 +135,14 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * @throws IllegalArgumentException if any parameter is null */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messageEndpoint, String sseEndpoint) { - this(objectMapper, "", messageEndpoint, sseEndpoint); + this(objectMapper, "", "", messageEndpoint, sseEndpoint); } /** * Constructs a new WebMvcSseServerTransportProvider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of messages. + * @param contextPath The context path under which the server runs. * @param baseUrl The base URL for the message endpoint, used to construct the full * endpoint URL for clients. * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC @@ -148,9 +151,10 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag * @param sseEndpoint The endpoint URI where clients establish their SSE connections. * @throws IllegalArgumentException if any parameter is null */ - public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, - String sseEndpoint) { + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String contextPath, String baseUrl, + String messageEndpoint, String sseEndpoint) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(contextPath, "Context path must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); @@ -161,13 +165,18 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr baseUrl = baseUrl.substring(0, baseUrl.length() - 1); } + if (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() - 1); + } + this.objectMapper = objectMapper; this.baseUrl = baseUrl; + this.contextPath = contextPath; this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() - .GET(this.sseEndpoint, this::handleSseConnection) - .POST(this.messageEndpoint, this::handleMessage) + .GET(this.baseUrl + this.sseEndpoint, this::handleSseConnection) + .POST(this.baseUrl + this.messageEndpoint, this::handleMessage) .build(); } @@ -276,7 +285,7 @@ private ServerResponse handleSseConnection(ServerRequest request) { try { sseBuilder.id(sessionId) .event(ENDPOINT_EVENT_TYPE) - .data(this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); + .data(this.contextPath + this.baseUrl + this.messageEndpoint + "?sessionId=" + sessionId); } catch (Exception e) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 6de6624b..1f26f430 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -49,8 +49,8 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT + CUSTOM_CONTEXT_PATH) - .sseEndpoint(WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) + var clientTransport = HttpClientSseClientTransport.builder("http://localhost:" + PORT) + .sseEndpoint(CUSTOM_CONTEXT_PATH + WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT) .build(); clientBuilder = McpClient.sync(clientTransport); @@ -91,7 +91,7 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + return new WebMvcSseServerTransportProvider(new ObjectMapper(), "", CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java index a830a780..4fa66f4d 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java @@ -51,8 +51,10 @@ public WebMvcSseServerTransportProvider transportProvider(org.springframework.co String baseUrl = env.getProperty("test.baseUrl"); String messageEndpoint = env.getProperty("test.messageEndpoint"); String sseEndpoint = env.getProperty("test.sseEndpoint"); + String contextPath = env.getProperty("test.contextPath"); - return new WebMvcSseServerTransportProvider(new ObjectMapper(), baseUrl, messageEndpoint, sseEndpoint); + return new WebMvcSseServerTransportProvider(new ObjectMapper(), contextPath, baseUrl, messageEndpoint, + sseEndpoint); } @Bean @@ -62,14 +64,17 @@ public RouterFunction routerFunction(WebMvcSseServerTransportPro } - @ParameterizedTest(name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" : {displayName} ") + @ParameterizedTest( + name = "baseUrl = \"{0}\" messageEndpoint = \"{1}\" sseEndpoint = \"{2}\" contextPath = \"{3}\" : {displayName} ") @MethodSource("provideCustomEndpoints") - public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint) { + public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, String sseEndpoint, + String contextPath) { System.setProperty("test.baseUrl", baseUrl); System.setProperty("test.messageEndpoint", messageEndpoint); System.setProperty("test.sseEndpoint", sseEndpoint); + System.setProperty("test.contextPath", contextPath); - tomcatServer = TomcatTestUtil.createTomcatServer(baseUrl, PORT, TestConfig.class); + tomcatServer = TomcatTestUtil.createTomcatServer(contextPath, PORT, TestConfig.class); try { tomcatServer.tomcat().start(); @@ -79,9 +84,18 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT + baseUrl) - .sseEndpoint(sseEndpoint) - .build()); + var c = contextPath; + var b = baseUrl; + var s = sseEndpoint; + if (baseUrl.endsWith("/")) { + b = b.substring(0, b.length() - 1); + } + if (contextPath.endsWith("/")) { + c = c.substring(0, c.length() - 1); + } + + clientBuilder = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(c + b + s).build()); McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -113,14 +127,13 @@ private static Stream provideCustomEndpoints() { String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; String[] messageEndpoints = { "/message", "/another/sse", "/" }; String[] sseEndpoints = { "/sse", "/another/sse", "/" }; - String[] contextPath = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; + String[] contextPaths = { "", "/mcp", "/root/mcp", "/", "/mcp/", "/root/mcp/" }; return Stream.of(baseUrls) .flatMap(baseUrl -> Stream.of(messageEndpoints) .flatMap(messageEndpoint -> Stream.of(sseEndpoints) - .map(sseEndpoint -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint)) - - )); + .flatMap(sseEndpoint -> Stream.of(contextPaths) + .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); } @AfterEach diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index cec853f0..f24bee29 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -70,12 +70,12 @@ public static boolean isEmpty(@Nullable Map map) { * base URL or URI is malformed */ public static URI resolveUri(URI baseUrl, String endpointUrl) { - URI endpointUri = URI.create(endpointUrl); + URI endpointUri = URI.create(endpointUrl.startsWith("/") ? endpointUrl.substring(1) : endpointUrl); if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); } else { - return baseUrl.resolve(endpointUri); + return ensureTrailingSlash(baseUrl).resolve(endpointUri); } } From b29953ca3ace9d85a880bd9811154e85275efe0d Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Sat, 31 May 2025 22:20:44 +0200 Subject: [PATCH 3/4] chore: Cleanup --- .../WebFluxSseServerTransportProvider.java | 25 ++++-- .../WebFluxSseCustomPathIntegrationTests.java | 79 ++++++++++++------ .../WebMvcSseServerTransportProvider.java | 13 +-- .../WebMvcSseCustomContextPathTests.java | 2 +- .../WebMvcSseCustomPathIntegrationTests.java | 82 +++++++++++++------ .../HttpClientSseClientTransport.java | 2 +- .../io/modelcontextprotocol/util/Utils.java | 38 ++++----- 7 files changed, 150 insertions(+), 91 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 2fcf3e96..23105638 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -12,6 +12,7 @@ import io.modelcontextprotocol.spec.McpServerTransport; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.Exceptions; @@ -145,6 +146,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa * Constructs a new WebFlux SSE server transport provider instance. * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization * of MCP messages. Must not be null. + * @param contextPath The context path of the server. * @param baseUrl webflux message base path * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC * messages. This endpoint will be communicated to clients during SSE connection @@ -159,16 +161,9 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String conte Assert.notNull(messageEndpoint, "Message endpoint must not be null"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); - if (baseUrl.endsWith("/")) { - baseUrl = baseUrl.substring(0, baseUrl.length() - 1); - } - if (contextPath.endsWith("/")) { - contextPath = contextPath.substring(0, contextPath.length() - 1); - } - this.objectMapper = objectMapper; - this.contextPath = contextPath; - this.baseUrl = baseUrl; + this.contextPath = Utils.removeTrailingSlash(contextPath); + this.baseUrl = Utils.removeTrailingSlash(baseUrl); this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() @@ -438,6 +433,18 @@ public Builder basePath(String baseUrl) { return this; } + /** + * Sets the context path under which the server is running. + * @param contextPath the context path. + * @return this builder instance. + * @throws IllegalArgumentException if contextPath is null + */ + public Builder contextPath(String contextPath) { + Assert.notNull(contextPath, "contextPath must not be null"); + this.contextPath = contextPath; + return this; + } + /** * Sets the endpoint URI where clients should send their JSON-RPC messages. * @param messageEndpoint The message endpoint URI. Must not be null. diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java index e5374931..6597c10f 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java @@ -12,7 +12,6 @@ import org.springframework.http.server.reactive.HttpHandler; import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; import org.springframework.web.reactive.function.client.WebClient; -import org.springframework.web.reactive.function.server.RequestPredicates; import org.springframework.web.reactive.function.server.RouterFunction; import org.springframework.web.reactive.function.server.RouterFunctions; import org.springframework.web.reactive.function.server.ServerResponse; @@ -22,14 +21,16 @@ import java.util.List; import java.util.Map; -import java.util.function.Supplier; import java.util.stream.Stream; import static org.assertj.core.api.Assertions.assertThat; import static org.springframework.web.reactive.function.server.RequestPredicates.path; import static org.springframework.web.reactive.function.server.RouterFunctions.nest; -import static org.springframework.web.reactive.function.server.RouterFunctions.route; +/** + * Tests the {@link WebFluxSseServerTransportProvider} with different values for the + * endpoint. + */ public class WebFluxSseCustomPathIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -56,25 +57,18 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri baseUrl, messageEndpoint, sseEndpoint); RouterFunction router = this.mcpServerTransportProvider.getRouterFunction(); + // wrap the context path around the router function RouterFunction nestedRouter = (RouterFunction) nest(path(contextPath), router); HttpHandler httpHandler = RouterFunctions.toHttpHandler(nestedRouter); ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); this.httpServer = HttpServer.create().port(PORT).handle(adapter).bindNow(); - var c = contextPath; - var b = baseUrl; - var s = sseEndpoint; - if (baseUrl.endsWith("/")) { - b = b.substring(0, b.length() - 1); - } - if (contextPath.endsWith("/")) { - c = c.substring(0, c.length() - 1); - } + var endpoint = buildSseEndpoint(contextPath, baseUrl, sseEndpoint); var clientBuilder = McpClient .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .sseEndpoint(c + b + s) + .sseEndpoint(endpoint) .build()); McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( @@ -102,24 +96,63 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri } - private static Stream provideCustomEndpoints() { - String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; - String[] messageEndpoints = { "/message", "/another/sse", "/" }; - String[] sseEndpoints = { "/sse", "/another/sse", "/" }; - String[] contextPaths = { "", "/mcp", "/root/mcp", "/", "/mcp/", "/root/mcp/" }; + /** + * This is a helper function for the tests which builds the SSE endpoint to pass to the client transport. + * + * @param contextPath context path of the server. + * @param baseUrl base url of the sse endpoint. + * @param sseEndpoint the sse endpoint. + * @return the created sse endpoint. + */ + private String buildSseEndpoint(String contextPath, String baseUrl, String sseEndpoint) { + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + if (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() - 1); + } - return Stream.of(baseUrls) - .flatMap(baseUrl -> Stream.of(messageEndpoints) - .flatMap(messageEndpoint -> Stream.of(sseEndpoints) - .flatMap(sseEndpoint -> Stream.of(contextPaths) - .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + return contextPath + baseUrl + sseEndpoint; } @AfterEach public void after() { + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } if (httpServer != null) { httpServer.disposeNow(); } } + /** + * Provides a stream of custom endpoints. This generates all possible combinations for + * allowed endpoint values. + * + *

+ * Each combination is returned as an {@link Arguments} object containing four + * parameters in the following order: + *

+ *
    + *
  1. Base URL (String)
  2. + *
  3. Message endpoint (String)
  4. + *
  5. SSE endpoint (String)
  6. + *
  7. Context path (String)
  8. + *
+ * @return a {@link Stream} of {@link Arguments} objects, each containing four String + * parameters representing different endpoint combinations for parameterized testing + */ + private static Stream provideCustomEndpoints() { + String[] baseUrls = { "", "/", "/v1", "/v1/" }; + String[] messageEndpoints = { "/", "/message", "/message/" }; + String[] sseEndpoints = { "/", "/sse", "/sse/" }; + String[] contextPaths = { "", "/", "/mcp", "/mcp/" }; + + return Stream.of(baseUrls) + .flatMap(baseUrl -> Stream.of(messageEndpoints) + .flatMap(messageEndpoint -> Stream.of(sseEndpoints) + .flatMap(sseEndpoint -> Stream.of(contextPaths) + .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + } + } \ No newline at end of file diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 2775f335..fd720298 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -19,6 +19,7 @@ import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.util.Assert; +import io.modelcontextprotocol.util.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; @@ -161,17 +162,9 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String contex Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); - if (baseUrl.endsWith("/")) { - baseUrl = baseUrl.substring(0, baseUrl.length() - 1); - } - - if (contextPath.endsWith("/")) { - contextPath = contextPath.substring(0, contextPath.length() - 1); - } - this.objectMapper = objectMapper; - this.baseUrl = baseUrl; - this.contextPath = contextPath; + this.contextPath = Utils.removeTrailingSlash(contextPath); + this.baseUrl = Utils.removeTrailingSlash(baseUrl); this.messageEndpoint = messageEndpoint; this.sseEndpoint = sseEndpoint; this.routerFunction = RouterFunctions.route() diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java index 1f26f430..06a01f79 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomContextPathTests.java @@ -91,7 +91,7 @@ static class TestConfig { @Bean public WebMvcSseServerTransportProvider webMvcSseServerTransportProvider() { - return new WebMvcSseServerTransportProvider(new ObjectMapper(), "", CUSTOM_CONTEXT_PATH, MESSAGE_ENDPOINT, + return new WebMvcSseServerTransportProvider(new ObjectMapper(), CUSTOM_CONTEXT_PATH, "", MESSAGE_ENDPOINT, WebMvcSseServerTransportProvider.DEFAULT_SSE_ENDPOINT); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java index 4fa66f4d..3a40abcf 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java @@ -4,10 +4,14 @@ import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; import io.modelcontextprotocol.spec.McpSchema; + import java.util.List; import java.util.Map; import java.util.stream.Stream; + +import org.springframework.core.env.Environment; import reactor.core.publisher.Mono; + import static org.assertj.core.api.Assertions.assertThat; import org.apache.catalina.LifecycleException; @@ -24,14 +28,16 @@ import com.fasterxml.jackson.databind.ObjectMapper; +/** + * Tests the {@link WebMvcSseServerTransportProvider} with different values for the + * endpoint. + */ public class WebMvcSseCustomPathIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); private WebMvcSseServerTransportProvider mcpServerTransportProvider; - McpClient.SyncSpec clientBuilder; - private TomcatTestUtil.TomcatServer tomcatServer; String emptyJsonSchema = """ @@ -47,7 +53,7 @@ public class WebMvcSseCustomPathIntegrationTests { static class TestConfig { @Bean - public WebMvcSseServerTransportProvider transportProvider(org.springframework.core.env.Environment env) { + public WebMvcSseServerTransportProvider transportProvider(Environment env) { String baseUrl = env.getProperty("test.baseUrl"); String messageEndpoint = env.getProperty("test.messageEndpoint"); String sseEndpoint = env.getProperty("test.sseEndpoint"); @@ -84,18 +90,10 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri throw new RuntimeException("Failed to start Tomcat", e); } - var c = contextPath; - var b = baseUrl; - var s = sseEndpoint; - if (baseUrl.endsWith("/")) { - b = b.substring(0, b.length() - 1); - } - if (contextPath.endsWith("/")) { - c = c.substring(0, c.length() - 1); - } + var endpoint = buildSseEndpoint(contextPath, baseUrl, sseEndpoint); - clientBuilder = McpClient - .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(c + b + s).build()); + var clientBuilder = McpClient + .sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).sseEndpoint(endpoint).build()); McpSchema.CallToolResult callResponse = new McpSchema.CallToolResult( List.of(new McpSchema.TextContent("CALL RESPONSE")), null); @@ -123,17 +121,23 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri server.close(); } - private static Stream provideCustomEndpoints() { - String[] baseUrls = { "", "/v1", "/api/v1", "/", "/v1/", "/api/v1/" }; - String[] messageEndpoints = { "/message", "/another/sse", "/" }; - String[] sseEndpoints = { "/sse", "/another/sse", "/" }; - String[] contextPaths = { "", "/mcp", "/root/mcp", "/", "/mcp/", "/root/mcp/" }; + /** + * This is a helper function for the tests which builds the SSE endpoint to pass to the client transport. + * + * @param contextPath context path of the server. + * @param baseUrl base url of the sse endpoint. + * @param sseEndpoint the sse endpoint. + * @return the created sse endpoint. + */ + private String buildSseEndpoint(String contextPath, String baseUrl, String sseEndpoint) { + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + if (contextPath.endsWith("/")) { + contextPath = contextPath.substring(0, contextPath.length() - 1); + } - return Stream.of(baseUrls) - .flatMap(baseUrl -> Stream.of(messageEndpoints) - .flatMap(messageEndpoint -> Stream.of(sseEndpoints) - .flatMap(sseEndpoint -> Stream.of(contextPaths) - .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + return contextPath + baseUrl + sseEndpoint; } @AfterEach @@ -155,4 +159,34 @@ public void after() { } } + /** + * Provides a stream of custom endpoints. This generates all possible combinations for + * allowed endpoint values. + * + *

+ * Each combination is returned as an {@link Arguments} object containing four + * parameters in the following order: + *

+ *
    + *
  1. Base URL (String)
  2. + *
  3. Message endpoint (String)
  4. + *
  5. SSE endpoint (String)
  6. + *
  7. Context path (String)
  8. + *
+ * @return a {@link Stream} of {@link Arguments} objects, each containing four String + * parameters representing different endpoint combinations for parameterized testing + */ + private static Stream provideCustomEndpoints() { + String[] baseUrls = { "", "/", "/v1", "/v1/" }; + String[] messageEndpoints = { "/", "/message", "/message/" }; + String[] sseEndpoints = { "/", "/sse", "/sse/" }; + String[] contextPaths = { "", "/", "/mcp", "/mcp/" }; + + return Stream.of(baseUrls) + .flatMap(baseUrl -> Stream.of(messageEndpoints) + .flatMap(messageEndpoint -> Stream.of(sseEndpoints) + .flatMap(sseEndpoint -> Stream.of(contextPaths) + .map(contextPath -> Arguments.of(baseUrl, messageEndpoint, sseEndpoint, contextPath))))); + } + } \ No newline at end of file diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java index a1cf6388..9d64e32b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/transport/HttpClientSseClientTransport.java @@ -341,7 +341,7 @@ public Mono connect(Function, Mono> h CompletableFuture future = new CompletableFuture<>(); connectionFuture.set(future); - URI clientUri = Utils.resolveSseUri(this.baseUri, this.sseEndpoint); + URI clientUri = Utils.resolveUri(this.baseUri, this.sseEndpoint); logger.debug("Subscribing to {}", clientUri); sseClient.subscribe(clientUri.toString(), new FlowSseClient.SseEventHandler() { @Override diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index f24bee29..4477d625 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -4,7 +4,6 @@ package io.modelcontextprotocol.util; -import java.net.URL; import reactor.util.annotation.Nullable; import java.net.URI; @@ -54,6 +53,19 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } + + /** + * Removes the trailing slash character of the given String. + * @param str the String to remove the trailing slash + * @return the modified String. + */ + public static String removeTrailingSlash(String str) { + if (str.endsWith("/")) { + str = str.substring(0, str.length() - 1); + } + return str; + } + /** * Resolves the given endpoint URL against the base URL. *
    @@ -70,33 +82,13 @@ public static boolean isEmpty(@Nullable Map map) { * base URL or URI is malformed */ public static URI resolveUri(URI baseUrl, String endpointUrl) { - URI endpointUri = URI.create(endpointUrl.startsWith("/") ? endpointUrl.substring(1) : endpointUrl); + URI endpointUri = URI.create(endpointUrl); if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); } else { - return ensureTrailingSlash(baseUrl).resolve(endpointUri); - } - } - - public static URI resolveSseUri(URI baseUrl, String endpointUrl) { - String sanitizedEndpoint = stripLeadingSlash(endpointUrl); - URI endpointUri = URI.create(sanitizedEndpoint); - if (endpointUri.isAbsolute() && !isUnderBaseUri(baseUrl, endpointUri)) { - throw new IllegalArgumentException("Absolute endpoint URL does not match the base URL."); + return baseUrl.resolve(endpointUri); } - - URI res = ensureTrailingSlash(baseUrl).resolve(endpointUri); - return res; - } - - private static String stripLeadingSlash(String url) { - return url.startsWith("/") ? url.substring(1) : url; - } - - private static URI ensureTrailingSlash(URI uri) { - String uriString = uri.toString(); - return !uriString.endsWith("/") ? URI.create(uriString.concat("/")) : uri; } /** From 35d73ee7d5d23cebc65a297bf661a3a335693ee3 Mon Sep 17 00:00:00 2001 From: Dennis Kawurek Date: Sat, 31 May 2025 22:35:51 +0200 Subject: [PATCH 4/4] chore: Additional assertion and apply formatter --- .../server/transport/WebFluxSseServerTransportProvider.java | 2 ++ .../server/WebFluxSseCustomPathIntegrationTests.java | 4 ++-- .../server/WebMvcSseCustomPathIntegrationTests.java | 4 ++-- mcp/src/main/java/io/modelcontextprotocol/util/Utils.java | 1 - 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 23105638..f3268afa 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -159,7 +159,9 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String conte Assert.notNull(contextPath, "Context path must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); + Assert.hasText(messageEndpoint, "Message endpoint must not be empty"); Assert.notNull(sseEndpoint, "SSE endpoint must not be null"); + Assert.hasText(sseEndpoint, "SSE endpoint must not be empty"); this.objectMapper = objectMapper; this.contextPath = Utils.removeTrailingSlash(contextPath); diff --git a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java index 6597c10f..5cd90a5c 100644 --- a/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java +++ b/mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/server/WebFluxSseCustomPathIntegrationTests.java @@ -97,8 +97,8 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri } /** - * This is a helper function for the tests which builds the SSE endpoint to pass to the client transport. - * + * This is a helper function for the tests which builds the SSE endpoint to pass to + * the client transport. * @param contextPath context path of the server. * @param baseUrl base url of the sse endpoint. * @param sseEndpoint the sse endpoint. diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java index 3a40abcf..7b39f132 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseCustomPathIntegrationTests.java @@ -122,8 +122,8 @@ public void testCustomizedEndpoints(String baseUrl, String messageEndpoint, Stri } /** - * This is a helper function for the tests which builds the SSE endpoint to pass to the client transport. - * + * This is a helper function for the tests which builds the SSE endpoint to pass to + * the client transport. * @param contextPath context path of the server. * @param baseUrl base url of the sse endpoint. * @param sseEndpoint the sse endpoint. diff --git a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java index 4477d625..c45e3150 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java +++ b/mcp/src/main/java/io/modelcontextprotocol/util/Utils.java @@ -53,7 +53,6 @@ public static boolean isEmpty(@Nullable Map map) { return (map == null || map.isEmpty()); } - /** * Removes the trailing slash character of the given String. * @param str the String to remove the trailing slash