diff --git a/sdk/python/kfp/dsl/executor.py b/sdk/python/kfp/dsl/executor.py index d4886488c4d..93ffa6241ab 100644 --- a/sdk/python/kfp/dsl/executor.py +++ b/sdk/python/kfp/dsl/executor.py @@ -68,13 +68,29 @@ def assign_input_and_output_artifacts(self) -> None: is_list_of_artifacts = ( type_annotations.is_Input_Output_artifact_annotation( annotation) and - type_annotations.is_list_of_artifacts(annotation.__origin__) + type_annotations.is_list_of_artifacts( + annotation.__args__[0]) ) or type_annotations.is_list_of_artifacts(annotation) if is_list_of_artifacts: # Get the annotation of the inner type of the list # to use when creating the artifacts - inner_annotation = type_annotations.get_inner_type( - annotation) + # For Input[List[Dataset]], we need to extract Dataset, + # not List[Dataset] + if type_annotations.is_Input_Output_artifact_annotation( + annotation): + # Strip Input/Output wrapper first + inner_type = ( + type_annotations.strip_Input_or_Output_marker( + annotation)) + # Then extract inner type from List if it's a list + if type_annotations.is_list_of_artifacts(inner_type): + inner_annotation = ( + type_annotations.get_inner_type(inner_type)) + else: + inner_annotation = inner_type + else: + inner_annotation = type_annotations.get_inner_type( + annotation) self.input_artifacts[name] = [ self.make_artifact( @@ -124,6 +140,11 @@ def make_artifact( ) else: artifact_cls = annotation + # Handle case where annotation is a List type + # (e.g., List[Dataset]) + # Extract the inner type if it's a list + if type_annotations.is_list_of_artifacts(artifact_cls): + artifact_cls = type_annotations.get_inner_type(artifact_cls) return create_artifact_instance( runtime_artifact, fallback_artifact_cls=artifact_cls) @@ -424,6 +445,13 @@ def create_artifact_instance( ) -> type: """Creates an artifact class instances from a runtime artifact dictionary.""" + # Handle case where fallback_artifact_cls is a List type + # (e.g., List[Dataset]) + # Extract the inner type if it's a list + if type_annotations.is_list_of_artifacts(fallback_artifact_cls): + fallback_artifact_cls = type_annotations.get_inner_type( + fallback_artifact_cls) + schema_title = runtime_artifact.get('type', {}).get('schemaTitle', '') artifact_cls = artifact_types._SCHEMA_TITLE_TO_TYPE.get( schema_title, fallback_artifact_cls) diff --git a/sdk/python/kfp/dsl/executor_test.py b/sdk/python/kfp/dsl/executor_test.py index 8492111c8c1..a4a9a595082 100644 --- a/sdk/python/kfp/dsl/executor_test.py +++ b/sdk/python/kfp/dsl/executor_test.py @@ -1604,6 +1604,54 @@ def test_func(input_list: Input[List[Artifact]]): self.assertDictEqual(output_metadata, {}) + def test_list_of_datasets_input(self): + """Test for Input[List[Dataset]] which previously caused a crash.""" + executor_input = """\ + { + "inputs": { + "artifacts": { + "input_list": { + "artifacts": [ + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0", + "type": { + "schemaTitle": "system.Dataset" + }, + "uri": "gs://some-bucket/output/input_list/0" + }, + { + "metadata": {}, + "name": "projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/1", + "type": { + "schemaTitle": "system.Dataset" + }, + "uri": "gs://some-bucket/output/input_list/1" + } + ] + } + } + }, + "outputs": { + "outputFile": "%(test_dir)s/output_metadata.json" + } + } + """ + + def test_func(input_list: Input[List[Dataset]]): + self.assertEqual(len(input_list), 2) + self.assertIsInstance(input_list[0], Dataset) + self.assertIsInstance(input_list[1], Dataset) + self.assertEqual( + input_list[0].name, + 'projects/123/locations/us-central1/metadataStores/default/artifacts/input_list/0' + ) + + output_metadata = self.execute_and_load_output_metadata( + test_func, executor_input) + + self.assertDictEqual(output_metadata, {}) + def test_list_of_artifacts_input_pythonic(self): executor_input = """\ {