Skip to content

Commit 2576617

Browse files
Support append and update for cached file systems (#1980)
Fixes #1977 Co-authored-by: Martin Durant <[email protected]>
1 parent cea9d7c commit 2576617

File tree

2 files changed

+133
-63
lines changed

2 files changed

+133
-63
lines changed

fsspec/implementations/cached.py

Lines changed: 81 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,7 @@ def ls(self, path, detail=True):
427427
def __getattribute__(self, item):
428428
if item in {
429429
"load_cache",
430+
"_get_cached_file_before_open",
430431
"_open",
431432
"save_cache",
432433
"close_and_update",
@@ -678,46 +679,12 @@ def cat(
678679
out = out[paths[0]]
679680
return out
680681

681-
def _open(self, path, mode="rb", **kwargs):
682-
path = self._strip_protocol(path)
683-
if "r" not in mode:
684-
hash = self._mapper(path)
685-
fn = os.path.join(self.storage[-1], hash)
686-
user_specified_kwargs = {
687-
k: v
688-
for k, v in kwargs.items()
689-
# those kwargs were added by open(), we don't want them
690-
if k not in ["autocommit", "block_size", "cache_options"]
691-
}
692-
return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
693-
detail = self._check_file(path)
694-
if detail:
695-
detail, fn = detail
696-
_, blocks = detail["fn"], detail["blocks"]
697-
if blocks is True:
698-
logger.debug("Opening local copy of %s", path)
699-
700-
# In order to support downstream filesystems to be able to
701-
# infer the compression from the original filename, like
702-
# the `TarFileSystem`, let's extend the `io.BufferedReader`
703-
# fileobject protocol by adding a dedicated attribute
704-
# `original`.
705-
f = open(fn, mode)
706-
f.original = detail.get("original")
707-
return f
708-
else:
709-
raise ValueError(
710-
f"Attempt to open partially cached file {path}"
711-
f" as a wholly cached file"
712-
)
713-
else:
714-
fn = self._make_local_details(path)
715-
kwargs["mode"] = mode
716-
682+
def _get_cached_file_before_open(self, path, **kwargs):
683+
fn = self._make_local_details(path)
717684
# call target filesystems open
718685
self._mkcache()
719686
if self.compression:
720-
with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
687+
with self.fs._open(path, mode="rb", **kwargs) as f, open(fn, "wb") as f2:
721688
if isinstance(f, AbstractBufferedFile):
722689
# want no type of caching if just downloading whole thing
723690
f.cache = BaseCache(0, f.cache.fetcher, f.size)
@@ -735,7 +702,47 @@ def _open(self, path, mode="rb", **kwargs):
735702
else:
736703
self.fs.get_file(path, fn)
737704
self.save_cache()
738-
return self._open(path, mode)
705+
706+
def _open(self, path, mode="rb", **kwargs):
707+
path = self._strip_protocol(path)
708+
# For read (or append), (try) download from remote
709+
if "r" in mode or "a" in mode:
710+
if not self._check_file(path):
711+
if self.fs.exists(path):
712+
self._get_cached_file_before_open(path, **kwargs)
713+
elif "r" in mode:
714+
raise FileNotFoundError(path)
715+
716+
detail, fn = self._check_file(path)
717+
_, blocks = detail["fn"], detail["blocks"]
718+
if blocks is True:
719+
logger.debug("Opening local copy of %s", path)
720+
else:
721+
raise ValueError(
722+
f"Attempt to open partially cached file {path}"
723+
f" as a wholly cached file"
724+
)
725+
726+
# Just reading does not need special file handling
727+
if "r" in mode and "+" not in mode:
728+
# In order to support downstream filesystems to be able to
729+
# infer the compression from the original filename, like
730+
# the `TarFileSystem`, let's extend the `io.BufferedReader`
731+
# fileobject protocol by adding a dedicated attribute
732+
# `original`.
733+
f = open(fn, mode)
734+
f.original = detail.get("original")
735+
return f
736+
737+
hash = self._mapper(path)
738+
fn = os.path.join(self.storage[-1], hash)
739+
user_specified_kwargs = {
740+
k: v
741+
for k, v in kwargs.items()
742+
# those kwargs were added by open(), we don't want them
743+
if k not in ["autocommit", "block_size", "cache_options"]
744+
}
745+
return LocalTempFile(self, path, mode=mode, fn=fn, **user_specified_kwargs)
739746

740747

741748
class SimpleCacheFileSystem(WholeFileCacheFileSystem):
@@ -894,37 +901,16 @@ def cat_ranges(
894901
paths, starts, ends, max_gap=max_gap, on_error=on_error, **kwargs
895902
)
896903

897-
def _open(self, path, mode="rb", **kwargs):
898-
path = self._strip_protocol(path)
904+
def _get_cached_file_before_open(self, path, **kwargs):
899905
sha = self._mapper(path)
900-
901-
if "r" not in mode:
902-
fn = os.path.join(self.storage[-1], sha)
903-
user_specified_kwargs = {
904-
k: v
905-
for k, v in kwargs.items()
906-
if k not in ["autocommit", "block_size", "cache_options"]
907-
} # those were added by open()
908-
return LocalTempFile(
909-
self,
910-
path,
911-
mode=mode,
912-
autocommit=not self._intrans,
913-
fn=fn,
914-
**user_specified_kwargs,
915-
)
916-
fn = self._check_file(path)
917-
if fn:
918-
return open(fn, mode)
919-
920906
fn = os.path.join(self.storage[-1], sha)
921907
logger.debug("Copying %s to local cache", path)
922-
kwargs["mode"] = mode
923908

924909
self._mkcache()
925910
self._cache_size = None
911+
926912
if self.compression:
927-
with self.fs._open(path, **kwargs) as f, open(fn, "wb") as f2:
913+
with self.fs._open(path, mode="rb", **kwargs) as f, open(fn, "wb") as f2:
928914
if isinstance(f, AbstractBufferedFile):
929915
# want no type of caching if just downloading whole thing
930916
f.cache = BaseCache(0, f.cache.fetcher, f.size)
@@ -941,7 +927,39 @@ def _open(self, path, mode="rb", **kwargs):
941927
f2.write(data)
942928
else:
943929
self.fs.get_file(path, fn)
944-
return self._open(path, mode)
930+
931+
def _open(self, path, mode="rb", **kwargs):
932+
path = self._strip_protocol(path)
933+
sha = self._mapper(path)
934+
935+
# For read (or append), (try) download from remote
936+
if "r" in mode or "a" in mode:
937+
if not self._check_file(path):
938+
# append does not require an existing file but read does
939+
if self.fs.exists(path):
940+
self._get_cached_file_before_open(path, **kwargs)
941+
elif "r" in mode:
942+
raise FileNotFoundError(path)
943+
944+
fn = self._check_file(path)
945+
# Just reading does not need special file handling
946+
if "r" in mode and "+" not in mode:
947+
return open(fn, mode)
948+
949+
fn = os.path.join(self.storage[-1], sha)
950+
user_specified_kwargs = {
951+
k: v
952+
for k, v in kwargs.items()
953+
if k not in ["autocommit", "block_size", "cache_options"]
954+
} # those were added by open()
955+
return LocalTempFile(
956+
self,
957+
path,
958+
mode=mode,
959+
autocommit=not self._intrans,
960+
fn=fn,
961+
**user_specified_kwargs,
962+
)
945963

946964

947965
class LocalTempFile:

fsspec/implementations/tests/test_cached.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,58 @@ def test_cached_write(protocol):
10911091
assert not os.path.exists(fn)
10921092

10931093

1094+
@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
1095+
def test_cached_append_text(protocol):
1096+
fn = "memory://afile"
1097+
with fsspec.open(fn, "w") as f:
1098+
f.write("hello")
1099+
with fsspec.open(f"{protocol}::{fn}", mode="a") as f:
1100+
assert isinstance(f.buffer, LocalTempFile)
1101+
f.write("world")
1102+
with fsspec.open(fn, "r") as f:
1103+
assert f.read() == "helloworld"
1104+
1105+
1106+
@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
1107+
def test_cached_append_binary(protocol):
1108+
fn = "memory://afile"
1109+
with fsspec.open(fn, "wb") as f:
1110+
f.write(b"hello")
1111+
with fsspec.open(f"{protocol}::{fn}", mode="ab") as f:
1112+
assert isinstance(f, LocalTempFile)
1113+
f.write(b"world")
1114+
with fsspec.open(fn, "rb") as f:
1115+
assert f.read() == b"helloworld"
1116+
1117+
1118+
@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
1119+
def test_cached_update_text(protocol):
1120+
fn = "memory://afile"
1121+
with fsspec.open(fn, "w") as f:
1122+
f.write("hello")
1123+
with fsspec.open(f"{protocol}::{fn}", mode="r+") as f:
1124+
assert isinstance(f.buffer, LocalTempFile)
1125+
assert f.read() == "hello"
1126+
f.seek(1)
1127+
f.write("world")
1128+
with fsspec.open(fn, "r") as f:
1129+
assert f.read() == "hworld"
1130+
1131+
1132+
@pytest.mark.parametrize("protocol", ["simplecache", "filecache"])
1133+
def test_cached_update_binary(protocol):
1134+
fn = "memory://afile"
1135+
with fsspec.open(fn, "wb") as f:
1136+
f.write(b"hello")
1137+
with fsspec.open(f"{protocol}::{fn}", mode="r+b") as f:
1138+
assert isinstance(f, LocalTempFile)
1139+
assert f.read() == b"hello"
1140+
f.seek(1)
1141+
f.write(b"world")
1142+
with fsspec.open(fn, "rb") as f:
1143+
assert f.read() == b"hworld"
1144+
1145+
10941146
def test_expiry():
10951147
import time
10961148

0 commit comments

Comments
 (0)