Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 99 additions & 166 deletions lectures/jax_nn.md
Original file line number Diff line number Diff line change
Expand Up @@ -715,22 +715,21 @@ Not surprisingly, Keras has more overhead from its abstraction layers.
```{exercise}
:label: jax_nn_ex1

Try to reduce the MSE on the validation data without significantly increasing the computational load.
Try to reduce the MSE on the validation data without significantly increasing
the computational load.

You should hold constant both the number of epochs and the total number of parameters in the network.
You should hold constant both the number of epochs and the total number of
parameters in the network.

Currently, the network has 4 layers with output dimension $k=10$, giving a total of:
- Layer 0: $1 \times 10 + 10 = 20$ parameters (weights + biases)
- Layer 1: $10 \times 10 + 10 = 110$ parameters
- Layer 2: $10 \times 10 + 10 = 110$ parameters
- Layer 3: $10 \times 1 + 1 = 11$ parameters
- Total: $251$ parameters
Currently, the network has 4 layers with output dimension $k=10$, giving a total
of $251$ parameters

You can experiment with:
- Changing the network architecture
- Trying different activation functions (e.g., `jax.nn.relu`, `jax.nn.gelu`, `jax.nn.sigmoid`, `jax.nn.elu`)
- Modifying the optimizer (e.g., different learning rates, learning rate schedules, momentum, other Optax optimizers)
- Trying different activation functions
- Modifying the optimizer (e.g., learning rates, learning rate schedules, momentum, etc.)
- Experimenting with different weight initialization strategies
- Modifying the loss function (e.g., adding regularization)


Which combination gives you the lowest validation MSE?
Expand All @@ -743,106 +742,24 @@ Which combination gives you the lowest validation MSE?

Let's implement and test several strategies.

**Strategy 1: Deeper Network Architecture**
**Strategy 1: Deeper Network + LR Schedule + L2 Regularization**

Let's try a deeper network with 6 layers instead of 4, keeping total parameters ≤ 251:
Let's try a deeper network (6 layers) combined with learning rate schedule and L2 regularization:

```{code-cell} ipython3
# Strategy 1: Deeper network (6 layers with k=6)
# Layer sizes: 1→6→6→6→6→6→1
# Parameters: (1×6+6) + 4×(6×6+6) + (6×1+1) = 12 + 4×42 + 7 = 187 < 251
θ = initialize_network(param_key, config)

# Strategy 1: Deeper network + LR schedule + L2 regularization
# Define deeper network architecture
def initialize_deep_params(
key: jax.Array, # JAX random key
k: int = 6, # Layer width
num_hidden: int = 5 # Number of hidden layers
key: jax.Array,
k: int = 6,
num_hidden: int = 5
):
" Initialize parameters for deeper network with k=6. "
layer_sizes = tuple([1] + [k] * num_hidden + [1])
config_deep = Config(layer_sizes=layer_sizes)
return initialize_network(key, config_deep)

θ_deep = initialize_deep_params(param_key)
config_deep = Config(layer_sizes=(1, 6, 6, 6, 6, 6, 1))

# Warmup
train_jax_optax_adam(θ_deep, x_train, y_train, config_deep)

# Actual run
θ_deep = initialize_deep_params(param_key)
start_time = time()
θ_deep = train_jax_optax_adam(θ_deep, x_train, y_train, config_deep)
θ_deep[0].W.block_until_ready()
deep_runtime = time() - start_time

deep_mse = loss_fn(θ_deep, x_validate, y_validate)
print(f"Strategy 1 - Deeper network (6 layers, k=6)")
print(f" Total parameters: 187")
print(f" Runtime: {deep_runtime:.2f}s")
print(f" Validation MSE: {deep_mse:.6f}")
print(f" Improvement over ADAM: {optax_adam_mse - deep_mse:.6f}")
```

**Strategy 2: Deeper Network + Learning Rate Schedule**

Since the deeper network performed best, let's combine it with the learning rate schedule:

```{code-cell} ipython3
# Strategy 2: Deeper network + LR schedule
θ_deep = initialize_deep_params(param_key)

def train_deep_with_schedule(
θ: list,
x: jnp.ndarray,
y: jnp.ndarray,
config: Config
):
" Train deeper network with learning rate schedule. "
epochs = config.epochs
schedule = optax.exponential_decay(
init_value=0.003,
transition_steps=1000,
decay_rate=0.5
)

solver = optax.adam(schedule)
opt_state = solver.init(θ)

def update(_, loop_state):
θ, opt_state = loop_state
grad = loss_gradient(θ, x, y)
updates, new_opt_state = solver.update(grad, opt_state, θ)
θ_new = optax.apply_updates(θ, updates)
return (θ_new, new_opt_state)

initial_loop_state = θ, opt_state
θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state)
return θ_final

# Warmup
train_deep_with_schedule(θ_deep, x_train, y_train, config_deep)

# Actual run
θ_deep = initialize_deep_params(param_key)
start_time = time()
θ_deep_schedule = train_deep_with_schedule(θ_deep, x_train, y_train, config_deep)
θ_deep_schedule[0].W.block_until_ready()
deep_schedule_runtime = time() - start_time

deep_schedule_mse = loss_fn(θ_deep_schedule, x_validate, y_validate)
print(f"Strategy 2 - Deeper network + LR schedule")
print(f" Runtime: {deep_schedule_runtime:.2f}s")
print(f" Validation MSE: {deep_schedule_mse:.6f}")
print(f" Improvement over ADAM: {optax_adam_mse - deep_schedule_mse:.6f}")
```

**Strategy 3: Deeper Network + LR Schedule + L2 Regularization**

Let's add L2 regularization (similar to ridge regression) to penalize complexity:

```{code-cell} ipython3
# Strategy 3: Deeper network + LR schedule + L2 regularization
θ_deep = initialize_deep_params(param_key)

def train_deep_with_schedule_and_l2(
Expand Down Expand Up @@ -898,73 +815,99 @@ start_time = time()
deep_l2_runtime = time() - start_time

deep_l2_mse = loss_fn(θ_deep_l2, x_validate, y_validate)
print(f"Strategy 3 - Deeper network + LR schedule + L2 regularization")
print(f"Strategy 1 - Deeper network + LR schedule + L2 regularization")
print(f" Runtime: {deep_l2_runtime:.2f}s")
print(f" Validation MSE: {deep_l2_mse:.6f}")
print(f" Improvement over ADAM: {optax_adam_mse - deep_l2_mse:.6f}")
```

**Strategy 4: Baseline + L2 Regularization**
**Strategy 2: Baseline + Armijo Line Search**

Let's see if L2 regularization helps the baseline architecture:
Let's implement gradient descent with [Armijo line search](https://en.wikipedia.org/wiki/Backtracking_line_search) for adaptive step size selection:

```{code-cell} ipython3
# Strategy 4: Baseline architecture + L2 regularization
θ = initialize_network(param_key, config)
# Strategy 2: Baseline architecture + Armijo line search
# Line search parameters
line_search_init_value = 0.01
line_search_backtrack_factor = 0.5
line_search_armijo_constant = 0.001
max_backtrack_steps = 20

def train_baseline_with_l2(
θ: list,
x: jnp.ndarray,
y: jnp.ndarray,
config: Config,
lambda_l2: float = 0.001
@partial(jax.jit, static_argnames=['config'])
def train_jax_armijo_ls(
θ: list, # Initial parameters (pytree)
x: jnp.ndarray, # Training input data
y: jnp.ndarray, # Training target data
config: Config # contains configuration data
):
" Train baseline model with L2 regularization. "
"""
Train model using gradient descent with Armijo line search.

The Armijo line search adaptively finds a suitable step size at each
iteration by ensuring sufficient decrease in the loss function.
"""
epochs = config.epochs
learning_rate = config.learning_rate

# Define regularized loss function
@jax.jit
def loss_fn_l2(θ, x, y):
# Standard MSE loss
mse = jnp.mean((f(θ, x) - y)**2)
# L2 penalty on weights (not biases)
l2_penalty = 0.0
for W, b in θ:
l2_penalty += jnp.sum(W**2)
return mse + lambda_l2 * l2_penalty
# Line search parameters
init_alpha = line_search_init_value
backtrack_factor = line_search_backtrack_factor
_armijo_constant = line_search_armijo_constant

loss_gradient_l2 = jax.jit(jax.grad(loss_fn_l2))
def update_step(current_theta, x_data, y_data):
current_loss = loss_fn(current_theta, x_data, y_data)
grad = loss_gradient(current_theta, x_data, y_data)

solver = optax.adam(learning_rate)
opt_state = solver.init(θ)
# Calculate squared Euclidean norm of the gradient for Armijo condition
grad_norm_sq = jax.tree_util.tree_reduce(
lambda a, b: a + jnp.sum(b**2), grad, initializer=0.0
)

def update(_, loop_state):
θ, opt_state = loop_state
grad = loss_gradient_l2(θ, x, y)
updates, new_opt_state = solver.update(grad, opt_state, θ)
θ_new = optax.apply_updates(θ, updates)
return (θ_new, new_opt_state)
# Define the condition for the while_loop
def cond_fn(loop_args):
alpha_val, current_loss_val, grad_sq_sum, theta_orig, x_in, y_in, step_count = loop_args
loss_threshold = current_loss_val - _armijo_constant * alpha_val * grad_sq_sum
theta_candidate = jax.tree.map(lambda p, g_leaf: p - alpha_val * g_leaf, theta_orig, grad)
loss_candidate = loss_fn(theta_candidate, x_in, y_in)
return (loss_candidate > loss_threshold) & (step_count < max_backtrack_steps)

# Define the body for the while_loop
def body_fn(loop_args):
alpha_val, current_loss_val, grad_sq_sum, theta_orig, x_in, y_in, step_count = loop_args
new_alpha = alpha_val * backtrack_factor
new_step_count = step_count + 1
return (new_alpha, current_loss_val, grad_sq_sum, theta_orig, x_in, y_in, new_step_count)

# Execute the Armijo line search using jax.lax.while_loop
final_alpha, _, _, _, _, _, _ = jax.lax.while_loop(
cond_fn,
body_fn,
(init_alpha, current_loss, grad_norm_sq, current_theta, x_data, y_data, 0)
)

initial_loop_state = θ, opt_state
θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state)
# Update parameters with the chosen step size
theta_new = jax.tree.map(lambda p, g_leaf: p - final_alpha * g_leaf, current_theta, grad)
return theta_new

# Main training loop (epochs)
θ_final = jax.lax.fori_loop(0, epochs, lambda i, current_theta: update_step(current_theta, x, y), θ)
return θ_final

# Warmup
train_baseline_with_l2(θ, x_train, y_train, config)
θ = initialize_network(param_key, config)
train_jax_armijo_ls(θ, x_train, y_train, config)

# Actual run
θ = initialize_network(param_key, config)
start_time = time()
θ_baseline_l2 = train_baseline_with_l2(θ, x_train, y_train, config)
θ_baseline_l2[0].W.block_until_ready()
baseline_l2_runtime = time() - start_time

baseline_l2_mse = loss_fn(θ_baseline_l2, x_validate, y_validate)
print(f"Strategy 4 - Baseline + L2 regularization")
print(f" Runtime: {baseline_l2_runtime:.2f}s")
print(f" Validation MSE: {baseline_l2_mse:.6f}")
print(f" Improvement over ADAM: {optax_adam_mse - baseline_l2_mse:.6f}")
θ_armijo = train_jax_armijo_ls(θ, x_train, y_train, config)
θ_armijo[0].W.block_until_ready()
armijo_runtime = time() - start_time

armijo_mse = loss_fn(θ_armijo, x_validate, y_validate)
print(f"Strategy 2 - Baseline + Armijo Line Search")
print(f" Runtime: {armijo_runtime:.2f}s")
print(f" Validation MSE: {armijo_mse:.6f}")
print(f" Improvement over ADAM: {optax_adam_mse - armijo_mse:.6f}")
```

**Results Summary**
Expand All @@ -978,31 +921,23 @@ Let's compare all strategies:
strategies_results = {
'Strategy': [
'Baseline (ADAM + tanh)',
'1. Deeper network (6 layers)',
'2. Deeper network + LR schedule',
'3. Strategy 2 + L2 regularization',
'4. Baseline + L2 regularization'
'1. Deeper network + LR schedule + L2',
'2. Baseline + Armijo Line Search'
],
'Runtime (s)': [
optax_adam_runtime,
deep_runtime,
deep_schedule_runtime,
deep_l2_runtime,
baseline_l2_runtime
armijo_runtime
],
'Validation MSE': [
optax_adam_mse,
deep_mse,
deep_schedule_mse,
deep_l2_mse,
baseline_l2_mse
armijo_mse
],
'Improvement': [
0.0,
float(optax_adam_mse - deep_mse),
float(optax_adam_mse - deep_schedule_mse),
float(optax_adam_mse - deep_l2_mse),
float(optax_adam_mse - baseline_l2_mse)
float(optax_adam_mse - armijo_mse)
]
}

Expand All @@ -1012,19 +947,17 @@ print(df_strategies.to_string(index=False))
```


The experimental results reveal several lessons:

1. Architecture matters: A deeper, narrower network outperformed the
baseline network, despite using fewer parameters (187 vs 251).
In terms of reducing loss on the validation test data, the current winner is the
Armijo line search strategy.

2. Combining strategies: Combining the deeper architecture with a learning
rate schedule showed that synergistic improvements are possible.
The Armijo backtracking line search is an adaptive step size method that
dynamically adjusts the learning rate at each iteration to ensure sufficient
decrease in the loss function.

3. Regularization helps: Adding L2 regularization (ridge penalty) can
improve performance by penalizing model complexity and reducing overfitting.
Unlike fixed learning rates or predetermined schedules, it adapts to the local
geometry of the loss landscape.

4. Regularization vs architecture: Comparing strategies 3 and 4 shows whether
regularization is more effective with deeper architectures or simpler ones.
This strategy and its code was contributed by [Matyas Farkas](https://www.matyasfarkas.eu/).
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subject-verb agreement error: "This strategy and its code" is a plural subject requiring the plural verb "were" instead of "was".

Suggested change
This strategy and its code was contributed by [Matyas Farkas](https://www.matyasfarkas.eu/).
This strategy and its code were contributed by [Matyas Farkas](https://www.matyasfarkas.eu/).

Copilot uses AI. Check for mistakes.



Expand Down
Loading