Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/flyte/extras/shell/_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import shlex
from typing import Any, Tuple

from ._types import FlagSpec, Stderr, Stdout, _classify_input
from ._types import FlagSpec, Stderr, Stdout, _classify_input, _is_optional

_PLACEHOLDER_RE = re.compile(r"\{(inputs|flags|outputs)\.([a-zA-Z_][a-zA-Z0-9_]*)\}")
_DICT_SEP = "\x1e"
Expand All @@ -20,6 +20,7 @@ def _render_command(
output_data_dir: pathlib.Path,
) -> Tuple[str, list[str]]:
kinds = {name: _classify_input(name, tp) for name, tp in inputs.items()}
optionals = {name: _is_optional(tp)[0] for name, tp in inputs.items()}

preamble_lines: list[str] = []
positional_templates: list[str] = []
Expand All @@ -33,14 +34,18 @@ def alloc_slot(name: str) -> str:

idx = len(positional_templates) + 1
positional_templates.append(f"{{{{.inputs.{name}}}}}")
var = f"_VAL_{name.upper()}"
preamble_lines.append(f'{var}="${idx}"')
var = f"_VAL_{name}"
# Brace the positional index: bash parses `$10` as `$1` + `"0"`,
# so any task with 10+ scalar/bool inputs would silently bind
# later variables to the wrong values. `${10}` is the only form
# that works for indices ≥ 10.
preamble_lines.append(f'{var}="${{{idx}}}"')
slot_var_for[name] = var
return var

def ensure_dict_decoded(name: str) -> str:
val_var = alloc_slot(name)
arr_var = f"_ARR_{name.upper()}"
arr_var = f"_ARR_{name}"

if name not in dict_decoded:
dict_decoded.add(name)
Expand Down Expand Up @@ -77,7 +82,7 @@ def render_flag_ref(name: str) -> str:

kind = kinds[name]
spec = flag_specs[name]
flag_var = f"_FLAG_{name.upper()}"
flag_var = f"_FLAG_{name}"

if name not in flag_emitted:
flag_emitted.add(name)
Expand All @@ -90,6 +95,7 @@ def render_flag_ref(name: str) -> str:
alloc_slot,
ensure_dict_decoded,
input_data_dir,
is_optional=optionals[name],
)
)
if kind in ("list_file", "dict_str"):
Expand Down Expand Up @@ -139,6 +145,7 @@ def _emit_flag_setter(
alloc_slot,
ensure_dict_decoded,
input_data_dir: pathlib.Path,
is_optional: bool = False,
) -> str:
flag = spec.flag
sep = spec.separator
Expand All @@ -155,6 +162,12 @@ def _emit_flag_setter(
)
if kind in ("file", "dir"):
path = input_data_dir / name
if is_optional:
return (
f"if [ -e {shlex.quote(str(path))} ]; then "
f"{flag_var}={shlex.quote(flag + sep + str(path))}; "
f'else {flag_var}=""; fi'
)
return f"{flag_var}={shlex.quote(flag + sep + str(path))}"
if kind == "list_file":
dirpath = input_data_dir / name
Expand Down
88 changes: 87 additions & 1 deletion src/flyte/extras/shell/_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class _Shell:
outputs: dict[str, Any]
script: str
flag_aliases: dict[str, FlagSpec] = field(default_factory=dict)
defaults: dict[str, Any] = field(default_factory=dict)
shell: str = "/bin/bash"
debug: bool = False
input_data_dir: pathlib.Path = pathlib.Path("/var/inputs")
Expand Down Expand Up @@ -186,6 +187,9 @@ async def _unpack_outputs(self, raw: Any) -> Any:
return unpacked[0] if single else tuple(unpacked)

async def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
if self.defaults:
kwargs = {**self.defaults, **kwargs}

out: dict[str, Any] = {}
for name, tp in self.inputs.items():
is_opt, _ = _is_optional(tp)
Expand Down Expand Up @@ -267,6 +271,60 @@ async def _get_output(self, output_directory: pathlib.Path) -> Tuple[Any, ...]:
) from e


def _validate_defaults(defaults: dict[str, Any], inputs: dict[str, Type]) -> dict[str, Any]:
"""Validate that every default key exists in ``inputs`` and that the
value's Python type matches the declared input type.

A ``None`` default is rejected — the same effect is achieved by
declaring the input as ``T | None`` and omitting it from ``defaults``.
"""
for name, value in defaults.items():
if name not in inputs:
raise KeyError(f"defaults references {name!r} which is not declared in inputs.")
if value is None:
raise ValueError(
f"defaults[{name!r}] = None is redundant; declare the input as "
f"`T | None` and omit it from defaults instead."
)

_, inner = _is_optional(inputs[name])
kind = _classify_input(name, inputs[name])

if kind == "file":
if not isinstance(value, File):
raise TypeError(f"defaults[{name!r}]: expected File, got {type(value).__name__}.")
elif kind == "dir":
if not isinstance(value, Dir):
raise TypeError(f"defaults[{name!r}]: expected Dir, got {type(value).__name__}.")
elif kind == "list_file":
if not isinstance(value, list):
raise TypeError(f"defaults[{name!r}]: expected list[File], got {type(value).__name__}.")
if not all(isinstance(item, File) for item in value):
raise TypeError(f"defaults[{name!r}]: list[File] requires every item to be a File.")
elif kind == "bool":
if not isinstance(value, bool):
raise TypeError(f"defaults[{name!r}]: expected bool, got {type(value).__name__}.")
elif kind == "dict_str":
if not isinstance(value, dict):
raise TypeError(f"defaults[{name!r}]: expected dict[str, str], got {type(value).__name__}.")
for k, v in value.items():
if not isinstance(k, str) or not isinstance(v, str):
raise TypeError(f"defaults[{name!r}]: dict[str, str] requires string keys and values.")
elif kind == "scalar":
if inner is int:
if not isinstance(value, int) or isinstance(value, bool):
raise TypeError(f"defaults[{name!r}]: expected int, got {type(value).__name__}.")
elif inner is float:
if not isinstance(value, (int, float)) or isinstance(value, bool):
raise TypeError(f"defaults[{name!r}]: expected float, got {type(value).__name__}.")
elif inner is str:
if not isinstance(value, str):
raise TypeError(f"defaults[{name!r}]: expected str, got {type(value).__name__}.")
else:
raise AssertionError(inner)
return dict(defaults)


def _truncate(s: str, limit: int = 4000) -> str:
if len(s) <= limit:
return s if s.endswith("\n") else s + "\n"
Expand Down Expand Up @@ -297,6 +355,7 @@ def create(
outputs: Optional[dict[str, Any]] = None,
script: str,
flag_aliases: Optional[dict[str, Union[str, Tuple[str, listMode], FlagSpec]]] = None,
defaults: Optional[dict[str, Any]] = None,
shell: str = "/bin/bash",
debug: bool = False,
resources: Optional[flyte.Resources] = None,
Expand Down Expand Up @@ -397,14 +456,27 @@ class for behaviour the type system can't express:
travel through bash positional args (``$1``, ``$2``) so they
survive arbitrary content (single quotes, tabs, dollar signs)
without escaping, and the wrapper already emits scalar
references as ``"${_VAL_X}"`` (quoted, single token). Wrapping
references as ``"${_VAL_name}"`` (quoted, single token). Wrapping
them in ``"..."`` again breaks out of the wrapper's quoting and
re-enables word splitting.
flag_aliases: Per-input override for ``{flags.<name>}`` rendering.
Values may be a string (just the flag, default join mode) or
``(flag, list_mode)`` to pick a list rendering mode (``"join"``,
``"repeat"``, ``"comma"``) or ``(flag, dict_mode)``
for dicts (``"pairs"``, ``"equals"``).
defaults: Per-input fallback value used when the caller omits that
input at call time. Lets you mark inputs as "optional at call
site" while still emitting their flag, independent of the
``T | None`` axis. The interaction with ``T | None`` is:

==================== ========================= =================================
Type In ``defaults`` Behavior when caller omits
==================== ========================= =================================
``T`` No ``TypeError`` at submit time
``T`` Yes Default used; flag emitted
``T | None`` No Empty value; flag suppressed
``T | None`` Yes Default used; flag emitted
==================== ========================= =================================
shell: Shell binary to use. Defaults to ``/bin/bash``.
debug: If True, container prints the rendered script to stderr
before running. Invaluable when authoring a new wrapper.
Expand Down Expand Up @@ -460,6 +532,17 @@ async def pipeline(a: File, b: list[File]) -> list[File]:
inputs = inputs or {}
outputs = outputs or {}

# Sanity-check the generated helper names used by the shell renderer.
generated_helpers: dict[str, str] = {}
for n in inputs:
for helper in (f"_VAL_{n}", f"_FLAG_{n}", f"_ARR_{n}"):
if helper in generated_helpers:
raise ValueError(
f"Input names {generated_helpers[helper]!r} and {n!r} collide in generated shell helper "
f"name {helper!r}. Rename one input."
)
generated_helpers[helper] = n

for n, t in inputs.items():
_classify_input(n, t)

Expand All @@ -471,6 +554,8 @@ async def pipeline(a: File, b: list[File]) -> list[File]:
raise KeyError(f"flag_aliases references {n!r} which is not declared in inputs.")
coerced_aliases[n] = FlagSpec.coerce(n, alias)

validated_defaults: dict[str, Any] = _validate_defaults(defaults or {}, inputs)

if not isinstance(image, (str, flyte.Image)):
raise TypeError(f"image must be a URI string or a flyte.Image, got {type(image).__name__}.")

Expand All @@ -481,6 +566,7 @@ async def pipeline(a: File, b: list[File]) -> list[File]:
outputs=dict(outputs),
script=script,
flag_aliases=coerced_aliases,
defaults=validated_defaults,
shell=shell,
debug=debug,
resources=resources,
Expand Down
Loading
Loading