diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb index f87ff51..32304b1 100644 --- a/docs/source/JAX_for_LLM_pretraining.ipynb +++ b/docs/source/JAX_for_LLM_pretraining.ipynb @@ -18,6 +18,9 @@ " \n", " Run in Google Colab\n", " \n", + " \n", + " View source on GitHub\n", + " \n", "" ] }, @@ -95,7 +98,7 @@ "id": "Rcji_799n4eA" }, "source": [ - "**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.\n", + "**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.\n", "\n", "Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices." ] @@ -214,6 +217,20 @@ "import time" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import flax\n", + "from pkg_resources import parse_version\n", + "\n", + "if parse_version(flax.__version__) >= parse_version(\"0.12\"):\n", + " flax.config.update('flax_always_shard_variable', False)\n", + " print('Disabling Flax variable eager sharding for backward compatibility...')" + ] + }, { "cell_type": "markdown", "metadata": { @@ -242,7 +259,7 @@ "- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..\n", "- `axis_names`, where:\n", " - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and\n", - " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.\n", + " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism\n", "\n", "This matches the structure in the Kaggle TPU v5e setup.\n", "\n", @@ -266,7 +283,8 @@ "# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n", "\n", "### For free-tier Colab TPU, which only has a single TPU core\n", - "# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))" + "if jax.device_count() == 1:\n", + " mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))" ] }, { @@ -467,7 +485,7 @@ " # and obtain logits for each token in the vocabulary (for next token prediction).\n", " outputs = self.output_layer(x)\n", " return outputs\n", - " \n", + "\n", " @nnx.jit\n", " def sample_from(self, logits):\n", " logits, indices = jax.lax.top_k(logits, k=top_k)\n", @@ -523,7 +541,9 @@ "embed_dim = 256\n", "num_heads = 8\n", "feed_forward_dim = 256\n", - "batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8\n", + "batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism\n", + "if jax.device_count() == 1:\n", + " batch_size = 192\n", "num_epochs = 1\n", "top_k = 10" ] @@ -1138,7 +1158,9 @@ "id": "3813cbf2", "metadata": {}, "source": [ - "## Profiling for hyperparameter tuning" + "## Profiling for hyperparameter tuning\n", + "\n", + "**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here." ] }, { diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index 8130590..741e2ee 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -22,6 +22,9 @@ kernelspec: Run in Google Colab + + View source on GitHub + +++ {"id": "NIOXoY1xgiww"} @@ -56,7 +59,7 @@ outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068 +++ {"id": "Rcji_799n4eA"} -**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator. +**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs. Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices. @@ -108,6 +111,15 @@ import tiktoken import time ``` +```{code-cell} +import flax +from pkg_resources import parse_version + +if parse_version(flax.__version__) >= parse_version("0.12"): + flax.config.update('flax_always_shard_variable', False) + print('Disabling Flax variable eager sharding for backward compatibility...') +``` + +++ {"id": "rPyt7MV6prz1"} ## Define the miniGPT model with Flax and JAX automatic parallelism @@ -132,7 +144,7 @@ Our `Mesh` will have two arguments: - `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices)).. - `axis_names`, where: - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and - - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores. + - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism This matches the structure in the Kaggle TPU v5e setup. @@ -150,7 +162,8 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) # mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model')) ### For free-tier Colab TPU, which only has a single TPU core -# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model")) +if jax.device_count() == 1: + mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model")) ``` +++ {"id": "_ZKdhNo98NgG"} @@ -331,7 +344,7 @@ class MiniGPT(nnx.Module): # and obtain logits for each token in the vocabulary (for next token prediction). outputs = self.output_layer(x) return outputs - + @nnx.jit def sample_from(self, logits): logits, indices = jax.lax.top_k(logits, k=top_k) @@ -377,7 +390,9 @@ maxlen = 256 embed_dim = 256 num_heads = 8 feed_forward_dim = 256 -batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8 +batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism +if jax.device_count() == 1: + batch_size = 192 num_epochs = 1 top_k = 10 ``` @@ -482,7 +497,7 @@ rng = jax.random.PRNGKey(0) start_prompt = "Once upon a time" start_tokens = tokenizer.encode(start_prompt)[:maxlen] -print(f"Initial generated text:") +print("Initial generated text:") generated_text = model.generate_text(maxlen, start_tokens) metrics_history = { @@ -521,13 +536,13 @@ for epoch in range(num_epochs): ) start_time = time.time() - print(f"Generated text:") + print("Generated text:") generated_text = model.generate_text(maxlen, start_tokens) step += 1 # Final text generation -print(f"Final generated text:") +print("Final generated text:") generated_text = model.generate_text(maxlen, start_tokens) ``` @@ -581,6 +596,8 @@ checkpointer.save('/content/save', state) ## Profiling for hyperparameter tuning +**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here. + ```{code-cell} !pip install -Uq tensorboard-plugin-profile tensorflow tensorboard ```