Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions sdk/python/kfp/dsl/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,29 @@ def assign_input_and_output_artifacts(self) -> None:
is_list_of_artifacts = (
type_annotations.is_Input_Output_artifact_annotation(
annotation) and
type_annotations.is_list_of_artifacts(annotation.__origin__)
type_annotations.is_list_of_artifacts(
annotation.__args__[0])
) or type_annotations.is_list_of_artifacts(annotation)
if is_list_of_artifacts:
# Get the annotation of the inner type of the list
# to use when creating the artifacts
inner_annotation = type_annotations.get_inner_type(
annotation)
# For Input[List[Dataset]], we need to extract Dataset,
# not List[Dataset]
if type_annotations.is_Input_Output_artifact_annotation(
annotation):
# Strip Input/Output wrapper first
inner_type = (
type_annotations.strip_Input_or_Output_marker(
annotation))
# Then extract inner type from List if it's a list
if type_annotations.is_list_of_artifacts(inner_type):
inner_annotation = (
type_annotations.get_inner_type(inner_type))
else:
inner_annotation = inner_type
else:
inner_annotation = type_annotations.get_inner_type(
annotation)

self.input_artifacts[name] = [
self.make_artifact(
Expand Down Expand Up @@ -124,6 +140,11 @@ def make_artifact(
)
else:
artifact_cls = annotation
# Handle case where annotation is a List type
# (e.g., List[Dataset])
# Extract the inner type if it's a list
if type_annotations.is_list_of_artifacts(artifact_cls):
artifact_cls = type_annotations.get_inner_type(artifact_cls)
return create_artifact_instance(
runtime_artifact, fallback_artifact_cls=artifact_cls)

Expand Down Expand Up @@ -424,6 +445,13 @@ def create_artifact_instance(
) -> type:
"""Creates an artifact class instances from a runtime artifact
dictionary."""
# Handle case where fallback_artifact_cls is a List type
# (e.g., List[Dataset])
# Extract the inner type if it's a list
if type_annotations.is_list_of_artifacts(fallback_artifact_cls):
fallback_artifact_cls = type_annotations.get_inner_type(
fallback_artifact_cls)

schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '')
artifact_cls = artifact_types._SCHEMA_TITLE_TO_TYPE.get(
schema_title, fallback_artifact_cls)
Expand Down
48 changes: 48 additions & 0 deletions sdk/python/kfp/dsl/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1604,6 +1604,54 @@ def test_func(input_list: Input[List[Artifact]]):

self.assertDictEqual(output_metadata, {})

def test_list_of_datasets_input(self):
"""Test for Input[List[Dataset]] which previously caused a crash."""
executor_input = """\
{
"inputs": {
"artifacts": {
"input_list": {
"artifacts": [
{
"metadata": {},
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0",
"type": {
"schemaTitle": "system.Dataset"
},
"uri": "gs://some-bucket/output/input_list/0"
},
{
"metadata": {},
"name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1",
"type": {
"schemaTitle": "system.Dataset"
},
"uri": "gs://some-bucket/output/input_list/1"
}
]
}
}
},
"outputs": {
"outputFile": "%(test_dir)s/output_metadata.json"
}
}
"""

def test_func(input_list: Input[List[Dataset]]):
self.assertEqual(len(input_list), 2)
self.assertIsInstance(input_list[0], Dataset)
self.assertIsInstance(input_list[1], Dataset)
self.assertEqual(
input_list[0].name,
'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0'
)

output_metadata = self.execute_and_load_output_metadata(
test_func, executor_input)

self.assertDictEqual(output_metadata, {})

def test_list_of_artifacts_input_pythonic(self):
executor_input = """\
{
Expand Down
Loading