@@ -143,6 +143,7 @@ def __init__(
143143
144144 self .name = name
145145 self .graph = graph
146+ self ._parse_conditional_branches ()
146147 self .flow = flow
147148 self .code_package_metadata = code_package_metadata
148149 self .code_package_sha = code_package_sha
@@ -920,6 +921,121 @@ def _compile_workflow_template(self):
920921 )
921922 )
922923
924+ # Visit every node and record information on conditional step structure
925+ def _parse_conditional_branches (self ):
926+ self .conditional_nodes = set ()
927+ self .conditional_join_nodes = set ()
928+ self .matching_conditional_join_dict = {}
929+
930+ node_conditional_parents = {}
931+ node_conditional_branches = {}
932+
933+ def _visit (node , seen , conditional_branch , conditional_parents = None ):
934+ if not node .type == "split-switch" and not (
935+ conditional_branch and conditional_parents
936+ ):
937+ # skip regular non-conditional nodes entirely
938+ return
939+
940+ if node .type == "split-switch" :
941+ conditional_branch = conditional_branch + [node .name ]
942+ node_conditional_branches [node .name ] = conditional_branch
943+
944+ conditional_parents = (
945+ [node .name ]
946+ if not conditional_parents
947+ else conditional_parents + [node .name ]
948+ )
949+ node_conditional_parents [node .name ] = conditional_parents
950+
951+ if conditional_parents and not node .type == "split-switch" :
952+ node_conditional_parents [node .name ] = conditional_parents
953+ conditional_branch = conditional_branch + [node .name ]
954+ node_conditional_branches [node .name ] = conditional_branch
955+
956+ self .conditional_nodes .add (node .name )
957+
958+ if conditional_branch and conditional_parents :
959+ for n in node .out_funcs :
960+ child = self .graph [n ]
961+ if n not in seen :
962+ _visit (
963+ child , seen + [n ], conditional_branch , conditional_parents
964+ )
965+
966+ # First we visit all nodes to determine conditional parents and branches
967+ for n in self .graph :
968+ _visit (n , [], [])
969+
970+ # Then we traverse again in order to determine conditional join nodes, and matching conditional join info
971+ for node in self .graph :
972+ if node_conditional_parents .get (node .name , False ):
973+ # do the required postprocessing for anything requiring node.in_funcs
974+
975+ # check that in previous parsing we have not closed all conditional in_funcs.
976+ # If so, this step can not be conditional either
977+ is_conditional = any (
978+ in_func in self .conditional_nodes
979+ or self .graph [in_func ].type == "split-switch"
980+ for in_func in node .in_funcs
981+ )
982+ if is_conditional :
983+ self .conditional_nodes .add (node .name )
984+ else :
985+ if node .name in self .conditional_nodes :
986+ self .conditional_nodes .remove (node .name )
987+
988+ # does this node close the latest conditional parent branches?
989+ conditional_in_funcs = [
990+ in_func
991+ for in_func in node .in_funcs
992+ if node_conditional_branches .get (in_func , False )
993+ ]
994+ closed_conditional_parents = []
995+ for last_split_switch in node_conditional_parents .get (node .name , [])[
996+ ::- 1
997+ ]:
998+ last_conditional_split_nodes = self .graph [
999+ last_split_switch
1000+ ].out_funcs
1001+ # p needs to be in at least one conditional_branch for it to be closed.
1002+ if all (
1003+ any (
1004+ p in node_conditional_branches .get (in_func , [])
1005+ for in_func in conditional_in_funcs
1006+ )
1007+ for p in last_conditional_split_nodes
1008+ ):
1009+ closed_conditional_parents .append (last_split_switch )
1010+
1011+ self .conditional_join_nodes .add (node .name )
1012+ self .matching_conditional_join_dict [last_split_switch ] = (
1013+ node .name
1014+ )
1015+
1016+ # Did we close all conditionals? Then this branch and all its children are not conditional anymore (unless a new conditional branch is encountered).
1017+ if not [
1018+ p
1019+ for p in node_conditional_parents .get (node .name , [])
1020+ if p not in closed_conditional_parents
1021+ ]:
1022+ if node .name in self .conditional_nodes :
1023+ self .conditional_nodes .remove (node .name )
1024+ node_conditional_parents [node .name ] = []
1025+ for p in node .out_funcs :
1026+ if p in self .conditional_nodes :
1027+ self .conditional_nodes .remove (p )
1028+ node_conditional_parents [p ] = []
1029+
1030+ def _is_conditional_node (self , node ):
1031+ return node .name in self .conditional_nodes
1032+
1033+ def _is_conditional_join_node (self , node ):
1034+ return node .name in self .conditional_join_nodes
1035+
1036+ def _matching_conditional_join (self , node ):
1037+ return self .matching_conditional_join_dict .get (node .name , None )
1038+
9231039 # Visit every node and yield the uber DAGTemplate(s).
9241040 def _dag_templates (self ):
9251041 def _visit (
@@ -941,19 +1057,15 @@ def _visit(
9411057 dag_tasks = []
9421058 if templates is None :
9431059 templates = []
1060+
9441061 if exit_node is not None and exit_node is node .name :
9451062 return templates , dag_tasks
9461063 if node .name == "start" :
9471064 # Start node has no dependencies.
9481065 dag_task = DAGTask (self ._sanitize (node .name )).template (
9491066 self ._sanitize (node .name )
9501067 )
951- if node .type == "split-switch" :
952- raise ArgoWorkflowsException (
953- "Deploying flows with switch statement "
954- "to Argo Workflows is not supported currently."
955- )
956- elif (
1068+ if (
9571069 node .is_inside_foreach
9581070 and self .graph [node .in_funcs [0 ]].type == "foreach"
9591071 and not self .graph [node .in_funcs [0 ]].parallel_foreach
@@ -1044,15 +1156,32 @@ def _visit(
10441156 else :
10451157 # Every other node needs only input-paths
10461158 parameters = [
1047- Parameter ("input-paths" ).value (
1048- compress_list (
1049- [
1050- "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1051- % (n , self ._sanitize (n ))
1052- for n in node .in_funcs
1053- ],
1054- # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
1055- zlibmin = inf ,
1159+ (
1160+ Parameter ("input-paths" ).value (
1161+ compress_list (
1162+ [
1163+ "argo-{{workflow.name}}/%s/{{tasks.%s.outputs.parameters.task-id}}"
1164+ % (n , self ._sanitize (n ))
1165+ for n in node .in_funcs
1166+ ],
1167+ # NOTE: We set zlibmin to infinite because zlib compression for the Argo input-paths breaks template value substitution.
1168+ zlibmin = inf ,
1169+ )
1170+ )
1171+ if not self ._is_conditional_join_node (node )
1172+ # The value fetching for input-paths from conditional steps has to be quite involved
1173+ # in order to avoid issues with replacements due to missing step outputs.
1174+ # NOTE: we differentiate the input-path expression only for conditional joins so we can still utilize the list compression,
1175+ # but do not have to rework all decompress usage due to the need for a custom separator
1176+ else Parameter ("input-paths" ).value (
1177+ compress_list (
1178+ [
1179+ "argo-{{workflow.name}}/%s/{{=(get(tasks['%s']?.outputs?.parameters, 'task-id') ?? 'no-task')}}"
1180+ % (n , self ._sanitize (n ))
1181+ for n in node .in_funcs
1182+ ],
1183+ separator = "%" , # non-default separator is required due to commas in the value expression
1184+ )
10561185 )
10571186 )
10581187 ]
@@ -1087,15 +1216,43 @@ def _visit(
10871216 ]
10881217 )
10891218
1219+ conditional_deps = [
1220+ "%s.Succeeded" % self ._sanitize (in_func )
1221+ for in_func in node .in_funcs
1222+ if self ._is_conditional_node (self .graph [in_func ])
1223+ ]
1224+ required_deps = [
1225+ "%s.Succeeded" % self ._sanitize (in_func )
1226+ for in_func in node .in_funcs
1227+ if not self ._is_conditional_node (self .graph [in_func ])
1228+ ]
1229+ both_conditions = required_deps and conditional_deps
1230+
1231+ depends_str = "{required}{_and}{conditional}" .format (
1232+ required = ("(%s)" if both_conditions else "%s" )
1233+ % " && " .join (required_deps ),
1234+ _and = " && " if both_conditions else "" ,
1235+ conditional = ("(%s)" if both_conditions else "%s" )
1236+ % " || " .join (conditional_deps ),
1237+ )
10901238 dag_task = (
10911239 DAGTask (self ._sanitize (node .name ))
1092- .dependencies (
1093- [self ._sanitize (in_func ) for in_func in node .in_funcs ]
1094- )
1240+ .depends (depends_str )
10951241 .template (self ._sanitize (node .name ))
10961242 .arguments (Arguments ().parameters (parameters ))
10971243 )
10981244
1245+ # Add conditional if this is the first step in a conditional branch
1246+ if (
1247+ self ._is_conditional_node (node )
1248+ and self .graph [node .in_funcs [0 ]].type == "split-switch"
1249+ ):
1250+ in_func = node .in_funcs [0 ]
1251+ dag_task .when (
1252+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1253+ % (self ._sanitize (in_func ), node .name )
1254+ )
1255+
10991256 dag_tasks .append (dag_task )
11001257 # End the workflow if we have reached the end of the flow
11011258 if node .type == "end" :
@@ -1121,6 +1278,23 @@ def _visit(
11211278 dag_tasks ,
11221279 parent_foreach ,
11231280 )
1281+ elif node .type == "split-switch" :
1282+ for n in node .out_funcs :
1283+ _visit (
1284+ self .graph [n ],
1285+ self ._matching_conditional_join (node ),
1286+ templates ,
1287+ dag_tasks ,
1288+ parent_foreach ,
1289+ )
1290+
1291+ return _visit (
1292+ self .graph [self ._matching_conditional_join (node )],
1293+ exit_node ,
1294+ templates ,
1295+ dag_tasks ,
1296+ parent_foreach ,
1297+ )
11241298 # For foreach nodes generate a new sub DAGTemplate
11251299 # We do this for "regular" foreaches (ie. `self.next(self.a, foreach=)`)
11261300 elif node .type == "foreach" :
@@ -1148,7 +1322,7 @@ def _visit(
11481322 #
11491323 foreach_task = (
11501324 DAGTask (foreach_template_name )
1151- .dependencies ([ self ._sanitize (node .name )] )
1325+ .depends ( f" { self ._sanitize (node .name )} .Succeeded" )
11521326 .template (foreach_template_name )
11531327 .arguments (
11541328 Arguments ().parameters (
@@ -1193,6 +1367,16 @@ def _visit(
11931367 % self ._sanitize (node .name )
11941368 )
11951369 )
1370+ # Add conditional if this is the first step in a conditional branch
1371+ if self ._is_conditional_node (node ) and not any (
1372+ self ._is_conditional_node (self .graph [in_func ])
1373+ for in_func in node .in_funcs
1374+ ):
1375+ in_func = node .in_funcs [0 ]
1376+ foreach_task .when (
1377+ "{{tasks.%s.outputs.parameters.switch-step}}==%s"
1378+ % (self ._sanitize (in_func ), node .name )
1379+ )
11961380 dag_tasks .append (foreach_task )
11971381 templates , dag_tasks_1 = _visit (
11981382 self .graph [node .out_funcs [0 ]],
@@ -1236,7 +1420,22 @@ def _visit(
12361420 self .graph [node .matching_join ].in_funcs [0 ]
12371421 )
12381422 }
1239- )
1423+ if not self ._is_conditional_join_node (
1424+ self .graph [node .matching_join ]
1425+ )
1426+ else
1427+ # Note: If the nodes leading to the join are conditional, then we need to use an expression to pick the outputs from the task that executed.
1428+ # ref for operators: https://github.com/expr-lang/expr/blob/master/docs/language-definition.md
1429+ {
1430+ "expression" : "get((%s)?.parameters, 'task-id')"
1431+ % " ?? " .join (
1432+ f"tasks['{ self ._sanitize (func )} ']?.outputs"
1433+ for func in self .graph [
1434+ node .matching_join
1435+ ].in_funcs
1436+ )
1437+ }
1438+ ),
12401439 ]
12411440 if not node .parallel_foreach
12421441 else [
@@ -1269,7 +1468,7 @@ def _visit(
12691468 join_foreach_task = (
12701469 DAGTask (self ._sanitize (self .graph [node .matching_join ].name ))
12711470 .template (self ._sanitize (self .graph [node .matching_join ].name ))
1272- .dependencies ([ foreach_template_name ] )
1471+ .depends ( f" { foreach_template_name } .Succeeded" )
12731472 .arguments (
12741473 Arguments ().parameters (
12751474 (
@@ -1568,10 +1767,25 @@ def _container_templates(self):
15681767 ]
15691768 )
15701769 input_paths = "%s/_parameters/%s" % (run_id , task_id_params )
1770+ # Only for static joins and conditional_joins
1771+ elif self ._is_conditional_join_node (node ) and not (
1772+ node .type == "join"
1773+ and self .graph [node .split_parents [- 1 ]].type == "foreach"
1774+ ):
1775+ input_paths = (
1776+ "$(python -m metaflow.plugins.argo.conditional_input_paths %s)"
1777+ % input_paths
1778+ )
15711779 elif (
15721780 node .type == "join"
15731781 and self .graph [node .split_parents [- 1 ]].type == "foreach"
15741782 ):
1783+ # foreach-joins straight out of conditional branches are not yet supported
1784+ if self ._is_conditional_join_node (node ):
1785+ raise ArgoWorkflowsException (
1786+ "Conditionals steps that transition directly into a join step are not currently supported. "
1787+ "As a workaround, you can add a normal step after the conditional steps that transitions to a join step."
1788+ )
15751789 # Set aggregated input-paths for a for-each join
15761790 foreach_step = next (
15771791 n for n in node .in_funcs if self .graph [n ].is_inside_foreach
@@ -1814,7 +2028,7 @@ def _container_templates(self):
18142028 [Parameter ("num-parallel" ), Parameter ("task-id-entropy" )]
18152029 )
18162030 else :
1817- # append this only for joins of foreaches, not static splits
2031+ # append these only for joins of foreaches, not static splits
18182032 inputs .append (Parameter ("split-cardinality" ))
18192033 # check if the node is a @parallel node.
18202034 elif node .parallel_step :
@@ -1849,6 +2063,13 @@ def _container_templates(self):
18492063 # are derived at runtime.
18502064 if not (node .name == "end" or node .parallel_step ):
18512065 outputs = [Parameter ("task-id" ).valueFrom ({"path" : "/mnt/out/task_id" })]
2066+
2067+ # If this step is a split-switch one, we need to output the switch step name
2068+ if node .type == "split-switch" :
2069+ outputs .append (
2070+ Parameter ("switch-step" ).valueFrom ({"path" : "/mnt/out/switch_step" })
2071+ )
2072+
18522073 if node .type == "foreach" :
18532074 # Emit split cardinality from foreach task
18542075 outputs .append (
@@ -3981,6 +4202,10 @@ def dependencies(self, dependencies):
39814202 self .payload ["dependencies" ] = dependencies
39824203 return self
39834204
4205+ def depends (self , depends : str ):
4206+ self .payload ["depends" ] = depends
4207+ return self
4208+
39844209 def template (self , template ):
39854210 # Template reference
39864211 self .payload ["template" ] = template
@@ -3992,6 +4217,10 @@ def inline(self, template):
39924217 self .payload ["inline" ] = template .to_json ()
39934218 return self
39944219
4220+ def when (self , when : str ):
4221+ self .payload ["when" ] = when
4222+ return self
4223+
39954224 def with_param (self , with_param ):
39964225 self .payload ["withParam" ] = with_param
39974226 return self
0 commit comments