-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add trace optimizer. * Clean up doc string * move jnp to np * add epsilon to avoid infinities.
- Loading branch information
Showing
8 changed files
with
837 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
""" | ||
ZnRND: A Zincwarecode package. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the Zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
Summary | ||
------- | ||
Module for testing the trace optimizer | ||
""" | ||
import jax.numpy as np | ||
from neural_tangents import stax | ||
|
||
from znrnd.accuracy_functions import LabelAccuracy | ||
from znrnd.data import MNISTGenerator | ||
from znrnd.loss_functions import CrossEntropyLoss | ||
from znrnd.models import NTModel | ||
from znrnd.optimizers import TraceOptimizer | ||
|
||
|
||
class TestTraceOptimizer: | ||
""" | ||
Test suite for optimizers. | ||
""" | ||
|
||
def test_optimizer_instantiation(self): | ||
""" | ||
Unit test for the trace optimizer | ||
""" | ||
# Test default settings | ||
my_optimizer = TraceOptimizer(scale_factor=100.0) | ||
assert my_optimizer.scale_factor == 100.0 | ||
assert my_optimizer.rescale_interval == 1 | ||
|
||
# Test custom settings | ||
my_optimizer = TraceOptimizer(scale_factor=50.0, rescale_interval=5) | ||
assert my_optimizer.scale_factor == 50.0 | ||
assert my_optimizer.rescale_interval == 5 | ||
|
||
def test_apply_operation(self): | ||
""" | ||
Test the apply operation of the optimizer. | ||
""" | ||
# Set parameters. | ||
scale_factor = 10 | ||
rescale_interval = 1 | ||
|
||
# Use MNIST data | ||
data = MNISTGenerator(ds_size=10) | ||
|
||
# Define the optimizer | ||
optimizer = TraceOptimizer( | ||
scale_factor=scale_factor, rescale_interval=rescale_interval | ||
) | ||
|
||
# Use small dense model | ||
network = stax.serial( | ||
stax.Flatten(), stax.Dense(5), stax.Relu(), stax.Dense(10) | ||
) | ||
# Define the model | ||
model = NTModel( | ||
loss_fn=CrossEntropyLoss(), | ||
optimizer=optimizer, | ||
input_shape=(1, 28, 28, 1), | ||
nt_module=network, | ||
accuracy_fn=LabelAccuracy(), | ||
batch_size=5, | ||
training_threshold=0.01, | ||
) | ||
|
||
# Get theoretical values | ||
ntk = model.compute_ntk(data.train_ds["inputs"], normalize=False)["empirical"] | ||
expected_lr = scale_factor / np.trace(ntk) | ||
|
||
# Compute actual values | ||
actual_lr = optimizer.apply_optimizer( | ||
model_state=model.model_state, | ||
data_set=data.train_ds["inputs"], | ||
ntk_fn=model.compute_ntk, | ||
epoch=1, | ||
).opt_state | ||
|
||
assert actual_lr.hyperparams["learning_rate"] == expected_lr |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
ZnRND: A zincwarecode package. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
Summary | ||
------- | ||
Init file for the optimizer package | ||
""" | ||
from znrnd.optimizers.trace_optimizer import TraceOptimizer | ||
|
||
__all__ = [TraceOptimizer.__name__] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
""" | ||
ZnRND: A zincwarecode package. | ||
License | ||
------- | ||
This program and the accompanying materials are made available under the terms | ||
of the Eclipse Public License v2.0 which accompanies this distribution, and is | ||
available at https://www.eclipse.org/legal/epl-v20.html | ||
SPDX-License-Identifier: EPL-2.0 | ||
Copyright Contributors to the zincwarecode Project. | ||
Contact Information | ||
------------------- | ||
email: [email protected] | ||
github: https://github.com/zincware | ||
web: https://zincwarecode.com/ | ||
Citation | ||
-------- | ||
If you use this module please cite us with: | ||
Summary | ||
------- | ||
Module for the trace optimizer. | ||
""" | ||
from dataclasses import dataclass | ||
from typing import Callable | ||
|
||
import jax.numpy as np | ||
import optax | ||
from flax.training.train_state import TrainState | ||
|
||
|
||
@dataclass | ||
class TraceOptimizer: | ||
""" | ||
Class implementation of the trace optimizer | ||
Attributes | ||
---------- | ||
scale_factor : float | ||
Scale factor to apply to the optimizer. | ||
rescale_interval : int | ||
Number of epochs to wait before re-scaling the learning rate. | ||
""" | ||
|
||
scale_factor: float | ||
rescale_interval: float = 1 | ||
|
||
@optax.inject_hyperparams | ||
def optimizer(self, learning_rate): | ||
return optax.sgd(learning_rate) | ||
|
||
def apply_optimizer( | ||
self, | ||
model_state: TrainState, | ||
data_set: np.ndarray, | ||
ntk_fn: Callable, | ||
epoch: int, | ||
): | ||
""" | ||
Apply the optimizer to a model state. | ||
Parameters | ||
---------- | ||
model_state : TrainState | ||
Current state of the model | ||
data_set : jnp.ndarray | ||
Data-set to use in the computation. | ||
ntk_fn : Callable | ||
Function to use for the NTK computation | ||
epoch : int | ||
Current epoch | ||
Returns | ||
------- | ||
new_state : TrainState | ||
New state of the model | ||
""" | ||
eps = 1e-8 | ||
# Check if the update should be performed. | ||
if epoch % self.rescale_interval == 0: | ||
# Compute the ntk trace. | ||
ntk = ntk_fn(data_set, normalize=False)["empirical"] | ||
trace = np.trace(ntk) | ||
|
||
# Create the new optimizer. | ||
new_optimizer = self.optimizer(self.scale_factor / (trace + eps)) | ||
|
||
# Create the new state | ||
new_state = TrainState.create( | ||
apply_fn=model_state.apply_fn, | ||
params=model_state.params, | ||
tx=new_optimizer, | ||
) | ||
else: | ||
# If no update is needed, return the old state. | ||
new_state = model_state | ||
|
||
return new_state |