Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Unreleased

- Introduced type annotations

# 4.1.2, 18 Jan 2021

- Correctly pass boto3 resource to writers (PR [#576](https://github.com/RaRe-Technologies/smart_open/pull/576), [@jackluo923](https://github.com/jackluo923))
Expand Down Expand Up @@ -35,6 +37,7 @@ to install the AWS dependencies only, or
pip install smart_open[all]

to install all dependencies, including AWS, GCS, etc.
>>>>>>> upstream/develop

# 2.2.1, 1 Oct 2020

Expand Down
31 changes: 31 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
[mypy]

[mypy-smart_open.tests.*]
ignore_missing_imports = True

#
# Third party libraries below
#
[mypy-azure.*]
ignore_missing_imports = True

[mypy-boto3.*]
ignore_missing_imports = True

[mypy-botocore.*]
ignore_missing_imports = True

[mypy-google.*]
ignore_missing_imports = True

[mypy-moto.*]
ignore_missing_imports = True

[mypy-paramiko.*]
ignore_missing_imports = True

[mypy-requests_kerberos.*]
ignore_missing_imports = True

[mypy-responses.*]
ignore_missing_imports = True
116 changes: 59 additions & 57 deletions smart_open/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,16 @@
"""Implements file-like objects for reading and writing to/from Azure Blob Storage."""

import base64
import collections
import io
import logging

from typing import (
Dict,
IO,
List,
)

import smart_open.bytebuffer
import smart_open.constants

Expand All @@ -37,8 +44,10 @@
https://docs.microsoft.com/en-us/rest/api/storageservices/understanding-block-blobs--append-blobs--and-page-blobs
"""

Uri = collections.namedtuple('Uri', 'scheme container_id blob_id')


def parse_uri(uri_as_string):
def parse_uri(uri_as_string: str) -> Uri:
sr = smart_open.utils.safe_urlsplit(uri_as_string)
assert sr.scheme == SCHEME
first = sr.netloc
Expand All @@ -52,54 +61,47 @@ def parse_uri(uri_as_string):
container_id = first
blob_id = second

return dict(scheme=SCHEME, container_id=container_id, blob_id=blob_id)
return Uri(scheme=SCHEME, container_id=container_id, blob_id=blob_id)


def open_uri(uri, mode, transport_params):
def open_uri(uri: str, mode: str, transport_params: Dict) -> IO[bytes]:
parsed_uri = parse_uri(uri)
kwargs = smart_open.utils.check_kwargs(open, transport_params)
return open(parsed_uri['container_id'], parsed_uri['blob_id'], mode, **kwargs)
return open(parsed_uri.container_id, parsed_uri.blob_id, mode, **kwargs)


def open(
container_id,
blob_id,
mode,
client=None, # type: azure.storage.blob.BlobServiceClient
buffer_size=DEFAULT_BUFFER_SIZE,
min_part_size=_DEFAULT_MIN_PART_SIZE
):
container_id: str,
blob_id: str,
mode: str,
client: 'azure.storage.blob.BlobServiceClient' = None,
buffer_size: int = DEFAULT_BUFFER_SIZE,
min_part_size: int = _DEFAULT_MIN_PART_SIZE,
) -> IO[bytes]:
"""Open an Azure Blob Storage blob for reading or writing.

Parameters
----------
container_id: str
The name of the container this object resides in.
blob_id: str
The name of the blob within the bucket.
mode: str
The mode for opening the object. Must be either "rb" or "wb".
client: azure.storage.blob.BlobServiceClient
The Azure Blob Storage client to use when working with azure-storage-blob.
buffer_size: int, optional
The buffer size to use when performing I/O. For reading only.
min_part_size: int, optional
The minimum part size for multipart uploads. For writing only.
:param container_id: The name of the container this object resides in.
:param blob_id: The name of the blob within the bucket.
:param mode: The mode for opening the object. Must be either "rb" or "wb".
:param client: The Azure Blob Storage client to use when working with azure-storage-blob.
:param buffer_size: The buffer size to use when performing I/O. For reading only.
:param min_part_size: The minimum part size for multipart uploads. For writing only.

"""
if not client:
raise ValueError('you must specify the client to connect to Azure')

if mode == smart_open.constants.READ_BINARY:
return Reader(
return Reader( # type: ignore
container_id,
blob_id,
client,
buffer_size=buffer_size,
line_terminator=smart_open.constants.BINARY_NEWLINE,
)
elif mode == smart_open.constants.WRITE_BINARY:
return Writer(
return Writer( # type: ignore
container_id,
blob_id,
client,
Expand All @@ -112,30 +114,29 @@ def open(
class _RawReader(object):
"""Read an Azure Blob Storage file."""

def __init__(self, blob, size):
# type: (azure.storage.blob.BlobClient, int) -> None
def __init__(self, blob: 'azure.storage.blob.BlobClient', size: int) -> None:
self._blob = blob
self._size = size
self._position = 0

def seek(self, position):
def seek(self, position: int) -> int:
"""Seek to the specified position (byte offset) in the Azure Blob Storage blob.

:param int position: The byte offset from the beginning of the blob.
:param position: The byte offset from the beginning of the blob.

Returns the position after seeking.
"""
self._position = position
return self._position

def read(self, size=-1):
def read(self, size: int = -1) -> bytes:
if self._position >= self._size:
return b''
binary = self._download_blob_chunk(size)
self._position += len(binary)
return binary

def _download_blob_chunk(self, size):
def _download_blob_chunk(self, size: int) -> bytes:
if self._size == self._position:
#
# When reading, we can't seek to the first byte of an empty file.
Expand All @@ -162,16 +163,16 @@ class Reader(io.BufferedIOBase):

"""
def __init__(
self,
container,
blob,
client, # type: azure.storage.blob.BlobServiceClient
buffer_size=DEFAULT_BUFFER_SIZE,
line_terminator=smart_open.constants.BINARY_NEWLINE,
):
self._container_client = client.get_container_client(container)
# type: azure.storage.blob.ContainerClient

self,
container: str,
blob: str,
client: 'azure.storage.blob.BlobServiceClient',
buffer_size: int = DEFAULT_BUFFER_SIZE,
line_terminator: bytes = smart_open.constants.BINARY_NEWLINE,
) -> None:
self._container_client: azure.storage.blob.ContainerClient = client.get_container_client(container)

self.name = blob
self._blob = self._container_client.get_blob_client(blob)
if self._blob is None:
raise azure.core.exceptions.ResourceNotFoundError(
Expand All @@ -190,7 +191,7 @@ def __init__(
#
# This member is part of the io.BufferedIOBase interface.
#
self.raw = None
self.raw = None # type: ignore

#
# Override some methods from io.IOBase.
Expand Down Expand Up @@ -309,7 +310,7 @@ def readline(self, limit=-1):
#
# Internal methods.
#
def _read_from_buffer(self, size=-1):
def _read_from_buffer(self, size: int = -1) -> bytes:
"""Remove at most size bytes from our buffer and return them."""
# logger.debug('reading %r bytes from %r byte-long buffer', size, len(self._current_part))
size = size if size >= 0 else len(self._current_part)
Expand All @@ -318,13 +319,14 @@ def _read_from_buffer(self, size=-1):
# logger.debug('part: %r', part)
return part

def _fill_buffer(self, size=-1):
def _fill_buffer(self, size: int = -1) -> bool:
size = max(size, self._current_part._chunk_size)
while len(self._current_part) < size and not self._position == self._size:
bytes_read = self._current_part.fill(self._raw_reader)
if bytes_read == 0:
logger.debug('reached EOF while filling buffer')
return True
return False

def __enter__(self):
return self
Expand All @@ -349,28 +351,28 @@ class Writer(io.BufferedIOBase):
Implements the io.BufferedIOBase interface of the standard library."""

def __init__(
self,
container,
blob,
client, # type: azure.storage.blob.BlobServiceClient
min_part_size=_DEFAULT_MIN_PART_SIZE,
):
self,
container: str,
blob: str,
client: 'azure.storage.blob.BlobServiceClient',
min_part_size: int = _DEFAULT_MIN_PART_SIZE,
) -> None:
self._client = client
self._container_client = self._client.get_container_client(container)
# type: azure.storage.blob.ContainerClient
self._blob = self._container_client.get_blob_client(blob) # type: azure.storage.blob.BlobClient
self._container_client: azure.storage.blob.ContainerClient = self._client.get_container_client(container) # noqa
self.name = blob
self._blob: azure.storage.blob.BlobClient = self._container_client.get_blob_client(blob)
self._min_part_size = min_part_size

self._total_size = 0
self._total_parts = 0
self._bytes_uploaded = 0
self._current_part = io.BytesIO()
self._block_list = []
self._block_list: List['azure.storage.blob.BlobBlock'] = []

#
# This member is part of the io.BufferedIOBase interface.
#
self.raw = None
self.raw = None # type: ignore

def flush(self):
pass
Expand Down Expand Up @@ -423,7 +425,7 @@ def write(self, b):

return len(b)

def _upload_part(self):
def _upload_part(self) -> None:
part_num = self._total_parts + 1
content_length = self._current_part.tell()
range_stop = self._bytes_uploaded + content_length - 1
Expand Down
9 changes: 8 additions & 1 deletion smart_open/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@
import logging
import os.path

from typing import (
Callable,
Dict,
IO,
)

logger = logging.getLogger(__name__)


_COMPRESSOR_REGISTRY = {}
Compressor = Callable[[IO, str], IO]
_COMPRESSOR_REGISTRY: Dict[str, Compressor] = {}


def get_supported_extensions():
Expand Down
Loading