Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ hyper = { version = "1.2", default-features = false, optional = true }
md-5 = { version = "0.10.6", default-features = false, optional = true }
quick-xml = { version = "0.38.0", features = ["serialize", "overlapped-lists"], optional = true }
rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"], optional = true }
reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2"], optional = true }
reqwest = { version = "0.12", default-features = false, features = ["http2", "rustls-tls-no-provider"], optional = true }
ring = { version = "0.17", default-features = false, features = ["std"], optional = true }
rustls-pki-types = { version = "1.9", default-features = false, features = ["std"], optional = true }
serde = { version = "1.0", default-features = false, features = ["derive"], optional = true }
Expand All @@ -71,12 +71,24 @@ wasm-bindgen-futures = "0.4.18"

[features]
default = ["fs"]
cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/stream", "chrono/serde", "base64", "rand", "ring", "http-body-util", "form_urlencoded", "serde_urlencoded"]
azure = ["cloud", "httparse"]
cloud-no-crypto = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/stream", "chrono/serde", "base64", "rand","http-body-util", "form_urlencoded", "serde_urlencoded"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a massive fan of the feature explosion, but this was the only way to avoid needing to cut a breaking release.

cloud = ["ring", "rustls-pki-types", "cloud-no-crypto", "reqwest?/rustls-tls-native-roots"]


azure-no-crypto = ["cloud-no-crypto", "httparse"]
azure = ["cloud", "azure-no-crypto"]

fs = ["walkdir"]
gcp = ["cloud", "rustls-pki-types"]
aws = ["cloud", "md-5"]
http = ["cloud"]

gcp-no-crypto = ["cloud-no-crypto"]
gcp = ["cloud", "gcp-no-crypto"]

aws-no-crypto = ["cloud-no-crypto", "md-5"]
aws = ["cloud", "aws-no-crypto"]

http-no-crypto = ["cloud-no-crypto"]
http = ["cloud", "http-no-crypto"]

tls-webpki-roots = ["reqwest?/rustls-tls-webpki-roots"]
integration = ["rand"]

Expand Down Expand Up @@ -105,4 +117,4 @@ features = ["js"]
[[test]]
name = "get_range_file"
path = "tests/get_range_file.rs"
required-features = ["fs"]
required-features = ["fs"]
12 changes: 11 additions & 1 deletion src/aws/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::aws::{
AmazonS3, AwsCredential, AwsCredentialProvider, Checksum, S3ConditionalPut, S3CopyIfNotExists,
STORE,
};
use crate::client::{HttpConnector, TokenCredentialProvider, http_connector};
use crate::client::{CryptoProvider, HttpConnector, TokenCredentialProvider, http_connector};
use crate::config::ConfigValue;
use crate::{ClientConfigKey, ClientOptions, Result, RetryConfig, StaticCredentialProvider};
use base64::Engine;
Expand Down Expand Up @@ -171,6 +171,8 @@ pub struct AmazonS3Builder {
client_options: ClientOptions,
/// Credentials
credentials: Option<AwsCredentialProvider>,
/// The [`CryptoProvider`] to use
crypto: Option<Arc<dyn CryptoProvider>>,
/// Skip signing requests
skip_signature: ConfigValue<bool>,
/// Copy if not exists
Expand Down Expand Up @@ -843,6 +845,12 @@ impl AmazonS3Builder {
self
}

/// The [`CryptoProvider`] to use
pub fn with_crypto_provider(mut self, provider: Arc<dyn CryptoProvider>) -> Self {
self.crypto = Some(provider);
self
}

/// Sets what protocol is allowed. If `allow_http` is :
/// * false (default): Only HTTPS are allowed
/// * true: HTTP and HTTPS are allowed
Expand Down Expand Up @@ -1150,6 +1158,7 @@ impl AmazonS3Builder {
endpoint: endpoint.clone(),
region: region.clone(),
credentials: Arc::clone(&credentials),
crypto: self.crypto.clone(),
},
http.connect(&self.client_options)?,
self.retry_config.clone(),
Expand Down Expand Up @@ -1190,6 +1199,7 @@ impl AmazonS3Builder {
bucket,
bucket_endpoint,
credentials,
crypto: self.crypto,
session_provider,
retry_config: self.retry_config,
client_options: self.client_options,
Expand Down
92 changes: 56 additions & 36 deletions src/aws/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ use crate::client::s3::{
CompleteMultipartUpload, CompleteMultipartUploadResult, CopyPartResult,
InitiateMultipartUploadResult, ListResponse, PartMetadata,
};
use crate::client::{GetOptionsExt, HttpClient, HttpError, HttpResponse};
use crate::client::{
CryptoProvider, DigestAlgorithm, GetOptionsExt, HttpClient, HttpError, HttpResponse,
crypto_provider,
};
use crate::list::{PaginatedListOptions, PaginatedListResult};
use crate::multipart::PartId;
use crate::{
Expand All @@ -52,8 +55,6 @@ use itertools::Itertools;
use md5::{Digest, Md5};
use percent_encoding::{PercentEncode, utf8_percent_encode};
use quick_xml::events::{self as xml_events};
use ring::digest;
use ring::digest::Context;
use serde::{Deserialize, Serialize};
use std::sync::Arc;

Expand Down Expand Up @@ -198,6 +199,7 @@ pub(crate) struct S3Config {
pub bucket: String,
pub bucket_endpoint: String,
pub credentials: AwsCredentialProvider,
pub crypto: Option<Arc<dyn CryptoProvider>>,
pub session_provider: Option<AwsCredentialProvider>,
pub retry_config: RetryConfig,
pub client_options: ClientOptions,
Expand All @@ -216,19 +218,18 @@ impl S3Config {
format!("{}/{}", self.bucket_endpoint, encode_path(path))
}

async fn get_session_credential(&self) -> Result<SessionCredential<'_>> {
let credential = match self.skip_signature {
async fn get_session_credential(&self) -> Result<Option<SessionCredential<'_>>> {
Ok(match self.skip_signature {
false => {
let provider = self.session_provider.as_ref().unwrap_or(&self.credentials);
Some(provider.get_credential().await?)
let credential = provider.get_credential().await?;
Some(SessionCredential {
credential,
session_token: self.session_provider.is_some(),
config: self,
})
}
true => None,
};

Ok(SessionCredential {
credential,
session_token: self.session_provider.is_some(),
config: self,
})
}

Expand All @@ -243,27 +244,32 @@ impl S3Config {
pub(crate) fn is_s3_express(&self) -> bool {
self.session_provider.is_some()
}

pub(crate) fn crypto(&self) -> Result<&dyn CryptoProvider> {
crypto_provider(self.crypto.as_deref())
}
}

struct SessionCredential<'a> {
credential: Option<Arc<AwsCredential>>,
credential: Arc<AwsCredential>,
session_token: bool,
config: &'a S3Config,
}

impl SessionCredential<'_> {
fn authorizer(&self) -> Option<AwsAuthorizer<'_>> {
fn authorizer(&self) -> Result<AwsAuthorizer<'_>> {
let mut authorizer =
AwsAuthorizer::new(self.credential.as_deref()?, "s3", &self.config.region)
AwsAuthorizer::new(self.credential.as_ref(), "s3", &self.config.region)
.with_sign_payload(self.config.sign_payload)
.with_request_payer(self.config.request_payer);
.with_request_payer(self.config.request_payer)
.with_crypto(self.config.crypto()?);

if self.session_token {
let token = HeaderName::from_static("x-amz-s3session-token");
authorizer = authorizer.with_token_header(token)
}

Some(authorizer)
Ok(authorizer)
}
}

Expand Down Expand Up @@ -296,7 +302,7 @@ pub(crate) struct Request<'a> {
path: &'a Path,
config: &'a S3Config,
builder: HttpRequestBuilder,
payload_sha256: Option<digest::Digest>,
payload_sha256: Option<[u8; 32]>,
payload: Option<PutPayload>,
use_session_creds: bool,
idempotent: bool,
Expand Down Expand Up @@ -397,13 +403,13 @@ impl Request<'_> {
Self { builder, ..self }
}

pub(crate) fn with_payload(mut self, payload: PutPayload) -> Self {
pub(crate) fn with_payload(mut self, payload: PutPayload) -> Result<Self> {
if (!self.config.skip_signature && self.config.sign_payload)
|| self.config.checksum.is_some()
{
let mut sha256 = Context::new(&digest::SHA256);
payload.iter().for_each(|x| sha256.update(x));
let payload_sha256 = sha256.finish();
let mut ctx = self.config.crypto()?.digest(DigestAlgorithm::Sha256)?;
payload.iter().for_each(|x| ctx.update(x));
let payload_sha256 = ctx.finish()?.try_into().unwrap();

if let Some(Checksum::SHA256) = self.config.checksum {
self.builder = self
Expand All @@ -416,24 +422,28 @@ impl Request<'_> {
let content_length = payload.content_length();
self.builder = self.builder.header(CONTENT_LENGTH, content_length);
self.payload = Some(payload);
self
Ok(self)
}

pub(crate) async fn send(self) -> Result<HttpResponse, RequestError> {
let credential = match self.use_session_creds {
true => self.config.get_session_credential().await?,
false => SessionCredential {
credential: self.config.get_credential().await?,
session_token: false,
config: self.config,
},
false => {
let credential = self.config.get_credential().await?;
credential.map(|credential| SessionCredential {
credential,
session_token: false,
config: self.config,
})
}
};
let authorizer = credential.as_ref().map(|x| x.authorizer()).transpose()?;

let sha = self.payload_sha256.as_ref().map(|x| x.as_ref());

let path = self.path.as_ref();
self.builder
.with_aws_sigv4(credential.authorizer(), sha)
.with_aws_sigv4(authorizer, sha)?
.retryable(&self.config.retry_config)
.retry_on_conflict(self.retry_on_conflict)
.idempotent(self.idempotent)
Expand Down Expand Up @@ -493,6 +503,7 @@ impl S3Client {
}

let credential = self.config.get_session_credential().await?;
let authorizer = credential.as_ref().map(|x| x.authorizer()).transpose()?;
let url = format!("{}?delete", self.config.bucket_endpoint);

let mut buffer = Vec::new();
Expand Down Expand Up @@ -536,7 +547,11 @@ impl S3Client {

let mut builder = self.client.request(Method::POST, url);

let digest = digest::digest(&digest::SHA256, &body);
let crypto = self.config.crypto()?;
let mut ctx = crypto.digest(DigestAlgorithm::Sha256)?;
ctx.update(body.as_ref());
let digest = ctx.finish()?;

builder = builder.header(SHA256_CHECKSUM, BASE64_STANDARD.encode(digest));

// S3 *requires* DeleteObjects to include a Content-MD5 header:
Expand All @@ -550,7 +565,7 @@ impl S3Client {
let response = builder
.header(CONTENT_TYPE, "application/xml")
.body(body)
.with_aws_sigv4(credential.authorizer(), Some(digest.as_ref()))
.with_aws_sigv4(authorizer, Some(digest))?
.send_retry(&self.config.retry_config)
.await
.map_err(|source| Error::DeleteObjectsRequest {
Expand Down Expand Up @@ -690,7 +705,7 @@ impl S3Client {
.idempotent(true);

request = match data {
PutPartPayload::Part(payload) => request.with_payload(payload),
PutPartPayload::Part(payload) => request.with_payload(payload)?,
PutPartPayload::Copy(path) => request.header(
"x-amz-copy-source",
&format!("{}/{}", self.config.bucket, encode_path(path)),
Expand Down Expand Up @@ -775,14 +790,15 @@ impl S3Client {
let body = quick_xml::se::to_string(&request).unwrap();

let credential = self.config.get_session_credential().await?;
let authorizer = credential.as_ref().map(|x| x.authorizer()).transpose()?;
let url = self.config.path_url(location);

let request = self
.client
.post(url)
.query(&[("uploadId", upload_id)])
.body(body)
.with_aws_sigv4(credential.authorizer(), None);
.with_aws_sigv4(authorizer, None)?;

let request = match mode {
CompleteMultipartMode::Overwrite => request,
Expand Down Expand Up @@ -821,11 +837,12 @@ impl S3Client {
#[cfg(test)]
pub(crate) async fn get_object_tagging(&self, path: &Path) -> Result<HttpResponse> {
let credential = self.config.get_session_credential().await?;
let authorizer = credential.as_ref().map(|x| x.authorizer()).transpose()?;
let url = format!("{}?tagging", self.config.path_url(path));
let response = self
.client
.request(Method::GET, url)
.with_aws_sigv4(credential.authorizer(), None)
.with_aws_sigv4(authorizer, None)?
.send_retry(&self.config.retry_config)
.await
.map_err(|e| e.error(STORE, path.to_string()))?;
Expand Down Expand Up @@ -856,6 +873,7 @@ impl GetClient for S3Client {
options: GetOptions,
) -> Result<HttpResponse> {
let credential = self.config.get_session_credential().await?;
let authorizer = credential.as_ref().map(|x| x.authorizer()).transpose()?;
let url = self.config.path_url(path);
let method = match options.head {
true => Method::HEAD,
Expand All @@ -878,7 +896,7 @@ impl GetClient for S3Client {

let response = builder
.with_get_options(options)
.with_aws_sigv4(credential.authorizer(), None)
.with_aws_sigv4(authorizer, None)?
.retryable_request()
.send(ctx)
.await
Expand All @@ -897,6 +915,7 @@ impl ListClient for Arc<S3Client> {
opts: PaginatedListOptions,
) -> Result<PaginatedListResult> {
let credential = self.config.get_session_credential().await?;
let authorizer = credential.as_ref().map(|x| x.authorizer()).transpose()?;
let url = self.config.bucket_endpoint.clone();

let mut query = Vec::with_capacity(4);
Expand Down Expand Up @@ -930,7 +949,7 @@ impl ListClient for Arc<S3Client> {
.request(Method::GET, &url)
.extensions(opts.extensions)
.query(&query)
.with_aws_sigv4(credential.authorizer(), None)
.with_aws_sigv4(authorizer, None)?
.send_retry(&self.config.retry_config)
.await
.map_err(|source| Error::ListRequest { source })?
Expand Down Expand Up @@ -1000,6 +1019,7 @@ mod tests {
conditional_put: Default::default(),
encryption_headers: Default::default(),
request_payer: false,
crypto: None,
};

let client = S3Client::new(config, HttpClient::new(reqwest::Client::new()));
Expand Down
Loading