Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/learning_jax/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Learning-JAX
Slide decks, coding exercises, and quick references for learning the JAX AI Stack. The coding exercises are designed to be runnable in a free Colab instance.

For more comprehensive documentation please see the individual websites:

* https://jaxstack.ai
* https://jax.dev
* https://flax.readthedocs.io
* https://orbax.readthedocs.io
* https://optax.readthedocs.io
* https://google-grain.readthedocs.io
* https://chex.readthedocs.io

[Join our growing community on Discord](https://goo.gle/jax-community) and connect with other developers!
1 change: 1 addition & 0 deletions docs/learning_jax/code-exercises/1 - JAX AI Stack.ipynb

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1jbsvh_ZWvXFaK-0FGsVYjoDc4QcBDV10","timestamp":1755113770723}],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Introduction to Flax NNX\n","\n","Welcome to the Flax NNX Colab Notebook! This notebook provides hands-on exercises designed to help PyTorch users transition to Flax NNX and the JAX ecosystem.\n","\n","We'll cover core concepts and build simple models."],"metadata":{"id":"wK90mE1fmGuk"}},{"cell_type":"code","source":["!pip install -Uq flax optax"],"metadata":{"id":"Pke20hU-A1iQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"C7OTedGZjhPz"},"outputs":[],"source":["# @title Exercise 1: Understanding Modules and Parameters (Coding Exercise)\n","\n","# Instructions:\n","# 1. Create a simple NNX Module called `MyLinearLayer`.\n","# 2. It should have an `nnx.Param` called `weight` (initialized randomly with shape [input_size, output_size]).\n","# 3. It should have an `nnx.Param` called `bias` (initialized with zeros with shape [output_size]).\n","# 4. The forward pass (`__call__` method) should perform a linear transformation: `x @ self.weight.value + self.bias.value`.\n","# 5. Instantiate the layer with `input_size=10` and `output_size=5`.\n","# 6. Print the shape of the `weight` and `bias` parameters.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","\n","class MyLinearLayer(nnx.Module):\n"," def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):\n","\n"," pass # FILL IN THIS PART\n","\n"," def __call__(self, x: jax.Array):\n"," pass # FILL IN THIS PART\n","\n","# Instantiate the layer\n","key = jax.random.PRNGKey(0)\n","linear_layer = MyLinearLayer(\n"," input_size='FILL IN THIS PART',\n"," output_size='FILL IN THIS PART',\n"," rngs=nnx.Rngs(key))\n","\n","# Print the shapes of the parameters\n","print(\"Weight shape:\", 'FILL IN THIS PART')\n","print(\"Bias shape:\", 'FILL IN THIS PART')\n","\n","# Example usage:\n","dummy_input = jnp.ones((1, 10))\n","output = linear_layer(dummy_input)\n","print(\"Output shape:\", output.shape)"]},{"cell_type":"code","source":["# @title Exercise 1 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","\n","# class MyLinearLayer(nnx.Module):\n","# def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):\n","# self.weight = nnx.Param(jax.random.normal(rngs.params(), (input_size, output_size)))\n","# self.bias = nnx.Param(jnp.zeros((output_size,)))\n","\n","# def __call__(self, x: jax.Array):\n","# return x @ self.weight.value + self.bias.value\n","\n","# # Instantiate the layer\n","# key = jax.random.PRNGKey(0)\n","# linear_layer = MyLinearLayer(input_size=10, output_size=5, rngs=nnx.Rngs(key))\n","\n","# # Print the shapes of the parameters\n","# print(\"Weight shape:\", linear_layer.weight.value.shape)\n","# print(\"Bias shape:\", linear_layer.bias.value.shape)\n","\n","# # Example usage:\n","# dummy_input = jnp.ones((1, 10))\n","# output = linear_layer(dummy_input)\n","# print(\"Output shape:\", output.shape)"],"metadata":{"id":"QeaDLUu_lXMA","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 2: State Management (Coding Exercise)\n","# Instructions:\n","# 1. Create an NNX Module called `CounterModule`.\n","# 2. It should have a Python instance attribute called `count` initialized to 0.\n","# 3. The `__call__` method should increment the `count` by 1 and return the new value.\n","# 4. Instantiate the module.\n","# 5. Call the module multiple times and print the returned value.\n","# 6. Use `nnx.split` and `nnx.merge` to save and load the module's state. Verify that the counter resumes from where it left off.\n","\n","from flax import nnx\n","import jax.numpy as jnp\n","\n","class CounterModule(nnx.Module):\n"," def __init__(self):\n"," pass # FILL IN THIS PART\n","\n"," def __call__(self):\n"," pass # FILL IN THIS PART\n","\n","# Instantiate the module\n","pass # FILL IN THIS PART. Name it \"counter\"\n","\n","# Call the module and print the value\n","print(\"First call:\", counter())\n","print(\"Second call:\", counter())\n","\n","# Split the module into graphdef and state.\n","# Remember that state is an nnx.Variable\n","graphdef, state = # FILL IN THIS PART\n","\n","# Merge the graphdef and state to create a new module\n","new_counter = # FILL IN THIS PART\n","\n","# Call the new module and print the value\n","print(\"After split and merge, first call:\", new_counter())\n","print(\"After split and merge, second call:\", new_counter())"],"metadata":{"id":"Qa51jundpavu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 2 Solution\n","\n","# from flax import nnx\n","# import jax.numpy as jnp\n","\n","# class CounterModule(nnx.Module):\n","# def __init__(self):\n","# self.count = 0\n","\n","# def __call__(self):\n","# self.count += 1\n","# return self.count\n","\n","# # Instantiate the module\n","# counter = CounterModule()\n","\n","# # Call the module and print the value\n","# print(\"First call:\", counter())\n","# print(\"Second call:\", counter())\n","\n","# # Split the module into graphdef and state\n","# graphdef, state = nnx.split(counter, nnx.Variable)\n","\n","# # Merge the graphdef and state to create a new module\n","# new_counter = nnx.merge(graphdef, state)\n","\n","# # Call the new module and print the value\n","# print(\"After split and merge, first call:\", new_counter())\n","# print(\"After split and merge, second call:\", new_counter())"],"metadata":{"id":"jVh1M8fYppnC","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 3: Explicit Random Number Generation (Coding Exercise)\n","\n","# Instructions:\n","# 1. Create an NNX Module called `RandomNormalLayer`.\n","# 2. Its `__init__` method should receive a `size` argument defining the size of the random vector to generate.\n","# 3. The `__init__` method should receive a `rngs: nnx.Rngs` argument that is used to generate a random normal tensor\n","# using jax.random.normal and assign the tensor to `self.random_vector`.\n","# 4. The `__call__` method should return the value of `self.random_vector` (a new random normal tensor).\n","# 5. Instantiate the layer with a size of 10, passing in the rngs parameter with a jax.random.PRNGKey.\n","# 6. Call the module twice and observe that the returned values are different.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","\n","# CREATE RandomNormalLayer\n","\n","# Instantiate the module\n","key = # USE jax.random.PRNGKey to create a new key\n","random_layer = RandomNormalLayer(size='SIZE HERE', rngs=nnx.Rngs(key))\n","\n","# Call the module and print the value\n","print(\"First call:\", random_layer())\n","print(\"Second call:\", random_layer())"],"metadata":{"id":"QKKiri2rptgl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 3 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","\n","# class RandomNormalLayer(nnx.Module):\n","# def __init__(self, size: int, *, rngs: nnx.Rngs):\n","# self.random_vector = nnx.Param(jax.random.normal(rngs.params(), (size,)))\n","\n","# def __call__(self):\n","# return self.random_vector.value\n","\n","# # Instantiate the module\n","# key = jax.random.PRNGKey(0)\n","# random_layer = RandomNormalLayer(size=10, rngs=nnx.Rngs(key))\n","\n","# # Call the module and print the value\n","# print(\"First call:\", random_layer())\n","# print(\"Second call:\", random_layer())"],"metadata":{"id":"RG420Ks9roNq","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 4: Building a Simple CNN (Coding Exercise)\n","\n","# Instructions:\n","# 1. Create an NNX Module representing a simple CNN with the following layers:\n","# - Convolutional layer (nnx.Conv) with 32 filters, kernel size 3, and stride 1.\n","# - ReLU activation.\n","# - Max pooling layer (nnx.max_pool) with window size 2 and stride 2.\n","# - Flatten layer (jax.numpy.reshape).\n","# - Linear layer (nnx.Linear) to map to 10 output classes.\n","# 2. Initialize the CNN with appropriate input and output shapes.\n","# 3. Perform a forward pass with a dummy input and print the output shape.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","import jax.lax\n","\n","class SimpleCNN(nnx.Module):\n"," def __init__(self, num_classes: int, *, rngs: nnx.Rngs):\n"," self.conv = nnx.Conv('STRIDE', 'FILTERS', kernel_size=('X, X'), rngs=rngs)\n"," self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)\n","\n"," def __call__(self, x: jax.Array):\n"," x = self.conv(x)\n"," print(f'{x.shape = }') # For debug\n"," x = nnx.relu(x)\n"," print(f'{x.shape = }') # For debug\n"," x = nnx.max_pool(x, window_shape=('X, X'), strides=('X, X'))\n"," print(f'{x.shape = }') # For debug\n"," x = x.reshape(x.shape[0], -1) # flatten\n"," print(f'{x.shape = }') # For debug\n"," x = self.linear(x)\n"," return x\n","\n","# Instantiate the CNN\n","key = jax.random.PRNGKey(0)\n","cnn = SimpleCNN(num_classes='OUTPUT CLASSES', rngs=nnx.Rngs(key))\n","\n","# Dummy input\n","dummy_input = jnp.ones((1, 28, 28, 1))\n","\n","# Forward pass\n","output = cnn(dummy_input)\n","print(\"Output shape:\", output.shape)"],"metadata":{"id":"zafWVwtE3xgF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 4 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","# import jax.lax\n","\n","# class SimpleCNN(nnx.Module):\n","# def __init__(self, num_classes: int, *, rngs: nnx.Rngs):\n","# self.conv = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n","# self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)\n","\n","# def __call__(self, x: jax.Array):\n","# x = self.conv(x)\n","# print(f'{x.shape = }')\n","# x = nnx.relu(x)\n","# print(f'{x.shape = }')\n","# x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n","# print(f'{x.shape = }')\n","# x = x.reshape(x.shape[0], -1) # flatten\n","# print(f'{x.shape = }')\n","# x = self.linear(x)\n","# return x\n","\n","# # Instantiate the CNN\n","# key = jax.random.PRNGKey(0)\n","# cnn = SimpleCNN(num_classes=10, rngs=nnx.Rngs(key))\n","\n","# # Dummy input\n","# dummy_input = jnp.ones((1, 28, 28, 1))\n","\n","# # Forward pass\n","# output = cnn(dummy_input)\n","# print(\"Output shape:\", output.shape)"],"metadata":{"id":"_XHKL8ZaYla4","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 5: Training Loop with Optax (Coding Exercise)\n","\n","# Instructions:\n","# 1. Define a simple model (e.g., a linear layer).\n","# 2. Create an nnx.Optimizer, making sure to specify which variable types to\n","# update using the now required wrt argument (e.g., wrt=nnx.Param).\n","# 3. Implement a training step function that:\n","# - Calculates the loss (e.g., mean squared error).\n","# - Computes gradients using `nnx.value_and_grad`.\n","# - Updates the model's state using `optimizer.update(model, grads)`.\n","# 4. Run the training loop for a few steps.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","import optax\n","\n","# Define a simple model\n","class LinearModel(nnx.Module):\n"," def __init__(self, *, rngs: nnx.Rngs):\n"," self.linear = 'LINEAR LAYER HERE'\n","\n"," def __call__(self, x: jax.Array):\n"," return self.linear(x)\n","\n","# Instantiate the model\n","key = jax.random.PRNGKey(0)\n","model = LinearModel(rngs=nnx.Rngs(key))\n","\n","# Create an Optax optimizer\n","tx = 'OPTAX SGD HERE'\n","optimizer = nnx.Optimizer('WRAP THE OPTIMIZER')\n","\n","# Dummy data\n","x = jnp.array([[2.0]])\n","y = jnp.array([[4.0]])\n","\n","# Training step function\n","@nnx.jit\n","def train_step(model, optimizer, x, y):\n"," def loss_fn(model):\n"," y_pred = model(x)\n"," return jnp.mean((y_pred - y) ** 2)\n","\n"," loss, grads = nnx.value_and_grad(loss_fn)(model)\n"," optimizer.update(model, grads)\n"," return loss, model\n","\n","# Training loop\n","num_steps = 10\n","for i in range(num_steps):\n"," loss, model = train_step(model, optimizer, x, y)\n"," print(f\"Step {i+1}, Loss: {loss}\")\n","\n","print(\"Trained model output:\", model(x))"],"metadata":{"id":"Sf4P1AEO3_Rp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 5 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","# import optax\n","\n","# # Define a simple model\n","# class LinearModel(nnx.Module):\n","# def __init__(self, *, rngs: nnx.Rngs):\n","# self.linear = nnx.Linear(in_features=1, out_features=1, rngs=rngs)\n","\n","# def __call__(self, x: jax.Array):\n","# return self.linear(x)\n","\n","# # Instantiate the model\n","# key = jax.random.PRNGKey(0)\n","# model = LinearModel(rngs=nnx.Rngs(key))\n","\n","# # Create an Optax optimizer\n","# tx = optax.sgd(learning_rate=0.01)\n","# optimizer = nnx.Optimizer(model, tx=tx, wrt=nnx.Param)\n","\n","# # Dummy data\n","# x = jnp.array([[2.0]])\n","# y = jnp.array([[4.0]])\n","\n","# # Training step function\n","# @nnx.jit\n","# def train_step(model, optimizer, x, y):\n","# def loss_fn(model):\n","# y_pred = model(x)\n","# return jnp.mean((y_pred - y) ** 2)\n","\n","# loss, grads = nnx.value_and_grad(loss_fn)(model)\n","# optimizer.update(model, grads)\n","# return loss, model\n","\n","# # Training loop\n","# num_steps = 10\n","# for i in range(num_steps):\n","# loss, model = train_step(model, optimizer, x, y)\n","# print(f\"Step {i+1}, Loss: {loss}\")\n","\n","# print(\"Trained model output:\", model(x))"],"metadata":{"id":"CaLOsG6paLam"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Congratulations!\n","You've now worked through the fundamentals of Flax NNX!\n","\n","Remember to consult the official documentation for more in-depth details:\n","\n","* Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io\n","* JAX: https://jax.readthedocs.io\n","\n","Keep practicing, and happy JAXing!\n","\n","Please send us feedback at https://goo.gle/jax-training-feedback"],"metadata":{"id":"khX7Io6749dt"}},{"cell_type":"markdown","source":[],"metadata":{"id":"_S3rApFP3hum"}}]}
Loading