From 719f64b996a5cebf723d8a2571ff33cf2a6b9461 Mon Sep 17 00:00:00 2001 From: champon1020 Date: Fri, 31 Jan 2025 02:03:43 +0900 Subject: [PATCH] feat: add a volumes argument to local.DockerRunner Signed-off-by: champon1020 --- sdk/python/kfp/local/config.py | 13 +- sdk/python/kfp/local/config_test.py | 7 +- sdk/python/kfp/local/docker_task_handler.py | 9 +- .../kfp/local/docker_task_handler_test.py | 160 ++++++++++++------ 4 files changed, 133 insertions(+), 56 deletions(-) diff --git a/sdk/python/kfp/local/config.py b/sdk/python/kfp/local/config.py index 9ea01d18369..2d56598448a 100755 --- a/sdk/python/kfp/local/config.py +++ b/sdk/python/kfp/local/config.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. """Objects for configuring local execution.""" + import abc import dataclasses +from dataclasses import field import os -from typing import Union +from typing import Dict, Union from kfp import local @@ -42,13 +44,20 @@ class SubprocessRunner: Args: use_venv: Whether to run the subprocess in a virtual environment. If True, dependencies will be installed in the virtual environment. If False, dependencies will be installed in the current environment. Using a virtual environment is recommended. """ + use_venv: bool = True @dataclasses.dataclass class DockerRunner: """Runner that indicates that local tasks should be run as a Docker - container.""" + container. + + Args: + volumes: Additional volumes you wnat to mount to task containers. + """ + + volumes: Dict[str, Dict[str, str]] = field(default_factory=dict) def __post_init__(self): try: diff --git a/sdk/python/kfp/local/config_test.py b/sdk/python/kfp/local/config_test.py index ad71cf5a6ac..8f389f1b1fb 100755 --- a/sdk/python/kfp/local/config_test.py +++ b/sdk/python/kfp/local/config_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tests for config.py.""" + import os import unittest from unittest import mock @@ -70,7 +71,7 @@ def test_validate_success(self): def test_validate_fail(self): with self.assertRaisesRegex( RuntimeError, - r"Local environment not initialized. Please run 'kfp\.local\.init\(\)' before executing tasks locally\." + r"Local environment not initialized. Please run 'kfp\.local\.init\(\)' before executing tasks locally\.", ): config.LocalExecutionConfig.validate() @@ -116,7 +117,7 @@ def test_runner_validation(self): """Test config instance attributes with multiple init() calls.""" with self.assertRaisesRegex( ValueError, - r'Got unknown runner foo of type str\. Runner should be one of the following types: SubprocessRunner\.' + r'Got unknown runner foo of type str\. Runner should be one of the following types: SubprocessRunner\.', ): local.init(runner='foo') @@ -127,7 +128,7 @@ def test_import_error(self): with mock.patch.dict('sys.modules', {'docker': None}): with self.assertRaisesRegex( ImportError, - r"Package 'docker' must be installed to use 'DockerRunner'\. Install it using 'pip install docker'\." + r"Package 'docker' must be installed to use 'DockerRunner'\. Install it using 'pip install docker'\.", ): local.DockerRunner() diff --git a/sdk/python/kfp/local/docker_task_handler.py b/sdk/python/kfp/local/docker_task_handler.py index 08d48b35b73..af5776834b0 100755 --- a/sdk/python/kfp/local/docker_task_handler.py +++ b/sdk/python/kfp/local/docker_task_handler.py @@ -43,13 +43,20 @@ def get_volumes_to_mount(self) -> Dict[str, Any]: raise ValueError( "'pipeline_root' should be an absolute path to correctly construct the volume mount specification." ) - return {self.pipeline_root: {'bind': self.pipeline_root, 'mode': 'rw'}} + default_volume = { + self.pipeline_root: { + 'bind': self.pipeline_root, + 'mode': 'rw' + } + } + return default_volume | self.runner.volumes def run(self) -> status.Status: """Runs the Docker container and returns the status.""" # nest docker import in case not available in user env so that # this module is runnable, even if not using DockerRunner import docker + client = docker.from_env() try: volumes = self.get_volumes_to_mount() diff --git a/sdk/python/kfp/local/docker_task_handler_test.py b/sdk/python/kfp/local/docker_task_handler_test.py index 71f8be21361..c8e10e56e2a 100755 --- a/sdk/python/kfp/local/docker_task_handler_test.py +++ b/sdk/python/kfp/local/docker_task_handler_test.py @@ -87,7 +87,8 @@ def test_cwd_volume(self): volumes={current_test_dir: { 'bind': '/localdir', 'mode': 'ro' - }}) + }}, + ) class TestDockerTaskHandler(DockerMockTestCase): @@ -101,12 +102,40 @@ def test_get_volumes_to_mount(self): ) volumes = handler.get_volumes_to_mount() self.assertEqual( - volumes, { + volumes, + { os.path.abspath('my_root'): { 'bind': os.path.abspath('my_root'), - 'mode': 'rw' + 'mode': 'rw', } - }) + }, + ) + + def test_get_volumes_to_mount_with_custom_volumes(self): + handler = docker_task_handler.DockerTaskHandler( + image='alpine', + full_command=['echo', 'foo'], + pipeline_root=os.path.abspath('my_root'), + runner=local.DockerRunner( + volumes={'/path/to/dir': { + 'bind': '/mnt/vol', + 'mode': 'rw' + }}), + ) + volumes = handler.get_volumes_to_mount() + self.assertEqual( + volumes, + { + os.path.abspath('my_root'): { + 'bind': os.path.abspath('my_root'), + 'mode': 'rw', + }, + '/path/to/dir': { + 'bind': '/mnt/vol', + 'mode': 'rw' + }, + }, + ) def test_run(self): handler = docker_task_handler.DockerTaskHandler( @@ -126,7 +155,7 @@ def test_run(self): volumes={ os.path.abspath('my_root'): { 'bind': os.path.abspath('my_root'), - 'mode': 'rw' + 'mode': 'rw', } }, ) @@ -134,7 +163,7 @@ def test_run(self): def test_pipeline_root_relpath(self): with self.assertRaisesRegex( ValueError, - r"'pipeline_root' should be an absolute path to correctly construct the volume mount specification\." + r"'pipeline_root' should be an absolute path to correctly construct the volume mount specification\.", ): docker_task_handler.DockerTaskHandler( image='alpine', @@ -246,11 +275,14 @@ def artifact_maker(x: str,): kwargs['image'], 'alpine:latest', ) - self.assertEqual(kwargs['command'], [ - 'sh', - '-c', - 'echo prefix-foo', - ]) + self.assertEqual( + kwargs['command'], + [ + 'sh', + '-c', + 'echo prefix-foo', + ], + ) self.assertTrue(kwargs['detach']) self.assertTrue(kwargs['stdout']) self.assertTrue(kwargs['stderr']) @@ -270,8 +302,10 @@ def comp(x: Optional[str] = None): dsl.IfPresentPlaceholder( input_name='x', then=['echo', x], - else_=['echo', 'No input provided!']) - ]) + else_=['echo', 'No input provided!'], + ) + ], + ) comp() @@ -282,10 +316,13 @@ def comp(x: Optional[str] = None): kwargs['image'], 'alpine:3.19.0', ) - self.assertEqual(kwargs['command'], [ - 'echo', - 'No input provided!', - ]) + self.assertEqual( + kwargs['command'], + [ + 'echo', + 'No input provided!', + ], + ) def test_if_present_with_string_provided(self): @@ -297,8 +334,10 @@ def comp(x: Optional[str] = None): dsl.IfPresentPlaceholder( input_name='x', then=['echo', x], - else_=['echo', 'No artifact provided!']) - ]) + else_=['echo', 'No artifact provided!'], + ) + ], + ) comp(x='foo') @@ -309,10 +348,13 @@ def comp(x: Optional[str] = None): kwargs['image'], 'alpine:3.19.0', ) - self.assertEqual(kwargs['command'], [ - 'echo', - 'foo', - ]) + self.assertEqual( + kwargs['command'], + [ + 'echo', + 'foo', + ], + ) def test_if_present_single_element_with_string_omitted(self): @@ -326,8 +368,9 @@ def comp(x: Optional[str] = None): input_name='x', then=x, else_='No artifact provided!', - ) - ]) + ), + ], + ) comp() @@ -338,10 +381,13 @@ def comp(x: Optional[str] = None): kwargs['image'], 'alpine:3.19.0', ) - self.assertEqual(kwargs['command'], [ - 'echo', - 'No artifact provided!', - ]) + self.assertEqual( + kwargs['command'], + [ + 'echo', + 'No artifact provided!', + ], + ) def test_if_present_single_element_with_string_provided(self): @@ -355,8 +401,9 @@ def comp(x: Optional[str] = None): input_name='x', then=x, else_='No artifact provided!', - ) - ]) + ), + ], + ) comp(x='foo') @@ -367,10 +414,13 @@ def comp(x: Optional[str] = None): kwargs['image'], 'alpine:3.19.0', ) - self.assertEqual(kwargs['command'], [ - 'echo', - 'foo', - ]) + self.assertEqual( + kwargs['command'], + [ + 'echo', + 'foo', + ], + ) def test_concat_placeholder(self): @@ -378,7 +428,8 @@ def test_concat_placeholder(self): def comp(x: Optional[str] = None): return dsl.ContainerSpec( image='alpine', - command=[dsl.ConcatPlaceholder(['prefix-', x, '-suffix'])]) + command=[dsl.ConcatPlaceholder(['prefix-', x, '-suffix'])], + ) comp() @@ -400,8 +451,9 @@ def comp(x: Optional[str] = None): command=[ 'echo', dsl.ConcatPlaceholder( - ['a', dsl.ConcatPlaceholder(['b', x, 'd'])]) - ]) + ['a', dsl.ConcatPlaceholder(['b', x, 'd'])]), + ], + ) comp(x='c') @@ -429,10 +481,12 @@ def comp(x: Optional[str] = None): dsl.IfPresentPlaceholder( input_name='x', then='one thing', - else_='another thing') - ]) - ]) - ]) + else_='another thing', + ), + ]), + ]), + ], + ) comp(x='c') @@ -460,10 +514,12 @@ def comp(x: Optional[str] = None): dsl.IfPresentPlaceholder( input_name='x', then='one thing', - else_='another thing') - ]) - ]) - ]) + else_='another thing', + ), + ]), + ]), + ], + ) comp() @@ -487,8 +543,10 @@ def comp(x: Optional[str] = None): dsl.IfPresentPlaceholder( input_name='x', then=dsl.ConcatPlaceholder([x]), - else_=dsl.ConcatPlaceholder(['something', ' ', 'else'])) - ]) + else_=dsl.ConcatPlaceholder(['something', ' ', 'else']), + ), + ], + ) comp(x='something') @@ -512,8 +570,10 @@ def comp(x: Optional[str] = None): dsl.IfPresentPlaceholder( input_name='x', then=dsl.ConcatPlaceholder([x]), - else_=dsl.ConcatPlaceholder(['another', ' ', 'thing'])) - ]) + else_=dsl.ConcatPlaceholder(['another', ' ', 'thing']), + ), + ], + ) comp()