diff --git a/CHANGELOG.md b/CHANGELOG.md index e0ae1aed..1b24a933 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- Introduced type annotations + # 4.1.2, 18 Jan 2021 - Correctly pass boto3 resource to writers (PR [#576](https://github.com/RaRe-Technologies/smart_open/pull/576), [@jackluo923](https://github.com/jackluo923)) @@ -35,6 +37,7 @@ to install the AWS dependencies only, or pip install smart_open[all] to install all dependencies, including AWS, GCS, etc. +>>>>>>> upstream/develop # 2.2.1, 1 Oct 2020 diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..7c0e145d --- /dev/null +++ b/mypy.ini @@ -0,0 +1,31 @@ +[mypy] + +[mypy-smart_open.tests.*] +ignore_missing_imports = True + +# +# Third party libraries below +# +[mypy-azure.*] +ignore_missing_imports = True + +[mypy-boto3.*] +ignore_missing_imports = True + +[mypy-botocore.*] +ignore_missing_imports = True + +[mypy-google.*] +ignore_missing_imports = True + +[mypy-moto.*] +ignore_missing_imports = True + +[mypy-paramiko.*] +ignore_missing_imports = True + +[mypy-requests_kerberos.*] +ignore_missing_imports = True + +[mypy-responses.*] +ignore_missing_imports = True diff --git a/smart_open/azure.py b/smart_open/azure.py index a2ec54fc..75056c8f 100644 --- a/smart_open/azure.py +++ b/smart_open/azure.py @@ -9,9 +9,16 @@ """Implements file-like objects for reading and writing to/from Azure Blob Storage.""" import base64 +import collections import io import logging +from typing import ( + Dict, + IO, + List, +) + import smart_open.bytebuffer import smart_open.constants @@ -37,8 +44,10 @@ https://docs.microsoft.com/en-us/rest/api/storageservices/understanding-block-blobs--append-blobs--and-page-blobs """ +Uri = collections.namedtuple('Uri', 'scheme container_id blob_id') + -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Uri: sr = smart_open.utils.safe_urlsplit(uri_as_string) assert sr.scheme == SCHEME first = sr.netloc @@ -52,54 +61,47 @@ def parse_uri(uri_as_string): container_id = first blob_id = second - return dict(scheme=SCHEME, container_id=container_id, blob_id=blob_id) + return Uri(scheme=SCHEME, container_id=container_id, blob_id=blob_id) -def open_uri(uri, mode, transport_params): +def open_uri(uri: str, mode: str, transport_params: Dict) -> IO[bytes]: parsed_uri = parse_uri(uri) kwargs = smart_open.utils.check_kwargs(open, transport_params) - return open(parsed_uri['container_id'], parsed_uri['blob_id'], mode, **kwargs) + return open(parsed_uri.container_id, parsed_uri.blob_id, mode, **kwargs) def open( - container_id, - blob_id, - mode, - client=None, # type: azure.storage.blob.BlobServiceClient - buffer_size=DEFAULT_BUFFER_SIZE, - min_part_size=_DEFAULT_MIN_PART_SIZE - ): + container_id: str, + blob_id: str, + mode: str, + client: 'azure.storage.blob.BlobServiceClient' = None, + buffer_size: int = DEFAULT_BUFFER_SIZE, + min_part_size: int = _DEFAULT_MIN_PART_SIZE, +) -> IO[bytes]: """Open an Azure Blob Storage blob for reading or writing. Parameters ---------- - container_id: str - The name of the container this object resides in. - blob_id: str - The name of the blob within the bucket. - mode: str - The mode for opening the object. Must be either "rb" or "wb". - client: azure.storage.blob.BlobServiceClient - The Azure Blob Storage client to use when working with azure-storage-blob. - buffer_size: int, optional - The buffer size to use when performing I/O. For reading only. - min_part_size: int, optional - The minimum part size for multipart uploads. For writing only. + :param container_id: The name of the container this object resides in. + :param blob_id: The name of the blob within the bucket. + :param mode: The mode for opening the object. Must be either "rb" or "wb". + :param client: The Azure Blob Storage client to use when working with azure-storage-blob. + :param buffer_size: The buffer size to use when performing I/O. For reading only. + :param min_part_size: The minimum part size for multipart uploads. For writing only. """ if not client: raise ValueError('you must specify the client to connect to Azure') if mode == smart_open.constants.READ_BINARY: - return Reader( + return Reader( # type: ignore container_id, blob_id, client, buffer_size=buffer_size, - line_terminator=smart_open.constants.BINARY_NEWLINE, ) elif mode == smart_open.constants.WRITE_BINARY: - return Writer( + return Writer( # type: ignore container_id, blob_id, client, @@ -112,30 +114,29 @@ def open( class _RawReader(object): """Read an Azure Blob Storage file.""" - def __init__(self, blob, size): - # type: (azure.storage.blob.BlobClient, int) -> None + def __init__(self, blob: 'azure.storage.blob.BlobClient', size: int) -> None: self._blob = blob self._size = size self._position = 0 - def seek(self, position): + def seek(self, position: int) -> int: """Seek to the specified position (byte offset) in the Azure Blob Storage blob. - :param int position: The byte offset from the beginning of the blob. + :param position: The byte offset from the beginning of the blob. Returns the position after seeking. """ self._position = position return self._position - def read(self, size=-1): + def read(self, size: int = -1) -> bytes: if self._position >= self._size: return b'' binary = self._download_blob_chunk(size) self._position += len(binary) return binary - def _download_blob_chunk(self, size): + def _download_blob_chunk(self, size: int) -> bytes: if self._size == self._position: # # When reading, we can't seek to the first byte of an empty file. @@ -162,16 +163,16 @@ class Reader(io.BufferedIOBase): """ def __init__( - self, - container, - blob, - client, # type: azure.storage.blob.BlobServiceClient - buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=smart_open.constants.BINARY_NEWLINE, - ): - self._container_client = client.get_container_client(container) - # type: azure.storage.blob.ContainerClient - + self, + container: str, + blob: str, + client: 'azure.storage.blob.BlobServiceClient', + buffer_size: int = DEFAULT_BUFFER_SIZE, + line_terminator: bytes = smart_open.constants.BINARY_NEWLINE, + ) -> None: + self._container_client: azure.storage.blob.ContainerClient = client.get_container_client(container) + + self.name = blob self._blob = self._container_client.get_blob_client(blob) if self._blob is None: raise azure.core.exceptions.ResourceNotFoundError( @@ -190,7 +191,7 @@ def __init__( # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore # # Override some methods from io.IOBase. @@ -309,7 +310,7 @@ def readline(self, limit=-1): # # Internal methods. # - def _read_from_buffer(self, size=-1): + def _read_from_buffer(self, size: int = -1) -> bytes: """Remove at most size bytes from our buffer and return them.""" # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part)) size = size if size >= 0 else len(self._current_part) @@ -318,13 +319,14 @@ def _read_from_buffer(self, size=-1): # logger.debug('part: %r', part) return part - def _fill_buffer(self, size=-1): + def _fill_buffer(self, size: int = -1) -> bool: size = max(size, self._current_part._chunk_size) while len(self._current_part) < size and not self._position == self._size: bytes_read = self._current_part.fill(self._raw_reader) if bytes_read == 0: logger.debug('reached EOF while filling buffer') return True + return False def __enter__(self): return self @@ -349,28 +351,28 @@ class Writer(io.BufferedIOBase): Implements the io.BufferedIOBase interface of the standard library.""" def __init__( - self, - container, - blob, - client, # type: azure.storage.blob.BlobServiceClient - min_part_size=_DEFAULT_MIN_PART_SIZE, - ): + self, + container: str, + blob: str, + client: 'azure.storage.blob.BlobServiceClient', + min_part_size: int = _DEFAULT_MIN_PART_SIZE, + ) -> None: self._client = client - self._container_client = self._client.get_container_client(container) - # type: azure.storage.blob.ContainerClient - self._blob = self._container_client.get_blob_client(blob) # type: azure.storage.blob.BlobClient + self._container_client: azure.storage.blob.ContainerClient = self._client.get_container_client(container) # noqa + self.name = blob + self._blob: azure.storage.blob.BlobClient = self._container_client.get_blob_client(blob) self._min_part_size = min_part_size self._total_size = 0 self._total_parts = 0 self._bytes_uploaded = 0 self._current_part = io.BytesIO() - self._block_list = [] + self._block_list: List['azure.storage.blob.BlobBlock'] = [] # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore def flush(self): pass @@ -423,7 +425,7 @@ def write(self, b): return len(b) - def _upload_part(self): + def _upload_part(self) -> None: part_num = self._total_parts + 1 content_length = self._current_part.tell() range_stop = self._bytes_uploaded + content_length - 1 diff --git a/smart_open/compression.py b/smart_open/compression.py index aa8b689c..0e925aec 100644 --- a/smart_open/compression.py +++ b/smart_open/compression.py @@ -9,10 +9,17 @@ import logging import os.path +from typing import ( + Callable, + Dict, + IO, +) + logger = logging.getLogger(__name__) -_COMPRESSOR_REGISTRY = {} +Compressor = Callable[[IO, str], IO] +_COMPRESSOR_REGISTRY: Dict[str, Compressor] = {} def get_supported_extensions(): diff --git a/smart_open/doctools.py b/smart_open/doctools.py index daa2bc01..55f75d0e 100644 --- a/smart_open/doctools.py +++ b/smart_open/doctools.py @@ -17,86 +17,80 @@ import os.path import re -from . import compression -from . import transport +from typing import ( + Callable, + List, + Tuple, +) + +from smart_open import ( + compression, + transport, +) PLACEHOLDER = ' smart_open/doctools.py magic goes here' -def extract_kwargs(docstring): +def extract_kwargs(function: Callable) -> List[Tuple[str, str, List[str]]]: """Extract keyword argument documentation from a function's docstring. Parameters ---------- - docstring: str - The docstring to extract keyword arguments from. + :param function: The function to extract keyword arguments from. Returns ------- - list of (str, str, list str) - - str - The name of the keyword argument. - str - Its type. - str - Its documentation as a list of lines. - - Notes - ----- - The implementation is rather fragile. It expects the following: - - 1. The parameters are under an underlined Parameters section - 2. Keyword parameters have the literal ", optional" after the type - 3. Names and types are not indented - 4. Descriptions are indented with 4 spaces - 5. The Parameters section ends with an empty line. + A list containing a tuple for each keyword argument: its name, type, + and documentation as a list of lines. Examples -------- - >>> docstring = '''The foo function. - ... Parameters - ... ---------- - ... bar: str, optional - ... This parameter is the bar. - ... baz: int, optional - ... This parameter is the baz. + >>> def fun(bar: str = 'bar', baz: int = 0) -> str: + ... '''The foo function. + ... :param bar: This parameter is the bar. + ... It does stuff. + ... :param baz: This parameter is the baz. + ... ''' ... - ... ''' - >>> kwargs = extract_kwargs(docstring) + >>> kwargs = extract_kwargs(fun) >>> kwargs[0] - ('bar', 'str, optional', ['This parameter is the bar.']) + ('bar', 'str', ['This parameter is the bar.', 'It does stuff.']) """ + docstring = getattr(function, '__doc__') if not docstring: return [] - lines = inspect.cleandoc(docstring).split('\n') - retval = [] - # - # 1. Find the underlined 'Parameters' section - # 2. Once there, continue parsing parameters until we hit an empty line + # NB v.annotation can either be a class or a string. # - while lines and lines[0] != 'Parameters': - lines.pop(0) + signature = inspect.signature(function) + types = { + k: getattr(v.annotation, '__name__', v.annotation) + for (k, v) in signature.parameters.items() + } + lines = inspect.cleandoc(docstring).split('\n') - if not lines: - return [] + def g(): + name = None + description = None + + for line in lines: + if line.startswith(':param '): + if name and description: + yield name, types[name], description - lines.pop(0) - lines.pop(0) + name, tmp_description = line[6:].split(':', 1) + name = name.strip() + description = [tmp_description.strip()] + elif line and line[0].isspace() and description: + description.append(line.strip()) - while lines and lines[0]: - name, type_ = lines.pop(0).split(':', 1) - description = [] - while lines and lines[0].startswith(' '): - description.append(lines.pop(0).strip()) - if 'optional' in type_: - retval.append((name.strip(), type_.strip(), description)) + if name and description: + yield name, types[name], description - return retval + return list(g()) def to_docstring(kwargs, lpad=''): @@ -132,7 +126,7 @@ def to_docstring(kwargs, lpad=''): """ buf = io.StringIO() for name, type_, description in kwargs: - buf.write('%s%s: %s\n' % (lpad, name, type_)) + buf.write('%s:param %s %s:\n' % (lpad, type_, name)) for line in description: buf.write('%s %s\n' % (lpad, line)) return buf.getvalue() @@ -168,7 +162,7 @@ def extract_examples_from_readme_rst(indent=' '): return indent + 'See README.rst' -def tweak_open_docstring(f): +def tweak_open_docstring(f: Callable) -> None: buf = io.StringIO() seen = set() @@ -180,16 +174,23 @@ def tweak_open_docstring(f): for scheme, submodule in sorted(transport._REGISTRY.items()): if scheme == transport.NO_SCHEME or submodule in seen: continue + seen.add(submodule) + if not submodule.__doc__ or not hasattr(submodule, 'open'): + continue + relpath = os.path.relpath(submodule.__file__, start=root_path) heading = '%s (%s)' % (scheme, relpath) print(' %s' % heading) print(' %s' % ('~' * len(heading))) + + assert submodule.__doc__ print(' %s' % submodule.__doc__.split('\n')[0]) print() - kwargs = extract_kwargs(submodule.open.__doc__) + assert hasattr(submodule, 'open') + kwargs = extract_kwargs(submodule.open) # type: ignore if kwargs: print(to_docstring(kwargs, lpad=u' ')) diff --git a/smart_open/gcs.py b/smart_open/gcs.py index 8cf2edde..b2072af4 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -8,6 +8,7 @@ """Implements file-like objects for reading and writing to/from GCS.""" +import collections import io import logging @@ -18,6 +19,14 @@ except ImportError: MISSING_DEPS = True +from typing import ( + Dict, + IO, + Optional, + Tuple, + Union, +) + import smart_open.bytebuffer import smart_open.utils @@ -46,32 +55,33 @@ _UPLOAD_COMPLETE_STATUS_CODES = (200, 201) -def _make_range_string(start, stop=None, end=None): +Uri = collections.namedtuple('Uri', 'scheme bucket_id blob_id') + + +def _make_range_string(start: int, stop: Optional[int] = None, end: Optional[int] = None) -> str: # # GCS seems to violate RFC-2616 (see utils.make_range_string), so we # need a separate implementation. # # https://cloud.google.com/storage/docs/xml-api/resumable-upload#step_3upload_the_file_blocks # + end_str = str(end) if end is None: - end = _UNKNOWN + end_str = _UNKNOWN if stop is None: - return 'bytes %d-/%s' % (start, end) - return 'bytes %d-%d/%s' % (start, stop, end) + return 'bytes %d-/%s' % (start, end_str) + return 'bytes %d-%d/%s' % (start, stop, end_str) class UploadFailedError(Exception): - def __init__(self, message, status_code, text): + def __init__(self, message: str, status_code: int, text: str) -> None: """Raise when a multi-part upload to GCS returns a failed response status code. Parameters ---------- - message: str - The error message to display. - status_code: int - The status code returned from the upload response. - text: str - The text returned from the upload response. + :param message: The error message to display. + :param status_code: The status code returned from the upload response. + :param text: The text returned from the upload response. """ super(UploadFailedError, self).__init__(message) @@ -91,46 +101,42 @@ def _fail(response, part_num, content_length, total_size, headers): raise UploadFailedError(msg, response.status_code, response.text) -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Uri: sr = smart_open.utils.safe_urlsplit(uri_as_string) assert sr.scheme == SCHEME bucket_id = sr.netloc blob_id = sr.path.lstrip('/') - return dict(scheme=SCHEME, bucket_id=bucket_id, blob_id=blob_id) + return Uri(scheme=SCHEME, bucket_id=bucket_id, blob_id=blob_id) -def open_uri(uri, mode, transport_params): +def open_uri(uri: str, mode: str, transport_params: Dict) -> IO[bytes]: parsed_uri = parse_uri(uri) kwargs = smart_open.utils.check_kwargs(open, transport_params) - return open(parsed_uri['bucket_id'], parsed_uri['blob_id'], mode, **kwargs) + return open(parsed_uri.bucket_id, parsed_uri.blob_id, mode, **kwargs) def open( - bucket_id, - blob_id, - mode, - buffer_size=DEFAULT_BUFFER_SIZE, - min_part_size=_MIN_MIN_PART_SIZE, - client=None, # type: google.cloud.storage.Client - ): + bucket_id: str, + blob_id: str, + mode: str, + buffer_size: int = DEFAULT_BUFFER_SIZE, + min_part_size: int = _MIN_MIN_PART_SIZE, + client: Optional['google.cloud.storage.Client'] = None, +) -> IO[bytes]: """Open an GCS blob for reading or writing. Parameters ---------- - bucket_id: str - The name of the bucket this object resides in. - blob_id: str - The name of the blob within the bucket. - mode: str - The mode for opening the object. Must be either "rb" or "wb". - buffer_size: int, optional - The buffer size to use when performing I/O. For reading only. - min_part_size: int, optional - The minimum part size for multipart uploads. For writing only. - client: google.cloud.storage.Client, optional - The GCS client to use when working with google-cloud-storage. + :param bucket_id: The name of the bucket this object resides in. + :param blob_id: The name of the blob within the bucket. + :param mode: The mode for opening the object. Must be either "rb" or "wb". + :param buffer_size: The buffer size to use when performing I/O. For reading only. + :param min_part_size: The minimum part size for multipart uploads. For writing only. + :param client: The GCS client to use when working with google-cloud-storage. """ + fileobj: Union[Reader, Writer, None] = None + if mode == constants.READ_BINARY: fileobj = Reader( bucket_id, @@ -149,15 +155,20 @@ def open( else: raise NotImplementedError('GCS support for mode %r not implemented' % mode) - fileobj.name = blob_id - return fileobj + assert hasattr(fileobj, 'name') + + # + # FIXME: not sure why mypy is unhappy about the line below. + # Both Writer and Reader inherit from io.BufferedIOBase, so they should + # behave like IO objects as far as typing is concerned. + # + return fileobj # type: ignore class _RawReader(object): """Read an GCS object.""" - def __init__(self, gcs_blob, size): - # type: (google.cloud.storage.Blob, int) -> None + def __init__(self, gcs_blob: 'google.cloud.storage.Blob', size: int) -> None: self._blob = gcs_blob self._size = size self._position = 0 @@ -179,7 +190,7 @@ def read(self, size=-1): self._position += len(binary) return binary - def _download_blob_chunk(self, size): + def _download_blob_chunk(self, size: int) -> bytes: start = position = self._position if position == self._size: # @@ -204,17 +215,18 @@ class Reader(io.BufferedIOBase): """ def __init__( - self, - bucket, - key, - buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=constants.BINARY_NEWLINE, - client=None, # type: google.cloud.storage.Client + self, + bucket: str, + key: str, + buffer_size: int = DEFAULT_BUFFER_SIZE, + line_terminator: bytes = constants.BINARY_NEWLINE, + client: Optional['google.cloud.storage.Client'] = None, ): if client is None: client = google.cloud.storage.Client() - self._blob = client.bucket(bucket).get_blob(key) # type: google.cloud.storage.Blob + self.name = key + self._blob: google.cloud.storage.Blob = client.bucket(bucket).get_blob(key) if self._blob is None: raise google.cloud.exceptions.NotFound('blob %s not found in %s' % (key, bucket)) @@ -231,7 +243,7 @@ def __init__( # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore # # Override some methods from io.IOBase. @@ -359,7 +371,7 @@ def readline(self, limit=-1): # # Internal methods. # - def _read_from_buffer(self, size=-1): + def _read_from_buffer(self, size: int = -1) -> bytes: """Remove at most size bytes from our buffer and return them.""" # logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part)) size = size if size >= 0 else len(self._current_part) @@ -368,7 +380,7 @@ def _read_from_buffer(self, size=-1): # logger.debug('part: %r', part) return part - def _fill_buffer(self, size=-1): + def _fill_buffer(self, size: int = -1) -> None: size = size if size >= 0 else self._current_part._chunk_size while len(self._current_part) < size and not self._eof: bytes_read = self._current_part.fill(self._raw_reader) @@ -391,16 +403,17 @@ class Writer(io.BufferedIOBase): Implements the io.BufferedIOBase interface of the standard library.""" def __init__( - self, - bucket, - blob, - min_part_size=_DEFAULT_MIN_PART_SIZE, - client=None, # type: google.cloud.storage.Client + self, + bucket: str, + blob: str, + min_part_size: int = _DEFAULT_MIN_PART_SIZE, + client: Optional['google.cloud.storage.Client'] = None, ): + self.name = blob if client is None: client = google.cloud.storage.Client() self._client = client - self._blob = self._client.bucket(bucket).blob(blob) # type: google.cloud.storage.Blob + self._blob: google.cloud.storage.Blob = self._client.bucket(bucket).blob(blob) assert min_part_size % _REQUIRED_CHUNK_MULTIPLE == 0, 'min part size must be a multiple of 256KB' assert min_part_size >= _MIN_MIN_PART_SIZE, 'min part size must be greater than 256KB' self._min_part_size = min_part_size @@ -420,7 +433,7 @@ def __init__( # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore def flush(self): pass @@ -489,7 +502,7 @@ def terminate(self): # # Internal methods. # - def _upload_part(self, is_last=False): + def _upload_part(self, is_last: bool = False) -> None: part_num = self._total_parts + 1 # @@ -505,6 +518,8 @@ def _upload_part(self, is_last=False): # content_length = self._current_part.tell() remainder = content_length % self._min_part_size + + end: Optional[int] = None if is_last: end = self._bytes_uploaded + content_length elif remainder == 0: @@ -531,10 +546,10 @@ def _upload_part(self, is_last=False): headers=headers, ) + expected: Tuple = _UPLOAD_INCOMPLETE_STATUS_CODES if is_last: expected = _UPLOAD_COMPLETE_STATUS_CODES - else: - expected = _UPLOAD_INCOMPLETE_STATUS_CODES + if response.status_code not in expected: _fail(response, part_num, content_length, self._total_size, headers) logger.debug("upload of part #%i finished" % part_num) @@ -548,7 +563,7 @@ def _upload_part(self, is_last=False): self._current_part = io.BytesIO(self._current_part.read()) self._current_part.seek(0, io.SEEK_END) - def _upload_empty_part(self): + def _upload_empty_part(self) -> None: logger.debug("creating empty file") headers = {'Content-Length': '0'} response = self._session.put(self._resumable_upload_url, headers=headers) diff --git a/smart_open/hdfs.py b/smart_open/hdfs.py index a4d892cd..a61c8de4 100644 --- a/smart_open/hdfs.py +++ b/smart_open/hdfs.py @@ -14,6 +14,7 @@ """ +import collections import io import logging import subprocess @@ -30,8 +31,10 @@ 'hdfs://path/file', ) +Uri = collections.namedtuple('Uri', 'scheme uri_path') -def parse_uri(uri_as_string): + +def parse_uri(uri_as_string: str) -> Uri: split_uri = urllib.parse.urlsplit(uri_as_string) assert split_uri.scheme == SCHEME @@ -40,15 +43,15 @@ def parse_uri(uri_as_string): if not uri_path: raise RuntimeError("invalid HDFS URI: %r" % uri_as_string) - return dict(scheme=SCHEME, uri_path=uri_path) + return Uri(scheme=SCHEME, uri_path=uri_path) def open_uri(uri, mode, transport_params): utils.check_kwargs(open, transport_params) parsed_uri = parse_uri(uri) - fobj = open(parsed_uri['uri_path'], mode) - fobj.name = parsed_uri['uri_path'].split('/')[-1] + fobj = open(parsed_uri.uri_path, mode) + fobj.name = parsed_uri.uri_path.split('/')[-1] return fobj diff --git a/smart_open/http.py b/smart_open/http.py index f68ca93c..5ea7dbfe 100644 --- a/smart_open/http.py +++ b/smart_open/http.py @@ -7,6 +7,7 @@ # """Implements file-like objects for reading from http.""" +import collections import io import logging import os.path @@ -25,6 +26,7 @@ logger = logging.getLogger(__name__) +Uri = collections.namedtuple('Uri', 'scheme uri_path') _HEADERS = {'Accept-Encoding': 'identity'} """The headers we send to the server with every HTTP request. @@ -35,13 +37,13 @@ """ -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Uri: split_uri = urllib.parse.urlsplit(uri_as_string) assert split_uri.scheme in SCHEMES uri_path = split_uri.netloc + split_uri.path uri_path = "/" + uri_path.lstrip("/") - return dict(scheme=split_uri.scheme, uri_path=uri_path) + return Uri(scheme=split_uri.scheme, uri_path=uri_path) def open_uri(uri, mode, transport_params): diff --git a/smart_open/local_file.py b/smart_open/local_file.py index e5f5c5aa..9cdcc93b 100644 --- a/smart_open/local_file.py +++ b/smart_open/local_file.py @@ -6,6 +6,7 @@ # from the MIT License (MIT). # """Implements the transport for the file:// schema.""" +import collections import io import os.path @@ -20,18 +21,19 @@ 'file:///home/user/file.bz2', ) +Uri = collections.namedtuple('Uri', 'scheme uri_path') open = io.open -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Uri: local_path = extract_local_path(uri_as_string) - return dict(scheme=SCHEME, uri_path=local_path) + return Uri(scheme=SCHEME, uri_path=local_path) def open_uri(uri_as_string, mode, transport_params): parsed_uri = parse_uri(uri_as_string) - fobj = io.open(parsed_uri['uri_path'], mode) + fobj = io.open(parsed_uri.uri_path, mode) return fobj diff --git a/smart_open/s3.py b/smart_open/s3.py index 279803af..60bfbe78 100644 --- a/smart_open/s3.py +++ b/smart_open/s3.py @@ -7,11 +7,24 @@ # """Implements file-like objects for reading and writing from/to AWS S3.""" +import collections import io import functools import logging import time +from typing import ( + Any, + Callable, + Dict, + IO, + Iterator, + List, + Optional, + Tuple, + Union, +) + try: import boto3 import botocore.client @@ -26,6 +39,8 @@ from smart_open import constants +Kwargs = Dict[str, Any] + logger = logging.getLogger(__name__) DEFAULT_MIN_PART_SIZE = 50 * 1024**2 @@ -51,8 +66,22 @@ # Returned by AWS when we try to seek beyond EOF. _OUT_OF_RANGE = 'InvalidRange' +Uri = collections.namedtuple( + 'Uri', + [ + 'scheme', + 'bucket_id', + 'key_id', + 'port', + 'host', + 'ordinary_calling_format', + 'access_id', + 'access_secret', + ] +) + -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Uri: # # Restrictions on bucket names and labels: # @@ -92,15 +121,15 @@ def parse_uri(uri_as_string): if '@' in head and ':' in head: ordinary_calling_format = True host_port, bucket_id = head.split('@') - host, port = host_port.split(':', 1) - port = int(port) + host, port_str = host_port.split(':', 1) + port = int(port_str) elif '@' in head: ordinary_calling_format = True host, bucket_id = head.split('@') else: bucket_id = head - return dict( + return Uri( scheme=split_uri.scheme, bucket_id=bucket_id, key_id=key_id, @@ -112,7 +141,7 @@ def parse_uri(uri_as_string): ) -def _consolidate_params(uri, transport_params): +def _consolidate_params(uri: Uri, transport_params: Kwargs) -> Tuple[Uri, Kwargs]: """Consolidates the parsed Uri with the additional parameters. This is necessary because the user can pass some of the parameters can in @@ -128,28 +157,28 @@ def _consolidate_params(uri, transport_params): transport_params = dict(transport_params) session = transport_params.get('session') - if session is not None and (uri['access_id'] or uri['access_secret']): + if session is not None and (uri.access_id or uri.access_secret): logger.warning( 'ignoring credentials parsed from URL because they conflict with ' 'transport_params["session"]. Set transport_params["session"] to None ' 'to suppress this warning.' ) - uri.update(access_id=None, access_secret=None) - elif (uri['access_id'] and uri['access_secret']): + uri = uri._replace(access_id=None, access_secret=None) + elif (uri.access_id and uri.access_secret): transport_params['session'] = boto3.Session( - aws_access_key_id=uri['access_id'], - aws_secret_access_key=uri['access_secret'], + aws_access_key_id=uri.access_id, + aws_secret_access_key=uri.access_secret, ) - uri.update(access_id=None, access_secret=None) + uri = uri._replace(access_id=None, access_secret=None) - if uri['host'] != DEFAULT_HOST: - endpoint_url = 'https://%(host)s:%(port)d' % uri + if uri.host != DEFAULT_HOST: + endpoint_url = 'https://%s:%d' % (uri.host, uri.port) _override_endpoint_url(transport_params, endpoint_url) return uri, transport_params -def _override_endpoint_url(transport_params, url): +def _override_endpoint_url(transport_params: Kwargs, url: str) -> None: try: resource_kwargs = transport_params['resource_kwargs'] except KeyError: @@ -164,73 +193,70 @@ def _override_endpoint_url(transport_params, url): resource_kwargs.update(endpoint_url=url) -def open_uri(uri, mode, transport_params): +def open_uri(uri: str, mode: str, transport_params: Kwargs) -> IO[bytes]: parsed_uri = parse_uri(uri) parsed_uri, transport_params = _consolidate_params(parsed_uri, transport_params) kwargs = smart_open.utils.check_kwargs(open, transport_params) - return open(parsed_uri['bucket_id'], parsed_uri['key_id'], mode, **kwargs) + return open(parsed_uri.bucket_id, parsed_uri.key_id, mode, **kwargs) def open( - bucket_id, - key_id, - mode, - version_id=None, - buffer_size=DEFAULT_BUFFER_SIZE, - min_part_size=DEFAULT_MIN_PART_SIZE, - session=None, - resource=None, - resource_kwargs=None, - multipart_upload_kwargs=None, - multipart_upload=True, - singlepart_upload_kwargs=None, - object_kwargs=None, - defer_seek=False, -): + bucket_id: str, + key_id: str, + mode: str, + version_id: Optional[str] = None, + buffer_size: int = DEFAULT_BUFFER_SIZE, + min_part_size: int = DEFAULT_MIN_PART_SIZE, + session: Optional['boto3.Session'] = None, + resource: Optional['boto3.resource'] = None, + resource_kwargs: dict = None, + multipart_upload_kwargs: Optional[dict] = None, + multipart_upload: bool = True, + singlepart_upload_kwargs: Optional[dict] = None, + object_kwargs: Optional[dict] = None, + defer_seek: bool = False, +) -> IO[bytes]: """Open an S3 object for reading or writing. Parameters ---------- - bucket_id: str - The name of the bucket this object resides in. - key_id: str + :param bucket_id: The name of the bucket this object resides in. + :param key_id: The name of the key within the bucket. - mode: str + :param mode: The mode for opening the object. Must be either "rb" or "wb". - buffer_size: int, optional + :param buffer_size: The buffer size to use when performing I/O. - min_part_size: int, optional + :param min_part_size: The minimum part size for multipart uploads. For writing only. - session: object, optional + :param session: The S3 session to use when working with boto3. If you don't specify this, then smart_open will create a new session for you. - resource: object, optional + :param resource: The S3 resource to use when working with boto3. If you don't specify this, then smart_open will create a new resource for you. - resource_kwargs: dict, optional + :param resource_kwargs: Keyword arguments to use when creating the S3 resource for reading or writing. Will be ignored if you specify the resource object explicitly. - multipart_upload_kwargs: dict, optional + :param multipart_upload_kwargs: Additional parameters to pass to boto3's initiate_multipart_upload function. For writing only. - singlepart_upload_kwargs: dict, optional + :param singlepart_upload_kwargs: Additional parameters to pass to boto3's S3.Object.put function when using single part upload. For writing only. - multipart_upload: bool, optional - Default: `True` + :param multipart_upload: If set to `True`, will use multipart upload for writing to S3. If set to `False`, S3 upload will use the S3 Single-Part Upload API, which is more ideal for small file sizes. For writing only. - version_id: str, optional + :param version_id: Version of the object, used when reading object. If None, will fetch the most recent version. - object_kwargs: dict, optional + :param object_kwargs: Additional parameters to pass to boto3's object.get function. Used during reading only. - defer_seek: boolean, optional - Default: `False` + :param defer_seek: If set to `True` on a file opened for reading, GetObject will not be called until the first seek() or read(). Avoids redundant API queries when seeking before reading. @@ -242,6 +268,8 @@ def open( if (mode == constants.WRITE_BINARY) and (version_id is not None): raise ValueError("version_id must be None when writing") + fileobj: Union[Reader, SinglepartWriter, MultipartWriter, None] = None + if mode == constants.READ_BINARY: fileobj = Reader( bucket_id, @@ -254,34 +282,33 @@ def open( object_kwargs=object_kwargs, defer_seek=defer_seek, ) + elif mode == constants.WRITE_BINARY and multipart_upload: + fileobj = MultipartWriter( + bucket_id, + key_id, + min_part_size=min_part_size, + session=session, + upload_kwargs=multipart_upload_kwargs, + resource=resource, + resource_kwargs=resource_kwargs, + ) elif mode == constants.WRITE_BINARY: - if multipart_upload: - fileobj = MultipartWriter( - bucket_id, - key_id, - min_part_size=min_part_size, - session=session, - resource=resource, - upload_kwargs=multipart_upload_kwargs, - resource_kwargs=resource_kwargs, - ) - else: - fileobj = SinglepartWriter( - bucket_id, - key_id, - session=session, - resource=resource, - upload_kwargs=singlepart_upload_kwargs, - resource_kwargs=resource_kwargs, - ) + fileobj = SinglepartWriter( + bucket_id, + key_id, + session=session, + upload_kwargs=singlepart_upload_kwargs, + resource=resource, + resource_kwargs=resource_kwargs, + ) else: assert False, 'unexpected mode: %r' % mode - fileobj.name = key_id - return fileobj + assert fileobj + return fileobj # type: ignore -def _get(s3_object, version=None, **kwargs): +def _get(s3_object: 'boto3.s3.Object', version: Optional[str] = None, **kwargs) -> Any: if version is not None: kwargs['VersionId'] = version try: @@ -292,14 +319,14 @@ def _get(s3_object, version=None, **kwargs): s3_object.bucket_name, s3_object.key, version, error ) ) - wrapped_error.backend_error = error + wrapped_error.backend_error = error # type: ignore raise wrapped_error from error -def _unwrap_ioerror(ioe): +def _unwrap_ioerror(ioe: IOError) -> Optional[Dict]: """Given an IOError from _get, return the 'Error' dictionary from boto.""" try: - return ioe.backend_error.response['Error'] + return ioe.backend_error.response['Error'] # type: ignore except (AttributeError, KeyError): return None @@ -312,28 +339,20 @@ class _SeekableRawReader(object): def __init__( self, - s3_object, - version_id=None, - object_kwargs=None, - ): + s3_object: 'boto3.s3.Object', + version_id: Optional[str] = None, + object_kwargs: Optional[Kwargs] = None, + ) -> None: self._object = s3_object - self._content_length = None + self._content_length: Optional[int] = None self._version_id = version_id self._position = 0 - self._body = None + self._body: Optional[io.BytesIO] = None self._object_kwargs = object_kwargs if object_kwargs else {} - def seek(self, offset, whence=constants.WHENCE_START): - """Seek to the specified position. - - :param int offset: The offset in bytes. - :param int whence: Where the offset is from. - - :returns: the position after seeking. - :rtype: int - """ + def seek(self, offset: int, whence: int = constants.WHENCE_START) -> int: if whence not in constants.WHENCE_CHOICES: - raise ValueError('invalid whence, expected one of %r' % constants.WHENCE_CHOICES) + raise ValueError('invalid whence, expected one of %r' % list(constants.WHENCE_CHOICES)) # # Close old body explicitly. @@ -367,13 +386,15 @@ def seek(self, offset, whence=constants.WHENCE_START): if reached_eof: self._body = io.BytesIO() + + assert self._content_length self._position = self._content_length else: self._open_body(start, stop) return self._position - def _open_body(self, start=None, stop=None): + def _open_body(self, start: Optional[int] = None, stop: Optional[int] = None) -> None: """Open a connection to download the specified range of bytes. Store the open file handle in self._body. @@ -420,11 +441,13 @@ def _open_body(self, start=None, stop=None): self._position = start self._body = response['Body'] - def read(self, size=-1): + def read(self, size: int = -1) -> bytes: """Read from the continuous connection with the remote peer.""" if self._body is None: # This is necessary for the very first read() after __init__(). self._open_body() + + assert self._content_length if self._position >= self._content_length: return b'' @@ -514,17 +537,18 @@ class Reader(io.BufferedIOBase): def __init__( self, - bucket, - key, - version_id=None, - buffer_size=DEFAULT_BUFFER_SIZE, - line_terminator=constants.BINARY_NEWLINE, - session=None, - resource=None, - resource_kwargs=None, - object_kwargs=None, - defer_seek=False, - ): + bucket: str, + key: str, + version_id: Optional[str] = None, + buffer_size: int = DEFAULT_BUFFER_SIZE, + line_terminator: bytes = constants.BINARY_NEWLINE, + session: Optional['boto3.Session'] = None, + resource: Optional['boto3.resource'] = None, + resource_kwargs: Optional[Kwargs] = None, + object_kwargs: Optional[Kwargs] = None, + defer_seek: bool = False, + ) -> None: + self.name = key self._buffer_size = buffer_size if resource_kwargs is None: @@ -551,7 +575,7 @@ def __init__( # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore if not defer_seek: self.seek(0) @@ -636,8 +660,8 @@ def seekable(self): def seek(self, offset, whence=constants.WHENCE_START): """Seek to the specified position. - :param int offset: The offset in bytes. - :param int whence: Where the offset is from. + :param offset: The offset in bytes. + :param whence: Where the offset is from. Returns the position after seeking.""" # Convert relative offset to absolute, since self._raw_reader @@ -668,7 +692,7 @@ def terminate(self): """Do nothing.""" pass - def to_boto3(self): + def to_boto3(self) -> 'boto3.s3.Object': """Create an **independent** `boto3.s3.Object` instance that points to the same resource as this instance. @@ -688,14 +712,14 @@ def to_boto3(self): # # Internal methods. # - def _read_from_buffer(self, size=-1): + def _read_from_buffer(self, size: int = -1) -> bytes: """Remove at most size bytes from our buffer and return them.""" size = size if size >= 0 else len(self._buffer) part = self._buffer.read(size) self._current_pos += len(part) return part - def _fill_buffer(self, size=-1): + def _fill_buffer(self, size: int = -1) -> None: size = max(size, self._buffer._chunk_size) while len(self._buffer) < size and not self._eof: bytes_read = self._buffer.fill(self._raw_reader) @@ -736,14 +760,14 @@ class MultipartWriter(io.BufferedIOBase): def __init__( self, - bucket, - key, - min_part_size=DEFAULT_MIN_PART_SIZE, - session=None, - resource=None, - resource_kwargs=None, - upload_kwargs=None, - ): + bucket: str, + key: str, + min_part_size: int = DEFAULT_MIN_PART_SIZE, + session: Optional['boto3.Session'] = None, + resource: Optional['boto3.resource'] = None, + resource_kwargs: Optional[Kwargs] = None, + upload_kwargs: Optional[Kwargs] = None, + ) -> None: if min_part_size < MIN_MIN_PART_SIZE: logger.warning("S3 requires minimum part size >= 5MB; \ multipart upload may fail") @@ -770,12 +794,12 @@ def __init__( self._buf = io.BytesIO() self._total_bytes = 0 self._total_parts = 0 - self._parts = [] + self._parts: List[Dict] = [] # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore def flush(self): pass @@ -846,7 +870,7 @@ def terminate(self): self._mp.abort() self._mp = None - def to_boto3(self): + def to_boto3(self) -> 'boto3.s3.Object': """Create an **independent** `boto3.s3.Object` instance that points to the same resource as this instance. @@ -860,7 +884,7 @@ def to_boto3(self): # # Internal methods. # - def _upload_next_part(self): + def _upload_next_part(self) -> None: part_num = self._total_parts + 1 logger.info( "%s: uploading part_num: %i, %i bytes (total %.3fGB)", @@ -923,14 +947,15 @@ class SinglepartWriter(io.BufferedIOBase): the data be written to S3 and the buffer is released.""" def __init__( - self, - bucket, - key, - session=None, - resource=None, - resource_kwargs=None, - upload_kwargs=None, - ): + self, + bucket: str, + key: str, + session: Optional['boto3.Session'] = None, + resource: Optional['boto3.resource'] = None, + resource_kwargs: Optional[Kwargs] = None, + upload_kwargs: Optional[Kwargs] = None, + ) -> None: + self.name = key _initialize_boto3(self, session, resource, resource_kwargs) @@ -951,7 +976,7 @@ def __init__( # # This member is part of the io.BufferedIOBase interface. # - self.raw = None + self.raw = None # type: ignore def flush(self): pass @@ -992,7 +1017,7 @@ def tell(self): def detach(self): raise io.UnsupportedOperation("detach() not supported") - def write(self, b): + def write(self, b: bytes) -> int: """Write the given buffer (bytes, bytearray, memoryview or any buffer interface implementation) into the buffer. Content of the buffer will be written to S3 on close as a single-part upload. @@ -1003,7 +1028,7 @@ def write(self, b): self._total_bytes += length return length - def terminate(self): + def terminate(self) -> None: """Nothing to cancel in single-part uploads.""" return @@ -1036,16 +1061,17 @@ def __repr__(self): def _retry_if_failed( - partial, - attempts=_UPLOAD_ATTEMPTS, - sleep_seconds=_SLEEP_SECONDS, - exceptions=None): + partial: Callable, + attempts: int = _UPLOAD_ATTEMPTS, + sleep_seconds: int = _SLEEP_SECONDS, + exceptions: Optional[List[Exception]] = None, +) -> Any: if exceptions is None: - exceptions = (botocore.exceptions.EndpointConnectionError, ) + exceptions = [botocore.exceptions.EndpointConnectionError] for attempt in range(attempts): try: return partial() - except exceptions: + except tuple(exceptions): # type: ignore logger.critical( 'Unable to connect to the endpoint. Check your network connection. ' 'Sleeping and retrying %d more times ' @@ -1069,38 +1095,38 @@ def _accept_all(key): def iter_bucket( - bucket_name, - prefix='', - accept_key=None, - key_limit=None, - workers=16, - retries=3, - **session_kwargs): + bucket_name: str, + prefix: str = '', + accept_key: Callable = None, + key_limit: Optional[int] = None, + workers: int = 16, + retries: int = 3, + **session_kwargs: Any, +) -> Iterator[Tuple[str, bytes]]: """ Iterate and download all S3 objects under `s3://bucket_name/prefix`. Parameters ---------- - bucket_name: str + :param bucket_name: The name of the bucket. - prefix: str, optional - Limits the iteration to keys starting with the prefix. - accept_key: callable, optional + :param prefix: + Limits the iteration to keys starting wit the prefix. + :param accept_key: This is a function that accepts a key name (unicode string) and returns True/False, signalling whether the given key should be downloaded. The default behavior is to accept all keys. - key_limit: int, optional + :param key_limit: If specified, the iterator will stop after yielding this many results. - workers: int, optional + :param workers: The number of subprocesses to use. - retries: int, optional + :param retries: The number of time to retry a failed download. - session_kwargs: dict, optional + :param session_kwargs: Keyword arguments to pass when creating a new session. For a list of available names and values, see: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session - Yields ------ str @@ -1135,7 +1161,7 @@ def iter_bucket( # before moving on. Works for boto3 as well as boto. # try: - bucket_name = bucket_name.name + bucket_name = bucket_name.name # type: ignore except AttributeError: pass @@ -1169,10 +1195,10 @@ def iter_bucket( def _list_bucket( - bucket_name, - prefix='', - accept_key=lambda k: True, - **session_kwargs): + bucket_name: str, + prefix: str = '', + accept_key=lambda k: True, +**session_kwargs) -> Iterator[str]: session = boto3.session.Session(**session_kwargs) client = session.client('s3') ctoken = None @@ -1199,7 +1225,12 @@ def _list_bucket( break -def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs): +def _download_key( + key_name: str, + bucket_name: Optional[str] = None, + retries: int = 3, + **session_kwargs, +) -> Optional[Tuple[str, bytes]]: if bucket_name is None: raise ValueError('bucket_name may not be None') @@ -1225,8 +1256,10 @@ def _download_key(key_name, bucket_name=None, retries=3, **session_kwargs): else: return key_name, content_bytes + return None + -def _download_fileobj(bucket, key_name): +def _download_fileobj(bucket: 'boto3.s3.Bucket', key_name: str) -> bytes: # # This is a separate function only because it makes it easier to inject # exceptions during tests. diff --git a/smart_open/smart_open_lib.py b/smart_open/smart_open_lib.py index bf25f5cc..76bb4948 100644 --- a/smart_open/smart_open_lib.py +++ b/smart_open/smart_open_lib.py @@ -31,9 +31,12 @@ # import smart_open.local_file as so_file -from smart_open import compression -from smart_open import doctools -from smart_open import transport +from smart_open import ( + compression, + doctools, + transport, + utils, +) # # For backwards compatibility and keeping old unit tests happy. @@ -42,12 +45,22 @@ from smart_open.utils import check_kwargs as _check_kwargs # noqa: F401 from smart_open.utils import inspect_kwargs as _inspect_kwargs # noqa: F401 +from typing import ( + Any, + Callable, + Dict, + IO, + Optional, + Tuple, + Union, +) + logger = logging.getLogger(__name__) DEFAULT_ENCODING = locale.getpreferredencoding(do_setlocale=False) -def _sniff_scheme(uri_as_string): +def _sniff_scheme(uri_as_string: str) -> str: """Returns the scheme of the URL only, as a string.""" # # urlsplit doesn't work on Windows -- it parses the drive as the scheme... @@ -59,34 +72,17 @@ def _sniff_scheme(uri_as_string): return urllib.parse.urlsplit(uri_as_string).scheme -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Tuple: """ Parse the given URI from a string. - Parameters - ---------- - uri_as_string: str - The URI to parse. - - Returns - ------- - collections.namedtuple - The parsed URI. - Notes ----- smart_open/doctools.py magic goes here """ scheme = _sniff_scheme(uri_as_string) submodule = transport.get_transport(scheme) - as_dict = submodule.parse_uri(uri_as_string) - - # - # The conversion to a namedtuple is just to keep the old tests happy while - # I'm still refactoring. - # - Uri = collections.namedtuple('Uri', sorted(as_dict.keys())) - return Uri(**as_dict) + return submodule.parse_uri(uri_as_string) # @@ -98,17 +94,17 @@ def parse_uri(uri_as_string): def open( - uri, - mode='r', - buffering=-1, - encoding=None, - errors=None, - newline=None, - closefd=True, - opener=None, - ignore_ext=False, - transport_params=None, - ): + uri: Union[str, IO], + mode: str = 'r', + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, + closefd: bool = True, + opener: Optional[Any] = None, + ignore_ext: bool = False, + transport_params: Dict[str, Any] = None, +) -> IO: r"""Open the URI object, returning a file-like object. The URI is usually a string in a variety of formats. @@ -293,14 +289,14 @@ def transfer(char): def _shortcut_open( - uri, - mode, - ignore_ext=False, - buffering=-1, - encoding=None, - errors=None, - newline=None, - ): + uri: Union[str, IO], + mode: str, + ignore_ext: bool = False, + buffering: int = -1, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, +) -> Optional[IO]: """Try to open the URI using the standard library io.open function. This can be much faster than the alternative of opening in binary mode and @@ -313,10 +309,9 @@ def _shortcut_open( If it is not possible to use the built-in open for the specified URI, returns None. - :param str uri: A string indicating what to open. - :param str mode: The mode to pass to the open function. + :param uri: A string indicating what to open. + :param mode: The mode to pass to the open function. :returns: The opened file - :rtype: file """ if not isinstance(uri, str): return None @@ -343,19 +338,18 @@ def _shortcut_open( if errors and 'b' not in mode: open_kwargs['errors'] = errors - return _builtin_open(local_path, mode, buffering=buffering, **open_kwargs) + return _builtin_open(local_path, mode, buffering=buffering, **open_kwargs) # type: ignore -def _open_binary_stream(uri, mode, transport_params): +def _open_binary_stream(uri: Union[str, IO], mode: str, transport_params: Dict[str, Any]) -> IO: """Open an arbitrary URI in the specified binary mode. Not all modes are supported for all protocols. :arg uri: The URI to open. May be a string, or something else. - :arg str mode: The mode to open with. Must be rb, wb or ab. + :arg mode: The mode to open with. Must be rb, wb or ab. :arg transport_params: Keyword argumens for the transport layer. :returns: A named file object - :rtype: file-like object with a .name attribute """ if mode not in ('rb', 'rb+', 'wb', 'wb+', 'ab', 'ab+'): # @@ -372,8 +366,8 @@ def _open_binary_stream(uri, mode, transport_params): # if there is no such an attribute, we return "unknown" - this # effectively disables any compression if not hasattr(uri, 'name'): - uri.name = getattr(uri, 'name', 'unknown') - return uri + uri.name = getattr(uri, 'name', 'unknown') # type: ignore + return uri # type: ignore if not isinstance(uri, str): raise TypeError("don't know how to handle uri %s" % repr(uri)) @@ -387,17 +381,22 @@ def _open_binary_stream(uri, mode, transport_params): return fobj -def _encoding_wrapper(fileobj, mode, encoding=None, errors=None, newline=None): +def _encoding_wrapper( + fileobj: IO[bytes], + mode: str, + encoding: Optional[str] = None, + errors: Optional[str] = None, + newline: Optional[str] = None, +) -> IO[Union[bytes, str]]: """Decode bytes into text, if necessary. If mode specifies binary access, does nothing, unless the encoding is specified. A non-null encoding implies text mode. :arg fileobj: must quack like a filehandle object. - :arg str mode: is the mode which was originally requested by the user. - :arg str encoding: The text encoding to use. If mode is binary, overrides mode. - :arg str errors: The method to use when handling encoding/decoding errors. - :returns: a file object + :arg mode: is the mode which was originally requested by the user. + :arg encoding: The text encoding to use. If mode is binary, overrides mode. + :arg errors: The method to use when handling encoding/decoding errors. """ logger.debug('encoding_wrapper: %r', locals()) @@ -439,10 +438,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): _patch_pathlib(self.old_impl) -def _patch_pathlib(func): +def _patch_pathlib(func: Callable) -> Callable: """Replace `Path.open` with `func`""" old_impl = pathlib.Path.open - pathlib.Path.open = func + pathlib.Path.open = func # type: ignore return old_impl diff --git a/smart_open/ssh.py b/smart_open/ssh.py index fa762eb6..c4e1c336 100644 --- a/smart_open/ssh.py +++ b/smart_open/ssh.py @@ -22,11 +22,18 @@ """ +import collections import getpass import logging import urllib.parse import warnings +from typing import ( + Any, + Dict, + Tuple, +) + import smart_open.utils logger = logging.getLogger(__name__) @@ -34,7 +41,7 @@ # # Global storage for SSH connections. # -_SSH = {} +_SSH: Dict[Tuple[str, str], Any] = {} SCHEMES = ("ssh", "scp", "sftp") """Supported URL schemes.""" @@ -48,15 +55,17 @@ 'sftp://username@host/path/file', ) +Uri = collections.namedtuple('Uri', 'scheme uri_path user host port password') + def _unquote(text): return text and urllib.parse.unquote(text) -def parse_uri(uri_as_string): +def parse_uri(uri_as_string: str) -> Uri: split_uri = urllib.parse.urlsplit(uri_as_string) assert split_uri.scheme in SCHEMES - return dict( + return Uri( scheme=split_uri.scheme, uri_path=_unquote(split_uri.path), user=_unquote(split_uri.username), @@ -68,7 +77,7 @@ def parse_uri(uri_as_string): def open_uri(uri, mode, transport_params): smart_open.utils.check_kwargs(open, transport_params) - parsed_uri = parse_uri(uri) + parsed_uri = parse_uri(uri)._asdict() uri_path = parsed_uri.pop('uri_path') parsed_uri.pop('scheme') return open(uri_path, mode, transport_params=transport_params, **parsed_uri) diff --git a/smart_open/tests/fixtures/missing_deps_transport.py b/smart_open/tests/fixtures/missing_deps_transport.py index ea686598..a141d919 100644 --- a/smart_open/tests/fixtures/missing_deps_transport.py +++ b/smart_open/tests/fixtures/missing_deps_transport.py @@ -4,7 +4,7 @@ try: - import this_module_does_not_exist_but_we_need_it # noqa + import this_module_does_not_exist_but_we_need_it # type: ignore # noqa except ImportError: MISSING_DEPS = True diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index 583f48e4..f5873433 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -5,6 +5,7 @@ # This code is distributed under the terms and conditions # from the MIT License (MIT). # +import collections import gzip import inspect import io @@ -13,15 +14,11 @@ import time import uuid import unittest -try: - from unittest import mock -except ImportError: - import mock import warnings -from collections import OrderedDict import google.cloud import google.api_core.exceptions +import mock import smart_open import smart_open.constants @@ -49,7 +46,7 @@ class FakeBucket(object): def __init__(self, client, name=None): self.client = client # type: FakeClient self.name = name - self.blobs = OrderedDict() + self.blobs = collections.OrderedDict() self._exists = True # @@ -239,8 +236,8 @@ def __init__(self, credentials=None): if credentials is None: credentials = FakeCredentials(self) self._credentials = credentials # type: FakeCredentials - self.uploads = OrderedDict() - self.__buckets = OrderedDict() + self.uploads = collections.OrderedDict() + self.__buckets = collections.OrderedDict() def bucket(self, bucket_id): try: diff --git a/smart_open/utils.py b/smart_open/utils.py index 94bac9eb..5bde7891 100644 --- a/smart_open/utils.py +++ b/smart_open/utils.py @@ -12,6 +12,11 @@ import logging import urllib.parse +from typing import ( + Optional, + Tuple, +) + logger = logging.getLogger(__name__) @@ -71,7 +76,7 @@ def check_kwargs(kallable, kwargs): return supported_kwargs -def clamp(value, minval, maxval): +def clamp(value: int, minval: int, maxval: int) -> int: """Clamp a numeric value to a specific range. Parameters @@ -94,22 +99,16 @@ def clamp(value, minval, maxval): return max(min(value, maxval), minval) -def make_range_string(start=None, stop=None): +def make_range_string(start: Optional[int] = None, stop: Optional[int] = None) -> str: """Create a byte range specifier in accordance with RFC-2616. Parameters ---------- - start: int, optional + :param start: The start of the byte range. If unspecified, stop indicated offset from EOF. - stop: int, optional + :param stop: The end of the byte range. If unspecified, indicates EOF. - - Returns - ------- - str - A byte range specifier. - """ # # https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.35 @@ -119,17 +118,12 @@ def make_range_string(start=None, stop=None): return 'bytes=%s-%s' % ('' if start is None else start, '' if stop is None else stop) -def parse_content_range(content_range): +def parse_content_range(content_range: str) -> Tuple[str, int, int, int]: """Extract units, start, stop, and length from a content range header like "bytes 0-846981/846982". Assumes a properly formatted content-range header from S3. See werkzeug.http.parse_content_range_header for a more robust version. - Parameters - ---------- - content_range: str - The content-range header to parse. - Returns ------- tuple (units: str, start: int, stop: int, length: int) @@ -142,7 +136,7 @@ def parse_content_range(content_range): return units, int(start), int(stop), int(length) -def safe_urlsplit(url): +def safe_urlsplit(url: str) -> urllib.parse.SplitResult: """This is a hack to prevent the regular urlsplit from splitting around question marks. A question mark (?) in a URL typically indicates the start of a diff --git a/smart_open/webhdfs.py b/smart_open/webhdfs.py index 369173c1..37687de0 100644 --- a/smart_open/webhdfs.py +++ b/smart_open/webhdfs.py @@ -12,6 +12,7 @@ """ +import collections import io import logging import urllib.parse @@ -35,9 +36,11 @@ MIN_PART_SIZE = 50 * 1024**2 # minimum part size for HDFS multipart uploads +Uri = collections.namedtuple('Uri', 'scheme uri') + def parse_uri(uri_as_str): - return dict(scheme=SCHEME, uri=uri_as_str) + return Uri(scheme=SCHEME, uri=uri_as_str) def open_uri(uri, mode, transport_params):