Skip to content

Commit ef005d7

Browse files
juanuribe28Tensorflow Cloud maintainers
authored and
Tensorflow Cloud maintainers
committed
Remove models_entry_point dependency on tensorflow_cloud.
PiperOrigin-RevId: 389661558
1 parent b6de193 commit ef005d7

File tree

5 files changed

+78
-67
lines changed

5 files changed

+78
-67
lines changed

src/python/tensorflow_cloud/core/experimental/constants.py

-17
This file was deleted.

src/python/tensorflow_cloud/core/experimental/models.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616

1717
import os
1818
import pickle
19-
import shutil
2019
from typing import Any, Dict, Optional
2120
import uuid
2221

23-
from . import constants
2422
from .. import machine_config
2523
from .. import run
2624
import tensorflow as tf
@@ -37,6 +35,10 @@
3735
import importlib_resources as pkg_resources
3836
# pylint: enable=g-import-not-at-top
3937

38+
_PARAMS_FILE_NAME_FORMAT = '{}_params'
39+
_ENTRY_POINT_FORMAT = '{}.py'
40+
_ENTRY_POINT_TEMPLATE = 'models_entry_point.py'
41+
4042

4143
def run_models(dataset_name: str,
4244
model_name: str,
@@ -267,18 +269,36 @@ def run_experiment_cloud(run_experiment_kwargs: Dict[str, Any],
267269
dict(distribution_strategy=distribution_strategy))
268270
file_id = str(uuid.uuid4())
269271
params_file = save_params(run_experiment_kwargs, file_id)
272+
entry_point = copy_entry_point(file_id, params_file)
270273

271-
with pkg_resources.path(__package__, 'models_entry_point.py') as path:
272-
entry_point = f'{file_id}.py'
273-
shutil.copyfile(str(path), entry_point)
274-
run_kwargs.update(dict(entry_point=entry_point,
275-
distribution_strategy=None))
276-
info = run.run(**run_kwargs)
274+
run_kwargs.update(dict(entry_point=entry_point,
275+
distribution_strategy=None))
276+
info = run.run(**run_kwargs)
277277
os.remove(entry_point)
278278
os.remove(params_file)
279279
return info
280280

281281

282+
def copy_entry_point(file_id, params_file):
283+
"""Copy models_entry_point and add params file name."""
284+
lines = get_original_lines()
285+
entry_point = _ENTRY_POINT_FORMAT.format(file_id)
286+
with open(entry_point, 'w') as entry_file:
287+
for line in lines:
288+
if line.startswith('PARAMS_FILE_NAME = '):
289+
entry_file.write(f"PARAMS_FILE_NAME = '{params_file}'\n")
290+
else:
291+
entry_file.write(line)
292+
return entry_point
293+
294+
295+
def get_original_lines():
296+
"""Gets the file lines of models_entry_point.py as a list of strings."""
297+
with pkg_resources.open_text(__package__, _ENTRY_POINT_TEMPLATE) as file:
298+
lines = file.readlines()
299+
return lines
300+
301+
282302
def get_distribution_strategy_str(run_kwargs):
283303
"""Gets the name of a distribution strategy based on cloud run config."""
284304
if ('worker_count' in run_kwargs
@@ -297,7 +317,7 @@ def get_distribution_strategy_str(run_kwargs):
297317

298318
def save_params(params, file_id):
299319
"""Pickles the params object using the file_id as prefix."""
300-
file_name = constants.PARAMS_FILE_NAME_FORMAT.format(file_id)
320+
file_name = _PARAMS_FILE_NAME_FORMAT.format(file_id)
301321
with open(file_name, 'xb') as f:
302322
pickle.dump(params, f)
303323
return file_name

src/python/tensorflow_cloud/core/experimental/models_entry_point.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,20 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
"""Entry point file for run_experiment_cloud."""
15+
"""Entry point template file for run_experiment_cloud."""
1616

17-
import os
1817
import pickle
1918

2019
import tensorflow as tf
2120

22-
from tensorflow_cloud.core.experimental import constants
2321
from official.core import train_lib
2422

23+
# PARAMS_FILE_NAME provides the name of the file that cointains the
24+
# run_experiment_kwargs used to call run_experiment. In models.py, when copying
25+
# this file, the value of PARAMS_FILE_NAME is updated to contain the actual name
26+
# of the file.
27+
PARAMS_FILE_NAME = 'file_name'
28+
2529

2630
def load_params(file_name):
2731
with open(file_name, 'rb') as f:
@@ -50,9 +54,7 @@ def get_one_device():
5054

5155

5256
def main():
53-
prefix, _ = os.path.splitext(os.path.basename(__file__))
54-
file_name = constants.PARAMS_FILE_NAME_FORMAT.format(prefix)
55-
run_experiment_kwargs = load_params(file_name)
57+
run_experiment_kwargs = load_params(PARAMS_FILE_NAME)
5658
strategy_str = run_experiment_kwargs['distribution_strategy']
5759
strategy = _DISTRIBUTION_STRATEGIES[strategy_str]()
5860
run_experiment_kwargs.update(dict(

src/python/tensorflow_cloud/core/experimental/tests/unit/models_entry_point_test.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import mock
1919
import tensorflow as tf
2020

21-
from tensorflow_cloud.core.experimental import constants
2221
from tensorflow_cloud.core.experimental import models_entry_point
2322
from official.core import base_task
2423
from official.core import config_definitions
@@ -64,9 +63,8 @@ def tearDown(self):
6463

6564
def test_main(self):
6665
models_entry_point.main()
67-
file_name = constants.PARAMS_FILE_NAME_FORMAT.format(
68-
'models_entry_point')
69-
self.load_params.assert_called_with(file_name)
66+
67+
self.load_params.assert_called_with(models_entry_point.PARAMS_FILE_NAME)
7068
self.run_experiment_kwargs.update(dict(
7169
distribution_strategy='one_device_strategy'))
7270
self.run_experiment.assert_called_with(**self.run_experiment_kwargs)

src/python/tensorflow_cloud/core/experimental/tests/unit/models_test.py

+39-31
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,16 @@
1515
"""Tests for the models experimental module."""
1616

1717
import os
18-
import pathlib
19-
import shutil
2018
import uuid
2119
from absl.testing import absltest
2220
import mock
2321

2422
import tensorflow as tf
2523
from tensorflow_cloud.core import machine_config
2624
from tensorflow_cloud.core import run
27-
from tensorflow_cloud.core.experimental import constants
2825
from tensorflow_cloud.core.experimental import models
2926
from official.vision.image_classification.efficientnet import efficientnet_model
3027

31-
# pylint: disable=g-import-not-at-top
32-
try:
33-
import importlib.resources as pkg_resources
34-
except ImportError:
35-
# Backported for python<3.7
36-
import importlib_resources as pkg_resources
37-
# pylint: enable=g-import-not-at-top
38-
3928

4029
class ModelsTest(absltest.TestCase):
4130

@@ -82,11 +71,9 @@ def setup_run_models(self):
8271
autospec=True,
8372
).start()
8473

85-
def setup_run_experiment_cloud(self, path):
86-
self.file_id = 'test'
87-
88-
self.params_file = constants.PARAMS_FILE_NAME_FORMAT.format(
89-
self.file_id)
74+
def setup_run_experiment_cloud(self, file_id):
75+
self.params_file = models._PARAMS_FILE_NAME_FORMAT.format(
76+
file_id)
9077

9178
self.save_params = mock.patch.object(
9279
models,
@@ -95,11 +82,11 @@ def setup_run_experiment_cloud(self, path):
9582
return_value=self.params_file,
9683
).start()
9784

98-
self.path = mock.patch.object(
99-
pkg_resources,
100-
'path',
85+
self.copy_entry_point = mock.patch.object(
86+
models,
87+
'copy_entry_point',
10188
autospec=True,
102-
return_value=pathlib.Path(path),
89+
return_value=models._ENTRY_POINT_FORMAT.format(file_id),
10390
).start()
10491

10592
self.remove = mock.patch.object(
@@ -111,13 +98,7 @@ def setup_run_experiment_cloud(self, path):
11198
self.uuid4 = mock.patch.object(
11299
uuid,
113100
'uuid4',
114-
return_value=self.file_id,
115-
).start()
116-
117-
self.copyfile = mock.patch.object(
118-
shutil,
119-
'copyfile',
120-
autospec=True,
101+
return_value=file_id,
121102
).start()
122103

123104
def tearDown(self):
@@ -201,18 +182,45 @@ def test_run_models_remote(self):
201182

202183
def test_run_experiment_cloud(self):
203184
self.setup_run(remote=False)
204-
path_str = '/test'
205-
self.setup_run_experiment_cloud(path_str)
185+
file_id = 'test_id'
186+
self.setup_run_experiment_cloud(file_id)
206187
run_experiment_kwargs = dict()
207188
models.run_experiment_cloud(
208189
run_experiment_kwargs=run_experiment_kwargs)
209-
entry_point = f'{self.file_id}.py'
210-
self.copyfile.assert_called_with(path_str, entry_point)
190+
entry_point = models._ENTRY_POINT_FORMAT.format(file_id)
211191
self.run.assert_called_with(entry_point=entry_point,
212192
distribution_strategy=None)
213193
self.remove.assert_any_call(entry_point)
214194
self.remove.assert_any_call(self.params_file)
215195

196+
def setup_copy_entry_point(self):
197+
self.get_original_lines = mock.patch.object(
198+
models,
199+
'get_original_lines',
200+
autospec=True,
201+
return_value=['PARAMS_FILE_NAME = not this', 'do not change'],
202+
).start()
203+
204+
self.open = mock.mock_open()
205+
mock.patch(
206+
'builtins.open',
207+
self.open,
208+
).start()
209+
210+
def test_copy_entry_point(self):
211+
self.setup_copy_entry_point()
212+
file_id = 'file_id'
213+
params_file = 'params_file'
214+
models.copy_entry_point(file_id, params_file)
215+
216+
self.open.assert_called_once_with(
217+
models._ENTRY_POINT_FORMAT.format(file_id),
218+
'w')
219+
entry_file = self.open()
220+
entry_file.write.assert_any_call(
221+
f"PARAMS_FILE_NAME = '{params_file}'\n")
222+
entry_file.write.assert_any_call('do not change')
223+
216224
def test_get_distribution_strategy_tpu(self):
217225
run_kwargs = dict(
218226
worker_count=1,

0 commit comments

Comments
 (0)