Skip to content

Commit

Permalink
feat: add a volumes argument to local.DockerRunner
Browse files Browse the repository at this point in the history
Signed-off-by: champon1020 <[email protected]>
  • Loading branch information
champon1020 committed Jan 30, 2025
1 parent 37a7b4e commit 886f16d
Show file tree
Hide file tree
Showing 4 changed files with 289 additions and 212 deletions.
17 changes: 13 additions & 4 deletions sdk/python/kfp/local/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Objects for configuring local execution."""

import abc
import dataclasses
from dataclasses import field
import os
from typing import Union
from typing import Dict, Union

from kfp import local

Expand All @@ -42,13 +44,20 @@ class SubprocessRunner:
Args:
use_venv: Whether to run the subprocess in a virtual environment. If True, dependencies will be installed in the virtual environment. If False, dependencies will be installed in the current environment. Using a virtual environment is recommended.
"""

use_venv: bool = True


@dataclasses.dataclass
class DockerRunner:
"""Runner that indicates that local tasks should be run as a Docker
container."""
container.
Args:
volumes: Additional volumes you wnat to mount to task containers.
"""

volumes: Dict[str, Dict[str, str]] = field(default_factory=dict)

def __post_init__(self):
try:
Expand All @@ -67,7 +76,7 @@ def __new__(
runner: SubprocessRunner,
pipeline_root: str,
raise_on_error: bool,
) -> 'LocalExecutionConfig':
) -> "LocalExecutionConfig":
# singleton pattern
cls.instance = super(LocalExecutionConfig, cls).__new__(cls)
return cls.instance
Expand Down Expand Up @@ -98,7 +107,7 @@ def validate(cls):
def init(
# annotate with subclasses, not parent class, for more helpful ref docs
runner: Union[SubprocessRunner, DockerRunner],
pipeline_root: str = './local_outputs',
pipeline_root: str = "./local_outputs",
raise_on_error: bool = True,
) -> None:
"""Initializes a local execution session.
Expand Down
35 changes: 18 additions & 17 deletions sdk/python/kfp/local/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for config.py."""

import os
import unittest
from unittest import mock
Expand All @@ -28,40 +29,40 @@ def setUp(self):
def test_local_runner_config_init(self):
"""Test instance attributes with one constructor call."""
config.LocalExecutionConfig(
pipeline_root='my/local/root',
pipeline_root="my/local/root",
runner=local.SubprocessRunner(use_venv=True),
raise_on_error=True,
)

instance = config.LocalExecutionConfig.instance

self.assertEqual(instance.pipeline_root, 'my/local/root')
self.assertEqual(instance.pipeline_root, "my/local/root")
self.assertEqual(instance.runner, local.SubprocessRunner(use_venv=True))
self.assertIs(instance.raise_on_error, True)

def test_local_runner_config_is_singleton(self):
"""Test instance attributes with multiple constructor calls."""
config.LocalExecutionConfig(
pipeline_root='my/local/root',
pipeline_root="my/local/root",
runner=local.SubprocessRunner(),
raise_on_error=True,
)
config.LocalExecutionConfig(
pipeline_root='other/local/root',
pipeline_root="other/local/root",
runner=local.SubprocessRunner(use_venv=False),
raise_on_error=False,
)

instance = config.LocalExecutionConfig.instance

self.assertEqual(instance.pipeline_root, 'other/local/root')
self.assertEqual(instance.pipeline_root, "other/local/root")
self.assertEqual(instance.runner,
local.SubprocessRunner(use_venv=False))
self.assertFalse(instance.raise_on_error, False)

def test_validate_success(self):
config.LocalExecutionConfig(
pipeline_root='other/local/root',
pipeline_root="other/local/root",
runner=local.SubprocessRunner(use_venv=False),
raise_on_error=False,
)
Expand All @@ -70,7 +71,7 @@ def test_validate_success(self):
def test_validate_fail(self):
with self.assertRaisesRegex(
RuntimeError,
r"Local environment not initialized. Please run 'kfp\.local\.init\(\)' before executing tasks locally\."
r"Local environment not initialized. Please run 'kfp\.local\.init\(\)' before executing tasks locally\.",
):
config.LocalExecutionConfig.validate()

Expand All @@ -83,31 +84,31 @@ def setUp(self):
def test_init_more_than_once(self):
"""Tests config instance attributes with one init() call."""
local.init(
pipeline_root='my/local/root',
pipeline_root="my/local/root",
runner=local.SubprocessRunner(use_venv=True),
)

instance = config.LocalExecutionConfig.instance

self.assertEqual(instance.pipeline_root, 'my/local/root')
self.assertEqual(instance.pipeline_root, "my/local/root")
self.assertEqual(instance.runner, local.SubprocessRunner(use_venv=True))

def test_init_more_than_once(self):
"""Test config instance attributes with multiple init() calls."""
local.init(
pipeline_root='my/local/root',
pipeline_root="my/local/root",
runner=local.SubprocessRunner(),
)
local.init(
pipeline_root='other/local/root',
pipeline_root="other/local/root",
runner=local.SubprocessRunner(use_venv=False),
raise_on_error=False,
)

instance = config.LocalExecutionConfig.instance

self.assertEqual(instance.pipeline_root,
os.path.abspath('other/local/root'))
os.path.abspath("other/local/root"))
self.assertEqual(instance.runner,
local.SubprocessRunner(use_venv=False))
self.assertFalse(instance.raise_on_error, False)
Expand All @@ -116,21 +117,21 @@ def test_runner_validation(self):
"""Test config instance attributes with multiple init() calls."""
with self.assertRaisesRegex(
ValueError,
r'Got unknown runner foo of type str\. Runner should be one of the following types: SubprocessRunner\.'
r"Got unknown runner foo of type str\. Runner should be one of the following types: SubprocessRunner\.",
):
local.init(runner='foo')
local.init(runner="foo")


class TestDockerRunner(unittest.TestCase):

def test_import_error(self):
with mock.patch.dict('sys.modules', {'docker': None}):
with mock.patch.dict("sys.modules", {"docker": None}):
with self.assertRaisesRegex(
ImportError,
r"Package 'docker' must be installed to use 'DockerRunner'\. Install it using 'pip install docker'\."
r"Package 'docker' must be installed to use 'DockerRunner'\. Install it using 'pip install docker'\.",
):
local.DockerRunner()


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
27 changes: 17 additions & 10 deletions sdk/python/kfp/local/docker_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,20 @@ def get_volumes_to_mount(self) -> Dict[str, Any]:
raise ValueError(
"'pipeline_root' should be an absolute path to correctly construct the volume mount specification."
)
return {self.pipeline_root: {'bind': self.pipeline_root, 'mode': 'rw'}}
default_volume = {
self.pipeline_root: {
"bind": self.pipeline_root,
"mode": "rw"
}
}
return default_volume | self.runner.volumes

def run(self) -> status.Status:
"""Runs the Docker container and returns the status."""
# nest docker import in case not available in user env so that
# this module is runnable, even if not using DockerRunner
import docker

client = docker.from_env()
try:
volumes = self.get_volumes_to_mount()
Expand All @@ -65,13 +72,13 @@ def run(self) -> status.Status:


def add_latest_tag_if_not_present(image: str) -> str:
if ':' not in image:
image = f'{image}:latest'
if ":" not in image:
image = f"{image}:latest"
return image


def run_docker_container(
client: 'docker.DockerClient',
client: "docker.DockerClient",
image: str,
command: List[str],
volumes: Dict[str, Any],
Expand All @@ -80,12 +87,12 @@ def run_docker_container(
image_exists = any(
image in existing_image.tags for existing_image in client.images.list())
if image_exists:
print(f'Found image {image!r}\n')
print(f"Found image {image!r}\n")
else:
print(f'Pulling image {image!r}')
repository, tag = image.split(':')
print(f"Pulling image {image!r}")
repository, tag = image.split(":")
client.images.pull(repository=repository, tag=tag)
print('Image pull complete\n')
print("Image pull complete\n")
container = client.containers.run(
image=image,
command=command,
Expand All @@ -97,5 +104,5 @@ def run_docker_container(
for line in container.logs(stream=True):
# the inner logs should already have trailing \n
# we do not need to add another
print(line.decode(), end='')
return container.wait()['StatusCode']
print(line.decode(), end="")
return container.wait()["StatusCode"]
Loading

0 comments on commit 886f16d

Please sign in to comment.