Skip to content

Commit ee19ee1

Browse files
committed
Add high-level function calling support for Anthropic
- simplify and unify the AbstractToolCallSupport::executeFuncitons. Fix filed typos - remove old classes
1 parent 7e98fb7 commit ee19ee1

File tree

6 files changed

+131
-610
lines changed

6 files changed

+131
-610
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

+98-81
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import java.util.HashSet;
2121
import java.util.List;
2222
import java.util.Map;
23-
import java.util.Optional;
2423
import java.util.Set;
2524
import java.util.stream.Collectors;
2625

@@ -34,23 +33,28 @@
3433
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlock.ContentBlockType;
3534
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
3635
import org.springframework.ai.anthropic.metadata.AnthropicChatResponseMetadata;
36+
import org.springframework.ai.chat.messages.AssistantMessage;
37+
import org.springframework.ai.chat.messages.Message;
3738
import org.springframework.ai.chat.messages.MessageType;
39+
import org.springframework.ai.chat.messages.ToolResponseMessage;
3840
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
3941
import org.springframework.ai.chat.model.ChatModel;
4042
import org.springframework.ai.chat.model.ChatResponse;
4143
import org.springframework.ai.chat.model.Generation;
4244
import org.springframework.ai.chat.prompt.ChatOptions;
4345
import org.springframework.ai.chat.prompt.Prompt;
4446
import org.springframework.ai.model.ModelOptionsUtils;
45-
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
47+
import org.springframework.ai.model.function.AbstractToolCallSupport;
4648
import org.springframework.ai.model.function.FunctionCallbackContext;
4749
import org.springframework.ai.retry.RetryUtils;
4850
import org.springframework.http.ResponseEntity;
4951
import org.springframework.retry.support.RetryTemplate;
5052
import org.springframework.util.Assert;
5153
import org.springframework.util.CollectionUtils;
54+
import org.springframework.util.StringUtils;
5255

5356
import reactor.core.publisher.Flux;
57+
import reactor.core.publisher.Mono;
5458

5559
/**
5660
* The {@link ChatModel} implementation for the Anthropic service.
@@ -60,13 +64,11 @@
6064
* @author Mariusz Bernacki
6165
* @since 1.0.0
6266
*/
63-
public class AnthropicChatModel extends
64-
AbstractFunctionCallSupport<AnthropicApi.AnthropicMessage, AnthropicApi.ChatCompletionRequest, ResponseEntity<AnthropicApi.ChatCompletionResponse>>
65-
implements ChatModel {
67+
public class AnthropicChatModel extends AbstractToolCallSupport<ChatCompletionResponse> implements ChatModel {
6668

6769
private static final Logger logger = LoggerFactory.getLogger(AnthropicChatModel.class);
6870

69-
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_OPUS.getValue();
71+
public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue();
7072

7173
public static final Integer DEFAULT_MAX_TOKENS = 500;
7274

@@ -148,7 +150,14 @@ public ChatResponse call(Prompt prompt) {
148150
ChatCompletionRequest request = createRequest(prompt, false);
149151

150152
return this.retryTemplate.execute(ctx -> {
151-
ResponseEntity<ChatCompletionResponse> completionEntity = this.callWithFunctionSupport(request);
153+
ResponseEntity<ChatCompletionResponse> completionEntity = this.anthropicApi.chatCompletionEntity(request);
154+
155+
if (this.isToolFunctionCall(completionEntity.getBody())) {
156+
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
157+
completionEntity.getBody());
158+
return this.call(new Prompt(toolCallMessageConversation, prompt.getOptions()));
159+
}
160+
152161
return toChatResponse(completionEntity.getBody());
153162
});
154163
}
@@ -162,14 +171,52 @@ public Flux<ChatResponse> stream(Prompt prompt) {
162171

163172
Flux<ChatCompletionResponse> response = this.anthropicApi.chatCompletionStream(request);
164173

165-
return response
166-
.switchMap(chatCompletionResponse -> handleFunctionCallOrReturnStream(request,
167-
Flux.just(ResponseEntity.of(Optional.of(chatCompletionResponse)))))
168-
.map(ResponseEntity::getBody)
169-
.map(this::toChatResponse);
174+
return response.switchMap(chatCompletionResponse -> {
175+
176+
if (this.isToolFunctionCall(chatCompletionResponse)) {
177+
List<Message> toolCallMessageConversation = this.handleToolCallRequests(prompt.getInstructions(),
178+
chatCompletionResponse);
179+
return this.stream(new Prompt(toolCallMessageConversation, prompt.getOptions()));
180+
}
181+
182+
return Mono.just(chatCompletionResponse).map(this::toChatResponse);
183+
});
170184
});
171185
}
172186

187+
private List<Message> handleToolCallRequests(List<Message> previousMessages,
188+
ChatCompletionResponse chatCompletionResponse) {
189+
190+
AnthropicMessage anthropicAssistantMessage = new AnthropicMessage(chatCompletionResponse.content(),
191+
Role.ASSISTANT);
192+
193+
List<ContentBlock> toolToUseList = anthropicAssistantMessage.content()
194+
.stream()
195+
.filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE)
196+
.toList();
197+
198+
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
199+
200+
for (ContentBlock toolToUse : toolToUseList) {
201+
202+
var functionCallId = toolToUse.id();
203+
var functionName = toolToUse.name();
204+
var functionArguments = ModelOptionsUtils.toJsonString(toolToUse.input());
205+
206+
toolCalls.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
207+
}
208+
209+
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
210+
ToolResponseMessage toolResponseMessage = this.executeFuncitons(assistantMessage);
211+
212+
// History
213+
List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages);
214+
toolCallMessageConversation.add(assistantMessage);
215+
toolCallMessageConversation.add(toolResponseMessage);
216+
217+
return toolCallMessageConversation;
218+
}
219+
173220
private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion) {
174221
if (chatCompletion == null) {
175222
logger.warn("Null chat completion returned");
@@ -203,18 +250,45 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
203250

204251
List<AnthropicMessage> userMessages = prompt.getInstructions()
205252
.stream()
206-
.filter(m -> m.getMessageType() != MessageType.SYSTEM)
207-
.map(m -> {
208-
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(m.getContent())));
209-
if (!CollectionUtils.isEmpty(m.getMedia())) {
210-
List<ContentBlock> mediaContent = m.getMedia()
253+
.filter(message -> message.getMessageType() != MessageType.SYSTEM)
254+
.map(message -> {
255+
if (message.getMessageType() == MessageType.USER) {
256+
List<ContentBlock> contents = new ArrayList<>(List.of(new ContentBlock(message.getContent())));
257+
if (!CollectionUtils.isEmpty(message.getMedia())) {
258+
List<ContentBlock> mediaContent = message.getMedia()
259+
.stream()
260+
.map(media -> new ContentBlock(media.getMimeType().toString(),
261+
this.fromMediaData(media.getData())))
262+
.toList();
263+
contents.addAll(mediaContent);
264+
}
265+
return new AnthropicMessage(contents, Role.valueOf(message.getMessageType().name()));
266+
}
267+
else if (message.getMessageType() == MessageType.ASSISTANT) {
268+
AssistantMessage assistantMessage = (AssistantMessage) message;
269+
List<ContentBlock> contentBlocks = new ArrayList<>();
270+
if (StringUtils.hasText(message.getContent())) {
271+
contentBlocks.add(new ContentBlock(message.getContent()));
272+
}
273+
if (!CollectionUtils.isEmpty(assistantMessage.getToolCalls())) {
274+
for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
275+
contentBlocks.add(new ContentBlock(ContentBlockType.TOOL_USE, toolCall.id(),
276+
toolCall.name(), ModelOptionsUtils.jsonToMap(toolCall.arguments())));
277+
}
278+
}
279+
return new AnthropicMessage(contentBlocks, Role.ASSISTANT);
280+
}
281+
else if (message.getMessageType() == MessageType.TOOL) {
282+
List<ContentBlock> toolResponses = ((ToolResponseMessage) message).getResponses()
211283
.stream()
212-
.map(media -> new ContentBlock(media.getMimeType().toString(),
213-
this.fromMediaData(media.getData())))
284+
.map(toolResponse -> new ContentBlock(ContentBlockType.TOOL_RESULT, toolResponse.id(),
285+
toolResponse.responseData()))
214286
.toList();
215-
contents.addAll(mediaContent);
287+
return new AnthropicMessage(toolResponses, Role.USER);
288+
}
289+
else {
290+
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
216291
}
217-
return new AnthropicMessage(contents, Role.valueOf(m.getMessageType().name()));
218292
})
219293
.toList();
220294

@@ -265,74 +339,17 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
265339
}).toList();
266340
}
267341

268-
@Override
269-
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
270-
AnthropicMessage responseMessage, List<AnthropicMessage> conversationHistory) {
271-
272-
List<ContentBlock> toolToUseList = responseMessage.content()
273-
.stream()
274-
.filter(c -> c.type() == ContentBlock.ContentBlockType.TOOL_USE)
275-
.toList();
276-
277-
List<ContentBlock> toolResults = new ArrayList<>();
278-
279-
for (ContentBlock toolToUse : toolToUseList) {
280-
281-
var functionCallId = toolToUse.id();
282-
var functionName = toolToUse.name();
283-
var functionArguments = toolToUse.input();
284-
285-
if (!this.functionCallbackRegister.containsKey(functionName)) {
286-
throw new IllegalStateException("No function callback found for function name: " + functionName);
287-
}
288-
289-
String functionResponse = this.functionCallbackRegister.get(functionName)
290-
.call(ModelOptionsUtils.toJsonString(functionArguments));
291-
292-
toolResults.add(new ContentBlock(ContentBlockType.TOOL_RESULT, functionCallId, functionResponse));
293-
}
294-
295-
// Add the function response to the conversation.
296-
conversationHistory.add(new AnthropicMessage(toolResults, Role.USER));
297-
298-
// Recursively call chatCompletionWithTools until the model doesn't call a
299-
// functions anymore.
300-
return ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
301-
}
302-
303-
@Override
304-
protected List<AnthropicMessage> doGetUserMessages(ChatCompletionRequest request) {
305-
return request.messages();
306-
}
307-
308-
@Override
309-
protected AnthropicMessage doGetToolResponseMessage(ResponseEntity<ChatCompletionResponse> response) {
310-
return new AnthropicMessage(response.getBody().content(), Role.ASSISTANT);
311-
}
312-
313-
@Override
314-
protected ResponseEntity<ChatCompletionResponse> doChatCompletion(ChatCompletionRequest request) {
315-
return this.anthropicApi.chatCompletionEntity(request);
316-
}
317-
318342
@SuppressWarnings("null")
319343
@Override
320-
protected boolean isToolFunctionCall(ResponseEntity<ChatCompletionResponse> response) {
321-
if (response == null || response.getBody() == null || CollectionUtils.isEmpty(response.getBody().content())) {
344+
protected boolean isToolFunctionCall(ChatCompletionResponse response) {
345+
if (response == null || CollectionUtils.isEmpty(response.content())) {
322346
return false;
323347
}
324-
return response.getBody()
325-
.content()
348+
return response.content()
326349
.stream()
327350
.anyMatch(content -> content.type() == ContentBlock.ContentBlockType.TOOL_USE);
328351
}
329352

330-
@Override
331-
protected Flux<ResponseEntity<ChatCompletionResponse>> doChatCompletionStream(ChatCompletionRequest request) {
332-
333-
return this.anthropicApi.chatCompletionStream(request).map(Optional::ofNullable).map(ResponseEntity::of);
334-
}
335-
336353
@Override
337354
public ChatOptions getDefaultOptions() {
338355
return AnthropicChatOptions.fromOptions(this.defaultOptions);

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java

+27-21
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,20 @@
1515
*/
1616
package org.springframework.ai.openai;
1717

18+
import java.util.ArrayList;
19+
import java.util.Base64;
20+
import java.util.HashSet;
21+
import java.util.List;
22+
import java.util.Map;
23+
import java.util.Set;
24+
import java.util.concurrent.ConcurrentHashMap;
25+
1826
import org.slf4j.Logger;
1927
import org.slf4j.LoggerFactory;
2028
import org.springframework.ai.chat.messages.AssistantMessage;
2129
import org.springframework.ai.chat.messages.Message;
2230
import org.springframework.ai.chat.messages.MessageType;
2331
import org.springframework.ai.chat.messages.ToolResponseMessage;
24-
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
2532
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
2633
import org.springframework.ai.chat.metadata.RateLimit;
2734
import org.springframework.ai.chat.model.ChatModel;
@@ -50,17 +57,10 @@
5057
import org.springframework.util.Assert;
5158
import org.springframework.util.CollectionUtils;
5259
import org.springframework.util.MimeType;
60+
5361
import reactor.core.publisher.Flux;
5462
import reactor.core.publisher.Mono;
5563

56-
import java.util.ArrayList;
57-
import java.util.Base64;
58-
import java.util.HashSet;
59-
import java.util.List;
60-
import java.util.Map;
61-
import java.util.Set;
62-
import java.util.concurrent.ConcurrentHashMap;
63-
6464
/**
6565
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI}
6666
* backed by {@link OpenAiApi}.
@@ -266,12 +266,12 @@ private List<Message> handleToolCallRequests(List<Message> previousMessages, Cha
266266
AssistantMessage assistantMessage = new AssistantMessage(nativeAssistantMessage.content(), Map.of(),
267267
assistantToolCalls);
268268

269-
List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage, false);
269+
ToolResponseMessage toolResponseMessage = this.executeFuncitons(assistantMessage);
270270

271271
// History
272272
List<Message> messages = new ArrayList<>(previousMessages);
273273
messages.add(assistantMessage);
274-
messages.addAll(toolResponseMessages);
274+
messages.add(toolResponseMessage);
275275

276276
return messages;
277277
}
@@ -321,8 +321,8 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
321321
content = contentList;
322322
}
323323

324-
return new ChatCompletionMessage(content,
325-
ChatCompletionMessage.Role.valueOf(message.getMessageType().name()));
324+
return List.of(new ChatCompletionMessage(content,
325+
ChatCompletionMessage.Role.valueOf(message.getMessageType().name())));
326326
}
327327
else if (message.getMessageType() == MessageType.ASSISTANT) {
328328
var assistantMessage = (AssistantMessage) message;
@@ -333,21 +333,27 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {
333333
return new ToolCall(toolCall.id(), toolCall.type(), function);
334334
}).toList();
335335
}
336-
return new ChatCompletionMessage(assistantMessage.getContent(), ChatCompletionMessage.Role.ASSISTANT,
337-
null, null, toolCalls);
336+
return List.of(new ChatCompletionMessage(assistantMessage.getContent(),
337+
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls));
338338
}
339339
else if (message.getMessageType() == MessageType.TOOL) {
340340
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
341-
Assert.isTrue(toolMessage.getResponses().size() == 1,
342-
"ToolResponseMessage must have exactly one response");
343-
ToolResponse response = toolMessage.getResponses().get(0);
344-
return new ChatCompletionMessage(response.respoinse(), ChatCompletionMessage.Role.TOOL, response.name(),
345-
response.id(), null);
341+
342+
toolMessage.getResponses().forEach(response -> {
343+
Assert.isTrue(response.id() != null, "ToolResponseMessage must have an id");
344+
Assert.isTrue(response.name() != null, "ToolResponseMessage must have a name");
345+
});
346+
347+
return toolMessage.getResponses()
348+
.stream()
349+
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
350+
tr.id(), null))
351+
.toList();
346352
}
347353
else {
348354
throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType());
349355
}
350-
}).toList();
356+
}).flatMap(List::stream).toList();
351357

352358
ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);
353359

models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,12 @@ public List<Message> handleToolCallRequests(List<Message> previousMessages, Gene
198198

199199
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), assistantToolCalls);
200200

201-
List<ToolResponseMessage> toolResponseMessages = this.executeFuncitons(assistantMessage, true);
201+
ToolResponseMessage toolResponseMessage = this.executeFuncitons(assistantMessage);
202202

203203
// History
204204
List<Message> toolCallMessageConversation = new ArrayList<>(previousMessages);
205205
toolCallMessageConversation.add(assistantMessage);
206-
toolCallMessageConversation.addAll(toolResponseMessages);
206+
toolCallMessageConversation.add(toolResponseMessage);
207207
return toolCallMessageConversation;
208208
}
209209

@@ -420,7 +420,7 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) {
420420
.map(response -> Part.newBuilder()
421421
.setFunctionResponse(FunctionResponse.newBuilder()
422422
.setName(response.name())
423-
.setResponse(jsonToStruct(response.respoinse()))
423+
.setResponse(jsonToStruct(response.responseData()))
424424
.build())
425425
.build())
426426
.toList();

0 commit comments

Comments
 (0)