diff --git a/tests/stages/deduplication/semantic/test_pairwise_io.py b/tests/stages/deduplication/semantic/test_pairwise_io.py index e8d16b1d4d..689463c76c 100644 --- a/tests/stages/deduplication/semantic/test_pairwise_io.py +++ b/tests/stages/deduplication/semantic/test_pairwise_io.py @@ -110,3 +110,37 @@ def test_process_finds_all_centroid_files(self, tmp_path: Path): assert result[2].task_id == "pairwise_centroid_2" assert result[2]._metadata == {"centroid_id": 2, "filetype": "parquet"} assert result[2].data == [str(centroid_2_dir / "file4.parquet"), str(centroid_2_dir / "file5.parquet")] + + def test_process_restores_protocol_for_remote_listings(self): + """Remote fsspec listings may strip protocols; tasks should keep full URLs.""" + + class FakeRemoteFs: + def unstrip_protocol(self, path: str) -> str: + return path if path.startswith("gs://") else f"gs://{path}" + + def ls(self, _path: str) -> list[str]: + return ["bucket/kmeans/centroid=7"] + + def expand_path(self, path: str, recursive: bool = False) -> list[str]: + assert recursive is False + return [path] + + def isdir(self, _path: str) -> bool: + return True + + def find(self, path: str, maxdepth: int | None, withdirs: bool, detail: bool) -> list[str]: + assert path == "gs://bucket/kmeans/centroid=7" + assert maxdepth == 1 + assert withdirs is False + assert detail is False + return ["bucket/kmeans/centroid=7/part.0.parquet"] + + stage = ClusterWiseFilePartitioningStage("gs://bucket/kmeans") + stage.fs = FakeRemoteFs() + stage.path_normalizer = stage.fs.unstrip_protocol + + empty_task = _EmptyTask(task_id="test", dataset_name="test", data=None) + result = stage.process(empty_task) + + assert len(result) == 1 + assert result[0].data == ["gs://bucket/kmeans/centroid=7/part.0.parquet"]