diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 9f2e8a9ebdb..184bbb4464e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -29,6 +29,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -1793,7 +1794,7 @@ public record ChunkChoice(// @formatter:off @JsonIgnoreProperties(ignoreUnknown = true) public record Embedding(// @formatter:off @JsonProperty("index") Integer index, - @JsonProperty("embedding") float[] embedding, + @JsonProperty("embedding") @JsonDeserialize(using = OpenAiEmbeddingDeserializer.class) float[] embedding, @JsonProperty("object") String object) { // @formatter:on /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiEmbeddingDeserializer.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiEmbeddingDeserializer.java new file mode 100644 index 00000000000..e6ede495398 --- /dev/null +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiEmbeddingDeserializer.java @@ -0,0 +1,72 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import com.fasterxml.jackson.core.JacksonException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Base64; + +/** + * Used to deserialize the `embedding` field returned by the model. + *

+ * Supports two input formats: + *

    + *
  1. {@code float[]} - returned directly as-is.
  2. + *
  3. A Base64-encoded string representing a float array stored in little-endian format. + * The string is first decoded into a byte array, then converted into a + * {@code float[]}.
  4. + *
+ * + * @author Sun Yuhan + */ +public class OpenAiEmbeddingDeserializer extends JsonDeserializer { + + @Override + public float[] deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) + throws IOException, JacksonException { + JsonToken token = jsonParser.currentToken(); + if (token == JsonToken.START_ARRAY) { + return jsonParser.readValueAs(float[].class); + } + else if (token == JsonToken.VALUE_STRING) { + String base64 = jsonParser.getValueAsString(); + byte[] decodedBytes = Base64.getDecoder().decode(base64); + + ByteBuffer byteBuffer = ByteBuffer.wrap(decodedBytes); + byteBuffer.order(ByteOrder.LITTLE_ENDIAN); + + int floatCount = decodedBytes.length / Float.BYTES; + float[] embeddingArray = new float[floatCount]; + + for (int i = 0; i < floatCount; i++) { + embeddingArray[i] = byteBuffer.getFloat(); + } + return embeddingArray; + } + else { + throw new IOException("Illegal embedding: " + token); + } + } + +} diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiEmbeddingDeserializerTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiEmbeddingDeserializerTests.java new file mode 100644 index 00000000000..8319cfd2acb --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiEmbeddingDeserializerTests.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.api; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Base64; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link OpenAiEmbeddingDeserializer} + * + * @author Sun Yuhan + */ +class OpenAiEmbeddingDeserializerTests { + + private final OpenAiEmbeddingDeserializer deserializer = new OpenAiEmbeddingDeserializer(); + + private final ObjectMapper mapper = new ObjectMapper(); + + @Test + void testDeserializeFloatArray() throws Exception { + JsonParser parser = mock(JsonParser.class); + DeserializationContext context = mock(DeserializationContext.class); + + when(parser.currentToken()).thenReturn(JsonToken.START_ARRAY); + float[] expected = new float[] { 1.0f, 2.0f, 3.0f }; + when(parser.readValueAs(float[].class)).thenReturn(expected); + + float[] result = deserializer.deserialize(parser, context); + assertArrayEquals(expected, result); + } + + @Test + void testDeserializeBase64String() throws Exception { + float[] original = new float[] { 4.2f, -1.5f, 0.0f }; + ByteBuffer buffer = ByteBuffer.allocate(original.length * Float.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + for (float v : original) { + buffer.putFloat(v); + } + String base64 = Base64.getEncoder().encodeToString(buffer.array()); + + JsonParser parser = mock(JsonParser.class); + DeserializationContext context = mock(DeserializationContext.class); + + when(parser.currentToken()).thenReturn(JsonToken.VALUE_STRING); + when(parser.getValueAsString()).thenReturn(base64); + + float[] result = deserializer.deserialize(parser, context); + + assertArrayEquals(original, result, 0.0001f); + } + + @Test + void testDeserializeIllegalToken() { + JsonParser parser = mock(JsonParser.class); + DeserializationContext context = mock(DeserializationContext.class); + + when(parser.currentToken()).thenReturn(JsonToken.VALUE_NUMBER_INT); + + IOException e = assertThrows(IOException.class, () -> deserializer.deserialize(parser, context)); + assertTrue(e.getMessage().contains("Illegal embedding")); + } + + @Test + void testDeserializeEmbeddingWithFloatArray() throws Exception { + String json = """ + { + "index": 1, + "embedding": [1.0, 2.0, 3.0], + "object": "embedding" + } + """; + OpenAiApi.Embedding embedding = mapper.readValue(json, OpenAiApi.Embedding.class); + assertEquals(1, embedding.index()); + assertArrayEquals(new float[] { 1.0f, 2.0f, 3.0f }, embedding.embedding(), 0.0001f); + assertEquals("embedding", embedding.object()); + } + + @Test + void testDeserializeEmbeddingWithBase64String() throws Exception { + float[] original = new float[] { 4.2f, -1.5f, 0.0f }; + ByteBuffer buffer = ByteBuffer.allocate(original.length * Float.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + for (float v : original) + buffer.putFloat(v); + String base64 = Base64.getEncoder().encodeToString(buffer.array()); + + String json = """ + { + "index": 2, + "embedding": "%s", + "object": "embedding" + } + """.formatted(base64); + + OpenAiApi.Embedding embedding = mapper.readValue(json, OpenAiApi.Embedding.class); + assertEquals(2, embedding.index()); + assertArrayEquals(original, embedding.embedding(), 0.0001f); + assertEquals("embedding", embedding.object()); + } + + @Test + void testDeserializeEmbeddingWithWrongType() { + String json = """ + { + "index": 3, + "embedding": 123, + "object": "embedding" + } + """; + JsonProcessingException ex = assertThrows(JsonProcessingException.class, () -> { + mapper.readValue(json, OpenAiApi.Embedding.class); + }); + assertTrue(ex.getMessage().contains("Illegal embedding")); + } + +}