diff --git a/caikit_ray_backend/blocks/ray_train.py b/caikit_ray_backend/blocks/ray_train.py index abff48c..4ea8bcc 100644 --- a/caikit_ray_backend/blocks/ray_train.py +++ b/caikit_ray_backend/blocks/ray_train.py @@ -218,6 +218,9 @@ def train( error.value_check("", num_gpus > 0) env_vars["requested_gpus"] = num_gpus + training_timeout = self.config.get("training_timeout", 60) + env_vars["training_timeout"] = float(training_timeout) + # Serialize **kwargs and add them to environment variables my_kwargs = {} for key, value in kwargs.items(): diff --git a/caikit_ray_backend/ray_submitter.py b/caikit_ray_backend/ray_submitter.py index f7cad5a..0160b7a 100644 --- a/caikit_ray_backend/ray_submitter.py +++ b/caikit_ray_backend/ray_submitter.py @@ -76,13 +76,20 @@ def main(): if model_path: error.type_check("", str, model_path=model_path) - # Finally kick off trainig + timeout = runtime_env.get("training_timeout", float(60)) + + # Finally kick off training with alog.ContextTimer(log.debug, "Done training %s in: ", module_class): - ray.get( - ray_training_tasks.train_and_save.options( - num_cpus=num_cpus, num_gpus=num_gpus - ).remote(module_class, model_path, *args, **kwargs) - ) + task = ray_training_tasks.train_and_save.options( + num_cpus=num_cpus, num_gpus=num_gpus + ).remote(module_class, model_path, *args, **kwargs) + ready, _ = ray.wait([task], timeout=timeout) + if ready: + ray.get(task) + else: + ray.cancel(task) + log.error("Task did not complete before time out.") + raise TimeoutError("Task did not complete before time out.") if __name__ == "__main__": diff --git a/tests/test_ray_backend.py b/tests/test_ray_backend.py index 2527467..b693362 100644 --- a/tests/test_ray_backend.py +++ b/tests/test_ray_backend.py @@ -16,6 +16,7 @@ """ # Standard from datetime import datetime +import logging import os import time @@ -46,7 +47,10 @@ def jsonl_file_data_stream(): def test_job_submission_client(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -82,7 +86,10 @@ def test_job_submission_client(mock_ray_cluster, jsonl_file_data_stream): def test_wait(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -101,7 +108,10 @@ def test_wait(mock_ray_cluster, jsonl_file_data_stream): def test_load(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -118,7 +128,10 @@ def test_load(mock_ray_cluster, jsonl_file_data_stream): def test_cancel(mock_ray_cluster, jsonl_file_data_stream): - config = {"connection": {"address": mock_ray_cluster.address}} + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 30.0, + } trainer = RayJobTrainModule(config, "ray_backend") args = [jsonl_file_data_stream] @@ -142,6 +155,27 @@ def test_cancel(mock_ray_cluster, jsonl_file_data_stream): assert status == TrainingStatus.CANCELED +def test_timeout(mock_ray_cluster, jsonl_file_data_stream): + config = { + "connection": {"address": mock_ray_cluster.address}, + "training_timeout": 0.25, + } + trainer = RayJobTrainModule(config, "ray_backend") + + args = [jsonl_file_data_stream] + model_future = trainer.train( + SampleModule, + *args, + save_path="/tmp", + ) + + time.sleep(3) + + status = model_future.get_info().status + print("Final status was", status) + assert status == TrainingStatus.ERRORED + + ## Test Ray Backend diff --git a/tox.ini b/tox.ini index 7bbfd9f..8849eb5 100644 --- a/tox.ini +++ b/tox.ini @@ -11,6 +11,7 @@ deps = pytest>=6.2.5,<7.0 pytest-cov>=2.10.1,<3.0 pytest-html>=3.1.1,<4.0 + pytest-catchlog tls_test_tools>=0.1.1 wheel>=0.38.4 tests/fixtures @@ -20,6 +21,7 @@ passenv = LOG_FORMATTER LOG_THREAD_ID LOG_CHANNEL_WIDTH +env_name=dev commands = pytest --cov=caikit --cov-report=html:coverage-{env_name} --cov-report=xml:coverage-{env_name}.xml {posargs:tests} ; Unclear: We probably want to test wheel packaging @@ -32,7 +34,8 @@ description = format with pre-commit deps = pre-commit>=3.0.4,<4.0 commands = ./scripts/fmt.sh allowlist_externals = ./scripts/fmt.sh -skip_install = True # Skip package install since fmt doesn't need to execute code, for ⚡⚡⚡ +skip_install = True +# Skip package install since fmt doesn't need to execute code, for ⚡⚡⚡ basepython = py39 [testenv:lint]