Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 79 additions & 2 deletions src/flyte/_code_bundle/bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

import asyncio
import gzip
import hashlib
import logging
import os
import pathlib
import random
import sqlite3
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Type

Expand All @@ -28,6 +32,56 @@
_pickled_file_extension = ".pkl.gz"
_tar_file_extension = ".tar.gz"

_BUNDLE_CACHE_TTL_DAYS = 1


def _scoped_digest(digest: str) -> str:
"""Return a digest scoped to the current endpoint/project/domain."""
from flyte._persistence._db import _cache_scope

raw = f"{_cache_scope()}:{digest}"
return hashlib.sha256(raw.encode()).hexdigest()


def _read_bundle_cache(digest: str) -> tuple[str, str] | None:
"""Look up a previously uploaded bundle by its file digest. Returns (hash_digest, remote_path) or None."""
from flyte._persistence._db import LocalDB

try:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not crash if database not found

conn = LocalDB.get_sync()
cutoff = time.time() - _BUNDLE_CACHE_TTL_DAYS * 86400
row = conn.execute(
"SELECT hash_digest, remote_path FROM bundle_cache WHERE digest = ? AND created_at > ?",
(_scoped_digest(digest), cutoff),
).fetchone()
# Prune expired entries ~5% of the time to avoid doing it on every read
if random.random() < 0.05:
with LocalDB._write_lock:
conn.execute("DELETE FROM bundle_cache WHERE created_at <= ?", (cutoff,))
conn.commit()
if row:
return row[0], row[1]
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to read bundle cache: {e}")
return None


def _write_bundle_cache(digest: str, hash_digest: str, remote_path: str) -> None:
"""Persist a successfully uploaded bundle to the SQLite cache."""
from flyte._persistence._db import LocalDB

try:
conn = LocalDB.get_sync()
with LocalDB._write_lock:
conn.execute(
"INSERT OR REPLACE INTO bundle_cache (digest, hash_digest, remote_path, created_at) "
"VALUES (?, ?, ?, ?)",
(_scoped_digest(digest), hash_digest, remote_path, time.time()),
)
conn.commit()
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to write bundle cache: {e}")


class _PklCache:
_pkl_cache: ClassVar[AsyncLRUCache[str, str]] = AsyncLRUCache[str, str](maxsize=100)
Expand Down Expand Up @@ -125,6 +179,7 @@ async def build_code_bundle(
dryrun: bool = False,
copy_bundle_to: pathlib.Path | None = None,
copy_style: CopyFiles = "loaded_modules",
skip_cache: bool = False,
) -> CodeBundle:
"""
Build the code bundle for the current environment.
Expand All @@ -135,14 +190,13 @@ async def build_code_bundle(
:param dryrun: If dryrun is enabled, files will not be uploaded to the control plane.
:param copy_bundle_to: If set, the bundle will be copied to this path. This is used for testing purposes.
:param copy_style: What to put into the tarball. (either all, or loaded_modules. if none, skip this function)
:param skip_cache: If true, skip the persistent SQLite cache lookup and always rebuild/re-upload.

:return: The code bundle, which contains the path where the code was zipped to.
"""
if copy_style == "none":
raise ValueError("If copy_style is 'none', just don't make a code bundle")

status.step("Bundling code...")
logger.debug("Building code bundle.")
from flyte.remote import upload_file

if not ignore:
Expand All @@ -163,6 +217,16 @@ async def build_code_bundle(
if logger.getEffectiveLevel() <= logging.INFO:
print_ls_tree(from_dir, files)

# Check persistent cache before creating the tar bundle to avoid unnecessary work
if not dryrun and not skip_cache:
cached = _read_bundle_cache(digest)
if cached:
hash_digest, remote_path = cached
status.success("Code bundle found in cache, skipping upload")
logger.debug(f"Code bundle cache hit: {remote_path}")
return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest, files=files)

status.step("Bundling code...")
logger.debug("Building code bundle.")
with tempfile.TemporaryDirectory() as tmp_dir:
bundle_path, tar_size, archive_size = create_bundle(
Expand All @@ -173,6 +237,7 @@ async def build_code_bundle(
status.step("Uploading code bundle...")
hash_digest, remote_path = await upload_file.aio(bundle_path)
logger.debug(f"Code bundle uploaded to {remote_path}")
_write_bundle_cache(digest, hash_digest, remote_path)
else:
if copy_bundle_to:
remote_path = str(copy_bundle_to / bundle_path.name)
Expand All @@ -198,6 +263,7 @@ async def build_code_bundle_from_relative_paths(
extract_dir: str = ".",
dryrun: bool = False,
copy_bundle_to: pathlib.Path | None = None,
skip_cache: bool = False,
) -> CodeBundle:
"""
Build a code bundle from a list of relative paths.
Expand All @@ -207,6 +273,7 @@ async def build_code_bundle_from_relative_paths(
working directory.
:param dryrun: If dryrun is enabled, files will not be uploaded to the control plane.
:param copy_bundle_to: If set, the bundle will be copied to this path. This is used for testing purposes.
:param skip_cache: If true, skip the persistent SQLite cache lookup and always rebuild/re-upload.
:return: The code bundle, which contains the path where the code was zipped to.
"""
status.step("Bundling code...")
Expand All @@ -218,6 +285,15 @@ async def build_code_bundle_from_relative_paths(
if logger.getEffectiveLevel() <= logging.INFO:
print_ls_tree(from_dir, files)

# Check persistent cache before creating the tar bundle to avoid unnecessary work
if not dryrun and not skip_cache:
cached = _read_bundle_cache(digest)
if cached:
hash_digest, remote_path = cached
status.success("Code bundle found in cache, skipping upload")
logger.debug(f"Code bundle cache hit: {remote_path}")
return CodeBundle(tgz=remote_path, destination=extract_dir, computed_version=hash_digest, files=files)

logger.debug("Building code bundle.")
with tempfile.TemporaryDirectory() as tmp_dir:
bundle_path, tar_size, archive_size = create_bundle(from_dir, pathlib.Path(tmp_dir), files, digest)
Expand All @@ -226,6 +302,7 @@ async def build_code_bundle_from_relative_paths(
status.step("Uploading code bundle...")
hash_digest, remote_path = await upload_file.aio(bundle_path)
logger.debug(f"Code bundle uploaded to {remote_path}")
_write_bundle_cache(digest, hash_digest, remote_path)
else:
remote_path = "na"
if copy_bundle_to:
Expand Down
8 changes: 7 additions & 1 deletion src/flyte/_internal/imagebuild/docker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ImageChecker,
LocalDockerCommandImageChecker,
LocalPodmanCommandImageChecker,
PersistentCacheImageChecker,
)
from flyte._internal.imagebuild.utils import (
copy_files_to_context,
Expand Down Expand Up @@ -582,7 +583,12 @@ class DockerImageBuilder(ImageBuilder):

def get_checkers(self) -> Optional[typing.List[typing.Type[ImageChecker]]]:
# Can get a public token for docker.io but ghcr requires a pat, so harder to get the manifest anonymously
return [LocalDockerCommandImageChecker, LocalPodmanCommandImageChecker, DockerAPIImageChecker]
return [
PersistentCacheImageChecker,
LocalDockerCommandImageChecker,
LocalPodmanCommandImageChecker,
DockerAPIImageChecker,
]

async def build_image(
self, image: Image, dry_run: bool = False, wait: bool = True, force: bool = False
Expand Down
85 changes: 82 additions & 3 deletions src/flyte/_internal/imagebuild/image_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations

import asyncio
import hashlib
import json
import random
import sqlite3
import time
import typing
from importlib.metadata import entry_points
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Tuple
Expand All @@ -13,8 +17,11 @@
from flyte._image import Architecture, Image
from flyte._initialize import _get_init_config
from flyte._logging import logger
from flyte._persistence._db import LocalDB
from flyte._status import status

_IMAGE_CACHE_TTL_DAYS = 30

if TYPE_CHECKING:
from flyte._build import ImageBuild

Expand All @@ -36,7 +43,15 @@ class ImageChecker(Protocol):
@classmethod
async def image_exists(
cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",)
) -> Optional[str]: ...
) -> Optional[str]:
"""
Check whether an image exists in a registry or cache.

Returns the image URI if found, or None if the image definitively does not exist.
Raise an exception if existence cannot be determined (e.g. cache miss, network failure)
so the next checker in the chain gets a chance.
"""
...


class DockerAPIImageChecker(ImageChecker):
Expand Down Expand Up @@ -93,6 +108,65 @@ async def image_exists(
return None


def _cache_key(repository: str, tag: str, arch: Tuple[str, ...]) -> str:
"""Return a stable cache key for an image, scoped to the current endpoint/project/domain."""
from flyte._persistence._db import _cache_scope

raw = f"{_cache_scope()}:{repository}:{tag}:{','.join(sorted(arch))}"
return hashlib.sha256(raw.encode()).hexdigest()


def _read_image_cache(repository: str, tag: str, arch: Tuple[str, ...]) -> Optional[str]:
"""Look up a previously verified image URI by repository, tag, and arch. Returns image_uri or None."""
try:
conn = LocalDB.get_sync()
cutoff = time.time() - _IMAGE_CACHE_TTL_DAYS * 86400
row = conn.execute(
"SELECT image_uri FROM image_cache WHERE key = ? AND created_at > ?",
(_cache_key(repository, tag, arch), cutoff),
).fetchone()
# Prune expired entries ~5% of the time to avoid doing it on every read
if random.random() < 0.05:
with LocalDB._write_lock:
conn.execute("DELETE FROM image_cache WHERE created_at <= ?", (cutoff,))
conn.commit()
if row:
return row[0]
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to read image cache: {e}")
return None


def _write_image_cache(repository: str, tag: str, arch: Tuple[str, ...], image_uri: str) -> None:
"""Persist a verified image URI to the SQLite cache."""
try:
conn = LocalDB.get_sync()
with LocalDB._write_lock:
conn.execute(
"INSERT OR REPLACE INTO image_cache (key, image_uri, created_at) VALUES (?, ?, ?)",
(_cache_key(repository, tag, arch), image_uri, time.time()),
)
conn.commit()
except (OSError, sqlite3.Error) as e:
logger.debug(f"Failed to write image cache: {e}")


class PersistentCacheImageChecker(ImageChecker):
"""Check if image was previously verified and cached in SQLite (~0ms)."""

@classmethod
async def image_exists(
cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",)
) -> Optional[str]:
uri = _read_image_cache(repository, tag, arch)
if uri:
logger.debug(f"Image {uri} found in persistent cache")
return uri
# Cache miss — raise so the next checker in the chain gets a chance.
# Returning None would mean "image definitely doesn't exist".
raise LookupError(f"Image {repository}:{tag} not found in persistent cache")


class LocalDockerCommandImageChecker(ImageChecker):
command_name: ClassVar[str] = "docker"

Expand Down Expand Up @@ -174,12 +248,17 @@ async def image_exists(image: Image) -> Optional[str]:
image_uri = await checker.image_exists(repository, tag, tuple(image.platform))
if image_uri:
logger.debug(f"Image {image_uri} in registry")
return image_uri
# Persist to disk so future process invocations skip network checks
if checker is not PersistentCacheImageChecker:
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the problem is the persistent cache never invalidates. If an image is deleted from the registry, the cache will still say it exists, and the build will be skipped

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but it makes UX way more better

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can solve this, by just having a very short TTL on the cache. this is why i am suggesting using sqlite. Anyways the data is tiny and one row is enough?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short TTL sounds good to me. I'll update it to use sqlite

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Add TTL and cache for code bundle.

_write_image_cache(repository, tag, tuple(image.platform), image_uri)
return image_uri
# Checker ran successfully and returned None — image not found
return None
except Exception as e:
logger.debug(f"Error checking image existence with {checker.__name__}: {e}")
continue

# If all checkers fail, then assume the image exists. This is current flytekit behavior
# All checkers raised exceptions (e.g. network failures) — assume image exists
status.info(f"All checkers failed to check existence of {image.uri}, assuming it exists")
return image.uri

Expand Down
45 changes: 39 additions & 6 deletions src/flyte/_persistence/_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,44 @@
DEFAULT_CACHE_DIR = "~/.flyte"
CACHE_LOCATION = "local-cache/cache.db"


def _cache_scope() -> str:
"""Return a stable string identifying the current endpoint+project+domain.

Used to scope image/bundle cache entries so that different environments
don't collide.
"""
config = auto()
endpoint = config.platform.endpoint or ""
project = config.task.project or ""
domain = config.task.domain or ""
return f"{endpoint}:{project}:{domain}"


_TASK_CACHE_DDL = """
CREATE TABLE IF NOT EXISTS task_cache (
key TEXT PRIMARY KEY,
value BLOB
)
"""

_IMAGE_CACHE_DDL = """
CREATE TABLE IF NOT EXISTS image_cache (
key TEXT PRIMARY KEY,
image_uri TEXT NOT NULL,
created_at REAL NOT NULL
)
"""

_BUNDLE_CACHE_DDL = """
CREATE TABLE IF NOT EXISTS bundle_cache (
digest TEXT PRIMARY KEY,
hash_digest TEXT NOT NULL,
remote_path TEXT NOT NULL,
created_at REAL NOT NULL
)
"""

_RUNS_DDL = """
CREATE TABLE IF NOT EXISTS runs (
run_name TEXT NOT NULL,
Expand Down Expand Up @@ -49,6 +80,8 @@
"""


_ALL_TABLE_DDLS = [_TASK_CACHE_DDL, _RUNS_DDL, _IMAGE_CACHE_DDL, _BUNDLE_CACHE_DDL]

_RUNS_INDEXES = [
"CREATE INDEX IF NOT EXISTS idx_runs_action_start ON runs (action_name, start_time)",
"CREATE INDEX IF NOT EXISTS idx_runs_status_start ON runs (status, start_time)",
Expand Down Expand Up @@ -114,16 +147,16 @@ async def initialize():
async def _initialize_async():
db_path = LocalDB._get_db_path()
conn = await aiosqlite.connect(db_path)
await conn.execute(_TASK_CACHE_DDL)
await conn.execute(_RUNS_DDL)
for ddl in _ALL_TABLE_DDLS:
await conn.execute(ddl)
for idx_stmt in _RUNS_INDEXES:
await conn.execute(idx_stmt)
await conn.commit()
LocalDB._conn = conn
# Also open a sync connection for sync callers
sync_conn = sqlite3.connect(db_path, check_same_thread=False)
sync_conn.execute(_TASK_CACHE_DDL)
sync_conn.execute(_RUNS_DDL)
for ddl in _ALL_TABLE_DDLS:
sync_conn.execute(ddl)
_migrate_sync(sync_conn)
LocalDB._conn_sync = sync_conn
LocalDB._initialized = True
Expand All @@ -140,8 +173,8 @@ def initialize_sync():
def _initialize_sync_inner():
db_path = LocalDB._get_db_path()
conn = sqlite3.connect(db_path, check_same_thread=False)
conn.execute(_TASK_CACHE_DDL)
conn.execute(_RUNS_DDL)
for ddl in _ALL_TABLE_DDLS:
conn.execute(ddl)
_migrate_sync(conn)
LocalDB._conn_sync = conn
LocalDB._initialized = True
Expand Down
Loading
Loading