diff --git a/.changes/next-release/bugfix-AmazonS3-34da391.json b/.changes/next-release/bugfix-AmazonS3-34da391.json new file mode 100644 index 000000000000..2d177a6c2849 --- /dev/null +++ b/.changes/next-release/bugfix-AmazonS3-34da391.json @@ -0,0 +1,6 @@ +{ + "type": "bugfix", + "category": "Amazon S3", + "contributor": "", + "description": "Added additional validations for multipart download operations in the Java multipart S3 client" +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java index 1ccae234631d..7466bda5b2a3 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriber.java @@ -23,6 +23,7 @@ import org.reactivestreams.Subscription; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; @@ -60,6 +61,11 @@ public class MultipartDownloaderSubscriber implements Subscriber "Sending GetObjectRequest for next part with partNumber=" + nextPartToGet); CompletableFuture getObjectFuture = s3.getObject(actualRequest, asyncResponseTransformer); + getObjectCallCount.incrementAndGet(); getObjectFutures.add(getObjectFuture); getObjectFuture.whenComplete((response, error) -> { if (error != null) { log.debug(() -> "Error encountered during GetObjectRequest with partNumber=" + nextPartToGet); - onError(error); + handleError(error); return; } requestMoreIfNeeded(response); @@ -166,6 +174,7 @@ private void requestMoreIfNeeded(GetObjectResponse response) { if (totalParts != null && totalParts > 1 && totalComplete < totalParts) { subscription.request(1); } else { + validatePartsCount(); log.debug(() -> String.format("Completing multipart download after a total of %d parts downloaded.", totalParts)); subscription.cancel(); } @@ -174,6 +183,13 @@ private void requestMoreIfNeeded(GetObjectResponse response) { @Override public void onError(Throwable t) { + handleError(t); + } + + /** + * The method used by the Subscriber itself when error occured. + */ + private void handleError(Throwable t) { CompletableFuture partFuture; while ((partFuture = getObjectFutures.poll()) != null) { partFuture.cancel(true); @@ -198,4 +214,14 @@ private GetObjectRequest nextRequest(int nextPartToGet) { } }); } + + private void validatePartsCount() { + int actualGetCount = getObjectCallCount.get(); + if (totalParts != null && actualGetCount != totalParts) { + String errorMessage = String.format("PartsCount validation failed. Expected %d, downloaded %d parts.", totalParts, + actualGetCount); + SdkClientException exception = SdkClientException.create(errorMessage); + handleError(exception); + } + } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberPartCountValidationTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberPartCountValidationTest.java new file mode 100644 index 000000000000..8ff1eeac0ad5 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartDownloaderSubscriberPartCountValidationTest.java @@ -0,0 +1,120 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +public class MultipartDownloaderSubscriberPartCountValidationTest { + @Mock + private S3AsyncClient s3Client; + + @Mock + private Subscription subscription; + + @Mock + private AsyncResponseTransformer responseTransformer; + + private GetObjectRequest getObjectRequest; + private MultipartDownloaderSubscriber subscriber; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + getObjectRequest = GetObjectRequest.builder() + .bucket("test-bucket") + .key("test-key") + .build(); + } + + @Test + void callCountMatchesTotalParts_shouldPass() throws InterruptedException { + subscriber = new MultipartDownloaderSubscriber(s3Client, getObjectRequest); + GetObjectResponse response1 = createMockResponse(3, "etag1"); + GetObjectResponse response2 = createMockResponse(3, "etag2"); + GetObjectResponse response3 = createMockResponse(3, "etag3"); + + CompletableFuture future1 = CompletableFuture.completedFuture(response1); + CompletableFuture future2 = CompletableFuture.completedFuture(response2); + CompletableFuture future3 = CompletableFuture.completedFuture(response3); + + when(s3Client.getObject(any(GetObjectRequest.class), eq(responseTransformer))) + .thenReturn(future1, future2, future3); + + subscriber.onSubscribe(subscription); + subscriber.onNext(responseTransformer); + subscriber.onNext(responseTransformer); + subscriber.onNext(responseTransformer); + Thread.sleep(100); + + subscriber.onComplete(); + + assertDoesNotThrow(() -> subscriber.future().get(1, TimeUnit.SECONDS)); + } + + @Test + void callCountMoreThanTotalParts_shouldThrowException() throws InterruptedException { + subscriber = new MultipartDownloaderSubscriber(s3Client, getObjectRequest, 3); + GetObjectResponse response1 = createMockResponse(2, "etag1"); + + CompletableFuture future1 = CompletableFuture.completedFuture(response1); + + when(s3Client.getObject(any(GetObjectRequest.class), eq(responseTransformer))) + .thenReturn(future1); + + subscriber.onSubscribe(subscription); + subscriber.onNext(responseTransformer); + Thread.sleep(100); + + subscriber.onComplete(); + + ExecutionException exception = assertThrows(ExecutionException.class, + () -> subscriber.future().get(1, TimeUnit.SECONDS)); + assertTrue(exception.getCause() instanceof SdkClientException); + assertTrue(exception.getCause().getMessage().contains("PartsCount validation failed")); + assertTrue(exception.getCause().getMessage().contains("Expected 2, downloaded 4 parts")); + + } + + private GetObjectResponse createMockResponse(int partsCount, String etag) { + GetObjectResponse.Builder builder = GetObjectResponse.builder() + .eTag(etag) + .contentLength(1024L); + + builder.partsCount(partsCount); + return builder.build(); + } + +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java index 9fd58e6a5fa0..449f245203c9 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java @@ -34,6 +34,7 @@ import static org.junit.jupiter.params.provider.Arguments.arguments; import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.internalErrorBody; import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.transformersSuppliers; +import static software.amazon.awssdk.services.s3.multipart.S3MultipartExecutionAttribute.MULTIPART_DOWNLOAD_RESUME_CONTEXT; import com.github.tomakehurst.wiremock.http.Fault; import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; @@ -58,6 +59,7 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.internal.async.ByteArrayAsyncResponseTransformer; import software.amazon.awssdk.core.internal.async.FileAsyncResponseTransformer; import software.amazon.awssdk.core.internal.async.InputStreamResponseTransformer; @@ -67,6 +69,7 @@ import software.amazon.awssdk.regions.Region; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.utils.AsyncResponseTransformerTestSupplier; @@ -144,6 +147,47 @@ public void errorOnThirdPart_shouldCompleteExceptionallyOnlyPartsGreaterThan } } + @ParameterizedTest + @MethodSource("partSizeAndTransformerParams") + public void partCountValidationFailure_shouldThrowException( + AsyncResponseTransformerTestSupplier supplier, + int partSize) { + + // To trigger the partCount failure, the resumeContext is used to initialize the actualGetCount larger than the + // totalPart number set in the response. This won't happen in real scenario, just to test if the error can be surfaced + // to the user if the validation fails. + MultipartDownloadResumeContext resumeContext = new MultipartDownloadResumeContext(); + resumeContext.addCompletedPart(1); + resumeContext.addCompletedPart(2); + resumeContext.addCompletedPart(3); + resumeContext.addToBytesToLastCompletedParts(3 * partSize); + + GetObjectRequest request = GetObjectRequest.builder() + .bucket(BUCKET) + .key(KEY) + .overrideConfiguration(config -> config + .putExecutionAttribute( + MULTIPART_DOWNLOAD_RESUME_CONTEXT, + resumeContext)) + .build(); + + util.stubForPart(BUCKET, KEY, 4, 2, partSize); + + // Skip the lazy transformer since the error won't surface unless the content is consumed + AsyncResponseTransformer transformer = supplier.transformer(); + if (transformer instanceof InputStreamResponseTransformer || transformer instanceof PublisherAsyncResponseTransformer) { + return; + } + + assertThatThrownBy(() -> { + T res = multipartClient.getObject(request, transformer).join(); + supplier.body(res); + }).isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(SdkClientException.class) + .hasMessageContaining("PartsCount validation failed. Expected 2, downloaded 4 parts"); + + } + @ParameterizedTest @MethodSource("nonRetryableResponseTransformers") public void errorOnFirstPart_shouldFail(AsyncResponseTransformerTestSupplier supplier) { diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java index ac667912a33f..e12a7cee35a0 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java @@ -66,6 +66,7 @@ public byte[] stubForPart(String testBucket, String testKey,int part, int totalP aResponse() .withHeader("x-amz-mp-parts-count", totalPart + "") .withHeader("ETag", eTag) + .withHeader("Content-Length", String.valueOf(body.length)) .withBody(body))); return body; }