From 5943a8764f83a80509f164c0b97478bc3d37b9bb Mon Sep 17 00:00:00 2001 From: Wei Wei <5577741+windmaple@users.noreply.github.com> Date: Thu, 16 Oct 2025 12:00:59 +0800 Subject: [PATCH 1/2] Dynamically set bs and mesh in miniGPT tutorial based on TPU platform --- docs/source/JAX_for_LLM_pretraining.ipynb | 36 +++++++++++++++++++---- docs/source/JAX_for_LLM_pretraining.md | 35 +++++++++++++++++----- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb index f87ff51..ea260d1 100644 --- a/docs/source/JAX_for_LLM_pretraining.ipynb +++ b/docs/source/JAX_for_LLM_pretraining.ipynb @@ -18,6 +18,9 @@ " \n", " Run in Google Colab\n", " \n", + " \n", + " View source on GitHub\n", + " \n", "" ] }, @@ -95,11 +98,27 @@ "id": "Rcji_799n4eA" }, "source": [ - "**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.\n", + "**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.\n", "\n", "Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "if os.path.exists('/content/'):\n", + " platform = \"Colab\"\n", + "elif os.path.exists('/kaggle/'):\n", + " platform = \"Kaggle\"\n", + "else:\n", + " # Assume using Cloud TPU otherwise\n", + " platform = \"GCP\"" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -242,7 +261,7 @@ "- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..\n", "- `axis_names`, where:\n", " - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and\n", - " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.\n", + " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism\n", "\n", "This matches the structure in the Kaggle TPU v5e setup.\n", "\n", @@ -266,7 +285,8 @@ "# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n", "\n", "### For free-tier Colab TPU, which only has a single TPU core\n", - "# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))" + "if platform == 'Colab':\n", + " mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))" ] }, { @@ -467,7 +487,7 @@ " # and obtain logits for each token in the vocabulary (for next token prediction).\n", " outputs = self.output_layer(x)\n", " return outputs\n", - " \n", + "\n", " @nnx.jit\n", " def sample_from(self, logits):\n", " logits, indices = jax.lax.top_k(logits, k=top_k)\n", @@ -523,7 +543,9 @@ "embed_dim = 256\n", "num_heads = 8\n", "feed_forward_dim = 256\n", - "batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8\n", + "batch_size = 192*8\n", + "if platform == \"Colab\":\n", + " batch_size = 192\n", "num_epochs = 1\n", "top_k = 10" ] @@ -1138,7 +1160,9 @@ "id": "3813cbf2", "metadata": {}, "source": [ - "## Profiling for hyperparameter tuning" + "## Profiling for hyperparameter tuning\n", + "\n", + "**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here." ] }, { diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index 8130590..bb1126a 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -22,6 +22,9 @@ kernelspec: Run in Google Colab + + View source on GitHub + +++ {"id": "NIOXoY1xgiww"} @@ -56,10 +59,21 @@ outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068 +++ {"id": "Rcji_799n4eA"} -**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator. +**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs. Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices. +```{code-cell} +import os +if os.path.exists('/content/'): + platform = "Colab" +elif os.path.exists('/kaggle/'): + platform = "Kaggle" +else: + # Assume using Cloud TPU otherwise + platform = "GCP" +``` + ```{code-cell} --- colab: @@ -132,7 +146,7 @@ Our `Mesh` will have two arguments: - `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices)).. - `axis_names`, where: - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and - - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores. + - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism This matches the structure in the Kaggle TPU v5e setup. @@ -150,7 +164,8 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) # mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model')) ### For free-tier Colab TPU, which only has a single TPU core -# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model")) +if platform == 'Colab': + mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model")) ``` +++ {"id": "_ZKdhNo98NgG"} @@ -331,7 +346,7 @@ class MiniGPT(nnx.Module): # and obtain logits for each token in the vocabulary (for next token prediction). outputs = self.output_layer(x) return outputs - + @nnx.jit def sample_from(self, logits): logits, indices = jax.lax.top_k(logits, k=top_k) @@ -377,7 +392,9 @@ maxlen = 256 embed_dim = 256 num_heads = 8 feed_forward_dim = 256 -batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8 +batch_size = 192*8 +if platform == "Colab": + batch_size = 192 num_epochs = 1 top_k = 10 ``` @@ -482,7 +499,7 @@ rng = jax.random.PRNGKey(0) start_prompt = "Once upon a time" start_tokens = tokenizer.encode(start_prompt)[:maxlen] -print(f"Initial generated text:") +print("Initial generated text:") generated_text = model.generate_text(maxlen, start_tokens) metrics_history = { @@ -521,13 +538,13 @@ for epoch in range(num_epochs): ) start_time = time.time() - print(f"Generated text:") + print("Generated text:") generated_text = model.generate_text(maxlen, start_tokens) step += 1 # Final text generation -print(f"Final generated text:") +print("Final generated text:") generated_text = model.generate_text(maxlen, start_tokens) ``` @@ -581,6 +598,8 @@ checkpointer.save('/content/save', state) ## Profiling for hyperparameter tuning +**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here. + ```{code-cell} !pip install -Uq tensorboard-plugin-profile tensorflow tensorboard ``` From eed5ee85a12c5ef0640aaf16bc7a5efeb274f770 Mon Sep 17 00:00:00 2001 From: Wei Wei <5577741+windmaple@users.noreply.github.com> Date: Fri, 17 Oct 2025 12:36:53 +0800 Subject: [PATCH 2/2] Update miniGPT notebook: 1. use jax.device_count() to determine mesh and bs 2. disable new variable eager sharding in Flax 0.12 3. fix typo --- docs/source/JAX_for_LLM_pretraining.ipynb | 38 +++++++++++------------ docs/source/JAX_for_LLM_pretraining.md | 28 ++++++++--------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb index ea260d1..32304b1 100644 --- a/docs/source/JAX_for_LLM_pretraining.ipynb +++ b/docs/source/JAX_for_LLM_pretraining.ipynb @@ -103,22 +103,6 @@ "Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices." ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "if os.path.exists('/content/'):\n", - " platform = \"Colab\"\n", - "elif os.path.exists('/kaggle/'):\n", - " platform = \"Kaggle\"\n", - "else:\n", - " # Assume using Cloud TPU otherwise\n", - " platform = \"GCP\"" - ] - }, { "cell_type": "code", "execution_count": 2, @@ -233,6 +217,20 @@ "import time" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import flax\n", + "from pkg_resources import parse_version\n", + "\n", + "if parse_version(flax.__version__) >= parse_version(\"0.12\"):\n", + " flax.config.update('flax_always_shard_variable', False)\n", + " print('Disabling Flax variable eager sharding for backward compatibility...')" + ] + }, { "cell_type": "markdown", "metadata": { @@ -261,7 +259,7 @@ "- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..\n", "- `axis_names`, where:\n", " - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and\n", - " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism\n", + " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism\n", "\n", "This matches the structure in the Kaggle TPU v5e setup.\n", "\n", @@ -285,7 +283,7 @@ "# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n", "\n", "### For free-tier Colab TPU, which only has a single TPU core\n", - "if platform == 'Colab':\n", + "if jax.device_count() == 1:\n", " mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))" ] }, @@ -543,8 +541,8 @@ "embed_dim = 256\n", "num_heads = 8\n", "feed_forward_dim = 256\n", - "batch_size = 192*8\n", - "if platform == \"Colab\":\n", + "batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism\n", + "if jax.device_count() == 1:\n", " batch_size = 192\n", "num_epochs = 1\n", "top_k = 10" diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index bb1126a..741e2ee 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -63,17 +63,6 @@ outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068 Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices. -```{code-cell} -import os -if os.path.exists('/content/'): - platform = "Colab" -elif os.path.exists('/kaggle/'): - platform = "Kaggle" -else: - # Assume using Cloud TPU otherwise - platform = "GCP" -``` - ```{code-cell} --- colab: @@ -122,6 +111,15 @@ import tiktoken import time ``` +```{code-cell} +import flax +from pkg_resources import parse_version + +if parse_version(flax.__version__) >= parse_version("0.12"): + flax.config.update('flax_always_shard_variable', False) + print('Disabling Flax variable eager sharding for backward compatibility...') +``` + +++ {"id": "rPyt7MV6prz1"} ## Define the miniGPT model with Flax and JAX automatic parallelism @@ -146,7 +144,7 @@ Our `Mesh` will have two arguments: - `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices)).. - `axis_names`, where: - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and - - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism + - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism This matches the structure in the Kaggle TPU v5e setup. @@ -164,7 +162,7 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) # mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model')) ### For free-tier Colab TPU, which only has a single TPU core -if platform == 'Colab': +if jax.device_count() == 1: mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model")) ``` @@ -392,8 +390,8 @@ maxlen = 256 embed_dim = 256 num_heads = 8 feed_forward_dim = 256 -batch_size = 192*8 -if platform == "Colab": +batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism +if jax.device_count() == 1: batch_size = 192 num_epochs = 1 top_k = 10