Skip to content

Commit f4b56e7

Browse files
authored
feature: conditionals for Argo Workflows (#2550)
Adds Conditional support to Argo Workflows TODO: - [x] nesting a conditional as the last step inside a foreach is not yet working. This is now out of scope for the initial release - [x] test `@parallel` with conditionals - [x] test regular flows with argo for possible regressions Open issues to solve in another PR: - conditionals as last step in a foreach. requires thorough input-paths rework - `@parallel` as a switch step, where it leads to more than one possible parallel-join step. This needs to either be fixed, or blocked. Issue exists with local conditional runs as well so outside of the scope for this PR
1 parent 2401e34 commit f4b56e7

File tree

3 files changed

+278
-22
lines changed

3 files changed

+278
-22
lines changed

metaflow/plugins/argo/argo_workflows.py

Lines changed: 251 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def __init__(
143143

144144
self.name = name
145145
self.graph = graph
146+
self._parse_conditional_branches()
146147
self.flow = flow
147148
self.code_package_metadata = code_package_metadata
148149
self.code_package_sha = code_package_sha
@@ -920,6 +921,121 @@ def _compile_workflow_template(self):
920921
)
921922
)
922923

924+
# Visit every node and record information on conditional step structure
925+
def _parse_conditional_branches(self):
926+
self.conditional_nodes = set()
927+
self.conditional_join_nodes = set()
928+
self.matching_conditional_join_dict = {}
929+
930+
node_conditional_parents = {}
931+
node_conditional_branches = {}
932+
933+
def _visit(node, seen, conditional_branch, conditional_parents=None):
934+
if not node.type == "split-switch" and not (
935+
conditional_branch and conditional_parents
936+
):
937+
# skip regular non-conditional nodes entirely
938+
return
939+
940+
if node.type == "split-switch":
941+
conditional_branch = conditional_branch + [node.name]
942+
node_conditional_branches[node.name] = conditional_branch
943+
944+
conditional_parents = (
945+
[node.name]
946+
if not conditional_parents
947+
else conditional_parents + [node.name]
948+
)
949+
node_conditional_parents[node.name] = conditional_parents
950+
951+
if conditional_parents and not node.type == "split-switch":
952+
node_conditional_parents[node.name] = conditional_parents
953+
conditional_branch = conditional_branch + [node.name]
954+
node_conditional_branches[node.name] = conditional_branch
955+
956+
self.conditional_nodes.add(node.name)
957+
958+
if conditional_branch and conditional_parents:
959+
for n in node.out_funcs:
960+
child = self.graph[n]
961+
if n not in seen:
962+
_visit(
963+
child, seen + [n], conditional_branch, conditional_parents
964+
)
965+
966+
# First we visit all nodes to determine conditional parents and branches
967+
for n in self.graph:
968+
_visit(n, [], [])
969+
970+
# Then we traverse again in order to determine conditional join nodes, and matching conditional join info
971+
for node in self.graph:
972+
if node_conditional_parents.get(node.name, False):
973+
# do the required postprocessing for anything requiring node.in_funcs
974+
975+
# check that in previous parsing we have not closed all conditional in_funcs.
976+
# If so, this step can not be conditional either
977+
is_conditional = any(
978+
in_func in self.conditional_nodes
979+
or self.graph[in_func].type == "split-switch"
980+
for in_func in node.in_funcs
981+
)
982+
if is_conditional:
983+
self.conditional_nodes.add(node.name)
984+
else:
985+
if node.name in self.conditional_nodes:
986+
self.conditional_nodes.remove(node.name)
987+
988+
# does this node close the latest conditional parent branches?
989+
conditional_in_funcs = [
990+
in_func
991+
for in_func in node.in_funcs
992+
if node_conditional_branches.get(in_func, False)
993+
]
994+
closed_conditional_parents = []
995+
for last_split_switch in node_conditional_parents.get(node.name, [])[
996+
::-1
997+
]:
998+
last_conditional_split_nodes = self.graph[
999+
last_split_switch
1000+
].out_funcs
1001+
# p needs to be in at least one conditional_branch for it to be closed.
1002+
if all(
1003+
any(
1004+
p in node_conditional_branches.get(in_func, [])
1005+
for in_func in conditional_in_funcs
1006+
)
1007+
for p in last_conditional_split_nodes
1008+
):
1009+
closed_conditional_parents.append(last_split_switch)
1010+
1011+
self.conditional_join_nodes.add(node.name)
1012+
self.matching_conditional_join_dict[last_split_switch] = (
1013+
node.name
1014+
)
1015+
1016+
# Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered).
1017+
if not [
1018+
p
1019+
for p in node_conditional_parents.get(node.name, [])
1020+
if p not in closed_conditional_parents
1021+
]:
1022+
if node.name in self.conditional_nodes:
1023+
self.conditional_nodes.remove(node.name)
1024+
node_conditional_parents[node.name] = []
1025+
for p in node.out_funcs:
1026+
if p in self.conditional_nodes:
1027+
self.conditional_nodes.remove(p)
1028+
node_conditional_parents[p] = []
1029+
1030+
def _is_conditional_node(self, node):
1031+
return node.name in self.conditional_nodes
1032+
1033+
def _is_conditional_join_node(self, node):
1034+
return node.name in self.conditional_join_nodes
1035+
1036+
def _matching_conditional_join(self, node):
1037+
return self.matching_conditional_join_dict.get(node.name, None)
1038+
9231039
# Visit every node and yield the uber DAGTemplate(s).
9241040
def _dag_templates(self):
9251041
def _visit(
@@ -941,19 +1057,15 @@ def _visit(
9411057
dag_tasks = []
9421058
if templates is None:
9431059
templates = []
1060+
9441061
if exit_node is not None and exit_node is node.name:
9451062
return templates, dag_tasks
9461063
if node.name == "start":
9471064
# Start node has no dependencies.
9481065
dag_task = DAGTask(self._sanitize(node.name)).template(
9491066
self._sanitize(node.name)
9501067
)
951-
if node.type == "split-switch":
952-
raise ArgoWorkflowsException(
953-
"Deploying flows with switch statement "
954-
"to Argo Workflows is not supported currently."
955-
)
956-
elif (
1068+
if (
9571069
node.is_inside_foreach
9581070
and self.graph[node.in_funcs[0]].type == "foreach"
9591071
and not self.graph[node.in_funcs[0]].parallel_foreach
@@ -1044,15 +1156,32 @@ def _visit(
10441156
else:
10451157
# Every other node needs only input-paths
10461158
parameters = [
1047-
Parameter("input-paths").value(
1048-
compress_list(
1049-
[
1050-
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1051-
% (n, self._sanitize(n))
1052-
for n in node.in_funcs
1053-
],
1054-
# NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
1055-
zlibmin=inf,
1159+
(
1160+
Parameter("input-paths").value(
1161+
compress_list(
1162+
[
1163+
"argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1164+
% (n, self._sanitize(n))
1165+
for n in node.in_funcs
1166+
],
1167+
# NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
1168+
zlibmin=inf,
1169+
)
1170+
)
1171+
if not self._is_conditional_join_node(node)
1172+
# The value fetching for input-paths from conditional steps has to be quite involved
1173+
# in order to avoid issues with replacements due to missing step outputs.
1174+
# NOTE: we differentiate the input-path expression only for conditional joins so we can still utilize the list compression,
1175+
# but do not have to rework all decompress usage due to the need for a custom separator
1176+
else Parameter("input-paths").value(
1177+
compress_list(
1178+
[
1179+
"argo-{{workflow.name}}/%s/{{=(get(tasks['%s']?.outputs?.parameters, 'task-id') ?? 'no-task')}}"
1180+
% (n, self._sanitize(n))
1181+
for n in node.in_funcs
1182+
],
1183+
separator="%", # non-default separator is required due to commas in the value expression
1184+
)
10561185
)
10571186
)
10581187
]
@@ -1087,15 +1216,43 @@ def _visit(
10871216
]
10881217
)
10891218

1219+
conditional_deps = [
1220+
"%s.Succeeded" % self._sanitize(in_func)
1221+
for in_func in node.in_funcs
1222+
if self._is_conditional_node(self.graph[in_func])
1223+
]
1224+
required_deps = [
1225+
"%s.Succeeded" % self._sanitize(in_func)
1226+
for in_func in node.in_funcs
1227+
if not self._is_conditional_node(self.graph[in_func])
1228+
]
1229+
both_conditions = required_deps and conditional_deps
1230+
1231+
depends_str = "{required}{_and}{conditional}".format(
1232+
required=("(%s)" if both_conditions else "%s")
1233+
% " && ".join(required_deps),
1234+
_and=" && " if both_conditions else "",
1235+
conditional=("(%s)" if both_conditions else "%s")
1236+
% " || ".join(conditional_deps),
1237+
)
10901238
dag_task = (
10911239
DAGTask(self._sanitize(node.name))
1092-
.dependencies(
1093-
[self._sanitize(in_func) for in_func in node.in_funcs]
1094-
)
1240+
.depends(depends_str)
10951241
.template(self._sanitize(node.name))
10961242
.arguments(Arguments().parameters(parameters))
10971243
)
10981244

1245+
# Add conditional if this is the first step in a conditional branch
1246+
if (
1247+
self._is_conditional_node(node)
1248+
and self.graph[node.in_funcs[0]].type == "split-switch"
1249+
):
1250+
in_func = node.in_funcs[0]
1251+
dag_task.when(
1252+
"{{tasks.%s.outputs.parameters.switch-step}}==%s"
1253+
% (self._sanitize(in_func), node.name)
1254+
)
1255+
10991256
dag_tasks.append(dag_task)
11001257
# End the workflow if we have reached the end of the flow
11011258
if node.type == "end":
@@ -1121,6 +1278,23 @@ def _visit(
11211278
dag_tasks,
11221279
parent_foreach,
11231280
)
1281+
elif node.type == "split-switch":
1282+
for n in node.out_funcs:
1283+
_visit(
1284+
self.graph[n],
1285+
self._matching_conditional_join(node),
1286+
templates,
1287+
dag_tasks,
1288+
parent_foreach,
1289+
)
1290+
1291+
return _visit(
1292+
self.graph[self._matching_conditional_join(node)],
1293+
exit_node,
1294+
templates,
1295+
dag_tasks,
1296+
parent_foreach,
1297+
)
11241298
# For foreach nodes generate a new sub DAGTemplate
11251299
# We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
11261300
elif node.type == "foreach":
@@ -1148,7 +1322,7 @@ def _visit(
11481322
#
11491323
foreach_task = (
11501324
DAGTask(foreach_template_name)
1151-
.dependencies([self._sanitize(node.name)])
1325+
.depends(f"{self._sanitize(node.name)}.Succeeded")
11521326
.template(foreach_template_name)
11531327
.arguments(
11541328
Arguments().parameters(
@@ -1193,6 +1367,16 @@ def _visit(
11931367
% self._sanitize(node.name)
11941368
)
11951369
)
1370+
# Add conditional if this is the first step in a conditional branch
1371+
if self._is_conditional_node(node) and not any(
1372+
self._is_conditional_node(self.graph[in_func])
1373+
for in_func in node.in_funcs
1374+
):
1375+
in_func = node.in_funcs[0]
1376+
foreach_task.when(
1377+
"{{tasks.%s.outputs.parameters.switch-step}}==%s"
1378+
% (self._sanitize(in_func), node.name)
1379+
)
11961380
dag_tasks.append(foreach_task)
11971381
templates, dag_tasks_1 = _visit(
11981382
self.graph[node.out_funcs[0]],
@@ -1236,7 +1420,22 @@ def _visit(
12361420
self.graph[node.matching_join].in_funcs[0]
12371421
)
12381422
}
1239-
)
1423+
if not self._is_conditional_join_node(
1424+
self.graph[node.matching_join]
1425+
)
1426+
else
1427+
# Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
1428+
# ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
1429+
{
1430+
"expression": "get((%s)?.parameters, 'task-id')"
1431+
% " ?? ".join(
1432+
f"tasks['{self._sanitize(func)}']?.outputs"
1433+
for func in self.graph[
1434+
node.matching_join
1435+
].in_funcs
1436+
)
1437+
}
1438+
),
12401439
]
12411440
if not node.parallel_foreach
12421441
else [
@@ -1269,7 +1468,7 @@ def _visit(
12691468
join_foreach_task = (
12701469
DAGTask(self._sanitize(self.graph[node.matching_join].name))
12711470
.template(self._sanitize(self.graph[node.matching_join].name))
1272-
.dependencies([foreach_template_name])
1471+
.depends(f"{foreach_template_name}.Succeeded")
12731472
.arguments(
12741473
Arguments().parameters(
12751474
(
@@ -1568,10 +1767,25 @@ def _container_templates(self):
15681767
]
15691768
)
15701769
input_paths = "%s/_parameters/%s" % (run_id, task_id_params)
1770+
# Only for static joins and conditional_joins
1771+
elif self._is_conditional_join_node(node) and not (
1772+
node.type == "join"
1773+
and self.graph[node.split_parents[-1]].type == "foreach"
1774+
):
1775+
input_paths = (
1776+
"$(python -m metaflow.plugins.argo.conditional_input_paths %s)"
1777+
% input_paths
1778+
)
15711779
elif (
15721780
node.type == "join"
15731781
and self.graph[node.split_parents[-1]].type == "foreach"
15741782
):
1783+
# foreach-joins straight out of conditional branches are not yet supported
1784+
if self._is_conditional_join_node(node):
1785+
raise ArgoWorkflowsException(
1786+
"Conditionals steps that transition directly into a join step are not currently supported. "
1787+
"As a workaround, you can add a normal step after the conditional steps that transitions to a join step."
1788+
)
15751789
# Set aggregated input-paths for a for-each join
15761790
foreach_step = next(
15771791
n for n in node.in_funcs if self.graph[n].is_inside_foreach
@@ -1814,7 +2028,7 @@ def _container_templates(self):
18142028
[Parameter("num-parallel"), Parameter("task-id-entropy")]
18152029
)
18162030
else:
1817-
# append this only for joins of foreaches, not static splits
2031+
# append these only for joins of foreaches, not static splits
18182032
inputs.append(Parameter("split-cardinality"))
18192033
# check if the node is a @parallel node.
18202034
elif node.parallel_step:
@@ -1849,6 +2063,13 @@ def _container_templates(self):
18492063
# are derived at runtime.
18502064
if not (node.name == "end" or node.parallel_step):
18512065
outputs = [Parameter("task-id").valueFrom({"path": "/mnt/out/task_id"})]
2066+
2067+
# If this step is a split-switch one, we need to output the switch step name
2068+
if node.type == "split-switch":
2069+
outputs.append(
2070+
Parameter("switch-step").valueFrom({"path": "/mnt/out/switch_step"})
2071+
)
2072+
18522073
if node.type == "foreach":
18532074
# Emit split cardinality from foreach task
18542075
outputs.append(
@@ -3981,6 +4202,10 @@ def dependencies(self, dependencies):
39814202
self.payload["dependencies"] = dependencies
39824203
return self
39834204

4205+
def depends(self, depends: str):
4206+
self.payload["depends"] = depends
4207+
return self
4208+
39844209
def template(self, template):
39854210
# Template reference
39864211
self.payload["template"] = template
@@ -3992,6 +4217,10 @@ def inline(self, template):
39924217
self.payload["inline"] = template.to_json()
39934218
return self
39944219

4220+
def when(self, when: str):
4221+
self.payload["when"] = when
4222+
return self
4223+
39954224
def with_param(self, with_param):
39964225
self.payload["withParam"] = with_param
39974226
return self

0 commit comments

Comments
 (0)