closeGracefully() {
*/
@Override
public void close() {
+ sseBuilderLock.lock();
try {
sseBuilder.complete();
logger.debug("Successfully completed SSE builder for session {}", sessionId);
@@ -412,6 +428,9 @@ public void close() {
catch (Exception e) {
logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage());
}
+ finally {
+ sseBuilderLock.unlock();
+ }
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java
new file mode 100644
index 00000000..d14a51d8
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java
@@ -0,0 +1,654 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server.transport;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.locks.ReentrantLock;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.RouterFunctions;
+import org.springframework.web.servlet.function.ServerRequest;
+import org.springframework.web.servlet.function.ServerResponse;
+import org.springframework.web.servlet.function.ServerResponse.SseBuilder;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.server.DefaultMcpTransportContext;
+import io.modelcontextprotocol.server.McpTransportContext;
+import io.modelcontextprotocol.server.McpTransportContextExtractor;
+import io.modelcontextprotocol.spec.HttpHeaders;
+import io.modelcontextprotocol.spec.McpError;
+import io.modelcontextprotocol.spec.McpSchema;
+import io.modelcontextprotocol.spec.McpStreamableServerSession;
+import io.modelcontextprotocol.spec.McpStreamableServerTransport;
+import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
+import io.modelcontextprotocol.util.Assert;
+import reactor.core.publisher.Mono;
+
+/**
+ * Server-side implementation of the Model Context Protocol (MCP) streamable transport
+ * layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This
+ * implementation provides a bridge between synchronous WebMVC operations and reactive
+ * programming patterns to maintain compatibility with the reactive transport interface.
+ *
+ *
+ * This is the non-reactive version of
+ * {@link io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider}
+ *
+ * @author Christian Tzolov
+ * @author Dariusz Jędrzejczyk
+ * @see McpStreamableServerTransportProvider
+ * @see RouterFunction
+ */
+public class WebMvcStreamableServerTransportProvider implements McpStreamableServerTransportProvider {
+
+ private static final Logger logger = LoggerFactory.getLogger(WebMvcStreamableServerTransportProvider.class);
+
+ /**
+ * Event type for JSON-RPC messages sent through the SSE connection.
+ */
+ public static final String MESSAGE_EVENT_TYPE = "message";
+
+ /**
+ * Event type for sending the message endpoint URI to clients.
+ */
+ public static final String ENDPOINT_EVENT_TYPE = "endpoint";
+
+ /**
+ * Default base URL for the message endpoint.
+ */
+ public static final String DEFAULT_BASE_URL = "";
+
+ /**
+ * The endpoint URI where clients should send their JSON-RPC messages. Defaults to
+ * "/mcp".
+ */
+ private final String mcpEndpoint;
+
+ /**
+ * Flag indicating whether DELETE requests are disallowed on the endpoint.
+ */
+ private final boolean disallowDelete;
+
+ private final ObjectMapper objectMapper;
+
+ private final RouterFunction routerFunction;
+
+ private McpStreamableServerSession.Factory sessionFactory;
+
+ /**
+ * Map of active client sessions, keyed by mcp-session-id.
+ */
+ private final ConcurrentHashMap sessions = new ConcurrentHashMap<>();
+
+ private McpTransportContextExtractor contextExtractor;
+
+ // private Function contextExtractor = req -> new
+ // DefaultMcpTransportContext();
+
+ /**
+ * Flag indicating if the transport is shutting down.
+ */
+ private volatile boolean isClosing = false;
+
+ /**
+ * Constructs a new WebMvcStreamableServerTransportProvider instance.
+ * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
+ * of messages.
+ * @param baseUrl The base URL for the message endpoint, used to construct the full
+ * endpoint URL for clients.
+ * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC
+ * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests.
+ * @param disallowDelete Whether to disallow DELETE requests on the endpoint.
+ * @throws IllegalArgumentException if any parameter is null
+ */
+ private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint,
+ boolean disallowDelete, McpTransportContextExtractor contextExtractor) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
+ Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null");
+
+ this.objectMapper = objectMapper;
+ this.mcpEndpoint = mcpEndpoint;
+ this.disallowDelete = disallowDelete;
+ this.contextExtractor = contextExtractor;
+ this.routerFunction = RouterFunctions.route()
+ .GET(this.mcpEndpoint, this::handleGet)
+ .POST(this.mcpEndpoint, this::handlePost)
+ .DELETE(this.mcpEndpoint, this::handleDelete)
+ .build();
+ }
+
+ @Override
+ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) {
+ this.sessionFactory = sessionFactory;
+ }
+
+ /**
+ * Broadcasts a notification to all connected clients through their SSE connections.
+ * If any errors occur during sending to a particular client, they are logged but
+ * don't prevent sending to other clients.
+ * @param method The method name for the notification
+ * @param params The parameters for the notification
+ * @return A Mono that completes when the broadcast attempt is finished
+ */
+ @Override
+ public Mono notifyClients(String method, Object params) {
+ if (this.sessions.isEmpty()) {
+ logger.debug("No active sessions to broadcast message to");
+ return Mono.empty();
+ }
+
+ logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
+
+ return Mono.fromRunnable(() -> {
+ this.sessions.values().parallelStream().forEach(session -> {
+ try {
+ session.sendNotification(method, params).block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage());
+ }
+ });
+ });
+ }
+
+ /**
+ * Initiates a graceful shutdown of the transport.
+ * @return A Mono that completes when all cleanup operations are finished
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ this.isClosing = true;
+ logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
+
+ this.sessions.values().parallelStream().forEach(session -> {
+ try {
+ session.closeGracefully().block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to close session {}: {}", session.getId(), e.getMessage());
+ }
+ });
+
+ this.sessions.clear();
+ logger.debug("Graceful shutdown completed");
+ });
+ }
+
+ /**
+ * Returns the RouterFunction that defines the HTTP endpoints for this transport. The
+ * router function handles three endpoints:
+ *
+ * - GET [mcpEndpoint] - For establishing SSE connections and message replay
+ * - POST [mcpEndpoint] - For receiving JSON-RPC messages from clients
+ * - DELETE [mcpEndpoint] - For session deletion (if enabled)
+ *
+ * @return The configured RouterFunction for handling HTTP requests
+ */
+ public RouterFunction getRouterFunction() {
+ return this.routerFunction;
+ }
+
+ /**
+ * Setup the listening SSE connections and message replay.
+ * @param request The incoming server request
+ * @return A ServerResponse configured for SSE communication, or an error response
+ */
+ private ServerResponse handleGet(ServerRequest request) {
+ if (this.isClosing) {
+ return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
+ }
+
+ List acceptHeaders = request.headers().asHttpHeaders().getAccept();
+ if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) {
+ return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM");
+ }
+
+ McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
+
+ if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
+ return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
+ }
+
+ String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID);
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ return ServerResponse.notFound().build();
+ }
+
+ logger.debug("Handling GET request for session: {}", sessionId);
+
+ try {
+ return ServerResponse.sse(sseBuilder -> {
+ sseBuilder.onTimeout(() -> {
+ logger.debug("SSE connection timed out for session: {}", sessionId);
+ });
+
+ WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport(
+ sessionId, sseBuilder);
+
+ // Check if this is a replay request
+ if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) {
+ String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID);
+
+ try {
+ session.replay(lastId)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .toIterable()
+ .forEach(message -> {
+ try {
+ sessionTransport.sendMessage(message)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to replay message: {}", e.getMessage());
+ sseBuilder.error(e);
+ }
+ });
+ }
+ catch (Exception e) {
+ logger.error("Failed to replay messages: {}", e.getMessage());
+ sseBuilder.error(e);
+ }
+ }
+ else {
+ // Establish new listening stream
+ McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session
+ .listeningStream(sessionTransport);
+
+ sseBuilder.onComplete(() -> {
+ logger.debug("SSE connection completed for session: {}", sessionId);
+ listeningStream.close();
+ });
+ }
+ }, Duration.ZERO);
+ }
+ catch (Exception e) {
+ logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage());
+ return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build();
+ }
+ }
+
+ /**
+ * Handles POST requests for incoming JSON-RPC messages from clients.
+ * @param request The incoming server request containing the JSON-RPC message
+ * @return A ServerResponse indicating success or appropriate error status
+ */
+ private ServerResponse handlePost(ServerRequest request) {
+ if (this.isClosing) {
+ return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
+ }
+
+ List acceptHeaders = request.headers().asHttpHeaders().getAccept();
+ if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)
+ || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) {
+ return ServerResponse.badRequest()
+ .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON"));
+ }
+
+ McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
+
+ try {
+ String body = request.body(String.class);
+ McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
+
+ // Handle initialization request
+ if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest
+ && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) {
+ McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(),
+ new TypeReference() {
+ });
+ McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
+ .startSession(initializeRequest);
+ this.sessions.put(init.session().getId(), init.session());
+
+ try {
+ McpSchema.InitializeResult initResult = init.initResult().block();
+
+ return ServerResponse.ok()
+ .contentType(MediaType.APPLICATION_JSON)
+ .header(HttpHeaders.MCP_SESSION_ID, init.session().getId())
+ .body(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult,
+ null));
+ }
+ catch (Exception e) {
+ logger.error("Failed to initialize session: {}", e.getMessage());
+ return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage()));
+ }
+ }
+
+ // Handle other messages that require a session
+ if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
+ return ServerResponse.badRequest().body(new McpError("Session ID missing"));
+ }
+
+ String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID);
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ return ServerResponse.status(HttpStatus.NOT_FOUND)
+ .body(new McpError("Session not found: " + sessionId));
+ }
+
+ if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) {
+ session.accept(jsonrpcResponse)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ return ServerResponse.accepted().build();
+ }
+ else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) {
+ session.accept(jsonrpcNotification)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ return ServerResponse.accepted().build();
+ }
+ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) {
+ // For streaming responses, we need to return SSE
+ return ServerResponse.sse(sseBuilder -> {
+ sseBuilder.onComplete(() -> {
+ logger.debug("Request response stream completed for session: {}", sessionId);
+ });
+ sseBuilder.onTimeout(() -> {
+ logger.debug("Request response stream timed out for session: {}", sessionId);
+ });
+
+ WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport(
+ sessionId, sseBuilder);
+
+ try {
+ session.responseStream(jsonrpcRequest, sessionTransport)
+ .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext))
+ .block();
+ }
+ catch (Exception e) {
+ logger.error("Failed to handle request stream: {}", e.getMessage());
+ sseBuilder.error(e);
+ }
+ }, Duration.ZERO);
+ }
+ else {
+ return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR)
+ .body(new McpError("Unknown message type"));
+ }
+ }
+ catch (IllegalArgumentException | IOException e) {
+ logger.error("Failed to deserialize message: {}", e.getMessage());
+ return ServerResponse.badRequest().body(new McpError("Invalid message format"));
+ }
+ catch (Exception e) {
+ logger.error("Error handling message: {}", e.getMessage());
+ return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage()));
+ }
+ }
+
+ /**
+ * Handles DELETE requests for session deletion.
+ * @param request The incoming server request
+ * @return A ServerResponse indicating success or appropriate error status
+ */
+ private ServerResponse handleDelete(ServerRequest request) {
+ if (this.isClosing) {
+ return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down");
+ }
+
+ if (this.disallowDelete) {
+ return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build();
+ }
+
+ McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext());
+
+ if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) {
+ return ServerResponse.badRequest().body("Session ID required in mcp-session-id header");
+ }
+
+ String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID);
+ McpStreamableServerSession session = this.sessions.get(sessionId);
+
+ if (session == null) {
+ return ServerResponse.notFound().build();
+ }
+
+ try {
+ session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
+ this.sessions.remove(sessionId);
+ return ServerResponse.ok().build();
+ }
+ catch (Exception e) {
+ logger.error("Failed to delete session {}: {}", sessionId, e.getMessage());
+ return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage()));
+ }
+ }
+
+ /**
+ * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class
+ * handles the transport-level communication for a specific client session.
+ *
+ *
+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the
+ * underlying SSE builder to prevent race conditions when multiple threads attempt to
+ * send messages concurrently.
+ */
+ private class WebMvcStreamableMcpSessionTransport implements McpStreamableServerTransport {
+
+ private final String sessionId;
+
+ private final SseBuilder sseBuilder;
+
+ private final ReentrantLock lock = new ReentrantLock();
+
+ private volatile boolean closed = false;
+
+ /**
+ * Creates a new session transport with the specified ID and SSE builder.
+ * @param sessionId The unique identifier for this session
+ * @param sseBuilder The SSE builder for sending server events to the client
+ */
+ WebMvcStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) {
+ this.sessionId = sessionId;
+ this.sseBuilder = sseBuilder;
+ logger.debug("Streamable session transport {} initialized with SSE builder", sessionId);
+ }
+
+ /**
+ * Sends a JSON-RPC message to the client through the SSE connection.
+ * @param message The JSON-RPC message to send
+ * @return A Mono that completes when the message has been sent
+ */
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message) {
+ return sendMessage(message, null);
+ }
+
+ /**
+ * Sends a JSON-RPC message to the client through the SSE connection with a
+ * specific message ID.
+ * @param message The JSON-RPC message to send
+ * @param messageId The message ID for SSE event identification
+ * @return A Mono that completes when the message has been sent
+ */
+ @Override
+ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) {
+ return Mono.fromRunnable(() -> {
+ if (this.closed) {
+ logger.debug("Attempted to send message to closed session: {}", this.sessionId);
+ return;
+ }
+
+ this.lock.lock();
+ try {
+ if (this.closed) {
+ logger.debug("Session {} was closed during message send attempt", this.sessionId);
+ return;
+ }
+
+ String jsonText = objectMapper.writeValueAsString(message);
+ this.sseBuilder.id(messageId != null ? messageId : this.sessionId)
+ .event(MESSAGE_EVENT_TYPE)
+ .data(jsonText);
+ logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId);
+ }
+ catch (Exception e) {
+ logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
+ try {
+ this.sseBuilder.error(e);
+ }
+ catch (Exception errorException) {
+ logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId,
+ errorException.getMessage());
+ }
+ }
+ finally {
+ this.lock.unlock();
+ }
+ });
+ }
+
+ /**
+ * Converts data from one type to another using the configured ObjectMapper.
+ * @param data The source data object to convert
+ * @param typeRef The target type reference
+ * @return The converted object of type T
+ * @param The target type
+ */
+ @Override
+ public T unmarshalFrom(Object data, TypeReference typeRef) {
+ return objectMapper.convertValue(data, typeRef);
+ }
+
+ /**
+ * Initiates a graceful shutdown of the transport.
+ * @return A Mono that completes when the shutdown is complete
+ */
+ @Override
+ public Mono closeGracefully() {
+ return Mono.fromRunnable(() -> {
+ WebMvcStreamableMcpSessionTransport.this.close();
+ });
+ }
+
+ /**
+ * Closes the transport immediately.
+ */
+ @Override
+ public void close() {
+ this.lock.lock();
+ try {
+ if (this.closed) {
+ logger.debug("Session transport {} already closed", this.sessionId);
+ return;
+ }
+
+ this.closed = true;
+
+ this.sseBuilder.complete();
+ logger.debug("Successfully completed SSE builder for session {}", sessionId);
+ }
+ catch (Exception e) {
+ logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage());
+ }
+ finally {
+ this.lock.unlock();
+ }
+ }
+
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Builder for creating instances of {@link WebMvcStreamableServerTransportProvider}.
+ */
+ public static class Builder {
+
+ private ObjectMapper objectMapper;
+
+ private String mcpEndpoint = "/mcp";
+
+ private boolean disallowDelete = false;
+
+ private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context;
+
+ /**
+ * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP
+ * messages.
+ * @param objectMapper The ObjectMapper instance. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if objectMapper is null
+ */
+ public Builder objectMapper(ObjectMapper objectMapper) {
+ Assert.notNull(objectMapper, "ObjectMapper must not be null");
+ this.objectMapper = objectMapper;
+ return this;
+ }
+
+ /**
+ * Sets the endpoint URI where clients should send their JSON-RPC messages.
+ * @param mcpEndpoint The MCP endpoint URI. Must not be null.
+ * @return this builder instance
+ * @throws IllegalArgumentException if mcpEndpoint is null
+ */
+ public Builder mcpEndpoint(String mcpEndpoint) {
+ Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
+ this.mcpEndpoint = mcpEndpoint;
+ return this;
+ }
+
+ /**
+ * Sets whether to disallow DELETE requests on the endpoint.
+ * @param disallowDelete true to disallow DELETE requests, false otherwise
+ * @return this builder instance
+ */
+ public Builder disallowDelete(boolean disallowDelete) {
+ this.disallowDelete = disallowDelete;
+ return this;
+ }
+
+ /**
+ * Sets the context extractor that allows providing the MCP feature
+ * implementations to inspect HTTP transport level metadata that was present at
+ * HTTP request processing time. This allows to extract custom headers and other
+ * useful data for use during execution later on in the process.
+ * @param contextExtractor The contextExtractor to fill in a
+ * {@link McpTransportContext}.
+ * @return this builder instance
+ * @throws IllegalArgumentException if contextExtractor is null
+ */
+ public Builder contextExtractor(McpTransportContextExtractor contextExtractor) {
+ Assert.notNull(contextExtractor, "contextExtractor must not be null");
+ this.contextExtractor = contextExtractor;
+ return this;
+ }
+
+ /**
+ * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with
+ * the configured settings.
+ * @return A new WebMvcStreamableServerTransportProvider instance
+ * @throws IllegalStateException if required parameters are not set
+ */
+ public WebMvcStreamableServerTransportProvider build() {
+ Assert.notNull(this.objectMapper, "ObjectMapper must be set");
+ Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
+
+ return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete,
+ this.contextExtractor);
+ }
+
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java
new file mode 100644
index 00000000..66349216
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server;
+
+import org.apache.catalina.Context;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.startup.Tomcat;
+import org.junit.jupiter.api.Timeout;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
+import org.springframework.web.servlet.DispatcherServlet;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerResponse;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
+import reactor.netty.DisposableServer;
+
+/**
+ * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}.
+ *
+ * @author Christian Tzolov
+ */
+@Timeout(15) // Giving extra time beyond the client timeout
+class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests {
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private static final String MCP_ENDPOINT = "/mcp";
+
+ private DisposableServer httpServer;
+
+ private AnnotationConfigWebApplicationContext appContext;
+
+ private Tomcat tomcat;
+
+ private McpStreamableServerTransportProvider transportProvider;
+
+ @Configuration
+ @EnableWebMvc
+ static class TestConfig {
+
+ @Bean
+ public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() {
+ return WebMvcStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .mcpEndpoint(MCP_ENDPOINT)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(
+ WebMvcStreamableServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+ private McpStreamableServerTransportProvider createMcpTransportProvider() {
+ // Set up Tomcat first
+ tomcat = new Tomcat();
+ tomcat.setPort(PORT);
+
+ // Set Tomcat base directory to java.io.tmpdir to avoid permission issues
+ String baseDir = System.getProperty("java.io.tmpdir");
+ tomcat.setBaseDir(baseDir);
+
+ // Use the same directory for document base
+ Context context = tomcat.addContext("", baseDir);
+
+ // Create and configure Spring WebMvc context
+ appContext = new AnnotationConfigWebApplicationContext();
+ appContext.register(TestConfig.class);
+ appContext.setServletContext(context.getServletContext());
+ appContext.refresh();
+
+ // Get the transport from Spring context
+ transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class);
+
+ // Create DispatcherServlet with our Spring context
+ DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
+
+ // Add servlet to Tomcat and get the wrapper
+ var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
+ wrapper.setLoadOnStartup(1);
+ context.addServletMappingDecoded("/*", "dispatcherServlet");
+
+ try {
+ tomcat.start();
+ tomcat.getConnector(); // Create and start the connector
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ return transportProvider;
+ }
+
+ @Override
+ protected McpServer.AsyncSpecification> prepareAsyncServerBuilder() {
+ return McpServer.async(createMcpTransportProvider());
+ }
+
+ @Override
+ protected void onStart() {
+ }
+
+ @Override
+ protected void onClose() {
+ if (httpServer != null) {
+ httpServer.disposeNow();
+ }
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java
new file mode 100644
index 00000000..cab487f1
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java
@@ -0,0 +1,121 @@
+/*
+ * Copyright 2024-2024 the original author or authors.
+ */
+
+package io.modelcontextprotocol.server;
+
+import org.apache.catalina.Context;
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.startup.Tomcat;
+import org.junit.jupiter.api.Timeout;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
+import org.springframework.web.servlet.DispatcherServlet;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerResponse;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
+import reactor.netty.DisposableServer;
+
+/**
+ * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}.
+ *
+ * @author Christian Tzolov
+ */
+@Timeout(15) // Giving extra time beyond the client timeout
+class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests {
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private static final String MCP_ENDPOINT = "/mcp";
+
+ private DisposableServer httpServer;
+
+ private AnnotationConfigWebApplicationContext appContext;
+
+ private Tomcat tomcat;
+
+ private McpStreamableServerTransportProvider transportProvider;
+
+ @Configuration
+ @EnableWebMvc
+ static class TestConfig {
+
+ @Bean
+ public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() {
+ return WebMvcStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .mcpEndpoint(MCP_ENDPOINT)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(
+ WebMvcStreamableServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+ private McpStreamableServerTransportProvider createMcpTransportProvider() {
+ // Set up Tomcat first
+ tomcat = new Tomcat();
+ tomcat.setPort(PORT);
+
+ // Set Tomcat base directory to java.io.tmpdir to avoid permission issues
+ String baseDir = System.getProperty("java.io.tmpdir");
+ tomcat.setBaseDir(baseDir);
+
+ // Use the same directory for document base
+ Context context = tomcat.addContext("", baseDir);
+
+ // Create and configure Spring WebMvc context
+ appContext = new AnnotationConfigWebApplicationContext();
+ appContext.register(TestConfig.class);
+ appContext.setServletContext(context.getServletContext());
+ appContext.refresh();
+
+ // Get the transport from Spring context
+ transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class);
+
+ // Create DispatcherServlet with our Spring context
+ DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext);
+
+ // Add servlet to Tomcat and get the wrapper
+ var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet);
+ wrapper.setLoadOnStartup(1);
+ context.addServletMappingDecoded("/*", "dispatcherServlet");
+
+ try {
+ tomcat.start();
+ tomcat.getConnector(); // Create and start the connector
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ return transportProvider;
+ }
+
+ @Override
+ protected McpServer.SyncSpecification> prepareSyncServerBuilder() {
+ return McpServer.sync(createMcpTransportProvider());
+ }
+
+ @Override
+ protected void onStart() {
+ }
+
+ @Override
+ protected void onClose() {
+ if (httpServer != null) {
+ httpServer.disposeNow();
+ }
+ }
+
+}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
index 9f2d6abf..45f6b94f 100644
--- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java
@@ -4,67 +4,32 @@
package io.modelcontextprotocol.server;
import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
-import static org.awaitility.Awaitility.await;
-import static org.mockito.Mockito.mock;
import java.time.Duration;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.Function;
import org.apache.catalina.LifecycleException;
import org.apache.catalina.LifecycleState;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
-import org.springframework.web.client.RestClient;
+import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.web.servlet.config.annotation.EnableWebMvc;
import org.springframework.web.servlet.function.RouterFunction;
import org.springframework.web.servlet.function.ServerResponse;
import com.fasterxml.jackson.databind.ObjectMapper;
+import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests;
import io.modelcontextprotocol.client.McpClient;
import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport;
+import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport;
+import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
+import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification;
import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider;
-import io.modelcontextprotocol.spec.McpError;
-import io.modelcontextprotocol.spec.McpSchema;
-import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
-import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities;
-import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
-import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
-import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
-import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
-import io.modelcontextprotocol.spec.McpSchema.Role;
-import io.modelcontextprotocol.spec.McpSchema.Root;
-import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
-import io.modelcontextprotocol.spec.McpSchema.Tool;
-import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
-import reactor.test.StepVerifier;
-import org.springframework.context.annotation.Bean;
-import org.springframework.context.annotation.Configuration;
-import org.springframework.web.client.RestClient;
-import org.springframework.web.servlet.config.annotation.EnableWebMvc;
-import org.springframework.web.servlet.function.RouterFunction;
-import org.springframework.web.servlet.function.ServerResponse;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
-import static org.awaitility.Awaitility.await;
-import static org.mockito.Mockito.mock;
-import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson;
-import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json;
-
-import net.javacrumbs.jsonunit.core.Option;
-
-class WebMvcSseIntegrationTests {
+class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests {
private static final int PORT = TestUtil.findAvailablePort();
@@ -72,7 +37,17 @@ class WebMvcSseIntegrationTests {
private WebMvcSseServerTransportProvider mcpServerTransportProvider;
- McpClient.SyncSpec clientBuilder;
+ @Override
+ protected void prepareClients(int port, String mcpEndpoint) {
+
+ clientBuilders.put("httpclient",
+ McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build())
+ .initializationTimeout(Duration.ofHours(10))
+ .requestTimeout(Duration.ofHours(10)));
+
+ clientBuilders.put("webflux", McpClient
+ .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build()));
+ }
@Configuration
@EnableWebMvc
@@ -105,7 +80,7 @@ public void before() {
throw new RuntimeException("Failed to start Tomcat", e);
}
- clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build());
+ prepareClients(PORT, MESSAGE_ENDPOINT);
// Get the transport from Spring context
mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class);
@@ -133,1102 +108,14 @@ public void after() {
}
}
- // ---------------------------------------
- // Sampling Tests
- // ---------------------------------------
- @Test
- void testCreateMessageWithoutSamplingCapabilities() {
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
- exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block();
- return Mono.just(mock(CallToolResult.class));
- })
- .build();
-
- //@formatter:off
- var server = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .tools(tool)
- .build();
-
- try (
- // Create client without sampling capabilities
- var client = clientBuilder
- .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
- .build()) {//@formatter:on
-
- assertThat(client.initialize()).isNotNull();
-
- try {
- client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- }
- catch (McpError e) {
- assertThat(e).isInstanceOf(McpError.class)
- .hasMessage("Client must be configured with sampling capabilities");
- }
- }
- server.close();
- }
-
- @Test
- void testCreateMessageSuccess() {
-
- Function samplingHandler = request -> {
- assertThat(request.messages()).hasSize(1);
- assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
-
- return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
- CreateMessageResult.StopReason.STOP_SEQUENCE);
- };
-
- CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
- null);
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- var createMessageRequest = McpSchema.CreateMessageRequest.builder()
- .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
- new McpSchema.TextContent("Test message"))))
- .modelPreferences(ModelPreferences.builder()
- .hints(List.of())
- .costPriority(1.0)
- .speedPriority(1.0)
- .intelligencePriority(1.0)
- .build())
- .build();
-
- StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.role()).isEqualTo(Role.USER);
- assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
- assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
- assertThat(result.model()).isEqualTo("MockModelName");
- assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
- }).verifyComplete();
-
- return Mono.just(callResponse);
- })
- .build();
-
- //@formatter:off
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .tools(tool)
- .build();
-
- try (
- var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
- .capabilities(ClientCapabilities.builder().sampling().build())
- .sampling(samplingHandler)
- .build()) {//@formatter:on
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
-
- assertThat(response).isNotNull().isEqualTo(callResponse);
- }
- mcpServer.close();
- }
-
- @Test
- void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException {
-
- // Client
-
- Function samplingHandler = request -> {
- assertThat(request.messages()).hasSize(1);
- assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
- try {
- TimeUnit.SECONDS.sleep(2);
- }
- catch (InterruptedException e) {
- throw new RuntimeException(e);
- }
- return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
- CreateMessageResult.StopReason.STOP_SEQUENCE);
- };
-
- var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
- .capabilities(ClientCapabilities.builder().sampling().build())
- .sampling(samplingHandler)
- .build();
-
- // Server
-
- CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
- null);
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
- .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
- new McpSchema.TextContent("Test message"))))
- .modelPreferences(ModelPreferences.builder()
- .hints(List.of())
- .costPriority(1.0)
- .speedPriority(1.0)
- .intelligencePriority(1.0)
- .build())
- .build();
-
- StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.role()).isEqualTo(Role.USER);
- assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
- assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
- assertThat(result.model()).isEqualTo("MockModelName");
- assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
- }).verifyComplete();
-
- return Mono.just(callResponse);
- })
- .build();
-
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .requestTimeout(Duration.ofSeconds(4))
- .tools(tool)
- .build();
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
-
- assertThat(response).isNotNull();
- assertThat(response).isEqualTo(callResponse);
-
- mcpClient.close();
- mcpServer.close();
- }
-
- @Test
- void testCreateMessageWithRequestTimeoutFail() throws InterruptedException {
-
- // Client
-
- Function samplingHandler = request -> {
- assertThat(request.messages()).hasSize(1);
- assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class);
- try {
- TimeUnit.SECONDS.sleep(2);
- }
- catch (InterruptedException e) {
- throw new RuntimeException(e);
- }
- return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName",
- CreateMessageResult.StopReason.STOP_SEQUENCE);
- };
-
- var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
- .capabilities(ClientCapabilities.builder().sampling().build())
- .sampling(samplingHandler)
- .build();
-
- // Server
-
- CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
- null);
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
- .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
- new McpSchema.TextContent("Test message"))))
- .modelPreferences(ModelPreferences.builder()
- .hints(List.of())
- .costPriority(1.0)
- .speedPriority(1.0)
- .intelligencePriority(1.0)
- .build())
- .build();
-
- StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.role()).isEqualTo(Role.USER);
- assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class);
- assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message");
- assertThat(result.model()).isEqualTo("MockModelName");
- assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE);
- }).verifyComplete();
-
- return Mono.just(callResponse);
- })
- .build();
-
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .requestTimeout(Duration.ofSeconds(1))
- .tools(tool)
- .build();
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
- mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- }).withMessageContaining("Timeout");
-
- mcpClient.close();
- mcpServer.close();
- }
-
- // ---------------------------------------
- // Elicitation Tests
- // ---------------------------------------
- @Test
- void testCreateElicitationWithoutElicitationCapabilities() {
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block();
-
- return Mono.just(mock(CallToolResult.class));
- })
- .build();
-
- var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build();
-
- try (
- // Create client without elicitation capabilities
- var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) {
-
- assertThat(client.initialize()).isNotNull();
-
- try {
- client.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- }
- catch (McpError e) {
- assertThat(e).isInstanceOf(McpError.class)
- .hasMessage("Client must be configured with elicitation capabilities");
- }
- }
- server.closeGracefully().block();
- }
-
- @Test
- void testCreateElicitationSuccess() {
-
- Function elicitationHandler = request -> {
- assertThat(request.message()).isNotEmpty();
- assertThat(request.requestedSchema()).isNotNull();
-
- return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT,
- Map.of("message", request.message()));
- };
-
- CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
- null);
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- var elicitationRequest = McpSchema.ElicitRequest.builder()
- .message("Test message")
- .requestedSchema(
- Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
- .build();
-
- StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT);
- assertThat(result.content().get("message")).isEqualTo("Test message");
- }).verifyComplete();
-
- return Mono.just(callResponse);
- })
- .build();
-
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .tools(tool)
- .build();
-
- try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
- .capabilities(ClientCapabilities.builder().elicitation().build())
- .elicitation(elicitationHandler)
- .build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
-
- assertThat(response).isNotNull();
- assertThat(response).isEqualTo(callResponse);
- }
- mcpServer.closeGracefully().block();
- }
-
- @Test
- void testCreateElicitationWithRequestTimeoutSuccess() {
-
- // Client
-
- Function elicitationHandler = request -> {
- assertThat(request.message()).isNotEmpty();
- assertThat(request.requestedSchema()).isNotNull();
- try {
- TimeUnit.SECONDS.sleep(2);
- }
- catch (InterruptedException e) {
- throw new RuntimeException(e);
- }
- return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT,
- Map.of("message", request.message()));
- };
-
- var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
- .capabilities(ClientCapabilities.builder().elicitation().build())
- .elicitation(elicitationHandler)
- .build();
-
- // Server
-
- CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
- null);
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- var elicitationRequest = McpSchema.ElicitRequest.builder()
- .message("Test message")
- .requestedSchema(
- Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
- .build();
-
- StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT);
- assertThat(result.content().get("message")).isEqualTo("Test message");
- }).verifyComplete();
-
- return Mono.just(callResponse);
- })
- .build();
-
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .requestTimeout(Duration.ofSeconds(3))
- .tools(tool)
- .build();
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
-
- assertThat(response).isNotNull();
- assertThat(response).isEqualTo(callResponse);
-
- mcpClient.closeGracefully();
- mcpServer.closeGracefully().block();
- }
-
- @Test
- void testCreateElicitationWithRequestTimeoutFail() {
-
- // Client
-
- Function elicitationHandler = request -> {
- assertThat(request.message()).isNotEmpty();
- assertThat(request.requestedSchema()).isNotNull();
- try {
- TimeUnit.SECONDS.sleep(2);
- }
- catch (InterruptedException e) {
- throw new RuntimeException(e);
- }
- return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT,
- Map.of("message", request.message()));
- };
-
- var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0"))
- .capabilities(ClientCapabilities.builder().elicitation().build())
- .elicitation(elicitationHandler)
- .build();
-
- // Server
-
- CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")),
- null);
-
- McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- var elicitationRequest = McpSchema.ElicitRequest.builder()
- .message("Test message")
- .requestedSchema(
- Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string"))))
- .build();
-
- StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> {
- assertThat(result).isNotNull();
- assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT);
- assertThat(result.content().get("message")).isEqualTo("Test message");
- }).verifyComplete();
-
- return Mono.just(callResponse);
- })
- .build();
-
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .requestTimeout(Duration.ofSeconds(1))
- .tools(tool)
- .build();
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- assertThatExceptionOfType(McpError.class).isThrownBy(() -> {
- mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- }).withMessageContaining("Timeout");
-
- mcpClient.closeGracefully();
- mcpServer.closeGracefully().block();
- }
-
- // ---------------------------------------
- // Roots Tests
- // ---------------------------------------
- @Test
- void testRootsSuccess() {
- List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2"));
-
- AtomicReference> rootsRef = new AtomicReference<>();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
- .build();
-
- try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
- .roots(roots)
- .build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- assertThat(rootsRef.get()).isNull();
-
- mcpClient.rootsListChangedNotification();
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(roots);
- });
-
- // Remove a root
- mcpClient.removeRoot(roots.get(0).uri());
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(roots.get(1)));
- });
-
- // Add a new root
- var root3 = new Root("uri3://", "root3");
- mcpClient.addRoot(root3);
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3));
- });
- }
-
- mcpServer.close();
- }
-
- @Test
- void testRootsWithoutCapability() {
-
- McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
-
- exchange.listRoots(); // try to list roots
-
- return mock(CallToolResult.class);
- })
- .build();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> {
- }).tools(tool).build();
-
- try (
- // Create client without roots capability
- // No roots capability
- var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) {
-
- assertThat(mcpClient.initialize()).isNotNull();
-
- // Attempt to list roots should fail
- try {
- mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
- }
- catch (McpError e) {
- assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported");
- }
- }
-
- mcpServer.close();
- }
-
- @Test
- void testRootsNotificationWithEmptyRootsList() {
- AtomicReference> rootsRef = new AtomicReference<>();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
- .build();
-
- try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
- .roots(List.of()) // Empty roots list
- .build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- mcpClient.rootsListChangedNotification();
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).isEmpty();
- });
- }
-
- mcpServer.close();
- }
-
- @Test
- void testRootsWithMultipleHandlers() {
- List roots = List.of(new Root("uri1://", "root1"));
-
- AtomicReference> rootsRef1 = new AtomicReference<>();
- AtomicReference> rootsRef2 = new AtomicReference<>();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate))
- .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate))
- .build();
-
- try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
- .roots(roots)
- .build()) {
-
- assertThat(mcpClient.initialize()).isNotNull();
-
- mcpClient.rootsListChangedNotification();
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef1.get()).containsAll(roots);
- assertThat(rootsRef2.get()).containsAll(roots);
- });
- }
-
- mcpServer.close();
- }
-
- @Test
- void testRootsServerCloseWithActiveSubscription() {
- List roots = List.of(new Root("uri1://", "root1"));
-
- AtomicReference> rootsRef = new AtomicReference<>();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate))
- .build();
-
- try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build())
- .roots(roots)
- .build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- mcpClient.rootsListChangedNotification();
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(roots);
- });
- }
-
- mcpServer.close();
- }
-
- // ---------------------------------------
- // Tools Tests
- // ---------------------------------------
-
- String emptyJsonSchema = """
- {
- "$schema": "http://json-schema.org/draft-07/schema#",
- "type": "object",
- "properties": {}
- }
- """;
-
- @Test
- void testToolCallSuccess() {
-
- var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
- McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
- // perform a blocking call to a remote service
- String response = RestClient.create()
- .get()
- .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")
- .retrieve()
- .body(String.class);
- assertThat(response).isNotBlank();
- return callResponse;
- })
- .build();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(tool1)
- .build();
-
- try (var mcpClient = clientBuilder.build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
-
- CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()));
-
- assertThat(response).isNotNull().isEqualTo(callResponse);
- }
-
- mcpServer.close();
- }
-
- @Test
- void testThrowingToolCallIsCaughtBeforeTimeout() {
- McpSyncServer mcpServer = McpServer.sync(mcpServerTransportProvider)
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(new McpServerFeatures.SyncToolSpecification(
- new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {
- // We trigger a timeout on blocking read, raising an exception
- Mono.never().block(Duration.ofSeconds(1));
- return null;
- }))
- .build();
-
- try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- // We expect the tool call to fail immediately with the exception raised by
- // the offending tool
- // instead of getting back a timeout.
- assertThatExceptionOfType(McpError.class)
- .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())))
- .withMessageContaining("Timeout on blocking read");
- }
-
- mcpServer.close();
- }
-
- @Test
- void testToolListChangeHandlingSuccess() {
-
- var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null);
- McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema))
- .callHandler((exchange, request) -> {
- // perform a blocking call to a remote service
- String response = RestClient.create()
- .get()
- .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")
- .retrieve()
- .body(String.class);
- assertThat(response).isNotBlank();
- return callResponse;
- })
- .build();
-
- AtomicReference> rootsRef = new AtomicReference<>();
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(tool1)
- .build();
-
- try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> {
- // perform a blocking call to a remote service
- String response = RestClient.create()
- .get()
- .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")
- .retrieve()
- .body(String.class);
- assertThat(response).isNotBlank();
- rootsRef.set(toolsUpdate);
- }).build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- assertThat(rootsRef.get()).isNull();
-
- assertThat(mcpClient.listTools().tools()).contains(tool1.tool());
-
- mcpServer.notifyToolsListChanged();
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(tool1.tool()));
- });
-
- // Remove a tool
- mcpServer.removeTool("tool1");
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).isEmpty();
- });
-
- // Add a new tool
- McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder()
- .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema))
- .callHandler((exchange, request) -> callResponse)
- .build();
-
- mcpServer.addTool(tool2);
-
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(rootsRef.get()).containsAll(List.of(tool2.tool()));
- });
- }
-
- mcpServer.close();
- }
-
- @Test
- void testInitialize() {
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider).build();
-
- try (var mcpClient = clientBuilder.build()) {
-
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
- }
-
- mcpServer.close();
- }
-
- @Test
- void testPingSuccess() {
- // Create server with a tool that uses ping functionality
- AtomicReference executionOrder = new AtomicReference<>("");
-
- McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
- new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema),
- (exchange, request) -> {
-
- executionOrder.set(executionOrder.get() + "1");
-
- // Test async ping behavior
- return exchange.ping().doOnNext(result -> {
-
- assertThat(result).isNotNull();
- // Ping should return an empty object or map
- assertThat(result).isInstanceOf(Map.class);
-
- executionOrder.set(executionOrder.get() + "2");
- assertThat(result).isNotNull();
- }).then(Mono.fromCallable(() -> {
- executionOrder.set(executionOrder.get() + "3");
- return new CallToolResult("Async ping test completed", false);
- }));
- });
-
- var mcpServer = McpServer.async(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(tool)
- .build();
-
- try (var mcpClient = clientBuilder.build()) {
-
- // Initialize client
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- // Call the tool that tests ping async behavior
- CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of()));
- assertThat(result).isNotNull();
- assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
- assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed");
-
- // Verify execution order
- assertThat(executionOrder.get()).isEqualTo("123");
- }
-
- mcpServer.close();
- }
-
- // ---------------------------------------
- // Tool Structured Output Schema Tests
- // ---------------------------------------
-
- @Test
- void testStructuredOutputValidationSuccess() {
- // Create a tool with output schema
- Map outputSchema = Map.of(
- "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation",
- Map.of("type", "string"), "timestamp", Map.of("type", "string")),
- "required", List.of("result", "operation"));
-
- Tool calculatorTool = Tool.builder()
- .name("calculator")
- .description("Performs mathematical calculations")
- .outputSchema(outputSchema)
- .build();
-
- McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool,
- (exchange, request) -> {
- String expression = (String) request.getOrDefault("expression", "2 + 3");
- double result = evaluateExpression(expression);
- return CallToolResult.builder()
- .structuredContent(
- Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z"))
- .build();
- });
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(tool)
- .build();
-
- try (var mcpClient = clientBuilder.build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- // Verify tool is listed with output schema
- var toolsList = mcpClient.listTools();
- assertThat(toolsList.tools()).hasSize(1);
- assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator");
- // Note: outputSchema might be null in sync server, but validation still works
-
- // Call tool with valid structured output
- CallToolResult response = mcpClient
- .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));
-
- assertThat(response).isNotNull();
- assertThat(response.isError()).isFalse();
-
- // In WebMVC, structured content is returned properly
- if (response.structuredContent() != null) {
- assertThat(response.structuredContent()).containsEntry("result", 5.0)
- .containsEntry("operation", "2 + 3")
- .containsEntry("timestamp", "2024-01-01T10:00:00Z");
- }
- else {
- // Fallback to checking content if structured content is not available
- assertThat(response.content()).isNotEmpty();
- }
-
- assertThat(response.structuredContent()).isNotNull();
- assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER)
- .when(Option.IGNORING_EXTRA_ARRAY_ITEMS)
- .isObject()
- .isEqualTo(json("""
- {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}"""));
- }
-
- mcpServer.close();
- }
-
- @Test
- void testStructuredOutputValidationFailure() {
- // Create a tool with output schema
- Map outputSchema = Map.of("type", "object", "properties",
- Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required",
- List.of("result", "operation"));
-
- Tool calculatorTool = Tool.builder()
- .name("calculator")
- .description("Performs mathematical calculations")
- .outputSchema(outputSchema)
- .build();
-
- McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool,
- (exchange, request) -> {
- // Return invalid structured output. Result should be number, missing
- // operation
- return CallToolResult.builder()
- .addTextContent("Invalid calculation")
- .structuredContent(Map.of("result", "not-a-number", "extra", "field"))
- .build();
- });
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(tool)
- .build();
-
- try (var mcpClient = clientBuilder.build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- // Call tool with invalid structured output
- CallToolResult response = mcpClient
- .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));
-
- assertThat(response).isNotNull();
- assertThat(response.isError()).isTrue();
- assertThat(response.content()).hasSize(1);
- assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
-
- String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text();
- assertThat(errorMessage).contains("Validation failed");
- }
-
- mcpServer.close();
- }
-
- @Test
- void testStructuredOutputMissingStructuredContent() {
- // Create a tool with output schema
- Map outputSchema = Map.of("type", "object", "properties",
- Map.of("result", Map.of("type", "number")), "required", List.of("result"));
-
- Tool calculatorTool = Tool.builder()
- .name("calculator")
- .description("Performs mathematical calculations")
- .outputSchema(outputSchema)
- .build();
-
- McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool,
- (exchange, request) -> {
- // Return result without structured content but tool has output schema
- return CallToolResult.builder().addTextContent("Calculation completed").build();
- });
-
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .tools(tool)
- .build();
-
- try (var mcpClient = clientBuilder.build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- // Call tool that should return structured content but doesn't
- CallToolResult response = mcpClient
- .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3")));
-
- assertThat(response).isNotNull();
- assertThat(response.isError()).isTrue();
- assertThat(response.content()).hasSize(1);
- assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
-
- String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text();
- assertThat(errorMessage).isEqualTo(
- "Response missing structured content which is expected when calling tool with non-empty outputSchema");
- }
-
- mcpServer.close();
- }
-
- @Test
- void testStructuredOutputRuntimeToolAddition() {
- // Start server without tools
- var mcpServer = McpServer.sync(mcpServerTransportProvider)
- .serverInfo("test-server", "1.0.0")
- .capabilities(ServerCapabilities.builder().tools(true).build())
- .build();
-
- try (var mcpClient = clientBuilder.build()) {
- InitializeResult initResult = mcpClient.initialize();
- assertThat(initResult).isNotNull();
-
- // Initially no tools
- assertThat(mcpClient.listTools().tools()).isEmpty();
-
- // Add tool with output schema at runtime
- Map outputSchema = Map.of("type", "object", "properties",
- Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required",
- List.of("message", "count"));
-
- Tool dynamicTool = Tool.builder()
- .name("dynamic-tool")
- .description("Dynamically added tool")
- .outputSchema(outputSchema)
- .build();
-
- McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool,
- (exchange, request) -> {
- int count = (Integer) request.getOrDefault("count", 1);
- return CallToolResult.builder()
- .addTextContent("Dynamic tool executed " + count + " times")
- .structuredContent(Map.of("message", "Dynamic execution", "count", count))
- .build();
- });
-
- // Add tool to server
- mcpServer.addTool(toolSpec);
-
- // Wait for tool list change notification
- await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> {
- assertThat(mcpClient.listTools().tools()).hasSize(1);
- });
-
- // Verify tool was added with output schema
- var toolsList = mcpClient.listTools();
- assertThat(toolsList.tools()).hasSize(1);
- assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool");
- // Note: outputSchema might be null in sync server, but validation still works
-
- // Call dynamically added tool
- CallToolResult response = mcpClient
- .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3)));
-
- assertThat(response).isNotNull();
- assertThat(response.isError()).isFalse();
-
- assertThat(response.content()).hasSize(1);
- assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class);
- assertThat(((McpSchema.TextContent) response.content().get(0)).text())
- .isEqualTo("Dynamic tool executed 3 times");
-
- assertThat(response.structuredContent()).isNotNull();
- assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER)
- .when(Option.IGNORING_EXTRA_ARRAY_ITEMS)
- .isObject()
- .isEqualTo(json("""
- {"count":3,"message":"Dynamic execution"}"""));
- }
-
- mcpServer.close();
+ @Override
+ protected AsyncSpecification> prepareAsyncServerBuilder() {
+ return McpServer.async(mcpServerTransportProvider);
}
- private double evaluateExpression(String expression) {
- // Simple expression evaluator for testing
- return switch (expression) {
- case "2 + 3" -> 5.0;
- case "10 * 2" -> 20.0;
- case "7 + 8" -> 15.0;
- case "5 + 3" -> 8.0;
- default -> 0.0;
- };
+ @Override
+ protected SingleSessionSyncSpecification prepareSyncServerBuilder() {
+ return McpServer.sync(mcpServerTransportProvider);
}
}
diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java
new file mode 100644
index 00000000..f99b016f
--- /dev/null
+++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java
@@ -0,0 +1,165 @@
+/*
+ * Copyright 2024 - 2024 the original author or authors.
+ */
+package io.modelcontextprotocol.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.time.Duration;
+
+import org.apache.catalina.LifecycleException;
+import org.apache.catalina.LifecycleState;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
+import org.springframework.web.servlet.function.RouterFunction;
+import org.springframework.web.servlet.function.ServerResponse;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests;
+import io.modelcontextprotocol.client.McpClient;
+import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport;
+import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport;
+import io.modelcontextprotocol.server.McpServer.AsyncSpecification;
+import io.modelcontextprotocol.server.McpServer.SyncSpecification;
+import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider;
+import io.modelcontextprotocol.spec.McpSchema;
+import reactor.core.scheduler.Schedulers;
+
+class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests {
+
+ private static final int PORT = TestUtil.findAvailablePort();
+
+ private static final String MESSAGE_ENDPOINT = "/mcp/message";
+
+ private WebMvcStreamableServerTransportProvider mcpServerTransportProvider;
+
+ @Configuration
+ @EnableWebMvc
+ static class TestConfig {
+
+ @Bean
+ public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() {
+ return WebMvcStreamableServerTransportProvider.builder()
+ .objectMapper(new ObjectMapper())
+ .mcpEndpoint(MESSAGE_ENDPOINT)
+ .build();
+ }
+
+ @Bean
+ public RouterFunction routerFunction(
+ WebMvcStreamableServerTransportProvider transportProvider) {
+ return transportProvider.getRouterFunction();
+ }
+
+ }
+
+ private TomcatTestUtil.TomcatServer tomcatServer;
+
+ @BeforeEach
+ public void before() {
+
+ tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class);
+
+ try {
+ tomcatServer.tomcat().start();
+ assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED);
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to start Tomcat", e);
+ }
+
+ clientBuilders
+ .put("httpclient",
+ McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT)
+ .endpoint(MESSAGE_ENDPOINT)
+ .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10)));
+
+ clientBuilders.put("webflux",
+ McpClient.sync(WebClientStreamableHttpTransport
+ .builder(WebClient.builder().baseUrl("http://localhost:" + PORT))
+ .endpoint(MESSAGE_ENDPOINT)
+ .build()));
+
+ // Get the transport from Spring context
+ this.mcpServerTransportProvider = tomcatServer.appContext()
+ .getBean(WebMvcStreamableServerTransportProvider.class);
+
+ }
+
+ @Override
+ protected AsyncSpecification> prepareAsyncServerBuilder() {
+ return McpServer.async(this.mcpServerTransportProvider);
+ }
+
+ @Override
+ protected SyncSpecification> prepareSyncServerBuilder() {
+ return McpServer.sync(this.mcpServerTransportProvider);
+ }
+
+ @AfterEach
+ public void after() {
+ reactor.netty.http.HttpResources.disposeLoopsAndConnections();
+ if (mcpServerTransportProvider != null) {
+ mcpServerTransportProvider.closeGracefully().block();
+ }
+ Schedulers.shutdownNow();
+ if (tomcatServer.appContext() != null) {
+ tomcatServer.appContext().close();
+ }
+ if (tomcatServer.tomcat() != null) {
+ try {
+ tomcatServer.tomcat().stop();
+ tomcatServer.tomcat().destroy();
+ }
+ catch (LifecycleException e) {
+ throw new RuntimeException("Failed to stop Tomcat", e);
+ }
+ }
+ }
+
+ @ParameterizedTest(name = "{0} : {displayName} ")
+ @ValueSource(strings = { "httpclient", "webflux" })
+ void simple(String clientType) {
+
+ var clientBuilder = clientBuilders.get(clientType);
+
+ var server = McpServer.async(mcpServerTransportProvider)
+ .serverInfo("test-server", "1.0.0")
+ .requestTimeout(Duration.ofSeconds(1000))
+ .build();
+
+ try (
+ // Create client without sampling capabilities
+ var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0"))
+ .requestTimeout(Duration.ofSeconds(1000))
+ .build()) {
+
+ assertThat(client.initialize()).isNotNull();
+
+ }
+ server.closeGracefully();
+ }
+
+ @Override
+ protected void prepareClients(int port, String mcpEndpoint) {
+
+ clientBuilders.put("httpclient", McpClient
+ .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build())
+ .initializationTimeout(Duration.ofHours(10))
+ .requestTimeout(Duration.ofHours(10)));
+
+ clientBuilders.put("webflux",
+ McpClient.sync(WebClientStreamableHttpTransport
+ .builder(WebClient.builder().baseUrl("http://localhost:" + port))
+ .endpoint(mcpEndpoint)
+ .build()));
+ }
+
+}
diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml
index f24d9fab..cc34e96d 100644
--- a/mcp-test/pom.xml
+++ b/mcp-test/pom.xml
@@ -91,6 +91,13 @@
${logback.version}