diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8eb49035..592c690c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,6 +131,8 @@ jobs: aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket-for-spawn aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket-for-checksum aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket-for-copy-if-not-exists + aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket-for-multipart-copy-large + aws --endpoint-url=http://localhost:4566 s3 mb s3://test-bucket-for-multipart-copy-small aws --endpoint-url=http://localhost:4566 s3api create-bucket --bucket test-object-lock --object-lock-enabled-for-bucket KMS_KEY=$(aws --endpoint-url=http://localhost:4566 kms create-key --description "test key") diff --git a/src/aws/builder.rs b/src/aws/builder.rs index e49145a4..0393a3b4 100644 --- a/src/aws/builder.rs +++ b/src/aws/builder.rs @@ -42,6 +42,11 @@ use url::Url; /// Default metadata endpoint static DEFAULT_METADATA_ENDPOINT: &str = "http://169.254.169.254"; +/// AWS S3 does not support copy operations larger than 5 GiB in a single request. See +/// [CopyObject](https://docs.aws.amazon.com/AmazonS3/latest/userguide/copy-object.html) for more +/// details. +const MAX_SINGLE_REQUEST_COPY_SIZE: u64 = 5 * 1024 * 1024 * 1024; + /// A specialized `Error` for object store-related errors #[derive(Debug, thiserror::Error)] enum Error { @@ -189,6 +194,10 @@ pub struct AmazonS3Builder { request_payer: ConfigValue, /// The [`HttpConnector`] to use http_connector: Option>, + /// Threshold (bytes) above which copy uses multipart copy. If not set, defaults to 5 GiB. + multipart_copy_threshold: Option>, + /// Preferred multipart copy part size (bytes). If not set, defaults to 5 GiB. + multipart_copy_part_size: Option>, } /// Configuration keys for [`AmazonS3Builder`] @@ -423,6 +432,10 @@ pub enum AmazonS3ConfigKey { /// Encryption options Encryption(S3EncryptionConfigKey), + /// Threshold (bytes) to switch to multipart copy + MultipartCopyThreshold, + /// Preferred multipart copy part size (bytes) + MultipartCopyPartSize, } impl AsRef for AmazonS3ConfigKey { @@ -455,6 +468,8 @@ impl AsRef for AmazonS3ConfigKey { Self::RequestPayer => "aws_request_payer", Self::Client(opt) => opt.as_ref(), Self::Encryption(opt) => opt.as_ref(), + Self::MultipartCopyThreshold => "aws_multipart_copy_threshold", + Self::MultipartCopyPartSize => "aws_multipart_copy_part_size", } } } @@ -499,6 +514,12 @@ impl FromStr for AmazonS3ConfigKey { "aws_conditional_put" | "conditional_put" => Ok(Self::ConditionalPut), "aws_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging), "aws_request_payer" | "request_payer" => Ok(Self::RequestPayer), + "aws_multipart_copy_threshold" | "multipart_copy_threshold" => { + Ok(Self::MultipartCopyThreshold) + } + "aws_multipart_copy_part_size" | "multipart_copy_part_size" => { + Ok(Self::MultipartCopyPartSize) + } // Backwards compatibility "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), "aws_server_side_encryption" | "server_side_encryption" => Ok(Self::Encryption( @@ -666,6 +687,12 @@ impl AmazonS3Builder { self.encryption_customer_key_base64 = Some(value.into()) } }, + AmazonS3ConfigKey::MultipartCopyThreshold => { + self.multipart_copy_threshold = Some(ConfigValue::Deferred(value.into())) + } + AmazonS3ConfigKey::MultipartCopyPartSize => { + self.multipart_copy_part_size = Some(ConfigValue::Deferred(value.into())) + } }; self } @@ -733,6 +760,14 @@ impl AmazonS3Builder { self.encryption_customer_key_base64.clone() } }, + AmazonS3ConfigKey::MultipartCopyThreshold => self + .multipart_copy_threshold + .as_ref() + .map(|x| x.to_string()), + AmazonS3ConfigKey::MultipartCopyPartSize => self + .multipart_copy_part_size + .as_ref() + .map(|x| x.to_string()), } } @@ -1029,6 +1064,18 @@ impl AmazonS3Builder { self } + /// Set threshold (bytes) above which copy uses multipart copy + pub fn with_multipart_copy_threshold(mut self, threshold_bytes: u64) -> Self { + self.multipart_copy_threshold = Some(ConfigValue::Parsed(threshold_bytes)); + self + } + + /// Set preferred multipart copy part size (bytes) + pub fn with_multipart_copy_part_size(mut self, part_size_bytes: u64) -> Self { + self.multipart_copy_part_size = Some(ConfigValue::Parsed(part_size_bytes)); + self + } + /// Create a [`AmazonS3`] instance from the provided values, /// consuming `self`. pub fn build(mut self) -> Result { @@ -1185,6 +1232,17 @@ impl AmazonS3Builder { S3EncryptionHeaders::default() }; + let multipart_copy_threshold = self + .multipart_copy_threshold + .map(|val| val.get()) + .transpose()? + .unwrap_or(MAX_SINGLE_REQUEST_COPY_SIZE); + let multipart_copy_part_size = self + .multipart_copy_part_size + .map(|val| val.get()) + .transpose()? + .unwrap_or(MAX_SINGLE_REQUEST_COPY_SIZE); + let config = S3Config { region, bucket, @@ -1201,6 +1259,8 @@ impl AmazonS3Builder { conditional_put: self.conditional_put.get()?, encryption_headers, request_payer: self.request_payer.get()?, + multipart_copy_threshold, + multipart_copy_part_size, }; let http_client = http.connect(&config.client_options)?; diff --git a/src/aws/client.rs b/src/aws/client.rs index bd9618ed..dd0c8679 100644 --- a/src/aws/client.rs +++ b/src/aws/client.rs @@ -140,6 +140,7 @@ impl From for crate::Error { pub(crate) enum PutPartPayload<'a> { Part(PutPayload), Copy(&'a Path), + CopyRange(&'a Path, std::ops::Range), } impl Default for PutPartPayload<'_> { @@ -209,6 +210,10 @@ pub(crate) struct S3Config { pub conditional_put: S3ConditionalPut, pub request_payer: bool, pub(super) encryption_headers: S3EncryptionHeaders, + /// Threshold in bytes above which copy will use multipart copy + pub multipart_copy_threshold: u64, + /// Preferred multipart copy part size in bytes (None => auto) + pub multipart_copy_part_size: u64, } impl S3Config { @@ -681,7 +686,10 @@ impl S3Client { part_idx: usize, data: PutPartPayload<'_>, ) -> Result { - let is_copy = matches!(data, PutPartPayload::Copy(_)); + let is_copy = matches!( + data, + PutPartPayload::Copy(_) | PutPartPayload::CopyRange(_, _) + ); let part = (part_idx + 1).to_string(); let mut request = self @@ -695,6 +703,18 @@ impl S3Client { "x-amz-copy-source", &format!("{}/{}", self.config.bucket, encode_path(path)), ), + PutPartPayload::CopyRange(path, range) => { + // AWS expects inclusive end for copy range header + let start = range.start; + let end_inclusive = range.end.saturating_sub(1); + let range_value = format!("bytes={}-{}", start, end_inclusive); + request + .header( + "x-amz-copy-source", + &format!("{}/{}", self.config.bucket, encode_path(path)), + ) + .header("x-amz-copy-source-range", &range_value) + } }; if self @@ -1000,6 +1020,8 @@ mod tests { conditional_put: Default::default(), encryption_headers: Default::default(), request_payer: false, + multipart_copy_threshold: 5 * 1024 * 1024 * 1024, + multipart_copy_part_size: 5 * 1024 * 1024 * 1024, }; let client = S3Client::new(config, HttpClient::new(reqwest::Client::new())); diff --git a/src/aws/mod.rs b/src/aws/mod.rs index dd2cf6f0..afa92e75 100644 --- a/src/aws/mod.rs +++ b/src/aws/mod.rs @@ -101,6 +101,69 @@ impl AmazonS3 { fn path_url(&self, path: &Path) -> String { self.client.config.path_url(path) } + + /// Construct the payloads for a multipart copy operation. + fn multipart_copy_payloads<'a>(&self, from: &'a Path, size: u64) -> Vec> { + let part_size = self.client.config.multipart_copy_part_size; + if size <= part_size { + return vec![PutPartPayload::Copy(from)]; + } + let mut payloads = Vec::new(); + let mut offset = 0; + while offset < size { + let end = if size - offset <= part_size { + size + } else { + offset + part_size + }; + payloads.push(PutPartPayload::CopyRange(from, offset..end)); + offset = end; + } + payloads + } + + /// Perform a multipart copy operation + /// + /// If the multipart upload fails, this function makes a best effort attempt to clean it up. + /// It's the caller's responsibility to add a lifecycle rule if guaranteed cleanup is required, + /// as we cannot protect against an ill-timed process crash. + async fn copy_multipart( + &self, + from: &Path, + to: &Path, + size: u64, + mode: CompleteMultipartMode, + ) -> Result<()> { + // Perform multipart copy using UploadPartCopy + let upload_id = self + .client + .create_multipart(to, PutMultipartOptions::default()) + .await?; + + let mut parts = Vec::new(); + for (idx, payload) in self + .multipart_copy_payloads(from, size) + .into_iter() + .enumerate() + { + match self.client.put_part(to, &upload_id, idx, payload).await { + Ok(part) => parts.push(part), + Err(e) => { + let _ = self.client.abort_multipart(to, &upload_id).await; + return Err(e); + } + }; + } + if let Err(err) = self + .client + .complete_multipart(to, &upload_id, parts, mode) + .await + { + let _ = self.client.abort_multipart(to, &upload_id).await; + return Err(err); + } + Ok(()) + } } #[async_trait] @@ -305,14 +368,31 @@ impl ObjectStore for AmazonS3 { mode, extensions: _, } = options; + // Determine source size to decide between single CopyObject and multipart copy + let head_meta = self + .client + .get_opts( + from, + GetOptions { + head: true, + ..Default::default() + }, + ) + .await? + .meta; match mode { CopyMode::Overwrite => { - self.client - .copy_request(from, to) - .idempotent(true) - .send() - .await?; + if head_meta.size <= self.client.config.multipart_copy_threshold { + self.client + .copy_request(from, to) + .idempotent(true) + .send() + .await?; + } else { + self.copy_multipart(from, to, head_meta.size, CompleteMultipartMode::Overwrite) + .await?; + } Ok(()) } CopyMode::Create => { @@ -322,45 +402,16 @@ impl ObjectStore for AmazonS3 { } Some(S3CopyIfNotExists::HeaderWithStatus(k, v, status)) => (k, v, *status), Some(S3CopyIfNotExists::Multipart) => { - let upload_id = self - .client - .create_multipart(to, PutMultipartOptions::default()) - .await?; - - let res = async { - let part_id = self - .client - .put_part(to, &upload_id, 0, PutPartPayload::Copy(from)) - .await?; - match self - .client - .complete_multipart( - to, - &upload_id, - vec![part_id], - CompleteMultipartMode::Create, - ) - .await - { - Err(e @ Error::Precondition { .. }) => Err(Error::AlreadyExists { + return self + .copy_multipart(from, to, head_meta.size, CompleteMultipartMode::Create) + .await + .map_err(|err| match err { + Error::Precondition { .. } => Error::AlreadyExists { path: to.to_string(), - source: Box::new(e), - }), - Ok(_) => Ok(()), - Err(e) => Err(e), - } - } - .await; - - // If the multipart upload failed, make a best effort attempt to - // clean it up. It's the caller's responsibility to add a - // lifecycle rule if guaranteed cleanup is required, as we - // cannot protect against an ill-timed process crash. - if res.is_err() { - let _ = self.client.abort_multipart(to, &upload_id).await; - } - - return res; + source: Box::new(err), + }, + other => other, + }); } None => { return Err(Error::NotSupported { @@ -513,6 +564,7 @@ mod tests { use crate::tests::*; use base64::Engine; use base64::prelude::BASE64_STANDARD; + use bytes::BytesMut; use http::HeaderMap; const NON_EXISTENT_NAME: &str = "nonexistentname"; @@ -574,6 +626,63 @@ mod tests { store.delete(&dst).await.unwrap(); } + #[tokio::test] + async fn large_file_copy_multipart() { + maybe_skip_integration!(); + + let bucket = "test-bucket-for-multipart-copy-large"; + let store = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .with_multipart_copy_threshold(5 * 1024 * 1024) + .with_multipart_copy_part_size(5 * 1024 * 1024) + .build() + .unwrap(); + + let mut payload = BytesMut::zeroed(10 * 1024 * 1024); + rand::fill(&mut payload[..]); + + let src = Path::parse("src.bin").unwrap(); + let dst = Path::parse("dst.bin").unwrap(); + store + .put(&src, PutPayload::from(payload.clone().freeze())) + .await + .unwrap(); + store.copy(&src, &dst).await.unwrap(); + let copied = store.get(&dst).await.unwrap(); + let content = copied.bytes().await.unwrap(); + assert_eq!(content, payload); + store.delete(&src).await.unwrap(); + store.delete(&dst).await.unwrap(); + } + + #[tokio::test] + async fn small_file_copy_single_part() { + maybe_skip_integration!(); + + let bucket = "test-bucket-for-multipart-copy-small"; + let store = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .with_multipart_copy_threshold(5 * 1024 * 1024) // trigger multipart copy + .with_multipart_copy_part_size(50 * 1024 * 1024) // but only use one part + .build() + .unwrap(); + + let src = Path::parse("src.bin").unwrap(); + let dst = Path::parse("dst.bin").unwrap(); + let mut payload = BytesMut::zeroed(10 * 1024 * 1024); + rand::fill(&mut payload[..]); + store + .put(&src, PutPayload::from(payload.clone().freeze())) + .await + .unwrap(); + store.copy(&src, &dst).await.unwrap(); + let copied = store.get(&dst).await.unwrap(); + let content = copied.bytes().await.unwrap(); + assert_eq!(content, payload); + store.delete(&src).await.unwrap(); + store.delete(&dst).await.unwrap(); + } + #[tokio::test] async fn write_multipart_file_with_signature_object_lock() { maybe_skip_integration!(); @@ -914,4 +1023,43 @@ mod tests { shutdown_tx.send(()).ok(); thread_handle.join().expect("runtime thread panicked"); } + + #[test] + fn test_multipart_copy_payloads_single() { + let store = AmazonS3Builder::default() + .with_bucket_name(NON_EXISTENT_NAME) + .with_multipart_copy_part_size(1024) + .build() + .unwrap(); + let path = Path::from("test.txt"); + let payloads = store.multipart_copy_payloads(&path, 1024); + assert_eq!(payloads.len(), 1); + let PutPartPayload::Copy(payload_path) = payloads[0] else { + panic!("expected Copy payload"); + }; + assert_eq!(payload_path, &path); + } + #[test] + fn test_multipart_copy_payloads_multiple() { + let store = AmazonS3Builder::default() + .with_bucket_name(NON_EXISTENT_NAME) + .with_multipart_copy_part_size(1024) + .build() + .unwrap(); + let path = Path::from("test.txt"); + let payloads = store.multipart_copy_payloads(&path, 2000); + assert_eq!(payloads.len(), 2); + let mut payloads = payloads.into_iter(); + let Some(PutPartPayload::CopyRange(payload_path, range)) = payloads.next() else { + panic!("expected CopyRange payload"); + }; + assert_eq!(payload_path, &path); + assert_eq!(range, 0..1024); + let Some(PutPartPayload::CopyRange(payload_path, range)) = payloads.next() else { + panic!("expected CopyRange payload"); + }; + assert_eq!(payload_path, &path); + assert_eq!(range, 1024..2000); + assert!(payloads.next().is_none()); + } } diff --git a/src/config.rs b/src/config.rs index 29a389d4..b042e209 100644 --- a/src/config.rs +++ b/src/config.rs @@ -112,6 +112,15 @@ impl Parse for u32 { } } +impl Parse for u64 { + fn parse(v: &str) -> Result { + Self::from_str(v).map_err(|_| Error::Generic { + store: "Config", + source: format!("failed to parse \"{v}\" as u64").into(), + }) + } +} + impl Parse for HeaderValue { fn parse(v: &str) -> Result { Self::from_str(v).map_err(|_| Error::Generic {