diff --git a/avex/utils/utils.py b/avex/utils/utils.py index c201771..61e9365 100644 --- a/avex/utils/utils.py +++ b/avex/utils/utils.py @@ -6,7 +6,9 @@ from __future__ import annotations +import hashlib import io +import json import logging import os import tempfile @@ -19,6 +21,153 @@ logger = logging.getLogger(__name__) +CACHE_META_SUFFIX = ".avex_cache_meta.json" + + +def _cache_meta_path(cache_path: Path) -> Path: + """Return the sidecar metadata path for a cached file. + + Returns + ------- + Path + Sidecar metadata path for ``cache_path``. + """ + return cache_path.with_name(cache_path.name + CACHE_META_SUFFIX) + + +def _read_cache_meta(cache_path: Path) -> dict[str, Any] | None: # noqa: ANN401 + """Read cache sidecar metadata if present. + + Returns + ------- + dict[str, Any] | None + Parsed metadata if the sidecar file exists and can be read, otherwise + ``None``. + """ + meta_path = _cache_meta_path(cache_path) + if not meta_path.exists(): + return None + try: + return json.loads(meta_path.read_text(encoding="utf-8")) + except Exception as e: # pragma: no cover - defensive + logger.debug("Failed to read cache meta %s: %s", meta_path, e) + return None + + +def _write_cache_meta(cache_path: Path, meta: dict[str, Any]) -> None: # noqa: ANN401 + """Write cache sidecar metadata.""" + meta_path = _cache_meta_path(cache_path) + try: + meta_path.write_text(json.dumps(meta, sort_keys=True), encoding="utf-8") + except Exception as e: # pragma: no cover - defensive + logger.debug("Failed to write cache meta %s: %s", meta_path, e) + + +def _remote_version_token(fs: Any, path: AnyPathT) -> str | None: # noqa: ANN401 + """Best-effort remote version token without downloading. + + Uses fsspec filesystem metadata (`fs.info`) when available. Different backends + expose different fields; we normalize a token from whichever stable identifiers + are present (e.g. etag, md5/crc32c/sha256, generation/versionId). + + Returns + ------- + str | None + A stable-ish version token derived from remote metadata when available, + otherwise ``None``. + """ + try: + info = fs.info(str(path)) + except Exception as e: # pragma: no cover - defensive + logger.debug("Failed to stat remote path %s: %s", path, e) + return None + + if not isinstance(info, dict): + return None + + candidates: list[str] = [] + for key in ( + "etag", + "ETag", + "md5", + "md5Hash", + "crc32c", + "sha256", + "generation", + "versionId", + "last_modified", + "mtime", + "size", + ): + if key in info and info[key] is not None: + candidates.append(f"{key}={info[key]}") + + if not candidates: + return None + + return "|".join(candidates) + + +def _download_atomically(fs: Any, src: AnyPathT, dst: Path) -> None: # noqa: ANN401 + """Download `src` to `dst` via a temp file + atomic rename. + + This avoids TOCTOU races (check-then-download) and prevents partially written + cache files from being treated as valid on subsequent runs. + """ + dst.parent.mkdir(parents=True, exist_ok=True) + + with tempfile.NamedTemporaryFile( + dir=dst.parent, + delete=False, + suffix=".tmp", + ) as tmp: + tmp_path = Path(tmp.name) + + try: + fs.get(str(src), str(tmp_path)) + tmp_path.replace(dst) # atomic on POSIX (same filesystem) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception as e: # pragma: no cover - defensive + logger.debug("Failed to clean up temp file %s: %s", tmp_path, e) + raise + + +def _safe_cache_path(cache_root: Path, bucket: str, filename: str) -> Path: + """Build a cache path without trusting user-controlled segments. + + Parameters + ---------- + cache_root: + Root directory for cached files. + bucket: + Bucket / repo identifier from the URI. + filename: + Target cached filename. + + Returns + ------- + Path + A cache path guaranteed (after resolve) to live under ``cache_root``. + + Raises + ------ + ValueError + If the resolved cache path would fall outside ``cache_root``. + """ + bucket_digest = hashlib.sha256(bucket.encode("utf-8")).hexdigest()[:16] + candidate = cache_root / bucket_digest / filename + + # Defense-in-depth: ensure the resolved path stays within cache_root. + root_resolved = cache_root.resolve() + candidate_resolved = candidate.resolve() + if root_resolved not in candidate_resolved.parents and candidate_resolved != root_resolved: + msg = f"Unsafe cache path resolved outside cache_root: {candidate_resolved}" + raise ValueError(msg) + + return candidate + def _get_local_path_for_cloud_file( path: AnyPathT, @@ -49,21 +198,74 @@ def _get_local_path_for_cloud_file( if is_cloud_path: if cache_mode in ["use", "force"]: - # Use cache - if "ESP_CACHE_HOME" in os.environ: - cache_path = Path(os.environ["ESP_CACHE_HOME"]) / path.name - else: - cache_path = Path.home() / ".cache" / "esp" / path.name + cache_root = ( + Path(os.environ["ESP_CACHE_HOME"]) if "ESP_CACHE_HOME" in os.environ else Path.home() / ".cache" / "esp" + ) + + # Avoid collisions across repos / buckets / nested paths by hashing the full URI. + # This keeps caching predictable even if different repos share the same filename. + uri = str(path) + digest = hashlib.sha256(uri.encode("utf-8")).hexdigest()[:16] + suffix = Path(path.name).suffix + cached_name = f"{Path(path.name).stem}-{digest}{suffix}" + + cache_path = _safe_cache_path(cache_root, path.bucket, cached_name) if not cache_path.exists() or cache_mode == "force": download_msg = ( "Force downloading" if cache_mode == "force" else "Cache file does not exist, downloading" ) logger.info(f"{download_msg} to {cache_path}...") - cache_path.parent.mkdir(parents=True, exist_ok=True) - fs.get(str(path), str(cache_path)) + try: + _download_atomically(fs, path, cache_path) + except (OSError, PermissionError) as e: + logger.warning( + "Caching is enabled but cache directory is not writable (%s); " + "falling back to direct cloud read for %s.", + e, + path, + ) + return None + token = _remote_version_token(fs, path) + if token is not None: + _write_cache_meta( + cache_path, + {"remote_version_token": token, "source_uri": str(path)}, + ) else: - logger.debug(f"Found {cache_path}, using local cache.") + # Cache exists. If we can cheaply validate remote metadata, do so. + cached_meta = _read_cache_meta(cache_path) + cached_token = cached_meta.get("remote_version_token") if isinstance(cached_meta, dict) else None + remote_token = _remote_version_token(fs, path) + + if cached_token is None or remote_token is None: + logger.info( + "Cannot validate cache for %s (cached_token=%s, remote_token=%s); using local cache.", + path, + cached_token is not None, + remote_token is not None, + ) + elif cached_token != remote_token: + logger.info( + "Remote object changed for %s; re-downloading to refresh cache.", + path, + ) + try: + _download_atomically(fs, path, cache_path) + except (OSError, PermissionError) as e: + logger.warning( + "Caching is enabled but cache directory is not writable (%s); " + "falling back to direct cloud read for %s.", + e, + path, + ) + return None + _write_cache_meta( + cache_path, + {"remote_version_token": remote_token, "source_uri": str(path)}, + ) + else: + logger.debug(f"Found {cache_path}, using local cache.") return cache_path else: # No caching - return None to indicate direct cloud read @@ -194,7 +396,7 @@ def _load_safetensor( def universal_torch_load( f: str | os.PathLike | AnyPathT, *, - cache_mode: Literal["none", "use", "force"] = "none", + cache_mode: Literal["none", "use", "force"] = "use", **kwargs: Any, # noqa: ANN401 ) -> Any: # noqa: ANN401 """ @@ -220,7 +422,7 @@ def universal_torch_load( "none": No caching (use cloud storage directly) "use": Use cache if available, download if not "force": Force redownload even if cache exists - Defaults to "none". + Defaults to "use". **kwargs: Additional keyword arguments passed to torch.load() or safetensors.torch.load_file(). For safetensors, supports 'device' parameter to load tensors on specific device. diff --git a/tests/unittests/test_cache_validation.py b/tests/unittests/test_cache_validation.py new file mode 100644 index 0000000..29be8aa --- /dev/null +++ b/tests/unittests/test_cache_validation.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import pytest + +from avex.io.paths import PureGSPath +from avex.utils.utils import _get_local_path_for_cloud_file + + +class FakeFS: + """Minimal fsspec-like FS for cache validation tests.""" + + def __init__(self, *, token: str) -> None: + self._token = token + self.get_calls: list[tuple[str, str]] = [] + + def info(self, _path: str) -> dict[str, Any]: + return {"etag": self._token, "size": 123} + + def get(self, src: str, dst: str) -> None: + self.get_calls.append((src, dst)) + Path(dst).write_bytes(b"dummy") + + def set_token(self, token: str) -> None: + self._token = token + + +def test_cache_mode_none_returns_none(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path)) + fs = FakeFS(token="t1") + path = PureGSPath("gs://bucket/file.pt") + + out = _get_local_path_for_cloud_file(path, fs, "none") + + assert out is None + assert fs.get_calls == [] + + +def test_cache_use_downloads_then_reuses_when_token_same(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path)) + fs = FakeFS(token="t1") + path = PureGSPath("gs://bucket/file.pt") + + p1 = _get_local_path_for_cloud_file(path, fs, "use") + assert p1 is not None and p1.exists() + assert len(fs.get_calls) == 1 + + p2 = _get_local_path_for_cloud_file(path, fs, "use") + assert p2 == p1 + assert len(fs.get_calls) == 1 # no re-download + + +def test_cache_use_redownloads_when_remote_token_changes(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path)) + fs = FakeFS(token="t1") + path = PureGSPath("gs://bucket/file.pt") + + p1 = _get_local_path_for_cloud_file(path, fs, "use") + assert p1 is not None and p1.exists() + assert len(fs.get_calls) == 1 + + fs.set_token("t2") + p2 = _get_local_path_for_cloud_file(path, fs, "use") + assert p2 == p1 + assert len(fs.get_calls) == 2 # refreshed due to token mismatch + + +def test_cache_force_always_redownloads(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path)) + fs = FakeFS(token="t1") + path = PureGSPath("gs://bucket/file.pt") + + p1 = _get_local_path_for_cloud_file(path, fs, "force") + assert p1 is not None and p1.exists() + assert len(fs.get_calls) == 1 + + p2 = _get_local_path_for_cloud_file(path, fs, "force") + assert p2 == p1 + assert len(fs.get_calls) == 2 + + +def test_failed_download_does_not_leave_corrupt_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path)) + + class FailingFS(FakeFS): + def get(self, src: str, dst: str) -> None: # type: ignore[override] + self.get_calls.append((src, dst)) + # Simulate a partial write then failure. + Path(dst).write_bytes(b"partial") + raise RuntimeError("network error") + + fs = FailingFS(token="t1") + path = PureGSPath("gs://bucket/file.pt") + + with pytest.raises(RuntimeError, match="network error"): + _ = _get_local_path_for_cloud_file(path, fs, "use") + + # Final cache file should not exist (atomic rename prevents corrupt cache). + # (Directory name is hashed; just ensure no completed cache artifact exists.) + assert not any(p.is_file() and p.suffix != ".tmp" for p in tmp_path.rglob("*")) + + +def test_bucket_is_hashed_in_cache_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path)) + fs = FakeFS(token="t1") + # Crafted "bucket" that could be problematic if used directly. + path = PureGSPath("gs://../file.pt") + + out = _get_local_path_for_cloud_file(path, fs, "use") + + assert out is not None + # Should be cached under a hash directory, not "..". + assert ".." not in out.parts + assert out.resolve().is_relative_to(tmp_path.resolve()) + + +def test_cache_unwritable_falls_back_to_direct_read(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + # Make cache root non-writable. + cache_root = tmp_path / "cache" + cache_root.mkdir() + cache_root.chmod(0o500) # read/execute only + monkeypatch.setenv("ESP_CACHE_HOME", str(cache_root)) + + fs = FakeFS(token="t1") + path = PureGSPath("gs://bucket/file.pt") + + out = _get_local_path_for_cloud_file(path, fs, "use") + assert out is None