Skip to content

Commit 9a1180a

Browse files
juanuribe28Tensorflow Cloud maintainers
authored andcommitted
Add integration tests for run_experiment_cloud wrapper.
PiperOrigin-RevId: 383893019
1 parent f1ae448 commit 9a1180a

File tree

8 files changed

+428
-123
lines changed

8 files changed

+428
-123
lines changed

src/python/dependencies.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def make_required_install_packages():
2727
"tensorflow>=1.15.0,<3.0",
2828
"tensorflow_datasets",
2929
"tensorflow_transform",
30+
"tf-models-official",
31+
"importlib_resources ; python_version<'3.7'"
3032
]
3133

3234

@@ -38,4 +40,5 @@ def make_required_test_packages():
3840
"numpy",
3941
"nbconvert",
4042
"tf-models-official",
43+
"importlib_resources ; python_version<'3.7'"
4144
]

src/python/tensorflow_cloud/core/containerize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _get_file_path_map(self):
285285
self.entry_point = sys.argv[0]
286286

287287
# Map entry_point directory to the dst directory.
288-
if not self.called_from_notebook:
288+
if not self.called_from_notebook or self.entry_point is not None:
289289
entry_point_dir, _ = os.path.split(self.entry_point)
290290
if not entry_point_dir: # Current directory
291291
entry_point_dir = "."

src/python/tensorflow_cloud/core/experimental/models.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,27 @@
1515
"""Module that contains the `run_models` wrapper for training models from TF Model Garden."""
1616

1717
import os
18+
import pickle
19+
import shutil
1820
from typing import Any, Dict, Optional
21+
import uuid
1922

2023
from .. import machine_config
2124
from .. import run
2225
import tensorflow as tf
2326
import tensorflow_datasets as tfds
2427

25-
from official.core import train_lib
2628
from official.vision.image_classification.efficientnet import efficientnet_model
2729
from official.vision.image_classification.resnet import resnet_model
2830

31+
# pylint: disable=g-import-not-at-top
32+
try:
33+
import importlib.resources as pkg_resources
34+
except ImportError:
35+
# Backported for python<3.7
36+
import importlib_resources as pkg_resources
37+
# pylint: enable=g-import-not-at-top
38+
2939

3040
def run_models(dataset_name: str,
3141
model_name: str,
@@ -251,48 +261,42 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
251261
"""
252262
if run_kwargs is None:
253263
run_kwargs = dict()
254-
255-
if run.remote():
256-
default_machine_config = machine_config.COMMON_MACHINE_CONFIGS['T4_1X']
257-
if 'chief_config' in run_kwargs:
258-
chief_config = run_kwargs['chief_config']
259-
else:
260-
chief_config = default_machine_config
261-
if 'worker_count' in run_kwargs:
262-
worker_count = run_kwargs['worker_count']
264+
distribution_strategy = get_distribution_strategy_str(run_kwargs)
265+
run_experiment_kwargs.update(
266+
dict(distribution_strategy=distribution_strategy))
267+
file_id = str(uuid.uuid4())
268+
params_file = save_params(run_experiment_kwargs, file_id)
269+
270+
with pkg_resources.path(__package__, 'models_entry_point.py') as path:
271+
entry_point = f'{file_id}.py'
272+
shutil.copyfile(str(path), entry_point)
273+
run_kwargs.update(dict(entry_point=entry_point,
274+
distribution_strategy=None))
275+
info = run.run(**run_kwargs)
276+
os.remove(entry_point)
277+
os.remove(params_file)
278+
return info
279+
280+
281+
def get_distribution_strategy_str(run_kwargs):
282+
"""Gets the name of a distribution strategy based on cloud run config."""
283+
if ('worker_count' in run_kwargs
284+
and run_kwargs['worker_count'] > 0):
285+
if ('worker_config' in run_kwargs
286+
and machine_config.is_tpu_config(run_kwargs['worker_config'])):
287+
return 'tpu'
263288
else:
264-
worker_count = 0
265-
if 'worker_config' in run_kwargs:
266-
worker_config = run_kwargs['worker_config']
267-
else:
268-
worker_config = default_machine_config
269-
distribution_strategy = get_distribution_strategy(chief_config,
270-
worker_count,
271-
worker_config)
272-
run_experiment_kwargs.update(
273-
dict(distribution_strategy=distribution_strategy))
274-
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
275-
model.save(run_experiment_kwargs['model_dir'])
276-
277-
run_kwargs.update(dict(entry_point=None,
278-
distribution_strategy=None))
279-
return run.run(**run_kwargs)
280-
281-
282-
def get_distribution_strategy(chief_config, worker_count, worker_config):
283-
"""Gets a tf distribution strategy based on the cloud run config."""
284-
if worker_count > 0:
285-
if machine_config.is_tpu_config(worker_config):
286-
# TODO(b/194857231) Dependency conflict for using TPUs
287-
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
288-
tpu='local')
289-
tf.config.experimental_connect_to_cluster(resolver)
290-
tf.tpu.experimental.initialize_tpu_system(resolver)
291-
return tf.distribute.TPUStrategy(resolver)
292-
else:
293-
# TODO(b/148619319) Saving model currently failing
294-
return tf.distribute.MultiWorkerMirroredStrategy()
295-
elif chief_config.accelerator_count > 1:
296-
return tf.distribute.MirroredStrategy()
289+
return 'multi_mirror'
290+
elif ('chief_config' in run_kwargs
291+
and run_kwargs['chief_config'].accelerator_count > 1):
292+
return 'mirror'
297293
else:
298-
return tf.distribute.OneDeviceStrategy(device='/gpu:0')
294+
return 'one_device'
295+
296+
297+
def save_params(params, file_id):
298+
"""Pickles the params object using the file_id as prefix."""
299+
file_name = f'{file_id}_params'
300+
with open(file_name, 'xb') as f:
301+
pickle.dump(params, f)
302+
return file_name
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Entry point file for run_experiment_cloud."""
16+
17+
import os
18+
import pickle
19+
20+
import tensorflow as tf
21+
22+
from official.core import train_lib
23+
24+
25+
def load_params(file_name):
26+
with open(file_name, 'rb') as f:
27+
params = pickle.load(f)
28+
return params
29+
30+
31+
def get_tpu_strategy():
32+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
33+
tpu='local')
34+
tf.config.experimental_connect_to_cluster(resolver)
35+
tf.tpu.experimental.initialize_tpu_system(resolver)
36+
return tf.distribute.TPUStrategy(resolver)
37+
38+
39+
def get_one_device():
40+
return tf.distribute.OneDeviceStrategy(device='/gpu:0')
41+
42+
_DISTRIBUTION_STRATEGIES = dict(
43+
# TODO(b/194857231) Dependency conflict for using TPUs
44+
tpu=get_tpu_strategy,
45+
# TODO(b/148619319) Saving model currently failing for multi_mirror
46+
multi_mirror=tf.distribute.MultiWorkerMirroredStrategy,
47+
mirror=tf.distribute.MirroredStrategy,
48+
one_device=get_one_device)
49+
50+
51+
def main():
52+
prefix, _ = os.path.splitext(os.path.basename(__file__))
53+
run_experiment_kwargs = load_params(f'{prefix}_params')
54+
strategy_str = run_experiment_kwargs['distribution_strategy']
55+
strategy = _DISTRIBUTION_STRATEGIES[strategy_str]()
56+
run_experiment_kwargs.update(dict(
57+
distribution_strategy=strategy))
58+
model, _ = train_lib.run_experiment(**run_experiment_kwargs)
59+
model.save(run_experiment_kwargs['model_dir'])
60+
61+
62+
if __name__ == '__main__':
63+
main()
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Lint as: python3
2+
# Copyright 2021 Google LLC. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Integration tests for calling run_experiment_cloud."""
16+
17+
import os
18+
import uuid
19+
20+
import tensorflow as tf
21+
import tensorflow_cloud as tfc
22+
from tensorflow_cloud.core.experimental import models
23+
from tensorflow_cloud.utils import google_api_client
24+
from official.core import task_factory
25+
from official.utils.testing import mock_task
26+
27+
# The staging bucket to use for cloudbuild as well as save the model and data.
28+
_TEST_BUCKET = os.environ["TEST_BUCKET"]
29+
_PROJECT_ID = os.environ["PROJECT_ID"]
30+
_PARENT_IMAGE = "gcr.io/deeplearning-platform-release/tf2-gpu.2-5"
31+
_BASE_PATH = f"gs://{_TEST_BUCKET}/{uuid.uuid4()}"
32+
33+
34+
class RunExperimentCloudTest(tf.test.TestCase):
35+
36+
def setUp(self):
37+
super(RunExperimentCloudTest, self).setUp()
38+
self.test_data_path = os.path.join(
39+
os.path.dirname(os.path.abspath(__file__)), "../testdata/"
40+
)
41+
self.requirements_txt = os.path.join(self.test_data_path,
42+
"requirements.txt")
43+
44+
self._test_config = {
45+
"trainer": {
46+
"checkpoint_interval": 10,
47+
"steps_per_loop": 10,
48+
"summary_interval": 10,
49+
"train_steps": 10,
50+
"validation_steps": 5,
51+
"validation_interval": 10,
52+
"continuous_eval_timeout": 1,
53+
"validation_summary_subdir": "validation",
54+
"optimizer_config": {
55+
"optimizer": {
56+
"type": "sgd",
57+
},
58+
"learning_rate": {
59+
"type": "constant"
60+
}
61+
}
62+
},
63+
}
64+
65+
self.params = mock_task.mock_experiment()
66+
self.params.override(self._test_config, is_strict=False)
67+
self.run_experiment_kwargs = dict(
68+
params=self.params,
69+
task=task_factory.get_task(self.params.task),
70+
mode="train_and_eval",
71+
)
72+
self.docker_config = tfc.DockerConfig(
73+
parent_image=_PARENT_IMAGE,
74+
image_build_bucket=_TEST_BUCKET
75+
)
76+
77+
def tpu_strategy(self):
78+
run_kwargs = dict(
79+
chief_config=tfc.COMMON_MACHINE_CONFIGS["CPU"],
80+
worker_count=1,
81+
worker_config=tfc.COMMON_MACHINE_CONFIGS["TPU"],
82+
requirements_txt=self.requirements_txt,
83+
job_labels={
84+
"job": "tpu_strategy",
85+
"team": "run_experiment_cloud_tests",
86+
},
87+
docker_config=self.docker_config,
88+
)
89+
run_experiment_kwargs = dict(
90+
model_dir=os.path.join(_BASE_PATH, "tpu", "saved_model"),
91+
**self.run_experiment_kwargs,
92+
)
93+
return models.run_experiment_cloud(run_experiment_kwargs,
94+
run_kwargs)
95+
96+
def multi_mirror_strategy(self):
97+
run_kwargs = dict(
98+
chief_config=tfc.COMMON_MACHINE_CONFIGS["P100_1X"],
99+
worker_count=1,
100+
worker_config=tfc.COMMON_MACHINE_CONFIGS["P100_1X"],
101+
requirements_txt=self.requirements_txt,
102+
job_labels={
103+
"job": "multi_mirror_strategy",
104+
"team": "run_experiment_cloud_tests",
105+
},
106+
docker_config=self.docker_config,
107+
)
108+
run_experiment_kwargs = dict(
109+
model_dir=os.path.join(_BASE_PATH, "multi_mirror", "saved_model"),
110+
**self.run_experiment_kwargs,
111+
)
112+
return models.run_experiment_cloud(run_experiment_kwargs,
113+
run_kwargs)
114+
115+
def mirror_strategy(self):
116+
run_kwargs = dict(
117+
chief_config=tfc.COMMON_MACHINE_CONFIGS["P100_4X"],
118+
requirements_txt=self.requirements_txt,
119+
job_labels={
120+
"job": "mirror",
121+
"team": "run_experiment_cloud_tests",
122+
},
123+
docker_config=self.docker_config,
124+
)
125+
run_experiment_kwargs = dict(
126+
model_dir=os.path.join(_BASE_PATH, "mirror", "saved_model"),
127+
**self.run_experiment_kwargs,
128+
)
129+
return models.run_experiment_cloud(run_experiment_kwargs,
130+
run_kwargs)
131+
132+
def one_device_strategy(self):
133+
run_kwargs = dict(
134+
requirements_txt=self.requirements_txt,
135+
job_labels={
136+
"job": "one_device",
137+
"team": "run_experiment_cloud_tests",
138+
},
139+
docker_config=self.docker_config,
140+
)
141+
run_experiment_kwargs = dict(
142+
model_dir=os.path.join(_BASE_PATH, "one_device", "saved_model"),
143+
**self.run_experiment_kwargs,
144+
)
145+
# Using the default T4 GPU for this test.
146+
return models.run_experiment_cloud(run_experiment_kwargs,
147+
run_kwargs)
148+
149+
def test_run_experiment_cloud(self):
150+
track_status = {
151+
"one_device_strategy": self.one_device_strategy(),
152+
"mirror_strategy": self.mirror_strategy(),
153+
# TODO(b/148619319) Enable when bug is solved
154+
# "multi_mirror_strategy": self.multi_mirror_strategy(),
155+
# TODO(b/194857231) Enable when bug is solved
156+
# "tpu_strategy": self.tpu_strategy(),
157+
}
158+
159+
for test_name, ret_val in track_status.items():
160+
self.assertTrue(
161+
google_api_client.wait_for_aip_training_job_completion(
162+
ret_val["job_id"], _PROJECT_ID),
163+
"Job {} generated from the test: {} has failed".format(
164+
ret_val["job_id"], test_name))
165+
166+
if __name__ == "__main__":
167+
tf.test.main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
git+https://github.com/tensorflow/cloud.git@refs/pull/360/head#egg=tensorflow-cloud&subdirectory=src/python
2+
tf-models-official

0 commit comments

Comments
 (0)