20
20
import java .util .HashSet ;
21
21
import java .util .List ;
22
22
import java .util .Map ;
23
- import java .util .Optional ;
24
23
import java .util .Set ;
25
24
import java .util .stream .Collectors ;
26
25
34
33
import org .springframework .ai .anthropic .api .AnthropicApi .ContentBlock .ContentBlockType ;
35
34
import org .springframework .ai .anthropic .api .AnthropicApi .Role ;
36
35
import org .springframework .ai .anthropic .metadata .AnthropicChatResponseMetadata ;
36
+ import org .springframework .ai .chat .messages .AssistantMessage ;
37
+ import org .springframework .ai .chat .messages .Message ;
37
38
import org .springframework .ai .chat .messages .MessageType ;
39
+ import org .springframework .ai .chat .messages .ToolResponseMessage ;
38
40
import org .springframework .ai .chat .metadata .ChatGenerationMetadata ;
39
41
import org .springframework .ai .chat .model .ChatModel ;
40
42
import org .springframework .ai .chat .model .ChatResponse ;
41
43
import org .springframework .ai .chat .model .Generation ;
42
44
import org .springframework .ai .chat .prompt .ChatOptions ;
43
45
import org .springframework .ai .chat .prompt .Prompt ;
44
46
import org .springframework .ai .model .ModelOptionsUtils ;
45
- import org .springframework .ai .model .function .AbstractFunctionCallSupport ;
47
+ import org .springframework .ai .model .function .AbstractToolCallSupport ;
46
48
import org .springframework .ai .model .function .FunctionCallbackContext ;
47
49
import org .springframework .ai .retry .RetryUtils ;
48
50
import org .springframework .http .ResponseEntity ;
49
51
import org .springframework .retry .support .RetryTemplate ;
50
52
import org .springframework .util .Assert ;
51
53
import org .springframework .util .CollectionUtils ;
54
+ import org .springframework .util .StringUtils ;
52
55
53
56
import reactor .core .publisher .Flux ;
57
+ import reactor .core .publisher .Mono ;
54
58
55
59
/**
56
60
* The {@link ChatModel} implementation for the Anthropic service.
60
64
* @author Mariusz Bernacki
61
65
* @since 1.0.0
62
66
*/
63
- public class AnthropicChatModel extends
64
- AbstractFunctionCallSupport <AnthropicApi .AnthropicMessage , AnthropicApi .ChatCompletionRequest , ResponseEntity <AnthropicApi .ChatCompletionResponse >>
65
- implements ChatModel {
67
+ public class AnthropicChatModel extends AbstractToolCallSupport <ChatCompletionResponse > implements ChatModel {
66
68
67
69
private static final Logger logger = LoggerFactory .getLogger (AnthropicChatModel .class );
68
70
69
- public static final String DEFAULT_MODEL_NAME = AnthropicApi .ChatModel .CLAUDE_3_OPUS .getValue ();
71
+ public static final String DEFAULT_MODEL_NAME = AnthropicApi .ChatModel .CLAUDE_3_5_SONNET .getValue ();
70
72
71
73
public static final Integer DEFAULT_MAX_TOKENS = 500 ;
72
74
@@ -148,7 +150,14 @@ public ChatResponse call(Prompt prompt) {
148
150
ChatCompletionRequest request = createRequest (prompt , false );
149
151
150
152
return this .retryTemplate .execute (ctx -> {
151
- ResponseEntity <ChatCompletionResponse > completionEntity = this .callWithFunctionSupport (request );
153
+ ResponseEntity <ChatCompletionResponse > completionEntity = this .anthropicApi .chatCompletionEntity (request );
154
+
155
+ if (this .isToolFunctionCall (completionEntity .getBody ())) {
156
+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
157
+ completionEntity .getBody ());
158
+ return this .call (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
159
+ }
160
+
152
161
return toChatResponse (completionEntity .getBody ());
153
162
});
154
163
}
@@ -162,14 +171,52 @@ public Flux<ChatResponse> stream(Prompt prompt) {
162
171
163
172
Flux <ChatCompletionResponse > response = this .anthropicApi .chatCompletionStream (request );
164
173
165
- return response
166
- .switchMap (chatCompletionResponse -> handleFunctionCallOrReturnStream (request ,
167
- Flux .just (ResponseEntity .of (Optional .of (chatCompletionResponse )))))
168
- .map (ResponseEntity ::getBody )
169
- .map (this ::toChatResponse );
174
+ return response .switchMap (chatCompletionResponse -> {
175
+
176
+ if (this .isToolFunctionCall (chatCompletionResponse )) {
177
+ List <Message > toolCallMessageConversation = this .handleToolCallRequests (prompt .getInstructions (),
178
+ chatCompletionResponse );
179
+ return this .stream (new Prompt (toolCallMessageConversation , prompt .getOptions ()));
180
+ }
181
+
182
+ return Mono .just (chatCompletionResponse ).map (this ::toChatResponse );
183
+ });
170
184
});
171
185
}
172
186
187
+ private List <Message > handleToolCallRequests (List <Message > previousMessages ,
188
+ ChatCompletionResponse chatCompletionResponse ) {
189
+
190
+ AnthropicMessage anthropicAssistantMessage = new AnthropicMessage (chatCompletionResponse .content (),
191
+ Role .ASSISTANT );
192
+
193
+ List <ContentBlock > toolToUseList = anthropicAssistantMessage .content ()
194
+ .stream ()
195
+ .filter (c -> c .type () == ContentBlock .ContentBlockType .TOOL_USE )
196
+ .toList ();
197
+
198
+ List <AssistantMessage .ToolCall > toolCalls = new ArrayList <>();
199
+
200
+ for (ContentBlock toolToUse : toolToUseList ) {
201
+
202
+ var functionCallId = toolToUse .id ();
203
+ var functionName = toolToUse .name ();
204
+ var functionArguments = ModelOptionsUtils .toJsonString (toolToUse .input ());
205
+
206
+ toolCalls .add (new AssistantMessage .ToolCall (functionCallId , "function" , functionName , functionArguments ));
207
+ }
208
+
209
+ AssistantMessage assistantMessage = new AssistantMessage ("" , Map .of (), toolCalls );
210
+ ToolResponseMessage toolResponseMessage = this .executeFuncitons (assistantMessage );
211
+
212
+ // History
213
+ List <Message > toolCallMessageConversation = new ArrayList <>(previousMessages );
214
+ toolCallMessageConversation .add (assistantMessage );
215
+ toolCallMessageConversation .add (toolResponseMessage );
216
+
217
+ return toolCallMessageConversation ;
218
+ }
219
+
173
220
private ChatResponse toChatResponse (ChatCompletionResponse chatCompletion ) {
174
221
if (chatCompletion == null ) {
175
222
logger .warn ("Null chat completion returned" );
@@ -203,18 +250,45 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
203
250
204
251
List <AnthropicMessage > userMessages = prompt .getInstructions ()
205
252
.stream ()
206
- .filter (m -> m .getMessageType () != MessageType .SYSTEM )
207
- .map (m -> {
208
- List <ContentBlock > contents = new ArrayList <>(List .of (new ContentBlock (m .getContent ())));
209
- if (!CollectionUtils .isEmpty (m .getMedia ())) {
210
- List <ContentBlock > mediaContent = m .getMedia ()
253
+ .filter (message -> message .getMessageType () != MessageType .SYSTEM )
254
+ .map (message -> {
255
+ if (message .getMessageType () == MessageType .USER ) {
256
+ List <ContentBlock > contents = new ArrayList <>(List .of (new ContentBlock (message .getContent ())));
257
+ if (!CollectionUtils .isEmpty (message .getMedia ())) {
258
+ List <ContentBlock > mediaContent = message .getMedia ()
259
+ .stream ()
260
+ .map (media -> new ContentBlock (media .getMimeType ().toString (),
261
+ this .fromMediaData (media .getData ())))
262
+ .toList ();
263
+ contents .addAll (mediaContent );
264
+ }
265
+ return new AnthropicMessage (contents , Role .valueOf (message .getMessageType ().name ()));
266
+ }
267
+ else if (message .getMessageType () == MessageType .ASSISTANT ) {
268
+ AssistantMessage assistantMessage = (AssistantMessage ) message ;
269
+ List <ContentBlock > contentBlocks = new ArrayList <>();
270
+ if (StringUtils .hasText (message .getContent ())) {
271
+ contentBlocks .add (new ContentBlock (message .getContent ()));
272
+ }
273
+ if (!CollectionUtils .isEmpty (assistantMessage .getToolCalls ())) {
274
+ for (AssistantMessage .ToolCall toolCall : assistantMessage .getToolCalls ()) {
275
+ contentBlocks .add (new ContentBlock (ContentBlockType .TOOL_USE , toolCall .id (),
276
+ toolCall .name (), ModelOptionsUtils .jsonToMap (toolCall .arguments ())));
277
+ }
278
+ }
279
+ return new AnthropicMessage (contentBlocks , Role .ASSISTANT );
280
+ }
281
+ else if (message .getMessageType () == MessageType .TOOL ) {
282
+ List <ContentBlock > toolResponses = ((ToolResponseMessage ) message ).getResponses ()
211
283
.stream ()
212
- .map (media -> new ContentBlock (media . getMimeType (). toString (),
213
- this . fromMediaData ( media . getData () )))
284
+ .map (toolResponse -> new ContentBlock (ContentBlockType . TOOL_RESULT , toolResponse . id (),
285
+ toolResponse . responseData ( )))
214
286
.toList ();
215
- contents .addAll (mediaContent );
287
+ return new AnthropicMessage (toolResponses , Role .USER );
288
+ }
289
+ else {
290
+ throw new IllegalArgumentException ("Unsupported message type: " + message .getMessageType ());
216
291
}
217
- return new AnthropicMessage (contents , Role .valueOf (m .getMessageType ().name ()));
218
292
})
219
293
.toList ();
220
294
@@ -265,74 +339,17 @@ private List<AnthropicApi.Tool> getFunctionTools(Set<String> functionNames) {
265
339
}).toList ();
266
340
}
267
341
268
- @ Override
269
- protected ChatCompletionRequest doCreateToolResponseRequest (ChatCompletionRequest previousRequest ,
270
- AnthropicMessage responseMessage , List <AnthropicMessage > conversationHistory ) {
271
-
272
- List <ContentBlock > toolToUseList = responseMessage .content ()
273
- .stream ()
274
- .filter (c -> c .type () == ContentBlock .ContentBlockType .TOOL_USE )
275
- .toList ();
276
-
277
- List <ContentBlock > toolResults = new ArrayList <>();
278
-
279
- for (ContentBlock toolToUse : toolToUseList ) {
280
-
281
- var functionCallId = toolToUse .id ();
282
- var functionName = toolToUse .name ();
283
- var functionArguments = toolToUse .input ();
284
-
285
- if (!this .functionCallbackRegister .containsKey (functionName )) {
286
- throw new IllegalStateException ("No function callback found for function name: " + functionName );
287
- }
288
-
289
- String functionResponse = this .functionCallbackRegister .get (functionName )
290
- .call (ModelOptionsUtils .toJsonString (functionArguments ));
291
-
292
- toolResults .add (new ContentBlock (ContentBlockType .TOOL_RESULT , functionCallId , functionResponse ));
293
- }
294
-
295
- // Add the function response to the conversation.
296
- conversationHistory .add (new AnthropicMessage (toolResults , Role .USER ));
297
-
298
- // Recursively call chatCompletionWithTools until the model doesn't call a
299
- // functions anymore.
300
- return ChatCompletionRequest .from (previousRequest ).withMessages (conversationHistory ).build ();
301
- }
302
-
303
- @ Override
304
- protected List <AnthropicMessage > doGetUserMessages (ChatCompletionRequest request ) {
305
- return request .messages ();
306
- }
307
-
308
- @ Override
309
- protected AnthropicMessage doGetToolResponseMessage (ResponseEntity <ChatCompletionResponse > response ) {
310
- return new AnthropicMessage (response .getBody ().content (), Role .ASSISTANT );
311
- }
312
-
313
- @ Override
314
- protected ResponseEntity <ChatCompletionResponse > doChatCompletion (ChatCompletionRequest request ) {
315
- return this .anthropicApi .chatCompletionEntity (request );
316
- }
317
-
318
342
@ SuppressWarnings ("null" )
319
343
@ Override
320
- protected boolean isToolFunctionCall (ResponseEntity < ChatCompletionResponse > response ) {
321
- if (response == null || response . getBody () == null || CollectionUtils .isEmpty (response . getBody () .content ())) {
344
+ protected boolean isToolFunctionCall (ChatCompletionResponse response ) {
345
+ if (response == null || CollectionUtils .isEmpty (response .content ())) {
322
346
return false ;
323
347
}
324
- return response .getBody ()
325
- .content ()
348
+ return response .content ()
326
349
.stream ()
327
350
.anyMatch (content -> content .type () == ContentBlock .ContentBlockType .TOOL_USE );
328
351
}
329
352
330
- @ Override
331
- protected Flux <ResponseEntity <ChatCompletionResponse >> doChatCompletionStream (ChatCompletionRequest request ) {
332
-
333
- return this .anthropicApi .chatCompletionStream (request ).map (Optional ::ofNullable ).map (ResponseEntity ::of );
334
- }
335
-
336
353
@ Override
337
354
public ChatOptions getDefaultOptions () {
338
355
return AnthropicChatOptions .fromOptions (this .defaultOptions );
0 commit comments