From b15758cb5ca0d7647800d122835963f604d1196c Mon Sep 17 00:00:00 2001 From: Kyle Hazen Date: Tue, 12 May 2026 21:10:16 -0700 Subject: [PATCH 1/3] shell: add defaults param, fix 3 rendering bugs, add bedtools example MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit shell.create now accepts `defaults={name: value}` — per-input fallback values applied when the caller omits an input. Works independently of `T | None`: required + no default raises, optional + no default suppresses the flag, either + default uses the default. Three rendering bugs surfaced while building a bedtools intersect wrapper: 1. Positional-arg indices >= 10 were emitted as bare `$10`, which bash parses as `$1` followed by literal "0". Any task with 10+ scalar/bool inputs had its 10..N values silently corrupted. Fixed by always braced `${10}`. 2. Optional File / Dir flags emitted unconditionally — the renderer hardcoded `-flag /var/inputs/name` regardless of whether the caller supplied the file, so omitting an optional file flag pointed the tool at a missing path. Now guarded with `if [ -e ]; then ...; else =""; fi`. 3. Inputs whose names differed only in case (e.g. `c` vs `C`, common in bio CLIs like bedtools / samtools) collided on the same `_VAL_*` / `_FLAG_*` bash variable and silently overwrote each other. create() now rejects this at declaration time with an error naming both offending inputs. Adds 28 unit tests across the four-cell defaults matrix, optional File/Dir flag emission, case-collision detection, and the positional-arg brace regression. New examples: - examples/shell/12_bedtools_intersect_example.py — three intersect queries in parallel against a small BED fixture. - examples/shell/modules/bedtools_intersect.py — typed wrapper around bedtools intersect; reference for bio-CLI shell-extra modules. Signed-off-by: Kyle Hazen --- .../shell/12_bedtools_intersect_example.py | 100 ++++++++ examples/shell/modules/bedtools_intersect.py | 69 +++++ src/flyte/extras/shell/_render.py | 17 +- src/flyte/extras/shell/_runtime.py | 111 ++++++++ tests/flyte/extras/test_shell.py | 240 +++++++++++++++++- 5 files changed, 530 insertions(+), 7 deletions(-) create mode 100644 examples/shell/12_bedtools_intersect_example.py create mode 100644 examples/shell/modules/bedtools_intersect.py diff --git a/examples/shell/12_bedtools_intersect_example.py b/examples/shell/12_bedtools_intersect_example.py new file mode 100644 index 000000000..bea64dc26 --- /dev/null +++ b/examples/shell/12_bedtools_intersect_example.py @@ -0,0 +1,100 @@ +"""bedtools intersect — three common overlap queries against a peaks file. + +This example consumes ``modules/bedtools_intersect.py`` (a typed shell wrapper +around the ``bedtools intersect`` CLI) and exercises three of its most-used +flag combinations on a small BED fixture: + +- ``wa=True`` — write each A feature that has *any* overlap in B. +- ``v=True`` — write each A feature that has *no* overlap in B (set diff). +- ``c=True`` — write each A feature with a trailing count of B overlaps. + +Fixture (4 "genes" in A, 3 "peaks" in B, all on chr1): + + A (genes) B (peaks) + chr1 100-200 gene1 chr1 150-180 peak1 <- overlaps gene1 + chr1 300-400 gene2 chr1 350-450 peak2 <- overlaps gene2 + chr1 500-600 gene3 chr1 900-950 peak3 + chr1 700-800 gene4 + +Expected: +- wa -> gene1, gene2 +- v -> gene3, gene4 +- c -> gene1\\t1, gene2\\t1, gene3\\t0, gene4\\t0 + +Run locally:: + + uv run python 12_bedtools_intersect_example.py +""" + +import asyncio +import tempfile + +import flyte +from flyte.io import File + +from modules.bedtools_intersect import bedtools_intersect + + +env = flyte.TaskEnvironment( + name="bedtools_intersect_example", + depends_on=[bedtools_intersect.env], +) + + +@env.task +async def intersect_demo(genes: File, peaks: File) -> tuple[File, File, File]: + overlapping, non_overlapping, counts = await asyncio.gather( + bedtools_intersect(a=genes, b=[peaks], wa=True), + bedtools_intersect(a=genes, b=[peaks], v=True), + bedtools_intersect(a=genes, b=[peaks], count_overlaps=True), + ) + return overlapping, non_overlapping, counts + + +GENES_BED = ( + "chr1\t100\t200\tgene1\t0\t+\n" + "chr1\t300\t400\tgene2\t0\t+\n" + "chr1\t500\t600\tgene3\t0\t+\n" + "chr1\t700\t800\tgene4\t0\t+\n" +) + +PEAKS_BED = ( + "chr1\t150\t180\tpeak1\t0\t+\n" + "chr1\t350\t450\tpeak2\t0\t+\n" + "chr1\t900\t950\tpeak3\t0\t+\n" +) + + + + +# Fixtures +# mkdir -p /tmp/bedtools-fixtures && \ +# printf 'chr1\t100\t200\tgene1\t0\t+\nchr1\t300\t400\tgene2\t0\t+\nchr1\t500\t600\tgene3\t0\t+\nchr1\t700\t800\tgene4\t0\t+\n' > /tmp/bedtools-fixtures/genes.bed && \ +# printf 'chr1\t150\t180\tpeak1\t0\t+\nchr1\t350\t450\tpeak2\t0\t+\nchr1\t900\t950\tpeak3\t0\t+\n' > /tmp/bedtools-fixtures/peaks.bed && \ +# ls -la /tmp/bedtools-fixtures/ + +if __name__ == "__main__": + + flyte.init() + + with tempfile.NamedTemporaryFile(mode="w", suffix=".bed", delete=False) as f: + f.write(GENES_BED) + genes_path = f.name + + with tempfile.NamedTemporaryFile(mode="w", suffix=".bed", delete=False) as f: + f.write(PEAKS_BED) + peaks_path = f.name + + run = flyte.with_runcontext().run( + intersect_demo, + File.from_local_sync(genes_path), + File.from_local_sync(peaks_path), + ) + + print(run) + + out = run.outputs() + overlapping_path = out.o0.download_sync("./overlapping.bed") + non_overlapping_path = out.o1.download_sync("./non_overlapping.bed") + counts_path = out.o2.download_sync("./counts.bed") + print(f"Wrote: {overlapping_path}, {non_overlapping_path}, {counts_path}") diff --git a/examples/shell/modules/bedtools_intersect.py b/examples/shell/modules/bedtools_intersect.py new file mode 100644 index 000000000..70214c8ca --- /dev/null +++ b/examples/shell/modules/bedtools_intersect.py @@ -0,0 +1,69 @@ +from flyte.extras import shell +from flyte.io import File + +IMAGE = "quay.io/biocontainers/bedtools:2.31.1--hf5e1c6e_0" + +# Inputs use descriptive Python names where the bedtools CLI has case-only +# collisions (`-c`/`-C`, `-s`/`-S`, `-f`/`-F`). The shell renderer builds +# bash variable names by uppercasing the Python name, so two inputs whose +# names differ only in case would collide on the same `_FLAG_*` slot. +# `flag_aliases` maps each descriptive Python name to the actual CLI flag. +bedtools_intersect = shell.create( + name="bedtools_intersect", + image=IMAGE, + inputs={ + "a": File, + "b": list[File], + "wa": bool | None, + "wb": bool | None, + "loj": bool | None, + "wo": bool | None, + "wao": bool | None, + "u": bool | None, + "count_overlaps": bool | None, + "count_per_file": bool | None, + "v": bool | None, + "same_strand": bool | None, + "opposite_strand": bool | None, + "frac_a": float | None, + "frac_b": float | None, + "r": bool | None, + "e": bool | None, + "ubam": bool | None, + "bed": bool | None, + "sorted": bool | None, + "nonamecheck": bool | None, + "g": File | None, + "names": str | None, + "filenames": bool | None, + "sortout": bool | None, + "split": bool | None, + "header": bool | None, + "nobuf": bool | None, + "iobuf": str | None, + }, + outputs={"out": File}, + flag_aliases={ + "b": "-b", + "count_overlaps": "-c", + "count_per_file": "-C", + "same_strand": "-s", + "opposite_strand": "-S", + "frac_a": "-f", + "frac_b": "-F", + }, + script=r""" + bedtools intersect \ + -a {inputs.a} \ + {flags.b} \ + {flags.wa} {flags.wb} {flags.loj} {flags.wo} {flags.wao} \ + {flags.u} {flags.count_overlaps} {flags.count_per_file} {flags.v} \ + {flags.same_strand} {flags.opposite_strand} \ + {flags.frac_a} {flags.frac_b} {flags.r} {flags.e} \ + {flags.ubam} {flags.bed} \ + {flags.sorted} {flags.nonamecheck} {flags.g} \ + {flags.names} {flags.filenames} {flags.sortout} \ + {flags.split} {flags.header} {flags.nobuf} {flags.iobuf} \ + > {outputs.out} + """, +) diff --git a/src/flyte/extras/shell/_render.py b/src/flyte/extras/shell/_render.py index abf53b7d6..870dc738d 100644 --- a/src/flyte/extras/shell/_render.py +++ b/src/flyte/extras/shell/_render.py @@ -5,7 +5,7 @@ import shlex from typing import Any, Tuple -from ._types import FlagSpec, Stderr, Stdout, _classify_input +from ._types import FlagSpec, Stderr, Stdout, _classify_input, _is_optional _PLACEHOLDER_RE = re.compile(r"\{(inputs|flags|outputs)\.([a-zA-Z_][a-zA-Z0-9_]*)\}") _DICT_SEP = "\x1e" @@ -20,6 +20,7 @@ def _render_command( output_data_dir: pathlib.Path, ) -> Tuple[str, list[str]]: kinds = {name: _classify_input(name, tp) for name, tp in inputs.items()} + optionals = {name: _is_optional(tp)[0] for name, tp in inputs.items()} preamble_lines: list[str] = [] positional_templates: list[str] = [] @@ -34,7 +35,11 @@ def alloc_slot(name: str) -> str: idx = len(positional_templates) + 1 positional_templates.append(f"{{{{.inputs.{name}}}}}") var = f"_VAL_{name.upper()}" - preamble_lines.append(f'{var}="${idx}"') + # Brace the positional index: bash parses `$10` as `$1` + `"0"`, + # so any task with 10+ scalar/bool inputs would silently bind + # later variables to the wrong values. `${10}` is the only form + # that works for indices ≥ 10. + preamble_lines.append(f'{var}="${{{idx}}}"') slot_var_for[name] = var return var @@ -90,6 +95,7 @@ def render_flag_ref(name: str) -> str: alloc_slot, ensure_dict_decoded, input_data_dir, + is_optional=optionals[name], ) ) if kind in ("list_file", "dict_str"): @@ -139,6 +145,7 @@ def _emit_flag_setter( alloc_slot, ensure_dict_decoded, input_data_dir: pathlib.Path, + is_optional: bool = False, ) -> str: flag = spec.flag sep = spec.separator @@ -155,6 +162,12 @@ def _emit_flag_setter( ) if kind in ("file", "dir"): path = input_data_dir / name + if is_optional: + return ( + f'if [ -e {shlex.quote(str(path))} ]; then ' + f'{flag_var}={shlex.quote(flag + sep + str(path))}; ' + f'else {flag_var}=""; fi' + ) return f"{flag_var}={shlex.quote(flag + sep + str(path))}" if kind == "list_file": dirpath = input_data_dir / name diff --git a/src/flyte/extras/shell/_runtime.py b/src/flyte/extras/shell/_runtime.py index 209e62083..2ce2121e4 100644 --- a/src/flyte/extras/shell/_runtime.py +++ b/src/flyte/extras/shell/_runtime.py @@ -35,6 +35,7 @@ class _Shell: outputs: dict[str, Any] script: str flag_aliases: dict[str, FlagSpec] = field(default_factory=dict) + defaults: dict[str, Any] = field(default_factory=dict) shell: str = "/bin/bash" debug: bool = False input_data_dir: pathlib.Path = pathlib.Path("/var/inputs") @@ -186,6 +187,9 @@ async def _unpack_outputs(self, raw: Any) -> Any: return unpacked[0] if single else tuple(unpacked) async def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]: + if self.defaults: + kwargs = {**self.defaults, **kwargs} + out: dict[str, Any] = {} for name, tp in self.inputs.items(): is_opt, _ = _is_optional(tp) @@ -267,6 +271,47 @@ async def _get_output(self, output_directory: pathlib.Path) -> Tuple[Any, ...]: ) from e +def _validate_defaults(defaults: dict[str, Any], inputs: dict[str, Type]) -> dict[str, Any]: + """Validate that every default key exists in ``inputs`` and that the + value's Python type matches the declared input type. + + A ``None`` default is rejected — the same effect is achieved by + declaring the input as ``T | None`` and omitting it from ``defaults``. + """ + for name, value in defaults.items(): + if name not in inputs: + raise KeyError(f"defaults references {name!r} which is not declared in inputs.") + if value is None: + raise ValueError( + f"defaults[{name!r}] = None is redundant; declare the input as " + f"`T | None` and omit it from defaults instead." + ) + + _, inner = _is_optional(inputs[name]) + + if inner is bool: + if not isinstance(value, bool): + raise TypeError(f"defaults[{name!r}]: expected bool, got {type(value).__name__}.") + elif inner is int: + if not isinstance(value, int) or isinstance(value, bool): + raise TypeError(f"defaults[{name!r}]: expected int, got {type(value).__name__}.") + elif inner is float: + if not isinstance(value, (int, float)) or isinstance(value, bool): + raise TypeError(f"defaults[{name!r}]: expected float, got {type(value).__name__}.") + elif inner is str: + if not isinstance(value, str): + raise TypeError(f"defaults[{name!r}]: expected str, got {type(value).__name__}.") + elif _is_dict_str_str(inner): + if not isinstance(value, dict): + raise TypeError(f"defaults[{name!r}]: expected dict[str, str], got {type(value).__name__}.") + for k, v in value.items(): + if not isinstance(k, str) or not isinstance(v, str): + raise TypeError( + f"defaults[{name!r}]: dict[str, str] requires string keys and values." + ) + return dict(defaults) + + def _truncate(s: str, limit: int = 4000) -> str: if len(s) <= limit: return s if s.endswith("\n") else s + "\n" @@ -297,6 +342,7 @@ def create( outputs: Optional[dict[str, Any]] = None, script: str, flag_aliases: Optional[dict[str, Union[str, Tuple[str, listMode], FlagSpec]]] = None, + defaults: Optional[dict[str, Any]] = None, shell: str = "/bin/bash", debug: bool = False, resources: Optional[flyte.Resources] = None, @@ -346,6 +392,28 @@ def create( - ``int``, ``float``, ``str``, ``bool`` scalars - ``T | None`` of any of the above (``None`` collapses to empty) + **Optional File / Dir flags emit conditionally.** When + ``{flags.}`` references an input typed as + ``File | None`` / ``Dir | None`` and the caller omits it, the + renderer guards the emission with ``if [ -e ]``: no + flag is added to the command when the file isn't staged. + Non-optional ``File`` / ``Dir`` inputs are still emitted + unconditionally (the caller is contractually required to + supply them). + + **Case-colliding input names are rejected at create() time.** + The renderer builds bash variable names by uppercasing the + Python identifier — so two inputs whose names differ only in + case (e.g. ``c`` and ``C``, common in bio CLIs like samtools + ``-h``/``-H`` or bedtools ``-c``/``-C``) would silently share + ``_VAL_*`` / ``_FLAG_*`` slots and overwrite each other. + ``create()`` raises ``ValueError`` listing both names; the + fix is to give one of them a descriptive Python name and + map back to the CLI flag with ``flag_aliases``:: + + inputs={"count_overlaps": bool | None, "count_per_file": bool | None} + flag_aliases={"count_overlaps": "-c", "count_per_file": "-C"} + **Recipes for things that look like they need a richer dict but don't:** @@ -405,6 +473,30 @@ class for behaviour the type system can't express: ``(flag, list_mode)`` to pick a list rendering mode (``"join"``, ``"repeat"``, ``"comma"``) or ``(flag, dict_mode)`` for dicts (``"pairs"``, ``"equals"``). + defaults: Per-input fallback value used when the caller omits that + input at call time. Lets you mark inputs as "optional at call + site" while still emitting their flag, independent of the + ``T | None`` axis. The interaction with ``T | None`` is: + + ==================== ========================= ================================= + Type In ``defaults`` Behavior when caller omits + ==================== ========================= ================================= + ``T`` No ``TypeError`` at submit time + ``T`` Yes Default used; flag emitted + ``T | None`` No Empty value; flag suppressed + ``T | None`` Yes Default used; flag emitted + ==================== ========================= ================================= + + Validation at ``create()`` time: + + - Keys must be present in ``inputs``. + - ``None`` values are rejected — use ``T | None`` and omit + from ``defaults`` instead. + - Value's Python type must match the declared input type + (``bool`` for ``bool``, ``int``/``float`` for ``float``, etc.). + - File/Dir/list[File] default values are accepted without + value-shape checking — supply only fully-constructed + :class:`~flyte.io.File` / :class:`~flyte.io.Dir` instances. shell: Shell binary to use. Defaults to ``/bin/bash``. debug: If True, container prints the rendered script to stderr before running. Invaluable when authoring a new wrapper. @@ -460,6 +552,22 @@ async def pipeline(a: File, b: list[File]) -> list[File]: inputs = inputs or {} outputs = outputs or {} + # Reject inputs whose names differ only in case — the renderer builds + # bash variable names by uppercasing the Python name, so names like + # `c` and `C` would silently share `_VAL_C` / `_FLAG_C` and clobber + # each other. Common in bio CLIs (samtools -h/-H, bedtools -c/-C, etc.). + seen_upper: dict[str, str] = {} + for n in inputs: + up = n.upper() + if up in seen_upper: + raise ValueError( + f"Input names {seen_upper[up]!r} and {n!r} collide on bash variable " + f"'_VAL_{up}' / '_FLAG_{up}' — they differ only in case. " + f"Rename one (e.g. 'count_overlaps') and use flag_aliases to keep " + f"the CLI flag (e.g. flag_aliases={{'count_overlaps': '-c'}})." + ) + seen_upper[up] = n + for n, t in inputs.items(): _classify_input(n, t) @@ -471,6 +579,8 @@ async def pipeline(a: File, b: list[File]) -> list[File]: raise KeyError(f"flag_aliases references {n!r} which is not declared in inputs.") coerced_aliases[n] = FlagSpec.coerce(n, alias) + validated_defaults: dict[str, Any] = _validate_defaults(defaults or {}, inputs) + if not isinstance(image, (str, flyte.Image)): raise TypeError(f"image must be a URI string or a flyte.Image, got {type(image).__name__}.") @@ -481,6 +591,7 @@ async def pipeline(a: File, b: list[File]) -> list[File]: outputs=dict(outputs), script=script, flag_aliases=coerced_aliases, + defaults=validated_defaults, shell=shell, debug=debug, resources=resources, diff --git a/tests/flyte/extras/test_shell.py b/tests/flyte/extras/test_shell.py index 82c81764c..6f13c170c 100644 --- a/tests/flyte/extras/test_shell.py +++ b/tests/flyte/extras/test_shell.py @@ -225,7 +225,7 @@ def test_scalar_goes_to_positional_arg(self): body, positional = _render_full("echo {inputs.x}", {"x": int}) assert positional == ["{{.inputs.x}}"] # The body binds positional $1 to _VAL_X and references it quoted. - assert '_VAL_X="$1"' in body + assert '_VAL_X="${1}"' in body assert '"${_VAL_X}"' in body # No propeller template appears inside the body string itself. assert "{{.inputs.x}}" not in body @@ -310,6 +310,52 @@ def test_list_file_flag_repeat(self): assert "for _f" in out assert "-I" in out + # ---- optional File / Dir flags emit conditionally ---- + + def test_required_file_flag_unconditional(self): + # Non-optional File flag is hardcoded — the caller is contractually + # required to supply it, so no runtime existence check is needed. + out = _render("tool {flags.ref}", {"ref": File}) + assert "_FLAG_REF=" in out + assert "/var/inputs/ref" in out + # No conditional guarding the assignment. + assert "if [ -e " not in out + + def test_optional_file_flag_guarded(self): + # Optional File flag must be guarded — if the caller didn't supply + # the file, /var/inputs/ won't exist and the tool would fail + # trying to open it. + out = _render("tool {flags.sites}", {"sites": File | None}) + assert "if [ -e /var/inputs/sites ]" in out + assert "_FLAG_SITES=-sites /var/inputs/sites" in out or "_FLAG_SITES='-sites /var/inputs/sites'" in out + assert '_FLAG_SITES=""' in out # else branch + + def test_optional_dir_flag_guarded(self): + # Same conditional emission for optional Dir flags. + out = _render("tool {flags.cache}", {"cache": Dir | None}) + assert "if [ -e /var/inputs/cache ]" in out + assert '_FLAG_CACHE=""' in out + + def test_required_dir_flag_unconditional(self): + out = _render("tool {flags.workdir}", {"workdir": Dir}) + assert "/var/inputs/workdir" in out + assert "if [ -e " not in out + + def test_positional_index_braced_for_two_digit_indices(self): + # Regression: bash parses `$10` as `$1` followed by literal `"0"`, + # silently binding 10+-indexed inputs to the wrong values. Indices + # must always be braced as `${10}` for correctness, regardless of + # whether they happen to be single-digit. + inputs = {f"x{i}": str for i in range(15)} + out = _render(" ".join(f"{{inputs.{n}}}" for n in inputs), inputs) + # Index 1 — must still work (the fix uses braces uniformly). + assert '_VAL_X0="${1}"' in out + # Index 10 — this is where the bug bit. Must be braced. + assert '_VAL_X9="${10}"' in out + # And the bare two-digit form must NOT appear anywhere. + assert '"$10"' not in out + assert '"$15"' not in out + # --------------------------------------------------------------------------- # Output collector resolution @@ -547,6 +593,54 @@ def test_flag_aliases_must_match_inputs(self): flag_aliases={"missing": "-m"}, ) + # ---- case-colliding input names rejected ---- + + def test_case_collision_lower_then_upper_rejected(self): + # `c` and `C` would both render to bash vars `_VAL_C` / `_FLAG_C` + # and silently overwrite each other. Common in bio CLIs. + with pytest.raises(ValueError, match="collide on bash variable"): + shell.create( + name="bad", + image="alpine:3.18", + inputs={"c": bool, "C": bool}, + outputs={"o": File}, + script="true", + ) + + def test_case_collision_message_names_both_inputs(self): + # Error must name *both* colliding inputs so the author can find them. + with pytest.raises(ValueError) as exc_info: + shell.create( + name="bad", + image="alpine:3.18", + inputs={"foo": bool, "FOO": bool}, + outputs={"o": File}, + script="true", + ) + msg = str(exc_info.value) + assert "'foo'" in msg and "'FOO'" in msg + + def test_case_collision_mixed_case_rejected(self): + # Not just exact lower/upper — any `.upper()` collision is rejected. + with pytest.raises(ValueError, match="collide on bash variable"): + shell.create( + name="bad", + image="alpine:3.18", + inputs={"my_flag": bool, "My_Flag": bool}, + outputs={"o": File}, + script="true", + ) + + def test_no_collision_distinct_uppercase_forms(self): + # Distinct uppercased forms — no collision, must not raise. + shell.create( + name="ok", + image="alpine:3.18", + inputs={"a": bool, "b": bool, "ab": bool}, + outputs={"o": File}, + script="true", + ) + def test_full_bedtools_shape_validates(self): # End-to-end create() with the full bedtools example shape — no exec. task = shell.create( @@ -778,6 +872,142 @@ def test_optional_dict_default_empty_string(self): assert result["opts"] == "" +# --------------------------------------------------------------------------- +# Defaults — four-cell matrix of {required, optional} × {has default, none} +# --------------------------------------------------------------------------- + + +class TestDefaults: + def _task(self, inputs, defaults=None): + return shell.create( + name="t", + image="alpine:3.18", + inputs=inputs, + outputs={"o": File}, + script="true", + defaults=defaults, + ) + + # ---- four-cell matrix ---- + + def test_required_no_default_missing_raises(self): + task = self._task({"wa": bool}) + with pytest.raises(TypeError, match="Missing required input: 'wa'"): + asyncio.run(task._prepare_kwargs({})) + + def test_required_with_default_uses_default(self): + task = self._task({"wa": bool}, defaults={"wa": False}) + result = asyncio.run(task._prepare_kwargs({})) + # Non-optional bool flows as the native Python bool — ContainerTask + # lower-cases it to the bash "false" string at template-render time. + assert result["wa"] is False + + def test_optional_no_default_missing_empty_string(self): + task = self._task({"wa": bool | None}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["wa"] == "" + + def test_optional_with_default_uses_default(self): + task = self._task({"wa": bool | None}, defaults={"wa": True}) + result = asyncio.run(task._prepare_kwargs({})) + # Optional scalars/bools are wired as str — defaults flow through + # the same conversion path as caller-supplied values. + assert result["wa"] == "true" + + # ---- caller-supplied value always wins ---- + + def test_caller_value_overrides_default(self): + task = self._task({"threads": int}, defaults={"threads": 4}) + result = asyncio.run(task._prepare_kwargs({"threads": 8})) + assert result["threads"] == 8 + + def test_caller_none_overrides_default_for_optional(self): + # Explicit None from caller must win even when a default exists — + # otherwise there's no way to opt out of an emitted flag. + task = self._task({"threads": int | None}, defaults={"threads": 4}) + result = asyncio.run(task._prepare_kwargs({"threads": None})) + assert result["threads"] == "" + + # ---- default values for various kinds ---- + + def test_default_for_optional_int(self): + task = self._task({"n": int | None}, defaults={"n": 42}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["n"] == "42" + + def test_default_for_optional_float(self): + task = self._task({"f": float | None}, defaults={"f": 0.5}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["f"] == "0.5" + + def test_default_for_optional_str(self): + task = self._task({"s": str | None}, defaults={"s": "hello"}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["s"] == "hello" + + def test_default_for_optional_dict(self): + task = self._task( + {"opts": dict[str, str] | None}, defaults={"opts": {"-k": "v"}} + ) + result = asyncio.run(task._prepare_kwargs({})) + # Dict defaults flow through the record-separator packing path. + assert result["opts"].split(_DICT_SEP) == ["-k", "v"] + + # ---- create()-time validation ---- + + def test_validate_unknown_key_rejected(self): + with pytest.raises(KeyError, match="not declared in inputs"): + self._task({"wa": bool}, defaults={"unknown": True}) + + def test_validate_none_default_rejected(self): + with pytest.raises(ValueError, match="redundant"): + self._task({"wa": bool | None}, defaults={"wa": None}) + + def test_validate_bool_type_mismatch(self): + with pytest.raises(TypeError, match="expected bool"): + self._task({"wa": bool}, defaults={"wa": "yes"}) + + def test_validate_int_type_mismatch_rejects_bool(self): + # bool is a subclass of int — reject it for int defaults to keep + # `True` from quietly meaning `1`. + with pytest.raises(TypeError, match="expected int"): + self._task({"n": int}, defaults={"n": True}) + + def test_validate_float_accepts_int(self): + # Lenient: int → float coercion is the obvious user intent. + task = self._task({"f": float | None}, defaults={"f": 5}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["f"] == "5" + + def test_validate_str_type_mismatch(self): + with pytest.raises(TypeError, match="expected str"): + self._task({"s": str | None}, defaults={"s": 42}) + + def test_validate_dict_type_mismatch(self): + with pytest.raises(TypeError, match="expected dict"): + self._task({"opts": dict[str, str] | None}, defaults={"opts": "not a dict"}) + + def test_validate_dict_non_string_value_rejected(self): + with pytest.raises(TypeError, match="string keys and values"): + self._task( + {"opts": dict[str, str] | None}, defaults={"opts": {"k": 42}} # type: ignore[dict-item] + ) + + # ---- no defaults parameter at all (backward compat) ---- + + def test_no_defaults_param_is_backward_compatible(self): + task = shell.create( + name="t", + image="alpine:3.18", + inputs={"wa": bool | None}, + outputs={"o": File}, + script="true", + ) + assert task.defaults == {} + result = asyncio.run(task._prepare_kwargs({})) + assert result["wa"] == "" + + # --------------------------------------------------------------------------- # Stdout / Stderr collectors # --------------------------------------------------------------------------- @@ -834,8 +1064,8 @@ def test_each_scalar_gets_distinct_positional_slot(self): {"a": str, "b": int}, ) assert positional == ["{{.inputs.a}}", "{{.inputs.b}}"] - assert '_VAL_A="$1"' in body - assert '_VAL_B="$2"' in body + assert '_VAL_A="${1}"' in body + assert '_VAL_B="${2}"' in body def test_same_input_referenced_twice_reuses_slot(self): body, positional = _render_full( @@ -844,7 +1074,7 @@ def test_same_input_referenced_twice_reuses_slot(self): ) # x referenced twice — single positional slot. assert positional == ["{{.inputs.x}}"] - assert body.count('_VAL_X="$1"') == 1 + assert body.count('_VAL_X="${1}"') == 1 def test_inputs_and_flags_for_same_var_share_slot(self): body, positional = _render_full( @@ -853,7 +1083,7 @@ def test_inputs_and_flags_for_same_var_share_slot(self): ) assert positional == ["{{.inputs.f}}"] # _VAL_F bound once, used by both the flag setter and the inputs ref. - assert body.count('_VAL_F="$1"') == 1 + assert body.count('_VAL_F="${1}"') == 1 class TestBuildCommandArgvLayout: From d4aecf91961ddb7fef7e8299974f3e5bfc80c235 Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 13 May 2026 10:52:42 +0530 Subject: [PATCH 2/3] deleted bedtools examples & removed uppercasing inputs Signed-off-by: Samhita Alla --- .../shell/12_bedtools_intersect_example.py | 100 ------------ examples/shell/modules/bedtools_intersect.py | 69 -------- src/flyte/extras/shell/_render.py | 6 +- src/flyte/extras/shell/_runtime.py | 115 +++++++------ tests/flyte/extras/test_shell.py | 151 +++++++++++------- 5 files changed, 147 insertions(+), 294 deletions(-) delete mode 100644 examples/shell/12_bedtools_intersect_example.py delete mode 100644 examples/shell/modules/bedtools_intersect.py diff --git a/examples/shell/12_bedtools_intersect_example.py b/examples/shell/12_bedtools_intersect_example.py deleted file mode 100644 index bea64dc26..000000000 --- a/examples/shell/12_bedtools_intersect_example.py +++ /dev/null @@ -1,100 +0,0 @@ -"""bedtools intersect — three common overlap queries against a peaks file. - -This example consumes ``modules/bedtools_intersect.py`` (a typed shell wrapper -around the ``bedtools intersect`` CLI) and exercises three of its most-used -flag combinations on a small BED fixture: - -- ``wa=True`` — write each A feature that has *any* overlap in B. -- ``v=True`` — write each A feature that has *no* overlap in B (set diff). -- ``c=True`` — write each A feature with a trailing count of B overlaps. - -Fixture (4 "genes" in A, 3 "peaks" in B, all on chr1): - - A (genes) B (peaks) - chr1 100-200 gene1 chr1 150-180 peak1 <- overlaps gene1 - chr1 300-400 gene2 chr1 350-450 peak2 <- overlaps gene2 - chr1 500-600 gene3 chr1 900-950 peak3 - chr1 700-800 gene4 - -Expected: -- wa -> gene1, gene2 -- v -> gene3, gene4 -- c -> gene1\\t1, gene2\\t1, gene3\\t0, gene4\\t0 - -Run locally:: - - uv run python 12_bedtools_intersect_example.py -""" - -import asyncio -import tempfile - -import flyte -from flyte.io import File - -from modules.bedtools_intersect import bedtools_intersect - - -env = flyte.TaskEnvironment( - name="bedtools_intersect_example", - depends_on=[bedtools_intersect.env], -) - - -@env.task -async def intersect_demo(genes: File, peaks: File) -> tuple[File, File, File]: - overlapping, non_overlapping, counts = await asyncio.gather( - bedtools_intersect(a=genes, b=[peaks], wa=True), - bedtools_intersect(a=genes, b=[peaks], v=True), - bedtools_intersect(a=genes, b=[peaks], count_overlaps=True), - ) - return overlapping, non_overlapping, counts - - -GENES_BED = ( - "chr1\t100\t200\tgene1\t0\t+\n" - "chr1\t300\t400\tgene2\t0\t+\n" - "chr1\t500\t600\tgene3\t0\t+\n" - "chr1\t700\t800\tgene4\t0\t+\n" -) - -PEAKS_BED = ( - "chr1\t150\t180\tpeak1\t0\t+\n" - "chr1\t350\t450\tpeak2\t0\t+\n" - "chr1\t900\t950\tpeak3\t0\t+\n" -) - - - - -# Fixtures -# mkdir -p /tmp/bedtools-fixtures && \ -# printf 'chr1\t100\t200\tgene1\t0\t+\nchr1\t300\t400\tgene2\t0\t+\nchr1\t500\t600\tgene3\t0\t+\nchr1\t700\t800\tgene4\t0\t+\n' > /tmp/bedtools-fixtures/genes.bed && \ -# printf 'chr1\t150\t180\tpeak1\t0\t+\nchr1\t350\t450\tpeak2\t0\t+\nchr1\t900\t950\tpeak3\t0\t+\n' > /tmp/bedtools-fixtures/peaks.bed && \ -# ls -la /tmp/bedtools-fixtures/ - -if __name__ == "__main__": - - flyte.init() - - with tempfile.NamedTemporaryFile(mode="w", suffix=".bed", delete=False) as f: - f.write(GENES_BED) - genes_path = f.name - - with tempfile.NamedTemporaryFile(mode="w", suffix=".bed", delete=False) as f: - f.write(PEAKS_BED) - peaks_path = f.name - - run = flyte.with_runcontext().run( - intersect_demo, - File.from_local_sync(genes_path), - File.from_local_sync(peaks_path), - ) - - print(run) - - out = run.outputs() - overlapping_path = out.o0.download_sync("./overlapping.bed") - non_overlapping_path = out.o1.download_sync("./non_overlapping.bed") - counts_path = out.o2.download_sync("./counts.bed") - print(f"Wrote: {overlapping_path}, {non_overlapping_path}, {counts_path}") diff --git a/examples/shell/modules/bedtools_intersect.py b/examples/shell/modules/bedtools_intersect.py deleted file mode 100644 index 70214c8ca..000000000 --- a/examples/shell/modules/bedtools_intersect.py +++ /dev/null @@ -1,69 +0,0 @@ -from flyte.extras import shell -from flyte.io import File - -IMAGE = "quay.io/biocontainers/bedtools:2.31.1--hf5e1c6e_0" - -# Inputs use descriptive Python names where the bedtools CLI has case-only -# collisions (`-c`/`-C`, `-s`/`-S`, `-f`/`-F`). The shell renderer builds -# bash variable names by uppercasing the Python name, so two inputs whose -# names differ only in case would collide on the same `_FLAG_*` slot. -# `flag_aliases` maps each descriptive Python name to the actual CLI flag. -bedtools_intersect = shell.create( - name="bedtools_intersect", - image=IMAGE, - inputs={ - "a": File, - "b": list[File], - "wa": bool | None, - "wb": bool | None, - "loj": bool | None, - "wo": bool | None, - "wao": bool | None, - "u": bool | None, - "count_overlaps": bool | None, - "count_per_file": bool | None, - "v": bool | None, - "same_strand": bool | None, - "opposite_strand": bool | None, - "frac_a": float | None, - "frac_b": float | None, - "r": bool | None, - "e": bool | None, - "ubam": bool | None, - "bed": bool | None, - "sorted": bool | None, - "nonamecheck": bool | None, - "g": File | None, - "names": str | None, - "filenames": bool | None, - "sortout": bool | None, - "split": bool | None, - "header": bool | None, - "nobuf": bool | None, - "iobuf": str | None, - }, - outputs={"out": File}, - flag_aliases={ - "b": "-b", - "count_overlaps": "-c", - "count_per_file": "-C", - "same_strand": "-s", - "opposite_strand": "-S", - "frac_a": "-f", - "frac_b": "-F", - }, - script=r""" - bedtools intersect \ - -a {inputs.a} \ - {flags.b} \ - {flags.wa} {flags.wb} {flags.loj} {flags.wo} {flags.wao} \ - {flags.u} {flags.count_overlaps} {flags.count_per_file} {flags.v} \ - {flags.same_strand} {flags.opposite_strand} \ - {flags.frac_a} {flags.frac_b} {flags.r} {flags.e} \ - {flags.ubam} {flags.bed} \ - {flags.sorted} {flags.nonamecheck} {flags.g} \ - {flags.names} {flags.filenames} {flags.sortout} \ - {flags.split} {flags.header} {flags.nobuf} {flags.iobuf} \ - > {outputs.out} - """, -) diff --git a/src/flyte/extras/shell/_render.py b/src/flyte/extras/shell/_render.py index 870dc738d..49e7a0601 100644 --- a/src/flyte/extras/shell/_render.py +++ b/src/flyte/extras/shell/_render.py @@ -34,7 +34,7 @@ def alloc_slot(name: str) -> str: idx = len(positional_templates) + 1 positional_templates.append(f"{{{{.inputs.{name}}}}}") - var = f"_VAL_{name.upper()}" + var = f"_VAL_{name}" # Brace the positional index: bash parses `$10` as `$1` + `"0"`, # so any task with 10+ scalar/bool inputs would silently bind # later variables to the wrong values. `${10}` is the only form @@ -45,7 +45,7 @@ def alloc_slot(name: str) -> str: def ensure_dict_decoded(name: str) -> str: val_var = alloc_slot(name) - arr_var = f"_ARR_{name.upper()}" + arr_var = f"_ARR_{name}" if name not in dict_decoded: dict_decoded.add(name) @@ -82,7 +82,7 @@ def render_flag_ref(name: str) -> str: kind = kinds[name] spec = flag_specs[name] - flag_var = f"_FLAG_{name.upper()}" + flag_var = f"_FLAG_{name}" if name not in flag_emitted: flag_emitted.add(name) diff --git a/src/flyte/extras/shell/_runtime.py b/src/flyte/extras/shell/_runtime.py index 2ce2121e4..7f80f2454 100644 --- a/src/flyte/extras/shell/_runtime.py +++ b/src/flyte/extras/shell/_runtime.py @@ -288,27 +288,58 @@ def _validate_defaults(defaults: dict[str, Any], inputs: dict[str, Type]) -> dic ) _, inner = _is_optional(inputs[name]) - - if inner is bool: + kind = _classify_input(name, inputs[name]) + + if kind == "file": + if not isinstance(value, File): + raise TypeError( + f"defaults[{name!r}]: expected File, got {type(value).__name__}." + ) + elif kind == "dir": + if not isinstance(value, Dir): + raise TypeError( + f"defaults[{name!r}]: expected Dir, got {type(value).__name__}." + ) + elif kind == "list_file": + if not isinstance(value, list): + raise TypeError( + f"defaults[{name!r}]: expected list[File], got {type(value).__name__}." + ) + if not all(isinstance(item, File) for item in value): + raise TypeError( + f"defaults[{name!r}]: list[File] requires every item to be a File." + ) + elif kind == "bool": if not isinstance(value, bool): raise TypeError(f"defaults[{name!r}]: expected bool, got {type(value).__name__}.") - elif inner is int: - if not isinstance(value, int) or isinstance(value, bool): - raise TypeError(f"defaults[{name!r}]: expected int, got {type(value).__name__}.") - elif inner is float: - if not isinstance(value, (int, float)) or isinstance(value, bool): - raise TypeError(f"defaults[{name!r}]: expected float, got {type(value).__name__}.") - elif inner is str: - if not isinstance(value, str): - raise TypeError(f"defaults[{name!r}]: expected str, got {type(value).__name__}.") - elif _is_dict_str_str(inner): + elif kind == "dict_str": if not isinstance(value, dict): - raise TypeError(f"defaults[{name!r}]: expected dict[str, str], got {type(value).__name__}.") + raise TypeError( + f"defaults[{name!r}]: expected dict[str, str], got {type(value).__name__}." + ) for k, v in value.items(): if not isinstance(k, str) or not isinstance(v, str): raise TypeError( f"defaults[{name!r}]: dict[str, str] requires string keys and values." ) + elif kind == "scalar": + if inner is int: + if not isinstance(value, int) or isinstance(value, bool): + raise TypeError( + f"defaults[{name!r}]: expected int, got {type(value).__name__}." + ) + elif inner is float: + if not isinstance(value, (int, float)) or isinstance(value, bool): + raise TypeError( + f"defaults[{name!r}]: expected float, got {type(value).__name__}." + ) + elif inner is str: + if not isinstance(value, str): + raise TypeError( + f"defaults[{name!r}]: expected str, got {type(value).__name__}." + ) + else: + raise AssertionError(inner) return dict(defaults) @@ -392,28 +423,6 @@ def create( - ``int``, ``float``, ``str``, ``bool`` scalars - ``T | None`` of any of the above (``None`` collapses to empty) - **Optional File / Dir flags emit conditionally.** When - ``{flags.}`` references an input typed as - ``File | None`` / ``Dir | None`` and the caller omits it, the - renderer guards the emission with ``if [ -e ]``: no - flag is added to the command when the file isn't staged. - Non-optional ``File`` / ``Dir`` inputs are still emitted - unconditionally (the caller is contractually required to - supply them). - - **Case-colliding input names are rejected at create() time.** - The renderer builds bash variable names by uppercasing the - Python identifier — so two inputs whose names differ only in - case (e.g. ``c`` and ``C``, common in bio CLIs like samtools - ``-h``/``-H`` or bedtools ``-c``/``-C``) would silently share - ``_VAL_*`` / ``_FLAG_*`` slots and overwrite each other. - ``create()`` raises ``ValueError`` listing both names; the - fix is to give one of them a descriptive Python name and - map back to the CLI flag with ``flag_aliases``:: - - inputs={"count_overlaps": bool | None, "count_per_file": bool | None} - flag_aliases={"count_overlaps": "-c", "count_per_file": "-C"} - **Recipes for things that look like they need a richer dict but don't:** @@ -465,7 +474,7 @@ class for behaviour the type system can't express: travel through bash positional args (``$1``, ``$2``) so they survive arbitrary content (single quotes, tabs, dollar signs) without escaping, and the wrapper already emits scalar - references as ``"${_VAL_X}"`` (quoted, single token). Wrapping + references as ``"${_VAL_name}"`` (quoted, single token). Wrapping them in ``"..."`` again breaks out of the wrapper's quoting and re-enables word splitting. flag_aliases: Per-input override for ``{flags.}`` rendering. @@ -486,17 +495,6 @@ class for behaviour the type system can't express: ``T | None`` No Empty value; flag suppressed ``T | None`` Yes Default used; flag emitted ==================== ========================= ================================= - - Validation at ``create()`` time: - - - Keys must be present in ``inputs``. - - ``None`` values are rejected — use ``T | None`` and omit - from ``defaults`` instead. - - Value's Python type must match the declared input type - (``bool`` for ``bool``, ``int``/``float`` for ``float``, etc.). - - File/Dir/list[File] default values are accepted without - value-shape checking — supply only fully-constructed - :class:`~flyte.io.File` / :class:`~flyte.io.Dir` instances. shell: Shell binary to use. Defaults to ``/bin/bash``. debug: If True, container prints the rendered script to stderr before running. Invaluable when authoring a new wrapper. @@ -552,21 +550,16 @@ async def pipeline(a: File, b: list[File]) -> list[File]: inputs = inputs or {} outputs = outputs or {} - # Reject inputs whose names differ only in case — the renderer builds - # bash variable names by uppercasing the Python name, so names like - # `c` and `C` would silently share `_VAL_C` / `_FLAG_C` and clobber - # each other. Common in bio CLIs (samtools -h/-H, bedtools -c/-C, etc.). - seen_upper: dict[str, str] = {} + # Sanity-check the generated helper names used by the shell renderer. + generated_helpers: dict[str, str] = {} for n in inputs: - up = n.upper() - if up in seen_upper: - raise ValueError( - f"Input names {seen_upper[up]!r} and {n!r} collide on bash variable " - f"'_VAL_{up}' / '_FLAG_{up}' — they differ only in case. " - f"Rename one (e.g. 'count_overlaps') and use flag_aliases to keep " - f"the CLI flag (e.g. flag_aliases={{'count_overlaps': '-c'}})." - ) - seen_upper[up] = n + for helper in (f"_VAL_{n}", f"_FLAG_{n}", f"_ARR_{n}"): + if helper in generated_helpers: + raise ValueError( + f"Input names {generated_helpers[helper]!r} and {n!r} collide in generated shell helper " + f"name {helper!r}. Rename one input." + ) + generated_helpers[helper] = n for n, t in inputs.items(): _classify_input(n, t) diff --git a/tests/flyte/extras/test_shell.py b/tests/flyte/extras/test_shell.py index 6f13c170c..79a112560 100644 --- a/tests/flyte/extras/test_shell.py +++ b/tests/flyte/extras/test_shell.py @@ -224,21 +224,21 @@ def test_scalar_goes_to_positional_arg(self): # with single quotes / tabs / specials survive without escaping. body, positional = _render_full("echo {inputs.x}", {"x": int}) assert positional == ["{{.inputs.x}}"] - # The body binds positional $1 to _VAL_X and references it quoted. - assert '_VAL_X="${1}"' in body - assert '"${_VAL_X}"' in body + # The body binds positional $1 to _VAL_x and references it quoted. + assert '_VAL_x="${1}"' in body + assert '"${_VAL_x}"' in body # No propeller template appears inside the body string itself. assert "{{.inputs.x}}" not in body def test_scalar_referenced_as_quoted_token(self): body, _ = _render_full("echo {inputs.s}", {"s": str}) # Reference is quoted so spaces / tabs in the value stay one bash token. - assert '"${_VAL_S}"' in body + assert '"${_VAL_s}"' in body def test_bool_flag_renders_conditional(self): out = _render("foo {flags.wa}", {"wa": bool}) assert "if [" in out - assert "_FLAG_WA" in out + assert "_FLAG_wa" in out assert "-wa" in out def test_flag_alias_overrides_default(self): @@ -316,7 +316,7 @@ def test_required_file_flag_unconditional(self): # Non-optional File flag is hardcoded — the caller is contractually # required to supply it, so no runtime existence check is needed. out = _render("tool {flags.ref}", {"ref": File}) - assert "_FLAG_REF=" in out + assert "_FLAG_ref=" in out assert "/var/inputs/ref" in out # No conditional guarding the assignment. assert "if [ -e " not in out @@ -327,14 +327,14 @@ def test_optional_file_flag_guarded(self): # trying to open it. out = _render("tool {flags.sites}", {"sites": File | None}) assert "if [ -e /var/inputs/sites ]" in out - assert "_FLAG_SITES=-sites /var/inputs/sites" in out or "_FLAG_SITES='-sites /var/inputs/sites'" in out - assert '_FLAG_SITES=""' in out # else branch + assert "_FLAG_sites=-sites /var/inputs/sites" in out or "_FLAG_sites='-sites /var/inputs/sites'" in out + assert '_FLAG_sites=""' in out # else branch def test_optional_dir_flag_guarded(self): # Same conditional emission for optional Dir flags. out = _render("tool {flags.cache}", {"cache": Dir | None}) assert "if [ -e /var/inputs/cache ]" in out - assert '_FLAG_CACHE=""' in out + assert '_FLAG_cache=""' in out def test_required_dir_flag_unconditional(self): out = _render("tool {flags.workdir}", {"workdir": Dir}) @@ -349,9 +349,9 @@ def test_positional_index_braced_for_two_digit_indices(self): inputs = {f"x{i}": str for i in range(15)} out = _render(" ".join(f"{{inputs.{n}}}" for n in inputs), inputs) # Index 1 — must still work (the fix uses braces uniformly). - assert '_VAL_X0="${1}"' in out + assert '_VAL_x0="${1}"' in out # Index 10 — this is where the bug bit. Must be braced. - assert '_VAL_X9="${10}"' in out + assert '_VAL_x9="${10}"' in out # And the bare two-digit form must NOT appear anywhere. assert '"$10"' not in out assert '"$15"' not in out @@ -593,46 +593,38 @@ def test_flag_aliases_must_match_inputs(self): flag_aliases={"missing": "-m"}, ) - # ---- case-colliding input names rejected ---- + # ---- case-preserving helper names ---- - def test_case_collision_lower_then_upper_rejected(self): - # `c` and `C` would both render to bash vars `_VAL_C` / `_FLAG_C` - # and silently overwrite each other. Common in bio CLIs. - with pytest.raises(ValueError, match="collide on bash variable"): - shell.create( - name="bad", - image="alpine:3.18", - inputs={"c": bool, "C": bool}, - outputs={"o": File}, - script="true", - ) + def test_case_distinct_inputs_are_allowed(self): + shell.create( + name="ok_case", + image="alpine:3.18", + inputs={"c": bool, "C": bool}, + outputs={"o": File}, + script="true", + ) - def test_case_collision_message_names_both_inputs(self): - # Error must name *both* colliding inputs so the author can find them. - with pytest.raises(ValueError) as exc_info: - shell.create( - name="bad", - image="alpine:3.18", - inputs={"foo": bool, "FOO": bool}, - outputs={"o": File}, - script="true", - ) - msg = str(exc_info.value) - assert "'foo'" in msg and "'FOO'" in msg + def test_mixed_case_distinct_inputs_are_allowed(self): + shell.create( + name="ok_mixed_case", + image="alpine:3.18", + inputs={"my_flag": bool, "My_Flag": bool}, + outputs={"o": File}, + script="true", + ) - def test_case_collision_mixed_case_rejected(self): - # Not just exact lower/upper — any `.upper()` collision is rejected. - with pytest.raises(ValueError, match="collide on bash variable"): - shell.create( - name="bad", - image="alpine:3.18", - inputs={"my_flag": bool, "My_Flag": bool}, - outputs={"o": File}, - script="true", - ) + def test_helper_names_preserve_input_case(self): + body = shell.create( + name="case_helpers", + image="alpine:3.18", + inputs={"c": bool, "C": bool}, + outputs={"o": File}, + script="tool {flags.c} {flags.C} > {outputs.o}", + )._build_command()[2] + assert "_FLAG_c" in body + assert "_FLAG_C" in body - def test_no_collision_distinct_uppercase_forms(self): - # Distinct uppercased forms — no collision, must not raise. + def test_distinct_case_sensitive_names_still_validate(self): shell.create( name="ok", image="alpine:3.18", @@ -671,9 +663,9 @@ def test_full_bedtools_shape_validates(self): body = cmd[2] assert "/var/inputs/a" in body assert "/var/inputs/b/*" in body - assert "_FLAG_WA" in body - assert "_FLAG_LOJ" in body - assert "_FLAG_NAMES" in body + assert "_FLAG_wa" in body + assert "_FLAG_loj" in body + assert "_FLAG_names" in body assert "/var/outputs/_returncode" in body def test_debug_mode_emits_script_dump(self): @@ -688,7 +680,7 @@ def test_debug_mode_emits_script_dump(self): body = task._build_command()[2] assert "rendered script" in body assert "cat <<'_EOF_' >&2" in body - assert '( echo "${_VAL_X}" > /var/outputs/o' not in body + assert '( echo "${_VAL_x}" > /var/outputs/o' not in body def test_debug_mode_dump_flows_through_declared_stderr(self): task = shell.create( @@ -794,9 +786,9 @@ def test_inputs_ref_renders_array_expansion(self): body, positional = _render_full("tool {inputs.opts}", {"opts": dict[str, str]}) # Dict gets a positional slot, then a decode preamble allocates an array. assert positional == ["{{.inputs.opts}}"] - assert "_ARR_OPTS=" in body + assert "_ARR_opts=" in body assert "IFS=" in body - assert '"${_ARR_OPTS[@]}"' in body + assert '"${_ARR_opts[@]}"' in body def test_flags_pairs_mode_default(self): body, _ = _render_full( @@ -804,8 +796,8 @@ def test_flags_pairs_mode_default(self): {"opts": dict[str, str]}, ) # Default mode is pairs — keys/values become separate argv tokens. - assert "_FLAG_OPTS=" in body - assert '"${_FLAG_OPTS[@]}"' in body + assert "_FLAG_opts=" in body + assert '"${_FLAG_opts[@]}"' in body def test_flags_equals_mode(self): body, _ = _render_full( @@ -953,6 +945,24 @@ def test_default_for_optional_dict(self): # Dict defaults flow through the record-separator packing path. assert result["opts"].split(_DICT_SEP) == ["-k", "v"] + def test_default_for_optional_file(self): + f = File(path="/tmp/example.txt") + task = self._task({"src": File | None}, defaults={"src": f}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["src"] is f + + def test_default_for_optional_dir(self): + d = Dir(path="/tmp/example_dir") + task = self._task({"src": Dir | None}, defaults={"src": d}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["src"] is d + + def test_default_for_optional_list_of_files(self): + files = [File(path="/tmp/a.txt"), File(path="/tmp/b.txt")] + task = self._task({"parts": list[File] | None}, defaults={"parts": files}) + result = asyncio.run(task._prepare_kwargs({})) + assert result["parts"] == files + # ---- create()-time validation ---- def test_validate_unknown_key_rejected(self): @@ -993,6 +1003,25 @@ def test_validate_dict_non_string_value_rejected(self): {"opts": dict[str, str] | None}, defaults={"opts": {"k": 42}} # type: ignore[dict-item] ) + def test_validate_file_type_mismatch(self): + with pytest.raises(TypeError, match="expected File"): + self._task({"src": File | None}, defaults={"src": "/tmp/example.txt"}) + + def test_validate_dir_type_mismatch(self): + with pytest.raises(TypeError, match="expected Dir"): + self._task({"src": Dir | None}, defaults={"src": "/tmp/example_dir"}) + + def test_validate_list_of_files_type_mismatch(self): + with pytest.raises(TypeError, match=r"expected list\[File\]"): + self._task({"parts": list[File] | None}, defaults={"parts": "not a list"}) + + def test_validate_list_of_files_item_type_mismatch(self): + with pytest.raises(TypeError, match=r"list\[File\] requires every item to be a File"): + self._task( + {"parts": list[File] | None}, + defaults={"parts": [File(path="/tmp/a.txt"), "/tmp/b.txt"]}, + ) + # ---- no defaults parameter at all (backward compat) ---- def test_no_defaults_param_is_backward_compatible(self): @@ -1046,7 +1075,7 @@ class TestScalarValuesSurviveShellSpecials: Scalar values go through bash positional args, never through inline shell substitution. The body never contains the literal value — only a - `"${_VAL_X}"` reference. Propeller substitutes the literal value into the + `"${_VAL_x}"` reference. Propeller substitutes the literal value into the argv slot at runtime; bash sees it as a verbatim string. """ @@ -1064,8 +1093,8 @@ def test_each_scalar_gets_distinct_positional_slot(self): {"a": str, "b": int}, ) assert positional == ["{{.inputs.a}}", "{{.inputs.b}}"] - assert '_VAL_A="${1}"' in body - assert '_VAL_B="${2}"' in body + assert '_VAL_a="${1}"' in body + assert '_VAL_b="${2}"' in body def test_same_input_referenced_twice_reuses_slot(self): body, positional = _render_full( @@ -1074,7 +1103,7 @@ def test_same_input_referenced_twice_reuses_slot(self): ) # x referenced twice — single positional slot. assert positional == ["{{.inputs.x}}"] - assert body.count('_VAL_X="${1}"') == 1 + assert body.count('_VAL_x="${1}"') == 1 def test_inputs_and_flags_for_same_var_share_slot(self): body, positional = _render_full( @@ -1082,8 +1111,8 @@ def test_inputs_and_flags_for_same_var_share_slot(self): {"f": str}, ) assert positional == ["{{.inputs.f}}"] - # _VAL_F bound once, used by both the flag setter and the inputs ref. - assert body.count('_VAL_F="${1}"') == 1 + # _VAL_f bound once, used by both the flag setter and the inputs ref. + assert body.count('_VAL_f="${1}"') == 1 class TestBuildCommandArgvLayout: From b5336067e630ebac9c7ddced69c33a595bf8731c Mon Sep 17 00:00:00 2001 From: Samhita Alla Date: Wed, 13 May 2026 10:59:41 +0530 Subject: [PATCH 3/3] make fmt Signed-off-by: Samhita Alla --- src/flyte/extras/shell/_render.py | 4 ++-- src/flyte/extras/shell/_runtime.py | 36 ++++++++---------------------- tests/flyte/extras/test_shell.py | 9 ++++---- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/src/flyte/extras/shell/_render.py b/src/flyte/extras/shell/_render.py index 49e7a0601..5ba3e81ab 100644 --- a/src/flyte/extras/shell/_render.py +++ b/src/flyte/extras/shell/_render.py @@ -164,8 +164,8 @@ def _emit_flag_setter( path = input_data_dir / name if is_optional: return ( - f'if [ -e {shlex.quote(str(path))} ]; then ' - f'{flag_var}={shlex.quote(flag + sep + str(path))}; ' + f"if [ -e {shlex.quote(str(path))} ]; then " + f"{flag_var}={shlex.quote(flag + sep + str(path))}; " f'else {flag_var}=""; fi' ) return f"{flag_var}={shlex.quote(flag + sep + str(path))}" diff --git a/src/flyte/extras/shell/_runtime.py b/src/flyte/extras/shell/_runtime.py index 7f80f2454..fe2017919 100644 --- a/src/flyte/extras/shell/_runtime.py +++ b/src/flyte/extras/shell/_runtime.py @@ -292,52 +292,34 @@ def _validate_defaults(defaults: dict[str, Any], inputs: dict[str, Type]) -> dic if kind == "file": if not isinstance(value, File): - raise TypeError( - f"defaults[{name!r}]: expected File, got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected File, got {type(value).__name__}.") elif kind == "dir": if not isinstance(value, Dir): - raise TypeError( - f"defaults[{name!r}]: expected Dir, got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected Dir, got {type(value).__name__}.") elif kind == "list_file": if not isinstance(value, list): - raise TypeError( - f"defaults[{name!r}]: expected list[File], got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected list[File], got {type(value).__name__}.") if not all(isinstance(item, File) for item in value): - raise TypeError( - f"defaults[{name!r}]: list[File] requires every item to be a File." - ) + raise TypeError(f"defaults[{name!r}]: list[File] requires every item to be a File.") elif kind == "bool": if not isinstance(value, bool): raise TypeError(f"defaults[{name!r}]: expected bool, got {type(value).__name__}.") elif kind == "dict_str": if not isinstance(value, dict): - raise TypeError( - f"defaults[{name!r}]: expected dict[str, str], got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected dict[str, str], got {type(value).__name__}.") for k, v in value.items(): if not isinstance(k, str) or not isinstance(v, str): - raise TypeError( - f"defaults[{name!r}]: dict[str, str] requires string keys and values." - ) + raise TypeError(f"defaults[{name!r}]: dict[str, str] requires string keys and values.") elif kind == "scalar": if inner is int: if not isinstance(value, int) or isinstance(value, bool): - raise TypeError( - f"defaults[{name!r}]: expected int, got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected int, got {type(value).__name__}.") elif inner is float: if not isinstance(value, (int, float)) or isinstance(value, bool): - raise TypeError( - f"defaults[{name!r}]: expected float, got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected float, got {type(value).__name__}.") elif inner is str: if not isinstance(value, str): - raise TypeError( - f"defaults[{name!r}]: expected str, got {type(value).__name__}." - ) + raise TypeError(f"defaults[{name!r}]: expected str, got {type(value).__name__}.") else: raise AssertionError(inner) return dict(defaults) diff --git a/tests/flyte/extras/test_shell.py b/tests/flyte/extras/test_shell.py index 79a112560..84a229df3 100644 --- a/tests/flyte/extras/test_shell.py +++ b/tests/flyte/extras/test_shell.py @@ -865,7 +865,7 @@ def test_optional_dict_default_empty_string(self): # --------------------------------------------------------------------------- -# Defaults — four-cell matrix of {required, optional} × {has default, none} +# Defaults — four-cell matrix of {required, optional} x {has default, none} # --------------------------------------------------------------------------- @@ -938,9 +938,7 @@ def test_default_for_optional_str(self): assert result["s"] == "hello" def test_default_for_optional_dict(self): - task = self._task( - {"opts": dict[str, str] | None}, defaults={"opts": {"-k": "v"}} - ) + task = self._task({"opts": dict[str, str] | None}, defaults={"opts": {"-k": "v"}}) result = asyncio.run(task._prepare_kwargs({})) # Dict defaults flow through the record-separator packing path. assert result["opts"].split(_DICT_SEP) == ["-k", "v"] @@ -1000,7 +998,8 @@ def test_validate_dict_type_mismatch(self): def test_validate_dict_non_string_value_rejected(self): with pytest.raises(TypeError, match="string keys and values"): self._task( - {"opts": dict[str, str] | None}, defaults={"opts": {"k": 42}} # type: ignore[dict-item] + {"opts": dict[str, str] | None}, + defaults={"opts": {"k": 42}}, # type: ignore[dict-item] ) def test_validate_file_type_mismatch(self):