Skip to content

Add observability to Bedrock Titan Embedding model #3014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@
import org.springframework.web.client.ResponseErrorHandler;

/**
* {@link AutoConfiguration Auto-configuration} for AI Retry.
* {@link AutoConfiguration Auto-configuration} for AI Retry. Provides beans for retry
* template and response error handling. Handles transient and non-transient exceptions
* based on HTTP status codes.
*
* @author Christian Tzolov
* Author: Christian Tzolov
*/
@AutoConfiguration
@ConditionalOnClass(RetryUtils.class)
Expand All @@ -63,9 +65,10 @@ public RetryTemplate retryTemplate(SpringAiRetryProperties properties) {
.withListener(new RetryListener() {

@Override
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback,
Throwable throwable) {
logger.warn("Retry error. Retry count: {}, Exception: {}", context.getRetryCount(),
throwable.getMessage(), throwable);
}
})
.build();
Expand All @@ -84,29 +87,35 @@ public boolean hasError(@NonNull ClientHttpResponse response) throws IOException

@Override
public void handleError(@NonNull ClientHttpResponse response) throws IOException {
if (response.getStatusCode().isError()) {
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
String message = String.format("%s - %s", response.getStatusCode().value(), error);

// Explicitly configured transient codes
if (properties.getOnHttpCodes().contains(response.getStatusCode().value())) {
throw new TransientAiException(message);
}

// onClientErrors - If true, do not throw a NonTransientAiException,
// and do not attempt retry for 4xx client error codes, false by
// default.
if (!properties.isOnClientErrors() && response.getStatusCode().is4xxClientError()) {
throw new NonTransientAiException(message);
}

// Explicitly configured non-transient codes
if (!CollectionUtils.isEmpty(properties.getExcludeOnHttpCodes())
&& properties.getExcludeOnHttpCodes().contains(response.getStatusCode().value())) {
throw new NonTransientAiException(message);
}
if (!response.getStatusCode().isError()) {
return;
}

String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
if (error == null || error.isEmpty()) {
error = "No response body available";
}

String message = String.format("HTTP %s - %s", response.getStatusCode().value(), error);

// Explicitly configured transient codes
if (properties.getOnHttpCodes().contains(response.getStatusCode().value())) {
throw new TransientAiException(message);
}

// Handle client errors (4xx)
if (!properties.isOnClientErrors() && response.getStatusCode().is4xxClientError()) {
throw new NonTransientAiException(message);
}

// Explicitly configured non-transient codes
if (!CollectionUtils.isEmpty(properties.getExcludeOnHttpCodes())
&& properties.getExcludeOnHttpCodes().contains(response.getStatusCode().value())) {
throw new NonTransientAiException(message);
}

// Default to transient exception
throw new TransientAiException(message);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
Expand Down Expand Up @@ -110,6 +116,12 @@
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<dependency>
<groupId>io.micrometer</groupId>
<artifactId>micrometer-observation</artifactId>
<version>1.15.0-RC1</version>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.springframework.ai.model.bedrock.titan.autoconfigure;

import com.fasterxml.jackson.databind.ObjectMapper;

import io.micrometer.observation.ObservationRegistry;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;

Expand Down Expand Up @@ -56,16 +58,35 @@ public class BedrockTitanEmbeddingAutoConfiguration {
public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider, BedrockTitanEmbeddingProperties properties,
BedrockAwsConnectionProperties awsProperties, ObjectMapper objectMapper) {

// Validate required properties
if (properties.getModel() == null || awsProperties.getTimeout() == null) {
throw new IllegalArgumentException("Required properties for TitanEmbeddingBedrockApi are missing.");
}

return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
objectMapper, awsProperties.getTimeout());
}

@Bean
@ConditionalOnMissingBean
public ObservationRegistry observationRegistry() {
return ObservationRegistry.create();
}

@Bean
@ConditionalOnMissingBean
@ConditionalOnBean(TitanEmbeddingBedrockApi.class)
public BedrockTitanEmbeddingModel titanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingApi,
BedrockTitanEmbeddingProperties properties) {
return new BedrockTitanEmbeddingModel(titanEmbeddingApi).withInputType(properties.getInputType());
BedrockTitanEmbeddingProperties properties, ObservationRegistry observationRegistry) {

// Validate required properties
if (properties.getInputType() == null) {
throw new IllegalArgumentException("InputType property for BedrockTitanEmbeddingModel is missing.");
}

return new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry)
.withInputType(properties.getInputType());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;

import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.Observation;

/**
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
* Bedrock Titan Embedding API. Titan Embedding supports text and image (encoded in
Expand All @@ -51,13 +54,17 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel {

private final TitanEmbeddingBedrockApi embeddingApi;

private final ObservationRegistry observationRegistry;

/**
* Titan Embedding API input types. Could be either text or image (encoded in base64).
*/
private InputType inputType = InputType.TEXT;

public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi,
ObservationRegistry observationRegistry) {
this.embeddingApi = titanEmbeddingBedrockApi;
this.observationRegistry = observationRegistry;
}

/**
Expand All @@ -78,17 +85,42 @@ public float[] embed(Document document) {
public EmbeddingResponse call(EmbeddingRequest request) {
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
if (request.getInstructions().size() != 1) {
logger.warn(
"Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
logger.warn("Titan Embedding does not support batch embedding. Multiple API calls will be made.");
}

List<Embedding> embeddings = new ArrayList<>();
var indexCounter = new AtomicInteger(0);

for (String inputContent : request.getInstructions()) {
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));

try {
TitanEmbeddingResponse response = Observation
.createNotStarted("bedrock.embedding", this.observationRegistry)
.lowCardinalityKeyValue("model", "titan")
.lowCardinalityKeyValue("input_type", this.inputType.name().toLowerCase())
.highCardinalityKeyValue("input_length", String.valueOf(inputContent.length()))
.observe(() -> {
TitanEmbeddingResponse r = this.embeddingApi.embedding(apiRequest);
Assert.notNull(r, "Embedding API returned null response");
return r;
});

if (response.embedding() == null || response.embedding().length == 0) {
logger.warn("Empty embedding vector returned for input at index {}. Skipping.", indexCounter.get());
continue;
}

embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
}
catch (Exception ex) {
logger.error("Titan API embedding failed for input at index {}: {}", indexCounter.get(),
summarizeInput(inputContent), ex);
throw ex; // Optional: Continue instead of throwing if you want partial
// success
}
}

return new EmbeddingResponse(embeddings);
}

Expand Down Expand Up @@ -117,6 +149,13 @@ public int dimensions() {

}

private String summarizeInput(String input) {
if (this.inputType == InputType.IMAGE) {
return "[image content omitted, length=" + input.length() + "]";
}
return input.length() > 100 ? input.substring(0, 100) + "..." : input;
}

public enum InputType {

TEXT, IMAGE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,18 @@

import static org.assertj.core.api.Assertions.assertThat;

import io.micrometer.observation.tck.TestObservationRegistry;

@SpringBootTest
@RequiresAwsCredentials
class BedrockTitanEmbeddingModelIT {

@Autowired
private BedrockTitanEmbeddingModel embeddingModel;

@Autowired
TestObservationRegistry observationRegistry;

@Test
void singleEmbedding() {
assertThat(this.embeddingModel).isNotNull();
Expand Down Expand Up @@ -82,8 +87,9 @@ public TitanEmbeddingBedrockApi titanEmbeddingApi() {
}

@Bean
public BedrockTitanEmbeddingModel titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi) {
return new BedrockTitanEmbeddingModel(titanEmbeddingApi);
public BedrockTitanEmbeddingModel titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi,
TestObservationRegistry observationRegistry) {
return new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);
}

}
Expand Down