Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 212 additions & 10 deletions avex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from __future__ import annotations

import hashlib
import io
import json
import logging
import os
import tempfile
Expand All @@ -19,6 +21,153 @@

logger = logging.getLogger(__name__)

CACHE_META_SUFFIX = ".avex_cache_meta.json"


def _cache_meta_path(cache_path: Path) -> Path:
"""Return the sidecar metadata path for a cached file.

Returns
-------
Path
Sidecar metadata path for ``cache_path``.
"""
return cache_path.with_name(cache_path.name + CACHE_META_SUFFIX)


def _read_cache_meta(cache_path: Path) -> dict[str, Any] | None: # noqa: ANN401
"""Read cache sidecar metadata if present.

Returns
-------
dict[str, Any] | None
Parsed metadata if the sidecar file exists and can be read, otherwise
``None``.
"""
meta_path = _cache_meta_path(cache_path)
if not meta_path.exists():
return None
try:
return json.loads(meta_path.read_text(encoding="utf-8"))
except Exception as e: # pragma: no cover - defensive
logger.debug("Failed to read cache meta %s: %s", meta_path, e)
return None


def _write_cache_meta(cache_path: Path, meta: dict[str, Any]) -> None: # noqa: ANN401
"""Write cache sidecar metadata."""
meta_path = _cache_meta_path(cache_path)
try:
meta_path.write_text(json.dumps(meta, sort_keys=True), encoding="utf-8")
except Exception as e: # pragma: no cover - defensive
logger.debug("Failed to write cache meta %s: %s", meta_path, e)


def _remote_version_token(fs: Any, path: AnyPathT) -> str | None: # noqa: ANN401
"""Best-effort remote version token without downloading.

Uses fsspec filesystem metadata (`fs.info`) when available. Different backends
expose different fields; we normalize a token from whichever stable identifiers
are present (e.g. etag, md5/crc32c/sha256, generation/versionId).

Returns
-------
str | None
A stable-ish version token derived from remote metadata when available,
otherwise ``None``.
"""
try:
info = fs.info(str(path))
except Exception as e: # pragma: no cover - defensive
logger.debug("Failed to stat remote path %s: %s", path, e)
return None

if not isinstance(info, dict):
return None

candidates: list[str] = []
for key in (
"etag",
"ETag",
"md5",
"md5Hash",
"crc32c",
"sha256",
"generation",
"versionId",
"last_modified",
"mtime",
"size",
):
if key in info and info[key] is not None:
candidates.append(f"{key}={info[key]}")

if not candidates:
return None

return "|".join(candidates)


def _download_atomically(fs: Any, src: AnyPathT, dst: Path) -> None: # noqa: ANN401
"""Download `src` to `dst` via a temp file + atomic rename.

This avoids TOCTOU races (check-then-download) and prevents partially written
cache files from being treated as valid on subsequent runs.
"""
dst.parent.mkdir(parents=True, exist_ok=True)

with tempfile.NamedTemporaryFile(
dir=dst.parent,
delete=False,
suffix=".tmp",
) as tmp:
tmp_path = Path(tmp.name)

try:
fs.get(str(src), str(tmp_path))
tmp_path.replace(dst) # atomic on POSIX (same filesystem)
except Exception:
try:
tmp_path.unlink(missing_ok=True)
except Exception as e: # pragma: no cover - defensive
logger.debug("Failed to clean up temp file %s: %s", tmp_path, e)
raise


def _safe_cache_path(cache_root: Path, bucket: str, filename: str) -> Path:
"""Build a cache path without trusting user-controlled segments.

Parameters
----------
cache_root:
Root directory for cached files.
bucket:
Bucket / repo identifier from the URI.
filename:
Target cached filename.

Returns
-------
Path
A cache path guaranteed (after resolve) to live under ``cache_root``.

Raises
------
ValueError
If the resolved cache path would fall outside ``cache_root``.
"""
bucket_digest = hashlib.sha256(bucket.encode("utf-8")).hexdigest()[:16]
candidate = cache_root / bucket_digest / filename

# Defense-in-depth: ensure the resolved path stays within cache_root.
root_resolved = cache_root.resolve()
candidate_resolved = candidate.resolve()
if root_resolved not in candidate_resolved.parents and candidate_resolved != root_resolved:
msg = f"Unsafe cache path resolved outside cache_root: {candidate_resolved}"
raise ValueError(msg)

return candidate


def _get_local_path_for_cloud_file(
path: AnyPathT,
Expand Down Expand Up @@ -49,21 +198,74 @@ def _get_local_path_for_cloud_file(

if is_cloud_path:
if cache_mode in ["use", "force"]:
# Use cache
if "ESP_CACHE_HOME" in os.environ:
cache_path = Path(os.environ["ESP_CACHE_HOME"]) / path.name
else:
cache_path = Path.home() / ".cache" / "esp" / path.name
cache_root = (
Path(os.environ["ESP_CACHE_HOME"]) if "ESP_CACHE_HOME" in os.environ else Path.home() / ".cache" / "esp"
)

# Avoid collisions across repos / buckets / nested paths by hashing the full URI.
# This keeps caching predictable even if different repos share the same filename.
uri = str(path)
digest = hashlib.sha256(uri.encode("utf-8")).hexdigest()[:16]
suffix = Path(path.name).suffix
cached_name = f"{Path(path.name).stem}-{digest}{suffix}"

cache_path = _safe_cache_path(cache_root, path.bucket, cached_name)

if not cache_path.exists() or cache_mode == "force":
download_msg = (
"Force downloading" if cache_mode == "force" else "Cache file does not exist, downloading"
)
logger.info(f"{download_msg} to {cache_path}...")
cache_path.parent.mkdir(parents=True, exist_ok=True)
fs.get(str(path), str(cache_path))
try:
_download_atomically(fs, path, cache_path)
except (OSError, PermissionError) as e:
logger.warning(
"Caching is enabled but cache directory is not writable (%s); "
"falling back to direct cloud read for %s.",
e,
path,
)
return None
token = _remote_version_token(fs, path)
if token is not None:
_write_cache_meta(
cache_path,
{"remote_version_token": token, "source_uri": str(path)},
)
else:
logger.debug(f"Found {cache_path}, using local cache.")
# Cache exists. If we can cheaply validate remote metadata, do so.
cached_meta = _read_cache_meta(cache_path)
cached_token = cached_meta.get("remote_version_token") if isinstance(cached_meta, dict) else None
remote_token = _remote_version_token(fs, path)

if cached_token is None or remote_token is None:
logger.info(
"Cannot validate cache for %s (cached_token=%s, remote_token=%s); using local cache.",
path,
cached_token is not None,
remote_token is not None,
)
elif cached_token != remote_token:
logger.info(
"Remote object changed for %s; re-downloading to refresh cache.",
path,
)
try:
_download_atomically(fs, path, cache_path)
except (OSError, PermissionError) as e:
logger.warning(
"Caching is enabled but cache directory is not writable (%s); "
"falling back to direct cloud read for %s.",
e,
path,
)
return None
_write_cache_meta(
cache_path,
{"remote_version_token": remote_token, "source_uri": str(path)},
)
else:
logger.debug(f"Found {cache_path}, using local cache.")
return cache_path
else:
# No caching - return None to indicate direct cloud read
Expand Down Expand Up @@ -194,7 +396,7 @@ def _load_safetensor(
def universal_torch_load(
f: str | os.PathLike | AnyPathT,
*,
cache_mode: Literal["none", "use", "force"] = "none",
cache_mode: Literal["none", "use", "force"] = "use",
**kwargs: Any, # noqa: ANN401
) -> Any: # noqa: ANN401
"""
Expand All @@ -220,7 +422,7 @@ def universal_torch_load(
"none": No caching (use cloud storage directly)
"use": Use cache if available, download if not
"force": Force redownload even if cache exists
Defaults to "none".
Defaults to "use".
**kwargs: Additional keyword arguments passed to torch.load() or safetensors.torch.load_file().
For safetensors, supports 'device' parameter to load tensors on specific device.

Expand Down
130 changes: 130 additions & 0 deletions tests/unittests/test_cache_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

import pytest

from avex.io.paths import PureGSPath
from avex.utils.utils import _get_local_path_for_cloud_file


class FakeFS:
"""Minimal fsspec-like FS for cache validation tests."""

def __init__(self, *, token: str) -> None:
self._token = token
self.get_calls: list[tuple[str, str]] = []

def info(self, _path: str) -> dict[str, Any]:
return {"etag": self._token, "size": 123}

def get(self, src: str, dst: str) -> None:
self.get_calls.append((src, dst))
Path(dst).write_bytes(b"dummy")

def set_token(self, token: str) -> None:
self._token = token


def test_cache_mode_none_returns_none(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path))
fs = FakeFS(token="t1")
path = PureGSPath("gs://bucket/file.pt")

out = _get_local_path_for_cloud_file(path, fs, "none")

assert out is None
assert fs.get_calls == []


def test_cache_use_downloads_then_reuses_when_token_same(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path))
fs = FakeFS(token="t1")
path = PureGSPath("gs://bucket/file.pt")

p1 = _get_local_path_for_cloud_file(path, fs, "use")
assert p1 is not None and p1.exists()
assert len(fs.get_calls) == 1

p2 = _get_local_path_for_cloud_file(path, fs, "use")
assert p2 == p1
assert len(fs.get_calls) == 1 # no re-download


def test_cache_use_redownloads_when_remote_token_changes(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path))
fs = FakeFS(token="t1")
path = PureGSPath("gs://bucket/file.pt")

p1 = _get_local_path_for_cloud_file(path, fs, "use")
assert p1 is not None and p1.exists()
assert len(fs.get_calls) == 1

fs.set_token("t2")
p2 = _get_local_path_for_cloud_file(path, fs, "use")
assert p2 == p1
assert len(fs.get_calls) == 2 # refreshed due to token mismatch


def test_cache_force_always_redownloads(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path))
fs = FakeFS(token="t1")
path = PureGSPath("gs://bucket/file.pt")

p1 = _get_local_path_for_cloud_file(path, fs, "force")
assert p1 is not None and p1.exists()
assert len(fs.get_calls) == 1

p2 = _get_local_path_for_cloud_file(path, fs, "force")
assert p2 == p1
assert len(fs.get_calls) == 2


def test_failed_download_does_not_leave_corrupt_cache(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path))

class FailingFS(FakeFS):
def get(self, src: str, dst: str) -> None: # type: ignore[override]
self.get_calls.append((src, dst))
# Simulate a partial write then failure.
Path(dst).write_bytes(b"partial")
raise RuntimeError("network error")

fs = FailingFS(token="t1")
path = PureGSPath("gs://bucket/file.pt")

with pytest.raises(RuntimeError, match="network error"):
_ = _get_local_path_for_cloud_file(path, fs, "use")

# Final cache file should not exist (atomic rename prevents corrupt cache).
# (Directory name is hashed; just ensure no completed cache artifact exists.)
assert not any(p.is_file() and p.suffix != ".tmp" for p in tmp_path.rglob("*"))


def test_bucket_is_hashed_in_cache_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("ESP_CACHE_HOME", str(tmp_path))
fs = FakeFS(token="t1")
# Crafted "bucket" that could be problematic if used directly.
path = PureGSPath("gs://../file.pt")

out = _get_local_path_for_cloud_file(path, fs, "use")

assert out is not None
# Should be cached under a hash directory, not "..".
assert ".." not in out.parts
assert out.resolve().is_relative_to(tmp_path.resolve())


def test_cache_unwritable_falls_back_to_direct_read(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
# Make cache root non-writable.
cache_root = tmp_path / "cache"
cache_root.mkdir()
cache_root.chmod(0o500) # read/execute only
monkeypatch.setenv("ESP_CACHE_HOME", str(cache_root))

fs = FakeFS(token="t1")
path = PureGSPath("gs://bucket/file.pt")

out = _get_local_path_for_cloud_file(path, fs, "use")
assert out is None
Loading