Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/advanced/artifact_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Tuple

import flyte
import flyte.artifacts as artifacts

env = flyte.TaskEnvironment("artifact_example")


@env.task
def create_artifact() -> str:
result = "This is my artifact content"
metadata = artifacts.Metadata(
name="my_artifact", version="1.0", description="An example artifact created in create_artifact task"
)
return artifacts.new(result, metadata)


@env.task
def model_artifact() -> str:
result = "This is my model artifact content"
card = artifacts.Card.create_from(
content="<h1>Model Card</h1><p>This is a sample model card.</p>",
format="html",
card_type="model",
)

metadata = artifacts.Metadata.create_model_metadata(
name="my_model_artifact",
version="1.0",
description="An example model artifact created in model_artifact task",
framework="PyTorch",
model_type="Neural Network",
architecture="ResNet50",
task="Image Classification",
modality=("image",),
serial_format="pt",
short_description="A ResNet50 model for image classification tasks.",
card=card,
)
return artifacts.new(result, metadata)


@env.task
def call_artifact() -> Tuple[str, str]:
x = create_artifact()
print(x)
y = model_artifact()
print(y)
return x, y


@env.task
async def use_artifact(v: str) -> str:
print(f"Using artifact with content: {v}")
return f"Artifact used with content: {v}"


@env.task
async def use_multiple_artifacts(v: list[str]) -> str:
print(f"Using multiple artifacts with contents: {v}")
return f"Multiple artifacts used with contents: {v}"


if __name__ == "__main__":
flyte.init()
v = flyte.run(call_artifact)
print(v.outputs())

from flyte.remote import Artifact

artifact_instance = Artifact.get("my_artifact", version="1.0")
v2 = flyte.run(use_artifact, v=artifact_instance)
print(v2.outputs())

artifact_list = [Artifact.get("my_artifact", version="1.0"), Artifact.get("my_artifact", version="1.0")]
v3 = flyte.run(use_multiple_artifacts, v=artifact_list)
print(v3.outputs())

artifact_list_via_prefix = list(Artifact.listall("my_artifact", version="1.0"))
v4 = flyte.run(use_multiple_artifacts, v=artifact_list_via_prefix)
print(v4.outputs())
16 changes: 14 additions & 2 deletions src/flyte/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,13 +519,25 @@ async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs:
)

with ctx.replace_task_context(tctx):
new_kwargs = {}
if kwargs:
from flyte.remote import Artifact

for k, v in kwargs.items():
if isinstance(v, Artifact):
new_kwargs[k] = v.pb2["data"]
elif isinstance(v, list) and len(v) > 0:
new_kwargs[k] = [item.pb2["data"] if isinstance(item, Artifact) else item for item in v]
else:
new_kwargs[k] = v
print(new_kwargs)
# make the local version always runs on a different thread, returns a wrapped future.
if obj._call_as_synchronous:
fut = controller.submit_sync(obj, *args, **kwargs)
fut = controller.submit_sync(obj, *args, **new_kwargs)
awaitable = asyncio.wrap_future(fut)
outputs = await awaitable
else:
outputs = await controller.submit(obj, *args, **kwargs)
outputs = await controller.submit(obj, *args, **new_kwargs)

class _LocalRun(Run):
def __init__(self, outputs: Tuple[Any, ...] | Any):
Expand Down
37 changes: 37 additions & 0 deletions src/flyte/artifacts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Artifacts module

This module provides a wrapper method to mark certain outputs as artifacts with associated metadata.

Usage example:
```python
import flyte.artifacts as artifacts

@env.task
def my_task() -> MyType:
result = MyType(...)
metadata = artifacts.Metadata(name="my_artifact", version="1.0", description="An example artifact")
return artifacts.new(result, metadata)
```

Launching with known artifacts:
```python
flyte.run(main, x=flyte.remote.Artifact.get("name", version="1.0"))
```

Retireve a set of artifacts and pass them as a list
```python
from flyte.remote import Artifact
flyte.run(main, x=[Artifact.get("name1", version="1.0"), Artifact.get("name2", version="2.0")])
```
OR
```python
flyte.run(main, x=flyte.remote.Artifact.list("name_prefix", partition_match="x"))
```
"""

from ._card import Card, CardFormat, CardType
from ._metadata import Metadata
from ._wrapper import Artifact, new

__all__ = ["Artifact", "Card", "CardFormat", "CardType", "Metadata", "new"]
63 changes: 63 additions & 0 deletions src/flyte/artifacts/_card.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import pathlib
import tempfile
from dataclasses import dataclass
from typing import Literal

import flyte
from flyte import storage, syncify

CardType = Literal["model", "data", "generic"]
CardFormat = Literal["html", "md", "json", "yaml", "csv", "tsv", "png", "jpg", "jpeg"]


@dataclass(frozen=True, kw_only=True)
class Card(object):
uri: str
format: CardFormat = "html"
card_type: CardType = "generic"

@syncify.syncify
@classmethod
async def create_from(
cls,
*,
content: str | None = None,
local_path: pathlib.Path | None = None,
format: CardFormat = "html",
card_type: CardType = "generic",
) -> Card:
"""
Upload a card either from raw content or from a local file path.

:param content: Raw content of the card to be uploaded.
:param local_path: Local file path of the card to be uploaded.
:param format: Format of the card (e.g., 'html', 'md',
'json', 'yaml', 'csv', 'tsv', 'png', 'jpg', 'jpeg').
:param card_type: Type of the card (e.g., 'model', 'data', 'generic').
"""
if content:
with tempfile.NamedTemporaryFile(mode="w", suffix=f".{format}", delete=False) as temp_file:
temp_file.write(content)
temp_path = pathlib.Path(temp_file.name)
return await _upload_card_from_local(temp_path, format=format, card_type=card_type)
if local_path:
return await _upload_card_from_local(local_path, format=format, card_type=card_type)
raise ValueError("Either content or local_path must be provided to upload a card.")


async def _upload_card_from_local(
local_path: pathlib.Path, format: CardFormat = "html", card_type: CardType = "generic"
) -> Card:
# Implement upload. If in task context, upload to current metadata location, if not, upload using control plane.
uri = ""
ctx = flyte.ctx()
if ctx:
output_path = ctx.output_path + "/" + f"{card_type}.{format}"
uri = await storage.put(str(local_path), output_path)
else:
import flyte.remote as remote

_, uri = await remote.upload_file.aio(local_path)
return Card(uri=uri, format=format, card_type=card_type)
54 changes: 54 additions & 0 deletions src/flyte/artifacts/_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from typing import Optional, Tuple

from ._card import Card


@dataclass(frozen=True, kw_only=True)
class Metadata:
"""Structured metadata for Flyte artifacts."""

# Core tracking fields
name: str
version: Optional[str] = None
description: Optional[str] = None
data: Optional[typing.Mapping[str, str]] = None
card: Optional[Card] = None

@classmethod
def create_model_metadata(
cls,
*,
name: str,
version: Optional[str] = None,
description: Optional[str] = None,
card: Optional[Card] = None,
framework: Optional[str] = None,
model_type: Optional[str] = None,
architecture: Optional[str] = None,
task: Optional[str] = None,
modality: Tuple[str, ...] = ("text",),
serial_format: str = "safetensors",
short_description: Optional[str] = None,
) -> Metadata:
"""
Helper method to create ModelMetadata. This method sets the data keys specific to models.
"""
return cls(
name=name,
version=version,
description=description,
data={
"framework": framework or "",
"model_type": model_type or "",
"architecture": architecture or "",
"task": task or "",
"modality": ",".join(modality) if modality else "",
"serial_format": serial_format or "",
"short_description": short_description or "",
},
card=card,
)
116 changes: 116 additions & 0 deletions src/flyte/artifacts/_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from typing import Any, Protocol, TypeVar, runtime_checkable

from typing_extensions import ParamSpec

from ._metadata import Metadata

T_co = TypeVar("T_co", covariant=True)
T = TypeVar("T")
P = ParamSpec("P")


@runtime_checkable
class Artifact(Protocol[T_co]):
"""Protocol for objects wrapped with Flyte metadata."""

_flyte_metadata: Metadata

def get_flyte_metadata(self) -> Metadata:
"""Get the Flyte metadata associated with this artifact."""
...


class ArtifactWrapper:
"""Zero-copy wrapper that preserves the original object interface."""

__slots__ = ("_flyte_metadata", "_obj")

def __init__(self, obj: T_co, metadata: Metadata) -> None:
object.__setattr__(self, "_obj", obj)
object.__setattr__(self, "_flyte_metadata", metadata)

def __getattr__(self, name: str) -> Any:
return getattr(self._obj, name)

def __setattr__(self, name: str, value: Any) -> None:
if name in ("_obj", "_flyte_metadata"):
object.__setattr__(self, name, value)
else:
setattr(self._obj, name, value)

def __delattr__(self, name: str) -> None:
if name in ("_obj", "_flyte_metadata"):
raise AttributeError(f"Cannot delete {name}")
delattr(self._obj, name)

def __getattribute__(self, name: str) -> Any:
"""Override attribute access to make __class__ return the wrapped object's type."""
if name == "__class__":
return type(object.__getattribute__(self, "_obj"))
elif name in (
"_obj",
"_flyte_metadata",
"get_flyte_metadata",
"__call__",
"__repr__",
"__str__",
"__bool__",
"__len__",
"__iter__",
"__getitem__",
"__setitem__",
"__contains__",
):
return object.__getattribute__(self, name)
else:
return getattr(object.__getattribute__(self, "_obj"), name)

def get_flyte_metadata(self) -> Metadata:
"""Get a copy of the Flyte metadata."""
import copy

return copy.deepcopy(self._flyte_metadata)

# Forward common special methods for better compatibility
def __str__(self) -> str:
return str(self._obj)

def __repr__(self) -> str:
return f"Artifact[{type(self._obj).__name__}]({self._obj})"

def __bool__(self) -> bool:
return bool(self._obj)

def __len__(self) -> int:
return len(self._obj)

def __iter__(self):
return iter(self._obj)

def __call__(self, *args: Any, **kwargs: Any):
return self._obj(*args, **kwargs)

def __getitem__(self, key):
return self._obj[key]

def __setitem__(self, key, value):
self._obj[key] = value

def __contains__(self, item):
return item in self._obj


def new(obj: T, metadata: Metadata) -> T:
"""
Wrap an object with Flyte metadata while preserving its type interface.

Args:
obj: The object to wrap
metadata: Metadata to associate with the object

Returns:
A zero-copy wrapper that behaves exactly like the original object
but carries additional Flyte metadata accessible via get_flyte_metadata()
"""
wrapper = ArtifactWrapper(obj, metadata)
return wrapper # type: ignore[return-value]
2 changes: 2 additions & 0 deletions src/flyte/remote/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"ActionInputs",
"ActionOutputs",
"App",
"Artifact",
"Project",
"Run",
"RunDetails",
Expand All @@ -23,6 +24,7 @@
]

from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
from ._artifact import Artifact
from ._app import App
from ._client.auth import create_channel
from ._data import upload_dir, upload_file
Expand Down
Loading
Loading