diff --git a/src/flyte/__init__.py b/src/flyte/__init__.py index 86e002e9d..599d4a133 100644 --- a/src/flyte/__init__.py +++ b/src/flyte/__init__.py @@ -28,6 +28,7 @@ from ._logging import logger from ._map import map from ._pod import PodTemplate +from ._replay import replay, with_replaycontext from ._resources import AMD_GPU, GPU, HABANA_GAUDI, TPU, Device, DeviceClass, Neuron, Resources from ._retry import RetryStrategy from ._reusable_environment import ReusePolicy @@ -97,11 +98,13 @@ def version() -> str: "init_passthrough", "logger", "map", + "replay", "run", "run_python_script", "serve", "trace", "version", + "with_replaycontext", "with_runcontext", "with_servecontext", ] diff --git a/src/flyte/_replay.py b/src/flyte/_replay.py new file mode 100644 index 000000000..4a3b2f283 --- /dev/null +++ b/src/flyte/_replay.py @@ -0,0 +1,299 @@ +from __future__ import annotations + +import pathlib +from typing import TYPE_CHECKING, Any, cast + +from flyte._environment import Environment +from flyte._initialize import get_client, get_init_config +from flyte._run import Mode +from flyte._task import F, P, R, TaskTemplate +from flyte.models import ( + ActionID, + RawDataPath, + SerializationContext, + TaskContext, +) +from flyte.syncify import syncify + +if TYPE_CHECKING: + from flyte.remote import Run + + +class _Replayer: + """Internal class that handles replay logic.""" + + def __init__( + self, + mode: Mode | None = None, + ): + self._mode = mode + + @syncify + async def replay( + self, + run_name: str, + action_name: str = "a0", + task_template: TaskTemplate | None = None, + ) -> Run: + """Execute the replay: fetch original run's inputs and RunSpec, then launch a new run.""" + from flyte._initialize import ensure_client + from flyte.remote import ActionDetails, RunDetails + + ensure_client() + + # Determine mode + mode = self._mode + if mode is None: + client = get_client() + if client is not None: + mode = "remote" + else: + mode = "local" + + if mode == "local" and task_template is None: + raise ValueError( + "Local replay requires a task_template to be provided. " + "Without a TaskTemplate, there is no Python function to execute locally." + ) + + # Step 1: Fetch RunDetails to get the RunSpec and root action details + run_details = await RunDetails.get.aio(name=run_name) + original_run_spec = run_details.pb2.run_spec + + # Step 2: Get the action details for the requested action + if action_name == "a0": + # Root action is already available in run_details + action_details = run_details.action_details + else: + action_details = await ActionDetails.get.aio( + run_name=run_name, + name=action_name, + ) + + # Step 3: Fetch raw proto inputs via get_action_data + from flyteidl2.workflow import run_service_pb2 + + resp = await get_client().run_service.get_action_data( + request=run_service_pb2.GetActionDataRequest( + action_id=action_details.pb2.id, + ) + ) + raw_inputs = resp.inputs + + # Step 4: Determine task_spec + if task_template is None: + # Reuse the resolved task spec from the original action + task_spec = action_details.pb2.resolved_task_spec + else: + # Build a fresh task_spec from the provided template + task_spec = await self._build_task_spec(task_template) + + # Step 5: Dispatch by mode + if mode == "remote": + return await self._replay_remote(task_spec, raw_inputs, original_run_spec) + elif mode == "local": + return await self._replay_local(task_spec, raw_inputs, action_details, task_template) + elif mode == "hybrid": + return await self._replay_hybrid(task_spec, raw_inputs, action_details, original_run_spec) + else: + raise ValueError(f"Unknown mode: {mode}") + + async def _build_task_spec(self, task: TaskTemplate): + """Build a task_spec from a TaskTemplate, including code bundling and image building.""" + import flyte.report + from flyte._code_bundle import build_code_bundle + from flyte._deploy import build_images + from flyte._image import Image, resolve_code_bundle_layer + from flyte._initialize import get_init_config + from flyte._internal.runtime.task_serde import translate_task_to_wire + + cfg = get_init_config() + + if task.parent_env is None: + raise ValueError("Task is not attached to an environment. Please attach the task to an environment.") + + parent_env = cast(Environment, task.parent_env()) + + from flyte._deploy import plan_deploy + + for _env in plan_deploy(parent_env)[0].envs.values(): + if isinstance(_env.image, Image): + _env.image = resolve_code_bundle_layer(_env.image, "loaded_modules", pathlib.Path(cfg.root_dir)) + + image_cache = await build_images.aio(parent_env) + + code_bundle = await build_code_bundle( + from_dir=cfg.root_dir, + dryrun=False, + copy_bundle_to=None, + copy_style="loaded_modules", + ) + + version = code_bundle.computed_version if code_bundle and code_bundle.computed_version else None + if not version: + raise ValueError("Version is required when running a task") + + project = cfg.project + domain = cfg.domain + org = cfg.org + + s_ctx = SerializationContext( + code_bundle=code_bundle, + version=version, + image_cache=image_cache, + root_dir=cfg.root_dir, + ) + action = ActionID(name="{{.actionName}}", run_name="{{.runName}}", project=project, domain=domain, org=org) + tctx = TaskContext( + action=action, + code_bundle=code_bundle, + output_path="", + version=version, + raw_data_path=RawDataPath(path=""), + compiled_image_cache=image_cache, + run_base_dir="", + report=flyte.report.Report(name=action.name), + ) + return translate_task_to_wire(task, s_ctx, default_inputs=None, task_context=tctx) + + async def _replay_remote(self, task_spec, raw_inputs, original_run_spec) -> Run: + """Replay by creating a new remote run with the original RunSpec and inputs.""" + from connectrpc.code import Code + from connectrpc.errors import ConnectError + from flyteidl2.common import identifier_pb2 + from flyteidl2.dataproxy import dataproxy_service_pb2 + from flyteidl2.workflow import run_service_pb2 + + import flyte.errors + from flyte.remote import Run + + cfg = get_init_config() + project_id = identifier_pb2.ProjectIdentifier( + name=cfg.project, + domain=cfg.domain, + organization=cfg.org, + ) + + # Upload inputs via dataproxy + upload_req = dataproxy_service_pb2.UploadInputsRequest( + inputs=raw_inputs, + task_spec=task_spec, + ) + upload_req.project_id.CopyFrom(project_id) + upload_resp = await get_client().dataproxy_service.upload_inputs(upload_req) + + # Create run with original RunSpec + try: + resp = await get_client().run_service.create_run( + run_service_pb2.CreateRunRequest( + project_id=project_id, + task_spec=task_spec, + offloaded_input_data=upload_resp.offloaded_input_data, + run_spec=original_run_spec, + ), + ) + return Run(pb2=resp.run) + except ConnectError as e: + if e.code == Code.UNAVAILABLE: + raise flyte.errors.RuntimeSystemError( + "SystemUnavailableError", + "Flyte system is currently unavailable. Check your configuration, or the service status.", + ) from e + elif e.code == Code.INVALID_ARGUMENT: + raise flyte.errors.RuntimeUserError("InvalidArgumentError", e.message) + elif e.code == Code.ALREADY_EXISTS: + raise flyte.errors.RuntimeUserError( + "RunAlreadyExistsError", + "A run with this name already exists. Please choose a different name.", + ) + else: + raise flyte.errors.RuntimeSystemError( + "RunCreationError", + f"Failed to create run: {e.message}", + ) from e + + async def _replay_local(self, task_spec, raw_inputs, action_details, task_template) -> Run: + """Replay locally by converting inputs to native and executing the task.""" + import flyte.types as types + from flyte._internal.runtime import convert + from flyte._run import run_task_locally + + task = task_template + assert task is not None # validated in replay() + + # Convert proto inputs to native Python types + native_iface = None + if action_details.pb2.HasField("task"): + iface = action_details.pb2.task.task_template.interface + native_iface = types.guess_interface(iface) + + if native_iface is None: + native_iface = task.native_interface + + native_inputs = await convert.convert_from_inputs_to_native(native_iface, convert.Inputs(raw_inputs)) + + return await run_task_locally(task, run_label="replay-local", **native_inputs) + + async def _replay_hybrid(self, task_spec, raw_inputs, action_details, original_run_spec) -> Any: + """Replay in hybrid mode: run parent locally, children remotely.""" + # Hybrid replay is not yet implemented + raise ValueError( + "Hybrid replay requires a run_base_dir. Use with_replaycontext with additional configuration, " + "or use remote mode for replaying runs." + ) + + +def with_replaycontext( + mode: Mode | None = None, +) -> _Replayer: + """ + Create a replay context with the given mode. + + Supports the same modes as with_runcontext: local, remote, and hybrid. + + Example:: + + flyte.with_replaycontext(mode="remote").replay("my-run-name", action_name="a0") + + :param mode: The execution mode - "local", "remote", or "hybrid". If not provided, + defaults to "remote" if a client is configured, else "local". + :return: A _Replayer with a .replay() method. + """ + return _Replayer(mode=mode) + + +@syncify +async def replay( + run_name: str, + action_name: str = "a0", + task_template: TaskTemplate[P, R, F] | None = None, +) -> Run: + """ + Replay an existing run by re-executing it with the same inputs and RunSpec. + + Retrieves the entire RunSpec and inputs from the original run/action, then launches + a new run. If task_template is not provided, the original remote task template is used. + If task_template is provided, the new code is bundled and used with the original inputs. + + Example:: + + # Replay with original task template + flyte.replay("my-run-name") + + # Replay a specific action + flyte.replay("my-run-name", action_name="a1") + + # Replay with new code + flyte.replay("my-run-name", task_template=my_updated_task) + + :param run_name: The name of the run to replay. + :param action_name: The name of the action to replay inputs from. Defaults to "a0" (root action). + :param task_template: Optional new TaskTemplate to use. If not provided, the original + remote task template is used. + :return: A Run object representing the new run. + """ + return await _Replayer().replay.aio( + run_name=run_name, + action_name=action_name, + task_template=task_template, + ) diff --git a/src/flyte/_run.py b/src/flyte/_run.py index c1a2ec775..da34508b2 100644 --- a/src/flyte/_run.py +++ b/src/flyte/_run.py @@ -87,6 +87,136 @@ def _get_main_run_mode() -> Mode | None: return _run_mode_var.get() +async def run_task_locally( + task: TaskTemplate, + *args: Any, + name: Optional[str] = None, + metadata_path: Optional[str] = None, + raw_data_path_str: Optional[str] = None, + tracker: Any = None, + custom_context: Optional[Dict[str, str]] = None, + disable_run_cache: bool = False, + run_label: str = "local-run", + **kwargs: Any, +) -> Any: + """Shared local execution logic used by both ``_Runner._run_local`` and ``_Replayer._replay_local``. + + Returns a ``_LocalRun`` (a ``Run`` subclass) wrapping the task outputs. + Notifications are **not** handled here — callers that need them should wrap + this call in try/except. + """ + from flyteidl2.common import identifier_pb2 + from flyteidl2.task import common_pb2 + from flyteidl2.workflow import run_definition_pb2 + + from flyte._internal.controllers import create_controller + from flyte._internal.controllers._local_controller import LocalController + from flyte.remote import ActionOutputs, Run + from flyte.report import Report + + controller = cast(LocalController, create_controller("local")) + + if name is None: + action = ActionID.create_random() + else: + action = ActionID(name=name) + + _metadata_path: pathlib.Path + if metadata_path is None: + _metadata_path = pathlib.Path("/") / "tmp" / "flyte" / "metadata" / action.name + else: + _metadata_path = pathlib.Path(metadata_path) / action.name + output_path = _metadata_path / "a0" + + if raw_data_path_str is None: + path = pathlib.Path("/") / "tmp" / "flyte" / "raw_data" / action.name + raw_data_path = RawDataPath(path=str(path)) + else: + raw_data_path = RawDataPath(path=raw_data_path_str) + + ctx = internal_ctx() + tctx = TaskContext( + action=action, + checkpoints=Checkpoints( + prev_checkpoint_path=internal_ctx().raw_data.path, + checkpoint_path=internal_ctx().raw_data.path, + ), + code_bundle=None, + output_path=str(output_path), + run_base_dir=str(_metadata_path), + version="na", + raw_data_path=raw_data_path, + compiled_image_cache=None, + report=Report(name=action.name), + mode="local", + custom_context=custom_context or {}, + disable_run_cache=disable_run_cache, + ) + + if tracker is not None: + ctx = Context(ctx.data.replace(tracker=tracker)) + + from flyte._initialize import is_persistence_enabled + from flyte._persistence._recorder import RunRecorder + + persist = is_persistence_enabled() + run_name = action.run_name or action.name + + if persist: + RunRecorder.initialize_persistence() + + recorder = RunRecorder(tracker=tracker, persist=persist, run_name=run_name) + controller.set_recorder(recorder) + + recorder.record_root_start(task_name=task.name) + + try: + with ctx.replace_task_context(tctx): + if task._call_as_synchronous: + fut = controller.submit_sync(task, *args, **kwargs) + awaitable = asyncio.wrap_future(fut) + outputs = await awaitable + else: + outputs = await controller.submit(task, *args, **kwargs) + except Exception as e: + recorder.record_root_failure(error=str(e)) + raise + + recorder.record_root_complete() + + class _LocalRun(Run): + def __init__(self, outputs: Tuple[Any, ...] | Any): + self._outputs = ActionOutputs(common_pb2.Outputs(), outputs if isinstance(outputs, tuple) else (outputs,)) + super().__init__( + pb2=run_definition_pb2.Run( + action=run_definition_pb2.Action( + id=identifier_pb2.ActionIdentifier( + name="a0", + run=identifier_pb2.RunIdentifier(name=run_label), + ) + ) + ) + ) + + @property + def url(self) -> str: + return str(_metadata_path) + + @syncify + async def wait( # type: ignore[override] + self, + quiet: bool = False, + wait_for: Literal["terminal", "running"] = "terminal", + ) -> None: + pass + + @syncify + async def outputs(self) -> ActionOutputs: # type: ignore[override] + return self._outputs + + return _LocalRun(outputs) + + class _Runner: def __init__( self, @@ -620,125 +750,29 @@ async def _send_local_notifications( ) async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Run: - from flyteidl2.common import identifier_pb2 - from flyteidl2.task import common_pb2 - - from flyte._internal.controllers import create_controller - from flyte._internal.controllers._local_controller import LocalController - from flyte.remote import ActionOutputs, Run - from flyte.report import Report - - controller = cast(LocalController, create_controller("local")) - - if self._name is None: - action = ActionID.create_random() - else: - action = ActionID(name=self._name) - - metadata_path = self._metadata_path - if metadata_path is None: - metadata_path = pathlib.Path("/") / "tmp" / "flyte" / "metadata" / action.name - else: - metadata_path = pathlib.Path(metadata_path) / action.name - output_path = metadata_path / "a0" - if self._raw_data_path is None: - path = pathlib.Path("/") / "tmp" / "flyte" / "raw_data" / action.name - raw_data_path = RawDataPath(path=str(path)) - else: - raw_data_path = RawDataPath(path=self._raw_data_path) - - ctx = internal_ctx() - tctx = TaskContext( - action=action, - checkpoints=Checkpoints( - prev_checkpoint_path=internal_ctx().raw_data.path, - checkpoint_path=internal_ctx().raw_data.path, - ), - code_bundle=None, - output_path=str(output_path), - run_base_dir=str(metadata_path), - version="na", - raw_data_path=raw_data_path, - compiled_image_cache=None, - report=Report(name=action.name), - mode="local", - custom_context=self._custom_context, - disable_run_cache=self._disable_run_cache, - ) - - if self._tracker is not None: - ctx = Context(ctx.data.replace(tracker=self._tracker)) - - from flyte._initialize import is_persistence_enabled - from flyte._persistence._recorder import RunRecorder - - persist = is_persistence_enabled() - run_name = action.run_name or action.name - - if persist: - RunRecorder.initialize_persistence() - - recorder = RunRecorder(tracker=self._tracker, persist=persist, run_name=run_name) - controller.set_recorder(recorder) - - recorder.record_root_start(task_name=obj.name) - try: - with ctx.replace_task_context(tctx): - # make the local version always runs on a different thread, returns a wrapped future. - if obj._call_as_synchronous: - fut = controller.submit_sync(obj, *args, **kwargs) - awaitable = asyncio.wrap_future(fut) - outputs = await awaitable - else: - outputs = await controller.submit(obj, *args, **kwargs) + result = await run_task_locally( + obj, + *args, + name=self._name, + metadata_path=self._metadata_path, + raw_data_path_str=self._raw_data_path, + tracker=self._tracker, + custom_context=self._custom_context, + disable_run_cache=self._disable_run_cache, + run_label="dry-run", + **kwargs, + ) except Exception as e: - recorder.record_root_failure(error=str(e)) if self._notifications: await self._send_local_notifications( - phase=ActionPhase.FAILED, task_name=obj.name, run_name=run_name, error=str(e) + phase=ActionPhase.FAILED, task_name=obj.name, run_name=obj.name, error=str(e) ) raise else: - recorder.record_root_complete() if self._notifications: - await self._send_local_notifications(phase=ActionPhase.SUCCEEDED, task_name=obj.name, run_name=run_name) - - class _LocalRun(Run): - def __init__(self, outputs: Tuple[Any, ...] | Any): - from flyteidl2.workflow import run_definition_pb2 - - self._outputs = ActionOutputs( - common_pb2.Outputs(), outputs if isinstance(outputs, tuple) else (outputs,) - ) - super().__init__( - pb2=run_definition_pb2.Run( - action=run_definition_pb2.Action( - id=identifier_pb2.ActionIdentifier( - name="a0", - run=identifier_pb2.RunIdentifier(name="dry-run"), - ) - ) - ) - ) - - @property - def url(self) -> str: - return str(metadata_path) - - @syncify - async def wait( # type: ignore[override] - self, - quiet: bool = False, - wait_for: Literal["terminal", "running"] = "terminal", - ) -> None: - pass - - @syncify - async def outputs(self) -> ActionOutputs: # type: ignore[override] - return self._outputs - - return _LocalRun(outputs) + await self._send_local_notifications(phase=ActionPhase.SUCCEEDED, task_name=obj.name, run_name=obj.name) + return result @syncify # type: ignore[arg-type] async def run(