From 0363d6d6de369194de563a6f49d7309a2a383242 Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Fri, 11 Jul 2025 09:17:29 +0800 Subject: [PATCH 1/2] fix: Fixed the issue where the returnDirect attribute of `@Tool` was not effective on the MCP server side. Signed-off-by: Sun Yuhan --- .../ai/mcp/AsyncMcpToolCallback.java | 20 ++++++++++++++++ .../springframework/ai/mcp/McpToolUtils.java | 6 ++++- .../ai/mcp/SyncMcpToolCallback.java | 21 ++++++++++++++++- .../ai/mcp/AsyncMcpToolCallbackTest.java | 23 +++++++++++++++++++ .../ai/mcp/SyncMcpToolCallbackTests.java | 6 ++++- 5 files changed, 73 insertions(+), 3 deletions(-) diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java index 5f8da416109..5d1702b48ac 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/AsyncMcpToolCallback.java @@ -17,6 +17,7 @@ package org.springframework.ai.mcp; import io.modelcontextprotocol.client.McpAsyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.Tool; import java.util.Map; @@ -28,6 +29,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -55,16 +57,21 @@ * } * * @author Christian Tzolov + * @author Sun Yuhan * @see ToolCallback * @see McpAsyncClient * @see Tool */ public class AsyncMcpToolCallback implements ToolCallback { + private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); + private final McpAsyncClient asyncMcpClient; private final Tool tool; + private final ToolMetadata toolMetadata; + /** * Creates a new {@code AsyncMcpToolCallback} instance. * @param mcpClient the MCP client to use for tool execution @@ -73,6 +80,14 @@ public class AsyncMcpToolCallback implements ToolCallback { public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) { this.asyncMcpClient = mcpClient; this.tool = tool; + McpSchema.ToolAnnotations annotations = tool.annotations(); + Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null; + if (returnDirect != null) { + this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build(); + } + else { + this.toolMetadata = DEFAULT_TOOL_METADATA; + } } /** @@ -130,4 +145,9 @@ public String call(String toolArguments, ToolContext toolContext) { return this.call(toolArguments); } + @Override + public ToolMetadata getToolMetadata() { + return this.toolMetadata; + } + } diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index 2f8f366d076..bc9c297f2ad 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -59,6 +59,7 @@ * * * @author Christian Tzolov + * @author Sun Yuhan */ public final class McpToolUtils { @@ -166,9 +167,12 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To */ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback, MimeType mimeType) { + McpSchema.ToolAnnotations toolAnnotations = new McpSchema.ToolAnnotations(null, null, null, null, null, + toolCallback.getToolMetadata().returnDirect()); var tool = new McpSchema.Tool(toolCallback.getToolDefinition().name(), - toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema()); + toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema(), + toolAnnotations); return new McpServerFeatures.SyncToolSpecification(tool, (exchange, request) -> { try { diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java index fc61d801df1..110e5b51e48 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/SyncMcpToolCallback.java @@ -17,6 +17,7 @@ package org.springframework.ai.mcp; import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolRequest; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.Tool; @@ -30,6 +31,7 @@ import org.springframework.ai.tool.definition.DefaultToolDefinition; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.tool.execution.ToolExecutionException; +import org.springframework.ai.tool.metadata.ToolMetadata; /** * Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool @@ -57,6 +59,7 @@ * } * * @author Christian Tzolov + * @author Sun Yuhan * @see ToolCallback * @see McpSyncClient * @see Tool @@ -65,10 +68,14 @@ public class SyncMcpToolCallback implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class); + private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); + private final McpSyncClient mcpClient; private final Tool tool; + private final ToolMetadata toolMetadata; + /** * Creates a new {@code SyncMcpToolCallback} instance. * @param mcpClient the MCP client to use for tool execution @@ -77,7 +84,14 @@ public class SyncMcpToolCallback implements ToolCallback { public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) { this.mcpClient = mcpClient; this.tool = tool; - + McpSchema.ToolAnnotations annotations = tool.annotations(); + Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null; + if (returnDirect != null) { + this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build(); + } + else { + this.toolMetadata = DEFAULT_TOOL_METADATA; + } } /** @@ -141,4 +155,9 @@ public String call(String toolArguments, ToolContext toolContext) { return this.call(toolArguments); } + @Override + public ToolMetadata getToolMetadata() { + return this.toolMetadata; + } + } diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java index abf2c395ed9..14f25205453 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/AsyncMcpToolCallbackTest.java @@ -9,6 +9,8 @@ import reactor.core.publisher.Mono; import org.springframework.ai.tool.execution.ToolExecutionException; + +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -22,6 +24,27 @@ class AsyncMcpToolCallbackTest { @Mock private McpSchema.Tool tool; + @Test + void getToolDefinitionShouldReturnCorrectDefinition() { + + var clientInfo = new McpSchema.Implementation("testClient", "1.0.0"); + var toolAnnotations = new McpSchema.ToolAnnotations(null, false, false, false, false, true); + + when(this.mcpClient.getClientInfo()).thenReturn(clientInfo); + when(this.tool.name()).thenReturn("testTool"); + when(this.tool.description()).thenReturn("Test tool description"); + when(this.tool.annotations()).thenReturn(toolAnnotations); + + AsyncMcpToolCallback callback = new AsyncMcpToolCallback(this.mcpClient, this.tool); + + var toolDefinition = callback.getToolDefinition(); + var toolMetadata = callback.getToolMetadata(); + + assertThat(toolDefinition.name()).isEqualTo(clientInfo.name() + "_testTool"); + assertThat(toolDefinition.description()).isEqualTo("Test tool description"); + assertThat(toolMetadata.returnDirect()).isEqualTo(true); + } + @Test void callShouldThrowOnError() { when(this.tool.name()).thenReturn("testTool"); diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java index 99a901553ad..9b3a671f386 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/SyncMcpToolCallbackTests.java @@ -31,7 +31,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.ai.chat.model.ToolContext; -import org.springframework.ai.content.Content; import org.springframework.ai.tool.execution.ToolExecutionException; import static org.assertj.core.api.Assertions.assertThat; @@ -53,16 +52,21 @@ class SyncMcpToolCallbackTests { void getToolDefinitionShouldReturnCorrectDefinition() { var clientInfo = new Implementation("testClient", "1.0.0"); + var toolAnnotations = new McpSchema.ToolAnnotations(null, false, false, false, false, true); + when(this.mcpClient.getClientInfo()).thenReturn(clientInfo); when(this.tool.name()).thenReturn("testTool"); when(this.tool.description()).thenReturn("Test tool description"); + when(this.tool.annotations()).thenReturn(toolAnnotations); SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool); var toolDefinition = callback.getToolDefinition(); + var toolMetadata = callback.getToolMetadata(); assertThat(toolDefinition.name()).isEqualTo(clientInfo.name() + "_testTool"); assertThat(toolDefinition.description()).isEqualTo("Test tool description"); + assertThat(toolMetadata.returnDirect()).isEqualTo(true); } @Test From a6a45dcfc53261afafe6be8d806a3b88a4d60008 Mon Sep 17 00:00:00 2001 From: Sun Yuhan Date: Fri, 11 Jul 2025 09:42:15 +0800 Subject: [PATCH 2/2] fix: Improved compatibility with empty ToolMetadata and enhanced some unit tests. Signed-off-by: Sun Yuhan --- .../main/java/org/springframework/ai/mcp/McpToolUtils.java | 6 +++++- .../java/org/springframework/ai/mcp/ToolUtilsTests.java | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java index bc9c297f2ad..cbe7e1890e2 100644 --- a/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java +++ b/mcp/common/src/main/java/org/springframework/ai/mcp/McpToolUtils.java @@ -30,6 +30,7 @@ import io.modelcontextprotocol.server.McpSyncServerExchange; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.Role; +import org.springframework.ai.tool.metadata.ToolMetadata; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -167,8 +168,11 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To */ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback, MimeType mimeType) { + boolean returnDirect = Optional.ofNullable(toolCallback.getToolMetadata()) + .map(ToolMetadata::returnDirect) + .orElse(false); McpSchema.ToolAnnotations toolAnnotations = new McpSchema.ToolAnnotations(null, null, null, null, null, - toolCallback.getToolMetadata().returnDirect()); + returnDirect); var tool = new McpSchema.Tool(toolCallback.getToolDefinition().name(), toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema(), diff --git a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java index 2bcbe305c5d..fc48085c7ba 100644 --- a/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java +++ b/mcp/common/src/test/java/org/springframework/ai/mcp/ToolUtilsTests.java @@ -32,6 +32,7 @@ import io.modelcontextprotocol.spec.McpSchema.TextContent; import io.modelcontextprotocol.spec.McpSchema.Tool; import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.metadata.ToolMetadata; import reactor.test.StepVerifier; import org.springframework.ai.tool.ToolCallback; @@ -199,8 +200,10 @@ private ToolCallback createMockToolCallback(String name, String result) { .description("Test tool") .inputSchema("{}") .build(); + ToolMetadata metadata = ToolMetadata.builder().build(); when(callback.getToolDefinition()).thenReturn(definition); when(callback.call(anyString(), any())).thenReturn(result); + when(callback.getToolMetadata()).thenReturn(metadata); return callback; } @@ -211,8 +214,10 @@ private ToolCallback createMockToolCallback(String name, RuntimeException error) .description("Test tool") .inputSchema("{}") .build(); + ToolMetadata metadata = ToolMetadata.builder().build(); when(callback.getToolDefinition()).thenReturn(definition); when(callback.call(anyString(), any())).thenThrow(error); + when(callback.getToolMetadata()).thenReturn(metadata); return callback; }