Skip to content

fix: Fixed the issue where the returnDirect attribute was not effective on the MCP server side. #3787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp;

import io.modelcontextprotocol.client.McpAsyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import java.util.Map;
Expand All @@ -28,6 +29,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;

/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
Expand Down Expand Up @@ -55,16 +57,21 @@
* }</pre>
*
* @author Christian Tzolov
* @author Sun Yuhan
* @see ToolCallback
* @see McpAsyncClient
* @see Tool
*/
public class AsyncMcpToolCallback implements ToolCallback {

private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

private final McpAsyncClient asyncMcpClient;

private final Tool tool;

private final ToolMetadata toolMetadata;

/**
* Creates a new {@code AsyncMcpToolCallback} instance.
* @param mcpClient the MCP client to use for tool execution
Expand All @@ -73,6 +80,14 @@ public class AsyncMcpToolCallback implements ToolCallback {
public AsyncMcpToolCallback(McpAsyncClient mcpClient, Tool tool) {
this.asyncMcpClient = mcpClient;
this.tool = tool;
McpSchema.ToolAnnotations annotations = tool.annotations();
Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null;
if (returnDirect != null) {
this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build();
}
else {
this.toolMetadata = DEFAULT_TOOL_METADATA;
}
}

/**
Expand Down Expand Up @@ -130,4 +145,9 @@ public String call(String toolArguments, ToolContext toolContext) {
return this.call(toolArguments);
}

@Override
public ToolMetadata getToolMetadata() {
return this.toolMetadata;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import io.modelcontextprotocol.server.McpSyncServerExchange;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.Role;
import org.springframework.ai.tool.metadata.ToolMetadata;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

Expand Down Expand Up @@ -59,6 +60,7 @@
* </ul>
*
* @author Christian Tzolov
* @author Sun Yuhan
*/
public final class McpToolUtils {

Expand Down Expand Up @@ -166,9 +168,15 @@ public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(To
*/
public static McpServerFeatures.SyncToolSpecification toSyncToolSpecification(ToolCallback toolCallback,
MimeType mimeType) {
boolean returnDirect = Optional.ofNullable(toolCallback.getToolMetadata())
.map(ToolMetadata::returnDirect)
.orElse(false);
McpSchema.ToolAnnotations toolAnnotations = new McpSchema.ToolAnnotations(null, null, null, null, null,
returnDirect);

var tool = new McpSchema.Tool(toolCallback.getToolDefinition().name(),
toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema());
toolCallback.getToolDefinition().description(), toolCallback.getToolDefinition().inputSchema(),
toolAnnotations);

return new McpServerFeatures.SyncToolSpecification(tool, (exchange, request) -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.mcp;

import io.modelcontextprotocol.client.McpSyncClient;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.CallToolRequest;
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
import io.modelcontextprotocol.spec.McpSchema.Tool;
Expand All @@ -30,6 +31,7 @@
import org.springframework.ai.tool.definition.DefaultToolDefinition;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.tool.execution.ToolExecutionException;
import org.springframework.ai.tool.metadata.ToolMetadata;

/**
* Implementation of {@link ToolCallback} that adapts MCP tools to Spring AI's tool
Expand Down Expand Up @@ -57,6 +59,7 @@
* }</pre>
*
* @author Christian Tzolov
* @author Sun Yuhan
* @see ToolCallback
* @see McpSyncClient
* @see Tool
Expand All @@ -65,10 +68,14 @@ public class SyncMcpToolCallback implements ToolCallback {

private static final Logger logger = LoggerFactory.getLogger(SyncMcpToolCallback.class);

private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();

private final McpSyncClient mcpClient;

private final Tool tool;

private final ToolMetadata toolMetadata;

/**
* Creates a new {@code SyncMcpToolCallback} instance.
* @param mcpClient the MCP client to use for tool execution
Expand All @@ -77,7 +84,14 @@ public class SyncMcpToolCallback implements ToolCallback {
public SyncMcpToolCallback(McpSyncClient mcpClient, Tool tool) {
this.mcpClient = mcpClient;
this.tool = tool;

McpSchema.ToolAnnotations annotations = tool.annotations();
Boolean returnDirect = (annotations != null) ? annotations.returnDirect() : null;
if (returnDirect != null) {
this.toolMetadata = ToolMetadata.builder().returnDirect(returnDirect).build();
}
else {
this.toolMetadata = DEFAULT_TOOL_METADATA;
}
}

/**
Expand Down Expand Up @@ -141,4 +155,9 @@ public String call(String toolArguments, ToolContext toolContext) {
return this.call(toolArguments);
}

@Override
public ToolMetadata getToolMetadata() {
return this.toolMetadata;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import reactor.core.publisher.Mono;

import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
Expand All @@ -22,6 +24,27 @@ class AsyncMcpToolCallbackTest {
@Mock
private McpSchema.Tool tool;

@Test
void getToolDefinitionShouldReturnCorrectDefinition() {

var clientInfo = new McpSchema.Implementation("testClient", "1.0.0");
var toolAnnotations = new McpSchema.ToolAnnotations(null, false, false, false, false, true);

when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
when(this.tool.name()).thenReturn("testTool");
when(this.tool.description()).thenReturn("Test tool description");
when(this.tool.annotations()).thenReturn(toolAnnotations);

AsyncMcpToolCallback callback = new AsyncMcpToolCallback(this.mcpClient, this.tool);

var toolDefinition = callback.getToolDefinition();
var toolMetadata = callback.getToolMetadata();

assertThat(toolDefinition.name()).isEqualTo(clientInfo.name() + "_testTool");
assertThat(toolDefinition.description()).isEqualTo("Test tool description");
assertThat(toolMetadata.returnDirect()).isEqualTo(true);
}

@Test
void callShouldThrowOnError() {
when(this.tool.name()).thenReturn("testTool");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.mockito.junit.jupiter.MockitoExtension;

import org.springframework.ai.chat.model.ToolContext;
import org.springframework.ai.content.Content;
import org.springframework.ai.tool.execution.ToolExecutionException;

import static org.assertj.core.api.Assertions.assertThat;
Expand All @@ -53,16 +52,21 @@ class SyncMcpToolCallbackTests {
void getToolDefinitionShouldReturnCorrectDefinition() {

var clientInfo = new Implementation("testClient", "1.0.0");
var toolAnnotations = new McpSchema.ToolAnnotations(null, false, false, false, false, true);

when(this.mcpClient.getClientInfo()).thenReturn(clientInfo);
when(this.tool.name()).thenReturn("testTool");
when(this.tool.description()).thenReturn("Test tool description");
when(this.tool.annotations()).thenReturn(toolAnnotations);

SyncMcpToolCallback callback = new SyncMcpToolCallback(this.mcpClient, this.tool);

var toolDefinition = callback.getToolDefinition();
var toolMetadata = callback.getToolMetadata();

assertThat(toolDefinition.name()).isEqualTo(clientInfo.name() + "_testTool");
assertThat(toolDefinition.description()).isEqualTo("Test tool description");
assertThat(toolMetadata.returnDirect()).isEqualTo(true);
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.modelcontextprotocol.spec.McpSchema.TextContent;
import io.modelcontextprotocol.spec.McpSchema.Tool;
import org.junit.jupiter.api.Test;
import org.springframework.ai.tool.metadata.ToolMetadata;
import reactor.test.StepVerifier;

import org.springframework.ai.tool.ToolCallback;
Expand Down Expand Up @@ -199,8 +200,10 @@ private ToolCallback createMockToolCallback(String name, String result) {
.description("Test tool")
.inputSchema("{}")
.build();
ToolMetadata metadata = ToolMetadata.builder().build();
when(callback.getToolDefinition()).thenReturn(definition);
when(callback.call(anyString(), any())).thenReturn(result);
when(callback.getToolMetadata()).thenReturn(metadata);
return callback;
}

Expand All @@ -211,8 +214,10 @@ private ToolCallback createMockToolCallback(String name, RuntimeException error)
.description("Test tool")
.inputSchema("{}")
.build();
ToolMetadata metadata = ToolMetadata.builder().build();
when(callback.getToolDefinition()).thenReturn(definition);
when(callback.call(anyString(), any())).thenThrow(error);
when(callback.getToolMetadata()).thenReturn(metadata);
return callback;
}

Expand Down