From 6bf335831de0cfe2c341a9e3cde8b54cb4cc5314 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 16 Sep 2025 18:16:02 -0700 Subject: [PATCH 1/8] wip Signed-off-by: Ketan Umare --- examples/advanced/artifact_example.py | 26 ++++++ src/flyte/artifacts/__init__.py | 35 ++++++++ src/flyte/artifacts/_wrapper.py | 107 +++++++++++++++++++++++++ src/flyte/types/_type_engine.py | 3 + tests/flyte/test_deploy_image_cache.py | 11 ++- tests/flyte/test_image_cache.py | 12 +-- 6 files changed, 185 insertions(+), 9 deletions(-) create mode 100644 examples/advanced/artifact_example.py create mode 100644 src/flyte/artifacts/__init__.py create mode 100644 src/flyte/artifacts/_wrapper.py diff --git a/examples/advanced/artifact_example.py b/examples/advanced/artifact_example.py new file mode 100644 index 000000000..7534fcc5e --- /dev/null +++ b/examples/advanced/artifact_example.py @@ -0,0 +1,26 @@ +import flyte +import flyte.artifacts as artifacts + +env = flyte.TaskEnvironment("artifact_example") + + +@env.task +def create_artifact() -> str: + result = "This is my artifact content" + metadata = artifacts.Metadata( + name="my_artifact", version="1.0", description="An example artifact created in create_artifact task" + ) + return artifacts.new(result, metadata) + + +@env.task +def call_artifact(artifact: str) -> str: + x = create_artifact() + print(x) + return artifact + + +if __name__ == "__main__": + flyte.init() + v = flyte.run(call_artifact, artifact="hello") + print(v.outputs()) diff --git a/src/flyte/artifacts/__init__.py b/src/flyte/artifacts/__init__.py new file mode 100644 index 000000000..c511ed434 --- /dev/null +++ b/src/flyte/artifacts/__init__.py @@ -0,0 +1,35 @@ +""" +Artifacts module + +This module provides a wrapper method to mark certain outputs as artifacts with associated metadata. + +Usage example: +```python +import flyte.artifacts as artifacts + +@env.task +def my_task() -> MyType: + result = MyType(...) + metadata = artifacts.Metadata(name="my_artifact", version="1.0", description="An example artifact") + return artifacts.new(result, metadata) +``` + +Launching with known artifacts: +```python +flyte.run(main, x=flyte.remote.Artifact.get("name", version="1.0")) +``` + +Retireve a set of artifacts and pass them as a list +```python +from flyte.remote import Artifact +flyte.run(main, x=[Artifact.get("name1", version="1.0"), Artifact.get("name2", version="2.0")]) +``` +OR +```python +flyte.run(main, x=flyte.remote.Artifact.list("name_prefix", partition_match="x")) +``` +""" + +from ._wrapper import Metadata, new + +__all__ = ["Metadata", "new"] diff --git a/src/flyte/artifacts/_wrapper.py b/src/flyte/artifacts/_wrapper.py new file mode 100644 index 000000000..3fdc8213d --- /dev/null +++ b/src/flyte/artifacts/_wrapper.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass +from typing import Any, Optional, Protocol, TypeVar, runtime_checkable + +from typing_extensions import ParamSpec + +T = TypeVar("T", contravariant=True) +P = ParamSpec("P", bound=Any) + + +@dataclass +class Metadata: + """Structured metadata for Flyte artifacts.""" + + # Core tracking fields + name: Optional[str] = None + version: Optional[str] = None + description: Optional[str] = None + + +@runtime_checkable +class Artifact(Protocol[T]): + """Protocol for objects wrapped with Flyte metadata.""" + + _flyte_metadata: Metadata + + def get_flyte_metadata(self) -> Metadata: + """Get the Flyte metadata associated with this artifact.""" + ... + + +class ArtifactWrapper: + """Zero-copy wrapper that preserves the original object interface.""" + + __slots__ = ("_flyte_metadata", "_obj") + + def __init__(self, obj: T, metadata: Metadata) -> None: + object.__setattr__(self, "_obj", obj) + object.__setattr__(self, "_flyte_metadata", metadata) + + def __getattr__(self, name: str) -> Any: + return getattr(self._obj, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_obj", "_flyte_metadata"): + object.__setattr__(self, name, value) + else: + setattr(self._obj, name, value) + + def __delattr__(self, name: str) -> None: + if name in ("_obj", "_flyte_metadata"): + raise AttributeError(f"Cannot delete {name}") + delattr(self._obj, name) + + @property + def __class__(self): + """Make isinstance checks work with the wrapped object's type.""" + return type(self._obj) + + def get_flyte_metadata(self) -> Metadata: + """Get a copy of the Flyte metadata.""" + import copy + + return copy.deepcopy(self._flyte_metadata) + + # Forward common special methods for better compatibility + def __str__(self) -> str: + return str(self._obj) + + def __repr__(self) -> str: + return f"Artifact[{type(self._obj).__name__}]({self._obj})" + + def __bool__(self) -> bool: + return bool(self._obj) + + def __len__(self) -> int: + return len(self._obj) + + def __iter__(self): + return iter(self._obj) + + def __call__(self, *args: P.args, **kwargs: P.kwargs): + return self._obj(*args, **kwargs) + + def __getitem__(self, key): + return self._obj[key] + + def __setitem__(self, key, value): + self._obj[key] = value + + def __contains__(self, item): + return item in self._obj + + +def new(obj: T, metadata: Metadata | None = None) -> T: + """ + Wrap an object with Flyte metadata while preserving its type interface. + + Args: + obj: The object to wrap + metadata: Metadata to associate with the object + + Returns: + A zero-copy wrapper that behaves exactly like the original object + but carries additional Flyte metadata accessible via get_flyte_metadata() + """ + wrapper = ArtifactWrapper(obj, metadata or Metadata()) + return wrapper # type: ignore[return-value] diff --git a/src/flyte/types/_type_engine.py b/src/flyte/types/_type_engine.py index b464f1367..134b5038e 100644 --- a/src/flyte/types/_type_engine.py +++ b/src/flyte/types/_type_engine.py @@ -38,6 +38,7 @@ from pydantic import BaseModel from typing_extensions import Annotated, get_args, get_origin +import flyte.artifacts._wrapper import flyte.storage as storage from flyte._logging import logger from flyte._utils.helpers import load_proto_from_file @@ -1083,6 +1084,8 @@ def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expecte async def to_literal( cls, python_val: typing.Any, python_type: Type[T], expected: types_pb2.LiteralType ) -> literals_pb2.Literal: + if isinstance(python_val, flyte.artifacts._wrapper.ArtifactWrapper): + python_val = python_val._obj transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: diff --git a/tests/flyte/test_deploy_image_cache.py b/tests/flyte/test_deploy_image_cache.py index 3cdbb852e..e1f329af6 100644 --- a/tests/flyte/test_deploy_image_cache.py +++ b/tests/flyte/test_deploy_image_cache.py @@ -8,10 +8,13 @@ from flyte._task_environment import TaskEnvironment -@pytest.mark.parametrize("python_version,expected_py_version", [ - (None, "{}.{}".format(sys.version_info.major, sys.version_info.minor)), # Use local python version - ((3, 10), "3.10"), -]) +@pytest.mark.parametrize( + "python_version,expected_py_version", + [ + (None, "{}.{}".format(sys.version_info.major, sys.version_info.minor)), # Use local python version + ((3, 10), "3.10"), + ], +) @pytest.mark.asyncio async def test_create_image_cache_lookup(python_version, expected_py_version): """Test that _build_images creates the correct nested dictionary structure for ImageCache.""" diff --git a/tests/flyte/test_image_cache.py b/tests/flyte/test_image_cache.py index 7ee6e938e..8a131d118 100644 --- a/tests/flyte/test_image_cache.py +++ b/tests/flyte/test_image_cache.py @@ -27,12 +27,14 @@ def test_image_cache_serialization_round_trip(): def test_image_cache_deserialize(): - test_data = ImageCache(image_lookup={ - "auto": {"3.12": "registry.example.com/auto:latest"}, - "test_id": {"3.11": "registry.example.com/test:latest"} - }) + test_data = ImageCache( + image_lookup={ + "auto": {"3.12": "registry.example.com/auto:latest"}, + "test_id": {"3.11": "registry.example.com/test:latest"}, + } + ) serialized = test_data.to_transport - + restored = ImageCache.from_transport(serialized) assert "auto" in restored.image_lookup assert isinstance(restored.image_lookup["auto"], dict) From 4fcf79cf76479d61a48fd4515a3bf12c87341d7f Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 16 Sep 2025 21:44:47 -0700 Subject: [PATCH 2/8] Updated usage example Signed-off-by: Ketan Umare --- examples/advanced/artifact_example.py | 32 +++++++++++-- src/flyte/artifacts/__init__.py | 4 +- src/flyte/artifacts/_wrapper.py | 42 ++++++++++++----- src/flyte/remote/__init__.py | 2 + src/flyte/remote/_artifact.py | 67 +++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 17 deletions(-) create mode 100644 src/flyte/remote/_artifact.py diff --git a/examples/advanced/artifact_example.py b/examples/advanced/artifact_example.py index 7534fcc5e..d3ddd8f2f 100644 --- a/examples/advanced/artifact_example.py +++ b/examples/advanced/artifact_example.py @@ -14,13 +14,39 @@ def create_artifact() -> str: @env.task -def call_artifact(artifact: str) -> str: +def call_artifact() -> str: x = create_artifact() print(x) - return artifact + return x + + +@env.task +async def use_artifact(v: str) -> str: + print(f"Using artifact with content: {v}") + return f"Artifact used with content: {v}" + + +@env.task +async def use_multiple_artifacts(v: list[str]) -> str: + print(f"Using multiple artifacts with contents: {v}") + return f"Multiple artifacts used with contents: {v}" if __name__ == "__main__": flyte.init() - v = flyte.run(call_artifact, artifact="hello") + v = flyte.run(call_artifact) print(v.outputs()) + + from flyte.remote import Artifact + + artifact_instance = Artifact.get("my_artifact", version="1.0") + v2 = flyte.run(use_artifact, v=artifact_instance) + print(v2.outputs()) + + artifact_list = [Artifact.get("my_artifact", version="1.0"), Artifact.get("my_artifact", version="1.0")] + v3 = flyte.run(use_multiple_artifacts, v=artifact_list) + print(v3.outputs()) + + artifact_list_via_prefix = Artifact.list("my_artifact", partition_match="1.0") + v4 = flyte.run(use_multiple_artifacts, v=artifact_list_via_prefix) + print(v4.outputs()) diff --git a/src/flyte/artifacts/__init__.py b/src/flyte/artifacts/__init__.py index c511ed434..e8fb06575 100644 --- a/src/flyte/artifacts/__init__.py +++ b/src/flyte/artifacts/__init__.py @@ -30,6 +30,6 @@ def my_task() -> MyType: ``` """ -from ._wrapper import Metadata, new +from ._wrapper import Artifact, Metadata, new -__all__ = ["Metadata", "new"] +__all__ = ["Artifact", "Metadata", "new"] diff --git a/src/flyte/artifacts/_wrapper.py b/src/flyte/artifacts/_wrapper.py index 3fdc8213d..9335f4f6c 100644 --- a/src/flyte/artifacts/_wrapper.py +++ b/src/flyte/artifacts/_wrapper.py @@ -3,8 +3,9 @@ from typing_extensions import ParamSpec -T = TypeVar("T", contravariant=True) -P = ParamSpec("P", bound=Any) +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +P = ParamSpec("P") @dataclass @@ -12,13 +13,13 @@ class Metadata: """Structured metadata for Flyte artifacts.""" # Core tracking fields - name: Optional[str] = None + name: str version: Optional[str] = None description: Optional[str] = None @runtime_checkable -class Artifact(Protocol[T]): +class Artifact(Protocol[T_co]): """Protocol for objects wrapped with Flyte metadata.""" _flyte_metadata: Metadata @@ -33,7 +34,7 @@ class ArtifactWrapper: __slots__ = ("_flyte_metadata", "_obj") - def __init__(self, obj: T, metadata: Metadata) -> None: + def __init__(self, obj: T_co, metadata: Metadata) -> None: object.__setattr__(self, "_obj", obj) object.__setattr__(self, "_flyte_metadata", metadata) @@ -51,10 +52,27 @@ def __delattr__(self, name: str) -> None: raise AttributeError(f"Cannot delete {name}") delattr(self._obj, name) - @property - def __class__(self): - """Make isinstance checks work with the wrapped object's type.""" - return type(self._obj) + def __getattribute__(self, name: str) -> Any: + """Override attribute access to make __class__ return the wrapped object's type.""" + if name == "__class__": + return type(object.__getattribute__(self, "_obj")) + elif name in ( + "_obj", + "_flyte_metadata", + "get_flyte_metadata", + "__call__", + "__repr__", + "__str__", + "__bool__", + "__len__", + "__iter__", + "__getitem__", + "__setitem__", + "__contains__", + ): + return object.__getattribute__(self, name) + else: + return getattr(object.__getattribute__(self, "_obj"), name) def get_flyte_metadata(self) -> Metadata: """Get a copy of the Flyte metadata.""" @@ -78,7 +96,7 @@ def __len__(self) -> int: def __iter__(self): return iter(self._obj) - def __call__(self, *args: P.args, **kwargs: P.kwargs): + def __call__(self, *args: Any, **kwargs: Any): return self._obj(*args, **kwargs) def __getitem__(self, key): @@ -91,7 +109,7 @@ def __contains__(self, item): return item in self._obj -def new(obj: T, metadata: Metadata | None = None) -> T: +def new(obj: T, metadata: Metadata) -> T: """ Wrap an object with Flyte metadata while preserving its type interface. @@ -103,5 +121,5 @@ def new(obj: T, metadata: Metadata | None = None) -> T: A zero-copy wrapper that behaves exactly like the original object but carries additional Flyte metadata accessible via get_flyte_metadata() """ - wrapper = ArtifactWrapper(obj, metadata or Metadata()) + wrapper = ArtifactWrapper(obj, metadata) return wrapper # type: ignore[return-value] diff --git a/src/flyte/remote/__init__.py b/src/flyte/remote/__init__.py index f27eb8b98..e89714386 100644 --- a/src/flyte/remote/__init__.py +++ b/src/flyte/remote/__init__.py @@ -7,6 +7,7 @@ "ActionDetails", "ActionInputs", "ActionOutputs", + "Artifact", "Project", "Run", "RunDetails", @@ -19,6 +20,7 @@ ] from ._action import Action, ActionDetails, ActionInputs, ActionOutputs +from ._artifact import Artifact from ._client.auth import create_channel from ._data import upload_dir, upload_file from ._project import Project diff --git a/src/flyte/remote/_artifact.py b/src/flyte/remote/_artifact.py new file mode 100644 index 000000000..7ab056ae4 --- /dev/null +++ b/src/flyte/remote/_artifact.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, AsyncIterator, Dict, Literal + +from flyte.artifacts import Artifact as CoreArtifact # Ensure the core Artifact is imported +from flyte.remote._common import ToJSONMixin +from flyte.syncify import syncify + + +@dataclass +class Artifact(ToJSONMixin): + """ + A class representing a project in the Union API. + """ + + pb2: Any # Replace 'Any' with the actual protobuf type when available + + @syncify + @classmethod + async def get(cls, name: str, version: str | Literal["latest"] = "latest") -> Artifact: + """ + Get an artifact by its name and version. + + :param name: The name of the artifact. + :param version: The version of the artifact. + """ + raise NotImplementedError("Artifact retrieval not yet implemented.") + + @syncify + @classmethod + async def listall( + cls, + name: str | None = None, + created_after: datetime | None = None, + limit: int = -1, + **partition_match: Dict[str, str], + ) -> AsyncIterator[Artifact]: + """ + List artifacts by name prefix and optional partition match. + + :param name: The name prefix of the artifacts. + :param created_after: Filter artifacts created after this datetime. + :param limit: The maximum number of artifacts to return. -1 for no limit. + :param partition_match: Key-value pairs to filter artifacts by partition. + :return: A list of artifacts. + """ + raise NotImplementedError("Artifact listing not yet implemented.") + + @syncify + @classmethod + async def create(cls, artifact: CoreArtifact) -> Artifact: + """ + Create a new artifact in the remote system. + + :param artifact: The core Artifact instance to create remotely. + :return: The created Artifact instance with remote metadata. + """ + raise NotImplementedError("Artifact creation not yet implemented.") + + @syncify + async def delete(self) -> None: + """ + Delete this artifact from the remote system. + """ + raise NotImplementedError("Artifact deletion not yet implemented.") From 7cb9ddf0c6d4221cefcf2a8f1e180d7f54205ab1 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Tue, 16 Sep 2025 21:56:35 -0700 Subject: [PATCH 3/8] Updated Signed-off-by: Ketan Umare --- examples/advanced/artifact_example.py | 2 +- src/flyte/_run.py | 16 ++++++++++-- src/flyte/remote/_artifact.py | 37 +++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/examples/advanced/artifact_example.py b/examples/advanced/artifact_example.py index d3ddd8f2f..53e5a330f 100644 --- a/examples/advanced/artifact_example.py +++ b/examples/advanced/artifact_example.py @@ -47,6 +47,6 @@ async def use_multiple_artifacts(v: list[str]) -> str: v3 = flyte.run(use_multiple_artifacts, v=artifact_list) print(v3.outputs()) - artifact_list_via_prefix = Artifact.list("my_artifact", partition_match="1.0") + artifact_list_via_prefix = list(Artifact.listall("my_artifact", version="1.0")) v4 = flyte.run(use_multiple_artifacts, v=artifact_list_via_prefix) print(v4.outputs()) diff --git a/src/flyte/_run.py b/src/flyte/_run.py index 644b9a4c7..398081a08 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -443,13 +443,25 @@ async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.k mode="local", ) with ctx.replace_task_context(tctx): + new_kwargs = {} + if kwargs: + from flyte.remote import Artifact + + for k, v in kwargs.items(): + if isinstance(v, Artifact): + new_kwargs[k] = v.pb2["data"] + elif isinstance(v, list) and len(v) > 0: + new_kwargs[k] = [item.pb2["data"] if isinstance(item, Artifact) else item for item in v] + else: + new_kwargs[k] = v + print(new_kwargs) # make the local version always runs on a different thread, returns a wrapped future. if obj._call_as_synchronous: - fut = controller.submit_sync(obj, *args, **kwargs) + fut = controller.submit_sync(obj, *args, **new_kwargs) awaitable = asyncio.wrap_future(fut) outputs = await awaitable else: - outputs = await controller.submit(obj, *args, **kwargs) + outputs = await controller.submit(obj, *args, **new_kwargs) class _LocalRun(Run): def __init__(self, outputs: Tuple[Any, ...] | Any): diff --git a/src/flyte/remote/_artifact.py b/src/flyte/remote/_artifact.py index 7ab056ae4..605a545a0 100644 --- a/src/flyte/remote/_artifact.py +++ b/src/flyte/remote/_artifact.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncIterator, Dict, Literal +from typing import Any, AsyncIterator, Literal from flyte.artifacts import Artifact as CoreArtifact # Ensure the core Artifact is imported from flyte.remote._common import ToJSONMixin @@ -26,6 +26,17 @@ async def get(cls, name: str, version: str | Literal["latest"] = "latest") -> Ar :param name: The name of the artifact. :param version: The version of the artifact. """ + if name == "my_artifact": + return Artifact( + pb2={ + "metadata": { + "name": "my_artifact", + "version": "1.0", + "description": "An example artifact", + }, + "data": "This is my artifact content", + }, + ) raise NotImplementedError("Artifact retrieval not yet implemented.") @syncify @@ -35,7 +46,7 @@ async def listall( name: str | None = None, created_after: datetime | None = None, limit: int = -1, - **partition_match: Dict[str, str], + **partition_match: str, ) -> AsyncIterator[Artifact]: """ List artifacts by name prefix and optional partition match. @@ -46,6 +57,28 @@ async def listall( :param partition_match: Key-value pairs to filter artifacts by partition. :return: A list of artifacts. """ + if name == "my_artifact" and partition_match.get("version") == "1.0": + yield Artifact( + pb2={ + "metadata": { + "name": "my_artifact", + "version": "1.0", + "description": "An example artifact", + }, + "data": "This is my artifact content 1", + }, + ) + yield Artifact( + pb2={ + "metadata": { + "name": "my_artifact", + "version": "1.0", + "description": "An example artifact", + }, + "data": "This is my artifact content 2", + }, + ) + return raise NotImplementedError("Artifact listing not yet implemented.") @syncify From a2cd1ac353f10cd3c6db40fe69afb384784dd206 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Fri, 19 Sep 2025 18:08:13 -0700 Subject: [PATCH 4/8] Updated Signed-off-by: Ketan Umare --- src/flyte/remote/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flyte/remote/__init__.py b/src/flyte/remote/__init__.py index 67ba66fab..896fdacba 100644 --- a/src/flyte/remote/__init__.py +++ b/src/flyte/remote/__init__.py @@ -7,8 +7,8 @@ "ActionDetails", "ActionInputs", "ActionOutputs", - "Phase", "Artifact", + "Phase", "Project", "Run", "RunDetails", From 425a7d376707ccd8895598f4d8dbcd9bb2c7b7f9 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Fri, 19 Sep 2025 22:14:28 -0700 Subject: [PATCH 5/8] more wip Signed-off-by: Ketan Umare --- src/flyte/artifacts/_card.py | 60 +++++++++++++++++++++++++++++++++ src/flyte/artifacts/_wrapper.py | 43 +++++++++++++++-------- 2 files changed, 89 insertions(+), 14 deletions(-) create mode 100644 src/flyte/artifacts/_card.py diff --git a/src/flyte/artifacts/_card.py b/src/flyte/artifacts/_card.py new file mode 100644 index 000000000..35b911a00 --- /dev/null +++ b/src/flyte/artifacts/_card.py @@ -0,0 +1,60 @@ + +@dataclass +class Card(object): + text: str + card_type: CardType = field(default=CardType.UNKNOWN, init=False) + + def serialize_to_string(self, ctx: FlyteContext, variable_name: str) -> typing.Tuple[str, str]: + # only upload if we're running a real task execution + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: + if ctx.user_space_params and ctx.user_space_params.output_metadata_prefix: + output_location = ctx.user_space_params.output_metadata_prefix + reader = StringIO(self.text) + to_path = ctx.file_access.put_raw_data( + reader, upload_prefix=output_location, file_name=f"card_{variable_name}", skip_raw_data_prefix=True + ) + logger.debug( + f"Artifact card detected for {variable_name}, attempting to upload under {output_location}" + ) + logger.info(f"Card uploaded to {to_path} for {variable_name}") + + c = artifacts_pb2.Card( + uri=to_path, + type=self.card_type, + ) + + s = c.SerializeToString() + encoded = b64encode(s).decode("utf-8") + + return _CARD_METADATA_KEY, encoded + + logger.debug(f"Artifact card found on {variable_name}, but not uploading, starts with {self.text[0:100]}") + return _CARD_METADATA_KEY, self.text[0:100] + + +@dataclass +class DataCard(Card): + """ + :param text: DataCard contents. + :param card_type: + """ + + card_type: CardType = CardType.DATASET + + @classmethod + def from_obj(cls, card_obj: typing.Any) -> Card: + return cls(text=str(card_obj)) + + +@dataclass +class ModelCard(Card): + """ + :param text: ModelCard contents. + :param card_type: + """ + + card_type: CardType = CardType.MODEL + + @classmethod + def from_obj(cls, card_obj: typing.Any) -> Card: + return cls(text=str(card_obj)) diff --git a/src/flyte/artifacts/_wrapper.py b/src/flyte/artifacts/_wrapper.py index 9335f4f6c..28f25974e 100644 --- a/src/flyte/artifacts/_wrapper.py +++ b/src/flyte/artifacts/_wrapper.py @@ -1,5 +1,6 @@ +import typing from dataclasses import dataclass -from typing import Any, Optional, Protocol, TypeVar, runtime_checkable +from typing import Any, Optional, Protocol, TypeVar, runtime_checkable, Tuple from typing_extensions import ParamSpec @@ -8,7 +9,7 @@ P = ParamSpec("P") -@dataclass +@dataclass(frozen=True, kw_only=True) class Metadata: """Structured metadata for Flyte artifacts.""" @@ -16,6 +17,20 @@ class Metadata: name: str version: Optional[str] = None description: Optional[str] = None + data: Optional[typing.Mapping[str, str]] = None + card: Optional[Card] = None + + +@dataclass(frozen=True, kw_only=True) +class ModelMetadata(Metadata): + """Metadata specific to machine learning models.""" + framework: Optional[str] = None + model_type: Optional[str] = None + architecture: Optional[str] = None + task: Optional[str] = None + modality: Tuple[str, ...] = ("text",) + serial_format: str = "safetensors" + short_description: Optional[str] = None @runtime_checkable @@ -57,18 +72,18 @@ def __getattribute__(self, name: str) -> Any: if name == "__class__": return type(object.__getattribute__(self, "_obj")) elif name in ( - "_obj", - "_flyte_metadata", - "get_flyte_metadata", - "__call__", - "__repr__", - "__str__", - "__bool__", - "__len__", - "__iter__", - "__getitem__", - "__setitem__", - "__contains__", + "_obj", + "_flyte_metadata", + "get_flyte_metadata", + "__call__", + "__repr__", + "__str__", + "__bool__", + "__len__", + "__iter__", + "__getitem__", + "__setitem__", + "__contains__", ): return object.__getattribute__(self, name) else: From d4688625c39193946a56acaaa63e45c39da7dbce Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Sat, 20 Sep 2025 22:18:17 -0700 Subject: [PATCH 6/8] Artifacts e2e Signed-off-by: Ketan Umare --- src/flyte/artifacts/__init__.py | 6 +- src/flyte/artifacts/_card.py | 108 ++++++++++++++++--------------- src/flyte/artifacts/_metadata.py | 54 ++++++++++++++++ src/flyte/artifacts/_wrapper.py | 54 +++++----------- 4 files changed, 128 insertions(+), 94 deletions(-) create mode 100644 src/flyte/artifacts/_metadata.py diff --git a/src/flyte/artifacts/__init__.py b/src/flyte/artifacts/__init__.py index e8fb06575..4c39e649a 100644 --- a/src/flyte/artifacts/__init__.py +++ b/src/flyte/artifacts/__init__.py @@ -30,6 +30,8 @@ def my_task() -> MyType: ``` """ -from ._wrapper import Artifact, Metadata, new +from ._card import Card, CardFormat, CardType +from ._metadata import Metadata +from ._wrapper import Artifact, new -__all__ = ["Artifact", "Metadata", "new"] +__all__ = ["Artifact", "Card", "CardFormat", "CardType", "Metadata", "new"] diff --git a/src/flyte/artifacts/_card.py b/src/flyte/artifacts/_card.py index 35b911a00..14f704911 100644 --- a/src/flyte/artifacts/_card.py +++ b/src/flyte/artifacts/_card.py @@ -1,60 +1,62 @@ +from __future__ import annotations -@dataclass -class Card(object): - text: str - card_type: CardType = field(default=CardType.UNKNOWN, init=False) - - def serialize_to_string(self, ctx: FlyteContext, variable_name: str) -> typing.Tuple[str, str]: - # only upload if we're running a real task execution - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: - if ctx.user_space_params and ctx.user_space_params.output_metadata_prefix: - output_location = ctx.user_space_params.output_metadata_prefix - reader = StringIO(self.text) - to_path = ctx.file_access.put_raw_data( - reader, upload_prefix=output_location, file_name=f"card_{variable_name}", skip_raw_data_prefix=True - ) - logger.debug( - f"Artifact card detected for {variable_name}, attempting to upload under {output_location}" - ) - logger.info(f"Card uploaded to {to_path} for {variable_name}") - - c = artifacts_pb2.Card( - uri=to_path, - type=self.card_type, - ) - - s = c.SerializeToString() - encoded = b64encode(s).decode("utf-8") - - return _CARD_METADATA_KEY, encoded - - logger.debug(f"Artifact card found on {variable_name}, but not uploading, starts with {self.text[0:100]}") - return _CARD_METADATA_KEY, self.text[0:100] - - -@dataclass -class DataCard(Card): - """ - :param text: DataCard contents. - :param card_type: - """ - - card_type: CardType = CardType.DATASET +import pathlib +import tempfile +from dataclasses import dataclass +from typing import Literal - @classmethod - def from_obj(cls, card_obj: typing.Any) -> Card: - return cls(text=str(card_obj)) +import flyte +from flyte import storage +CardType = Literal["model", "data", "generic"] +CardFormat = Literal["html", "md", "json", "yaml", "csv", "tsv", "png", "jpg", "jpeg"] -@dataclass -class ModelCard(Card): - """ - :param text: ModelCard contents. - :param card_type: - """ - card_type: CardType = CardType.MODEL +@dataclass(frozen=True, kw_only=True) +class Card(object): + uri: str + format: CardFormat = "html" + card_type: CardType = "generic" @classmethod - def from_obj(cls, card_obj: typing.Any) -> Card: - return cls(text=str(card_obj)) + async def create_from( + cls, + *, + content: str | None = None, + local_path: pathlib.Path | None = None, + format: CardFormat = "html", + card_type: CardType = "generic", + ) -> Card: + """ + Upload a card either from raw content or from a local file path. + + :param content: Raw content of the card to be uploaded. + :param local_path: Local file path of the card to be uploaded. + :param format: Format of the card (e.g., 'html', 'md', + 'json', 'yaml', 'csv', 'tsv', 'png', 'jpg', 'jpeg'). + :param card_type: Type of the card (e.g., 'model', 'data', 'generic'). + """ + if content: + with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False) as temp_file: + temp_file.write(content) + temp_path = pathlib.Path(temp_file.name) + return await _upload_card_from_local(temp_path, format=format, card_type=card_type) + if local_path: + return await _upload_card_from_local(local_path, format=format, card_type=card_type) + raise ValueError("Either content or local_path must be provided to upload a card.") + + +async def _upload_card_from_local( + local_path: pathlib.Path, format: CardFormat = "html", card_type: CardType = "generic" +) -> Card: + # Implement upload. If in task context, upload to current metadata location, if not, upload using control plane. + uri = "" + ctx = flyte.ctx() + if ctx and ctx.is_in_cluster(): + output_path = ctx.output_path + "/" + f"{card_type}.{format}" + uri = await storage.put(str(local_path), output_path) + else: + import flyte.remote as remote + + _, uri = await remote.upload_file.aio(local_path) + return Card(uri=uri, format=format, card_type=card_type) diff --git a/src/flyte/artifacts/_metadata.py b/src/flyte/artifacts/_metadata.py new file mode 100644 index 000000000..4becec77c --- /dev/null +++ b/src/flyte/artifacts/_metadata.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass +from typing import Optional, Tuple + +from ._card import Card + + +@dataclass(frozen=True, kw_only=True) +class Metadata: + """Structured metadata for Flyte artifacts.""" + + # Core tracking fields + name: str + version: Optional[str] = None + description: Optional[str] = None + data: Optional[typing.Mapping[str, str]] = None + card: Optional[Card] = None + + @classmethod + def create_model_metadata( + cls, + *, + name: str, + version: Optional[str] = None, + description: Optional[str] = None, + card: Optional[Card] = None, + framework: Optional[str] = None, + model_type: Optional[str] = None, + architecture: Optional[str] = None, + task: Optional[str] = None, + modality: Tuple[str, ...] = ("text",), + serial_format: str = "safetensors", + short_description: Optional[str] = None, + ) -> Metadata: + """ + Helper method to create ModelMetadata. This method sets the data keys specific to models. + """ + return cls( + name=name, + version=version, + description=description, + data={ + "framework": framework or "", + "model_type": model_type or "", + "architecture": architecture or "", + "task": task or "", + "modality": ",".join(modality) if modality else "", + "serial_format": serial_format or "", + "short_description": short_description or "", + }, + card=card, + ) diff --git a/src/flyte/artifacts/_wrapper.py b/src/flyte/artifacts/_wrapper.py index 28f25974e..35b5e4576 100644 --- a/src/flyte/artifacts/_wrapper.py +++ b/src/flyte/artifacts/_wrapper.py @@ -1,38 +1,14 @@ -import typing -from dataclasses import dataclass -from typing import Any, Optional, Protocol, TypeVar, runtime_checkable, Tuple +from typing import Any, Protocol, TypeVar, runtime_checkable from typing_extensions import ParamSpec +from ._metadata import Metadata + T_co = TypeVar("T_co", covariant=True) T = TypeVar("T") P = ParamSpec("P") -@dataclass(frozen=True, kw_only=True) -class Metadata: - """Structured metadata for Flyte artifacts.""" - - # Core tracking fields - name: str - version: Optional[str] = None - description: Optional[str] = None - data: Optional[typing.Mapping[str, str]] = None - card: Optional[Card] = None - - -@dataclass(frozen=True, kw_only=True) -class ModelMetadata(Metadata): - """Metadata specific to machine learning models.""" - framework: Optional[str] = None - model_type: Optional[str] = None - architecture: Optional[str] = None - task: Optional[str] = None - modality: Tuple[str, ...] = ("text",) - serial_format: str = "safetensors" - short_description: Optional[str] = None - - @runtime_checkable class Artifact(Protocol[T_co]): """Protocol for objects wrapped with Flyte metadata.""" @@ -72,18 +48,18 @@ def __getattribute__(self, name: str) -> Any: if name == "__class__": return type(object.__getattribute__(self, "_obj")) elif name in ( - "_obj", - "_flyte_metadata", - "get_flyte_metadata", - "__call__", - "__repr__", - "__str__", - "__bool__", - "__len__", - "__iter__", - "__getitem__", - "__setitem__", - "__contains__", + "_obj", + "_flyte_metadata", + "get_flyte_metadata", + "__call__", + "__repr__", + "__str__", + "__bool__", + "__len__", + "__iter__", + "__getitem__", + "__setitem__", + "__contains__", ): return object.__getattribute__(self, name) else: From a815467829a73638b64700f26935dcd76e80e95a Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Sun, 21 Sep 2025 15:22:51 -0700 Subject: [PATCH 7/8] artifact-improvements Signed-off-by: Ketan Umare --- examples/advanced/artifact_example.py | 33 +++++++++++++++++++++++++-- src/flyte/artifacts/_card.py | 5 ++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/examples/advanced/artifact_example.py b/examples/advanced/artifact_example.py index 53e5a330f..ebed9ca05 100644 --- a/examples/advanced/artifact_example.py +++ b/examples/advanced/artifact_example.py @@ -1,3 +1,5 @@ +from typing import Tuple + import flyte import flyte.artifacts as artifacts @@ -14,10 +16,37 @@ def create_artifact() -> str: @env.task -def call_artifact() -> str: +def model_artifact() -> str: + result = "This is my model artifact content" + card = artifacts.Card.create_from( + content="

Model Card

This is a sample model card.

", + format="html", + card_type="model", + ) + + metadata = artifacts.Metadata.create_model_metadata( + name="my_model_artifact", + version="1.0", + description="An example model artifact created in model_artifact task", + framework="PyTorch", + model_type="Neural Network", + architecture="ResNet50", + task="Image Classification", + modality=("image",), + serial_format="pt", + short_description="A ResNet50 model for image classification tasks.", + card=card, + ) + return artifacts.new(result, metadata) + + +@env.task +def call_artifact() -> Tuple[str, str]: x = create_artifact() print(x) - return x + y = model_artifact() + print(y) + return x, y @env.task diff --git a/src/flyte/artifacts/_card.py b/src/flyte/artifacts/_card.py index 14f704911..863ab7738 100644 --- a/src/flyte/artifacts/_card.py +++ b/src/flyte/artifacts/_card.py @@ -6,7 +6,7 @@ from typing import Literal import flyte -from flyte import storage +from flyte import storage, syncify CardType = Literal["model", "data", "generic"] CardFormat = Literal["html", "md", "json", "yaml", "csv", "tsv", "png", "jpg", "jpeg"] @@ -18,6 +18,7 @@ class Card(object): format: CardFormat = "html" card_type: CardType = "generic" + @syncify.syncify @classmethod async def create_from( cls, @@ -52,7 +53,7 @@ async def _upload_card_from_local( # Implement upload. If in task context, upload to current metadata location, if not, upload using control plane. uri = "" ctx = flyte.ctx() - if ctx and ctx.is_in_cluster(): + if ctx: output_path = ctx.output_path + "/" + f"{card_type}.{format}" uri = await storage.put(str(local_path), output_path) else: From e7c2f2a3f41f2abd45dca95ac0b9cbb63ce48654 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 6 Oct 2025 16:55:46 -0700 Subject: [PATCH 8/8] Updated Signed-off-by: Ketan Umare --- src/flyte/_image.py | 1 - src/flyte/_utils/module_loader.py | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/flyte/_image.py b/src/flyte/_image.py index 10eef097e..2ebee36cc 100644 --- a/src/flyte/_image.py +++ b/src/flyte/_image.py @@ -12,7 +12,6 @@ import rich.repr from packaging.version import Version - if TYPE_CHECKING: from flyte import Secret, SecretRequest diff --git a/src/flyte/_utils/module_loader.py b/src/flyte/_utils/module_loader.py index f73667db8..16552b39d 100644 --- a/src/flyte/_utils/module_loader.py +++ b/src/flyte/_utils/module_loader.py @@ -17,6 +17,7 @@ def load_python_modules(path: Path, recursive: bool = False) -> Tuple[List[str], :return: List of loaded module names, and list of file paths that failed to load """ from rich.progress import BarColumn, Progress, TextColumn, TimeElapsedColumn, TimeRemainingColumn + loaded_modules = [] failed_paths = [] @@ -32,12 +33,12 @@ def load_python_modules(path: Path, recursive: bool = False) -> Tuple[List[str], python_files = glob.glob(str(path / pattern), recursive=recursive) with Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.0f}%", - TimeElapsedColumn(), - TimeRemainingColumn(), - TextColumn("• {task.fields[current_file]}"), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeElapsedColumn(), + TimeRemainingColumn(), + TextColumn("• {task.fields[current_file]}"), ) as progress: task = progress.add_task(f"Loading {len(python_files)} files", total=len(python_files), current_file="") for file_path in python_files: