diff --git a/examples/advanced/artifact_example.py b/examples/advanced/artifact_example.py new file mode 100644 index 000000000..ebed9ca05 --- /dev/null +++ b/examples/advanced/artifact_example.py @@ -0,0 +1,81 @@ +from typing import Tuple + +import flyte +import flyte.artifacts as artifacts + +env = flyte.TaskEnvironment("artifact_example") + + +@env.task +def create_artifact() -> str: + result = "This is my artifact content" + metadata = artifacts.Metadata( + name="my_artifact", version="1.0", description="An example artifact created in create_artifact task" + ) + return artifacts.new(result, metadata) + + +@env.task +def model_artifact() -> str: + result = "This is my model artifact content" + card = artifacts.Card.create_from( + content="
This is a sample model card.
", + format="html", + card_type="model", + ) + + metadata = artifacts.Metadata.create_model_metadata( + name="my_model_artifact", + version="1.0", + description="An example model artifact created in model_artifact task", + framework="PyTorch", + model_type="Neural Network", + architecture="ResNet50", + task="Image Classification", + modality=("image",), + serial_format="pt", + short_description="A ResNet50 model for image classification tasks.", + card=card, + ) + return artifacts.new(result, metadata) + + +@env.task +def call_artifact() -> Tuple[str, str]: + x = create_artifact() + print(x) + y = model_artifact() + print(y) + return x, y + + +@env.task +async def use_artifact(v: str) -> str: + print(f"Using artifact with content: {v}") + return f"Artifact used with content: {v}" + + +@env.task +async def use_multiple_artifacts(v: list[str]) -> str: + print(f"Using multiple artifacts with contents: {v}") + return f"Multiple artifacts used with contents: {v}" + + +if __name__ == "__main__": + flyte.init() + v = flyte.run(call_artifact) + print(v.outputs()) + + from flyte.remote import Artifact + + artifact_instance = Artifact.get("my_artifact", version="1.0") + v2 = flyte.run(use_artifact, v=artifact_instance) + print(v2.outputs()) + + artifact_list = [Artifact.get("my_artifact", version="1.0"), Artifact.get("my_artifact", version="1.0")] + v3 = flyte.run(use_multiple_artifacts, v=artifact_list) + print(v3.outputs()) + + artifact_list_via_prefix = list(Artifact.listall("my_artifact", version="1.0")) + v4 = flyte.run(use_multiple_artifacts, v=artifact_list_via_prefix) + print(v4.outputs()) diff --git a/src/flyte/_run.py b/src/flyte/_run.py index b889edf56..d52af6d37 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -519,13 +519,25 @@ async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: ) with ctx.replace_task_context(tctx): + new_kwargs = {} + if kwargs: + from flyte.remote import Artifact + + for k, v in kwargs.items(): + if isinstance(v, Artifact): + new_kwargs[k] = v.pb2["data"] + elif isinstance(v, list) and len(v) > 0: + new_kwargs[k] = [item.pb2["data"] if isinstance(item, Artifact) else item for item in v] + else: + new_kwargs[k] = v + print(new_kwargs) # make the local version always runs on a different thread, returns a wrapped future. if obj._call_as_synchronous: - fut = controller.submit_sync(obj, *args, **kwargs) + fut = controller.submit_sync(obj, *args, **new_kwargs) awaitable = asyncio.wrap_future(fut) outputs = await awaitable else: - outputs = await controller.submit(obj, *args, **kwargs) + outputs = await controller.submit(obj, *args, **new_kwargs) class _LocalRun(Run): def __init__(self, outputs: Tuple[Any, ...] | Any): diff --git a/src/flyte/artifacts/__init__.py b/src/flyte/artifacts/__init__.py new file mode 100644 index 000000000..4c39e649a --- /dev/null +++ b/src/flyte/artifacts/__init__.py @@ -0,0 +1,37 @@ +""" +Artifacts module + +This module provides a wrapper method to mark certain outputs as artifacts with associated metadata. + +Usage example: +```python +import flyte.artifacts as artifacts + +@env.task +def my_task() -> MyType: + result = MyType(...) + metadata = artifacts.Metadata(name="my_artifact", version="1.0", description="An example artifact") + return artifacts.new(result, metadata) +``` + +Launching with known artifacts: +```python +flyte.run(main, x=flyte.remote.Artifact.get("name", version="1.0")) +``` + +Retireve a set of artifacts and pass them as a list +```python +from flyte.remote import Artifact +flyte.run(main, x=[Artifact.get("name1", version="1.0"), Artifact.get("name2", version="2.0")]) +``` +OR +```python +flyte.run(main, x=flyte.remote.Artifact.list("name_prefix", partition_match="x")) +``` +""" + +from ._card import Card, CardFormat, CardType +from ._metadata import Metadata +from ._wrapper import Artifact, new + +__all__ = ["Artifact", "Card", "CardFormat", "CardType", "Metadata", "new"] diff --git a/src/flyte/artifacts/_card.py b/src/flyte/artifacts/_card.py new file mode 100644 index 000000000..863ab7738 --- /dev/null +++ b/src/flyte/artifacts/_card.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import pathlib +import tempfile +from dataclasses import dataclass +from typing import Literal + +import flyte +from flyte import storage, syncify + +CardType = Literal["model", "data", "generic"] +CardFormat = Literal["html", "md", "json", "yaml", "csv", "tsv", "png", "jpg", "jpeg"] + + +@dataclass(frozen=True, kw_only=True) +class Card(object): + uri: str + format: CardFormat = "html" + card_type: CardType = "generic" + + @syncify.syncify + @classmethod + async def create_from( + cls, + *, + content: str | None = None, + local_path: pathlib.Path | None = None, + format: CardFormat = "html", + card_type: CardType = "generic", + ) -> Card: + """ + Upload a card either from raw content or from a local file path. + + :param content: Raw content of the card to be uploaded. + :param local_path: Local file path of the card to be uploaded. + :param format: Format of the card (e.g., 'html', 'md', + 'json', 'yaml', 'csv', 'tsv', 'png', 'jpg', 'jpeg'). + :param card_type: Type of the card (e.g., 'model', 'data', 'generic'). + """ + if content: + with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False) as temp_file: + temp_file.write(content) + temp_path = pathlib.Path(temp_file.name) + return await _upload_card_from_local(temp_path, format=format, card_type=card_type) + if local_path: + return await _upload_card_from_local(local_path, format=format, card_type=card_type) + raise ValueError("Either content or local_path must be provided to upload a card.") + + +async def _upload_card_from_local( + local_path: pathlib.Path, format: CardFormat = "html", card_type: CardType = "generic" +) -> Card: + # Implement upload. If in task context, upload to current metadata location, if not, upload using control plane. + uri = "" + ctx = flyte.ctx() + if ctx: + output_path = ctx.output_path + "/" + f"{card_type}.{format}" + uri = await storage.put(str(local_path), output_path) + else: + import flyte.remote as remote + + _, uri = await remote.upload_file.aio(local_path) + return Card(uri=uri, format=format, card_type=card_type) diff --git a/src/flyte/artifacts/_metadata.py b/src/flyte/artifacts/_metadata.py new file mode 100644 index 000000000..4becec77c --- /dev/null +++ b/src/flyte/artifacts/_metadata.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass +from typing import Optional, Tuple + +from ._card import Card + + +@dataclass(frozen=True, kw_only=True) +class Metadata: + """Structured metadata for Flyte artifacts.""" + + # Core tracking fields + name: str + version: Optional[str] = None + description: Optional[str] = None + data: Optional[typing.Mapping[str, str]] = None + card: Optional[Card] = None + + @classmethod + def create_model_metadata( + cls, + *, + name: str, + version: Optional[str] = None, + description: Optional[str] = None, + card: Optional[Card] = None, + framework: Optional[str] = None, + model_type: Optional[str] = None, + architecture: Optional[str] = None, + task: Optional[str] = None, + modality: Tuple[str, ...] = ("text",), + serial_format: str = "safetensors", + short_description: Optional[str] = None, + ) -> Metadata: + """ + Helper method to create ModelMetadata. This method sets the data keys specific to models. + """ + return cls( + name=name, + version=version, + description=description, + data={ + "framework": framework or "", + "model_type": model_type or "", + "architecture": architecture or "", + "task": task or "", + "modality": ",".join(modality) if modality else "", + "serial_format": serial_format or "", + "short_description": short_description or "", + }, + card=card, + ) diff --git a/src/flyte/artifacts/_wrapper.py b/src/flyte/artifacts/_wrapper.py new file mode 100644 index 000000000..35b5e4576 --- /dev/null +++ b/src/flyte/artifacts/_wrapper.py @@ -0,0 +1,116 @@ +from typing import Any, Protocol, TypeVar, runtime_checkable + +from typing_extensions import ParamSpec + +from ._metadata import Metadata + +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T") +P = ParamSpec("P") + + +@runtime_checkable +class Artifact(Protocol[T_co]): + """Protocol for objects wrapped with Flyte metadata.""" + + _flyte_metadata: Metadata + + def get_flyte_metadata(self) -> Metadata: + """Get the Flyte metadata associated with this artifact.""" + ... + + +class ArtifactWrapper: + """Zero-copy wrapper that preserves the original object interface.""" + + __slots__ = ("_flyte_metadata", "_obj") + + def __init__(self, obj: T_co, metadata: Metadata) -> None: + object.__setattr__(self, "_obj", obj) + object.__setattr__(self, "_flyte_metadata", metadata) + + def __getattr__(self, name: str) -> Any: + return getattr(self._obj, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_obj", "_flyte_metadata"): + object.__setattr__(self, name, value) + else: + setattr(self._obj, name, value) + + def __delattr__(self, name: str) -> None: + if name in ("_obj", "_flyte_metadata"): + raise AttributeError(f"Cannot delete {name}") + delattr(self._obj, name) + + def __getattribute__(self, name: str) -> Any: + """Override attribute access to make __class__ return the wrapped object's type.""" + if name == "__class__": + return type(object.__getattribute__(self, "_obj")) + elif name in ( + "_obj", + "_flyte_metadata", + "get_flyte_metadata", + "__call__", + "__repr__", + "__str__", + "__bool__", + "__len__", + "__iter__", + "__getitem__", + "__setitem__", + "__contains__", + ): + return object.__getattribute__(self, name) + else: + return getattr(object.__getattribute__(self, "_obj"), name) + + def get_flyte_metadata(self) -> Metadata: + """Get a copy of the Flyte metadata.""" + import copy + + return copy.deepcopy(self._flyte_metadata) + + # Forward common special methods for better compatibility + def __str__(self) -> str: + return str(self._obj) + + def __repr__(self) -> str: + return f"Artifact[{type(self._obj).__name__}]({self._obj})" + + def __bool__(self) -> bool: + return bool(self._obj) + + def __len__(self) -> int: + return len(self._obj) + + def __iter__(self): + return iter(self._obj) + + def __call__(self, *args: Any, **kwargs: Any): + return self._obj(*args, **kwargs) + + def __getitem__(self, key): + return self._obj[key] + + def __setitem__(self, key, value): + self._obj[key] = value + + def __contains__(self, item): + return item in self._obj + + +def new(obj: T, metadata: Metadata) -> T: + """ + Wrap an object with Flyte metadata while preserving its type interface. + + Args: + obj: The object to wrap + metadata: Metadata to associate with the object + + Returns: + A zero-copy wrapper that behaves exactly like the original object + but carries additional Flyte metadata accessible via get_flyte_metadata() + """ + wrapper = ArtifactWrapper(obj, metadata) + return wrapper # type: ignore[return-value] diff --git a/src/flyte/remote/__init__.py b/src/flyte/remote/__init__.py index 67bc40827..eb1626993 100644 --- a/src/flyte/remote/__init__.py +++ b/src/flyte/remote/__init__.py @@ -8,6 +8,7 @@ "ActionInputs", "ActionOutputs", "App", + "Artifact", "Project", "Run", "RunDetails", @@ -23,6 +24,7 @@ ] from ._action import Action, ActionDetails, ActionInputs, ActionOutputs +from ._artifact import Artifact from ._app import App from ._client.auth import create_channel from ._data import upload_dir, upload_file diff --git a/src/flyte/remote/_artifact.py b/src/flyte/remote/_artifact.py new file mode 100644 index 000000000..605a545a0 --- /dev/null +++ b/src/flyte/remote/_artifact.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Any, AsyncIterator, Literal + +from flyte.artifacts import Artifact as CoreArtifact # Ensure the core Artifact is imported +from flyte.remote._common import ToJSONMixin +from flyte.syncify import syncify + + +@dataclass +class Artifact(ToJSONMixin): + """ + A class representing a project in the Union API. + """ + + pb2: Any # Replace 'Any' with the actual protobuf type when available + + @syncify + @classmethod + async def get(cls, name: str, version: str | Literal["latest"] = "latest") -> Artifact: + """ + Get an artifact by its name and version. + + :param name: The name of the artifact. + :param version: The version of the artifact. + """ + if name == "my_artifact": + return Artifact( + pb2={ + "metadata": { + "name": "my_artifact", + "version": "1.0", + "description": "An example artifact", + }, + "data": "This is my artifact content", + }, + ) + raise NotImplementedError("Artifact retrieval not yet implemented.") + + @syncify + @classmethod + async def listall( + cls, + name: str | None = None, + created_after: datetime | None = None, + limit: int = -1, + **partition_match: str, + ) -> AsyncIterator[Artifact]: + """ + List artifacts by name prefix and optional partition match. + + :param name: The name prefix of the artifacts. + :param created_after: Filter artifacts created after this datetime. + :param limit: The maximum number of artifacts to return. -1 for no limit. + :param partition_match: Key-value pairs to filter artifacts by partition. + :return: A list of artifacts. + """ + if name == "my_artifact" and partition_match.get("version") == "1.0": + yield Artifact( + pb2={ + "metadata": { + "name": "my_artifact", + "version": "1.0", + "description": "An example artifact", + }, + "data": "This is my artifact content 1", + }, + ) + yield Artifact( + pb2={ + "metadata": { + "name": "my_artifact", + "version": "1.0", + "description": "An example artifact", + }, + "data": "This is my artifact content 2", + }, + ) + return + raise NotImplementedError("Artifact listing not yet implemented.") + + @syncify + @classmethod + async def create(cls, artifact: CoreArtifact) -> Artifact: + """ + Create a new artifact in the remote system. + + :param artifact: The core Artifact instance to create remotely. + :return: The created Artifact instance with remote metadata. + """ + raise NotImplementedError("Artifact creation not yet implemented.") + + @syncify + async def delete(self) -> None: + """ + Delete this artifact from the remote system. + """ + raise NotImplementedError("Artifact deletion not yet implemented.") diff --git a/src/flyte/types/_type_engine.py b/src/flyte/types/_type_engine.py index 847f5c3da..62067c08c 100644 --- a/src/flyte/types/_type_engine.py +++ b/src/flyte/types/_type_engine.py @@ -38,6 +38,7 @@ from pydantic import BaseModel from typing_extensions import Annotated, get_args, get_origin +import flyte.artifacts._wrapper import flyte.storage as storage from flyte._logging import logger from flyte._utils.helpers import load_proto_from_file @@ -1110,6 +1111,8 @@ def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expecte async def to_literal( cls, python_val: typing.Any, python_type: Type[T], expected: types_pb2.LiteralType ) -> literals_pb2.Literal: + if isinstance(python_val, flyte.artifacts._wrapper.ArtifactWrapper): + python_val = python_val._obj transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: