Skip to content

Commit a79009e

Browse files
SinaChavoshiTensorflow Cloud maintainers
authored and
Tensorflow Cloud maintainers
committed
Add integration tests for Distributing Cloud Tuner.
Fix a few issues: - Duplicate epoch reports to Oracle - Excessive cash discovery Info/Error log PiperOrigin-RevId: 339757281
1 parent 651ce41 commit a79009e

File tree

4 files changed

+208
-20
lines changed

4 files changed

+208
-20
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# Lint as: python3
2+
# Copyright 2020 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Integration tests for Distributing Cloud Tuner."""
16+
17+
import contextlib
18+
import io
19+
import os
20+
import re
21+
import kerastuner
22+
import tensorflow as tf
23+
from tensorflow import keras
24+
from tensorflow_cloud.tuner import optimizer_client
25+
from tensorflow_cloud.tuner.tuner import DistributingCloudTuner
26+
27+
# If input dataset is created outside tuner.search(),
28+
# it requires eager execution even in TF 1.x.
29+
if tf.version.VERSION.split(".")[0] == "1":
30+
tf.compat.v1.enable_eager_execution()
31+
32+
# The project id to use to run tests.
33+
_PROJECT_ID = os.environ["PROJECT_ID"]
34+
35+
# The GCP region in which the end-to-end test is run.
36+
_REGION = os.environ["REGION"]
37+
38+
# Study ID for testing
39+
_STUDY_ID_BASE = "dct_{}".format((os.environ["BUILD_ID"]).replace("-", "_"))
40+
41+
# The base docker image to use for the remote environment.
42+
_DOCKER_IMAGE = os.environ["DOCKER_IMAGE"]
43+
44+
# The staging bucket to use to copy the model and data for the remote run.
45+
_REMOTE_DIR = os.path.join("gs://", os.environ["TEST_BUCKET"], _STUDY_ID_BASE)
46+
47+
# The search space for hyperparameters
48+
_HPS = kerastuner.engine.hyperparameters.HyperParameters()
49+
_HPS.Float("learning_rate", min_value=1e-4, max_value=1e-2, sampling="log")
50+
_HPS.Int("num_layers", 2, 10)
51+
52+
53+
def _load_data(dir_path=None):
54+
"""Loads and prepares data."""
55+
56+
mnist_file_path = None
57+
if dir_path:
58+
mnist_file_path = os.path.join(dir_path, "mnist.npz")
59+
60+
(x, y), (val_x, val_y) = keras.datasets.mnist.load_data(mnist_file_path)
61+
x = x.astype("float32") / 255.0
62+
val_x = val_x.astype("float32") / 255.0
63+
64+
return ((x[:10000], y[:10000]), (val_x, val_y))
65+
66+
67+
def _build_model(hparams):
68+
# Note that CloudTuner does not support adding hyperparameters in
69+
# the model building function. Instead, the search space is configured
70+
# by passing a hyperparameters argument when instantiating (constructing)
71+
# the tuner.
72+
model = keras.Sequential()
73+
model.add(keras.layers.Flatten(input_shape=(28, 28)))
74+
75+
# Build the model with number of layers from the hyperparameters
76+
for _ in range(hparams.get("num_layers")):
77+
model.add(keras.layers.Dense(units=64, activation="relu"))
78+
model.add(keras.layers.Dense(10, activation="softmax"))
79+
80+
# Compile the model with learning rate from the hyperparameters
81+
model.compile(
82+
optimizer=keras.optimizers.Adam(lr=hparams.get("learning_rate")),
83+
loss="sparse_categorical_crossentropy",
84+
metrics=["acc"],
85+
)
86+
return model
87+
88+
89+
class _DistributingCloudTunerIntegrationTestBase(tf.test.TestCase):
90+
91+
def setUp(self):
92+
super(_DistributingCloudTunerIntegrationTestBase, self).setUp()
93+
self._study_id = None
94+
95+
def _assert_output(self, fn, regex_str):
96+
stdout = io.StringIO()
97+
with contextlib.redirect_stdout(stdout):
98+
fn()
99+
output = stdout.getvalue()
100+
self.assertRegex(output, re.compile(regex_str, re.DOTALL))
101+
102+
def _assert_results_summary(self, fn):
103+
self._assert_output(
104+
fn, ".*Results summary.*Trial summary.*Hyperparameters.*")
105+
106+
def _delete_dir(self, path) -> None:
107+
"""Deletes a directory if exists."""
108+
if tf.io.gfile.isdir(path):
109+
tf.io.gfile.rmtree(path)
110+
111+
def tearDown(self):
112+
super(_DistributingCloudTunerIntegrationTestBase, self).tearDown()
113+
114+
# Delete the study used in the test, if present
115+
if self._study_id:
116+
service = optimizer_client.create_or_load_study(
117+
_PROJECT_ID, _REGION, self._study_id, None)
118+
service.delete_study()
119+
120+
tf.keras.backend.clear_session()
121+
122+
# Delete log files, saved_models and other training assets
123+
self._delete_dir(_REMOTE_DIR)
124+
125+
126+
class DistributingCloudTunerIntegrationTest(
127+
_DistributingCloudTunerIntegrationTestBase):
128+
129+
def setUp(self):
130+
super(DistributingCloudTunerIntegrationTest, self).setUp()
131+
(self._x, self._y), (self._val_x, self._val_y) = _load_data(
132+
self.get_temp_dir())
133+
134+
def testCloudTunerHyperparameters(self):
135+
"""Test case to configure Distributing Tuner with HyperParameters."""
136+
study_id = "{}_hyperparameters".format(_STUDY_ID_BASE)
137+
self._study_id = study_id
138+
139+
tuner = DistributingCloudTuner(
140+
_build_model,
141+
project_id=_PROJECT_ID,
142+
region=_REGION,
143+
objective="acc",
144+
hyperparameters=_HPS,
145+
max_trials=2,
146+
study_id=study_id,
147+
directory=_REMOTE_DIR,
148+
container_uri=_DOCKER_IMAGE
149+
)
150+
151+
tuner.search(
152+
x=self._x,
153+
y=self._y,
154+
epochs=2,
155+
validation_data=(self._val_x, self._val_y),
156+
)
157+
158+
self._assert_results_summary(tuner.results_summary)
159+
160+
if __name__ == "__main__":
161+
tf.test.main()

src/python/tensorflow_cloud/tuner/tests/unit/tuner_test.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,9 @@ def test_add_model_checkpoint_callback(self, mock_super_tuner):
539539
auto_spec=True)
540540
@mock.patch.object(tf_utils, "get_tensorboard_log_watcher_from_path",
541541
auto_spec=True)
542+
@mock.patch.object(tf.io.gfile, "makedirs", auto_spec=True)
542543
def test_remote_run_trial_with_successful_job(
543-
self, mock_log_watcher, mock_is_running, mock_super_tuner,
544+
self, mock_tf_io, mock_log_watcher, mock_is_running, mock_super_tuner,
544545
mock_job_status, mock_cloud_fit):
545546
remote_tuner = self._remote_tuner(
546547
None, None, self._study_config, max_trials=10)
@@ -573,21 +574,23 @@ def test_remote_run_trial_with_successful_job(
573574
image_uri=self._container_uri,
574575
job_id=self._job_id)
575576

576-
log_path = remote_tuner._get_tensorboard_log_dir(
577-
self._test_trial.trial_id)
577+
log_path = os.path.join(remote_tuner._get_tensorboard_log_dir(
578+
self._test_trial.trial_id), "train")
578579
mock_log_watcher.assert_called_with(log_path)
579580
self.assertEqual(
580581
2, remote_tuner._get_remote_training_metrics.call_count)
582+
mock_tf_io.assert_called_with(log_path)
581583

582584
@mock.patch.object(cloud_fit_client, "cloud_fit", auto_spec=True)
583585
@mock.patch.object(google_api_client,
584586
"wait_for_api_training_job_completion", auto_spec=True)
585587
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
586588
@mock.patch.object(google_api_client, "is_api_training_job_running",
587589
auto_spec=True)
590+
@mock.patch.object(tf.io.gfile, "makedirs", auto_spec=True)
588591
def test_remote_run_trial_with_failed_job(
589-
self, mock_is_running, mock_super_tuner,
590-
mock_job_status, mock_cloud_fit):
592+
self, mock_tf_io, mock_is_running, mock_super_tuner, mock_job_status,
593+
mock_cloud_fit):
591594

592595
remote_tuner = self._remote_tuner(
593596
None, None, self._study_config, max_trials=10)
@@ -609,8 +612,9 @@ def test_remote_run_trial_with_failed_job(
609612
@mock.patch.object(super_tuner.Tuner, "__init__", auto_spec=True)
610613
@mock.patch.object(google_api_client, "is_api_training_job_running",
611614
auto_spec=True)
615+
@mock.patch.object(tf.io.gfile, "makedirs", auto_spec=True)
612616
def test_remote_run_trial_with_oracle_canceling_job(
613-
self, mock_is_running, mock_super_tuner,
617+
self, mock_tf_io, mock_is_running, mock_super_tuner,
614618
mock_job_status, mock_cloud_fit, mock_stop_job):
615619

616620
remote_tuner = self._remote_tuner(
@@ -656,7 +660,7 @@ def test_get_remote_training_metrics(self, mock_super_tuner):
656660
log_reader = tf_utils.get_tensorboard_log_watcher_from_path(log_dir)
657661
results = remote_tuner._get_remote_training_metrics(log_reader, {})
658662

659-
self.assertLen(results.completed_epoch_metrics, 3)
663+
self.assertLen(results.completed_epoch_metrics, 2)
660664
self.assertIn("accuracy", results.completed_epoch_metrics[0])
661665
self.assertIn("loss", results.completed_epoch_metrics[0])
662666
self.assertEqual(

src/python/tensorflow_cloud/tuner/tuner.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
# metrics from remote training Tensorboard logs during training with:
4646
# - 'completed_epoch_metrics'- a list of epoch metrics for completed epochs.
4747
# - 'partial_epoch_metrics' - Any incomplete epoch metrics for the last epoch.
48+
# If training has completed this will contain metrics for the final epoch of
49+
# training.
50+
4851
_TrainingMetrics = collections.namedtuple("_TrainingMetrics", [
4952
"completed_epoch_metrics", "partial_epoch_metrics"])
5053

@@ -568,7 +571,11 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
568571

569572
# Create an instance of tensorboard DirectoryWatcher to retrieve the
570573
# logs for this trial run
571-
log_path = self._get_tensorboard_log_dir(trial.trial_id)
574+
log_path = os.path.join(
575+
self._get_tensorboard_log_dir(trial.trial_id), "train")
576+
577+
# Tensorboard log watcher expects the path to exist
578+
tf.io.gfile.makedirs(log_path)
572579

573580
# TODO(b/170687807) Switch from using "{}".format() to f-string
574581
tf.get_logger().info(
@@ -590,11 +597,12 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
590597

591598
for epoch_metrics in training_metrics.completed_epoch_metrics:
592599
# TODO(b/169197272) Validate metrics contain oracle objective
593-
trial.status = self.oracle.update_trial(
594-
trial_id=trial.trial_id,
595-
metrics=epoch_metrics,
596-
step=epoch)
597-
epoch += 1
600+
if epoch_metrics:
601+
trial.status = self.oracle.update_trial(
602+
trial_id=trial.trial_id,
603+
metrics=epoch_metrics,
604+
step=epoch)
605+
epoch += 1
598606

599607
if trial.status == "STOPPED":
600608
google_api_client.stop_aip_training_job(
@@ -617,11 +625,19 @@ def run_trial(self, trial, *fit_args, **fit_kwargs):
617625
for epoch_metrics in training_metrics.completed_epoch_metrics:
618626
# TODO(b/169197272) Validate metrics contain oracle objective
619627
# TODO(b/170907612) Support submit partial results to Oracle
628+
if epoch_metrics:
629+
self.oracle.update_trial(
630+
trial_id=trial.trial_id,
631+
metrics=epoch_metrics,
632+
step=epoch)
633+
epoch += 1
634+
635+
# submit final epoch metrics
636+
if training_metrics.partial_epoch_metrics:
620637
self.oracle.update_trial(
621638
trial_id=trial.trial_id,
622-
metrics=epoch_metrics,
639+
metrics=training_metrics.partial_epoch_metrics,
623640
step=epoch)
624-
epoch += 1
625641

626642
def _get_job_spec_from_config(self, job_id: Text) -> Dict[Text, Any]:
627643
"""Creates a request dictionary for the CAIP training service.
@@ -680,7 +696,9 @@ def _get_remote_training_metrics(
680696
- 'completed_epoch_metrics'- a list of epoch metrics for completed
681697
epochs.
682698
- 'partial_epoch_metrics' - Any incomplete epoch metrics for the
683-
last epoch.
699+
last epoch. Once training completes, the final epoch metrics
700+
will be stored here, this is not included in
701+
completed_epoch_metrics.
684702
"""
685703
completed_epoch_metrics = []
686704
for event in log_reader.Load():
@@ -699,7 +717,6 @@ def _get_remote_training_metrics(
699717
# the unrelated Objectives.
700718
partial_epoch_metrics[metric] = tf.make_ndarray(
701719
event.summary.value[0].tensor)
702-
completed_epoch_metrics.append(partial_epoch_metrics)
703720
return _TrainingMetrics(completed_epoch_metrics, partial_epoch_metrics)
704721

705722
def load_model(self, trial):

src/python/tensorflow_cloud/utils/google_api_client.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def wait_for_api_training_job_completion(job_id: Text, project_id: Text)->bool:
5555
"""
5656
# Wait for AIP Training job to finish
5757
job_name = "projects/{}/jobs/{}".format(project_id, job_id)
58-
api_client = discovery.build("ml", "v1")
58+
# Disable cache_discovery to remove excessive info logs see:
59+
# https://github.com/googleapis/google-api-python-client/issues/299
60+
api_client = discovery.build("ml", "v1", cache_discovery=False)
5961

6062
request = api_client.projects().jobs().get(name=job_name)
6163

@@ -94,7 +96,9 @@ def is_api_training_job_running(job_id: Text, project_id: Text)->bool:
9496
cancelled.
9597
"""
9698
job_name = "projects/{}/jobs/{}".format(project_id, job_id)
97-
api_client = discovery.build("ml", "v1")
99+
# Disable cache_discovery to remove excessive info logs see:
100+
# https://github.com/googleapis/google-api-python-client/issues/299
101+
api_client = discovery.build("ml", "v1", cache_discovery=False)
98102

99103
logging.info("Retrieving status for job %s.", job_name)
100104

@@ -112,7 +116,9 @@ def stop_aip_training_job(job_id: Text, project_id: Text):
112116
project_id: Project under which the AIP Training job is running.
113117
"""
114118
job_name = "projects/{}/jobs/{}".format(project_id, job_id)
115-
api_client = discovery.build("ml", "v1")
119+
# Disable cache_discovery to remove excessive info logs see:
120+
# https://github.com/googleapis/google-api-python-client/issues/299
121+
api_client = discovery.build("ml", "v1", cache_discovery=False)
116122

117123
logging.info("Canceling the job %s.", job_name)
118124

0 commit comments

Comments
 (0)