diff --git a/mcp-spring/mcp-spring-webmvc/pom.xml b/mcp-spring/mcp-spring-webmvc/pom.xml index 4c6d37bf..4b8c684a 100644 --- a/mcp-spring/mcp-spring-webmvc/pom.xml +++ b/mcp-spring/mcp-spring-webmvc/pom.xml @@ -28,6 +28,12 @@ 0.11.0-SNAPSHOT + + org.springframework + spring-webmvc + ${springframework.version} + + io.modelcontextprotocol.sdk mcp-test @@ -35,10 +41,11 @@ test - - org.springframework - spring-webmvc - ${springframework.version} + + io.modelcontextprotocol.sdk + mcp-spring-webflux + 0.11.0-SNAPSHOT + test diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index 114eff60..5aa89d52 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -8,6 +8,7 @@ import java.time.Duration; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; @@ -339,6 +340,12 @@ private class WebMvcMcpSessionTransport implements McpServerTransport { private final SseBuilder sseBuilder; + /** + * Lock to ensure thread-safe access to the SSE builder when sending messages. + * This prevents concurrent modifications that could lead to corrupted SSE events. + */ + private final ReentrantLock sseBuilderLock = new ReentrantLock(); + /** * Creates a new session transport with the specified ID and SSE builder. * @param sessionId The unique identifier for this session @@ -358,6 +365,7 @@ private class WebMvcMcpSessionTransport implements McpServerTransport { @Override public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.fromRunnable(() -> { + sseBuilderLock.lock(); try { String jsonText = objectMapper.writeValueAsString(message); sseBuilder.id(sessionId).event(MESSAGE_EVENT_TYPE).data(jsonText); @@ -367,6 +375,9 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message) { logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); sseBuilder.error(e); } + finally { + sseBuilderLock.unlock(); + } }); } @@ -390,6 +401,7 @@ public T unmarshalFrom(Object data, TypeReference typeRef) { public Mono closeGracefully() { return Mono.fromRunnable(() -> { logger.debug("Closing session transport: {}", sessionId); + sseBuilderLock.lock(); try { sseBuilder.complete(); logger.debug("Successfully completed SSE builder for session {}", sessionId); @@ -397,6 +409,9 @@ public Mono closeGracefully() { catch (Exception e) { logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); } + finally { + sseBuilderLock.unlock(); + } }); } @@ -405,6 +420,7 @@ public Mono closeGracefully() { */ @Override public void close() { + sseBuilderLock.lock(); try { sseBuilder.complete(); logger.debug("Successfully completed SSE builder for session {}", sessionId); @@ -412,6 +428,9 @@ public void close() { catch (Exception e) { logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); } + finally { + sseBuilderLock.unlock(); + } } } diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java new file mode 100644 index 00000000..d14a51d8 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -0,0 +1,654 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.RouterFunctions; +import org.springframework.web.servlet.function.ServerRequest; +import org.springframework.web.servlet.function.ServerResponse; +import org.springframework.web.servlet.function.ServerResponse.SseBuilder; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.DefaultMcpTransportContext; +import io.modelcontextprotocol.server.McpTransportContext; +import io.modelcontextprotocol.server.McpTransportContextExtractor; +import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerTransport; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import reactor.core.publisher.Mono; + +/** + * Server-side implementation of the Model Context Protocol (MCP) streamable transport + * layer using HTTP with Server-Sent Events (SSE) through Spring WebMVC. This + * implementation provides a bridge between synchronous WebMVC operations and reactive + * programming patterns to maintain compatibility with the reactive transport interface. + * + *

+ * This is the non-reactive version of + * {@link io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider} + * + * @author Christian Tzolov + * @author Dariusz Jędrzejczyk + * @see McpStreamableServerTransportProvider + * @see RouterFunction + */ +public class WebMvcStreamableServerTransportProvider implements McpStreamableServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(WebMvcStreamableServerTransportProvider.class); + + /** + * Event type for JSON-RPC messages sent through the SSE connection. + */ + public static final String MESSAGE_EVENT_TYPE = "message"; + + /** + * Event type for sending the message endpoint URI to clients. + */ + public static final String ENDPOINT_EVENT_TYPE = "endpoint"; + + /** + * Default base URL for the message endpoint. + */ + public static final String DEFAULT_BASE_URL = ""; + + /** + * The endpoint URI where clients should send their JSON-RPC messages. Defaults to + * "/mcp". + */ + private final String mcpEndpoint; + + /** + * Flag indicating whether DELETE requests are disallowed on the endpoint. + */ + private final boolean disallowDelete; + + private final ObjectMapper objectMapper; + + private final RouterFunction routerFunction; + + private McpStreamableServerSession.Factory sessionFactory; + + /** + * Map of active client sessions, keyed by mcp-session-id. + */ + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + private McpTransportContextExtractor contextExtractor; + + // private Function contextExtractor = req -> new + // DefaultMcpTransportContext(); + + /** + * Flag indicating if the transport is shutting down. + */ + private volatile boolean isClosing = false; + + /** + * Constructs a new WebMvcStreamableServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param mcpEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP. This endpoint will handle GET, POST, and DELETE requests. + * @param disallowDelete Whether to disallow DELETE requests on the endpoint. + * @throws IllegalArgumentException if any parameter is null + */ + private WebMvcStreamableServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + boolean disallowDelete, McpTransportContextExtractor contextExtractor) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + Assert.notNull(contextExtractor, "McpTransportContextExtractor must not be null"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.disallowDelete = disallowDelete; + this.contextExtractor = contextExtractor; + this.routerFunction = RouterFunctions.route() + .GET(this.mcpEndpoint, this::handleGet) + .POST(this.mcpEndpoint, this::handlePost) + .DELETE(this.mcpEndpoint, this::handleDelete) + .build(); + } + + @Override + public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + /** + * Broadcasts a notification to all connected clients through their SSE connections. + * If any errors occur during sending to a particular client, they are logged but + * don't prevent sending to other clients. + * @param method The method name for the notification + * @param params The parameters for the notification + * @return A Mono that completes when the broadcast attempt is finished + */ + @Override + public Mono notifyClients(String method, Object params) { + if (this.sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + + return Mono.fromRunnable(() -> { + this.sessions.values().parallelStream().forEach(session -> { + try { + session.sendNotification(method, params).block(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage()); + } + }); + }); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when all cleanup operations are finished + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + this.isClosing = true; + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + + this.sessions.values().parallelStream().forEach(session -> { + try { + session.closeGracefully().block(); + } + catch (Exception e) { + logger.error("Failed to close session {}: {}", session.getId(), e.getMessage()); + } + }); + + this.sessions.clear(); + logger.debug("Graceful shutdown completed"); + }); + } + + /** + * Returns the RouterFunction that defines the HTTP endpoints for this transport. The + * router function handles three endpoints: + *

    + *
  • GET [mcpEndpoint] - For establishing SSE connections and message replay
  • + *
  • POST [mcpEndpoint] - For receiving JSON-RPC messages from clients
  • + *
  • DELETE [mcpEndpoint] - For session deletion (if enabled)
  • + *
+ * @return The configured RouterFunction for handling HTTP requests + */ + public RouterFunction getRouterFunction() { + return this.routerFunction; + } + + /** + * Setup the listening SSE connections and message replay. + * @param request The incoming server request + * @return A ServerResponse configured for SSE communication, or an error response + */ + private ServerResponse handleGet(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM)) { + return ServerResponse.badRequest().body("Invalid Accept header. Expected TEXT_EVENT_STREAM"); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + logger.debug("Handling GET request for session: {}", sessionId); + + try { + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onTimeout(() -> { + logger.debug("SSE connection timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + // Check if this is a replay request + if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + + try { + session.replay(lastId) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .toIterable() + .forEach(message -> { + try { + sessionTransport.sendMessage(message) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to replay message: {}", e.getMessage()); + sseBuilder.error(e); + } + }); + } + catch (Exception e) { + logger.error("Failed to replay messages: {}", e.getMessage()); + sseBuilder.error(e); + } + } + else { + // Establish new listening stream + McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + .listeningStream(sessionTransport); + + sseBuilder.onComplete(() -> { + logger.debug("SSE connection completed for session: {}", sessionId); + listeningStream.close(); + }); + } + }, Duration.ZERO); + } + catch (Exception e) { + logger.error("Failed to handle GET request for session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).build(); + } + } + + /** + * Handles POST requests for incoming JSON-RPC messages from clients. + * @param request The incoming server request containing the JSON-RPC message + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handlePost(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + List acceptHeaders = request.headers().asHttpHeaders().getAccept(); + if (!acceptHeaders.contains(MediaType.TEXT_EVENT_STREAM) + || !acceptHeaders.contains(MediaType.APPLICATION_JSON)) { + return ServerResponse.badRequest() + .body(new McpError("Invalid Accept headers. Expected TEXT_EVENT_STREAM and APPLICATION_JSON")); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + try { + String body = request.body(String.class); + McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); + + // Handle initialization request + if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest + && jsonrpcRequest.method().equals(McpSchema.METHOD_INITIALIZE)) { + McpSchema.InitializeRequest initializeRequest = objectMapper.convertValue(jsonrpcRequest.params(), + new TypeReference() { + }); + McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory + .startSession(initializeRequest); + this.sessions.put(init.session().getId(), init.session()); + + try { + McpSchema.InitializeResult initResult = init.initResult().block(); + + return ServerResponse.ok() + .contentType(MediaType.APPLICATION_JSON) + .header(HttpHeaders.MCP_SESSION_ID, init.session().getId()) + .body(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, jsonrpcRequest.id(), initResult, + null)); + } + catch (Exception e) { + logger.error("Failed to initialize session: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + // Handle other messages that require a session + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().body(new McpError("Session ID missing")); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.status(HttpStatus.NOT_FOUND) + .body(new McpError("Session not found: " + sessionId)); + } + + if (message instanceof McpSchema.JSONRPCResponse jsonrpcResponse) { + session.accept(jsonrpcResponse) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCNotification jsonrpcNotification) { + session.accept(jsonrpcNotification) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + return ServerResponse.accepted().build(); + } + else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { + // For streaming responses, we need to return SSE + return ServerResponse.sse(sseBuilder -> { + sseBuilder.onComplete(() -> { + logger.debug("Request response stream completed for session: {}", sessionId); + }); + sseBuilder.onTimeout(() -> { + logger.debug("Request response stream timed out for session: {}", sessionId); + }); + + WebMvcStreamableMcpSessionTransport sessionTransport = new WebMvcStreamableMcpSessionTransport( + sessionId, sseBuilder); + + try { + session.responseStream(jsonrpcRequest, sessionTransport) + .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) + .block(); + } + catch (Exception e) { + logger.error("Failed to handle request stream: {}", e.getMessage()); + sseBuilder.error(e); + } + }, Duration.ZERO); + } + else { + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR) + .body(new McpError("Unknown message type")); + } + } + catch (IllegalArgumentException | IOException e) { + logger.error("Failed to deserialize message: {}", e.getMessage()); + return ServerResponse.badRequest().body(new McpError("Invalid message format")); + } + catch (Exception e) { + logger.error("Error handling message: {}", e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Handles DELETE requests for session deletion. + * @param request The incoming server request + * @return A ServerResponse indicating success or appropriate error status + */ + private ServerResponse handleDelete(ServerRequest request) { + if (this.isClosing) { + return ServerResponse.status(HttpStatus.SERVICE_UNAVAILABLE).body("Server is shutting down"); + } + + if (this.disallowDelete) { + return ServerResponse.status(HttpStatus.METHOD_NOT_ALLOWED).build(); + } + + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); + + if (!request.headers().asHttpHeaders().containsKey(HttpHeaders.MCP_SESSION_ID)) { + return ServerResponse.badRequest().body("Session ID required in mcp-session-id header"); + } + + String sessionId = request.headers().asHttpHeaders().getFirst(HttpHeaders.MCP_SESSION_ID); + McpStreamableServerSession session = this.sessions.get(sessionId); + + if (session == null) { + return ServerResponse.notFound().build(); + } + + try { + session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); + this.sessions.remove(sessionId); + return ServerResponse.ok().build(); + } + catch (Exception e) { + logger.error("Failed to delete session {}: {}", sessionId, e.getMessage()); + return ServerResponse.status(HttpStatus.INTERNAL_SERVER_ERROR).body(new McpError(e.getMessage())); + } + } + + /** + * Implementation of McpStreamableServerTransport for WebMVC SSE sessions. This class + * handles the transport-level communication for a specific client session. + * + *

+ * This class is thread-safe and uses a ReentrantLock to synchronize access to the + * underlying SSE builder to prevent race conditions when multiple threads attempt to + * send messages concurrently. + */ + private class WebMvcStreamableMcpSessionTransport implements McpStreamableServerTransport { + + private final String sessionId; + + private final SseBuilder sseBuilder; + + private final ReentrantLock lock = new ReentrantLock(); + + private volatile boolean closed = false; + + /** + * Creates a new session transport with the specified ID and SSE builder. + * @param sessionId The unique identifier for this session + * @param sseBuilder The SSE builder for sending server events to the client + */ + WebMvcStreamableMcpSessionTransport(String sessionId, SseBuilder sseBuilder) { + this.sessionId = sessionId; + this.sseBuilder = sseBuilder; + logger.debug("Streamable session transport {} initialized with SSE builder", sessionId); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection. + * @param message The JSON-RPC message to send + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message) { + return sendMessage(message, null); + } + + /** + * Sends a JSON-RPC message to the client through the SSE connection with a + * specific message ID. + * @param message The JSON-RPC message to send + * @param messageId The message ID for SSE event identification + * @return A Mono that completes when the message has been sent + */ + @Override + public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId) { + return Mono.fromRunnable(() -> { + if (this.closed) { + logger.debug("Attempted to send message to closed session: {}", this.sessionId); + return; + } + + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session {} was closed during message send attempt", this.sessionId); + return; + } + + String jsonText = objectMapper.writeValueAsString(message); + this.sseBuilder.id(messageId != null ? messageId : this.sessionId) + .event(MESSAGE_EVENT_TYPE) + .data(jsonText); + logger.debug("Message sent to session {} with ID {}", this.sessionId, messageId); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); + try { + this.sseBuilder.error(e); + } + catch (Exception errorException) { + logger.error("Failed to send error to SSE builder for session {}: {}", this.sessionId, + errorException.getMessage()); + } + } + finally { + this.lock.unlock(); + } + }); + } + + /** + * Converts data from one type to another using the configured ObjectMapper. + * @param data The source data object to convert + * @param typeRef The target type reference + * @return The converted object of type T + * @param The target type + */ + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + /** + * Initiates a graceful shutdown of the transport. + * @return A Mono that completes when the shutdown is complete + */ + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + WebMvcStreamableMcpSessionTransport.this.close(); + }); + } + + /** + * Closes the transport immediately. + */ + @Override + public void close() { + this.lock.lock(); + try { + if (this.closed) { + logger.debug("Session transport {} already closed", this.sessionId); + return; + } + + this.closed = true; + + this.sseBuilder.complete(); + logger.debug("Successfully completed SSE builder for session {}", sessionId); + } + catch (Exception e) { + logger.warn("Failed to complete SSE builder for session {}: {}", sessionId, e.getMessage()); + } + finally { + this.lock.unlock(); + } + } + + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of {@link WebMvcStreamableServerTransportProvider}. + */ + public static class Builder { + + private ObjectMapper objectMapper; + + private String mcpEndpoint = "/mcp"; + + private boolean disallowDelete = false; + + private McpTransportContextExtractor contextExtractor = (serverRequest, context) -> context; + + /** + * Sets the ObjectMapper to use for JSON serialization/deserialization of MCP + * messages. + * @param objectMapper The ObjectMapper instance. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if objectMapper is null + */ + public Builder objectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the endpoint URI where clients should send their JSON-RPC messages. + * @param mcpEndpoint The MCP endpoint URI. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if mcpEndpoint is null + */ + public Builder mcpEndpoint(String mcpEndpoint) { + Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets whether to disallow DELETE requests on the endpoint. + * @param disallowDelete true to disallow DELETE requests, false otherwise + * @return this builder instance + */ + public Builder disallowDelete(boolean disallowDelete) { + this.disallowDelete = disallowDelete; + return this; + } + + /** + * Sets the context extractor that allows providing the MCP feature + * implementations to inspect HTTP transport level metadata that was present at + * HTTP request processing time. This allows to extract custom headers and other + * useful data for use during execution later on in the process. + * @param contextExtractor The contextExtractor to fill in a + * {@link McpTransportContext}. + * @return this builder instance + * @throws IllegalArgumentException if contextExtractor is null + */ + public Builder contextExtractor(McpTransportContextExtractor contextExtractor) { + Assert.notNull(contextExtractor, "contextExtractor must not be null"); + this.contextExtractor = contextExtractor; + return this; + } + + /** + * Builds a new instance of {@link WebMvcStreamableServerTransportProvider} with + * the configured settings. + * @return A new WebMvcStreamableServerTransportProvider instance + * @throws IllegalStateException if required parameters are not set + */ + public WebMvcStreamableServerTransportProvider build() { + Assert.notNull(this.objectMapper, "ObjectMapper must be set"); + Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); + + return new WebMvcStreamableServerTransportProvider(this.objectMapper, this.mcpEndpoint, this.disallowDelete, + this.contextExtractor); + } + + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java new file mode 100644 index 00000000..66349216 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableAsyncServerTransportTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableAsyncServerTransportTests extends AbstractMcpAsyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint(MCP_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java new file mode 100644 index 00000000..cab487f1 --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMcpStreamableSyncServerTransportTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server; + +import org.apache.catalina.Context; +import org.apache.catalina.LifecycleException; +import org.apache.catalina.startup.Tomcat; +import org.junit.jupiter.api.Timeout; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.context.support.AnnotationConfigWebApplicationContext; +import org.springframework.web.servlet.DispatcherServlet; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; +import reactor.netty.DisposableServer; + +/** + * Tests for {@link McpAsyncServer} using {@link WebFluxSseServerTransportProvider}. + * + * @author Christian Tzolov + */ +@Timeout(15) // Giving extra time beyond the client timeout +class WebMcpStreamableSyncServerTransportTests extends AbstractMcpSyncServerTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MCP_ENDPOINT = "/mcp"; + + private DisposableServer httpServer; + + private AnnotationConfigWebApplicationContext appContext; + + private Tomcat tomcat; + + private McpStreamableServerTransportProvider transportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcSseServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint(MCP_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private McpStreamableServerTransportProvider createMcpTransportProvider() { + // Set up Tomcat first + tomcat = new Tomcat(); + tomcat.setPort(PORT); + + // Set Tomcat base directory to java.io.tmpdir to avoid permission issues + String baseDir = System.getProperty("java.io.tmpdir"); + tomcat.setBaseDir(baseDir); + + // Use the same directory for document base + Context context = tomcat.addContext("", baseDir); + + // Create and configure Spring WebMvc context + appContext = new AnnotationConfigWebApplicationContext(); + appContext.register(TestConfig.class); + appContext.setServletContext(context.getServletContext()); + appContext.refresh(); + + // Get the transport from Spring context + transportProvider = appContext.getBean(McpStreamableServerTransportProvider.class); + + // Create DispatcherServlet with our Spring context + DispatcherServlet dispatcherServlet = new DispatcherServlet(appContext); + + // Add servlet to Tomcat and get the wrapper + var wrapper = Tomcat.addServlet(context, "dispatcherServlet", dispatcherServlet); + wrapper.setLoadOnStartup(1); + context.addServletMappingDecoded("/*", "dispatcherServlet"); + + try { + tomcat.start(); + tomcat.getConnector(); // Create and start the connector + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + return transportProvider; + } + + @Override + protected McpServer.SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(createMcpTransportProvider()); + } + + @Override + protected void onStart() { + } + + @Override + protected void onClose() { + if (httpServer != null) { + httpServer.disposeNow(); + } + } + +} diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java index 9f2d6abf..45f6b94f 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcSseIntegrationTests.java @@ -4,67 +4,32 @@ package io.modelcontextprotocol.server; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; import java.time.Duration; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; import org.apache.catalina.LifecycleException; import org.apache.catalina.LifecycleState; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; +import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.servlet.config.annotation.EnableWebMvc; import org.springframework.web.servlet.function.RouterFunction; import org.springframework.web.servlet.function.ServerResponse; import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.WebFluxSseClientTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SingleSessionSyncSpecification; import io.modelcontextprotocol.server.transport.WebMvcSseServerTransportProvider; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpSchema.CallToolResult; -import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; -import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; -import io.modelcontextprotocol.spec.McpSchema.InitializeResult; -import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; -import io.modelcontextprotocol.spec.McpSchema.Role; -import io.modelcontextprotocol.spec.McpSchema.Root; -import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; -import io.modelcontextprotocol.spec.McpSchema.Tool; -import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; -import reactor.test.StepVerifier; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestClient; -import org.springframework.web.servlet.config.annotation.EnableWebMvc; -import org.springframework.web.servlet.function.RouterFunction; -import org.springframework.web.servlet.function.ServerResponse; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.awaitility.Awaitility.await; -import static org.mockito.Mockito.mock; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; - -import net.javacrumbs.jsonunit.core.Option; - -class WebMvcSseIntegrationTests { +class WebMvcSseIntegrationTests extends AbstractMcpClientServerIntegrationTests { private static final int PORT = TestUtil.findAvailablePort(); @@ -72,7 +37,17 @@ class WebMvcSseIntegrationTests { private WebMvcSseServerTransportProvider mcpServerTransportProvider; - McpClient.SyncSpec clientBuilder; + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", + McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + port).build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", McpClient + .sync(WebFluxSseClientTransport.builder(WebClient.builder().baseUrl("http://localhost:" + port)).build())); + } @Configuration @EnableWebMvc @@ -105,7 +80,7 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilder = McpClient.sync(HttpClientSseClientTransport.builder("http://localhost:" + PORT).build()); + prepareClients(PORT, MESSAGE_ENDPOINT); // Get the transport from Spring context mcpServerTransportProvider = tomcatServer.appContext().getBean(WebMvcSseServerTransportProvider.class); @@ -133,1102 +108,14 @@ public void after() { } } - // --------------------------------------- - // Sampling Tests - // --------------------------------------- - @Test - void testCreateMessageWithoutSamplingCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - //@formatter:off - var server = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - // Create client without sampling capabilities - var client = clientBuilder - .clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) - .build()) {//@formatter:on - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with sampling capabilities"); - } - } - server.close(); - } - - @Test - void testCreateMessageSuccess() { - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var createMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - //@formatter:off - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try ( - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build()) {//@formatter:on - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull().isEqualTo(callResponse); - } - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutSuccess() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(4)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.close(); - mcpServer.close(); - } - - @Test - void testCreateMessageWithRequestTimeoutFail() throws InterruptedException { - - // Client - - Function samplingHandler = request -> { - assertThat(request.messages()).hasSize(1); - assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", - CreateMessageResult.StopReason.STOP_SEQUENCE); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().sampling().build()) - .sampling(samplingHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var craeteMessageRequest = McpSchema.CreateMessageRequest.builder() - .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, - new McpSchema.TextContent("Test message")))) - .modelPreferences(ModelPreferences.builder() - .hints(List.of()) - .costPriority(1.0) - .speedPriority(1.0) - .intelligencePriority(1.0) - .build()) - .build(); - - StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.role()).isEqualTo(Role.USER); - assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); - assertThat(result.model()).isEqualTo("MockModelName"); - assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.close(); - mcpServer.close(); - } - - // --------------------------------------- - // Elicitation Tests - // --------------------------------------- - @Test - void testCreateElicitationWithoutElicitationCapabilities() { - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); - - return Mono.just(mock(CallToolResult.class)); - }) - .build(); - - var server = McpServer.async(mcpServerTransportProvider).serverInfo("test-server", "1.0.0").tools(tool).build(); - - try ( - // Create client without elicitation capabilities - var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { - - assertThat(client.initialize()).isNotNull(); - - try { - client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class) - .hasMessage("Client must be configured with elicitation capabilities"); - } - } - server.closeGracefully().block(); - } - - @Test - void testCreateElicitationSuccess() { - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - } - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutSuccess() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(3)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull(); - assertThat(response).isEqualTo(callResponse); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - @Test - void testCreateElicitationWithRequestTimeoutFail() { - - // Client - - Function elicitationHandler = request -> { - assertThat(request.message()).isNotEmpty(); - assertThat(request.requestedSchema()).isNotNull(); - try { - TimeUnit.SECONDS.sleep(2); - } - catch (InterruptedException e) { - throw new RuntimeException(e); - } - return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, - Map.of("message", request.message())); - }; - - var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) - .capabilities(ClientCapabilities.builder().elicitation().build()) - .elicitation(elicitationHandler) - .build(); - - // Server - - CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), - null); - - McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - var elicitationRequest = McpSchema.ElicitRequest.builder() - .message("Test message") - .requestedSchema( - Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) - .build(); - - StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { - assertThat(result).isNotNull(); - assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); - assertThat(result.content().get("message")).isEqualTo("Test message"); - }).verifyComplete(); - - return Mono.just(callResponse); - }) - .build(); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .requestTimeout(Duration.ofSeconds(1)) - .tools(tool) - .build(); - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThatExceptionOfType(McpError.class).isThrownBy(() -> { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - }).withMessageContaining("Timeout"); - - mcpClient.closeGracefully(); - mcpServer.closeGracefully().block(); - } - - // --------------------------------------- - // Roots Tests - // --------------------------------------- - @Test - void testRootsSuccess() { - List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - - // Remove a root - mcpClient.removeRoot(roots.get(0).uri()); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); - }); - - // Add a new root - var root3 = new Root("uri3://", "root3"); - mcpClient.addRoot(root3); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithoutCapability() { - - McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - - exchange.listRoots(); // try to list roots - - return mock(CallToolResult.class); - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider).rootsChangeHandler((exchange, rootsUpdate) -> { - }).tools(tool).build(); - - try ( - // Create client without roots capability - // No roots capability - var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - // Attempt to list roots should fail - try { - mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - } - catch (McpError e) { - assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); - } - } - - mcpServer.close(); - } - - @Test - void testRootsNotificationWithEmptyRootsList() { - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(List.of()) // Empty roots list - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsWithMultipleHandlers() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef1 = new AtomicReference<>(); - AtomicReference> rootsRef2 = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - assertThat(mcpClient.initialize()).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef1.get()).containsAll(roots); - assertThat(rootsRef2.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - @Test - void testRootsServerCloseWithActiveSubscription() { - List roots = List.of(new Root("uri1://", "root1")); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) - .build(); - - try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) - .roots(roots) - .build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - mcpClient.rootsListChangedNotification(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(roots); - }); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tools Tests - // --------------------------------------- - - String emptyJsonSchema = """ - { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": {} - } - """; - - @Test - void testToolCallSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); - - assertThat(response).isNotNull().isEqualTo(callResponse); - } - - mcpServer.close(); - } - - @Test - void testThrowingToolCallIsCaughtBeforeTimeout() { - McpSyncServer mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(new McpServerFeatures.SyncToolSpecification( - new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> { - // We trigger a timeout on blocking read, raising an exception - Mono.never().block(Duration.ofSeconds(1)); - return null; - })) - .build(); - - try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // We expect the tool call to fail immediately with the exception raised by - // the offending tool - // instead of getting back a timeout. - assertThatExceptionOfType(McpError.class) - .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) - .withMessageContaining("Timeout on blocking read"); - } - - mcpServer.close(); - } - - @Test - void testToolListChangeHandlingSuccess() { - - var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); - McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema)) - .callHandler((exchange, request) -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - return callResponse; - }) - .build(); - - AtomicReference> rootsRef = new AtomicReference<>(); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool1) - .build(); - - try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { - // perform a blocking call to a remote service - String response = RestClient.create() - .get() - .uri("https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md") - .retrieve() - .body(String.class); - assertThat(response).isNotBlank(); - rootsRef.set(toolsUpdate); - }).build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - assertThat(rootsRef.get()).isNull(); - - assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); - - mcpServer.notifyToolsListChanged(); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); - }); - - // Remove a tool - mcpServer.removeTool("tool1"); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).isEmpty(); - }); - - // Add a new tool - McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() - .tool(new McpSchema.Tool("tool2", "tool2 description", emptyJsonSchema)) - .callHandler((exchange, request) -> callResponse) - .build(); - - mcpServer.addTool(tool2); - - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); - }); - } - - mcpServer.close(); - } - - @Test - void testInitialize() { - - var mcpServer = McpServer.sync(mcpServerTransportProvider).build(); - - try (var mcpClient = clientBuilder.build()) { - - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - } - - mcpServer.close(); - } - - @Test - void testPingSuccess() { - // Create server with a tool that uses ping functionality - AtomicReference executionOrder = new AtomicReference<>(""); - - McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( - new McpSchema.Tool("ping-async-test", "Test ping async behavior", emptyJsonSchema), - (exchange, request) -> { - - executionOrder.set(executionOrder.get() + "1"); - - // Test async ping behavior - return exchange.ping().doOnNext(result -> { - - assertThat(result).isNotNull(); - // Ping should return an empty object or map - assertThat(result).isInstanceOf(Map.class); - - executionOrder.set(executionOrder.get() + "2"); - assertThat(result).isNotNull(); - }).then(Mono.fromCallable(() -> { - executionOrder.set(executionOrder.get() + "3"); - return new CallToolResult("Async ping test completed", false); - })); - }); - - var mcpServer = McpServer.async(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - - // Initialize client - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call the tool that tests ping async behavior - CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); - assertThat(result).isNotNull(); - assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); - - // Verify execution order - assertThat(executionOrder.get()).isEqualTo("123"); - } - - mcpServer.close(); - } - - // --------------------------------------- - // Tool Structured Output Schema Tests - // --------------------------------------- - - @Test - void testStructuredOutputValidationSuccess() { - // Create a tool with output schema - Map outputSchema = Map.of( - "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", - Map.of("type", "string"), "timestamp", Map.of("type", "string")), - "required", List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - String expression = (String) request.getOrDefault("expression", "2 + 3"); - double result = evaluateExpression(expression); - return CallToolResult.builder() - .structuredContent( - Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Verify tool is listed with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call tool with valid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - - // In WebMVC, structured content is returned properly - if (response.structuredContent() != null) { - assertThat(response.structuredContent()).containsEntry("result", 5.0) - .containsEntry("operation", "2 + 3") - .containsEntry("timestamp", "2024-01-01T10:00:00Z"); - } - else { - // Fallback to checking content if structured content is not available - assertThat(response.content()).isNotEmpty(); - } - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputValidationFailure() { - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", - List.of("result", "operation")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return invalid structured output. Result should be number, missing - // operation - return CallToolResult.builder() - .addTextContent("Invalid calculation") - .structuredContent(Map.of("result", "not-a-number", "extra", "field")) - .build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool with invalid structured output - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).contains("Validation failed"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputMissingStructuredContent() { - // Create a tool with output schema - Map outputSchema = Map.of("type", "object", "properties", - Map.of("result", Map.of("type", "number")), "required", List.of("result")); - - Tool calculatorTool = Tool.builder() - .name("calculator") - .description("Performs mathematical calculations") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification tool = new McpServerFeatures.SyncToolSpecification(calculatorTool, - (exchange, request) -> { - // Return result without structured content but tool has output schema - return CallToolResult.builder().addTextContent("Calculation completed").build(); - }); - - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .tools(tool) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Call tool that should return structured content but doesn't - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isTrue(); - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - - String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); - assertThat(errorMessage).isEqualTo( - "Response missing structured content which is expected when calling tool with non-empty outputSchema"); - } - - mcpServer.close(); - } - - @Test - void testStructuredOutputRuntimeToolAddition() { - // Start server without tools - var mcpServer = McpServer.sync(mcpServerTransportProvider) - .serverInfo("test-server", "1.0.0") - .capabilities(ServerCapabilities.builder().tools(true).build()) - .build(); - - try (var mcpClient = clientBuilder.build()) { - InitializeResult initResult = mcpClient.initialize(); - assertThat(initResult).isNotNull(); - - // Initially no tools - assertThat(mcpClient.listTools().tools()).isEmpty(); - - // Add tool with output schema at runtime - Map outputSchema = Map.of("type", "object", "properties", - Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", - List.of("message", "count")); - - Tool dynamicTool = Tool.builder() - .name("dynamic-tool") - .description("Dynamically added tool") - .outputSchema(outputSchema) - .build(); - - McpServerFeatures.SyncToolSpecification toolSpec = new McpServerFeatures.SyncToolSpecification(dynamicTool, - (exchange, request) -> { - int count = (Integer) request.getOrDefault("count", 1); - return CallToolResult.builder() - .addTextContent("Dynamic tool executed " + count + " times") - .structuredContent(Map.of("message", "Dynamic execution", "count", count)) - .build(); - }); - - // Add tool to server - mcpServer.addTool(toolSpec); - - // Wait for tool list change notification - await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { - assertThat(mcpClient.listTools().tools()).hasSize(1); - }); - - // Verify tool was added with output schema - var toolsList = mcpClient.listTools(); - assertThat(toolsList.tools()).hasSize(1); - assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); - // Note: outputSchema might be null in sync server, but validation still works - - // Call dynamically added tool - CallToolResult response = mcpClient - .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); - - assertThat(response).isNotNull(); - assertThat(response.isError()).isFalse(); - - assertThat(response.content()).hasSize(1); - assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); - assertThat(((McpSchema.TextContent) response.content().get(0)).text()) - .isEqualTo("Dynamic tool executed 3 times"); - - assertThat(response.structuredContent()).isNotNull(); - assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) - .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) - .isObject() - .isEqualTo(json(""" - {"count":3,"message":"Dynamic execution"}""")); - } - - mcpServer.close(); + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(mcpServerTransportProvider); } - private double evaluateExpression(String expression) { - // Simple expression evaluator for testing - return switch (expression) { - case "2 + 3" -> 5.0; - case "10 * 2" -> 20.0; - case "7 + 8" -> 15.0; - case "5 + 3" -> 8.0; - default -> 0.0; - }; + @Override + protected SingleSessionSyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(mcpServerTransportProvider); } } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java new file mode 100644 index 00000000..f99b016f --- /dev/null +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -0,0 +1,165 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol.server; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.time.Duration; + +import org.apache.catalina.LifecycleException; +import org.apache.catalina.LifecycleState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.web.reactive.function.client.WebClient; +import org.springframework.web.servlet.config.annotation.EnableWebMvc; +import org.springframework.web.servlet.function.RouterFunction; +import org.springframework.web.servlet.function.ServerResponse; + +import com.fasterxml.jackson.databind.ObjectMapper; + +import io.modelcontextprotocol.AbstractMcpClientServerIntegrationTests; +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; +import io.modelcontextprotocol.client.transport.WebClientStreamableHttpTransport; +import io.modelcontextprotocol.server.McpServer.AsyncSpecification; +import io.modelcontextprotocol.server.McpServer.SyncSpecification; +import io.modelcontextprotocol.server.transport.WebMvcStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import reactor.core.scheduler.Schedulers; + +class WebMvcStreamableIntegrationTests extends AbstractMcpClientServerIntegrationTests { + + private static final int PORT = TestUtil.findAvailablePort(); + + private static final String MESSAGE_ENDPOINT = "/mcp/message"; + + private WebMvcStreamableServerTransportProvider mcpServerTransportProvider; + + @Configuration + @EnableWebMvc + static class TestConfig { + + @Bean + public WebMvcStreamableServerTransportProvider webMvcStreamableServerTransportProvider() { + return WebMvcStreamableServerTransportProvider.builder() + .objectMapper(new ObjectMapper()) + .mcpEndpoint(MESSAGE_ENDPOINT) + .build(); + } + + @Bean + public RouterFunction routerFunction( + WebMvcStreamableServerTransportProvider transportProvider) { + return transportProvider.getRouterFunction(); + } + + } + + private TomcatTestUtil.TomcatServer tomcatServer; + + @BeforeEach + public void before() { + + tomcatServer = TomcatTestUtil.createTomcatServer("", PORT, TestConfig.class); + + try { + tomcatServer.tomcat().start(); + assertThat(tomcatServer.tomcat().getServer().getState()).isEqualTo(LifecycleState.STARTED); + } + catch (Exception e) { + throw new RuntimeException("Failed to start Tomcat", e); + } + + clientBuilders + .put("httpclient", + McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .build())); + + // Get the transport from Spring context + this.mcpServerTransportProvider = tomcatServer.appContext() + .getBean(WebMvcStreamableServerTransportProvider.class); + + } + + @Override + protected AsyncSpecification prepareAsyncServerBuilder() { + return McpServer.async(this.mcpServerTransportProvider); + } + + @Override + protected SyncSpecification prepareSyncServerBuilder() { + return McpServer.sync(this.mcpServerTransportProvider); + } + + @AfterEach + public void after() { + reactor.netty.http.HttpResources.disposeLoopsAndConnections(); + if (mcpServerTransportProvider != null) { + mcpServerTransportProvider.closeGracefully().block(); + } + Schedulers.shutdownNow(); + if (tomcatServer.appContext() != null) { + tomcatServer.appContext().close(); + } + if (tomcatServer.tomcat() != null) { + try { + tomcatServer.tomcat().stop(); + tomcatServer.tomcat().destroy(); + } + catch (LifecycleException e) { + throw new RuntimeException("Failed to stop Tomcat", e); + } + } + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + @Override + protected void prepareClients(int port, String mcpEndpoint) { + + clientBuilders.put("httpclient", McpClient + .sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + port).endpoint(mcpEndpoint).build()) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + + clientBuilders.put("webflux", + McpClient.sync(WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + port)) + .endpoint(mcpEndpoint) + .build())); + } + +} diff --git a/mcp-test/pom.xml b/mcp-test/pom.xml index f24d9fab..cc34e96d 100644 --- a/mcp-test/pom.xml +++ b/mcp-test/pom.xml @@ -91,6 +91,13 @@ ${logback.version} + + net.javacrumbs.json-unit + json-unit-assertj + ${json-unit-assertj.version} + + + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java new file mode 100644 index 00000000..d3d4fc07 --- /dev/null +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -0,0 +1,1274 @@ +/* + * Copyright 2024 - 2024 the original author or authors. + */ +package io.modelcontextprotocol; + +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.awaitility.Awaitility.await; +import static org.mockito.Mockito.mock; + +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import io.modelcontextprotocol.client.McpClient; +import io.modelcontextprotocol.server.McpServer; +import io.modelcontextprotocol.server.McpServerFeatures; +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CallToolResult; +import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import io.modelcontextprotocol.spec.McpSchema.ElicitRequest; +import io.modelcontextprotocol.spec.McpSchema.ElicitResult; +import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import io.modelcontextprotocol.spec.McpSchema.ModelPreferences; +import io.modelcontextprotocol.spec.McpSchema.Role; +import io.modelcontextprotocol.spec.McpSchema.Root; +import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; +import io.modelcontextprotocol.spec.McpSchema.Tool; +import net.javacrumbs.jsonunit.core.Option; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public abstract class AbstractMcpClientServerIntegrationTests { + + protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + + abstract protected void prepareClients(int port, String mcpEndpoint); + + abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); + + abstract protected McpServer.SyncSpecification prepareSyncServerBuilder(); + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void simple(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1000)) + .build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .requestTimeout(Duration.ofSeconds(1000)) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + } + server.closeGracefully(); + } + + // --------------------------------------- + // Sampling Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithoutSamplingCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + exchange.createMessage(mock(McpSchema.CreateMessageRequest.class)).block(); + return Mono.just(mock(CallToolResult.class)); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without sampling capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with sampling capabilities"); + } + } + server.closeGracefully(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + //@formatter:off + var mcpServer = prepareAsyncServerBuilder() + .serverInfo("test-server", "1.0.0") + .tools(tool) + .build(); + + try ( + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build()) {//@formatter:on + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutSuccess(String clientType) throws InterruptedException { + + // Client + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + // Server + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(4)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.close(); + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateMessageWithRequestTimeoutFail(String clientType) throws InterruptedException { + + var clientBuilder = clientBuilders.get(clientType); + + Function samplingHandler = request -> { + assertThat(request.messages()).hasSize(1); + assertThat(request.messages().get(0).content()).isInstanceOf(McpSchema.TextContent.class); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new CreateMessageResult(Role.USER, new McpSchema.TextContent("Test message"), "MockModelName", + CreateMessageResult.StopReason.STOP_SEQUENCE); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().sampling().build()) + .sampling(samplingHandler) + .build(); + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var createMessageRequest = McpSchema.CreateMessageRequest.builder() + .messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER, + new McpSchema.TextContent("Test message")))) + .modelPreferences(ModelPreferences.builder() + .hints(List.of()) + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0) + .build()) + .build(); + + StepVerifier.create(exchange.createMessage(createMessageRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.role()).isEqualTo(Role.USER); + assertThat(result.content()).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content()).text()).isEqualTo("Test message"); + assertThat(result.model()).isEqualTo("MockModelName"); + assertThat(result.stopReason()).isEqualTo(CreateMessageResult.StopReason.STOP_SEQUENCE); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("Timeout"); + + mcpClient.close(); + mcpServer.close(); + } + + // --------------------------------------- + // Elicitation Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithoutElicitationCapabilities(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + exchange.createElicitation(mock(McpSchema.ElicitRequest.class)).block(); + + return Mono.just(mock(CallToolResult.class)); + }) + .build(); + + var server = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try ( + // Create client without elicitation capabilities + var client = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")).build()) { + + assertThat(client.initialize()).isNotNull(); + + try { + client.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class) + .hasMessage("Client must be configured with elicitation capabilities"); + } + } + server.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0").tools(tool).build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + } + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + try { + TimeUnit.SECONDS.sleep(2); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new McpSchema.ElicitResult(McpSchema.ElicitResult.Action.ACCEPT, + Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + CallToolResult callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), + null); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = McpSchema.ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + StepVerifier.create(exchange.createElicitation(elicitationRequest)).consumeNextWith(result -> { + assertThat(result).isNotNull(); + assertThat(result.action()).isEqualTo(McpSchema.ElicitResult.Action.ACCEPT); + assertThat(result.content().get("message")).isEqualTo("Test message"); + }).verifyComplete(); + + return Mono.just(callResponse); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(3)) + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull(); + assertThat(response).isEqualTo(callResponse); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testCreateElicitationWithRequestTimeoutFail(String clientType) { + + var latch = new CountDownLatch(1); + + var clientBuilder = clientBuilders.get(clientType); + + Function elicitationHandler = request -> { + assertThat(request.message()).isNotEmpty(); + assertThat(request.requestedSchema()).isNotNull(); + + try { + if (!latch.await(2, TimeUnit.SECONDS)) { + throw new RuntimeException("Timeout waiting for elicitation processing"); + } + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + return new ElicitResult(ElicitResult.Action.ACCEPT, Map.of("message", request.message())); + }; + + var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample client", "0.0.0")) + .capabilities(ClientCapabilities.builder().elicitation().build()) + .elicitation(elicitationHandler) + .build(); + + CallToolResult callResponse = new CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + + AtomicReference resultRef = new AtomicReference<>(); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + var elicitationRequest = ElicitRequest.builder() + .message("Test message") + .requestedSchema( + Map.of("type", "object", "properties", Map.of("message", Map.of("type", "string")))) + .build(); + + return exchange.createElicitation(elicitationRequest) + .doOnNext(resultRef::set) + .then(Mono.just(callResponse)); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .requestTimeout(Duration.ofSeconds(1)) // 1 second. + .tools(tool) + .build(); + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatExceptionOfType(McpError.class).isThrownBy(() -> { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + }).withMessageContaining("within 1000ms"); + + ElicitResult elicitResult = resultRef.get(); + assertThat(elicitResult).isNull(); + + mcpClient.closeGracefully(); + mcpServer.closeGracefully().block(); + } + + // --------------------------------------- + // Roots Tests + // --------------------------------------- + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1"), new Root("uri2://", "root2")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + + // Remove a root + mcpClient.removeRoot(roots.get(0).uri()); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1))); + }); + + // Add a new root + var root3 = new Root("uri3://", "root3"); + mcpClient.addRoot(root3); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(roots.get(1), root3)); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithoutCapability(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + exchange.listRoots(); // try to list roots + + return mock(CallToolResult.class); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().rootsChangeHandler((exchange, rootsUpdate) -> { + }).tools(tool).build(); + + try ( + // Create client without roots capability + // No roots capability + var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().build()).build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + // Attempt to list roots should fail + try { + mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + } + catch (McpError e) { + assertThat(e).isInstanceOf(McpError.class).hasMessage("Roots not supported"); + } + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsNotificationWithEmptyRootsList(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(List.of()) // Empty roots list + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsWithMultipleHandlers(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef1 = new AtomicReference<>(); + AtomicReference> rootsRef2 = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef1.set(rootsUpdate)) + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef2.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + assertThat(mcpClient.initialize()).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef1.get()).containsAll(roots); + assertThat(rootsRef2.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testRootsServerCloseWithActiveSubscription(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + List roots = List.of(new Root("uri1://", "root1")); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder() + .rootsChangeHandler((exchange, rootsUpdate) -> rootsRef.set(rootsUpdate)) + .build(); + + try (var mcpClient = clientBuilder.capabilities(ClientCapabilities.builder().roots(true).build()) + .roots(roots) + .build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + mcpClient.rootsListChangedNotification(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(roots); + }); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tools Tests + // --------------------------------------- + + String emptyJsonSchema = """ + { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": {} + } + """; + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolCallSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + return callResponse; + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + CallToolResult response = mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of())); + + assertThat(response).isNotNull().isEqualTo(callResponse); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + McpSyncServer mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool1") + .description("tool1 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + // We trigger a timeout on blocking read, raising an exception + Mono.never().block(Duration.ofSeconds(1)); + return null; + }) + .build()) + .build(); + + try (var mcpClient = clientBuilder.requestTimeout(Duration.ofMillis(6666)).build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // We expect the tool call to fail immediately with the exception raised by + // the offending tool + // instead of getting back a timeout. + assertThatExceptionOfType(McpError.class) + .isThrownBy(() -> mcpClient.callTool(new McpSchema.CallToolRequest("tool1", Map.of()))) + .withMessageContaining("Timeout on blocking read"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testToolListChangeHandlingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var callResponse = new McpSchema.CallToolResult(List.of(new McpSchema.TextContent("CALL RESPONSE")), null); + McpServerFeatures.SyncToolSpecification tool1 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder().name("tool1").description("tool1 description").inputSchema(emptyJsonSchema).build()) + .callHandler((exchange, request) -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + return callResponse; + }) + .build(); + + AtomicReference> rootsRef = new AtomicReference<>(); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool1) + .build(); + + try (var mcpClient = clientBuilder.toolsChangeConsumer(toolsUpdate -> { + // perform a blocking call to a remote service + try { + HttpResponse response = HttpClient.newHttpClient() + .send(HttpRequest.newBuilder() + .uri(URI.create( + "https://raw.githubusercontent.com/modelcontextprotocol/java-sdk/refs/heads/main/README.md")) + .GET() + .build(), HttpResponse.BodyHandlers.ofString()); + String responseBody = response.body(); + assertThat(responseBody).isNotBlank(); + } + catch (Exception e) { + e.printStackTrace(); + } + + rootsRef.set(toolsUpdate); + }).build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThat(rootsRef.get()).isNull(); + + assertThat(mcpClient.listTools().tools()).contains(tool1.tool()); + + mcpServer.notifyToolsListChanged(); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool1.tool())); + }); + + // Remove a tool + mcpServer.removeTool("tool1"); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).isEmpty(); + }); + + // Add a new tool + McpServerFeatures.SyncToolSpecification tool2 = McpServerFeatures.SyncToolSpecification.builder() + .tool(Tool.builder() + .name("tool2") + .description("tool2 description") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> callResponse) + .build(); + + mcpServer.addTool(tool2); + + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(rootsRef.get()).containsAll(List.of(tool2.tool())); + }); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testInitialize(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mcpServer = prepareSyncServerBuilder().build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testPingSuccess(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create server with a tool that uses ping functionality + AtomicReference executionOrder = new AtomicReference<>(""); + + McpServerFeatures.AsyncToolSpecification tool = McpServerFeatures.AsyncToolSpecification.builder() + .tool(Tool.builder() + .name("ping-async-test") + .description("Test ping async behavior") + .inputSchema(emptyJsonSchema) + .build()) + .callHandler((exchange, request) -> { + + executionOrder.set(executionOrder.get() + "1"); + + // Test async ping behavior + return exchange.ping().doOnNext(result -> { + + assertThat(result).isNotNull(); + // Ping should return an empty object or map + assertThat(result).isInstanceOf(Map.class); + + executionOrder.set(executionOrder.get() + "2"); + assertThat(result).isNotNull(); + }).then(Mono.fromCallable(() -> { + executionOrder.set(executionOrder.get() + "3"); + return new CallToolResult("Async ping test completed", false); + })); + }) + .build(); + + var mcpServer = prepareAsyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that tests ping async behavior + CallToolResult result = mcpClient.callTool(new McpSchema.CallToolRequest("ping-async-test", Map.of())); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Async ping test completed"); + + // Verify execution order + assertThat(executionOrder.get()).isEqualTo("123"); + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tool Structured Output Schema Tests + // --------------------------------------- + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationSuccess(String clientType) { + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of( + "type", "object", "properties", Map.of("result", Map.of("type", "number"), "operation", + Map.of("type", "string"), "timestamp", Map.of("type", "string")), + "required", List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + String expression = (String) request.arguments().getOrDefault("expression", "2 + 3"); + double result = evaluateExpression(expression); + return CallToolResult.builder() + .structuredContent( + Map.of("result", result, "operation", expression, "timestamp", "2024-01-01T10:00:00Z")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Verify tool is listed with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("calculator"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call tool with valid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + // In WebMVC, structured content is returned properly + if (response.structuredContent() != null) { + assertThat(response.structuredContent()).containsEntry("result", 5.0) + .containsEntry("operation", "2 + 3") + .containsEntry("timestamp", "2024-01-01T10:00:00Z"); + } + else { + // Fallback to checking content if structured content is not available + assertThat(response.content()).isNotEmpty(); + } + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"result":5.0,"operation":"2 + 3","timestamp":"2024-01-01T10:00:00Z"}""")); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputValidationFailure(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number"), "operation", Map.of("type", "string")), "required", + List.of("result", "operation")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return invalid structured output. Result should be number, missing + // operation + return CallToolResult.builder() + .addTextContent("Invalid calculation") + .structuredContent(Map.of("result", "not-a-number", "extra", "field")) + .build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool with invalid structured output + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).contains("Validation failed"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputMissingStructuredContent(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Create a tool with output schema + Map outputSchema = Map.of("type", "object", "properties", + Map.of("result", Map.of("type", "number")), "required", List.of("result")); + + Tool calculatorTool = Tool.builder() + .name("calculator") + .description("Performs mathematical calculations") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification tool = McpServerFeatures.SyncToolSpecification.builder() + .tool(calculatorTool) + .callHandler((exchange, request) -> { + // Return result without structured content but tool has output schema + return CallToolResult.builder().addTextContent("Calculation completed").build(); + }) + .build(); + + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tool) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call tool that should return structured content but doesn't + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("calculator", Map.of("expression", "2 + 3"))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isTrue(); + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + + String errorMessage = ((McpSchema.TextContent) response.content().get(0)).text(); + assertThat(errorMessage).isEqualTo( + "Response missing structured content which is expected when calling tool with non-empty outputSchema"); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testStructuredOutputRuntimeToolAddition(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Start server without tools + var mcpServer = prepareSyncServerBuilder().serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().tools(true).build()) + .build(); + + try (var mcpClient = clientBuilder.build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Initially no tools + assertThat(mcpClient.listTools().tools()).isEmpty(); + + // Add tool with output schema at runtime + Map outputSchema = Map.of("type", "object", "properties", + Map.of("message", Map.of("type", "string"), "count", Map.of("type", "integer")), "required", + List.of("message", "count")); + + Tool dynamicTool = Tool.builder() + .name("dynamic-tool") + .description("Dynamically added tool") + .outputSchema(outputSchema) + .build(); + + McpServerFeatures.SyncToolSpecification toolSpec = McpServerFeatures.SyncToolSpecification.builder() + .tool(dynamicTool) + .callHandler((exchange, request) -> { + int count = (Integer) request.arguments().getOrDefault("count", 1); + return CallToolResult.builder() + .addTextContent("Dynamic tool executed " + count + " times") + .structuredContent(Map.of("message", "Dynamic execution", "count", count)) + .build(); + }) + .build(); + + // Add tool to server + mcpServer.addTool(toolSpec); + + // Wait for tool list change notification + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + assertThat(mcpClient.listTools().tools()).hasSize(1); + }); + + // Verify tool was added with output schema + var toolsList = mcpClient.listTools(); + assertThat(toolsList.tools()).hasSize(1); + assertThat(toolsList.tools().get(0).name()).isEqualTo("dynamic-tool"); + // Note: outputSchema might be null in sync server, but validation still works + + // Call dynamically added tool + CallToolResult response = mcpClient + .callTool(new McpSchema.CallToolRequest("dynamic-tool", Map.of("count", 3))); + + assertThat(response).isNotNull(); + assertThat(response.isError()).isFalse(); + + assertThat(response.content()).hasSize(1); + assertThat(response.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) response.content().get(0)).text()) + .isEqualTo("Dynamic tool executed 3 times"); + + assertThat(response.structuredContent()).isNotNull(); + assertThatJson(response.structuredContent()).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo(json(""" + {"count":3,"message":"Dynamic execution"}""")); + } + + mcpServer.close(); + } + + private double evaluateExpression(String expression) { + // Simple expression evaluator for testing + return switch (expression) { + case "2 + 3" -> 5.0; + case "10 * 2" -> 20.0; + case "7 + 8" -> 15.0; + case "5 + 3" -> 8.0; + default -> 0.0; + }; + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index fcd42a43..1b3eee3c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -296,7 +296,7 @@ private McpNotificationHandler asyncRootsListChangedNotificationHandler( List, Mono>> rootsChangeConsumers) { return (exchange, params) -> exchange.listRoots() .flatMap(listRootsResult -> Flux.fromIterable(rootsChangeConsumers) - .flatMap(consumer -> consumer.apply(exchange, listRootsResult.roots())) + .flatMap(consumer -> Mono.defer(() -> consumer.apply(exchange, listRootsResult.roots()))) .onErrorResume(error -> { logger.error("Error handling roots list change notification", error); return Mono.empty(); @@ -506,7 +506,7 @@ private McpRequestHandler toolsCallRequestHandler() { return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.callHandler().apply(exchange, callToolRequest)) + return toolSpecification.map(tool -> Mono.defer(() -> tool.callHandler().apply(exchange, callToolRequest))) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } @@ -634,7 +634,7 @@ private McpRequestHandler resourcesReadRequestHand .findFirst() .orElseThrow(() -> new McpError("Resource not found: " + resourceUri)); - return specification.readHandler().apply(exchange, resourceRequest); + return Mono.defer(() -> specification.readHandler().apply(exchange, resourceRequest)); }; } @@ -740,7 +740,7 @@ private McpRequestHandler promptsGetRequestHandler() return Mono.error(new McpError("Prompt not found: " + promptRequest.name())); } - return specification.promptHandler().apply(exchange, promptRequest); + return Mono.defer(() -> specification.promptHandler().apply(exchange, promptRequest)); }; } @@ -845,7 +845,7 @@ private McpRequestHandler completionCompleteRequestHan return Mono.error(new McpError("AsyncCompletionSpecification not found: " + request.ref())); } - return specification.completionHandler().apply(exchange, request); + return Mono.defer(() -> specification.completionHandler().apply(exchange, request)); }; } diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java index 7a1e9077..0ba8bf92 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/AbstractMcpAsyncServerTests.java @@ -4,9 +4,17 @@ package io.modelcontextprotocol.server; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + import java.time.Duration; import java.util.List; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -18,18 +26,9 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerTransportProvider; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatCode; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - /** * Test suite for the {@link McpAsyncServer} that can be used with different * {@link io.modelcontextprotocol.spec.McpServerTransportProvider} implementations. @@ -65,10 +64,7 @@ void tearDown() { // --------------------------------------- // Server Lifecycle Tests // --------------------------------------- - - @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "sse", "streamable" }) - void testConstructorWithInvalidArguments(String serverType) { + void testConstructorWithInvalidArguments() { assertThatThrownBy(() -> McpServer.async((McpServerTransportProvider) null)) .isInstanceOf(IllegalArgumentException.class) .hasMessage("Transport provider must not be null");