Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/stages/deduplication/semantic/test_pairwise_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,37 @@ def test_process_finds_all_centroid_files(self, tmp_path: Path):
assert result[2].task_id == "pairwise_centroid_2"
assert result[2]._metadata == {"centroid_id": 2, "filetype": "parquet"}
assert result[2].data == [str(centroid_2_dir / "file4.parquet"), str(centroid_2_dir / "file5.parquet")]

def test_process_restores_protocol_for_remote_listings(self):
"""Remote fsspec listings may strip protocols; tasks should keep full URLs."""

class FakeRemoteFs:
def unstrip_protocol(self, path: str) -> str:
return path if path.startswith("gs://") else f"gs://{path}"

def ls(self, _path: str) -> list[str]:
return ["bucket/kmeans/centroid=7"]

def expand_path(self, path: str, recursive: bool = False) -> list[str]:
assert recursive is False
return [path]

def isdir(self, _path: str) -> bool:
return True

def find(self, path: str, maxdepth: int | None, withdirs: bool, detail: bool) -> list[str]:
assert path == "gs://bucket/kmeans/centroid=7"
assert maxdepth == 1
assert withdirs is False
assert detail is False
return ["bucket/kmeans/centroid=7/part.0.parquet"]
Comment on lines +131 to +136

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 maxdepth == 1 assertion is wrong and will cause the test to fail. process() calls get_all_file_paths_under(..., recurse_subdirectories=True, ...), and inside _gather_file_records the depth is computed as maxdepth=None if recurse_subdirectories else 1 — so with recurse_subdirectories=True the value passed to find is None, not 1.

Suggested change
def find(self, path: str, maxdepth: int | None, withdirs: bool, detail: bool) -> list[str]:
assert path == "gs://bucket/kmeans/centroid=7"
assert maxdepth == 1
assert withdirs is False
assert detail is False
return ["bucket/kmeans/centroid=7/part.0.parquet"]
def find(self, path: str, maxdepth: int | None, withdirs: bool, detail: bool) -> list[str]:
assert path == "gs://bucket/kmeans/centroid=7"
assert maxdepth is None
assert withdirs is False
assert detail is False
return ["bucket/kmeans/centroid=7/part.0.parquet"]


stage = ClusterWiseFilePartitioningStage("gs://bucket/kmeans")
stage.fs = FakeRemoteFs()
stage.path_normalizer = stage.fs.unstrip_protocol

empty_task = _EmptyTask(task_id="test", dataset_name="test", data=None)
result = stage.process(empty_task)

assert len(result) == 1
assert result[0].data == ["gs://bucket/kmeans/centroid=7/part.0.parquet"]