Skip to content

Commit

Permalink
Start adding S3 support for VRTStack
Browse files Browse the repository at this point in the history
  • Loading branch information
scottstanie committed Mar 2, 2024
1 parent 5478a89 commit f6a4de5
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 3 deletions.
199 changes: 199 additions & 0 deletions src/dolphin/io/_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
from __future__ import annotations

import copy
import logging
import re
from pathlib import Path
from typing import Protocol, Union
from urllib.parse import ParseResult, urlparse

__all__ = ["S3Path"]


logger = logging.getLogger(__name__)


class GeneralPath(Protocol):
"""A protocol to handle paths that can be either local or S3 paths."""

def parent(self):
...

def suffix(self):
...

def read_text(self):
...

def __truediv__(self, other):
...

def __str__(self) -> str:
...

def __fspath__(self) -> str:
return str(self)


class S3Path(GeneralPath):
"""A convenience class to handle paths on S3.
This class relies on `pathlib.Path` for operations using `urllib` to parse the url.
If passing a url with a trailing slash, the slash will be preserved
when converting back to string.
Note that pure path manipulation functions do *not* require `boto3`,
but functions which interact with S3 (e.g. `exists()`, `.read_text()`) do.
Attributes
----------
bucket : str
Name of bucket in the url
path : pathlib.Path
The URL path after s3://<bucket>/
key : str
Alias of `path` converted to a string
Examples
--------
>>> from orca.paths import S3Path
>>> s3_path = S3Path("s3://bucket/path/to/file.txt")
>>> str(s3_path)
's3://bucket/path/to/file.txt'
>>> s3_path.parent
S3Path("s3://bucket/path/to/")
>>> str(s3_path.parent)
's3://bucket/path/to'
"""

def __init__(self, s3_url: Union[str, "S3Path"]):
"""Create an S3Path.
Parameters
----------
s3_url : str or S3Path
The S3 url to parse.
"""
# Names come from the urllib.parse.ParseResult
if isinstance(s3_url, S3Path):
self._scheme: str = s3_url._scheme
self._netloc: str = s3_url._netloc
self.bucket: str = s3_url.bucket
self.path: Path = s3_url.path
self._trailing_slash: str = s3_url._trailing_slash
else:
parsed: ParseResult = urlparse(s3_url)
self._scheme = parsed.scheme
self._netloc = self.bucket = parsed.netloc
self._parsed = parsed
self.path = Path(parsed.path)
self._trailing_slash = "/" if s3_url.endswith("/") else ""

if self._scheme != "s3":
raise ValueError(f"{s3_url} is not an S3 url")

@classmethod
def from_bucket_key(cls, bucket: str, key: str):
"""Create a `S3Path` from the bucket name and key/prefix.
Matches API of some Boto3 functions which use this format.
Parameters
----------
bucket : str
Name of S3 bucket.
key : str
S3 url of path after the bucket.
"""
return cls(f"s3://{bucket}/{key}")

def get_path(self):
# For S3 paths, we need to add the double slash and netloc back to the front
return f"{self._scheme}://{self._netloc}{self.path.as_posix()}{self._trailing_slash}"

@property
def key(self) -> str:
"""Name of key/prefix within the bucket with leading slash removed."""
return f"{str(self.path.as_posix()).lstrip('/')}{self._trailing_slash}"

@property
def parent(self):
parent_path = self.path.parent
# Since this is a parent, it will will always end in a slash
if self._scheme == "s3":
# For S3 paths, we need to add the scheme and netloc back to the front
return S3Path(f"{self._scheme}://{self._netloc}{parent_path.as_posix()}/")
else:
# For local paths, we can just convert the path to a string
return S3Path(str(parent_path) + "/")

@property
def suffix(self):
return self.path.suffix

def _get_client(self):
import boto3

return boto3.client("s3")

def exists(self) -> bool:
"""Whether this path exists on S3."""
client = self._get_client()
resp = client.list_objects_v2(
Bucket=self.bucket,
Prefix=self.key,
MaxKeys=1,
)
return resp.get("KeyCount") == 1

def read_text(self) -> str:
"""Download/read the S3 file as text."""
return self._download_as_bytes().decode()

def read_bytes(self) -> bytes:
"""Download/read the S3 file as bytes."""
return self._download_as_bytes()

def _download_as_bytes(self) -> bytes:
"""Download file to a `BytesIO` buffer to read as bytes."""
from io import BytesIO

client = self._get_client()

bio = BytesIO()
client.download_fileobj(self.bucket, self.key, bio)
bio.seek(0)
out = bio.read()
bio.close()
return out

def __truediv__(self, other):
new = copy.deepcopy(self)
new.path = self.path / other
new._trailing_slash = "/" if str(other).endswith("/") else ""
return new

def __repr__(self):
return f'S3Path("{self.get_path()}")'

def __str__(self):
return self.get_path()

def glob(self, pattern):
from ._s3 import list_bucket

full_pattern = str(self) + pattern
logger.debug(f"Searching {full_pattern}")
return list_bucket(full_bucket_glob=full_pattern)


def fix_s3_url(url):
"""Fix an S3 URL that has been altered by pathlib.
Will replace s3:/my-bucket/... with s3://my-bucket/...
"""
return re.sub(r"s3:/((?!/).*)", r"s3://\1", str(url))
17 changes: 14 additions & 3 deletions src/dolphin/io/_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from dolphin.io._blocks import iter_blocks

from ._background import _DEFAULT_TIMEOUT, BackgroundReader
from ._paths import S3Path

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -742,7 +743,10 @@ def __init__(

# files: list[Filename] = [Path(f) for f in file_list]
self._use_abs_path = use_abs_path
if use_abs_path:
files: list[Filename | S3Path]
if any(str(f).startswith("s3://") for f in file_list):
files = [S3Path(str(f)) for f in file_list]
elif use_abs_path:
files = [utils._resolve_gdal_path(p) for p in file_list]
else:
files = list(file_list)
Expand Down Expand Up @@ -822,12 +826,19 @@ def _write(self):
ds = None

@property
def _gdal_file_strings(self):
def _gdal_file_strings(self) -> list[str]:
"""Get the GDAL-compatible paths to write to the VRT.
If we're not using .h5 or .nc, this will just be the file_list as is.
"""
return [io.format_nc_filename(f, self.subdataset) for f in self.file_list]
out = []
for f in self.file_list:
if isinstance(f, S3Path):
s = str(f).replace("s3://", "/vsis3/")
else:
s = io.format_nc_filename(f, self.subdataset)
out.append(s)
return out

def __fspath__(self):
# Allows os.fspath() to work on the object, enabling rasterio.open()
Expand Down
66 changes: 66 additions & 0 deletions src/dolphin/io/_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import boto3


def list_bucket_boto3(
bucket: str | None = None,
prefix: str | None = None,
suffix: str | None = None,
full_bucket_glob: str | None = None,
aws_profile: str | None = None,
) -> list[str]:
"""List items in a bucket using boto3.
Parameters
----------
bucket : str, optional
Name of the bucket.
prefix : str, optional
Prefix to filter by.
suffix : str, optional
Suffix to filter by.
full_bucket_glob : str, optional
Full glob to filter by.
aws_profile : str, optional
AWS profile to use.
Returns
-------
List[str]
List of items in the bucket.
"""
session = (
boto3.Session(profile_name=aws_profile) if aws_profile else boto3.Session()
)
s3 = session.client("s3")
out: list[str] = []

# Determine the prefix for listing objects
if full_bucket_glob:
# If full_bucket_glob is provided, extract bucket and prefix from it
if full_bucket_glob.startswith("s3://"):
full_bucket_glob = full_bucket_glob[5:] # Remove 's3://'
bucket, *glob_prefix = full_bucket_glob.split("/", 1)
prefix = glob_prefix[0] if glob_prefix else ""

# Ensure bucket is specified
if not bucket:
raise ValueError("Bucket name must be specified")

paginator = s3.get_paginator("list_objects_v2")
operation_parameters = {"Bucket": bucket}
if prefix:
operation_parameters["Prefix"] = prefix

page_iterator = paginator.paginate(**operation_parameters)

for page in page_iterator:
if "Contents" in page:
for item in page["Contents"]:
key = item["Key"]
if suffix:
if key.endswith(suffix):
out.append(key)
else:
out.append(key)
return out

0 comments on commit f6a4de5

Please sign in to comment.