From 5c8c3daf95ef5e1d148327a89ae54605a56aba78 Mon Sep 17 00:00:00 2001 From: YunKui Lu Date: Sat, 5 Jul 2025 00:43:27 +0800 Subject: [PATCH] feat(zhipuai): Support glm-4.1v-thinking-flash model - Add reasoning_content fields to ChatCompletionMessage - Added ZhiPuAiAssistantMessage as a subclass of AssistantMessage to support returning CoT content. - Add integration tests - fix "RestClient.Builder bean not found" exception for zhipu's image auto-config Signed-off-by: YunKui Lu --- .../ZhiPuAiImageAutoConfiguration.java | 9 +- .../ai/zhipuai/ZhiPuAiAssistantMessage.java | 80 ++++++++++++++ .../ai/zhipuai/ZhiPuAiChatModel.java | 12 ++- .../ai/zhipuai/api/ZhiPuAiApi.java | 15 ++- .../ZhiPuAiStreamFunctionCallingHelper.java | 6 +- .../api/ZhiPuAiApiToolFunctionCallIT.java | 2 +- .../ai/zhipuai/chat/ZhiPuAiChatModelIT.java | 100 ++++++++++++++++-- 7 files changed, 205 insertions(+), 19 deletions(-) create mode 100644 models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiAssistantMessage.java diff --git a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java index 4bf5a17e2d7..ea89f6198cb 100644 --- a/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java +++ b/auto-configurations/models/spring-ai-autoconfigure-model-zhipuai/src/main/java/org/springframework/ai/model/zhipuai/autoconfigure/ZhiPuAiImageAutoConfiguration.java @@ -22,6 +22,7 @@ import org.springframework.ai.zhipuai.ZhiPuAiImageModel; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -51,8 +52,8 @@ public class ZhiPuAiImageAutoConfiguration { @Bean @ConditionalOnMissingBean public ZhiPuAiImageModel zhiPuAiImageModel(ZhiPuAiConnectionProperties commonProperties, - ZhiPuAiImageProperties imageProperties, RestClient.Builder restClientBuilder, RetryTemplate retryTemplate, - ResponseErrorHandler responseErrorHandler) { + ZhiPuAiImageProperties imageProperties, ObjectProvider restClientBuilderProvider, + RetryTemplate retryTemplate, ResponseErrorHandler responseErrorHandler) { String apiKey = StringUtils.hasText(imageProperties.getApiKey()) ? imageProperties.getApiKey() : commonProperties.getApiKey(); @@ -63,7 +64,9 @@ public ZhiPuAiImageModel zhiPuAiImageModel(ZhiPuAiConnectionProperties commonPro Assert.hasText(apiKey, "ZhiPuAI API key must be set"); Assert.hasText(baseUrl, "ZhiPuAI base URL must be set"); - var zhiPuAiImageApi = new ZhiPuAiImageApi(baseUrl, apiKey, restClientBuilder, responseErrorHandler); + // TODO add ZhiPuAiApi support for image + var zhiPuAiImageApi = new ZhiPuAiImageApi(baseUrl, apiKey, + restClientBuilderProvider.getIfAvailable(RestClient::builder), responseErrorHandler); return new ZhiPuAiImageModel(zhiPuAiImageApi, imageProperties.getOptions(), retryTemplate); } diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiAssistantMessage.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiAssistantMessage.java new file mode 100644 index 00000000000..db4ec584c9e --- /dev/null +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiAssistantMessage.java @@ -0,0 +1,80 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.zhipuai; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.content.Media; + +/** + * @author YunKui Lu + */ +public class ZhiPuAiAssistantMessage extends AssistantMessage { + + /** + * The CoT content of the message. + */ + private String reasoningContent; + + public ZhiPuAiAssistantMessage(String content) { + super(content); + } + + public ZhiPuAiAssistantMessage(String content, String reasoningContent, Map properties, + List toolCalls, List media) { + super(content, properties, toolCalls, media); + this.reasoningContent = reasoningContent; + } + + public String getReasoningContent() { + return reasoningContent; + } + + public ZhiPuAiAssistantMessage setReasoningContent(String reasoningContent) { + this.reasoningContent = reasoningContent; + return this; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ZhiPuAiAssistantMessage that)) { + return false; + } + if (!super.equals(o)) { + return false; + } + return Objects.equals(reasoningContent, that.reasoningContent); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), reasoningContent); + } + + @Override + public String toString() { + return "ZhiPuAiAssistantMessage{" + "media=" + media + ", messageType=" + messageType + ", metadata=" + metadata + + ", reasoningContent='" + reasoningContent + '\'' + ", textContent='" + textContent + '\'' + '}'; + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 01402acc36a..74f627e83db 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -227,7 +227,11 @@ private static Generation buildGeneration(Choice choice, Map met toolCall.function().name(), toolCall.function().arguments())) .toList(); - var assistantMessage = new AssistantMessage(choice.message().content(), metadata, toolCalls); + String textContent = choice.message().content(); + String reasoningContent = choice.message().reasoningContent(); + + var assistantMessage = new ZhiPuAiAssistantMessage(textContent, reasoningContent, metadata, toolCalls, + List.of()); String finishReason = (choice.finishReason() != null ? choice.finishReason().name() : ""); var generationMetadata = ChatGenerationMetadata.builder().finishReason(finishReason).build(); return new Generation(assistantMessage, generationMetadata); @@ -510,7 +514,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { }).toList(); } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -521,7 +525,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null)) + tr.id(), null, null)) .toList(); } else { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java index 10213b9734b..d632249bb3d 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiApi.java @@ -379,7 +379,15 @@ public enum ChatModel implements ChatModelDescription { GLM_4_Flash("glm-4-flash"), - GLM_3_Turbo("GLM-3-Turbo"); // @formatter:on + GLM_3_Turbo("GLM-3-Turbo"), + + // --- Visual Reasoning Models --- + + GLM_4_Thinking_FlashX("glm-4.1v-thinking-flashx"), + + GLM_4_Thinking_Flash("glm-4.1v-thinking-flash"), + + ; // @formatter:on public final String value; @@ -774,7 +782,8 @@ public record ChatCompletionMessage(// @formatter:off @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") List toolCalls) { // @formatter:on + @JsonProperty("tool_calls") List toolCalls, + @JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on /** * Create a chat completion message with the given content and role. All other @@ -783,7 +792,7 @@ public record ChatCompletionMessage(// @formatter:off * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null); + this(content, role, null, null, null, null); } /** diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java index e4629e94b49..e3afcf6407e 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/api/ZhiPuAiStreamFunctionCallingHelper.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -85,6 +85,8 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { String content = (current.content() != null ? current.content() : (previous.content() != null) ? previous.content() : ""); + String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent() + : (previous.reasoningContent() != null ? previous.reasoningContent() : "")); Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null String name = (current.name() != null ? current.name() : previous.name()); @@ -118,7 +120,7 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, reasoningContent); } private ToolCall merge(ToolCall previous, ToolCall current) { diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java index 05e4341ba2d..4d154750f17 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/api/ZhiPuAiApiToolFunctionCallIT.java @@ -124,7 +124,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, - functionName, toolCall.id(), null)); + functionName, toolCall.id(), null, null)); } } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java index c18e22c0e0a..15f5ac4195c 100644 --- a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/chat/ZhiPuAiChatModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ import org.springframework.ai.converter.ListOutputConverter; import org.springframework.ai.converter.MapOutputConverter; import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.zhipuai.ZhiPuAiAssistantMessage; import org.springframework.ai.zhipuai.ZhiPuAiChatOptions; import org.springframework.ai.zhipuai.ZhiPuAiTestConfiguration; import org.springframework.ai.zhipuai.api.MockWeatherService; @@ -60,6 +61,7 @@ import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; import org.springframework.util.MimeTypeUtils; +import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -312,7 +314,29 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "glm-4v" }) + @ValueSource(strings = { "glm-4.1v-thinking-flash" }) + void reasonerMultiModalityEmbeddedImageThinkingModel(String modelName) throws IOException { + var imageData = new ClassPathResource("/test.png"); + + var userMessage = UserMessage.builder() + .text("Explain what do you see on this picture?") + .media(List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))) + .build(); + + var response = this.chatModel + .call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build())); + + logger.info(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).containsAnyOf("bananas", "apple", "bowl", "basket", + "fruit stand"); + + logger.info(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent()); + assertThat(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent()) + .containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4v", "glm-4.1v-thinking-flash" }) void multiModalityImageUrl(String modelName) throws IOException { var userMessage = UserMessage.builder() @@ -331,8 +355,9 @@ void multiModalityImageUrl(String modelName) throws IOException { "fruit stand"); } - @Test - void streamingMultiModalityImageUrl() throws IOException { + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4.1v-thinking-flash" }) + void reasonerMultiModalityImageUrl(String modelName) throws IOException { var userMessage = UserMessage.builder() .text("Explain what do you see on this picture?") @@ -342,8 +367,32 @@ void streamingMultiModalityImageUrl() throws IOException { .build())) .build(); - Flux response = this.streamingChatModel.stream(new Prompt(List.of(userMessage), - ZhiPuAiChatOptions.builder().model(ZhiPuAiApi.ChatModel.GLM_4V.getValue()).build())); + ChatResponse response = this.chatModel + .call(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build())); + + logger.info(response.getResult().getOutput().getText()); + assertThat(response.getResult().getOutput().getText()).containsAnyOf("bananas", "apple", "bowl", "basket", + "fruit stand"); + + logger.info(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent()); + assertThat(((ZhiPuAiAssistantMessage) response.getResult().getOutput()).getReasoningContent()) + .containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4v" }) + void streamingMultiModalityImageUrl(String modelName) throws IOException { + + var userMessage = UserMessage.builder() + .text("Explain what do you see on this picture?") + .media(List.of(Media.builder() + .mimeType(MimeTypeUtils.IMAGE_PNG) + .data(URI.create("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")) + .build())) + .build(); + + Flux response = this.streamingChatModel + .stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build())); String content = Objects.requireNonNull(response.collectList().block()) .stream() @@ -356,6 +405,45 @@ void streamingMultiModalityImageUrl() throws IOException { assertThat(content).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "glm-4.1v-thinking-flash" }) + void reasonerStreamingMultiModalityImageUrl(String modelName) throws IOException { + + var userMessage = UserMessage.builder() + .text("Explain what do you see on this picture?") + .media(List.of(Media.builder() + .mimeType(MimeTypeUtils.IMAGE_PNG) + .data(URI.create("https://docs.spring.io/spring-ai/reference/_images/multimodal.test.png")) + .build())) + .build(); + + Flux response = this.streamingChatModel + .stream(new Prompt(List.of(userMessage), ZhiPuAiChatOptions.builder().model(modelName).build())); + + List streamingMessages = Objects.requireNonNull(response.collectList().block()) + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(m -> (ZhiPuAiAssistantMessage) m.getOutput()) + .toList(); + + String reasoningContent = streamingMessages.stream() + .map(ZhiPuAiAssistantMessage::getReasoningContent) + .filter(StringUtils::hasText) + .collect(Collectors.joining()); + + String content = streamingMessages.stream() + .map(AssistantMessage::getText) + .filter(StringUtils::hasText) + .collect(Collectors.joining()); + + logger.info("CoT: {}", reasoningContent); + assertThat(reasoningContent).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + + logger.info("Response: {}", content); + assertThat(content).containsAnyOf("bananas", "apple", "bowl", "basket", "fruit stand"); + } + record ActorsFilmsRecord(String actor, List movies) { }