diff --git a/README.md b/README.md index 3f05d18..7a6c9d2 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ usage: s3-pit-restore [-h] -b BUCKET [-B DEST_BUCKET] [-d DEST] [-P DEST_PREFIX] [-p PREFIX] [-t TIMESTAMP] [-f FROM_TIMESTAMP] [-e] [-v] [--dry-run] [--debug] [--test] [--max-workers MAX_WORKERS] + [--avoid-duplicates] [--sse {AES256,aws:kms}] optional arguments: @@ -129,6 +130,7 @@ optional arguments: --test s3 pit restore testing --max-workers MAX_WORKERS max number of concurrent download requests + --avoid-duplicates tries to avoid copying files that are already at the latest version --sse ALGORITHM specify what SSE algorithm you would like to use for the copy ``` diff --git a/s3-pit-restore b/s3-pit-restore index c905d8f..41b5a49 100755 --- a/s3-pit-restore +++ b/s3-pit-restore @@ -34,6 +34,8 @@ import os, sys, time, signal, argparse, boto3, botocore, \ from datetime import datetime, timezone from dateutil.parser import parse from s3transfer.manager import TransferConfig +from botocore.exceptions import ClientError +from collections import defaultdict args = None executor = None @@ -123,6 +125,45 @@ class TestS3PitRestore(unittest.TestCase): self.assertNotEqual(bucket_versioning.status, None) print("enabled!") + def get_versions(self, s3, path, dest, isDestPrefix = False): + bucket = s3.Bucket(dest) + base_path = os.path.basename(os.path.normpath(path)) + # If isDestPrefix is true, we are getting the version for the dest_prefix folder; otherwise, we are getting the version for the bucket's root folder. + if isDestPrefix: + base_path = os.path.join(args.dest_prefix, base_path) + resp = bucket.meta.client.list_object_versions(Bucket=dest, Prefix=base_path) + + fileList = defaultdict(list) + for obj in [*resp['Versions']]: + if obj['Key'][-1] != '/': + fileList[obj['Key']].append(obj['VersionId']) + return fileList + + def compare_versions(self, version_before, version_after): + # If the `dest_prefix` is set we restore at a different location (at the dest_prefix path), the version_after value for dest_prefix will be different when restoring here as compared to restoring at the root location so we need different conditions to validate restoring here. + if args.dest_prefix: + if len(version_after) == 1: + return True + return False + for key in version_before.keys(): + version_before_arr = version_before[key] + version_after_arr = version_after[key] + if (len(version_after_arr) == 1) and (version_after_arr[0] == version_before_arr[0]): + continue + elif len(version_after_arr) == 3: + continue + else: + return False + return True + + def s3_clean_test_files(self, s3, path): + base_path = os.path.basename(os.path.normpath(path)) + bucket = s3.Bucket(args.bucket) + bucket.object_versions.filter(Prefix=base_path).delete() + if args.dest_prefix: + base_path = os.path.join(args.dest_prefix, base_path) + bucket.object_versions.filter(Prefix=base_path).delete() + def test_restore(self): contents_before = [ str(uuid.uuid4()) for n in range(2048) ] contents_after = [ str(uuid.uuid4()) for n in range(2048) ] @@ -178,6 +219,56 @@ class TestS3PitRestore(unittest.TestCase): print("Restoring and checking for dmarker_restore test") self.assertTrue(self.check_tree(path, content)) + def test_avoid_duplicates(self): + if (args.bucket != args.dest_bucket): + print('test_avoid_duplicates is applicable when performing an inplace restore.') + return self.assertTrue(False) + contents_before = [ str(uuid.uuid4()) for n in range(2) ] + contents_after = [ str(uuid.uuid4()) for n in range(1) ] + path = os.path.join(os.path.abspath(args.dest), "test-s3-pit-avoid-duplicates") + s3 = boto3.resource('s3', endpoint_url=args.endpoint_url) + self.check_versioning(s3) + + print("Before starting the avoid_duplicates test...") + self.remove_tree(path) + + time.sleep(1) + time_before = datetime.now(timezone.utc) + time.sleep(1) + self.generate_tree(path, contents_before) + self.upload_directory(s3, path, args.bucket) + self.remove_tree(path) + + print("Getting file versions for first upload.") + version_before = self.get_versions(s3, path, args.bucket) + + print("Upload and overwriting...") + time.sleep(1) + time_after = datetime.now(timezone.utc) + time.sleep(1) + self.generate_tree(path, contents_after) + self.upload_directory(s3, path, args.bucket) + self.remove_tree(path) + + args.from_timestamp = str(time_before) + args.timestamp = str(time_after) + args.avoid_duplicates = True + args.prefix = os.path.basename(os.path.normpath(path)) + print("Restoring objects") + do_restore() + args.avoid_duplicates = False + + print("Getting file versions after the restore.") + version_after = {} + if args.dest_prefix: + version_after = self.get_versions(s3, path, args.bucket, True) + else: + version_after = self.get_versions(s3, path, args.bucket) + print('Deleting test files uploaded to s3') + self.s3_clean_test_files(s3, path) + print('Comparing versions...') + self.assertTrue(self.compare_versions(version_before, version_after)) + def signal_handler(signal, frame): executor.shutdown(wait=False) for future in list(futures.keys()): @@ -238,6 +329,8 @@ def handled_by_standard(obj): return True def handled_by_copy(obj): + if args.avoid_duplicates and not needs_copy(obj): + return True if args.dry_run: print_obj(obj) return True @@ -246,6 +339,20 @@ def handled_by_copy(obj): futures[future] = obj return True +def needs_copy(obj): + try: + destination_object_data = client.head_object(Bucket=args.dest_bucket, Key=obj["Key"]) + except ClientError as error: + if error.response['ResponseMetadata']['HTTPStatusCode'] == 404: + return True + else: + raise error + # Won't work for files uploaded with different multipart chunk sizes + if args.bucket != args.dest_bucket: + return obj["ETag"] != destination_object_data["ETag"] + else: + return obj["VersionId"] != destination_object_data["VersionId"] + def download_file(obj): transfer.download_file(args.bucket, obj["Key"], obj["Key"], extra_args={"VersionId": obj["VersionId"]}) unixtime = time.mktime(obj["LastModified"].timetuple()) @@ -404,6 +511,7 @@ if __name__=='__main__': parser.add_argument('--debug', help='enable debug output', action='store_true') parser.add_argument('--test', help='s3 pit restore testing', action='store_true') parser.add_argument('--max-workers', help='max number of concurrent download requests', default=10, type=int) + parser.add_argument('--avoid-duplicates', help='avoids copying files if the latest version is the version that matches timestamp requested', action='store_true') parser.add_argument('--sse', choices=['AES256', 'aws:kms'], help='Specify server-side encryption') args = parser.parse_args()