Skip to content

Commit 0bbccf8

Browse files
ashakirintzolov
authored andcommitted
fix: #2146 getting default AWS region using DefaultAwsRegionProviderChain
Signed-off-by: Andrei Shakirin <[email protected]>
1 parent cd3fc2f commit 0bbccf8

File tree

4 files changed

+168
-3
lines changed

4 files changed

+168
-3
lines changed

models/spring-ai-bedrock-converse/src/main/java/org/springframework/ai/bedrock/converse/BedrockProxyChatModel.java

+8
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@
4040
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
4141
import software.amazon.awssdk.core.SdkBytes;
4242
import software.amazon.awssdk.core.document.Document;
43+
import software.amazon.awssdk.core.exception.SdkClientException;
4344
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
4445
import software.amazon.awssdk.regions.Region;
46+
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
4547
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
4648
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
4749
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
@@ -788,6 +790,12 @@ public static final class Builder {
788790
private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
789791

790792
private Builder() {
793+
try {
794+
region = DefaultAwsRegionProviderChain.builder().build().getRegion();
795+
}
796+
catch (SdkClientException e) {
797+
logger.warn("Failed to load region from DefaultAwsRegionProviderChain, using US_EAST_1", e);
798+
}
791799
}
792800

793801
public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.bedrock.converse;
18+
19+
import org.junit.jupiter.api.Test;
20+
import org.junit.jupiter.api.extension.ExtendWith;
21+
import org.mockito.Answers;
22+
import org.mockito.Mock;
23+
import org.mockito.MockedStatic;
24+
import org.mockito.junit.jupiter.MockitoExtension;
25+
import software.amazon.awssdk.core.exception.SdkClientException;
26+
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
27+
28+
import static org.mockito.Mockito.mockStatic;
29+
import static org.mockito.Mockito.when;
30+
31+
@ExtendWith(MockitoExtension.class)
32+
class BedrockProxyChatModelTest {
33+
34+
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
35+
private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder;
36+
37+
@Test
38+
void shouldIgnoreExceptionAndUseDefault() {
39+
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
40+
when(awsRegionProviderBuilder.build().getRegion())
41+
.thenThrow(SdkClientException.builder().message("failed load").build());
42+
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
43+
BedrockProxyChatModel.builder().build();
44+
}
45+
}
46+
47+
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/api/AbstractBedrockApi.java

+16-3
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@
3030
import com.fasterxml.jackson.databind.ObjectMapper;
3131
import org.slf4j.Logger;
3232
import org.slf4j.LoggerFactory;
33+
import org.springframework.util.ObjectUtils;
3334
import reactor.core.publisher.Flux;
3435
import reactor.core.publisher.Sinks;
3536
import reactor.core.publisher.Sinks.EmitFailureHandler;
3637
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
3738
import software.amazon.awssdk.auth.credentials.ProfileCredentialsProvider;
3839
import software.amazon.awssdk.core.SdkBytes;
40+
import software.amazon.awssdk.core.exception.SdkClientException;
3941
import software.amazon.awssdk.regions.Region;
42+
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
4043
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
4144
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
4245
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
@@ -148,14 +151,12 @@ public AbstractBedrockApi(String modelId, AwsCredentialsProvider credentialsProv
148151

149152
Assert.hasText(modelId, "Model id must not be empty");
150153
Assert.notNull(credentialsProvider, "Credentials provider must not be null");
151-
Assert.notNull(region, "Region must not be empty");
152154
Assert.notNull(objectMapper, "Object mapper must not be null");
153155
Assert.notNull(timeout, "Timeout must not be null");
154156

155157
this.modelId = modelId;
156158
this.objectMapper = objectMapper;
157-
this.region = region;
158-
159+
this.region = getRegion(region);
159160

160161
this.client = BedrockRuntimeClient.builder()
161162
.region(this.region)
@@ -339,5 +340,17 @@ public record AmazonBedrockInvocationMetrics(
339340
@JsonProperty("outputTokenCount") Long outputTokenCount,
340341
@JsonProperty("invocationLatency") Long invocationLatency) {
341342
}
343+
344+
private Region getRegion(Region region) {
345+
if (ObjectUtils.isEmpty(region)) {
346+
try {
347+
return DefaultAwsRegionProviderChain.builder().build().getRegion();
348+
} catch (SdkClientException e) {
349+
throw new IllegalArgumentException("Region is empty and cannot be loaded from DefaultAwsRegionProviderChain: " + e.getMessage(), e);
350+
}
351+
} else {
352+
return region;
353+
}
354+
}
342355
}
343356
// @formatter:on
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright 2023-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.bedrock.api;
18+
19+
import com.fasterxml.jackson.databind.ObjectMapper;
20+
import org.junit.jupiter.api.Test;
21+
import org.junit.jupiter.api.extension.ExtendWith;
22+
import org.mockito.Answers;
23+
import org.mockito.Mock;
24+
import org.mockito.MockedStatic;
25+
import org.mockito.junit.jupiter.MockitoExtension;
26+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
27+
import software.amazon.awssdk.core.exception.SdkClientException;
28+
import software.amazon.awssdk.regions.Region;
29+
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
30+
31+
import java.time.Duration;
32+
33+
import static org.assertj.core.api.Assertions.assertThat;
34+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
35+
import static org.mockito.Mockito.*;
36+
37+
@ExtendWith(MockitoExtension.class)
38+
class AbstractBedrockApiTest {
39+
40+
@Mock(answer = Answers.RETURNS_DEEP_STUBS)
41+
private DefaultAwsRegionProviderChain.Builder awsRegionProviderBuilder;
42+
43+
@Mock
44+
private AwsCredentialsProvider awsCredentialsProvider = mock(AwsCredentialsProvider.class);
45+
46+
@Mock
47+
private ObjectMapper objectMapper = mock(ObjectMapper.class);
48+
49+
@Test
50+
void shouldLoadRegionFromAwsDefaults() {
51+
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
52+
when(awsRegionProviderBuilder.build().getRegion()).thenReturn(Region.AF_SOUTH_1);
53+
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
54+
AbstractBedrockApi<Object, Object, Object> testBedrockApi = new TestBedrockApi("modelId",
55+
awsCredentialsProvider, null, objectMapper, Duration.ofMinutes(5));
56+
assertThat(testBedrockApi.getRegion()).isEqualTo(Region.AF_SOUTH_1);
57+
}
58+
}
59+
60+
@Test
61+
void shouldThrowIllegalArgumentIfAwsDefaultsFailed() {
62+
try (MockedStatic<DefaultAwsRegionProviderChain> mocked = mockStatic(DefaultAwsRegionProviderChain.class)) {
63+
when(awsRegionProviderBuilder.build().getRegion())
64+
.thenThrow(SdkClientException.builder().message("failed load").build());
65+
mocked.when(DefaultAwsRegionProviderChain::builder).thenReturn(awsRegionProviderBuilder);
66+
assertThatThrownBy(() -> new TestBedrockApi("modelId", awsCredentialsProvider, null, objectMapper,
67+
Duration.ofMinutes(5)))
68+
.isInstanceOf(IllegalArgumentException.class)
69+
.hasMessageContaining("failed load");
70+
}
71+
}
72+
73+
private static class TestBedrockApi extends AbstractBedrockApi<Object, Object, Object> {
74+
75+
protected TestBedrockApi(String modelId, AwsCredentialsProvider credentialsProvider, Region region,
76+
ObjectMapper objectMapper, Duration timeout) {
77+
super(modelId, credentialsProvider, region, objectMapper, timeout);
78+
}
79+
80+
@Override
81+
protected Object embedding(Object request) {
82+
return null;
83+
}
84+
85+
@Override
86+
protected Object chatCompletion(Object request) {
87+
return null;
88+
}
89+
90+
@Override
91+
protected Object internalInvocation(Object request, Class<Object> clazz) {
92+
return null;
93+
}
94+
95+
}
96+
97+
}

0 commit comments

Comments
 (0)