Skip to content

Commit

Permalink
Clarify multi_worker_with_keras.ipynb example code (#1)
Browse files Browse the repository at this point in the history
Suggesting that the "mnist.py" file be renamed to something else like "mnist_setup.py" since the resulting "import mnist" is confusingly in conflict with the mnist package. This threw me off until I re-read the instructions, since I had placed the code in one file which resulted in errors executing mnist.mnist_dataset and mnist.build_and_compile_cnn_model.
  • Loading branch information
Obliman authored Jan 25, 2022
1 parent 83f5d3c commit 2d9aa29
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions site/en/tutorials/distribute/multi_worker_with_keras.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
"id": "fLW6D2TzvC-4"
},
"source": [
"Next, create an `mnist.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:"
"Next, create an `mnist_setup.py` file with a simple model and dataset setup. This Python file will be used by the worker-processes in this tutorial:"
]
},
{
Expand All @@ -205,7 +205,7 @@
},
"outputs": [],
"source": [
"%%writefile mnist.py\n",
"%%writefile mnist_setup.py\n",
"\n",
"import os\n",
"import tensorflow as tf\n",
Expand Down Expand Up @@ -256,11 +256,11 @@
},
"outputs": [],
"source": [
"import mnist\n",
"import mnist_setup\n",
"\n",
"batch_size = 64\n",
"single_worker_dataset = mnist.mnist_dataset(batch_size)\n",
"single_worker_model = mnist.build_and_compile_cnn_model()\n",
"single_worker_dataset = mnist_setup.mnist_dataset(batch_size)\n",
"single_worker_model = mnist_setup.build_and_compile_cnn_model()\n",
"single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)"
]
},
Expand Down Expand Up @@ -492,7 +492,7 @@
"source": [
"with strategy.scope():\n",
" # Model building/compiling need to be within `strategy.scope()`.\n",
" multi_worker_model = mnist.build_and_compile_cnn_model()"
" multi_worker_model = mnist_setup.build_and_compile_cnn_model()"
]
},
{
Expand All @@ -512,7 +512,7 @@
"source": [
"To actually run with `MultiWorkerMirroredStrategy` you'll need to run worker processes and pass a `TF_CONFIG` to them.\n",
"\n",
"Like the `mnist.py` file written earlier, here is the `main.py` that each of the workers will run:"
"Like the `mnist_setup.py` file written earlier, here is the `main.py` that each of the workers will run:"
]
},
{
Expand All @@ -529,7 +529,7 @@
"import json\n",
"\n",
"import tensorflow as tf\n",
"import mnist\n",
"import mnist_setup\n",
"\n",
"per_worker_batch_size = 64\n",
"tf_config = json.loads(os.environ['TF_CONFIG'])\n",
Expand All @@ -538,11 +538,11 @@
"strategy = tf.distribute.MultiWorkerMirroredStrategy()\n",
"\n",
"global_batch_size = per_worker_batch_size * num_workers\n",
"multi_worker_dataset = mnist.mnist_dataset(global_batch_size)\n",
"multi_worker_dataset = mnist_setup.mnist_dataset(global_batch_size)\n",
"\n",
"with strategy.scope():\n",
" # Model building/compiling need to be within `strategy.scope()`.\n",
" multi_worker_model = mnist.build_and_compile_cnn_model()\n",
" multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n",
"\n",
"\n",
"multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)"
Expand Down Expand Up @@ -820,7 +820,7 @@
"options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF\n",
"\n",
"global_batch_size = 64\n",
"multi_worker_dataset = mnist.mnist_dataset(batch_size=64)\n",
"multi_worker_dataset = mnist_setup.mnist_dataset(batch_size=64)\n",
"dataset_no_auto_shard = multi_worker_dataset.with_options(options)"
]
},
Expand Down Expand Up @@ -1146,7 +1146,7 @@
"\n",
"callbacks = [tf.keras.callbacks.BackupAndRestore(backup_dir='/tmp/backup')]\n",
"with strategy.scope():\n",
" multi_worker_model = mnist.build_and_compile_cnn_model()\n",
" multi_worker_model = mnist_setup.build_and_compile_cnn_model()\n",
"multi_worker_model.fit(multi_worker_dataset,\n",
" epochs=3,\n",
" steps_per_epoch=70,\n",
Expand Down

0 comments on commit 2d9aa29

Please sign in to comment.