diff --git a/examples/stress/README.md b/examples/stress/README.md new file mode 100644 index 000000000..b36edf5a8 --- /dev/null +++ b/examples/stress/README.md @@ -0,0 +1,50 @@ +# Stress Examples + +This directory contains ad hoc stress and failure-mode examples for Flyte and Union dogfood testing. + +## Primary Entry Point + +Use [sleep_fanout_harness_wrapper.sh](/Users/praful/flyte-sdk/examples/stress/sleep_fanout_harness_wrapper.sh) for multi-run `core-sleep` fanout tests: + +```bash +examples/stress/sleep_fanout_harness_wrapper.sh \ + --config ~/.flyte/config-dogfood.yaml \ + --total-runs 10 \ + --submit-concurrency 10 \ + --n-children 1000 \ + --sleep-duration 600 \ + --poll-interval 1 \ + --run-env _F_MAX_QPS=150 \ + --run-env _F_CTRL_WORKERS=20 \ + --run-env _F_P_CNC=1000 +``` + +This wrapper: +- submits many top-level `sleep_fanout` runs through `flyte run` +- tracks aggregate child visibility and running counts +- prints parent-run counts (`p_live`, `p_run`) and child creation rate (`create_rps`, `rps/p`) + +The underlying task definitions live in [sleep_fanout.py](/Users/praful/flyte-sdk/examples/stress/sleep_fanout.py), and the local submit helper lives in [sleep_fanout_harness.py](/Users/praful/flyte-sdk/examples/stress/sleep_fanout_harness.py). + +## Key Files + +- [sleep_fanout.py](/Users/praful/flyte-sdk/examples/stress/sleep_fanout.py): `core-sleep` leaf task, parent fanout task, and swarm submit task definitions. +- [sleep_fanout_harness.py](/Users/praful/flyte-sdk/examples/stress/sleep_fanout_harness.py): local async submit harness used by the wrapper. +- [runs_per_second.py](/Users/praful/flyte-sdk/examples/stress/runs_per_second.py): launch-rate test helper. +- [fanout_concurrency.py](/Users/praful/flyte-sdk/examples/stress/fanout_concurrency.py): simple fanout/concurrency experiment. +- [large_fanout.py](/Users/praful/flyte-sdk/examples/stress/large_fanout.py): wide fanout example. +- [duplicate_action_id.py](/Users/praful/flyte-sdk/examples/stress/duplicate_action_id.py): action-id collision / dedupe behavior probe. +- [crash_recovery_trace.py](/Users/praful/flyte-sdk/examples/stress/crash_recovery_trace.py), [long_recovery.py](/Users/praful/flyte-sdk/examples/stress/long_recovery.py), [fast_crasher.py](/Users/praful/flyte-sdk/examples/stress/fast_crasher.py): controller and recovery failure scenarios. +- [cpu_gremlin.py](/Users/praful/flyte-sdk/examples/stress/cpu_gremlin.py), [network_gremlin.py](/Users/praful/flyte-sdk/examples/stress/network_gremlin.py): fault-injection style workload examples. +- [large_file_io.py](/Users/praful/flyte-sdk/examples/stress/large_file_io.py), [large_dir_io.py](/Users/praful/flyte-sdk/examples/stress/large_dir_io.py), [benchmark/large_io_comparison.py](/Users/praful/flyte-sdk/examples/stress/benchmark/large_io_comparison.py): large I/O stress examples. +- [scale_test_same_image.py](/Users/praful/flyte-sdk/examples/stress/scale_test_same_image.py), [scale_test_varied_images.py](/Users/praful/flyte-sdk/examples/stress/scale_test_varied_images.py), [image_builds.py](/Users/praful/flyte-sdk/examples/stress/image_builds.py): image build and scale tests. + +## Notes + +- `sleep_fanout` leaves use the `core-sleep` plugin, so the children run in leaseworker instead of creating task pods. +- Parent resource defaults for fanout are controlled in `sleep_fanout.py` via: + - `FLYTE_STRESS_FANOUT_CPU_REQUEST` + - `FLYTE_STRESS_FANOUT_CPU_LIMIT` + - `FLYTE_STRESS_FANOUT_MEMORY_REQUEST` + - `FLYTE_STRESS_FANOUT_MEMORY_LIMIT` +- Remote image contents come from the built wheel in `dist/`, not directly from local `src/`. If the wrapper warns that `src/flyte` is newer than the wheel, rebuild the wheel before relying on SDK changes in remote runs. diff --git a/examples/stress/sleep_fanout.py b/examples/stress/sleep_fanout.py index c1b637561..30b640e10 100644 --- a/examples/stress/sleep_fanout.py +++ b/examples/stress/sleep_fanout.py @@ -1,25 +1,104 @@ import asyncio +import os from datetime import timedelta import flyte import flyte.report from flyte.extras import Sleep +_STRESS_IMAGE_REGISTRY = os.getenv("FLYTE_STRESS_IMAGE_REGISTRY") +_STRESS_IMAGE_NAME = os.getenv("FLYTE_STRESS_IMAGE_NAME") +_STRESS_IMAGE_PLATFORMS = tuple( + p.strip() for p in os.getenv("FLYTE_STRESS_IMAGE_PLATFORMS", "linux/amd64").split(",") if p.strip() +) +_STRESS_RUNTIME_ENV = { + k: v + for k, v in { + "FLYTE_STRESS_IMAGE_REGISTRY": _STRESS_IMAGE_REGISTRY, + "FLYTE_STRESS_IMAGE_NAME": _STRESS_IMAGE_NAME, + "FLYTE_STRESS_IMAGE_PLATFORMS": ",".join(_STRESS_IMAGE_PLATFORMS), + }.items() + if v +} + +# Let remote runs redirect image builds to a writable registry without +# touching the task definitions. For dogfood, this can point at the shared ECR +# repo used for ad hoc SDK test images. Default to amd64-only so the first +# build is faster and matches the dogfood cluster architecture. +stress_image = flyte.Image.from_debian_base( + python_version=(3, 12), + registry=_STRESS_IMAGE_REGISTRY, + name=_STRESS_IMAGE_NAME, + platform=_STRESS_IMAGE_PLATFORMS, +) + + +def _fanout_resources() -> flyte.Resources: + # Default for the distributed harness shape. Override these env vars when + # testing a single huge parent that needs much more headroom. + cpu_request = int(os.getenv("FLYTE_STRESS_FANOUT_CPU_REQUEST", "1")) + cpu_limit = int(os.getenv("FLYTE_STRESS_FANOUT_CPU_LIMIT", "2")) + memory_request = os.getenv("FLYTE_STRESS_FANOUT_MEMORY_REQUEST", "2Gi") + memory_limit = os.getenv("FLYTE_STRESS_FANOUT_MEMORY_LIMIT", "4Gi") + return flyte.Resources(cpu=(cpu_request, cpu_limit), memory=(memory_request, memory_limit)) + + +def _controller_tuning_env() -> dict[str, str]: + env: dict[str, str] = {} + for key in ( + "_F_MAX_QPS", + "_F_CTRL_WORKERS", + "_F_P_CNC", + "_U_USE_ACTIONS", + "_F_TRACE_SUBMIT", + "_F_TRACE_SUBMIT_LIMIT", + ): + value = os.getenv(key) + if value is not None: + env[key] = value + return env + + +def _nested_run_env() -> dict[str, str]: + return { + **_STRESS_RUNTIME_ENV, + **_controller_tuning_env(), + } + + +def _controller_tuning_summary() -> str: + env = _controller_tuning_env() + return ( + "controller_env " + f"_F_MAX_QPS={env.get('_F_MAX_QPS', '')} " + f"_F_CTRL_WORKERS={env.get('_F_CTRL_WORKERS', '')} " + f"_F_P_CNC={env.get('_F_P_CNC', '')} " + f"_U_USE_ACTIONS={env.get('_U_USE_ACTIONS', '')} " + f"_F_TRACE_SUBMIT={env.get('_F_TRACE_SUBMIT', '')} " + f"_F_TRACE_SUBMIT_LIMIT={env.get('_F_TRACE_SUBMIT_LIMIT', '')}" + ) + # Leaves run in leaseworker via the core-sleep plugin: no task pods are created, # so we can fan out wide without paying pod-startup cost. sleep_env = flyte.TaskEnvironment( name="sleep_fanout_leaf", + image=stress_image, + env_vars=_STRESS_RUNTIME_ENV, plugin_config=Sleep(), ) fanout_env = flyte.TaskEnvironment( name="sleep_fanout", - resources=flyte.Resources(cpu="50m", memory="200Mi"), + image=stress_image, + env_vars=_STRESS_RUNTIME_ENV, + resources=_fanout_resources(), depends_on=[sleep_env], ) swarm_env = flyte.TaskEnvironment( name="sleep_fanout_swarm", + image=stress_image, + env_vars=_STRESS_RUNTIME_ENV, resources=flyte.Resources(cpu=1, memory="500Mi"), depends_on=[fanout_env], ) @@ -41,8 +120,15 @@ async def sleep_fanout( All leaves run in leaseworker via the core-sleep plugin, so no task pods are created. """ + print( + f"fanout_inputs n_children={n_children} " + f"sleep_duration={sleep_duration} " + f"sleep_seconds={sleep_duration.total_seconds()}", + flush=True, + ) + print(_controller_tuning_summary(), flush=True) await asyncio.gather(*(sleep_leaf(duration=sleep_duration) for _ in range(n_children))) - print(f"Done. Total leaves: {n_children}") + print(f"Done. Total leaves: {n_children}", flush=True) return n_children @@ -60,13 +146,19 @@ async def submit_runs( from aiolimiter import AsyncLimiter limiter = AsyncLimiter(max_rps, 1) + child_run_env = _nested_run_env() - async def submit_one() -> str: + async def submit_one(idx: int) -> str: async with limiter: - run = await flyte.run.aio(sleep_fanout, n_children=n_children, sleep_duration=sleep_duration) + run = await flyte.run.aio( + sleep_fanout.override(env_vars=child_run_env), + n_children=n_children, + sleep_duration=sleep_duration, + ) + print(f"submitted_run idx={idx} url={run.url}", flush=True) return run.url - urls = await asyncio.gather(*(submit_one() for _ in range(n_runs))) + urls = await asyncio.gather(*(submit_one(i) for i in range(n_runs))) print(f"Swarm worker done. Submitted {len(urls)} runs at <= {max_rps} rps.") return list(urls) diff --git a/examples/stress/sleep_fanout_harness.py b/examples/stress/sleep_fanout_harness.py index 3e500cbb8..541408018 100644 --- a/examples/stress/sleep_fanout_harness.py +++ b/examples/stress/sleep_fanout_harness.py @@ -1,5 +1,5 @@ """ -Submit N copies of the sleep_fanout `main` task as fast as possible. +Submit N copies of the `sleep_fanout` task through the `flyte run` CLI. Each run spawns n_children core-sleep leaves in leaseworker (no task pods). Submissions are launched with a bounded semaphore to cap in-flight TCP @@ -9,34 +9,123 @@ import argparse import asyncio +import os +import pathlib +import re +import shutil import time from datetime import timedelta -from sleep_fanout import sleep_fanout as sleep_fanout_main +RUN_URL_RE = re.compile(r"URL:\s+(\S+/runs/[^/?\s]+)") +RUN_NAME_RE = re.compile(r"Created Run:\s+([^\s]+)") +RUNS_FILE = os.getenv("FLYTE_HARNESS_RUNS_FILE") +REPO_ROOT = pathlib.Path(__file__).resolve().parents[2] +LOCAL_SDK_SRC = REPO_ROOT / "src" +FLYTE_BIN = os.getenv("FLYTE_HARNESS_FLYTE_BIN") or shutil.which("flyte") or "flyte" +FORCE_LOCAL_SDK = os.getenv("FLYTE_HARNESS_FORCE_LOCAL_SDK", "").lower() in {"1", "true", "yes", "on"} -import flyte + +def _subprocess_env() -> dict[str, str]: + env = os.environ.copy() + if FORCE_LOCAL_SDK: + existing = env.get("PYTHONPATH", "") + local_src = str(LOCAL_SDK_SRC) + env["PYTHONPATH"] = f"{local_src}:{existing}" if existing else local_src + return env + + +def _append_run_name(path: str, name: str) -> None: + with open(path, "a", encoding="utf-8") as f: + f.write(f"{name}\n") async def submit_one(sem: asyncio.Semaphore, idx: int, n_children: int, sleep_duration: timedelta) -> str | None: async with sem: + os.environ.setdefault("_U_USE_ACTIONS", "1") + config = os.getenv("FLYTE_HARNESS_CONFIG", os.path.expanduser("~/.flyte/config-dogfood.yaml")) + image_builder = os.getenv("FLYTE_HARNESS_IMAGE_BUILDER", "remote") + project = os.getenv("FLYTE_HARNESS_PROJECT", "") + domain = os.getenv("FLYTE_HARNESS_DOMAIN", "") + run_env_keys = tuple( + k + for k in ( + "_F_MAX_QPS", + "_F_CTRL_WORKERS", + "_F_P_CNC", + "_U_USE_ACTIONS", + "_F_TRACE_SUBMIT", + "_F_TRACE_SUBMIT_LIMIT", + ) + if os.getenv(k) + ) + + cmd = [FLYTE_BIN, "-c", config, "--image-builder", image_builder, "run"] + if project: + cmd.extend(["-p", project]) + if domain: + cmd.extend(["-d", domain]) + for key in run_env_keys: + cmd.extend(["--env", f"{key}={os.environ[key]}"]) + cmd.extend( + [ + "examples/stress/sleep_fanout.py", + "sleep_fanout", + "--n_children", + str(n_children), + "--sleep_duration", + f"PT{int(sleep_duration.total_seconds())}S", + ] + ) + try: - run = await flyte.with_runcontext("remote").run.aio( - sleep_fanout_main, - n_children=n_children, - sleep_duration=sleep_duration, + proc = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + env=_subprocess_env(), ) - return run.url + + output_lines: list[str] = [] + assert proc.stdout is not None + while True: + line = await proc.stdout.readline() + if not line: + break + text = line.decode("utf-8", errors="replace").rstrip() + output_lines.append(text) + + rc = await proc.wait() + output = "\n".join(output_lines) + if rc != 0: + print(f"[{idx}] submit failed rc={rc} output={output!r}", flush=True) + return None + + url_match = RUN_URL_RE.search(output) + if url_match: + return url_match.group(1) + + name_match = RUN_NAME_RE.search(output) + if name_match: + return name_match.group(1) + + if "/runs/" in output: + print(f"[{idx}] submit failed: partial run URL parse failure output={output!r}", flush=True) + else: + print(f"[{idx}] submit failed: could not parse run id from output={output!r}", flush=True) + return None + return None except Exception as e: cause = getattr(e, "__cause__", None) print(f"[{idx}] submit failed: {type(e).__name__}: {e!r} cause={cause!r}", flush=True) return None -async def submit_many(total: int, concurrency: int, n_children: int, sleep_duration: timedelta) -> None: +async def submit_many(total: int, concurrency: int, n_children: int, sleep_duration: timedelta) -> int: sem = asyncio.Semaphore(concurrency) start = time.monotonic() submitted = 0 failed = 0 + runs_file_lock = asyncio.Lock() async def wrapped(i: int): nonlocal submitted, failed @@ -45,7 +134,10 @@ async def wrapped(i: int): failed += 1 else: submitted += 1 - print(f"[{i}] {name}", flush=True) + if RUNS_FILE: + async with runs_file_lock: + await asyncio.to_thread(_append_run_name, RUNS_FILE, name) + print(f"submitted_run idx={i} url={name}", flush=True) done = submitted + failed if done % 100 == 0: elapsed = time.monotonic() - start @@ -57,6 +149,7 @@ async def wrapped(i: int): elapsed = time.monotonic() - start rps = submitted / elapsed if elapsed > 0 else 0 print(f"\nDone. submitted={submitted} failed={failed} elapsed={elapsed:.2f}s rps={rps:.2f}") + return 1 if failed else 0 # python stress/sleep_fanout_harness.py --total 25000 --concurrency 500 --n_children 10 --sleep_seconds 10 @@ -68,8 +161,7 @@ def main() -> None: parser.add_argument("--sleep_seconds", type=int, default=10) args = parser.parse_args() - flyte.init_from_config() - asyncio.run( + rc = asyncio.run( submit_many( total=args.total, concurrency=args.concurrency, @@ -77,6 +169,7 @@ def main() -> None: sleep_duration=timedelta(seconds=args.sleep_seconds), ) ) + raise SystemExit(rc) if __name__ == "__main__": diff --git a/examples/stress/sleep_fanout_harness_wrapper.sh b/examples/stress/sleep_fanout_harness_wrapper.sh new file mode 100755 index 000000000..3511dea01 --- /dev/null +++ b/examples/stress/sleep_fanout_harness_wrapper.sh @@ -0,0 +1,711 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +CONFIG="${HOME}/.flyte/config-dogfood.yaml" +PROJECT="" +DOMAIN="" +IMAGE_REGISTRY="${FLYTE_STRESS_IMAGE_REGISTRY:-376129846803.dkr.ecr.us-east-2.amazonaws.com/union}" +IMAGE_NAME="${FLYTE_STRESS_IMAGE_NAME:-dogfood}" +IMAGE_PLATFORMS="${FLYTE_STRESS_IMAGE_PLATFORMS:-linux/amd64}" +IMAGE_BUILDER="${FLYTE_STRESS_IMAGE_BUILDER:-remote}" +FANOUT_CPU_REQUEST="${FLYTE_STRESS_FANOUT_CPU_REQUEST:-1}" +FANOUT_CPU_LIMIT="${FLYTE_STRESS_FANOUT_CPU_LIMIT:-2}" +FANOUT_MEMORY_REQUEST="${FLYTE_STRESS_FANOUT_MEMORY_REQUEST:-2Gi}" +FANOUT_MEMORY_LIMIT="${FLYTE_STRESS_FANOUT_MEMORY_LIMIT:-4Gi}" +TOTAL_RUNS=20 +SUBMIT_CONCURRENCY=100 +N_CHILDREN=5000 +SLEEP_DURATION=800 +POLL_INTERVAL=2 +ABORT_REASON="wrapper interrupted" +RUN_ENV_KVS=() + +EXPECTED_TOTAL_CHILDREN=0 +CHILD_RUNS=() + +LAUNCH_PID="" +LAUNCH_LOG="" +LAUNCH_RC_FILE="" +RUNS_FILE="" +LAUNCH_DONE=0 +LAUNCH_RC=0 +STOPPING=0 +ABORT_SENT=0 + +SCRIPT_START_EPOCH="$(date +%s)" +FIRST_DISCOVERED_AT="" +FIRST_RUNNING_AT="" +ALL_VISIBLE_AT="" +TERMINAL_AT="" + +PEAK_SEEN=0 +PEAK_RUNNING=0 +PEAK_ACTIVE=0 +PEAK_CREATE_RPS=0 +PEAK_PARENT_LIVE=0 +PEAK_PARENT_RUNNING=0 +LAST_LAUNCH_STAGE="" +SDK_WHEEL_PATH="" +SDK_SRC_NEWER=0 + +usage() { + cat <<'EOF' +Usage: + examples/stress/sleep_fanout_harness_wrapper.sh [options] + +Options: + --config PATH Flyte config path. Default: ~/.flyte/config-dogfood.yaml + --project NAME Override project for get/abort. + --domain NAME Override domain for get/abort. + --image-registry VALUE Registry prefix for the task image. + --image-name VALUE Repository name for the task image. + --image-builder VALUE Flyte image builder to use for lookups. Default: remote + --run-env KEY=VALUE Export an env var into the local submit harness and propagate it to remote runs. + --total-runs INT Number of top-level sleep_fanout runs to submit. Default: 20 + --submit-concurrency INT Local submission concurrency. Default: 100 + --n-children INT Leaves per sleep_fanout run. Default: 5000 + --sleep-duration VALUE Sleep duration in seconds per leaf. Default: 800 + --poll-interval SEC Poll interval in seconds. Default: 2 + --abort-reason TEXT Reason passed to 'flyte abort run'. Default: wrapper interrupted + --help Show this message. +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --config) + CONFIG="$2" + shift 2 + ;; + --project) + PROJECT="$2" + shift 2 + ;; + --domain) + DOMAIN="$2" + shift 2 + ;; + --image-registry) + IMAGE_REGISTRY="$2" + shift 2 + ;; + --image-name) + IMAGE_NAME="$2" + shift 2 + ;; + --image-builder) + IMAGE_BUILDER="$2" + shift 2 + ;; + --run-env) + RUN_ENV_KVS+=("$2") + shift 2 + ;; + --total-runs) + TOTAL_RUNS="$2" + shift 2 + ;; + --submit-concurrency) + SUBMIT_CONCURRENCY="$2" + shift 2 + ;; + --n-children) + N_CHILDREN="$2" + shift 2 + ;; + --sleep-duration) + SLEEP_DURATION="$2" + shift 2 + ;; + --poll-interval) + POLL_INTERVAL="$2" + shift 2 + ;; + --abort-reason) + ABORT_REASON="$2" + shift 2 + ;; + --help) + usage + exit 0 + ;; + *) + echo "Unknown option: $1" >&2 + usage >&2 + exit 2 + ;; + esac +done + +EXPECTED_TOTAL_CHILDREN=$((TOTAL_RUNS * N_CHILDREN)) + +if ! command -v flyte >/dev/null 2>&1; then + echo "flyte is required but was not found in PATH." >&2 + exit 1 +fi + +if ! command -v jq >/dev/null 2>&1; then + echo "jq is required but was not found in PATH." >&2 + exit 1 +fi + +if ! command -v python >/dev/null 2>&1; then + echo "python is required but was not found in PATH." >&2 + exit 1 +fi + +CONFIG="${CONFIG/#\~/${HOME}}" +LAUNCH_LOG="$(mktemp "${TMPDIR:-/tmp}/sleep-fanout-harness.XXXXXX.log")" +LAUNCH_RC_FILE="$(mktemp "${TMPDIR:-/tmp}/sleep-fanout-harness.XXXXXX.rc")" +RUNS_FILE="$(mktemp "${TMPDIR:-/tmp}/sleep-fanout-runs.XXXXXX.txt")" + +cleanup() { + if [[ -n "${LAUNCH_LOG}" && -f "${LAUNCH_LOG}" ]]; then + rm -f "${LAUNCH_LOG}" + fi + if [[ -n "${LAUNCH_RC_FILE}" && -f "${LAUNCH_RC_FILE}" ]]; then + rm -f "${LAUNCH_RC_FILE}" + fi + if [[ -n "${RUNS_FILE}" && -f "${RUNS_FILE}" ]]; then + rm -f "${RUNS_FILE}" + fi +} +trap cleanup EXIT + +project_args=() +domain_args=() + +if [[ -n "${PROJECT}" ]]; then + project_args=(-p "${PROJECT}") +fi + +if [[ -n "${DOMAIN}" ]]; then + domain_args=(-d "${DOMAIN}") +fi + +flyte_cmd_json() { + COLUMNS=500 _U_USE_ACTIONS="${_U_USE_ACTIONS:-1}" flyte -c "${CONFIG}" --image-builder "${IMAGE_BUILDER}" -of json-raw "$@" \ + | perl -pe 's/\e\[[0-9;]*[A-Za-z]//g' +} + +is_terminal_phase() { + case "$1" in + ACTION_PHASE_SUCCEEDED|ACTION_PHASE_FAILED|ACTION_PHASE_ABORTED|ACTION_PHASE_TIMED_OUT) + return 0 + ;; + *) + return 1 + ;; + esac +} + +format_duration() { + local value="$1" + if [[ -z "${value}" ]]; then + echo "n/a" + return + fi + printf '%02dh:%02dm:%02ds' "$((value / 3600))" "$(((value % 3600) / 60))" "$((value % 60))" +} + +elapsed_from_start() { + local epoch="$1" + if [[ -z "${epoch}" ]]; then + echo "" + return + fi + echo "$((epoch - SCRIPT_START_EPOCH))" +} + +print_row() { + printf '%-8s %-12s %-8s %-8s %-18s %-14s %-8s %-10s %-8s %-10s %-8s %-8s\n' \ + "$(date +%H:%M:%S)" "$1" "$2" "$3" "$4" "$5" "$6" "$7" "$8" "$9" "${10}" "${11}" +} + +sanitize_run_name() { + local value="$1" + value="$(printf '%s' "${value}" | tr -d '\r')" + value="$(printf '%s' "${value}" | sed 's/[[:space:]]*$//')" + value="$(printf '%s' "${value}" | grep -Eo '[ur][[:alnum:]]{5,}' | head -n 1 || true)" + printf '%s' "${value}" +} + +detect_sdk_wheel_status() { + local wheel_path="" + local newest_src="" + + wheel_path="$( + find "${REPO_ROOT}/dist" -maxdepth 1 -type f -name 'flyte-*.whl' -print 2>/dev/null \ + | sort \ + | tail -n 1 || true + )" + SDK_WHEEL_PATH="${wheel_path}" + SDK_SRC_NEWER=0 + + if [[ -z "${wheel_path}" ]]; then + return + fi + + newest_src="$( + find "${REPO_ROOT}/src/flyte" -type f -newer "${wheel_path}" -print 2>/dev/null \ + | head -n 1 || true + )" + if [[ -n "${newest_src}" ]]; then + SDK_SRC_NEWER=1 + fi +} + +child_run_known() { + local target="$1" + local existing="" + for existing in "${CHILD_RUNS[@]}"; do + if [[ "${existing}" == "${target}" ]]; then + return 0 + fi + done + return 1 +} + +discover_child_runs() { + local child_run="" + + if [[ -s "${RUNS_FILE}" ]]; then + while IFS= read -r child_run; do + child_run="$(sanitize_run_name "${child_run}")" + [[ -z "${child_run}" ]] && continue + if ! child_run_known "${child_run}"; then + CHILD_RUNS+=("${child_run}") + fi + done < "${RUNS_FILE}" + fi + + if [[ -s "${LAUNCH_LOG}" ]]; then + while IFS= read -r child_run; do + child_run="$(sanitize_run_name "${child_run}")" + [[ -z "${child_run}" ]] && continue + if ! child_run_known "${child_run}"; then + CHILD_RUNS+=("${child_run}") + fi + done < <( + perl -ne ' + s/\e\[[0-9;]*[A-Za-z]//g; + s/\r/\n/g; + if (/submitted_run idx=\d+ url=.*\/runs\/([^\/?\s]+)/) { + print "$1\n"; + } elsif (/submitted_run idx=\d+ url=([ur][[:alnum:]]{5,})/) { + print "$1\n"; + } + ' "${LAUNCH_LOG}" + ) + fi +} + +fetch_actions_json_for_run() { + local run_name="$1" + flyte_cmd_json get action "${project_args[@]}" "${domain_args[@]}" "${run_name}" +} + +aggregate_child_runs_tsv() { + local discovered=0 + local roots_terminal=0 + local parent_live=0 + local parent_running=0 + local seen=0 + local queued=0 + local waiting=0 + local initializing=0 + local running=0 + local succeeded=0 + local failed=0 + local aborted=0 + local timed_out=0 + local run_name="" + local json="" + local root_phase="" + local c_seen=0 + local c_queued=0 + local c_waiting=0 + local c_initializing=0 + local c_running=0 + local c_succeeded=0 + local c_failed=0 + local c_aborted=0 + local c_timed_out=0 + local not_created=0 + local active=0 + + discovered="${#CHILD_RUNS[@]}" + for run_name in "${CHILD_RUNS[@]}"; do + if ! json="$(fetch_actions_json_for_run "${run_name}" 2>/dev/null)"; then + continue + fi + + IFS=$'\t' read -r root_phase c_seen c_queued c_waiting c_initializing c_running c_succeeded c_failed c_aborted c_timed_out \ + <<<"$(jq -r ' + [ .[] ] as $all + | ($all | map(select(.id.name == "a0")) | .[0]) as $root + | [ $all[] | select(.id.name != "a0") ] as $kids + | [ + ($root.status.phase // "MISSING"), + ($kids | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_QUEUED")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_WAITING_FOR_RESOURCES")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_INITIALIZING")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_RUNNING")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_SUCCEEDED")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_FAILED")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_ABORTED")) | length), + ($kids | map(select(.status.phase == "ACTION_PHASE_TIMED_OUT")) | length) + ] + | @tsv + ' <<<"${json}")" + + if is_terminal_phase "${root_phase}"; then + roots_terminal=$((roots_terminal + 1)) + elif [[ "${root_phase}" != "MISSING" ]]; then + parent_live=$((parent_live + 1)) + if [[ "${root_phase}" == "ACTION_PHASE_RUNNING" ]]; then + parent_running=$((parent_running + 1)) + fi + fi + seen=$((seen + c_seen)) + queued=$((queued + c_queued)) + waiting=$((waiting + c_waiting)) + initializing=$((initializing + c_initializing)) + running=$((running + c_running)) + succeeded=$((succeeded + c_succeeded)) + failed=$((failed + c_failed)) + aborted=$((aborted + c_aborted)) + timed_out=$((timed_out + c_timed_out)) + done + + if (( EXPECTED_TOTAL_CHILDREN > seen )); then + not_created=$((EXPECTED_TOTAL_CHILDREN - seen)) + fi + active=$((queued + waiting + initializing + running)) + + printf '%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n' \ + "${discovered}/${TOTAL_RUNS}" \ + "${parent_live}" \ + "${parent_running}" \ + "${seen}/${EXPECTED_TOTAL_CHILDREN}" \ + "${not_created}" \ + "${queued}" \ + "${waiting}" \ + "${initializing}" \ + "${running}" \ + "${active}" \ + "${succeeded}" \ + "${failed}" \ + "${aborted}" \ + "${timed_out}" \ + "${roots_terminal}" +} + +launch_stage_from_log() { + if [[ ! -s "${LAUNCH_LOG}" ]]; then + return 1 + fi + + local stage="" + stage="$( + perl -pe 's/\e\[[0-9;]*[A-Za-z]//g; s/\r/\n/g' "${LAUNCH_LOG}" \ + | sed '/^[[:space:]]*$/d' \ + | grep -E '^(submitted=|Done\.|submitted_run|Error:|ERROR|Failed|failed)' \ + | tail -n 1 || true + )" + + if [[ -z "${stage}" ]]; then + stage="$(perl -pe 's/\e\[[0-9;]*[A-Za-z]//g; s/\r/\n/g' "${LAUNCH_LOG}" | sed '/^[[:space:]]*$/d' | tail -n 1)" + fi + + [[ -n "${stage}" ]] || return 1 + printf '%s' "${stage}" +} + +abort_remote_runs() { + local run_name="" + if [[ "${ABORT_SENT}" -eq 1 ]]; then + return + fi + ABORT_SENT=1 + + for run_name in "${CHILD_RUNS[@]}"; do + COLUMNS=500 _U_USE_ACTIONS="${_U_USE_ACTIONS:-1}" flyte -c "${CONFIG}" --image-builder "${IMAGE_BUILDER}" \ + abort run "${project_args[@]}" "${domain_args[@]}" --reason "${ABORT_REASON}" "${run_name}" >/dev/null 2>&1 || true + done +} + +handle_signal() { + local sig="$1" + if [[ "${STOPPING}" -eq 1 ]]; then + echo + echo "Received ${sig} again, exiting immediately." + exit 130 + fi + STOPPING=1 + + echo + echo "Received ${sig}, stopping local submissions and aborting discovered runs." + if [[ -n "${LAUNCH_PID}" ]] && kill -0 "${LAUNCH_PID}" 2>/dev/null; then + kill "${LAUNCH_PID}" 2>/dev/null || true + fi + discover_child_runs + abort_remote_runs +} + +trap 'handle_signal INT' INT +trap 'handle_signal TERM' TERM + +cd "${REPO_ROOT}" +detect_sdk_wheel_status + +echo "Launching local multi-run harness" +echo " config: ${CONFIG}" +echo " total_runs: ${TOTAL_RUNS}" +echo " submit_concurrency: ${SUBMIT_CONCURRENCY}" +echo " n_children_per_run: ${N_CHILDREN}" +echo " total_children_expected: ${EXPECTED_TOTAL_CHILDREN}" +echo " sleep_duration: ${SLEEP_DURATION}" +echo " poll_interval: ${POLL_INTERVAL}s" +echo " image target: ${IMAGE_REGISTRY}/${IMAGE_NAME}" +echo " image builder: ${IMAGE_BUILDER}" +echo " image platforms: ${IMAGE_PLATFORMS}" +if [[ "${FLYTE_HARNESS_FORCE_LOCAL_SDK:-0}" == "1" || "${FLYTE_HARNESS_FORCE_LOCAL_SDK:-}" == "true" ]]; then + echo " sdk source: ${REPO_ROOT}/src forced via $(command -v flyte)" +else + echo " sdk source: installed flyte via $(command -v flyte)" +fi +if [[ -n "${SDK_WHEEL_PATH}" ]]; then + echo " sdk wheel: ${SDK_WHEEL_PATH}" + if [[ "${SDK_SRC_NEWER}" -eq 1 ]]; then + echo " warning: src/flyte is newer than the dist wheel; remote image will not include recent SDK src changes until you rebuild the wheel" + fi +else + echo " sdk wheel: " +fi +echo " fanout parent resources: cpu ${FANOUT_CPU_REQUEST}/${FANOUT_CPU_LIMIT}, memory ${FANOUT_MEMORY_REQUEST}/${FANOUT_MEMORY_LIMIT}" +echo " use_actions: ${_U_USE_ACTIONS:-1}" +if [[ -n "${PROJECT}" || -n "${DOMAIN}" ]]; then + echo " project/domain override: ${PROJECT:-} / ${DOMAIN:-}" +fi +if [[ "${#RUN_ENV_KVS[@]}" -gt 0 ]]; then + echo " run env overrides: ${RUN_ENV_KVS[*]}" +fi +echo +printf '%-8s %-12s %-8s %-8s %-18s %-14s %-8s %-10s %-8s %-10s %-8s %-8s\n' \ + "time" "runs" "p_live" "p_run" "seen_children" "not_created" "d_seen" "create_rps" "rps/p" "eta_fill" "running" "active" + +( + export _U_USE_ACTIONS="${_U_USE_ACTIONS:-1}" + export FLYTE_STRESS_IMAGE_REGISTRY="${IMAGE_REGISTRY}" + export FLYTE_STRESS_IMAGE_NAME="${IMAGE_NAME}" + export FLYTE_STRESS_IMAGE_PLATFORMS="${IMAGE_PLATFORMS}" + export FLYTE_HARNESS_CONFIG="${CONFIG}" + export FLYTE_HARNESS_IMAGE_BUILDER="${IMAGE_BUILDER}" + export FLYTE_HARNESS_PROJECT="${PROJECT}" + export FLYTE_HARNESS_DOMAIN="${DOMAIN}" + export FLYTE_HARNESS_RUNS_FILE="${RUNS_FILE}" + local_kv="" + for local_kv in "${RUN_ENV_KVS[@]}"; do + export "${local_kv}" + done + rc=0 + python examples/stress/sleep_fanout_harness.py \ + --total "${TOTAL_RUNS}" \ + --concurrency "${SUBMIT_CONCURRENCY}" \ + --n_children "${N_CHILDREN}" \ + --sleep_seconds "${SLEEP_DURATION}" || rc=$? + printf '%s\n' "${rc}" > "${LAUNCH_RC_FILE}" + exit "${rc}" +) >"${LAUNCH_LOG}" 2>&1 & +LAUNCH_PID=$! + +FINAL_DISCOVERED=0 +FINAL_SEEN=0 +FINAL_SUCCEEDED=0 +FINAL_FAILED=0 +FINAL_ABORTED=0 +FINAL_TIMED_OUT=0 +FINAL_ROOTS_TERMINAL=0 +FINAL_PARENT_LIVE=0 +FINAL_PARENT_RUNNING=0 +LAST_SAMPLE_TS="" +LAST_SAMPLE_SEEN="" + +while true; do + discover_child_runs + + if [[ -z "${FIRST_DISCOVERED_AT}" && "${#CHILD_RUNS[@]}" -gt 0 ]]; then + FIRST_DISCOVERED_AT="$(date +%s)" + fi + + if [[ "${LAUNCH_DONE}" -eq 0 && -s "${LAUNCH_RC_FILE}" ]]; then + LAUNCH_RC="$(tr -d '\r\n[:space:]' < "${LAUNCH_RC_FILE}")" + if [[ -z "${LAUNCH_RC}" ]]; then + LAUNCH_RC=1 + fi + wait "${LAUNCH_PID}" 2>/dev/null || true + LAUNCH_DONE=1 + fi + + if [[ "${#CHILD_RUNS[@]}" -gt 0 ]]; then + IFS=$'\t' read -r runs parent_live parent_running seen_children not_created queued waiting initializing running active succeeded failed aborted timed_out roots_terminal \ + <<<"$(aggregate_child_runs_tsv)" + + FINAL_DISCOVERED="${runs%%/*}" + FINAL_PARENT_LIVE="${parent_live}" + FINAL_PARENT_RUNNING="${parent_running}" + FINAL_SEEN="${seen_children%%/*}" + FINAL_SUCCEEDED="${succeeded}" + FINAL_FAILED="${failed}" + FINAL_ABORTED="${aborted}" + FINAL_TIMED_OUT="${timed_out}" + FINAL_ROOTS_TERMINAL="${roots_terminal}" + + if (( FINAL_SEEN > PEAK_SEEN )); then + PEAK_SEEN="${FINAL_SEEN}" + fi + if (( running > PEAK_RUNNING )); then + PEAK_RUNNING="${running}" + fi + if (( active > PEAK_ACTIVE )); then + PEAK_ACTIVE="${active}" + fi + if (( parent_live > PEAK_PARENT_LIVE )); then + PEAK_PARENT_LIVE="${parent_live}" + fi + if (( parent_running > PEAK_PARENT_RUNNING )); then + PEAK_PARENT_RUNNING="${parent_running}" + fi + if [[ -z "${FIRST_RUNNING_AT}" && "${running}" -gt 0 ]]; then + FIRST_RUNNING_AT="$(date +%s)" + fi + if [[ -z "${ALL_VISIBLE_AT}" && "${FINAL_SEEN}" -ge "${EXPECTED_TOTAL_CHILDREN}" ]]; then + ALL_VISIBLE_AT="$(date +%s)" + fi + + sample_ts="$(date +%s)" + delta_seen="0" + create_rps="n/a" + create_rps_per_parent="n/a" + eta_fill="n/a" + if [[ -n "${LAST_SAMPLE_TS}" && -n "${LAST_SAMPLE_SEEN}" ]]; then + delta_t=$((sample_ts - LAST_SAMPLE_TS)) + if (( delta_t > 0 )); then + delta_seen=$((FINAL_SEEN - LAST_SAMPLE_SEEN)) + if (( delta_seen < 0 )); then + delta_seen=0 + fi + create_rps="$(python - <<'PY' "${delta_seen}" "${delta_t}" +import sys +delta_seen = int(sys.argv[1]) +delta_t = int(sys.argv[2]) +print(f"{delta_seen / delta_t:.1f}") +PY +)" + if (( parent_running > 0 )); then + create_rps_per_parent="$(python - <<'PY' "${delta_seen}" "${delta_t}" "${parent_running}" +import sys +delta_seen = int(sys.argv[1]) +delta_t = int(sys.argv[2]) +parent_running = int(sys.argv[3]) +print(f"{(delta_seen / delta_t) / parent_running:.1f}") +PY +)" + fi + create_rps_int="$(python - <<'PY' "${delta_seen}" "${delta_t}" +import sys +delta_seen = int(sys.argv[1]) +delta_t = int(sys.argv[2]) +print(int(delta_seen / delta_t)) +PY +)" + if (( create_rps_int > PEAK_CREATE_RPS )); then + PEAK_CREATE_RPS="${create_rps_int}" + fi + if (( FINAL_SEEN < EXPECTED_TOTAL_CHILDREN && delta_seen > 0 )); then + eta_fill="$(python - <<'PY' "${EXPECTED_TOTAL_CHILDREN}" "${FINAL_SEEN}" "${delta_seen}" "${delta_t}" +import math +import sys +expected = int(sys.argv[1]) +seen = int(sys.argv[2]) +delta_seen = int(sys.argv[3]) +delta_t = int(sys.argv[4]) +remaining = expected - seen +seconds = math.ceil(remaining / (delta_seen / delta_t)) +h, rem = divmod(seconds, 3600) +m, s = divmod(rem, 60) +print(f"{h:02d}:{m:02d}:{s:02d}") +PY +)" + elif (( FINAL_SEEN >= EXPECTED_TOTAL_CHILDREN )); then + eta_fill="00:00:00" + fi + fi + fi + LAST_SAMPLE_TS="${sample_ts}" + LAST_SAMPLE_SEEN="${FINAL_SEEN}" + + print_row \ + "${runs}" \ + "${parent_live}" \ + "${parent_running}" \ + "${seen_children}" \ + "${not_created}" \ + "${delta_seen}" \ + "${create_rps}" \ + "${create_rps_per_parent}" \ + "${eta_fill}" \ + "${running}" \ + "${active}" + + if [[ "${LAUNCH_DONE}" -eq 1 ]] && (( FINAL_ROOTS_TERMINAL == FINAL_DISCOVERED )) && (( active == 0 )); then + TERMINAL_AT="$(date +%s)" + break + fi + else + if stage="$(launch_stage_from_log)"; then + if [[ "${stage}" != "${LAST_LAUNCH_STAGE}" ]]; then + LAST_LAUNCH_STAGE="${stage}" + echo "launch: ${stage}" + fi + fi + print_row "0/${TOTAL_RUNS}" 0 0 "0/${EXPECTED_TOTAL_CHILDREN}" "${EXPECTED_TOTAL_CHILDREN}" 0 0 0 0 0 0 + fi + + if [[ "${LAUNCH_DONE}" -eq 1 && "${LAUNCH_RC}" -ne 0 && "${#CHILD_RUNS[@]}" -eq 0 ]]; then + echo + echo "Local submit harness failed before any runs were discovered." >&2 + cat "${LAUNCH_LOG}" >&2 + exit "${LAUNCH_RC}" + fi + + sleep "${POLL_INTERVAL}" +done + +echo +echo "Aggregate Summary" +echo " runs_discovered: ${FINAL_DISCOVERED}/${TOTAL_RUNS}" +echo " total_expected_children: ${EXPECTED_TOTAL_CHILDREN}" +echo " child_run_roots_terminal: ${FINAL_ROOTS_TERMINAL}/${FINAL_DISCOVERED}" +echo " peak_parent_live: ${PEAK_PARENT_LIVE}" +echo " peak_parent_running: ${PEAK_PARENT_RUNNING}" +echo " children_seen: ${FINAL_SEEN}/${EXPECTED_TOTAL_CHILDREN}" +echo " succeeded: ${FINAL_SUCCEEDED}" +echo " failed: ${FINAL_FAILED}" +echo " aborted: ${FINAL_ABORTED}" +echo " timed_out: ${FINAL_TIMED_OUT}" +echo " peak_seen: ${PEAK_SEEN}/${EXPECTED_TOTAL_CHILDREN}" +echo " peak_running: ${PEAK_RUNNING}" +echo " peak_active: ${PEAK_ACTIVE}" +echo " peak_create_rps: ${PEAK_CREATE_RPS}" +echo " first_run_discovered: $(format_duration "$(elapsed_from_start "${FIRST_DISCOVERED_AT}")")" +echo " aggregate_first_running: $(format_duration "$(elapsed_from_start "${FIRST_RUNNING_AT}")")" +echo " aggregate_all_visible: $(format_duration "$(elapsed_from_start "${ALL_VISIBLE_AT}")")" +echo " aggregate_terminal: $(format_duration "$(elapsed_from_start "${TERMINAL_AT}")")" +echo " total_elapsed: $(format_duration "$(( $(date +%s) - SCRIPT_START_EPOCH ))")" diff --git a/src/flyte/_internal/controllers/remote/_controller.py b/src/flyte/_internal/controllers/remote/_controller.py index a37356924..fd35dce2c 100644 --- a/src/flyte/_internal/controllers/remote/_controller.py +++ b/src/flyte/_internal/controllers/remote/_controller.py @@ -4,6 +4,7 @@ import concurrent.futures import os import threading +import time from collections import defaultdict from collections.abc import Callable from contextlib import nullcontext @@ -153,6 +154,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg if tctx is None: raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized") current_action_id = tctx.action + trace_enabled = self._should_trace_sequence(_task_call_seq) # In the case of a regular code bundle, we will just pass it down as it is to the downstream tasks # It is not allowed to change the code bundle (for regular code bundles) in the middle of a run. @@ -167,11 +169,14 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg ) _ctx = ctx.new_in_driver_literal_conversion(True) if ctx.is_task_context() else nullcontext() + sdk_inputs_start = time.monotonic() with _ctx: inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs) + sdk_inputs_ms = (time.monotonic() - sdk_inputs_start) * 1000 root_dir = Path(code_bundle.destination).absolute() if code_bundle else Path.cwd() # Don't set output path in sec context because node executor will set it + sdk_serialize_start = time.monotonic() new_serialization_context = SerializationContext( project=current_action_id.project, domain=current_action_id.domain, @@ -189,12 +194,17 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path( tctx, task_spec, inputs_hash, _task_call_seq ) + sdk_serialize_ms = (time.monotonic() - sdk_serialize_start) * 1000 logger.info(f"Sub action {sub_action_id} output path {sub_action_output_path}") serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True) + serialized_input_bytes = len(serialized_inputs) inputs_uri = io.inputs_path(sub_action_output_path) + storage_put_start = time.monotonic() await upload_inputs_with_retry(serialized_inputs, inputs_uri, max_bytes=_task.max_inline_io_bytes) + storage_put_ms = (time.monotonic() - storage_put_start) * 1000 + sdk_cache_start = time.monotonic() md = task_spec.task_template.metadata ignored_input_vars = [] if len(md.cache_ignore_input_vars) > 0: @@ -210,6 +220,7 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg ignored_input_vars, inputs.proto_inputs, ) + sdk_cache_ms = (time.monotonic() - sdk_cache_start) * 1000 # Clear to free memory serialized_inputs = None # type: ignore @@ -233,13 +244,41 @@ async def _submit(self, _task_call_seq: int, _task: TaskTemplate, *args, **kwarg cache_key=cache_key, queue=_task.queue, ) + self._mark_action_for_trace(action.name) + if trace_enabled: + self._trace_log( + action.name, + "sdk_prepare", + kind="sdk_only", + seq=_task_call_seq, + task=_task.name, + sdk_inputs_ms=f"{sdk_inputs_ms:.1f}", + sdk_serialize_ms=f"{sdk_serialize_ms:.1f}", + sdk_cache_ms=f"{sdk_cache_ms:.1f}", + input_bytes=serialized_input_bytes, + ) + self._trace_log( + action.name, + "storage_put_inputs", + kind="storage_api", + elapsed_ms=f"{storage_put_ms:.1f}", + input_bytes=serialized_input_bytes, + ) try: logger.info( f"Submitting action Run:[{action.run_name}, Parent:[{action.parent_action_name}], " f"task:[{_task.name}], action:[{action.name}]" ) + submit_start = time.monotonic() n = await self.submit_action(action) + if trace_enabled: + self._trace_log( + action.name, + "submit_action_done", + kind="mixed", + elapsed_ms=f"{(time.monotonic() - submit_start) * 1000:.1f}", + ) logger.info(f"Action for task [{_task.name}] action id: {action.name}, completed!") except asyncio.CancelledError: # If the action is cancelled, we need to cancel the action on the server as well diff --git a/src/flyte/_internal/controllers/remote/_core.py b/src/flyte/_internal/controllers/remote/_core.py index b67bf13be..a5fcc9d19 100644 --- a/src/flyte/_internal/controllers/remote/_core.py +++ b/src/flyte/_internal/controllers/remote/_core.py @@ -4,6 +4,7 @@ import os import sys import threading +import time from asyncio import Event from typing import Awaitable, Coroutine, Optional @@ -65,7 +66,7 @@ def __init__( self._shared_queue: asyncio.Queue[Action] = asyncio.Queue(maxsize=10000) self._running = False self._resource_log_task = None - self._workers = workers + self._workers = int(os.getenv("_F_CTRL_WORKERS", str(workers))) self._max_retries = int(os.getenv("_F_MAX_RETRIES", max_system_retries)) self._resource_log_interval = resource_log_interval_sec self._min_backoff_on_err = min_backoff_on_err_sec @@ -77,6 +78,10 @@ def __init__( self._informer_start_wait_timeout = thread_wait_timeout_sec max_qps = int(os.getenv("_F_MAX_QPS", "100")) self._rate_limiter = AsyncLimiter(max_qps, 1.0) + self._trace_submit = os.getenv("_F_TRACE_SUBMIT", "").lower() in {"1", "true", "yes", "on"} + self._trace_submit_limit = int(os.getenv("_F_TRACE_SUBMIT_LIMIT", "10")) + self._trace_actions: set[str] = set() + self._trace_lock = threading.Lock() # Thread management self._thread = None @@ -86,6 +91,28 @@ def __init__( self._thread_com_lock = threading.Lock() self._start() + def _should_trace_sequence(self, seq: int) -> bool: + return self._trace_submit and seq <= self._trace_submit_limit + + def _mark_action_for_trace(self, action_name: str): + if not self._trace_submit: + return + with self._trace_lock: + if len(self._trace_actions) < self._trace_submit_limit: + self._trace_actions.add(action_name) + + def _trace_enabled_for(self, action_name: str) -> bool: + if not self._trace_submit: + return False + with self._trace_lock: + return action_name in self._trace_actions + + def _trace_log(self, action_name: str, phase: str, **fields): + if not self._trace_enabled_for(action_name): + return + payload = " ".join(f"{key}={value}" for key, value in fields.items()) + print(f"submit_trace action={action_name} phase={phase} {payload}".rstrip(), flush=True) + # ---------------- Public sync methods, we can add more sync methods if needed @log def submit_action_sync(self, action: Action) -> Action: @@ -277,6 +304,8 @@ async def _bg_finalize_informer( async def _bg_submit_action(self, action: Action) -> Action: """Submit a resource and await its completion, returning the final state""" logger.debug(f"{threading.current_thread().name} Submitting action {action.name}") + trace_enabled = self._trace_enabled_for(action.name) + informer_start = time.monotonic() informer = await self._informers.get_or_create( action.action_id.run, action.parent_action_name, @@ -286,11 +315,36 @@ async def _bg_submit_action(self, action: Action) -> Action: timeout=self._informer_start_wait_timeout, actions_service=self._actions_service, ) + if trace_enabled: + watch_api = "actions.watch_for_updates" if self._actions_service else "state.watch" + self._trace_log( + action.name, + "informer_ready", + kind="controlplane_api", + api=watch_api, + elapsed_ms=f"{(time.monotonic() - informer_start) * 1000:.1f}", + ) + queue_submit_start = time.monotonic() await informer.submit(action) + if trace_enabled: + self._trace_log( + action.name, + "queue_submit", + kind="sdk_only", + elapsed_ms=f"{(time.monotonic() - queue_submit_start) * 1000:.1f}", + ) logger.debug(f"{threading.current_thread().name} Waiting for completion of {action.name}") # Wait for completion + wait_start = time.monotonic() await informer.wait_for_action_completion(action.name) + if trace_enabled: + self._trace_log( + action.name, + "wait_for_completion", + kind="lifecycle_wait", + elapsed_ms=f"{(time.monotonic() - wait_start) * 1000:.1f}", + ) logger.info(f"{threading.current_thread().name} Action {action.name} completed") # Get final resource state and clean up @@ -346,7 +400,9 @@ async def _bg_launch(self, action: Action): Attempt to launch an action. """ if not action.is_started(): + limiter_wait_start = time.monotonic() async with self._rate_limiter: + limiter_wait_ms = (time.monotonic() - limiter_wait_start) * 1000 task: run_definition_pb2.TaskAction | None = None trace: run_definition_pb2.TraceAction | None = None if action.type == "task": @@ -375,6 +431,7 @@ async def _bg_launch(self, action: Action): trace = action.trace logger.debug(f"Attempting to launch action: {action.name}, actions? {bool(self._actions_service)}") + launch_start = time.monotonic() try: if self._actions_service: await self._actions_service.enqueue( @@ -406,6 +463,14 @@ async def _bg_launch(self, action: Action): timeout_ms=int(self._enqueue_timeout * 1000), ) logger.info(f"Successfully launched action: {action.name}") + self._trace_log( + action.name, + "enqueue_action", + kind="controlplane_api", + api="actions.enqueue" if self._actions_service else "queue.enqueue_action", + limiter_wait_ms=f"{limiter_wait_ms:.1f}", + elapsed_ms=f"{(time.monotonic() - launch_start) * 1000:.1f}", + ) except ConnectError as e: if e.code == Code.ALREADY_EXISTS: logger.info(f"Action {action.name} already exists, continuing to monitor.") diff --git a/src/flyte/remote/_client/auth/_token_client.py b/src/flyte/remote/_client/auth/_token_client.py index 214fdc5e0..9ea6198a2 100644 --- a/src/flyte/remote/_client/auth/_token_client.py +++ b/src/flyte/remote/_client/auth/_token_client.py @@ -17,6 +17,10 @@ error_slow_down = "slow_down" error_auth_pending = "authorization_pending" +_TOKEN_REQUEST_MAX_ATTEMPTS = 5 +_TOKEN_REQUEST_INITIAL_BACKOFF_SECONDS = 1.0 +_TOKEN_REQUEST_MAX_BACKOFF_SECONDS = 8.0 + # Grant Types class GrantType(str, enum.Enum): @@ -82,6 +86,36 @@ def get_basic_authorization_header(client_id: str, client_secret: str) -> str: return "Basic {}".format(base64.b64encode(concatenated.encode(utf_8)).decode(utf_8)) +async def _post_token_request( + http_session: httpx.AsyncClient, + token_endpoint: str, + *, + data: dict[str, str], + headers: dict[str, str], +) -> httpx.Response: + """POST to the token endpoint with bounded retries for transient transport failures.""" + for attempt in range(1, _TOKEN_REQUEST_MAX_ATTEMPTS + 1): + try: + return await http_session.post(token_endpoint, data=data, headers=headers) + except httpx.TransportError: + if attempt >= _TOKEN_REQUEST_MAX_ATTEMPTS: + raise + backoff = min( + _TOKEN_REQUEST_INITIAL_BACKOFF_SECONDS * (2 ** (attempt - 1)), + _TOKEN_REQUEST_MAX_BACKOFF_SECONDS, + ) + logger.warning( + "Token endpoint request failed for %s, retrying in %.1fs (%d/%d)", + token_endpoint, + backoff, + attempt, + _TOKEN_REQUEST_MAX_ATTEMPTS, + exc_info=True, + ) + await asyncio.sleep(backoff) + raise RuntimeError("unreachable") + + async def get_token( token_endpoint: str, http_session: httpx.AsyncClient, @@ -149,7 +183,7 @@ async def get_token( if refresh_token: body["refresh_token"] = refresh_token - response = await http_session.post(token_endpoint, data=body, headers=headers) + response = await _post_token_request(http_session, token_endpoint, data=body, headers=headers) if not response.is_success: j = response.json() diff --git a/tests/flyte/remote/test_auth_token_client.py b/tests/flyte/remote/test_auth_token_client.py new file mode 100644 index 000000000..8a66b29cc --- /dev/null +++ b/tests/flyte/remote/test_auth_token_client.py @@ -0,0 +1,62 @@ +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest + +from flyte.remote._client.auth._token_client import get_token + + +def _success_response(access_token: str = "access-token", refresh_token: str | None = None, expires_in: int = 3600): + response = Mock(spec=httpx.Response) + response.is_success = True + payload = { + "access_token": access_token, + "expires_in": expires_in, + } + if refresh_token is not None: + payload["refresh_token"] = refresh_token + response.json.return_value = payload + return response + + +@pytest.mark.asyncio +async def test_get_token_retries_transient_transport_errors(monkeypatch): + session = Mock(spec=httpx.AsyncClient) + session.post = AsyncMock( + side_effect=[ + httpx.ConnectTimeout("connect timed out"), + _success_response(refresh_token="refresh-token"), + ] + ) + sleep = AsyncMock() + monkeypatch.setattr("flyte.remote._client.auth._token_client.asyncio.sleep", sleep) + + token, refresh_token, expires_in = await get_token( + token_endpoint="https://issuer.example.com/oauth/token", + http_session=session, + client_id="client-id", + ) + + assert token == "access-token" + assert refresh_token == "refresh-token" + assert expires_in == 3600 + assert session.post.await_count == 2 + sleep.assert_awaited_once_with(1.0) + + +@pytest.mark.asyncio +async def test_get_token_raises_after_retry_budget_exhausted(monkeypatch): + session = Mock(spec=httpx.AsyncClient) + session.post = AsyncMock(side_effect=httpx.ConnectTimeout("connect timed out")) + sleep = AsyncMock() + monkeypatch.setattr("flyte.remote._client.auth._token_client.asyncio.sleep", sleep) + + with pytest.raises(httpx.ConnectTimeout): + await get_token( + token_endpoint="https://issuer.example.com/oauth/token", + http_session=session, + client_id="client-id", + ) + + assert session.post.await_count == 5 + assert sleep.await_count == 4