Skip to content
This repository was archived by the owner on Apr 10, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lucid/misc/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from lucid.misc.io.showing import show
from lucid.misc.io.loading import load
from lucid.misc.io.saving import save, CaptureSaveContext, batch_save
from lucid.misc.io.scoping import io_scope, scope_url
12 changes: 12 additions & 0 deletions lucid/misc/io/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ def _decompress_xz(handle, **kwargs):
}


modes = {
".png": "wb",
".jpg": "wb",
".jpeg": "wb",
".webp": "wb",
".npy": "wb",
".npz": "wb",
".json": "w",
".txt": "w",
".pb": "wb",
}

unsafe_loaders = {
".pickle": _load_pickle,
".pkl": _load_pickle,
Expand Down
8 changes: 8 additions & 0 deletions lucid/misc/io/reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

from lucid.misc.io.writing import write_handle
from lucid.misc.io.scoping import scope_url
from lucid.misc.io.util import isazure

import blobfile

# create logger with module name, e.g. lucid.misc.io.reading
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -93,6 +95,12 @@ def read_handle(url, cache=None, mode="rb"):
"""
url = scope_url(url)

if isazure(url):
handle = blobfile.BlobFile(url, mode)
yield handle
handle.close()
return

scheme = urlparse(url).scheme

if cache == "purge":
Expand Down
21 changes: 19 additions & 2 deletions lucid/misc/io/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from lucid.misc.io.writing import write_handle
from lucid.misc.io.serialize_array import _normalize_array
from lucid.misc.io.scoping import current_io_scopes, set_io_scopes

from lucid.misc.io.util import isazure

# create logger with module name, e.g. lucid.misc.io.saving
log = logging.getLogger(__name__)
Expand Down Expand Up @@ -227,6 +227,18 @@ def compress_xz(handle, **kwargs):
".pb": save_pb,
}

modes = {
".png": "wb",
".jpg": "wb",
".jpeg": "wb",
".webp": "wb",
".npy": "wb",
".npz": "wb",
".json": "w",
".txt": "w",
".pb": "wb",
}

unsafe_savers = {
".pickle": save_pickle,
".pkl": save_pickle,
Expand Down Expand Up @@ -255,6 +267,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona

# Determine context
# Is this a handle? What is the extension? Are we saving to GCS?

is_handle = hasattr(url_or_handle, "write") and hasattr(url_or_handle, "name")
if is_handle:
path = url_or_handle.name
Expand Down Expand Up @@ -292,7 +305,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona
else:
handle_provider = write_handle

with handle_provider(url_or_handle) as handle:
with handle_provider(url_or_handle, mode = modes[ext]) as handle:
with compressor(handle) as compressed_handle:
result = saver(thing, compressed_handle, **kwargs)

Expand All @@ -309,6 +322,7 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona

# capture save if a save context is available
save_context = save_context if save_context is not None else CaptureSaveContext.current_save_context()

if save_context:
log.debug(
"capturing save: resulted in {} -> {} in save_context {}".format(
Expand All @@ -320,6 +334,9 @@ def save(thing, url_or_handle, allow_unsafe_formats=False, save_context: Optiona
if result is not None and "url" in result and result["url"].startswith("gs://"):
result["serve"] = "https://storage.googleapis.com/{}".format(result["url"][5:])

if isazure(result["url"]):
result["serve"] = url

return result


Expand Down
11 changes: 10 additions & 1 deletion lucid/misc/io/writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from contextlib import contextmanager
from urllib.parse import urlparse
from tensorflow import gfile
import blobfile

from lucid.misc.io.scoping import scope_url
from lucid.misc.io.util import isazure

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,11 +62,18 @@ def write(data, url, mode="wb"):

_write_to_path(data, url, mode=mode)


@contextmanager
def write_handle(path, mode=None):
path = scope_url(path)

if isazure(path):
if mode is None:
mode = "w"
handle = blobfile.BlobFile(path, mode)
yield handle
handle.close()
return

if _supports_make_dirs(path):
gfile.MakeDirs(os.path.dirname(path))

Expand Down