Skip to content

Commit 6f629c2

Browse files
authored
Support chaining for zip and tar paths (#440)
* upath._chain: fix unnesting for correct roundtripping * tests: add chained tarfile tests * upath: correct use of vfspath and .path and .__str__ * upath.implementations.cloud: fix .path override * tests: add chained zip tests * upath._chain: populate target_options if target_protocol is set * upath._flavour: fix windows join behavior
1 parent 746ea07 commit 6f629c2

File tree

8 files changed

+136
-36
lines changed

8 files changed

+136
-36
lines changed

upath/_chain.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import warnings
55
from collections import defaultdict
6+
from collections import deque
67
from collections.abc import MutableMapping
78
from collections.abc import Sequence
89
from collections.abc import Set
@@ -105,7 +106,26 @@ def replace(
105106
return type(self)(*segments, index=index)
106107

107108
def to_list(self) -> list[ChainSegment]:
108-
return list(self._segments)
109+
"""return a list of chain segments unnesting target_* segments"""
110+
queue = deque(self._segments)
111+
segments = []
112+
while queue:
113+
segment = queue.popleft()
114+
if (
115+
"target_protocol" in segment.storage_options
116+
and "fo" in segment.storage_options
117+
):
118+
storage_options = segment.storage_options.copy()
119+
target_options = storage_options.pop("target_options", {})
120+
target_protocol = storage_options.pop("target_protocol")
121+
fo = storage_options.pop("fo")
122+
queue.appendleft(ChainSegment(fo, target_protocol, target_options))
123+
segments.append(
124+
ChainSegment(segment.path, segment.protocol, storage_options)
125+
)
126+
elif not segments or segment != segments[-1]:
127+
segments.append(segment)
128+
return segments
109129

110130
@classmethod
111131
def from_list(cls, segments: list[ChainSegment], index: int = 0) -> Self:
@@ -124,9 +144,9 @@ def nest(self) -> ChainSegment:
124144
urls = _prev
125145
_prev = urls
126146
if i == len(chain) - 1:
127-
inkwargs = dict(**kw, **inkwargs)
147+
inkwargs = {**kw, **inkwargs}
128148
continue
129-
inkwargs["target_options"] = dict(**kw, **inkwargs)
149+
inkwargs["target_options"] = {**kw, **inkwargs}
130150
inkwargs["target_protocol"] = protocol
131151
inkwargs["fo"] = urls # codespell:ignore fo
132152
urlpath, protocol, _ = chain[0]
@@ -186,6 +206,8 @@ def unchain(self, path: str, kwargs: dict[str, Any]) -> list[ChainSegment]:
186206
kws.update(kwargs)
187207
kw = dict(**extra_kwargs)
188208
kw.update(kws)
209+
if "target_protocol" in kw:
210+
kw.setdefault("target_options", {})
189211
bit = flavour.strip_protocol(bit) or flavour.root_marker
190212
if (
191213
protocol in {"blockcache", "filecache", "simplecache"}
@@ -229,7 +251,8 @@ def chain(self, segments: Sequence[ChainSegment]) -> tuple[str, dict[str, Any]]:
229251
continue
230252
urlpaths.append(urlpath)
231253
# TODO: ensure roundtrip with unchain behavior
232-
kwargs[segment.protocol] = segment.storage_options
254+
if segment.storage_options:
255+
kwargs[segment.protocol] = segment.storage_options
233256
return self.link.join(urlpaths), kwargs
234257

235258

upath/_flavour.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,11 @@ def join(self, path: JoinablePathLike, *paths: JoinablePathLike) -> str:
287287
if not paths:
288288
return self.strip_protocol(path) or self.root_marker
289289
if self.local_file:
290-
return os.path.join(
290+
p = os.path.join(
291291
self.strip_protocol(path),
292292
*map(self.stringify_path, paths),
293293
)
294+
return p if os.name != "nt" else p.replace("\\", "/")
294295
if self.netloc_is_anchor:
295296
drv, p0 = self.splitdrive(path)
296297
pN = list(map(self.stringify_path, paths))

upath/core.py

Lines changed: 42 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -402,27 +402,38 @@ def __init__(
402402
**args0._chain.nest().storage_options,
403403
**storage_options,
404404
}
405-
str_args0 = str(args0)
405+
str_args0 = args0.__vfspath__()
406406

407407
else:
408408
if hasattr(args0, "__fspath__") and args0.__fspath__ is not None:
409409
str_args0 = args0.__fspath__()
410+
elif hasattr(args0, "__vfspath__") and args0.__vfspath__ is not None:
411+
str_args0 = args0.__vfspath__()
412+
elif isinstance(args0, str):
413+
str_args0 = args0
410414
else:
411-
str_args0 = str(args0)
415+
raise TypeError(
416+
"argument should be a UPath, str, "
417+
f"or support __vfspath__ or __fspath__, not {type(args0)!r}"
418+
)
412419
storage_options = type(self)._parse_storage_options(
413420
str_args0, protocol, storage_options
414421
)
415-
if len(args) > 1:
416-
str_args0 = WrappedFileSystemFlavour.from_protocol(protocol).join(
417-
str_args0, *args[1:]
418-
)
419422
else:
420423
str_args0 = "."
421424

422425
segments = chain_parser.unchain(
423426
str_args0, {"protocol": protocol, **storage_options}
424427
)
425-
self._chain = Chain.from_list(segments)
428+
chain = Chain.from_list(segments)
429+
if len(args) > 1:
430+
chain = chain.replace(
431+
path=WrappedFileSystemFlavour.from_protocol(protocol).join(
432+
chain.active_path,
433+
*args[1:],
434+
)
435+
)
436+
self._chain = chain
426437
self._chain_parser = chain_parser
427438
self._raw_urlpaths = args
428439
self._relative_base = None
@@ -638,9 +649,6 @@ def with_segments(self, *pathsegments: JoinablePathLike) -> Self:
638649
return new_instance
639650

640651
def __str__(self) -> str:
641-
return self.__vfspath__()
642-
643-
def __vfspath__(self) -> str:
644652
if self._relative_base is not None:
645653
active_path = self._chain.active_path
646654
stripped_base = self.parser.strip_protocol(
@@ -658,10 +666,19 @@ def __vfspath__(self) -> str:
658666
else:
659667
return self._chain_parser.chain(self._chain.to_list())[0]
660668

669+
def __vfspath__(self) -> str:
670+
if self._relative_base is not None:
671+
return self.__str__()
672+
else:
673+
return self.path
674+
661675
def __repr__(self) -> str:
676+
cls_name = type(self).__name__
677+
path = self.__vfspath__()
662678
if self._relative_base is not None:
663-
return f"<relative {type(self).__name__} {str(self)!r}>"
664-
return f"{type(self).__name__}({self.path!r}, protocol={self._protocol!r})"
679+
return f"<relative {cls_name} {path!r}>"
680+
else:
681+
return f"{cls_name}({path!r}, protocol={self._protocol!r})"
665682

666683
# === JoinablePath overrides ======================================
667684

@@ -704,9 +721,9 @@ def with_name(self, name: str) -> Self:
704721
split = self.parser.split
705722
if self.parser.sep in name: # `split(name)[0]`
706723
raise ValueError(f"Invalid name {name!r}")
707-
path = str(self)
708-
path = path.removesuffix(split(path)[1]) + name
709-
return self.with_segments(path)
724+
_path = self.__vfspath__()
725+
_path = _path.removesuffix(split(_path)[1]) + name
726+
return self.with_segments(_path)
710727

711728
@property
712729
def anchor(self) -> str:
@@ -780,7 +797,7 @@ def iterdir(self) -> Iterator[Self]:
780797
continue
781798
# only want the path name with iterdir
782799
_, _, name = name.removesuffix(sep).rpartition(self.parser.sep)
783-
yield base.with_segments(str(base), name)
800+
yield base.with_segments(base.path, name)
784801

785802
def __open_reader__(self) -> BinaryIO:
786803
return self.fs.open(self.path, mode="rb")
@@ -1045,7 +1062,7 @@ def glob(
10451062
self = self.absolute()
10461063
path_pattern = self.joinpath(pattern).path
10471064
sep = self.parser.sep
1048-
base = self.fs._strip_protocol(self.path)
1065+
base = self.path
10491066
for name in self.fs.glob(path_pattern):
10501067
name = name.removeprefix(base).removeprefix(sep)
10511068
yield self.joinpath(name)
@@ -1075,7 +1092,7 @@ def rglob(
10751092
if _FSSPEC_HAS_WORKING_GLOB:
10761093
r_path_pattern = self.joinpath("**", pattern).path
10771094
sep = self.parser.sep
1078-
base = self.fs._strip_protocol(self.path)
1095+
base = self.path
10791096
for name in self.fs.glob(r_path_pattern):
10801097
name = name.removeprefix(base).removeprefix(sep)
10811098
yield self.joinpath(name)
@@ -1084,7 +1101,7 @@ def rglob(
10841101
path_pattern = self.joinpath(pattern).path
10851102
r_path_pattern = self.joinpath("**", pattern).path
10861103
sep = self.parser.sep
1087-
base = self.fs._strip_protocol(self.path)
1104+
base = self.path
10881105
seen = set()
10891106
for p in (path_pattern, r_path_pattern):
10901107
for name in self.fs.glob(p):
@@ -1134,7 +1151,7 @@ def __eq__(self, other: object) -> bool:
11341151
return False
11351152

11361153
return (
1137-
self.path == other.path
1154+
self.__vfspath__() == other.__vfspath__()
11381155
and self.protocol == other.protocol
11391156
and self.storage_options == other.storage_options
11401157
)
@@ -1145,29 +1162,27 @@ def __hash__(self) -> int:
11451162
Note: in the future, if hash collisions become an issue, we
11461163
can add `fsspec.utils.tokenize(storage_options)`
11471164
"""
1148-
if self._relative_base is not None:
1149-
return hash((self.protocol, str(self)))
1150-
return hash((self.protocol, self.path))
1165+
return hash((self.protocol, self.__vfspath__()))
11511166

11521167
def __lt__(self, other: object) -> bool:
11531168
if not isinstance(other, UPath) or self.parser is not other.parser:
11541169
return NotImplemented
1155-
return self.path < other.path
1170+
return self.__vfspath__() < other.__vfspath__()
11561171

11571172
def __le__(self, other: object) -> bool:
11581173
if not isinstance(other, UPath) or self.parser is not other.parser:
11591174
return NotImplemented
1160-
return self.path <= other.path
1175+
return self.__vfspath__() <= other.__vfspath__()
11611176

11621177
def __gt__(self, other: object) -> bool:
11631178
if not isinstance(other, UPath) or self.parser is not other.parser:
11641179
return NotImplemented
1165-
return self.path > other.path
1180+
return self.__vfspath__() > other.__vfspath__()
11661181

11671182
def __ge__(self, other: object) -> bool:
11681183
if not isinstance(other, UPath) or self.parser is not other.parser:
11691184
return NotImplemented
1170-
return self.path >= other.path
1185+
return self.__vfspath__() >= other.__vfspath__()
11711186

11721187
def resolve(self, strict: bool = False) -> Self:
11731188
if self._relative_base is not None:

upath/implementations/cloud.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def root(self) -> str:
6060
return ""
6161
return self.parser.sep
6262

63-
def __vfspath__(self) -> str:
64-
path = super().__vfspath__()
63+
def __str__(self) -> str:
64+
path = super().__str__()
6565
if self._relative_base is None:
6666
drive = self.parser.splitdrive(path)[0]
6767
if drive and path == f"{self.protocol}://{drive}":
@@ -71,7 +71,11 @@ def __vfspath__(self) -> str:
7171
@property
7272
def path(self) -> str:
7373
self_path = super().path
74-
if self._relative_base is None and self.parser.sep not in self_path:
74+
if (
75+
self._relative_base is None
76+
and self_path
77+
and self.parser.sep not in self_path
78+
):
7579
return self_path + self.root
7680
return self_path
7781

upath/tests/cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def test_as_uri(self):
512512
# test that we can reconstruct the path from the uri
513513
p0 = self.path
514514
uri = p0.as_uri()
515-
p1 = UPath(uri, **p0.fs.storage_options)
515+
p1 = UPath(uri, **p0.storage_options)
516516
assert p0 == p1
517517

518518
def test_protocol(self):

upath/tests/implementations/test_tar.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,19 @@ def test_move_memory(self, clear_fsspec_memory_cache):
8282
@pytest.mark.skip(reason="Only testing read on TarPath")
8383
def test_move_into_memory(self, clear_fsspec_memory_cache):
8484
pass
85+
86+
87+
@pytest.fixture(scope="function")
88+
def tarred_testdir_file_in_memory(tarred_testdir_file, clear_fsspec_memory_cache):
89+
p = UPath(tarred_testdir_file, protocol="file")
90+
t = p.move(UPath("memory:///mytarfile.tar"))
91+
assert t.protocol == "memory"
92+
assert t.exists()
93+
yield t.as_uri()
94+
95+
96+
class TestChainedTarPath(TestTarPath):
97+
98+
@pytest.fixture(autouse=True)
99+
def path(self, tarred_testdir_file_in_memory):
100+
self.path = UPath("tar://::memory:///mytarfile.tar")

upath/tests/implementations/test_zip.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,25 @@ def test_write_text(self):
145145
@pytest.mark.skip(reason="fsspec zipfile filesystem is either read xor write mode")
146146
def test_fsspec_compat(self):
147147
pass
148+
149+
150+
@pytest.fixture(scope="function")
151+
def zipped_testdir_file_in_memory(zipped_testdir_file, clear_fsspec_memory_cache):
152+
p = UPath(zipped_testdir_file, protocol="file")
153+
t = p.move(UPath("memory:///myzipfile.zip"))
154+
assert t.protocol == "memory"
155+
assert t.exists()
156+
yield t.as_uri()
157+
158+
159+
class TestChainedZipPath(TestZipPath):
160+
161+
@pytest.fixture(autouse=True)
162+
def path(self, zipped_testdir_file_in_memory, request):
163+
try:
164+
(mode,) = request.param
165+
except (ValueError, TypeError, AttributeError):
166+
mode = "r"
167+
self.path = UPath(
168+
"zip://", fo="/myzipfile.zip", mode=mode, target_protocol="memory"
169+
)

upath/tests/test_chain.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fsspec.implementations.memory import MemoryFileSystem
66

77
from upath import UPath
8+
from upath._chain import FSSpecChainParser
89

910

1011
@pytest.mark.parametrize(
@@ -115,3 +116,21 @@ def test_write_file(clear_memory_fs):
115116
pth = UPath("simplecache::memory://abc.txt")
116117
pth.write_bytes(b"hello world")
117118
assert clear_memory_fs.cat_file("abc.txt") == b"hello world"
119+
120+
121+
@pytest.mark.parametrize(
122+
"urlpath",
123+
[
124+
"memory:///file.txt",
125+
"simplecache::memory:///tmp",
126+
"zip://file.txt::memory:///tmp.zip",
127+
"zip://a/b/c.txt::simplecache::memory:///zipfile.zip",
128+
"simplecache::zip://a/b/c.txt::tar://blah.zip::memory:///file.tar",
129+
],
130+
)
131+
def test_chain_parser_roundtrip(urlpath: str):
132+
parser = FSSpecChainParser()
133+
segments = parser.unchain(urlpath, {})
134+
rechained, kw = parser.chain(segments)
135+
assert rechained == urlpath
136+
assert kw == {}

0 commit comments

Comments
 (0)