Skip to content

Commit a147e12

Browse files
committed
make fmt & make mypy
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
1 parent 732e2aa commit a147e12

13 files changed

Lines changed: 67 additions & 125 deletions

examples/shell/01_basic.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88

99
import sys
1010
import tempfile
11+
from pathlib import Path
1112

1213
import flyte
14+
from flyte._image import PythonWheels
1315
from flyte.extras import shell
1416
from flyte.io import File
15-
from flyte._image import PythonWheels
16-
from pathlib import Path
1717

1818
# Wrap `head` — emits the first N lines of an input file.
1919
head_task = shell.create(
@@ -60,8 +60,6 @@ async def take_first_lines(src: File, n: int) -> File:
6060
f.write("\n".join(f"line {i}" for i in range(1, 21)))
6161
path = f.name
6262

63-
run = flyte.with_runcontext(mode=mode).run(
64-
take_first_lines, File.from_local_sync(path), 5
65-
)
63+
run = flyte.with_runcontext(mode=mode).run(take_first_lines, File.from_local_sync(path), 5)
6664
print(run.url if mode == "remote" else run)
6765
print(f"Output: {run.outputs()}")

examples/shell/02_lists_and_globs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ async def split_concatenated(parts: list[File], chunk_lines: int) -> list[File]:
7777

7878
parts: list[File] = []
7979
for i in range(3):
80-
with tempfile.NamedTemporaryFile(
81-
mode="w", suffix=f"_p{i}.txt", delete=False
82-
) as f:
80+
with tempfile.NamedTemporaryFile(mode="w", suffix=f"_p{i}.txt", delete=False) as f:
8381
f.write("\n".join(f"part-{i}/line-{j}" for j in range(10)))
8482
parts.append(File.from_local_sync(f.name))
8583

examples/shell/04_dict_inputs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@
4646
image="debian:12-slim",
4747
inputs={"opts": dict[str, str]},
4848
outputs={"argv": Stdout(type=str)},
49-
flag_aliases={
50-
"opts": ""
51-
}, # pairs mode is the default; no per-key prefix; keys already include `--`
49+
flag_aliases={"opts": ""}, # pairs mode is the default; no per-key prefix; keys already include `--`
5250
script=r"""
5351
echo {flags.opts}
5452
""",
@@ -117,7 +115,7 @@ async def dict_demo() -> tuple[str, str, str, str]:
117115
opts={
118116
"--memory": "4G",
119117
"--threads": "8",
120-
"--label": "it's a test", # single quote — safe via positional args
118+
"--label": "it's a test", # single quote — safe via positional args
121119
}
122120
)
123121
equals_argv = await echo_equals(

examples/shell/05_bool_and_optional.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,7 @@ async def bool_demo() -> tuple[str, str, str]:
7171
# All flags on.
7272
a = await report(title="full", verbose=True, case_insensitive=True, threads=8)
7373
# Only verbose; threads omitted (None).
74-
b = await report(
75-
title="partial", verbose=True, case_insensitive=False, threads=None
76-
)
74+
b = await report(title="partial", verbose=True, case_insensitive=False, threads=None)
7775
# All flags off.
7876
c = await report(title="bare", verbose=False, case_insensitive=False, threads=None)
7977
return a, b, c

examples/shell/08_dirs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ async def summarize_files(src: Dir) -> Dir:
7979
for i, body in enumerate(["one\ntwo\nthree\n", "alpha\nbeta\n", "single\n"]):
8080
(tmp / f"file_{i}.txt").write_text(body)
8181

82-
run = flyte.with_runcontext(mode=mode).run(
83-
summarize_files, Dir.from_local_sync(str(tmp))
84-
)
82+
run = flyte.with_runcontext(mode=mode).run(summarize_files, Dir.from_local_sync(str(tmp)))
8583
print(run.url if mode == "remote" else run)
8684
print(f"Output: {run.outputs()}")

examples/shell/10_resources_and_image.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@
3838
# A layered Image — base + extra apt package. shell.create builds this on
3939
# first call via flyte.build (using the configured builder) and passes the
4040
# resolved URI down to ContainerTask.
41-
image_with_jq = flyte.Image.from_debian_base(install_flyte=False).with_apt_packages(
42-
"jq"
43-
)
41+
image_with_jq = flyte.Image.from_debian_base(install_flyte=False).with_apt_packages("jq")
4442

4543

4644
pretty_print_json = shell.create(

examples/shell/11_list_flag_modes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ async def list_mode_demo(parts: list[File]) -> tuple[str, str, str]:
7474

7575
parts: list[File] = []
7676
for i in range(3):
77-
with tempfile.NamedTemporaryFile(
78-
mode="w", suffix=f"_{i}.txt", delete=False
79-
) as f:
77+
with tempfile.NamedTemporaryFile(mode="w", suffix=f"_{i}.txt", delete=False) as f:
8078
f.write(f"part-{i}\n")
8179
parts.append(File.from_local_sync(f.name))
8280

src/flyte/extras/shell/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,36 +57,36 @@
5757
from __future__ import annotations
5858

5959
from ._render import _DICT_SEP, _render_command
60-
from ._runtime import _Shell, _read_process_result, create
60+
from ._runtime import _read_process_result, _Shell, create
6161
from ._types import (
6262
DictMode,
6363
FlagSpec,
6464
Glob,
6565
OutputSpec,
6666
Stderr,
6767
Stdout,
68-
listMode,
6968
_classify_input,
7069
_is_list_of,
7170
_is_optional,
7271
_validate_outputs,
72+
listMode,
7373
)
7474

7575
__all__ = [
76+
"_DICT_SEP",
7677
"DictMode",
7778
"FlagSpec",
7879
"Glob",
7980
"OutputSpec",
8081
"Stderr",
8182
"Stdout",
82-
"create",
83-
"listMode",
84-
"_DICT_SEP",
8583
"_Shell",
8684
"_classify_input",
8785
"_is_list_of",
8886
"_is_optional",
8987
"_read_process_result",
9088
"_render_command",
9189
"_validate_outputs",
90+
"create",
91+
"listMode",
9292
]

src/flyte/extras/shell/_render.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from ._types import FlagSpec, Stderr, Stdout, _classify_input
99

10-
1110
_PLACEHOLDER_RE = re.compile(r"\{(inputs|flags|outputs)\.([a-zA-Z_][a-zA-Z0-9_]*)\}")
1211
_DICT_SEP = "\x1e"
1312

@@ -57,9 +56,7 @@ def render_input_ref(name: str) -> str:
5756
kind = kinds.get(name)
5857

5958
if kind is None:
60-
raise KeyError(
61-
f"{{inputs.{name}}} used in script but {name!r} is not declared in inputs."
62-
)
59+
raise KeyError(f"{{inputs.{name}}} used in script but {name!r} is not declared in inputs.")
6360
if kind in ("file", "dir"):
6461
return str(input_data_dir / name)
6562
if kind == "list_file":
@@ -108,10 +105,7 @@ def replace(match: re.Match) -> str:
108105
return render_flag_ref(name)
109106
if ns == "outputs":
110107
if name not in outputs:
111-
raise KeyError(
112-
f"{{outputs.{name}}} references an unknown output. "
113-
f"Declared outputs: {list(outputs)}"
114-
)
108+
raise KeyError(f"{{outputs.{name}}} references an unknown output. Declared outputs: {list(outputs)}")
115109
spec = outputs[name]
116110
if isinstance(spec, (Stdout, Stderr)):
117111
raise KeyError(
@@ -151,11 +145,7 @@ def _emit_flag_setter(
151145

152146
if kind == "bool":
153147
val_var = alloc_slot(name)
154-
return (
155-
f'if [ "${{{val_var}}}" = "true" ]; then '
156-
f"{flag_var}={shlex.quote(flag)}; "
157-
f'else {flag_var}=""; fi'
158-
)
148+
return f'if [ "${{{val_var}}}" = "true" ]; then {flag_var}={shlex.quote(flag)}; else {flag_var}=""; fi'
159149
if kind == "scalar":
160150
val_var = alloc_slot(name)
161151
return (

src/flyte/extras/shell/_runtime.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111

1212
from ._render import _DICT_SEP, _render_command
1313
from ._types import (
14+
_SCALAR_TYPES,
1415
FlagSpec,
1516
Glob,
1617
Stderr,
1718
Stdout,
18-
_ProcessResult,
19-
_SCALAR_TYPES,
2019
_classify_input,
2120
_is_dict_str_str,
2221
_is_optional,
22+
_ProcessResult,
2323
_validate_outputs,
2424
listMode,
2525
)
@@ -51,8 +51,8 @@ class _Shell:
5151
_env: Optional["flyte.TaskEnvironment"] = field(default=None, repr=False, compare=False)
5252
_resolved_image_uri: Optional[str] = field(default=None, repr=False, compare=False)
5353

54-
def _container_inputs(self) -> dict[str, Type]:
55-
wired: dict[str, Type] = {}
54+
def _container_inputs(self) -> dict[str, Any]:
55+
wired: dict[str, Any] = {}
5656
for name, tp in self.inputs.items():
5757
is_opt, inner = _is_optional(tp)
5858
if _is_dict_str_str(inner):
@@ -81,10 +81,7 @@ def _build_command(self) -> list[str]:
8181
script=self.script,
8282
inputs=self.inputs,
8383
outputs=self.outputs,
84-
flag_specs={
85-
name: FlagSpec.coerce(name, self.flag_aliases.get(name))
86-
for name in self.inputs
87-
},
84+
flag_specs={name: FlagSpec.coerce(name, self.flag_aliases.get(name)) for name in self.inputs},
8885
input_data_dir=self.input_data_dir,
8986
output_data_dir=self.output_data_dir,
9087
)
@@ -100,17 +97,13 @@ def _build_command(self) -> list[str]:
10097
mkdirs = [
10198
f"mkdir -p {shlex.quote(str(self.output_data_dir / name))}"
10299
for name, spec in self.outputs.items()
103-
if isinstance(spec, Glob)
104-
or (isinstance(spec, type) and issubclass(spec, Dir))
100+
if isinstance(spec, Glob) or (isinstance(spec, type) and issubclass(spec, Dir))
105101
]
106102
mkdir_preamble = "; ".join(mkdirs) + ";" if mkdirs else ""
107103

108104
debug_preamble = ""
109105
if self.debug:
110-
debug_preamble = (
111-
'echo "--- shell task: rendered script ---" >&2; '
112-
f"cat <<'_EOF_' >&2\n{body}\n_EOF_\n"
113-
)
106+
debug_preamble = f"echo \"--- shell task: rendered script ---\" >&2; cat <<'_EOF_' >&2\n{body}\n_EOF_\n"
114107

115108
wrapped = (
116109
f"{mkdir_preamble} "
@@ -187,9 +180,7 @@ async def _unpack_outputs(self, raw: Any) -> Any:
187180
if isinstance(spec, Glob) and isinstance(value, Dir):
188181
local = await value.download() if hasattr(value, "download") else value.path
189182
matched = sorted(pathlib.Path(str(local)).glob(spec.pattern))
190-
unpacked.append(
191-
[await File.from_local(str(p)) for p in matched if p.is_file()]
192-
)
183+
unpacked.append([await File.from_local(str(p)) for p in matched if p.is_file()])
193184
else:
194185
unpacked.append(value)
195186
return unpacked[0] if single else tuple(unpacked)
@@ -223,20 +214,15 @@ async def _prepare_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
223214
for k, v in value.items():
224215
if _DICT_SEP in k or _DICT_SEP in v:
225216
raise ValueError(
226-
f"dict input {name!r}: keys/values cannot contain the "
227-
f"record-separator byte (\\x1e)."
217+
f"dict input {name!r}: keys/values cannot contain the record-separator byte (\\x1e)."
228218
)
229219
parts.append(k)
230220
parts.append(v)
231221
out[name] = _DICT_SEP.join(parts)
232222
continue
233223

234224
if is_opt and kind in ("scalar", "bool"):
235-
out[name] = (
236-
"true"
237-
if (kind == "bool" and value)
238-
else "false" if kind == "bool" else str(value)
239-
)
225+
out[name] = "true" if (kind == "bool" and value) else "false" if kind == "bool" else str(value)
240226
continue
241227

242228
out[name] = value
@@ -310,9 +296,7 @@ def create(
310296
inputs: Optional[dict[str, Type]] = None,
311297
outputs: Optional[dict[str, Any]] = None,
312298
script: str,
313-
flag_aliases: Optional[
314-
dict[str, Union[str, Tuple[str, listMode], FlagSpec]]
315-
] = None,
299+
flag_aliases: Optional[dict[str, Union[str, Tuple[str, listMode], FlagSpec]]] = None,
316300
shell: str = "/bin/bash",
317301
debug: bool = False,
318302
resources: Optional[flyte.Resources] = None,
@@ -484,16 +468,11 @@ async def pipeline(a: File, b: list[File]) -> list[File]:
484468
coerced_aliases: dict[str, FlagSpec] = {}
485469
for n, alias in (flag_aliases or {}).items():
486470
if n not in inputs:
487-
raise KeyError(
488-
f"flag_aliases references {n!r} which is not declared in inputs."
489-
)
471+
raise KeyError(f"flag_aliases references {n!r} which is not declared in inputs.")
490472
coerced_aliases[n] = FlagSpec.coerce(n, alias)
491473

492474
if not isinstance(image, (str, flyte.Image)):
493-
raise TypeError(
494-
f"image must be a URI string or a flyte.Image, got "
495-
f"{type(image).__name__}."
496-
)
475+
raise TypeError(f"image must be a URI string or a flyte.Image, got {type(image).__name__}.")
497476

498477
return _Shell(
499478
name=name,

0 commit comments

Comments
 (0)