-
-
Notifications
You must be signed in to change notification settings - Fork 6
Streamline neural network exercise strategies #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jstac
wants to merge
4
commits into
main
Choose a base branch
from
jaxnnex
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 1 commit
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -715,22 +715,21 @@ Not surprisingly, Keras has more overhead from its abstraction layers. | |||||
| ```{exercise} | ||||||
| :label: jax_nn_ex1 | ||||||
|
|
||||||
| Try to reduce the MSE on the validation data without significantly increasing the computational load. | ||||||
| Try to reduce the MSE on the validation data without significantly increasing | ||||||
| the computational load. | ||||||
|
|
||||||
| You should hold constant both the number of epochs and the total number of parameters in the network. | ||||||
| You should hold constant both the number of epochs and the total number of | ||||||
| parameters in the network. | ||||||
|
|
||||||
| Currently, the network has 4 layers with output dimension $k=10$, giving a total of: | ||||||
| - Layer 0: $1 \times 10 + 10 = 20$ parameters (weights + biases) | ||||||
| - Layer 1: $10 \times 10 + 10 = 110$ parameters | ||||||
| - Layer 2: $10 \times 10 + 10 = 110$ parameters | ||||||
| - Layer 3: $10 \times 1 + 1 = 11$ parameters | ||||||
| - Total: $251$ parameters | ||||||
| Currently, the network has 4 layers with output dimension $k=10$, giving a total | ||||||
| of $251$ parameters | ||||||
|
|
||||||
| You can experiment with: | ||||||
| - Changing the network architecture | ||||||
| - Trying different activation functions (e.g., `jax.nn.relu`, `jax.nn.gelu`, `jax.nn.sigmoid`, `jax.nn.elu`) | ||||||
| - Modifying the optimizer (e.g., different learning rates, learning rate schedules, momentum, other Optax optimizers) | ||||||
| - Trying different activation functions | ||||||
| - Modifying the optimizer (e.g., learning rates, learning rate schedules, momentum, etc.) | ||||||
| - Experimenting with different weight initialization strategies | ||||||
| - Modifying the loss function (e.g., adding regularization) | ||||||
|
|
||||||
|
|
||||||
| Which combination gives you the lowest validation MSE? | ||||||
|
|
@@ -743,106 +742,24 @@ Which combination gives you the lowest validation MSE? | |||||
|
|
||||||
| Let's implement and test several strategies. | ||||||
|
|
||||||
| **Strategy 1: Deeper Network Architecture** | ||||||
| **Strategy 1: Deeper Network + LR Schedule + L2 Regularization** | ||||||
|
|
||||||
| Let's try a deeper network with 6 layers instead of 4, keeping total parameters ≤ 251: | ||||||
| Let's try a deeper network (6 layers) combined with learning rate schedule and L2 regularization: | ||||||
|
|
||||||
| ```{code-cell} ipython3 | ||||||
| # Strategy 1: Deeper network (6 layers with k=6) | ||||||
| # Layer sizes: 1→6→6→6→6→6→1 | ||||||
| # Parameters: (1×6+6) + 4×(6×6+6) + (6×1+1) = 12 + 4×42 + 7 = 187 < 251 | ||||||
| θ = initialize_network(param_key, config) | ||||||
|
|
||||||
| # Strategy 1: Deeper network + LR schedule + L2 regularization | ||||||
| # Define deeper network architecture | ||||||
| def initialize_deep_params( | ||||||
| key: jax.Array, # JAX random key | ||||||
| k: int = 6, # Layer width | ||||||
| num_hidden: int = 5 # Number of hidden layers | ||||||
| key: jax.Array, | ||||||
| k: int = 6, | ||||||
| num_hidden: int = 5 | ||||||
| ): | ||||||
| " Initialize parameters for deeper network with k=6. " | ||||||
| layer_sizes = tuple([1] + [k] * num_hidden + [1]) | ||||||
| config_deep = Config(layer_sizes=layer_sizes) | ||||||
| return initialize_network(key, config_deep) | ||||||
|
|
||||||
| θ_deep = initialize_deep_params(param_key) | ||||||
| config_deep = Config(layer_sizes=(1, 6, 6, 6, 6, 6, 1)) | ||||||
|
|
||||||
| # Warmup | ||||||
| train_jax_optax_adam(θ_deep, x_train, y_train, config_deep) | ||||||
|
|
||||||
| # Actual run | ||||||
| θ_deep = initialize_deep_params(param_key) | ||||||
| start_time = time() | ||||||
| θ_deep = train_jax_optax_adam(θ_deep, x_train, y_train, config_deep) | ||||||
| θ_deep[0].W.block_until_ready() | ||||||
| deep_runtime = time() - start_time | ||||||
|
|
||||||
| deep_mse = loss_fn(θ_deep, x_validate, y_validate) | ||||||
| print(f"Strategy 1 - Deeper network (6 layers, k=6)") | ||||||
| print(f" Total parameters: 187") | ||||||
| print(f" Runtime: {deep_runtime:.2f}s") | ||||||
| print(f" Validation MSE: {deep_mse:.6f}") | ||||||
| print(f" Improvement over ADAM: {optax_adam_mse - deep_mse:.6f}") | ||||||
| ``` | ||||||
|
|
||||||
| **Strategy 2: Deeper Network + Learning Rate Schedule** | ||||||
|
|
||||||
| Since the deeper network performed best, let's combine it with the learning rate schedule: | ||||||
|
|
||||||
| ```{code-cell} ipython3 | ||||||
| # Strategy 2: Deeper network + LR schedule | ||||||
| θ_deep = initialize_deep_params(param_key) | ||||||
|
|
||||||
| def train_deep_with_schedule( | ||||||
| θ: list, | ||||||
| x: jnp.ndarray, | ||||||
| y: jnp.ndarray, | ||||||
| config: Config | ||||||
| ): | ||||||
| " Train deeper network with learning rate schedule. " | ||||||
| epochs = config.epochs | ||||||
| schedule = optax.exponential_decay( | ||||||
| init_value=0.003, | ||||||
| transition_steps=1000, | ||||||
| decay_rate=0.5 | ||||||
| ) | ||||||
|
|
||||||
| solver = optax.adam(schedule) | ||||||
| opt_state = solver.init(θ) | ||||||
|
|
||||||
| def update(_, loop_state): | ||||||
| θ, opt_state = loop_state | ||||||
| grad = loss_gradient(θ, x, y) | ||||||
| updates, new_opt_state = solver.update(grad, opt_state, θ) | ||||||
| θ_new = optax.apply_updates(θ, updates) | ||||||
| return (θ_new, new_opt_state) | ||||||
|
|
||||||
| initial_loop_state = θ, opt_state | ||||||
| θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state) | ||||||
| return θ_final | ||||||
|
|
||||||
| # Warmup | ||||||
| train_deep_with_schedule(θ_deep, x_train, y_train, config_deep) | ||||||
|
|
||||||
| # Actual run | ||||||
| θ_deep = initialize_deep_params(param_key) | ||||||
| start_time = time() | ||||||
| θ_deep_schedule = train_deep_with_schedule(θ_deep, x_train, y_train, config_deep) | ||||||
| θ_deep_schedule[0].W.block_until_ready() | ||||||
| deep_schedule_runtime = time() - start_time | ||||||
|
|
||||||
| deep_schedule_mse = loss_fn(θ_deep_schedule, x_validate, y_validate) | ||||||
| print(f"Strategy 2 - Deeper network + LR schedule") | ||||||
| print(f" Runtime: {deep_schedule_runtime:.2f}s") | ||||||
| print(f" Validation MSE: {deep_schedule_mse:.6f}") | ||||||
| print(f" Improvement over ADAM: {optax_adam_mse - deep_schedule_mse:.6f}") | ||||||
| ``` | ||||||
|
|
||||||
| **Strategy 3: Deeper Network + LR Schedule + L2 Regularization** | ||||||
|
|
||||||
| Let's add L2 regularization (similar to ridge regression) to penalize complexity: | ||||||
|
|
||||||
| ```{code-cell} ipython3 | ||||||
| # Strategy 3: Deeper network + LR schedule + L2 regularization | ||||||
| θ_deep = initialize_deep_params(param_key) | ||||||
|
|
||||||
| def train_deep_with_schedule_and_l2( | ||||||
|
|
@@ -898,73 +815,99 @@ start_time = time() | |||||
| deep_l2_runtime = time() - start_time | ||||||
|
|
||||||
| deep_l2_mse = loss_fn(θ_deep_l2, x_validate, y_validate) | ||||||
| print(f"Strategy 3 - Deeper network + LR schedule + L2 regularization") | ||||||
| print(f"Strategy 1 - Deeper network + LR schedule + L2 regularization") | ||||||
| print(f" Runtime: {deep_l2_runtime:.2f}s") | ||||||
| print(f" Validation MSE: {deep_l2_mse:.6f}") | ||||||
| print(f" Improvement over ADAM: {optax_adam_mse - deep_l2_mse:.6f}") | ||||||
| ``` | ||||||
|
|
||||||
| **Strategy 4: Baseline + L2 Regularization** | ||||||
| **Strategy 2: Baseline + Armijo Line Search** | ||||||
|
|
||||||
| Let's see if L2 regularization helps the baseline architecture: | ||||||
| Let's implement gradient descent with [Armijo line search](https://en.wikipedia.org/wiki/Backtracking_line_search) for adaptive step size selection: | ||||||
|
|
||||||
| ```{code-cell} ipython3 | ||||||
| # Strategy 4: Baseline architecture + L2 regularization | ||||||
| θ = initialize_network(param_key, config) | ||||||
| # Strategy 2: Baseline architecture + Armijo line search | ||||||
| # Line search parameters | ||||||
| line_search_init_value = 0.01 | ||||||
| line_search_backtrack_factor = 0.5 | ||||||
| line_search_armijo_constant = 0.001 | ||||||
| max_backtrack_steps = 20 | ||||||
|
|
||||||
| def train_baseline_with_l2( | ||||||
| θ: list, | ||||||
| x: jnp.ndarray, | ||||||
| y: jnp.ndarray, | ||||||
| config: Config, | ||||||
| lambda_l2: float = 0.001 | ||||||
| @partial(jax.jit, static_argnames=['config']) | ||||||
| def train_jax_armijo_ls( | ||||||
| θ: list, # Initial parameters (pytree) | ||||||
| x: jnp.ndarray, # Training input data | ||||||
| y: jnp.ndarray, # Training target data | ||||||
| config: Config # contains configuration data | ||||||
| ): | ||||||
| " Train baseline model with L2 regularization. " | ||||||
| """ | ||||||
| Train model using gradient descent with Armijo line search. | ||||||
|
|
||||||
| The Armijo line search adaptively finds a suitable step size at each | ||||||
| iteration by ensuring sufficient decrease in the loss function. | ||||||
| """ | ||||||
| epochs = config.epochs | ||||||
| learning_rate = config.learning_rate | ||||||
|
|
||||||
| # Define regularized loss function | ||||||
| @jax.jit | ||||||
| def loss_fn_l2(θ, x, y): | ||||||
| # Standard MSE loss | ||||||
| mse = jnp.mean((f(θ, x) - y)**2) | ||||||
| # L2 penalty on weights (not biases) | ||||||
| l2_penalty = 0.0 | ||||||
| for W, b in θ: | ||||||
| l2_penalty += jnp.sum(W**2) | ||||||
| return mse + lambda_l2 * l2_penalty | ||||||
| # Line search parameters | ||||||
| init_alpha = line_search_init_value | ||||||
| backtrack_factor = line_search_backtrack_factor | ||||||
| _armijo_constant = line_search_armijo_constant | ||||||
|
|
||||||
| loss_gradient_l2 = jax.jit(jax.grad(loss_fn_l2)) | ||||||
| def update_step(current_theta, x_data, y_data): | ||||||
| current_loss = loss_fn(current_theta, x_data, y_data) | ||||||
| grad = loss_gradient(current_theta, x_data, y_data) | ||||||
|
|
||||||
| solver = optax.adam(learning_rate) | ||||||
| opt_state = solver.init(θ) | ||||||
| # Calculate squared Euclidean norm of the gradient for Armijo condition | ||||||
| grad_norm_sq = jax.tree_util.tree_reduce( | ||||||
| lambda a, b: a + jnp.sum(b**2), grad, initializer=0.0 | ||||||
| ) | ||||||
|
|
||||||
| def update(_, loop_state): | ||||||
| θ, opt_state = loop_state | ||||||
| grad = loss_gradient_l2(θ, x, y) | ||||||
| updates, new_opt_state = solver.update(grad, opt_state, θ) | ||||||
| θ_new = optax.apply_updates(θ, updates) | ||||||
| return (θ_new, new_opt_state) | ||||||
| # Define the condition for the while_loop | ||||||
| def cond_fn(loop_args): | ||||||
| alpha_val, current_loss_val, grad_sq_sum, theta_orig, x_in, y_in, step_count = loop_args | ||||||
| loss_threshold = current_loss_val - _armijo_constant * alpha_val * grad_sq_sum | ||||||
| theta_candidate = jax.tree.map(lambda p, g_leaf: p - alpha_val * g_leaf, theta_orig, grad) | ||||||
| loss_candidate = loss_fn(theta_candidate, x_in, y_in) | ||||||
| return (loss_candidate > loss_threshold) & (step_count < max_backtrack_steps) | ||||||
|
|
||||||
| # Define the body for the while_loop | ||||||
| def body_fn(loop_args): | ||||||
| alpha_val, current_loss_val, grad_sq_sum, theta_orig, x_in, y_in, step_count = loop_args | ||||||
| new_alpha = alpha_val * backtrack_factor | ||||||
| new_step_count = step_count + 1 | ||||||
| return (new_alpha, current_loss_val, grad_sq_sum, theta_orig, x_in, y_in, new_step_count) | ||||||
|
|
||||||
| # Execute the Armijo line search using jax.lax.while_loop | ||||||
| final_alpha, _, _, _, _, _, _ = jax.lax.while_loop( | ||||||
| cond_fn, | ||||||
| body_fn, | ||||||
| (init_alpha, current_loss, grad_norm_sq, current_theta, x_data, y_data, 0) | ||||||
| ) | ||||||
|
|
||||||
| initial_loop_state = θ, opt_state | ||||||
| θ_final, _ = jax.lax.fori_loop(0, epochs, update, initial_loop_state) | ||||||
| # Update parameters with the chosen step size | ||||||
| theta_new = jax.tree.map(lambda p, g_leaf: p - final_alpha * g_leaf, current_theta, grad) | ||||||
| return theta_new | ||||||
|
|
||||||
| # Main training loop (epochs) | ||||||
| θ_final = jax.lax.fori_loop(0, epochs, lambda i, current_theta: update_step(current_theta, x, y), θ) | ||||||
| return θ_final | ||||||
|
|
||||||
| # Warmup | ||||||
| train_baseline_with_l2(θ, x_train, y_train, config) | ||||||
| θ = initialize_network(param_key, config) | ||||||
| train_jax_armijo_ls(θ, x_train, y_train, config) | ||||||
|
|
||||||
| # Actual run | ||||||
| θ = initialize_network(param_key, config) | ||||||
| start_time = time() | ||||||
| θ_baseline_l2 = train_baseline_with_l2(θ, x_train, y_train, config) | ||||||
| θ_baseline_l2[0].W.block_until_ready() | ||||||
| baseline_l2_runtime = time() - start_time | ||||||
|
|
||||||
| baseline_l2_mse = loss_fn(θ_baseline_l2, x_validate, y_validate) | ||||||
| print(f"Strategy 4 - Baseline + L2 regularization") | ||||||
| print(f" Runtime: {baseline_l2_runtime:.2f}s") | ||||||
| print(f" Validation MSE: {baseline_l2_mse:.6f}") | ||||||
| print(f" Improvement over ADAM: {optax_adam_mse - baseline_l2_mse:.6f}") | ||||||
| θ_armijo = train_jax_armijo_ls(θ, x_train, y_train, config) | ||||||
| θ_armijo[0].W.block_until_ready() | ||||||
| armijo_runtime = time() - start_time | ||||||
|
|
||||||
| armijo_mse = loss_fn(θ_armijo, x_validate, y_validate) | ||||||
| print(f"Strategy 2 - Baseline + Armijo Line Search") | ||||||
| print(f" Runtime: {armijo_runtime:.2f}s") | ||||||
| print(f" Validation MSE: {armijo_mse:.6f}") | ||||||
| print(f" Improvement over ADAM: {optax_adam_mse - armijo_mse:.6f}") | ||||||
| ``` | ||||||
|
|
||||||
| **Results Summary** | ||||||
|
|
@@ -978,31 +921,23 @@ Let's compare all strategies: | |||||
| strategies_results = { | ||||||
| 'Strategy': [ | ||||||
| 'Baseline (ADAM + tanh)', | ||||||
| '1. Deeper network (6 layers)', | ||||||
| '2. Deeper network + LR schedule', | ||||||
| '3. Strategy 2 + L2 regularization', | ||||||
| '4. Baseline + L2 regularization' | ||||||
| '1. Deeper network + LR schedule + L2', | ||||||
| '2. Baseline + Armijo Line Search' | ||||||
| ], | ||||||
| 'Runtime (s)': [ | ||||||
| optax_adam_runtime, | ||||||
| deep_runtime, | ||||||
| deep_schedule_runtime, | ||||||
| deep_l2_runtime, | ||||||
| baseline_l2_runtime | ||||||
| armijo_runtime | ||||||
| ], | ||||||
| 'Validation MSE': [ | ||||||
| optax_adam_mse, | ||||||
| deep_mse, | ||||||
| deep_schedule_mse, | ||||||
| deep_l2_mse, | ||||||
| baseline_l2_mse | ||||||
| armijo_mse | ||||||
| ], | ||||||
| 'Improvement': [ | ||||||
| 0.0, | ||||||
| float(optax_adam_mse - deep_mse), | ||||||
| float(optax_adam_mse - deep_schedule_mse), | ||||||
| float(optax_adam_mse - deep_l2_mse), | ||||||
| float(optax_adam_mse - baseline_l2_mse) | ||||||
| float(optax_adam_mse - armijo_mse) | ||||||
| ] | ||||||
| } | ||||||
|
|
||||||
|
|
@@ -1012,19 +947,17 @@ print(df_strategies.to_string(index=False)) | |||||
| ``` | ||||||
|
|
||||||
|
|
||||||
| The experimental results reveal several lessons: | ||||||
|
|
||||||
| 1. Architecture matters: A deeper, narrower network outperformed the | ||||||
| baseline network, despite using fewer parameters (187 vs 251). | ||||||
| In terms of reducing loss on the validation test data, the current winner is the | ||||||
| Armijo line search strategy. | ||||||
|
|
||||||
| 2. Combining strategies: Combining the deeper architecture with a learning | ||||||
| rate schedule showed that synergistic improvements are possible. | ||||||
| The Armijo backtracking line search is an adaptive step size method that | ||||||
| dynamically adjusts the learning rate at each iteration to ensure sufficient | ||||||
| decrease in the loss function. | ||||||
|
|
||||||
| 3. Regularization helps: Adding L2 regularization (ridge penalty) can | ||||||
| improve performance by penalizing model complexity and reducing overfitting. | ||||||
| Unlike fixed learning rates or predetermined schedules, it adapts to the local | ||||||
| geometry of the loss landscape. | ||||||
|
|
||||||
| 4. Regularization vs architecture: Comparing strategies 3 and 4 shows whether | ||||||
| regularization is more effective with deeper architectures or simpler ones. | ||||||
| This strategy and its code was contributed by [Matyas Farkas](https://www.matyasfarkas.eu/). | ||||||
|
||||||
| This strategy and its code was contributed by [Matyas Farkas](https://www.matyasfarkas.eu/). | |
| This strategy and its code were contributed by [Matyas Farkas](https://www.matyasfarkas.eu/). |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.