From 5943a8764f83a80509f164c0b97478bc3d37b9bb Mon Sep 17 00:00:00 2001
From: Wei Wei <5577741+windmaple@users.noreply.github.com>
Date: Thu, 16 Oct 2025 12:00:59 +0800
Subject: [PATCH 1/2] Dynamically set bs and mesh in miniGPT tutorial based on
TPU platform
---
docs/source/JAX_for_LLM_pretraining.ipynb | 36 +++++++++++++++++++----
docs/source/JAX_for_LLM_pretraining.md | 35 +++++++++++++++++-----
2 files changed, 57 insertions(+), 14 deletions(-)
diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb
index f87ff51..ea260d1 100644
--- a/docs/source/JAX_for_LLM_pretraining.ipynb
+++ b/docs/source/JAX_for_LLM_pretraining.ipynb
@@ -18,6 +18,9 @@
"
\n",
" Run in Google Colab\n",
" | \n",
+ " \n",
+ " View source on GitHub\n",
+ " | \n",
""
]
},
@@ -95,11 +98,27 @@
"id": "Rcji_799n4eA"
},
"source": [
- "**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.\n",
+ "**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.\n",
"\n",
"Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices."
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "if os.path.exists('/content/'):\n",
+ " platform = \"Colab\"\n",
+ "elif os.path.exists('/kaggle/'):\n",
+ " platform = \"Kaggle\"\n",
+ "else:\n",
+ " # Assume using Cloud TPU otherwise\n",
+ " platform = \"GCP\""
+ ]
+ },
{
"cell_type": "code",
"execution_count": 2,
@@ -242,7 +261,7 @@
"- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..\n",
"- `axis_names`, where:\n",
" - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and\n",
- " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.\n",
+ " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism\n",
"\n",
"This matches the structure in the Kaggle TPU v5e setup.\n",
"\n",
@@ -266,7 +285,8 @@
"# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n",
"\n",
"### For free-tier Colab TPU, which only has a single TPU core\n",
- "# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))"
+ "if platform == 'Colab':\n",
+ " mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))"
]
},
{
@@ -467,7 +487,7 @@
" # and obtain logits for each token in the vocabulary (for next token prediction).\n",
" outputs = self.output_layer(x)\n",
" return outputs\n",
- " \n",
+ "\n",
" @nnx.jit\n",
" def sample_from(self, logits):\n",
" logits, indices = jax.lax.top_k(logits, k=top_k)\n",
@@ -523,7 +543,9 @@
"embed_dim = 256\n",
"num_heads = 8\n",
"feed_forward_dim = 256\n",
- "batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8\n",
+ "batch_size = 192*8\n",
+ "if platform == \"Colab\":\n",
+ " batch_size = 192\n",
"num_epochs = 1\n",
"top_k = 10"
]
@@ -1138,7 +1160,9 @@
"id": "3813cbf2",
"metadata": {},
"source": [
- "## Profiling for hyperparameter tuning"
+ "## Profiling for hyperparameter tuning\n",
+ "\n",
+ "**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here."
]
},
{
diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md
index 8130590..bb1126a 100644
--- a/docs/source/JAX_for_LLM_pretraining.md
+++ b/docs/source/JAX_for_LLM_pretraining.md
@@ -22,6 +22,9 @@ kernelspec:
Run in Google Colab
|
+
+ View source on GitHub
+ |
+++ {"id": "NIOXoY1xgiww"}
@@ -56,10 +59,21 @@ outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068
+++ {"id": "Rcji_799n4eA"}
-**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.
+**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.
Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices.
+```{code-cell}
+import os
+if os.path.exists('/content/'):
+ platform = "Colab"
+elif os.path.exists('/kaggle/'):
+ platform = "Kaggle"
+else:
+ # Assume using Cloud TPU otherwise
+ platform = "GCP"
+```
+
```{code-cell}
---
colab:
@@ -132,7 +146,7 @@ Our `Mesh` will have two arguments:
- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..
- `axis_names`, where:
- `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and
- - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.
+ - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism
This matches the structure in the Kaggle TPU v5e setup.
@@ -150,7 +164,8 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
### For free-tier Colab TPU, which only has a single TPU core
-# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model"))
+if platform == 'Colab':
+ mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model"))
```
+++ {"id": "_ZKdhNo98NgG"}
@@ -331,7 +346,7 @@ class MiniGPT(nnx.Module):
# and obtain logits for each token in the vocabulary (for next token prediction).
outputs = self.output_layer(x)
return outputs
-
+
@nnx.jit
def sample_from(self, logits):
logits, indices = jax.lax.top_k(logits, k=top_k)
@@ -377,7 +392,9 @@ maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256
-batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8
+batch_size = 192*8
+if platform == "Colab":
+ batch_size = 192
num_epochs = 1
top_k = 10
```
@@ -482,7 +499,7 @@ rng = jax.random.PRNGKey(0)
start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
-print(f"Initial generated text:")
+print("Initial generated text:")
generated_text = model.generate_text(maxlen, start_tokens)
metrics_history = {
@@ -521,13 +538,13 @@ for epoch in range(num_epochs):
)
start_time = time.time()
- print(f"Generated text:")
+ print("Generated text:")
generated_text = model.generate_text(maxlen, start_tokens)
step += 1
# Final text generation
-print(f"Final generated text:")
+print("Final generated text:")
generated_text = model.generate_text(maxlen, start_tokens)
```
@@ -581,6 +598,8 @@ checkpointer.save('/content/save', state)
## Profiling for hyperparameter tuning
+**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here.
+
```{code-cell}
!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard
```
From eed5ee85a12c5ef0640aaf16bc7a5efeb274f770 Mon Sep 17 00:00:00 2001
From: Wei Wei <5577741+windmaple@users.noreply.github.com>
Date: Fri, 17 Oct 2025 12:36:53 +0800
Subject: [PATCH 2/2] Update miniGPT notebook:
1. use jax.device_count() to determine mesh and bs
2. disable new variable eager sharding in Flax 0.12
3. fix typo
---
docs/source/JAX_for_LLM_pretraining.ipynb | 38 +++++++++++------------
docs/source/JAX_for_LLM_pretraining.md | 28 ++++++++---------
2 files changed, 31 insertions(+), 35 deletions(-)
diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb
index ea260d1..32304b1 100644
--- a/docs/source/JAX_for_LLM_pretraining.ipynb
+++ b/docs/source/JAX_for_LLM_pretraining.ipynb
@@ -103,22 +103,6 @@
"Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices."
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "if os.path.exists('/content/'):\n",
- " platform = \"Colab\"\n",
- "elif os.path.exists('/kaggle/'):\n",
- " platform = \"Kaggle\"\n",
- "else:\n",
- " # Assume using Cloud TPU otherwise\n",
- " platform = \"GCP\""
- ]
- },
{
"cell_type": "code",
"execution_count": 2,
@@ -233,6 +217,20 @@
"import time"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import flax\n",
+ "from pkg_resources import parse_version\n",
+ "\n",
+ "if parse_version(flax.__version__) >= parse_version(\"0.12\"):\n",
+ " flax.config.update('flax_always_shard_variable', False)\n",
+ " print('Disabling Flax variable eager sharding for backward compatibility...')"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -261,7 +259,7 @@
"- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..\n",
"- `axis_names`, where:\n",
" - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and\n",
- " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism\n",
+ " - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism\n",
"\n",
"This matches the structure in the Kaggle TPU v5e setup.\n",
"\n",
@@ -285,7 +283,7 @@
"# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n",
"\n",
"### For free-tier Colab TPU, which only has a single TPU core\n",
- "if platform == 'Colab':\n",
+ "if jax.device_count() == 1:\n",
" mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))"
]
},
@@ -543,8 +541,8 @@
"embed_dim = 256\n",
"num_heads = 8\n",
"feed_forward_dim = 256\n",
- "batch_size = 192*8\n",
- "if platform == \"Colab\":\n",
+ "batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism\n",
+ "if jax.device_count() == 1:\n",
" batch_size = 192\n",
"num_epochs = 1\n",
"top_k = 10"
diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md
index bb1126a..741e2ee 100644
--- a/docs/source/JAX_for_LLM_pretraining.md
+++ b/docs/source/JAX_for_LLM_pretraining.md
@@ -63,17 +63,6 @@ outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068
Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices.
-```{code-cell}
-import os
-if os.path.exists('/content/'):
- platform = "Colab"
-elif os.path.exists('/kaggle/'):
- platform = "Kaggle"
-else:
- # Assume using Cloud TPU otherwise
- platform = "GCP"
-```
-
```{code-cell}
---
colab:
@@ -122,6 +111,15 @@ import tiktoken
import time
```
+```{code-cell}
+import flax
+from pkg_resources import parse_version
+
+if parse_version(flax.__version__) >= parse_version("0.12"):
+ flax.config.update('flax_always_shard_variable', False)
+ print('Disabling Flax variable eager sharding for backward compatibility...')
+```
+
+++ {"id": "rPyt7MV6prz1"}
## Define the miniGPT model with Flax and JAX automatic parallelism
@@ -146,7 +144,7 @@ Our `Mesh` will have two arguments:
- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..
- `axis_names`, where:
- `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and
- - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism
+ - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism
This matches the structure in the Kaggle TPU v5e setup.
@@ -164,7 +162,7 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
### For free-tier Colab TPU, which only has a single TPU core
-if platform == 'Colab':
+if jax.device_count() == 1:
mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model"))
```
@@ -392,8 +390,8 @@ maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256
-batch_size = 192*8
-if platform == "Colab":
+batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism
+if jax.device_count() == 1:
batch_size = 192
num_epochs = 1
top_k = 10