Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ usage: s3-pit-restore [-h] -b BUCKET [-B DEST_BUCKET] [-d DEST]
[-P DEST_PREFIX] [-p PREFIX] [-t TIMESTAMP]
[-f FROM_TIMESTAMP] [-e] [-v] [--dry-run] [--debug]
[--test] [--max-workers MAX_WORKERS]
[--avoid-duplicates]
[--sse {AES256,aws:kms}]

optional arguments:
Expand All @@ -129,6 +130,7 @@ optional arguments:
--test s3 pit restore testing
--max-workers MAX_WORKERS
max number of concurrent download requests
--avoid-duplicates tries to avoid copying files that are already at the latest version
--sse ALGORITHM
specify what SSE algorithm you would like to use for the copy
```
Expand Down
108 changes: 108 additions & 0 deletions s3-pit-restore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import os, sys, time, signal, argparse, boto3, botocore, \
from datetime import datetime, timezone
from dateutil.parser import parse
from s3transfer.manager import TransferConfig
from botocore.exceptions import ClientError
from collections import defaultdict

args = None
executor = None
Expand Down Expand Up @@ -123,6 +125,45 @@ class TestS3PitRestore(unittest.TestCase):
self.assertNotEqual(bucket_versioning.status, None)
print("enabled!")

def get_versions(self, s3, path, dest, isDestPrefix = False):
bucket = s3.Bucket(dest)
base_path = os.path.basename(os.path.normpath(path))
# If isDestPrefix is true, we are getting the version for the dest_prefix folder; otherwise, we are getting the version for the bucket's root folder.
if isDestPrefix:
base_path = os.path.join(args.dest_prefix, base_path)
resp = bucket.meta.client.list_object_versions(Bucket=dest, Prefix=base_path)

fileList = defaultdict(list)
for obj in [*resp['Versions']]:
if obj['Key'][-1] != '/':
fileList[obj['Key']].append(obj['VersionId'])
return fileList

def compare_versions(self, version_before, version_after):
# If the `dest_prefix` is set we restore at a different location (at the dest_prefix path), the version_after value for dest_prefix will be different when restoring here as compared to restoring at the root location so we need different conditions to validate restoring here.
if args.dest_prefix:
if len(version_after) == 1:
return True
return False
for key in version_before.keys():
version_before_arr = version_before[key]
version_after_arr = version_after[key]
if (len(version_after_arr) == 1) and (version_after_arr[0] == version_before_arr[0]):
continue
elif len(version_after_arr) == 3:
continue
else:
return False
return True

def s3_clean_test_files(self, s3, path):
base_path = os.path.basename(os.path.normpath(path))
bucket = s3.Bucket(args.bucket)
bucket.object_versions.filter(Prefix=base_path).delete()
if args.dest_prefix:
base_path = os.path.join(args.dest_prefix, base_path)
bucket.object_versions.filter(Prefix=base_path).delete()

def test_restore(self):
contents_before = [ str(uuid.uuid4()) for n in range(2048) ]
contents_after = [ str(uuid.uuid4()) for n in range(2048) ]
Expand Down Expand Up @@ -178,6 +219,56 @@ class TestS3PitRestore(unittest.TestCase):
print("Restoring and checking for dmarker_restore test")
self.assertTrue(self.check_tree(path, content))

def test_avoid_duplicates(self):
if (args.bucket != args.dest_bucket):
print('test_avoid_duplicates is applicable when performing an inplace restore.')
return self.assertTrue(False)
contents_before = [ str(uuid.uuid4()) for n in range(2) ]
contents_after = [ str(uuid.uuid4()) for n in range(1) ]
path = os.path.join(os.path.abspath(args.dest), "test-s3-pit-avoid-duplicates")
s3 = boto3.resource('s3', endpoint_url=args.endpoint_url)
self.check_versioning(s3)

print("Before starting the avoid_duplicates test...")
self.remove_tree(path)

time.sleep(1)
time_before = datetime.now(timezone.utc)
time.sleep(1)
self.generate_tree(path, contents_before)
self.upload_directory(s3, path, args.bucket)
self.remove_tree(path)

print("Getting file versions for first upload.")
version_before = self.get_versions(s3, path, args.bucket)

print("Upload and overwriting...")
time.sleep(1)
time_after = datetime.now(timezone.utc)
time.sleep(1)
self.generate_tree(path, contents_after)
self.upload_directory(s3, path, args.bucket)
self.remove_tree(path)

args.from_timestamp = str(time_before)
args.timestamp = str(time_after)
args.avoid_duplicates = True
args.prefix = os.path.basename(os.path.normpath(path))
print("Restoring objects")
do_restore()
args.avoid_duplicates = False

print("Getting file versions after the restore.")
version_after = {}
if args.dest_prefix:
version_after = self.get_versions(s3, path, args.bucket, True)
else:
version_after = self.get_versions(s3, path, args.bucket)
print('Deleting test files uploaded to s3')
self.s3_clean_test_files(s3, path)
print('Comparing versions...')
self.assertTrue(self.compare_versions(version_before, version_after))

def signal_handler(signal, frame):
executor.shutdown(wait=False)
for future in list(futures.keys()):
Expand Down Expand Up @@ -238,6 +329,8 @@ def handled_by_standard(obj):
return True

def handled_by_copy(obj):
if args.avoid_duplicates and not needs_copy(obj):
return True
if args.dry_run:
print_obj(obj)
return True
Expand All @@ -246,6 +339,20 @@ def handled_by_copy(obj):
futures[future] = obj
return True

def needs_copy(obj):
try:
destination_object_data = client.head_object(Bucket=args.dest_bucket, Key=obj["Key"])
except ClientError as error:
if error.response['ResponseMetadata']['HTTPStatusCode'] == 404:
return True
else:
raise error
# Won't work for files uploaded with different multipart chunk sizes
if args.bucket != args.dest_bucket:
return obj["ETag"] != destination_object_data["ETag"]
else:
return obj["VersionId"] != destination_object_data["VersionId"]

def download_file(obj):
transfer.download_file(args.bucket, obj["Key"], obj["Key"], extra_args={"VersionId": obj["VersionId"]})
unixtime = time.mktime(obj["LastModified"].timetuple())
Expand Down Expand Up @@ -404,6 +511,7 @@ if __name__=='__main__':
parser.add_argument('--debug', help='enable debug output', action='store_true')
parser.add_argument('--test', help='s3 pit restore testing', action='store_true')
parser.add_argument('--max-workers', help='max number of concurrent download requests', default=10, type=int)
parser.add_argument('--avoid-duplicates', help='avoids copying files if the latest version is the version that matches timestamp requested', action='store_true')
parser.add_argument('--sse', choices=['AES256', 'aws:kms'], help='Specify server-side encryption')
args = parser.parse_args()

Expand Down