Skip to content

Commit

Permalink
Fix scatter over workflows by advancing iteration over all steps (#187)
Browse files Browse the repository at this point in the history
* Fix scatter over workflows by advancing iteration of all steps of a scatter.
  • Loading branch information
tetron authored Sep 12, 2016
1 parent 2b3d2ec commit bff8a26
Showing 1 changed file with 46 additions and 25 deletions.
71 changes: 46 additions & 25 deletions cwltool/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,13 @@ def valueFromFunc(k, v): # type: (Any, Any) -> Any
# https://github.com/python/mypy/issues/797
**kwargs)
elif method == "flat_crossproduct":
jobs = flat_crossproduct_scatter(step, inputobj,
scatter,
cast(Callable[[Any], Any],
jobs = cast(Generator,
flat_crossproduct_scatter(step, inputobj,
scatter,
cast(Callable[[Any], Any],
# known bug in mypy
# https://github.com/python/mypy/issues/797
callback), 0, **kwargs)
callback), 0, **kwargs))
else:
_logger.debug(u"[job %s] job input %s", step.name, json.dumps(inputobj, indent=4))
inputobj = postScatterEval(inputobj)
Expand All @@ -332,7 +333,7 @@ def run(self, **kwargs):
_logger.debug(u"[%s] workflow starting", self.name)

def job(self, joborder, output_callback, **kwargs):
# type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
# type: (Dict[Text, Any], Callable[[Any, Any], Any], **Any) -> Generator
self.state = {}
self.processStatus = "success"

Expand Down Expand Up @@ -405,7 +406,7 @@ def __init__(self, toolpath_object, **kwargs):
# TODO: statically validate data links instead of doing it at runtime.

def job(self, joborder, output_callback, **kwargs):
# type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator[WorkflowJob, None, None]
# type: (Dict[Text, Text], Callable[[Any, Any], Any], **Any) -> Generator
builder = self._init_job(joborder, **kwargs)
wj = WorkflowJob(self, **kwargs)
yield wj
Expand Down Expand Up @@ -577,9 +578,25 @@ def setTotal(self, total): # type: (int) -> None
if self.completed == self.total:
self.output_callback(self.dest, self.processStatus)

def parallel_steps(steps, rc, kwargs): # type: (List[Generator], ReceiveScatterOutput, Dict[str, Any]) -> Generator
while rc.completed < rc.total:
made_progress = False
for step in steps:
if kwargs.get("on_error", "stop") == "stop" and rc.processStatus != "success":
break
for j in step:
if kwargs.get("on_error", "stop") == "stop" and rc.processStatus != "success":
break
if j:
made_progress = True
yield j
else:
break
if not made_progress and rc.completed < rc.total:
yield None

def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwargs):
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
l = None
for s in scatter_keys:
if l is None:
Expand All @@ -593,21 +610,23 @@ def dotproduct_scatter(process, joborder, scatter_keys, output_callback, **kwarg

rc = ReceiveScatterOutput(output_callback, output)

steps = []
for n in range(0, l):
jo = copy.copy(joborder)
for s in scatter_keys:
jo[s] = joborder[s][n]

jo = kwargs["postScatterEval"](jo)

for j in process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs):
yield j
steps.append(process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs))

rc.setTotal(l)

return parallel_steps(steps, rc, kwargs)


def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback, **kwargs):
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator[WorkflowJob, None, None]
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Callable[..., Any], **Any) -> Generator
scatter_key = scatter_keys[0]
l = len(joborder[scatter_key])
output = {} # type: Dict[Text,List[Text]]
Expand All @@ -616,25 +635,24 @@ def nested_crossproduct_scatter(process, joborder, scatter_keys, output_callback

rc = ReceiveScatterOutput(output_callback, output)

steps = []
for n in range(0, l):
jo = copy.copy(joborder)
jo[scatter_key] = joborder[scatter_key][n]

if len(scatter_keys) == 1:
jo = kwargs["postScatterEval"](jo)
for j in process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs):
yield j
steps.append(process.job(jo, functools.partial(rc.receive_scatter_output, n), **kwargs))
else:
for j in nested_crossproduct_scatter(process, jo,
steps.append(nested_crossproduct_scatter(process, jo,
scatter_keys[1:], cast( # known bug with mypy
# https://github.com/python/mypy/issues/797
# https://github.com/python/mypy/issues/797g
Callable[[Any], Any],
functools.partial(rc.receive_scatter_output, n)),
**kwargs):
yield j
functools.partial(rc.receive_scatter_output, n)), **kwargs))

rc.setTotal(l)

return parallel_steps(steps, rc, kwargs)

def crossproduct_size(joborder, scatter_keys):
# type: (Dict[Text, Any], List[Text]) -> int
Expand All @@ -650,7 +668,7 @@ def crossproduct_size(joborder, scatter_keys):
return sum

def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback, startindex, **kwargs):
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Generator[WorkflowJob, None, None]
# type: (WorkflowJobStep, Dict[Text, Any], List[Text], Union[ReceiveScatterOutput,Callable[..., Any]], int, **Any) -> Union[List[Generator], Generator]
scatter_key = scatter_keys[0]
l = len(joborder[scatter_key])
rc = None # type: ReceiveScatterOutput
Expand All @@ -665,20 +683,23 @@ def flat_crossproduct_scatter(process, joborder, scatter_keys, output_callback,
else:
raise Exception("Unhandled code path. Please report this.")

steps = []
put = startindex
for n in range(0, l):
jo = copy.copy(joborder)
jo[scatter_key] = joborder[scatter_key][n]

if len(scatter_keys) == 1:
jo = kwargs["postScatterEval"](jo)
for j in process.job(jo, functools.partial(rc.receive_scatter_output, put), **kwargs):
yield j
steps.append(process.job(jo, functools.partial(rc.receive_scatter_output, put), **kwargs))
put += 1
else:
for j in flat_crossproduct_scatter(process, jo, scatter_keys[1:], rc, put, **kwargs):
if j:
put += 1
yield j
add = flat_crossproduct_scatter(process, jo, scatter_keys[1:], rc, put, **kwargs)
put += len(cast(List[Generator], add))
steps.extend(add)

rc.setTotal(put)
if startindex == 0 and not isinstance(output_callback, ReceiveScatterOutput):
rc.setTotal(put)
return parallel_steps(steps, rc, kwargs)
else:
return steps

0 comments on commit bff8a26

Please sign in to comment.