-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
426 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,225 @@ | ||
diff --git a/sdk/python/kfp/cli/compile_.py b/sdk/python/kfp/cli/compile_.py | ||
index 2bd3bab18..d1e84fd72 100644 | ||
--- a/sdk/python/kfp/cli/compile_.py | ||
+++ b/sdk/python/kfp/cli/compile_.py | ||
@@ -133,12 +133,19 @@ def parse_parameters(parameters: Optional[str]) -> Dict: | ||
is_flag=True, | ||
default=False, | ||
help='Whether to disable type checking.') | ||
+@click.option( | ||
+ '--disable-execution-caching-by-default', | ||
+ is_flag=True, | ||
+ default=False, | ||
+ envvar='KFP_DISABLE_EXECUTION_CACHING_BY_DEFAULT', | ||
+ help='Whether to disable execution caching by default.') | ||
def compile_( | ||
py: str, | ||
output: str, | ||
function_name: Optional[str] = None, | ||
pipeline_parameters: Optional[str] = None, | ||
disable_type_check: bool = False, | ||
+ disable_execution_caching_by_default: bool = False, | ||
) -> None: | ||
"""Compiles a pipeline or component written in a .py file.""" | ||
pipeline_func = collect_pipeline_or_component_func( | ||
@@ -149,7 +156,8 @@ def compile_( | ||
pipeline_func=pipeline_func, | ||
pipeline_parameters=parsed_parameters, | ||
package_path=package_path, | ||
- type_check=not disable_type_check) | ||
+ type_check=not disable_type_check, | ||
+ execution_caching_default=not disable_execution_caching_by_default) | ||
|
||
click.echo(package_path) | ||
|
||
diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py | ||
index a77f606e8..9db1abaa0 100644 | ||
--- a/sdk/python/kfp/compiler/compiler.py | ||
+++ b/sdk/python/kfp/compiler/compiler.py | ||
@@ -22,7 +22,7 @@ from typing import Any, Dict, Optional | ||
from kfp.compiler import pipeline_spec_builder as builder | ||
from kfp.dsl import base_component | ||
from kfp.dsl.types import type_utils | ||
- | ||
+from kfp.dsl.pipeline_context import Pipeline | ||
|
||
class Compiler: | ||
"""Compiles pipelines composed using the KFP SDK DSL to a YAML pipeline | ||
@@ -53,10 +53,12 @@ class Compiler: | ||
pipeline_name: Optional[str] = None, | ||
pipeline_parameters: Optional[Dict[str, Any]] = None, | ||
type_check: bool = True, | ||
+ execution_caching_default: bool = True | ||
) -> None: | ||
"""Compiles the pipeline or component function into IR YAML. | ||
|
||
Args: | ||
+ execution_caching_default: | ||
pipeline_func: Pipeline function constructed with the ``@dsl.pipeline`` or component constructed with the ``@dsl.component`` decorator. | ||
package_path: Output YAML file path. For example, ``'~/my_pipeline.yaml'`` or ``'~/my_component.yaml'``. | ||
pipeline_name: Name of the pipeline. | ||
@@ -72,11 +74,12 @@ class Compiler: | ||
'`Callable` constructed with @dsl.pipeline ' | ||
f'decorator. Got: {type(pipeline_func)}') | ||
|
||
- pipeline_spec = builder.modify_pipeline_spec_with_override( | ||
- pipeline_spec=pipeline_func.pipeline_spec, | ||
- pipeline_name=pipeline_name, | ||
- pipeline_parameters=pipeline_parameters, | ||
- ) | ||
+ with Pipeline(execution_caching_default=execution_caching_default): | ||
+ pipeline_spec = builder.modify_pipeline_spec_with_override( | ||
+ pipeline_spec=pipeline_func.pipeline_spec, | ||
+ pipeline_name=pipeline_name, | ||
+ pipeline_parameters=pipeline_parameters, | ||
+ ) | ||
|
||
builder.write_pipeline_spec_to_file( | ||
pipeline_spec=pipeline_spec, | ||
diff --git a/sdk/python/kfp/compiler/pipeline_spec_builder.py b/sdk/python/kfp/compiler/pipeline_spec_builder.py | ||
index 6d2a0cfa9..2a53d6277 100644 | ||
--- a/sdk/python/kfp/compiler/pipeline_spec_builder.py | ||
+++ b/sdk/python/kfp/compiler/pipeline_spec_builder.py | ||
@@ -86,6 +86,7 @@ def build_task_spec_for_task( | ||
task: pipeline_task.PipelineTask, | ||
parent_component_inputs: pipeline_spec_pb2.ComponentInputsSpec, | ||
tasks_in_current_dag: List[str], | ||
+ execution_caching_default: bool = True, | ||
) -> pipeline_spec_pb2.PipelineTaskSpec: | ||
"""Builds PipelineTaskSpec for a pipeline task. | ||
|
||
@@ -106,6 +107,7 @@ def build_task_spec_for_task( | ||
producer task. | ||
|
||
Args: | ||
+ execution_caching_default: | ||
task: The task to build a PipelineTaskSpec for. | ||
parent_component_inputs: The task's parent component's input specs. | ||
tasks_in_current_dag: The list of tasks names for tasks in the same dag. | ||
@@ -122,7 +124,9 @@ def build_task_spec_for_task( | ||
pipeline_task_spec.component_ref.name = ( | ||
utils.sanitize_component_name(task.name)) | ||
pipeline_task_spec.caching_options.enable_cache = ( | ||
- task._task_spec.enable_caching) | ||
+ task._task_spec.enable_caching if task._task_spec.enable_caching is not None | ||
+ else execution_caching_default | ||
+ ) | ||
|
||
if task._task_spec.retry_policy is not None: | ||
pipeline_task_spec.retry_policy.CopyFrom( | ||
@@ -1218,10 +1222,12 @@ def build_spec_by_group( | ||
name_to_for_loop_group: Mapping[str, tasks_group.ParallelFor], | ||
platform_spec: pipeline_spec_pb2.PlatformSpec, | ||
is_compiled_component: bool, | ||
+ execution_caching_default: bool = True, | ||
) -> None: | ||
"""Generates IR spec given a TasksGroup. | ||
|
||
Args: | ||
+ execution_caching_default: | ||
pipeline_spec: The pipeline_spec to update in place. | ||
deployment_config: The deployment_config to hold all executors. The | ||
spec is updated in place. | ||
@@ -1276,6 +1282,7 @@ def build_spec_by_group( | ||
task=subgroup, | ||
parent_component_inputs=group_component_spec.input_definitions, | ||
tasks_in_current_dag=tasks_in_current_dag, | ||
+ execution_caching_default=execution_caching_default, | ||
) | ||
task_name_to_task_spec[subgroup.name] = subgroup_task_spec | ||
subgroup_component_spec = build_component_spec_for_task( | ||
@@ -1850,10 +1857,12 @@ def create_pipeline_spec( | ||
pipeline: pipeline_context.Pipeline, | ||
component_spec: structures.ComponentSpec, | ||
pipeline_outputs: Optional[Any] = None, | ||
+ execution_caching_default: bool = True, | ||
) -> Tuple[pipeline_spec_pb2.PipelineSpec, pipeline_spec_pb2.PlatformSpec]: | ||
"""Creates a pipeline spec object. | ||
|
||
Args: | ||
+ execution_caching_default: | ||
pipeline: The instantiated pipeline object. | ||
component_spec: The component spec structures. | ||
pipeline_outputs: The pipeline outputs via return. | ||
@@ -1932,6 +1941,7 @@ def create_pipeline_spec( | ||
name_to_for_loop_group=name_to_for_loop_group, | ||
platform_spec=platform_spec, | ||
is_compiled_component=False, | ||
+ execution_caching_default=execution_caching_default, | ||
) | ||
|
||
build_exit_handler_groups_recursively( | ||
diff --git a/sdk/python/kfp/dsl/component_factory.py b/sdk/python/kfp/dsl/component_factory.py | ||
index c649424ba..babaebacd 100644 | ||
--- a/sdk/python/kfp/dsl/component_factory.py | ||
+++ b/sdk/python/kfp/dsl/component_factory.py | ||
@@ -676,6 +676,7 @@ def create_graph_component_from_func( | ||
name: Optional[str] = None, | ||
description: Optional[str] = None, | ||
display_name: Optional[str] = None, | ||
+ execution_caching_default: bool = True, | ||
) -> graph_component.GraphComponent: | ||
"""Implementation for the @pipeline decorator. | ||
|
||
@@ -692,6 +693,7 @@ def create_graph_component_from_func( | ||
component_spec=component_spec, | ||
pipeline_func=func, | ||
display_name=display_name, | ||
+ execution_caching_default=execution_caching_default, | ||
) | ||
|
||
|
||
diff --git a/sdk/python/kfp/dsl/graph_component.py b/sdk/python/kfp/dsl/graph_component.py | ||
index 2b09927df..3217d35b4 100644 | ||
--- a/sdk/python/kfp/dsl/graph_component.py | ||
+++ b/sdk/python/kfp/dsl/graph_component.py | ||
@@ -37,9 +37,11 @@ class GraphComponent(base_component.BaseComponent): | ||
component_spec: structures.ComponentSpec, | ||
pipeline_func: Callable, | ||
display_name: Optional[str] = None, | ||
+ execution_caching_default: bool = True, | ||
): | ||
super().__init__(component_spec=component_spec) | ||
self.pipeline_func = pipeline_func | ||
+ self.execution_caching_default = execution_caching_default | ||
|
||
args_list = [] | ||
signature = inspect.signature(pipeline_func) | ||
@@ -54,7 +56,7 @@ class GraphComponent(base_component.BaseComponent): | ||
)) | ||
|
||
with pipeline_context.Pipeline( | ||
- self.component_spec.name) as dsl_pipeline: | ||
+ self.component_spec.name, execution_caching_default=execution_caching_default) as dsl_pipeline: | ||
pipeline_outputs = pipeline_func(*args_list) | ||
|
||
if not dsl_pipeline.tasks: | ||
@@ -69,6 +71,7 @@ class GraphComponent(base_component.BaseComponent): | ||
pipeline=dsl_pipeline, | ||
component_spec=self.component_spec, | ||
pipeline_outputs=pipeline_outputs, | ||
+ execution_caching_default=self.execution_caching_default, | ||
) | ||
|
||
pipeline_root = getattr(pipeline_func, 'pipeline_root', None) | ||
diff --git a/sdk/python/kfp/dsl/pipeline_context.py b/sdk/python/kfp/dsl/pipeline_context.py | ||
index 4881bc568..2ab1e0da8 100644 | ||
--- a/sdk/python/kfp/dsl/pipeline_context.py | ||
+++ b/sdk/python/kfp/dsl/pipeline_context.py | ||
@@ -101,13 +101,15 @@ class Pipeline: | ||
"""Gets the default pipeline.""" | ||
return Pipeline._default_pipeline | ||
|
||
- def __init__(self, name: str): | ||
+ def __init__(self, name: str, execution_caching_default: bool = True): | ||
"""Creates a new instance of Pipeline. | ||
|
||
Args: | ||
name: The name of the pipeline. | ||
+ execution_caching_default: Whether caching is enabled for the tasks by default. | ||
""" | ||
self.name = name | ||
+ self.execution_caching_default = execution_caching_default | ||
self.tasks = {} | ||
# Add the root group. | ||
self.groups = [ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# PIPELINE DEFINITION | ||
# Name: tiny-pipeline | ||
components: | ||
comp-my-component: | ||
executorLabel: exec-my-component | ||
deploymentSpec: | ||
executors: | ||
exec-my-component: | ||
container: | ||
args: | ||
- --executor_input | ||
- '{{$}}' | ||
- --function_to_execute | ||
- my_component | ||
command: | ||
- sh | ||
- -c | ||
- "\nif ! [ -x \"$(command -v pip)\" ]; then\n python3 -m ensurepip ||\ | ||
\ python3 -m ensurepip --user || apt-get install python3-pip\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1\ | ||
\ python3 -m pip install --quiet --no-warn-script-location 'kfp==2.9.0'\ | ||
\ '--no-deps' 'typing-extensions>=3.7.4,<5; python_version<\"3.9\"' && \"\ | ||
$0\" \"$@\"\n" | ||
- sh | ||
- -ec | ||
- 'program_path=$(mktemp -d) | ||
printf "%s" "$0" > "$program_path/ephemeral_component.py" | ||
_KFP_RUNTIME=true python3 -m kfp.dsl.executor_main --component_module_path "$program_path/ephemeral_component.py" "$@" | ||
' | ||
- "\nimport kfp\nfrom kfp import dsl\nfrom kfp.dsl import *\nfrom typing import\ | ||
\ *\n\ndef my_component():\n pass\n\n" | ||
image: python:3.9 | ||
pipelineInfo: | ||
name: tiny-pipeline | ||
root: | ||
dag: | ||
tasks: | ||
my-component: | ||
cachingOptions: | ||
enableCache: true | ||
componentRef: | ||
name: comp-my-component | ||
taskInfo: | ||
name: my-component | ||
schemaVersion: 2.1.0 | ||
sdkVersion: kfp-2.9.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from kfp import compiler, dsl | ||
|
||
common_base_image = "registry.redhat.io/ubi8/python-39@sha256:3523b184212e1f2243e76d8094ab52b01ea3015471471290d011625e1763af61" | ||
|
||
|
||
@dsl.component(base_image=common_base_image) | ||
def print_message(message: str): | ||
"""Prints a message""" | ||
print(message + " (step 1)") | ||
|
||
|
||
@dsl.pipeline(name="version-test-pipeline", description="Pipeline that prints a hello message") | ||
def version_test_pipeline(message: str = "Hello world"): | ||
print_message_task = print_message(message=message) | ||
|
||
|
||
if __name__ == "__main__": | ||
compiler.Compiler().compile(version_test_pipeline, | ||
package_path=__file__.replace(".py", "_compiled.yaml")) |
Oops, something went wrong.