diff --git a/fsspec/implementations/tests/test_zip.py b/fsspec/implementations/tests/test_zip.py index 59559c0b9..74696ee2c 100644 --- a/fsspec/implementations/tests/test_zip.py +++ b/fsspec/implementations/tests/test_zip.py @@ -1,5 +1,6 @@ import collections.abc import os.path +import zipfile from pathlib import Path from shutil import make_archive @@ -161,6 +162,17 @@ def zip_file_fixture(tmp_path): return Path(make_archive(zip_file, "zip", data_dir)) +@pytest.fixture(name="zip_file2") +def zip_file_fixture2(tmp_path): + file_path = tmp_path / "zip_file2.zip" + + with zipfile.ZipFile(file_path, "w") as z: + z.writestr("a/b/c", "") + z.writestr("a/b/d/e", "") + + return file_path + + def _assert_all_except_context_dependent_variables(result, expected_result): for path in expected_result: assert result[path] @@ -480,3 +492,43 @@ def test_find_returns_expected_result_recursion_depth_set(zip_file): ] assert result == expected_result + + +@pytest.mark.parametrize( + "args,expected_result", + [ + pytest.param( + ("a/b", 1), + ["a/b/c"], + id="find-maxdepth-correct-depth", + ), + pytest.param( + ("a/b", None, True), + ["a/b", "a/b/c", "a/b/d", "a/b/d/e"], + id="find-withdirs-should-not-include-parents", + ), + pytest.param( + ("a/b", 1, True), + ["a/b", "a/b/c", "a/b/d"], + id="find-withdirs-maxdepth", + ), + pytest.param( + ("/a//b///", 1, True), + ["a/b", "a/b/c", "a/b/d"], + id="find-ill-formed-path", + ), + pytest.param( + ("\\a\\\\b\\", 1, True), + ["a/b", "a/b/c", "a/b/d"], + id="find-ill-formed-path-windows", + ), + pytest.param( + (Path("\\a\\\\b\\"), 1, True), + ["a/b", "a/b/c", "a/b/d"], + id="find-using-pathobj", + ), + ], +) +def test_find_generic(zip_file2, args, expected_result): + zip_file_system = ZipFileSystem(zip_file2) + assert zip_file_system.find(*args) == expected_result diff --git a/fsspec/implementations/zip.py b/fsspec/implementations/zip.py index 6db3ae278..d55d54665 100644 --- a/fsspec/implementations/zip.py +++ b/fsspec/implementations/zip.py @@ -138,14 +138,17 @@ def find(self, path, maxdepth=None, withdirs=False, detail=False, **kwargs): if maxdepth is not None and maxdepth < 1: raise ValueError("maxdepth must be at least 1") + def to_parts(_path: str): + return list(filter(None, _path.replace("\\", "/").split("/"))) + + if not isinstance(path, str): + path = str(path) + # Remove the leading slash, as the zip file paths are always # given without a leading slash path = path.lstrip("/") - path_parts = list(filter(lambda s: bool(s), path.split("/"))) - - def _matching_starts(file_path): - file_parts = filter(lambda s: bool(s), file_path.split("/")) - return all(a == b for a, b in zip(path_parts, file_parts)) + path_parts = to_parts(path) + path_depth = len(path_parts) self._get_dirs() @@ -157,21 +160,22 @@ def _matching_starts(file_path): return result if detail else [path] for file_path, file_info in self.dir_cache.items(): - if not (path == "" or _matching_starts(file_path)): + if len(file_parts := to_parts(file_path)) < path_depth or any( + a != b for a, b in zip(path_parts, file_parts) + ): + # skip parent folders and mismatching paths continue if file_info["type"] == "directory": - if withdirs: - if file_path not in result: - result[file_path.strip("/")] = file_info + if withdirs and file_path not in result: + result[file_path.strip("/")] = file_info continue if file_path not in result: result[file_path] = file_info if detail else None if maxdepth: - path_depth = path.count("/") result = { - k: v for k, v in result.items() if k.count("/") - path_depth < maxdepth + k: v for k, v in result.items() if k.count("/") < maxdepth + path_depth } return result if detail else sorted(result)