diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java index a7dac4c05..dd6217dc5 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpAsyncClient.java @@ -267,6 +267,16 @@ public class McpAsyncClient { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_MESSAGE, asyncLoggingNotificationHandler(loggingConsumersFinal)); + // Utility Progress Notification + List>> progressConsumersFinal = new ArrayList<>(); + progressConsumersFinal + .add((notification) -> Mono.fromRunnable(() -> logger.debug("Progress: {}", notification))); + if (!Utils.isEmpty(features.progressConsumers())) { + progressConsumersFinal.addAll(features.progressConsumers()); + } + notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_PROGRESS, + asyncProgressNotificationHandler(progressConsumersFinal)); + this.transport.setExceptionHandler(this::handleException); this.sessionSupplier = () -> new McpClientSession(requestTimeout, transport, requestHandlers, notificationHandlers); @@ -985,6 +995,20 @@ private NotificationHandler asyncLoggingNotificationHandler( }; } + private NotificationHandler asyncProgressNotificationHandler( + List>> progressConsumers) { + + return params -> { + McpSchema.ProgressNotification progressNotification = transport.unmarshalFrom(params, + new TypeReference() { + }); + + return Flux.fromIterable(progressConsumers) + .flatMap(consumer -> consumer.apply(progressNotification)) + .then(); + }; + } + /** * Sets the minimum logging level for messages received from the server. The client * will only receive log messages at or above the specified severity level. diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java index d8925b005..ce603a0fa 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClient.java @@ -177,6 +177,8 @@ class SyncSpec { private final List> loggingConsumers = new ArrayList<>(); + private final List> progressConsumers = new ArrayList<>(); + private Function samplingHandler; private Function elicitationHandler; @@ -377,6 +379,36 @@ public SyncSpec loggingConsumers(List progressConsumer) { + Assert.notNull(progressConsumer, "Progress consumer must not be null"); + this.progressConsumers.add(progressConsumer); + return this; + } + + /** + * Adds a multiple consumers to be notified of progress notifications from the + * server. This allows the client to track long-running operations and provide + * feedback to users. + * @param progressConsumers A list of consumers that receives progress + * notifications. Must not be null. + * @return This builder instance for method chaining + * @throws IllegalArgumentException if progressConsumer is null + */ + public SyncSpec progressConsumers(List> progressConsumers) { + Assert.notNull(progressConsumers, "Progress consumers must not be null"); + this.progressConsumers.addAll(progressConsumers); + return this; + } + /** * Create an instance of {@link McpSyncClient} with the provided configurations or * sensible defaults. @@ -385,7 +417,8 @@ public SyncSpec loggingConsumers(List>> loggingConsumers = new ArrayList<>(); + private final List>> progressConsumers = new ArrayList<>(); + private Function> samplingHandler; private Function> elicitationHandler; @@ -663,8 +698,8 @@ public McpAsyncClient build() { return new McpAsyncClient(this.transport, this.requestTimeout, this.initializationTimeout, new McpClientFeatures.Async(this.clientInfo, this.capabilities, this.roots, this.toolsChangeConsumers, this.resourcesChangeConsumers, this.resourcesUpdateConsumers, - this.promptsChangeConsumers, this.loggingConsumers, this.samplingHandler, - this.elicitationHandler)); + this.promptsChangeConsumers, this.loggingConsumers, this.progressConsumers, + this.samplingHandler, this.elicitationHandler)); } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java index bd1a0985a..ed3f48b9b 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/client/McpClientFeatures.java @@ -59,6 +59,7 @@ class McpClientFeatures { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -68,6 +69,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, + List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler) { @@ -79,6 +81,7 @@ record Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progressconsumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -89,6 +92,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c List, Mono>> resourcesUpdateConsumers, List, Mono>> promptsChangeConsumers, List>> loggingConsumers, + List>> progressConsumers, Function> samplingHandler, Function> elicitationHandler) { @@ -106,6 +110,7 @@ public Async(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities c this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; } @@ -149,6 +154,12 @@ public static Async fromSync(Sync syncSpec) { .subscribeOn(Schedulers.boundedElastic())); } + List>> progressConsumers = new ArrayList<>(); + for (Consumer consumer : syncSpec.progressConsumers()) { + progressConsumers.add(p -> Mono.fromRunnable(() -> consumer.accept(p)) + .subscribeOn(Schedulers.boundedElastic())); + } + Function> samplingHandler = r -> Mono .fromCallable(() -> syncSpec.samplingHandler().apply(r)) .subscribeOn(Schedulers.boundedElastic()); @@ -159,7 +170,7 @@ public static Async fromSync(Sync syncSpec) { return new Async(syncSpec.clientInfo(), syncSpec.clientCapabilities(), syncSpec.roots(), toolsChangeConsumers, resourcesChangeConsumers, resourcesUpdateConsumers, promptsChangeConsumers, - loggingConsumers, samplingHandler, elicitationHandler); + loggingConsumers, progressConsumers, samplingHandler, elicitationHandler); } } @@ -174,6 +185,7 @@ public static Async fromSync(Sync syncSpec) { * @param resourcesChangeConsumers the resources change consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -183,6 +195,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, + List> progressConsumers, Function samplingHandler, Function elicitationHandler) { @@ -196,6 +209,7 @@ public record Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabili * @param resourcesUpdateConsumers the resource update consumers. * @param promptsChangeConsumers the prompts change consumers. * @param loggingConsumers the logging consumers. + * @param progressConsumers the progress consumers. * @param samplingHandler the sampling handler. * @param elicitationHandler the elicitation handler. */ @@ -205,6 +219,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl List>> resourcesUpdateConsumers, List>> promptsChangeConsumers, List> loggingConsumers, + List> progressConsumers, Function samplingHandler, Function elicitationHandler) { @@ -222,6 +237,7 @@ public Sync(McpSchema.Implementation clientInfo, McpSchema.ClientCapabilities cl this.resourcesUpdateConsumers = resourcesUpdateConsumers != null ? resourcesUpdateConsumers : List.of(); this.promptsChangeConsumers = promptsChangeConsumers != null ? promptsChangeConsumers : List.of(); this.loggingConsumers = loggingConsumers != null ? loggingConsumers : List.of(); + this.progressConsumers = progressConsumers != null ? progressConsumers : List.of(); this.samplingHandler = samplingHandler; this.elicitationHandler = elicitationHandler; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 02ad955b9..dfb77afd6 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -360,7 +360,7 @@ private McpServerSession.RequestHandler toolsCallRequestHandler( return Mono.error(new McpError("Tool not found: " + callToolRequest.name())); } - return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest.arguments())) + return toolSpecification.map(tool -> tool.call().apply(exchange, callToolRequest)) .orElse(Mono.error(new McpError("Tool not found: " + callToolRequest.name()))); }; } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 2fd95a10d..aea869bc2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -139,6 +139,16 @@ public Mono listRoots(String cursor) { LIST_ROOTS_RESULT_TYPE_REF); } + public Mono notification(String method, Object params) { + if (method == null || method.isEmpty()) { + return Mono.error(new McpError("Method must not be null or empty")); + } + if (params == null) { + return Mono.error(new McpError("Params must not be null")); + } + return this.session.sendNotification(method, params); + } + /** * Send a logging message notification to the client. Messages below the current * minimum logging level will be filtered out. diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java index d6ec2cc30..fc38d2915 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServer.java @@ -309,7 +309,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * Example usage:
{@code
 		 * .tool(
 		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> Mono.fromSupplier(() -> calculate(args))
+		 *     (exchange, request) -> Mono.fromSupplier(() -> calculate(request))
 		 *         .map(result -> new CallToolResult("Result: " + result))
 		 * )
 		 * }
@@ -323,7 +323,7 @@ public AsyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabi * @throws IllegalArgumentException if tool or handler is null */ public AsyncSpecification tool(McpSchema.Tool tool, - BiFunction, Mono> handler) { + BiFunction> handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); @@ -801,7 +801,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * Example usage:
{@code
 		 * .tool(
 		 *     new Tool("calculator", "Performs calculations", schema),
-		 *     (exchange, args) -> new CallToolResult("Result: " + calculate(args))
+		 *     (exchange, request) -> new CallToolResult("Result: " + calculate(request))
 		 * )
 		 * }
* @param tool The tool definition including name, description, and schema. Must @@ -814,7 +814,7 @@ public SyncSpecification capabilities(McpSchema.ServerCapabilities serverCapabil * @throws IllegalArgumentException if tool or handler is null */ public SyncSpecification tool(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> handler) { + BiFunction handler) { Assert.notNull(tool, "Tool must not be null"); Assert.notNull(handler, "Handler must not be null"); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java index 8311f5d41..95e55bae1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpServerFeatures.java @@ -222,8 +222,8 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * .required("expression") * .property("expression", JsonSchemaType.STRING) * ), - * (exchange, args) -> { - * String expr = (String) args.get("expression"); + * (exchange, request) -> { + * String expr = (String) request.arguments().get("expression"); * return Mono.fromSupplier(() -> evaluate(expr)) * .map(result -> new CallToolResult("Result: " + result)); * } @@ -237,7 +237,7 @@ record Sync(McpSchema.Implementation serverInfo, McpSchema.ServerCapabilities se * connected client. The second arguments is a map of tool arguments. */ public record AsyncToolSpecification(McpSchema.Tool tool, - BiFunction, Mono> call) { + BiFunction> call) { static AsyncToolSpecification fromSync(SyncToolSpecification tool) { // FIXME: This is temporary, proper validation should be implemented @@ -413,7 +413,7 @@ static AsyncCompletionSpecification fromSync(SyncCompletionSpecification complet * client. The second arguments is a map of arguments passed to the tool. */ public record SyncToolSpecification(McpSchema.Tool tool, - BiFunction, McpSchema.CallToolResult> call) { + BiFunction call) { } /** diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java index 597130946..e1a17924d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java @@ -23,6 +23,7 @@ import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.util.annotation.Nullable; /** * Based on the JSON-RPC 2.0 @@ -57,6 +58,8 @@ private McpSchema() { public static final String METHOD_PING = "ping"; + public static final String METHOD_NOTIFICATION_PROGRESS = "notifications/progress"; + // Tool Methods public static final String METHOD_TOOLS_LIST = "tools/list"; @@ -867,15 +870,22 @@ private static JsonSchema parseSchema(String schema) { * tools/list. * @param arguments Arguments to pass to the tool. These must conform to the tool's * input schema. + * @param _meta Optional metadata about the request. This can include additional + * information like `progressToken` */ @JsonInclude(JsonInclude.Include.NON_ABSENT) @JsonIgnoreProperties(ignoreUnknown = true) public record CallToolRequest(// @formatter:off @JsonProperty("name") String name, - @JsonProperty("arguments") Map arguments) implements Request { + @JsonProperty("arguments") Map arguments, + @Nullable @JsonProperty("_meta") Map _meta) implements Request { public CallToolRequest(String name, String jsonArguments) { - this(name, parseJsonArguments(jsonArguments)); + this(name, parseJsonArguments(jsonArguments), null); + } + + public CallToolRequest(String name, Map arguments) { + this(name, arguments, null); } private static Map parseJsonArguments(String jsonArguments) { @@ -1309,11 +1319,23 @@ public record PaginatedResult(@JsonProperty("nextCursor") String nextCursor) { // --------------------------- // Progress and Logging // --------------------------- + + /** + * The Model Context Protocol (MCP) supports optional progress tracking for + * long-running operations through notification messages. Either side can send + * progress notifications to provide updates about operation status. + * + * @param progressToken The original progress token + * @param progress The current progress value so far + * @param total An optional “total” value + * @param message An optional “message” value + */ @JsonIgnoreProperties(ignoreUnknown = true) public record ProgressNotification(// @formatter:off @JsonProperty("progressToken") String progressToken, - @JsonProperty("progress") double progress, - @JsonProperty("total") Double total) { + @JsonProperty("progress") Double progress, + @JsonProperty("total") Double total, + @JsonProperty("message") String message) { }// @formatter:on /** diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java index dc9d1cfab..a93e017ff 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/HttpServletSseServerTransportProviderIntegrationTests.java @@ -958,4 +958,75 @@ void testLoggingNotification() { mcpServer.close(); } + // --------------------------------------- + // Progress Tests + // --------------------------------------- + @Test + void testProgressNotification() { + // Create a list to store received logging notifications + List receivedNotifications = new ArrayList<>(); + + // Create server with a tool that sends logging notifications + McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification( + new McpSchema.Tool("progress-test", "Test progress notifications", emptyJsonSchema), + (exchange, request) -> { + + var progressToken = (String) request._meta().get("progressToken"); + + exchange + .notification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, 0.1, 1.0, "Test progress 1/10")) + .block(); + + exchange + .notification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, 0.5, 1.0, "Test progress 5/10")) + .block(); + + exchange + .notification(McpSchema.METHOD_NOTIFICATION_PROGRESS, + new McpSchema.ProgressNotification(progressToken, 1.0, 1.0, "Test progress 10/10")) + .block(); + + return Mono.just(new CallToolResult("Progress test completed", false)); + }); + + var mcpServer = McpServer.async(mcpServerTransportProvider) + .serverInfo("test-server", "1.0.0") + .capabilities(ServerCapabilities.builder().logging().tools(true).build()) + .tools(tool) + .build(); + try ( + // Create client with progress notification handler + var mcpClient = clientBuilder.progressConsumer(receivedNotifications::add).build()) { + + // Initialize client + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // Call the tool that sends progress notifications + CallToolResult result = mcpClient.callTool( + new McpSchema.CallToolRequest("progress-test", Map.of(), Map.of("progressToken", "test-token"))); + assertThat(result).isNotNull(); + assertThat(result.content().get(0)).isInstanceOf(McpSchema.TextContent.class); + assertThat(((McpSchema.TextContent) result.content().get(0)).text()).isEqualTo("Progress test completed"); + + // Wait for notifications to be processed + await().atMost(Duration.ofSeconds(5)).untilAsserted(() -> { + + System.out.println("Received notifications: " + receivedNotifications); + + // Should have received 3 notifications + assertThat(receivedNotifications).hasSize(3); + + // Check the progress notifications + assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progressToken)) + .containsExactlyInAnyOrder("test-token", "test-token", "test-token"); + assertThat(receivedNotifications.stream().map(McpSchema.ProgressNotification::progress)) + .containsExactlyInAnyOrder(0.1, 0.5, 1.0); + }); + } + mcpServer.close(); + } + } diff --git a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java index df8176a4b..bb09cf662 100644 --- a/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java +++ b/mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java @@ -186,6 +186,24 @@ void testJSONRPCRequest() throws Exception { {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"}}""")); } + @Test + void testJSONRPCRequestWithMeta() throws Exception { + Map params = new HashMap<>(); + params.put("key", "value"); + params.put("_meta", Map.of("progressToken", "abc123")); + + McpSchema.JSONRPCRequest request = new McpSchema.JSONRPCRequest(McpSchema.JSONRPC_VERSION, "method_name", 1, + params); + + String value = mapper.writeValueAsString(request); + assertThatJson(value).when(Option.IGNORING_ARRAY_ORDER) + .when(Option.IGNORING_EXTRA_ARRAY_ITEMS) + .isObject() + .isEqualTo( + json(""" + {"jsonrpc":"2.0","method":"method_name","id":1,"params":{"key":"value"},"_meta":{"progressToken":"abc123"}}""")); + } + @Test void testJSONRPCNotification() throws Exception { Map params = new HashMap<>();