Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions docs/source/JAX_for_LLM_pretraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
"</table>"
]
},
Expand Down Expand Up @@ -95,7 +98,7 @@
"id": "Rcji_799n4eA"
},
"source": [
"**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.\n",
"**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.\n",
"\n",
"Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices."
]
Expand Down Expand Up @@ -214,6 +217,20 @@
"import time"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import flax\n",
"from pkg_resources import parse_version\n",
"\n",
"if parse_version(flax.__version__) >= parse_version(\"0.12\"):\n",
" flax.config.update('flax_always_shard_variable', False)\n",
" print('Disabling Flax variable eager sharding for backward compatibility...')"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -242,7 +259,7 @@
"- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..\n",
"- `axis_names`, where:\n",
" - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and\n",
" - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.\n",
" - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism\n",
"\n",
"This matches the structure in the Kaggle TPU v5e setup.\n",
"\n",
Expand All @@ -266,7 +283,8 @@
"# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))\n",
"\n",
"### For free-tier Colab TPU, which only has a single TPU core\n",
"# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))"
"if jax.device_count() == 1:\n",
" mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), (\"batch\", \"model\"))"
]
},
{
Expand Down Expand Up @@ -467,7 +485,7 @@
" # and obtain logits for each token in the vocabulary (for next token prediction).\n",
" outputs = self.output_layer(x)\n",
" return outputs\n",
" \n",
"\n",
" @nnx.jit\n",
" def sample_from(self, logits):\n",
" logits, indices = jax.lax.top_k(logits, k=top_k)\n",
Expand Down Expand Up @@ -523,7 +541,9 @@
"embed_dim = 256\n",
"num_heads = 8\n",
"feed_forward_dim = 256\n",
"batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8\n",
"batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism\n",
"if jax.device_count() == 1:\n",
" batch_size = 192\n",
"num_epochs = 1\n",
"top_k = 10"
]
Expand Down Expand Up @@ -1138,7 +1158,9 @@
"id": "3813cbf2",
"metadata": {},
"source": [
"## Profiling for hyperparameter tuning"
"## Profiling for hyperparameter tuning\n",
"\n",
"**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here."
]
},
{
Expand Down
33 changes: 25 additions & 8 deletions docs/source/JAX_for_LLM_pretraining.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ kernelspec:
<td>
<a target="_blank" href="https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
</td>
<td>
<a target="_blank" href="https://github.com/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
</td>
</table>

+++ {"id": "NIOXoY1xgiww"}
Expand Down Expand Up @@ -56,7 +59,7 @@ outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068

+++ {"id": "Rcji_799n4eA"}

**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator.
**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.

Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices.

Expand Down Expand Up @@ -108,6 +111,15 @@ import tiktoken
import time
```

```{code-cell}
import flax
from pkg_resources import parse_version

if parse_version(flax.__version__) >= parse_version("0.12"):
flax.config.update('flax_always_shard_variable', False)
print('Disabling Flax variable eager sharding for backward compatibility...')
```

+++ {"id": "rPyt7MV6prz1"}

## Define the miniGPT model with Flax and JAX automatic parallelism
Expand All @@ -132,7 +144,7 @@ Our `Mesh` will have two arguments:
- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..
- `axis_names`, where:
- `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and
- `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores.
- `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor parallism

This matches the structure in the Kaggle TPU v5e setup.

Expand All @@ -150,7 +162,8 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
# mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

### For free-tier Colab TPU, which only has a single TPU core
# mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model"))
if jax.device_count() == 1:
mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model"))
```

+++ {"id": "_ZKdhNo98NgG"}
Expand Down Expand Up @@ -331,7 +344,7 @@ class MiniGPT(nnx.Module):
# and obtain logits for each token in the vocabulary (for next token prediction).
outputs = self.output_layer(x)
return outputs

@nnx.jit
def sample_from(self, logits):
logits, indices = jax.lax.top_k(logits, k=top_k)
Expand Down Expand Up @@ -377,7 +390,9 @@ maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256
batch_size = 192 # You can set a bigger batch size if you use Kaggle's TPU v5e-8
batch_size = 192 * jax.device_count() / 2 # divide by 2 in case of model parallelism
if jax.device_count() == 1:
batch_size = 192
num_epochs = 1
top_k = 10
```
Expand Down Expand Up @@ -482,7 +497,7 @@ rng = jax.random.PRNGKey(0)

start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
print(f"Initial generated text:")
print("Initial generated text:")
generated_text = model.generate_text(maxlen, start_tokens)

metrics_history = {
Expand Down Expand Up @@ -521,13 +536,13 @@ for epoch in range(num_epochs):
)
start_time = time.time()

print(f"Generated text:")
print("Generated text:")
generated_text = model.generate_text(maxlen, start_tokens)

step += 1

# Final text generation
print(f"Final generated text:")
print("Final generated text:")
generated_text = model.generate_text(maxlen, start_tokens)
```

Expand Down Expand Up @@ -581,6 +596,8 @@ checkpointer.save('/content/save', state)

## Profiling for hyperparameter tuning

**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here.

```{code-cell}
!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard
```
Expand Down