Skip to content

Commit cc25b14

Browse files
committed
Google genai response multimodality support
Author: 楚孔响 <[email protected]> Signed-off-by: 楚孔响 <[email protected]> Signed-off-by: ckx521 <[email protected]>
1 parent eda3c74 commit cc25b14

File tree

3 files changed

+63
-5
lines changed

3 files changed

+63
-5
lines changed

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
import org.springframework.util.Assert;
9292
import org.springframework.util.CollectionUtils;
9393
import org.springframework.util.StringUtils;
94+
import org.springframework.util.MimeType;
9495

9596
/**
9697
* Google GenAI Chat Model implementation that provides access to Google's Gemini language
@@ -626,7 +627,19 @@ protected List<Generation> responseCandidateToGeneration(Candidate candidate) {
626627
.parts()
627628
.orElse(List.of())
628629
.stream()
629-
.map(part -> new AssistantMessage(part.text().orElse(""), messageMetadata))
630+
.map(part -> {
631+
// Multimodality Response Support
632+
List<Media> media = part.inlineData()
633+
.filter(blob -> blob.data().isPresent() && blob.mimeType().isPresent())
634+
.map(blob -> Media
635+
.builder()
636+
.mimeType(MimeType.valueOf(blob.mimeType().get()))
637+
.data(blob.data().get())
638+
.build())
639+
.map(List::of)
640+
.orElse(List.of());
641+
return new AssistantMessage(part.text().orElse(""), messageMetadata, List.of(), media);
642+
})
630643
.map(assistantMessage -> new Generation(assistantMessage, chatGenerationMetadata))
631644
.toList();
632645
}
@@ -725,6 +738,10 @@ GeminiRequest createGeminiRequest(Prompt prompt) {
725738
configBuilder.systemInstruction(systemContents.get(0));
726739
}
727740

741+
if (!CollectionUtils.isEmpty(requestOptions.getResponseModalities())) {
742+
configBuilder.responseModalities(requestOptions.getResponseModalities());
743+
}
744+
728745
GenerateContentConfig config = configBuilder.build();
729746

730747
// Create message contents
@@ -850,7 +867,7 @@ public static final class Builder {
850867
private GoogleGenAiChatOptions defaultOptions = GoogleGenAiChatOptions.builder()
851868
.temperature(0.7)
852869
.topP(1.0)
853-
.model(GoogleGenAiChatModel.ChatModel.GEMINI_2_0_FLASH)
870+
.model(ChatModel.GEMINI_2_0_FLASH)
854871
.build();
855872

856873
private ToolCallingManager toolCallingManager;

models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatOptions.java

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ public class GoogleGenAiChatOptions implements ToolCallingChatOptions {
113113
*/
114114
private @JsonProperty("thinkingBudget") Integer thinkingBudget;
115115

116+
/**
117+
* Optional. Response Modalities.
118+
* @see com.google.genai.types.Modality.Known
119+
*/
120+
private @JsonProperty("responseModalities") List<String> responseModalities = new ArrayList<>();
121+
116122
/**
117123
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
118124
* completion requests.
@@ -174,6 +180,7 @@ public static GoogleGenAiChatOptions fromOptions(GoogleGenAiChatOptions fromOpti
174180
options.setToolContext(fromOptions.getToolContext());
175181
options.setThinkingBudget(fromOptions.getThinkingBudget());
176182
options.setLabels(fromOptions.getLabels());
183+
options.setResponseModalities(fromOptions.getResponseModalities());
177184
return options;
178185
}
179186

@@ -355,6 +362,15 @@ public void setToolContext(Map<String, Object> toolContext) {
355362
this.toolContext = toolContext;
356363
}
357364

365+
public List<String> getResponseModalities() {
366+
return responseModalities;
367+
}
368+
369+
public void setResponseModalities(List<String> responseModalities) {
370+
Assert.notNull(responseModalities, "responseModalities cannot be null");
371+
this.responseModalities = responseModalities;
372+
}
373+
358374
@Override
359375
public boolean equals(Object o) {
360376
if (this == o) {
@@ -376,15 +392,17 @@ public boolean equals(Object o) {
376392
&& Objects.equals(this.toolNames, that.toolNames)
377393
&& Objects.equals(this.safetySettings, that.safetySettings)
378394
&& Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled)
379-
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels);
395+
&& Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.labels, that.labels)
396+
&& Objects.equals(this.responseModalities, that.responseModalities);
380397
}
381398

382399
@Override
383400
public int hashCode() {
384401
return Objects.hash(this.stopSequences, this.temperature, this.topP, this.topK, this.candidateCount,
385402
this.frequencyPenalty, this.presencePenalty, this.thinkingBudget, this.maxOutputTokens, this.model,
386403
this.responseMimeType, this.toolCallbacks, this.toolNames, this.googleSearchRetrieval,
387-
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels);
404+
this.safetySettings, this.internalToolExecutionEnabled, this.toolContext, this.labels,
405+
this.responseModalities);
388406
}
389407

390408
@Override
@@ -396,7 +414,7 @@ public String toString() {
396414
+ this.model + '\'' + ", responseMimeType='" + this.responseMimeType + '\'' + ", toolCallbacks="
397415
+ this.toolCallbacks + ", toolNames=" + this.toolNames + ", googleSearchRetrieval="
398416
+ this.googleSearchRetrieval + ", safetySettings=" + this.safetySettings + ", labels=" + this.labels
399-
+ '}';
417+
+ ", responseModalities=" + this.responseModalities + '}';
400418
}
401419

402420
@Override
@@ -530,6 +548,18 @@ public Builder labels(Map<String, String> labels) {
530548
return this;
531549
}
532550

551+
public Builder responseModalities(List<String> responseModalities) {
552+
Assert.notNull(responseModalities, "responseModalities must not be null");
553+
this.options.responseModalities = responseModalities;
554+
return this;
555+
}
556+
557+
public Builder responseModalitie(String responseModalitie) {
558+
Assert.hasText(responseModalitie, "responseModalitie must not be empty");
559+
this.options.responseModalities.add(responseModalitie);
560+
return this;
561+
}
562+
533563
public GoogleGenAiChatOptions build() {
534564
return this.options;
535565
}

models/spring-ai-google-genai/src/test/java/org/springframework/ai/google/genai/GoogleGenAiChatOptionsTest.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
package org.springframework.ai.google.genai;
1818

19+
import java.util.List;
1920
import java.util.Map;
2021

22+
import com.google.genai.types.Modality;
2123
import org.junit.jupiter.api.Test;
2224

2325
import static org.assertj.core.api.Assertions.assertThat;
@@ -153,4 +155,13 @@ public void testToStringWithLabels() {
153155
assertThat(toString).contains("test-model");
154156
}
155157

158+
@Test
159+
public void testResponseMultimodality() {
160+
GoogleGenAiChatOptions options = GoogleGenAiChatOptions.builder()
161+
.responseModalities(List.of(Modality.Known.TEXT.name(), Modality.Known.IMAGE.name()))
162+
.build();
163+
String toString = options.toString();
164+
assertThat(toString).contains("responseModalities=[TEXT, IMAGE]");
165+
}
166+
156167
}

0 commit comments

Comments
 (0)