Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 81 additions & 63 deletions fsspec/implementations/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def ls(self, path, detail=True):
def __getattribute__(self, item):
if item in {
"load_cache",
"_get_cached_file_before_open",
"_open",
"save_cache",
"close_and_update",
Expand Down Expand Up @@ -678,46 +679,12 @@ def cat(
out = out[paths[0]]
return out

def _open(self, path, mode="rb", **kwargs):
path = self._strip_protocol(path)
if "r" not in mode:
hash = self._mapper(path)
fn = os.path.join(self.storage[-1], hash)
user_specified_kwargs = {
k: v
for k, v in kwargs.items()
# those kwargs were added by open(), we don't want them
if k not in ["autocommit", "block_size", "cache_options"]
}
return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
detail = self._check_file(path)
if detail:
detail, fn = detail
_, blocks = detail["fn"], detail["blocks"]
if blocks is True:
logger.debug("Opening local copy of %s", path)

# In order to support downstream filesystems to be able to
# infer the compression from the original filename, like
# the `TarFileSystem`, let's extend the `io.BufferedReader`
# fileobject protocol by adding a dedicated attribute
# `original`.
f = open(fn, mode)
f.original = detail.get("original")
return f
else:
raise ValueError(
f"Attempt to open partially cached file {path}"
f" as a wholly cached file"
)
else:
fn = self._make_local_details(path)
kwargs["mode"] = mode

def _get_cached_file_before_open(self, path, **kwargs):
fn = self._make_local_details(path)
# call target filesystems open
self._mkcache()
if self.compression:
with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
with self.fs._open(path, mode="rb", **kwargs) as f, open(fn, "wb") as f2:
if isinstance(f, AbstractBufferedFile):
# want no type of caching if just downloading whole thing
f.cache = BaseCache(0, f.cache.fetcher, f.size)
Expand All @@ -735,7 +702,47 @@ def _open(self, path, mode="rb", **kwargs):
else:
self.fs.get_file(path, fn)
self.save_cache()
return self._open(path, mode)

def _open(self, path, mode="rb", **kwargs):
path = self._strip_protocol(path)
# For read (or append), (try) download from remote
if "r" in mode or "a" in mode:
if not self._check_file(path):
if self.fs.exists(path):
self._get_cached_file_before_open(path, **kwargs)
elif "r" in mode:
raise FileNotFoundError(path)

detail, fn = self._check_file(path)
_, blocks = detail["fn"], detail["blocks"]
if blocks is True:
logger.debug("Opening local copy of %s", path)
else:
raise ValueError(
f"Attempt to open partially cached file {path}"
f" as a wholly cached file"
)

# Just reading does not need special file handling
if "r" in mode and "+" not in mode:
# In order to support downstream filesystems to be able to
# infer the compression from the original filename, like
# the `TarFileSystem`, let's extend the `io.BufferedReader`
# fileobject protocol by adding a dedicated attribute
# `original`.
f = open(fn, mode)
f.original = detail.get("original")
return f

hash = self._mapper(path)
fn = os.path.join(self.storage[-1], hash)
user_specified_kwargs = {
k: v
for k, v in kwargs.items()
# those kwargs were added by open(), we don't want them
if k not in ["autocommit", "block_size", "cache_options"]
}
return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)


class SimpleCacheFileSystem(WholeFileCacheFileSystem):
Expand Down Expand Up @@ -894,37 +901,16 @@ def cat_ranges(
paths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs
)

def _open(self, path, mode="rb", **kwargs):
path = self._strip_protocol(path)
def _get_cached_file_before_open(self, path, **kwargs):
sha = self._mapper(path)

if "r" not in mode:
fn = os.path.join(self.storage[-1], sha)
user_specified_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["autocommit", "block_size", "cache_options"]
} # those were added by open()
return LocalTempFile(
self,
path,
mode=mode,
autocommit=not self._intrans,
fn=fn,
**user_specified_kwargs,
)
fn = self._check_file(path)
if fn:
return open(fn, mode)

fn = os.path.join(self.storage[-1], sha)
logger.debug("Copying %s to local cache", path)
kwargs["mode"] = mode

self._mkcache()
self._cache_size = None

if self.compression:
with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
with self.fs._open(path, mode="rb", **kwargs) as f, open(fn, "wb") as f2:
if isinstance(f, AbstractBufferedFile):
# want no type of caching if just downloading whole thing
f.cache = BaseCache(0, f.cache.fetcher, f.size)
Expand All @@ -941,7 +927,39 @@ def _open(self, path, mode="rb", **kwargs):
f2.write(data)
else:
self.fs.get_file(path, fn)
return self._open(path, mode)

def _open(self, path, mode="rb", **kwargs):
path = self._strip_protocol(path)
sha = self._mapper(path)

# For read (or append), (try) download from remote
if "r" in mode or "a" in mode:
if not self._check_file(path):
# append does not require an existing file but read does
if self.fs.exists(path):
self._get_cached_file_before_open(path, **kwargs)
elif "r" in mode:
raise FileNotFoundError(path)

fn = self._check_file(path)
# Just reading does not need special file handling
if "r" in mode and "+" not in mode:
return open(fn, mode)

fn = os.path.join(self.storage[-1], sha)
user_specified_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["autocommit", "block_size", "cache_options"]
} # those were added by open()
return LocalTempFile(
self,
path,
mode=mode,
autocommit=not self._intrans,
fn=fn,
**user_specified_kwargs,
)


class LocalTempFile:
Expand Down
52 changes: 52 additions & 0 deletions fsspec/implementations/tests/test_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,58 @@ def test_cached_write(protocol):
assert not os.path.exists(fn)


@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
def test_cached_append_text(protocol):
fn = "memory://afile"
with fsspec.open(fn, "w") as f:
f.write("hello")
with fsspec.open(f"{protocol}::{fn}", mode="a") as f:
assert isinstance(f.buffer, LocalTempFile)
f.write("world")
with fsspec.open(fn, "r") as f:
assert f.read() == "helloworld"


@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
def test_cached_append_binary(protocol):
fn = "memory://afile"
with fsspec.open(fn, "wb") as f:
f.write(b"hello")
with fsspec.open(f"{protocol}::{fn}", mode="ab") as f:
assert isinstance(f, LocalTempFile)
f.write(b"world")
with fsspec.open(fn, "rb") as f:
assert f.read() == b"helloworld"


@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
def test_cached_update_text(protocol):
fn = "memory://afile"
with fsspec.open(fn, "w") as f:
f.write("hello")
with fsspec.open(f"{protocol}::{fn}", mode="r+") as f:
assert isinstance(f.buffer, LocalTempFile)
assert f.read() == "hello"
f.seek(1)
f.write("world")
with fsspec.open(fn, "r") as f:
assert f.read() == "hworld"


@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
def test_cached_update_binary(protocol):
fn = "memory://afile"
with fsspec.open(fn, "wb") as f:
f.write(b"hello")
with fsspec.open(f"{protocol}::{fn}", mode="r+b") as f:
assert isinstance(f, LocalTempFile)
assert f.read() == b"hello"
f.seek(1)
f.write(b"world")
with fsspec.open(fn, "rb") as f:
assert f.read() == b"hworld"


def test_expiry():
import time

Expand Down
Loading