diff --git a/.changelog/imds-retries.md b/.changelog/imds-retries.md new file mode 100644 index 0000000000..114f331df5 --- /dev/null +++ b/.changelog/imds-retries.md @@ -0,0 +1,10 @@ +--- +applies_to: ["client"] +authors: ["landonxjames"] +references: ["aws-sdk-rust#1233"] +breaking: false +new_feature: true +bug_fix: false +--- + +Allow IMDS clients to be configured with a user-provided `SharedRetryClassifier`. diff --git a/aws/rust-runtime/aws-config/Cargo.lock b/aws/rust-runtime/aws-config/Cargo.lock index 8b0811a3c8..faf4684a1d 100644 --- a/aws/rust-runtime/aws-config/Cargo.lock +++ b/aws/rust-runtime/aws-config/Cargo.lock @@ -45,7 +45,7 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.14" +version = "1.5.15" dependencies = [ "aws-credential-types", "aws-runtime", diff --git a/aws/rust-runtime/aws-config/external-types.toml b/aws/rust-runtime/aws-config/external-types.toml index 3e2c6726d1..362f2ea292 100644 --- a/aws/rust-runtime/aws-config/external-types.toml +++ b/aws/rust-runtime/aws-config/external-types.toml @@ -32,6 +32,8 @@ allowed_external_types = [ "aws_smithy_runtime_api::client::identity::ResolveIdentity", "aws_smithy_runtime_api::client::orchestrator::HttpResponse", "aws_smithy_runtime_api::client::result::SdkError", + "aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry", + "aws_smithy_runtime_api::client::retries::classifiers::SharedRetryClassifier", "aws_smithy_runtime_api::client::stalled_stream_protection::StalledStreamProtectionConfig", "aws_smithy_types::body::SdkBody", "aws_smithy_types::checksum_config::RequestChecksumCalculation", diff --git a/aws/rust-runtime/aws-config/src/imds/client.rs b/aws/rust-runtime/aws-config/src/imds/client.rs index 49408fcb4f..3f23cfdc99 100644 --- a/aws/rust-runtime/aws-config/src/imds/client.rs +++ b/aws/rust-runtime/aws-config/src/imds/client.rs @@ -52,6 +52,8 @@ const DEFAULT_TOKEN_TTL: Duration = Duration::from_secs(21_600); const DEFAULT_ATTEMPTS: u32 = 4; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(1); const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(1); +const DEFAULT_OPERATION_TIMEOUT: Duration = Duration::from_secs(30); +const DEFAULT_OPERATION_ATTEMPT_TIMEOUT: Duration = Duration::from_secs(10); fn user_agent() -> AwsUserAgent { AwsUserAgent::new_from_environment(Env::real(), ApiMetadata::new("imds", PKG_VERSION)) @@ -238,6 +240,7 @@ impl ImdsCommonRuntimePlugin { config: &ProviderConfig, endpoint_resolver: ImdsEndpointResolver, retry_config: RetryConfig, + retry_classifier: SharedRetryClassifier, timeout_config: TimeoutConfig, ) -> Self { let mut layer = Layer::new("ImdsCommonRuntimePlugin"); @@ -254,7 +257,7 @@ impl ImdsCommonRuntimePlugin { .with_http_client(config.http_client()) .with_endpoint_resolver(Some(endpoint_resolver)) .with_interceptor(UserAgentInterceptor::new()) - .with_retry_classifier(SharedRetryClassifier::new(ImdsResponseRetryClassifier)) + .with_retry_classifier(retry_classifier) .with_retry_strategy(Some(StandardRetryStrategy::new())) .with_time_source(Some(config.time_source())) .with_sleep_impl(config.sleep_impl()), @@ -322,7 +325,10 @@ pub struct Builder { token_ttl: Option, connect_timeout: Option, read_timeout: Option, + operation_timeout: Option, + operation_attempt_timeout: Option, config: Option, + retry_classifier: Option, } impl Builder { @@ -398,6 +404,32 @@ impl Builder { self } + /// Override the operation timeout for IMDS + /// + /// This value defaults to 1 second + pub fn operation_timeout(mut self, timeout: Duration) -> Self { + self.operation_timeout = Some(timeout); + self + } + + /// Override the operation attempt timeout for IMDS + /// + /// This value defaults to 1 second + pub fn operation_attempt_timeout(mut self, timeout: Duration) -> Self { + self.operation_attempt_timeout = Some(timeout); + self + } + + /// Override the retry classifier for IMDS + /// + /// This defaults to only retrying on server errors and 401s. The [ImdsResponseRetryClassifier] in this + /// module offers some configuration options and can be wrapped by[SharedRetryClassifier::new()] for use + /// here or you can create your own fully customized [SharedRetryClassifier]. + pub fn retry_classifier(mut self, retry_classifier: SharedRetryClassifier) -> Self { + self.retry_classifier = Some(retry_classifier); + self + } + /* TODO(https://github.com/awslabs/aws-sdk-rust/issues/339): Support customizing the port explicitly */ /* pub fn port(mut self, port: u32) -> Self { @@ -411,6 +443,11 @@ impl Builder { let timeout_config = TimeoutConfig::builder() .connect_timeout(self.connect_timeout.unwrap_or(DEFAULT_CONNECT_TIMEOUT)) .read_timeout(self.read_timeout.unwrap_or(DEFAULT_READ_TIMEOUT)) + .operation_attempt_timeout( + self.operation_attempt_timeout + .unwrap_or(DEFAULT_OPERATION_ATTEMPT_TIMEOUT), + ) + .operation_timeout(self.operation_timeout.unwrap_or(DEFAULT_OPERATION_TIMEOUT)) .build(); let endpoint_source = self .endpoint @@ -421,10 +458,14 @@ impl Builder { }; let retry_config = RetryConfig::standard() .with_max_attempts(self.max_attempts.unwrap_or(DEFAULT_ATTEMPTS)); + let retry_classifier = self.retry_classifier.unwrap_or(SharedRetryClassifier::new( + ImdsResponseRetryClassifier::default(), + )); let common_plugin = SharedRuntimePlugin::new(ImdsCommonRuntimePlugin::new( &config, endpoint_resolver, retry_config, + retry_classifier, timeout_config, )); let operation = Operation::builder() @@ -549,8 +590,20 @@ impl ResolveEndpoint for ImdsEndpointResolver { /// - 403 (IMDS disabled): **Not Retryable** /// - 404 (Not found): **Not Retryable** /// - >=500 (server error): **Retryable** -#[derive(Clone, Debug)] -struct ImdsResponseRetryClassifier; +/// - Timeouts: Not retried by default, but this is configurable via [Self::with_retry_connect_timeouts()] +#[derive(Clone, Debug, Default)] +#[non_exhaustive] +pub struct ImdsResponseRetryClassifier { + retry_connect_timeouts: bool, +} + +impl ImdsResponseRetryClassifier { + /// Indicate whether the IMDS client should retry on connection timeouts + pub fn with_retry_connect_timeouts(mut self, retry_connect_timeouts: bool) -> Self { + self.retry_connect_timeouts = retry_connect_timeouts; + self + } +} impl ClassifyRetry for ImdsResponseRetryClassifier { fn name(&self) -> &'static str { @@ -567,7 +620,10 @@ impl ClassifyRetry for ImdsResponseRetryClassifier { // This catch-all includes successful responses that fail to parse. These should not be retried. _ => RetryAction::NoActionIndicated, } + } else if self.retry_connect_timeouts { + RetryAction::server_error() } else { + // This is the default behavior. // Don't retry timeouts for IMDS, or else it will take ~30 seconds for the default // credentials provider chain to fail to provide credentials. // Also don't retry non-responses. @@ -593,7 +649,9 @@ pub(crate) mod test { HttpRequest, HttpResponse, OrchestratorError, }; use aws_smithy_runtime_api::client::result::ConnectorError; - use aws_smithy_runtime_api::client::retries::classifiers::{ClassifyRetry, RetryAction}; + use aws_smithy_runtime_api::client::retries::classifiers::{ + ClassifyRetry, RetryAction, SharedRetryClassifier, + }; use aws_smithy_types::body::SdkBody; use aws_smithy_types::error::display::DisplayErrorContext; use aws_types::os_shim_internal::{Env, Fs}; @@ -603,6 +661,7 @@ pub(crate) mod test { use std::collections::HashMap; use std::error::Error; use std::io; + use std::time::SystemTime; use std::time::{Duration, UNIX_EPOCH}; use tracing_test::traced_test; @@ -933,7 +992,7 @@ pub(crate) mod test { let mut ctx = InterceptorContext::new(Input::doesnt_matter()); ctx.set_output_or_error(Ok(Output::doesnt_matter())); ctx.set_response(imds_response("").map(|_| SdkBody::empty())); - let classifier = ImdsResponseRetryClassifier; + let classifier = ImdsResponseRetryClassifier::default(); assert_eq!( RetryAction::NoActionIndicated, classifier.classify_retry(&ctx) @@ -950,6 +1009,65 @@ pub(crate) mod test { ); } + /// User provided retry classifier works + #[tokio::test] + async fn user_provided_retry_classifier() { + #[derive(Clone, Debug)] + struct UserProvidedRetryClassifier; + + impl ClassifyRetry for UserProvidedRetryClassifier { + fn name(&self) -> &'static str { + "UserProvidedRetryClassifier" + } + + // Don't retry anything + fn classify_retry(&self, _ctx: &InterceptorContext) -> RetryAction { + RetryAction::RetryForbidden + } + } + + let events = vec![ + ReplayEvent::new( + token_request("http://169.254.169.254", 21600), + token_response(0, TOKEN_A), + ), + ReplayEvent::new( + imds_request("http://169.254.169.254/latest/metadata", TOKEN_A), + http::Response::builder() + .status(401) + .body(SdkBody::empty()) + .unwrap(), + ), + ReplayEvent::new( + token_request("http://169.254.169.254", 21600), + token_response(21600, TOKEN_B), + ), + ReplayEvent::new( + imds_request("http://169.254.169.254/latest/metadata", TOKEN_B), + imds_response("ok"), + ), + ]; + let http_client = StaticReplayClient::new(events); + + let imds_client = super::Client::builder() + .configure( + &ProviderConfig::no_configuration() + .with_sleep_impl(InstantSleep::unlogged()) + .with_http_client(http_client.clone()), + ) + .retry_classifier(SharedRetryClassifier::new(UserProvidedRetryClassifier)) + .build(); + + let res = imds_client + .get("/latest/metadata") + .await + .expect_err("Client should error"); + + // Assert that the operation errored on the initial 401 and did not retry and get + // the 200 (since the user provided retry classifier never retries) + assert_full_error_contains!(res, "401"); + } + // since tokens are sent as headers, the tokens need to be valid header values #[tokio::test] async fn invalid_token() { @@ -989,9 +1107,6 @@ pub(crate) mod test { #[cfg(feature = "rustls")] async fn one_second_connect_timeout() { use crate::imds::client::ImdsError; - use aws_smithy_types::error::display::DisplayErrorContext; - use std::time::SystemTime; - let client = Client::builder() // 240.* can never be resolved .endpoint("http://240.0.0.0") @@ -1023,6 +1138,40 @@ pub(crate) mod test { ); } + /// Retry classifier properly retries timeouts when configured to (meaning it takes ~30s to fail) + #[tokio::test] + async fn retry_connect_timeouts() { + let http_client = StaticReplayClient::new(vec![]); + let imds_client = super::Client::builder() + .retry_classifier(SharedRetryClassifier::new( + ImdsResponseRetryClassifier::default().with_retry_connect_timeouts(true), + )) + .configure(&ProviderConfig::no_configuration().with_http_client(http_client.clone())) + .operation_timeout(Duration::from_secs(1)) + .endpoint("http://240.0.0.0") + .expect("valid uri") + .build(); + + let now = SystemTime::now(); + let _res = imds_client + .get("/latest/metadata") + .await + .expect_err("240.0.0.0 will never resolve"); + let time_elapsed: Duration = now.elapsed().unwrap(); + + assert!( + time_elapsed > Duration::from_secs(1), + "time_elapsed should be greater than 1s but was {:?}", + time_elapsed + ); + + assert!( + time_elapsed < Duration::from_secs(2), + "time_elapsed should be less than 2s but was {:?}", + time_elapsed + ); + } + #[derive(Debug, Deserialize)] struct ImdsConfigTest { env: HashMap,