Skip to content

Commit

Permalink
IMDS retries (#3975)
Browse files Browse the repository at this point in the history
## Motivation and Context
<!--- Why is this change required? What problem does it solve? -->
<!--- If it fixes an open issue, please link to the issue here -->
Addressing awslabs/aws-sdk-rust#1233

## Description
<!--- Describe your changes in detail -->
Add ability to configure the retry classifier on the IMDS client.

## Testing
<!--- Please describe in detail how you tested your changes -->
<!--- Include details of your testing environment, and the tests you ran
to -->
<!--- see how your change affects other areas of the code, etc. -->
Added new unit test to ensure that a user defined retry classifer is
being used.

## Checklist
<!--- If a checkbox below is not applicable, then please DELETE it
rather than leaving it unchecked -->
- [x] For changes to the smithy-rs codegen or runtime crates, I have
created a changelog entry Markdown file in the `.changelog` directory,
specifying "client," "server," or both in the `applies_to` key.

----

_By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice._
  • Loading branch information
landonxjames authored Jan 22, 2025
1 parent 4ad631a commit acb693f
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 9 deletions.
10 changes: 10 additions & 0 deletions .changelog/imds-retries.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
applies_to: ["client"]
authors: ["landonxjames"]
references: ["aws-sdk-rust#1233"]
breaking: false
new_feature: true
bug_fix: false
---

Allow IMDS clients to be configured with a user-provided `SharedRetryClassifier`.
2 changes: 1 addition & 1 deletion aws/rust-runtime/aws-config/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions aws/rust-runtime/aws-config/external-types.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ allowed_external_types = [
"aws_smithy_runtime_api::client::identity::ResolveIdentity",
"aws_smithy_runtime_api::client::orchestrator::HttpResponse",
"aws_smithy_runtime_api::client::result::SdkError",
"aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry",
"aws_smithy_runtime_api::client::retries::classifiers::SharedRetryClassifier",
"aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig",
"aws_smithy_types::body::SdkBody",
"aws_smithy_types::checksum_config::RequestChecksumCalculation",
Expand Down
165 changes: 157 additions & 8 deletions aws/rust-runtime/aws-config/src/imds/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600);
const DEFAULT_ATTEMPTS: u32 = 4;
const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1);
const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10);

fn user_agent() -> AwsUserAgent {
AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION))
Expand Down Expand Up @@ -238,6 +240,7 @@ impl ImdsCommonRuntimePlugin {
config: &ProviderConfig,
endpoint_resolver: ImdsEndpointResolver,
retry_config: RetryConfig,
retry_classifier: SharedRetryClassifier,
timeout_config: TimeoutConfig,
) -> Self {
let mut layer = Layer::new("ImdsCommonRuntimePlugin");
Expand All @@ -254,7 +257,7 @@ impl ImdsCommonRuntimePlugin {
.with_http_client(config.http_client())
.with_endpoint_resolver(Some(endpoint_resolver))
.with_interceptor(UserAgentInterceptor::new())
.with_retry_classifier(SharedRetryClassifier::new(ImdsResponseRetryClassifier))
.with_retry_classifier(retry_classifier)
.with_retry_strategy(Some(StandardRetryStrategy::new()))
.with_time_source(Some(config.time_source()))
.with_sleep_impl(config.sleep_impl()),
Expand Down Expand Up @@ -322,7 +325,10 @@ pub struct Builder {
token_ttl: Option<Duration>,
connect_timeout: Option<Duration>,
read_timeout: Option<Duration>,
operation_timeout: Option<Duration>,
operation_attempt_timeout: Option<Duration>,
config: Option<ProviderConfig>,
retry_classifier: Option<SharedRetryClassifier>,
}

impl Builder {
Expand Down Expand Up @@ -398,6 +404,32 @@ impl Builder {
self
}

/// Override the operation timeout for IMDS
///
/// This value defaults to 1 second
pub fn operation_timeout(mut self, timeout: Duration) -> Self {
self.operation_timeout = Some(timeout);
self
}

/// Override the operation attempt timeout for IMDS
///
/// This value defaults to 1 second
pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self {
self.operation_attempt_timeout = Some(timeout);
self
}

/// Override the retry classifier for IMDS
///
/// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this
/// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use
/// here or you can create your own fully customized [SharedRetryClassifier].
pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self {
self.retry_classifier = Some(retry_classifier);
self
}

/* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */
/*
pub fn port(mut self, port: u32) -> Self {
Expand All @@ -411,6 +443,11 @@ impl Builder {
let timeout_config = TimeoutConfig::builder()
.connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT))
.read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT))
.operation_attempt_timeout(
self.operation_attempt_timeout
.unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT),
)
.operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT))
.build();
let endpoint_source = self
.endpoint
Expand All @@ -421,10 +458,14 @@ impl Builder {
};
let retry_config = RetryConfig::standard()
.with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS));
let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new(
ImdsResponseRetryClassifier::default(),
));
let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new(
&config,
endpoint_resolver,
retry_config,
retry_classifier,
timeout_config,
));
let operation = Operation::builder()
Expand Down Expand Up @@ -549,8 +590,20 @@ impl ResolveEndpoint for ImdsEndpointResolver {
/// - 403 (IMDS disabled): **Not Retryable**
/// - 404 (Not found): **Not Retryable**
/// - >=500 (server error): **Retryable**
#[derive(Clone, Debug)]
struct ImdsResponseRetryClassifier;
/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()]
#[derive(Clone, Debug, Default)]
#[non_exhaustive]
pub struct ImdsResponseRetryClassifier {
retry_connect_timeouts: bool,
}

impl ImdsResponseRetryClassifier {
/// Indicate whether the IMDS client should retry on connection timeouts
pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self {
self.retry_connect_timeouts = retry_connect_timeouts;
self
}
}

impl ClassifyRetry for ImdsResponseRetryClassifier {
fn name(&self) -> &'static str {
Expand All @@ -567,7 +620,10 @@ impl ClassifyRetry for ImdsResponseRetryClassifier {
// This catch-all includes successful responses that fail to parse. These should not be retried.
_ => RetryAction::NoActionIndicated,
}
} else if self.retry_connect_timeouts {
RetryAction::server_error()
} else {
// This is the default behavior.
// Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default
// credentials provider chain to fail to provide credentials.
// Also don't retry non-responses.
Expand All @@ -593,7 +649,9 @@ pub(crate) mod test {
HttpRequest, HttpResponse, OrchestratorError,
};
use aws_smithy_runtime_api::client::result::ConnectorError;
use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction};
use aws_smithy_runtime_api::client::retries::classifiers::{
ClassifyRetry, RetryAction, SharedRetryClassifier,
};
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::error::display::DisplayErrorContext;
use aws_types::os_shim_internal::{Env, Fs};
Expand All @@ -603,6 +661,7 @@ pub(crate) mod test {
use std::collections::HashMap;
use std::error::Error;
use std::io;
use std::time::SystemTime;
use std::time::{Duration, UNIX_EPOCH};
use tracing_test::traced_test;

Expand Down Expand Up @@ -933,7 +992,7 @@ pub(crate) mod test {
let mut ctx = InterceptorContext::new(Input::doesnt_matter());
ctx.set_output_or_error(Ok(Output::doesnt_matter()));
ctx.set_response(imds_response("").map(|_| SdkBody::empty()));
let classifier = ImdsResponseRetryClassifier;
let classifier = ImdsResponseRetryClassifier::default();
assert_eq!(
RetryAction::NoActionIndicated,
classifier.classify_retry(&ctx)
Expand All @@ -950,6 +1009,65 @@ pub(crate) mod test {
);
}

/// User provided retry classifier works
#[tokio::test]
async fn user_provided_retry_classifier() {
#[derive(Clone, Debug)]
struct UserProvidedRetryClassifier;

impl ClassifyRetry for UserProvidedRetryClassifier {
fn name(&self) -> &'static str {
"UserProvidedRetryClassifier"
}

// Don't retry anything
fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction {
RetryAction::RetryForbidden
}
}

let events = vec![
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(0, TOKEN_A),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_A),
http::Response::builder()
.status(401)
.body(SdkBody::empty())
.unwrap(),
),
ReplayEvent::new(
token_request("http://169.254.169.254", 21600),
token_response(21600, TOKEN_B),
),
ReplayEvent::new(
imds_request("http://169.254.169.254/latest/metadata", TOKEN_B),
imds_response("ok"),
),
];
let http_client = StaticReplayClient::new(events);

let imds_client = super::Client::builder()
.configure(
&ProviderConfig::no_configuration()
.with_sleep_impl(InstantSleep::unlogged())
.with_http_client(http_client.clone()),
)
.retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier))
.build();

let res = imds_client
.get("/latest/metadata")
.await
.expect_err("Client should error");

// Assert that the operation errored on the initial 401 and did not retry and get
// the 200 (since the user provided retry classifier never retries)
assert_full_error_contains!(res, "401");
}

// since tokens are sent as headers, the tokens need to be valid header values
#[tokio::test]
async fn invalid_token() {
Expand Down Expand Up @@ -989,9 +1107,6 @@ pub(crate) mod test {
#[cfg(feature = "rustls")]
async fn one_second_connect_timeout() {
use crate::imds::client::ImdsError;
use aws_smithy_types::error::display::DisplayErrorContext;
use std::time::SystemTime;

let client = Client::builder()
// 240.* can never be resolved
.endpoint("http://240.0.0.0")
Expand Down Expand Up @@ -1023,6 +1138,40 @@ pub(crate) mod test {
);
}

/// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail)
#[tokio::test]
async fn retry_connect_timeouts() {
let http_client = StaticReplayClient::new(vec![]);
let imds_client = super::Client::builder()
.retry_classifier(SharedRetryClassifier::new(
ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true),
))
.configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone()))
.operation_timeout(Duration::from_secs(1))
.endpoint("http://240.0.0.0")
.expect("valid uri")
.build();

let now = SystemTime::now();
let _res = imds_client
.get("/latest/metadata")
.await
.expect_err("240.0.0.0 will never resolve");
let time_elapsed: Duration = now.elapsed().unwrap();

assert!(
time_elapsed > Duration::from_secs(1),
"time_elapsed should be greater than 1s but was {:?}",
time_elapsed
);

assert!(
time_elapsed < Duration::from_secs(2),
"time_elapsed should be less than 2s but was {:?}",
time_elapsed
);
}

#[derive(Debug, Deserialize)]
struct ImdsConfigTest {
env: HashMap<String, String>,
Expand Down

0 comments on commit acb693f

Please sign in to comment.