Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(backend): Resolves issue when using task dependent on ParallelFor completion #11476

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def test_set_description_through_pipeline_decorator(self):

@dsl.pipeline(description='Prefer me.')
def my_pipeline():
"""Don't prefer me"""
"""Don't prefer me."""
VALID_PRODUCER_COMPONENT_SAMPLE(input_param='input')

self.assertEqual(my_pipeline.pipeline_spec.pipeline_info.description,
Expand All @@ -441,7 +441,8 @@ def test_set_description_through_pipeline_docstring_long(self):
def my_pipeline():
"""Docstring-specified description.

More information about this pipeline."""
More information about this pipeline.
"""
VALID_PRODUCER_COMPONENT_SAMPLE(input_param='input')

self.assertEqual(
Expand Down Expand Up @@ -2429,6 +2430,7 @@ def pipeline_with_multiline_definition(
sample_input1: bool = True,
sample_input2: str = 'string') -> str:
"""docstring short description.

docstring long description. docstring long description.
"""
op1 = my_comp(string=sample_input2, model=sample_input1)
Expand All @@ -2455,10 +2457,9 @@ def pipeline_with_multiline_definition(
def pipeline_with_multiline_definition(
sample_input1: bool = True,
sample_input2: str = 'string') -> str:
"""
docstring long description.
docstring long description.
docstring long description.
"""docstring long description.

docstring long description. docstring long description.
"""
op1 = my_comp(string=sample_input2, model=sample_input1)
result = op1.output
Expand Down Expand Up @@ -2487,8 +2488,8 @@ def test_idempotency_on_comment_with_multiline_docstring(self):
def my_pipeline(sample_input1: bool = True,
sample_input2: str = 'string') -> str:
"""docstring short description.
docstring long description.
docstring long description.

docstring long description. docstring long description.
"""
op1 = my_comp(string=sample_input2, model=sample_input1)
result = op1.output
Expand Down Expand Up @@ -4144,7 +4145,7 @@ def my_pipeline(
string: str,
in_artifact: Input[Artifact],
) -> Outputs:
"""Pipeline description. Returns
"""Pipeline description. Returns.

Args:
string: Return Pipeline input string. Returns
Expand Down Expand Up @@ -4607,7 +4608,9 @@ class TestDslOneOf(unittest.TestCase):
# To help narrow the tests further (we already test lots of aspects in the following cases), we choose focus on the dsl.OneOf behavior, not the conditional logic if If/Elif/Else. This is more verbose, but more maintainable and the behavior under test is clearer.

def test_if_else_returned(self):
"""Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels."""
"""Uses If and Else branches, parameters passed to dsl.OneOf, dsl.OneOf
returned from a pipeline, and different output keys on dsl.OneOf
channels."""

@dsl.pipeline
def roll_die_pipeline() -> str:
Expand Down Expand Up @@ -4668,7 +4671,9 @@ def roll_die_pipeline() -> str:
)

def test_if_elif_else_returned(self):
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf returned from a pipeline, and different output keys on dsl.OneOf channels."""
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf,
dsl.OneOf returned from a pipeline, and different output keys on
dsl.OneOf channels."""

@dsl.pipeline
def roll_die_pipeline() -> str:
Expand Down Expand Up @@ -4743,7 +4748,9 @@ def roll_die_pipeline() -> str:
)

def test_if_elif_else_consumed(self):
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, dsl.OneOf passed to a consumer task, and different output keys on dsl.OneOf channels."""
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf,
dsl.OneOf passed to a consumer task, and different output keys on
dsl.OneOf channels."""

@dsl.pipeline
def roll_die_pipeline():
Expand Down Expand Up @@ -4820,7 +4827,9 @@ def roll_die_pipeline():
)

def test_if_else_consumed_and_returned(self):
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline."""
"""Uses If, Elif, and Else branches, parameters passed to dsl.OneOf,
and dsl.OneOf passed to a consumer task and returned from the
pipeline."""

@dsl.pipeline
def flip_coin_pipeline() -> str:
Expand Down Expand Up @@ -4893,7 +4902,8 @@ def flip_coin_pipeline() -> str:
)

def test_if_else_consumed_and_returned_artifacts(self):
"""Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and dsl.OneOf passed to a consumer task and returned from the pipeline."""
"""Uses If, Elif, and Else branches, artifacts passed to dsl.OneOf, and
dsl.OneOf passed to a consumer task and returned from the pipeline."""

@dsl.pipeline
def flip_coin_pipeline() -> Artifact:
Expand Down Expand Up @@ -5060,7 +5070,8 @@ def flip_coin_pipeline(execute_pipeline: bool):
print_task_2.outputs['a'])

def test_deeply_nested_consumed(self):
"""Uses If, Elif, Else, and OneOf deeply nested within multiple dub-DAGs."""
"""Uses If, Elif, Else, and OneOf deeply nested within multiple dub-
DAGs."""

@dsl.pipeline
def flip_coin_pipeline(execute_pipeline: bool):
Expand Down Expand Up @@ -5159,7 +5170,8 @@ def flip_coin_pipeline(execute_pipeline: bool):
print_task_2.outputs['a'])

def test_oneof_in_condition(self):
"""Tests that dsl.OneOf's channel can be consumed in a downstream group nested one level"""
"""Tests that dsl.OneOf's channel can be consumed in a downstream group
nested one level."""

@dsl.pipeline
def roll_die_pipeline(repeat_on: str = 'Got heads!'):
Expand Down Expand Up @@ -5212,7 +5224,8 @@ def roll_die_pipeline(repeat_on: str = 'Got heads!'):
)

def test_consumed_in_nested_groups(self):
"""Tests that dsl.OneOf's channel can be consumed in a downstream group nested multiple levels"""
"""Tests that dsl.OneOf's channel can be consumed in a downstream group
nested multiple levels."""

@dsl.pipeline
def roll_die_pipeline(
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/kfp/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ def get_dependencies(
# then make this validation dsl.Collected-aware
elif isinstance(upstream_parent_group, tasks_group.ParallelFor):
upstream_tasks_that_downstream_consumers_from = [
channel.task.name for channel in task._channel_inputs
channel.task.name
for channel in task._channel_inputs
if channel.task
]
has_data_exchange = upstream_task.name in upstream_tasks_that_downstream_consumers_from
# don't raise for .after
Expand Down
Loading