diff --git a/.env.example b/.env.example index 02824e7..7356491 100644 --- a/.env.example +++ b/.env.example @@ -5,3 +5,9 @@ OSS_ACCESS_KEY_ID= OSS_ACCESS_KEY_SECRET= ENDPOINT= BUCKET= + +# S3 +BUCKET= +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_DEFAULT_REGION= diff --git a/README.md b/README.md index b2951d5..7aac81b 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,31 @@ MINIO_ACCESS_KEY= MINIO_SECRET_KEY= ``` +### [S3](https://aws.amazon.com/s3/) + +Usage: + +```python +client = StoreFactory.new_client( + provider="S3", bucket= +) + +# Use endpoint when accessing S3 via a PrivateLink interface endpoint. +# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/s3-example-privatelink.html +client = StoreFactory.new_client( + provider="S3", bucket=, endpoint= +) +``` + +Required environment variables: + +```yaml +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +# If a region is not specified, the bucket is created in the S3 default region (us-east-1). +AWS_DEFAULT_REGION= +``` + ## Development Once you want to run the integration tests, you should have a `.env` file locally, similar to the `.env.example`. diff --git a/omnistore/objstore/constant.py b/omnistore/objstore/constant.py index 1e1d371..b9a592d 100644 --- a/omnistore/objstore/constant.py +++ b/omnistore/objstore/constant.py @@ -1,2 +1,3 @@ OBJECT_STORE_OSS = "OSS" OBJECT_STORE_MINIO = "MINIO" +OBJECT_STORE_S3 = "S3" diff --git a/omnistore/objstore/objstore_factory.py b/omnistore/objstore/objstore_factory.py index e2e60ba..c40fd39 100644 --- a/omnistore/objstore/objstore_factory.py +++ b/omnistore/objstore/objstore_factory.py @@ -1,6 +1,7 @@ from omnistore.objstore.aliyun_oss import OSS -from omnistore.objstore.constant import OBJECT_STORE_OSS, OBJECT_STORE_MINIO +from omnistore.objstore.constant import OBJECT_STORE_OSS, OBJECT_STORE_MINIO, OBJECT_STORE_S3 from omnistore.objstore.minio import MinIO +from omnistore.objstore.s3 import S3 from omnistore.store import Store @@ -8,10 +9,11 @@ class StoreFactory: ObjStores = { OBJECT_STORE_OSS: OSS, OBJECT_STORE_MINIO: MinIO, + OBJECT_STORE_S3: S3, } @classmethod - def new_client(cls, provider: str, endpoint: str, bucket: str) -> Store: + def new_client(cls, provider: str, endpoint: str = None, bucket: str = None) -> Store: objstore = cls.ObjStores[provider] if not objstore: raise KeyError(f"Unknown object store provider {provider}") diff --git a/omnistore/objstore/s3.py b/omnistore/objstore/s3.py new file mode 100644 index 0000000..5e48575 --- /dev/null +++ b/omnistore/objstore/s3.py @@ -0,0 +1,114 @@ +import io +import os +from pathlib import Path + +import boto3 +from botocore.exceptions import ClientError + +from omnistore.objstore.objstore import ObjStore + + +class S3(ObjStore): + def __init__(self, bucket: str, endpoint: str = None): + """ + Construct a new client to communicate with the AWS S3 provider. + + AWS credentials are expected to be provided via environment variables: + - AWS_ACCESS_KEY_ID + - AWS_SECRET_ACCESS_KEY + - AWS_DEFAULT_REGION + """ + region = os.environ.get("AWS_DEFAULT_REGION") + + # If a region is not specified, the bucket is created in the S3 default region (us-east-1). + # If the user explicitly provides an endpoint_url, the region is not used. + kwargs = {} + if endpoint: + kwargs['endpoint_url'] = endpoint + if region: + kwargs['region_name'] = region + + self.client = boto3.client('s3', **kwargs) + self.resource = boto3.resource('s3', **kwargs) + self.bucket_name = bucket + + # Make sure the bucket exists + try: + self.client.head_bucket(Bucket=bucket) + except ClientError as e: + # If bucket doesn't exist, create it + if e.response['Error']['Code'] == '404': + kwargs = {} + # For non us-east-1 region, we need to specify the LocationConstraint parameter when creating the bucket + if region: + kwargs['CreateBucketConfiguration'] = { + "LocationConstraint": region + } + self.client.create_bucket(Bucket=bucket, **kwargs) + else: + raise e + + def create_dir(self, dirname: str): + if not dirname.endswith("/"): + dirname += "/" + empty_stream = io.BytesIO(b"") + self.client.put_object(Bucket=self.bucket_name, Key=dirname, Body=empty_stream) + + def delete_dir(self, dirname: str): + if not dirname.endswith("/"): + dirname += "/" + + bucket = self.resource.Bucket(self.bucket_name) + bucket.objects.filter(Prefix=dirname).delete() + + def upload(self, src: str, dest: str): + self.client.upload_file(src, self.bucket_name, dest) + + def upload_dir(self, src_dir: str, dest_dir: str): + for file in Path(src_dir).rglob("*"): + if file.is_file(): + dest_path = f"{dest_dir}/{file.relative_to(src_dir)}" + self.upload(str(file), dest_path) + elif file.is_dir(): + self.create_dir(f"{dest_dir}/{file.relative_to(src_dir)}/") + + def download(self, src: str, dest: str): + self.client.download_file(self.bucket_name, src, dest) + + def download_dir(self, src_dir: str, dest_dir: str): + if not src_dir.endswith("/"): + src_dir += "/" + path = Path(dest_dir) + if not path.exists(): + path.mkdir(parents=True) + + paginator = self.client.get_paginator('list_objects_v2') + pages = paginator.paginate(Bucket=self.bucket_name, Prefix=src_dir) + + for page in pages: + if 'Contents' not in page: + continue + + for obj in page['Contents']: + key = obj['Key'] + if key.endswith('/'): # Skip directories + continue + + file_path = Path(dest_dir, Path(key).relative_to(src_dir)) + if not file_path.parent.exists(): + file_path.parent.mkdir(parents=True, exist_ok=True) + + self.download(key, str(file_path)) + + def delete(self, filename: str): + self.client.delete_object(Bucket=self.bucket_name, Key=filename) + + def exists(self, filename: str): + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + else: + raise e diff --git a/tests/integration_tests/objstore/test_s3.py b/tests/integration_tests/objstore/test_s3.py new file mode 100644 index 0000000..90f379a --- /dev/null +++ b/tests/integration_tests/objstore/test_s3.py @@ -0,0 +1,65 @@ +import os +import shutil + +import pytest +from dotenv import load_dotenv + +from omnistore.objstore import StoreFactory +from omnistore.objstore.constant import OBJECT_STORE_S3 + +load_dotenv() + +class TestS3: + @pytest.fixture(scope="module", autouse=True) + def setup_and_teardown(self): + print("Setting up the test environment.") + try: + os.makedirs("./test-tmp", exist_ok=True) + except Exception as e: + print(f"An error occurred: {e}") + + yield + + print("Tearing down the test environment.") + shutil.rmtree("./test-tmp") + + def test_upload_and_download_files(self): + bucket = os.getenv("BUCKET") + + client = StoreFactory.new_client( + provider=OBJECT_STORE_S3, bucket=bucket + ) + assert False == client.exists("foo.txt") + + with open("./test-tmp/foo.txt", "w") as file: + file.write("test") + + client.upload("./test-tmp/foo.txt", "foo.txt") + assert True == client.exists("foo.txt") + + client.download("foo.txt", "./test-tmp/bar.txt") + assert True == os.path.exists("./test-tmp/bar.txt") + + client.delete("foo.txt") + assert False == client.exists("foo.txt") + + def test_upload_and_download_dir(self): + bucket = os.getenv("BUCKET") + + client = StoreFactory.new_client( + provider=OBJECT_STORE_S3, bucket=bucket + ) + assert False == client.exists("/test/foo.txt") + + os.makedirs("./test-tmp/test/111", exist_ok=True) + with open("./test-tmp/test/111/foo.txt", "w") as file: + file.write("test") + + client.upload_dir("./test-tmp/test", "test") + assert True == client.exists("test/111/foo.txt") + + client.download_dir("test", "./test-tmp/test1") + assert True == os.path.exists("./test-tmp/test1/111/foo.txt") + + client.delete_dir("test") + assert False == client.exists("test/foo.txt")