Skip to content

Commit

Permalink
Add trace optimizer to ZnRND (#73)
Browse files Browse the repository at this point in the history
* add trace optimizer.

* Clean up doc string

* move jnp to np

* add epsilon to avoid infinities.
  • Loading branch information
SamTov authored Jan 4, 2023
1 parent 36b921a commit eef920d
Show file tree
Hide file tree
Showing 8 changed files with 837 additions and 69 deletions.
98 changes: 98 additions & 0 deletions CI/unit_tests/optimizers/test_trace_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
ZnRND: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Module for testing the trace optimizer
"""
import jax.numpy as np
from neural_tangents import stax

from znrnd.accuracy_functions import LabelAccuracy
from znrnd.data import MNISTGenerator
from znrnd.loss_functions import CrossEntropyLoss
from znrnd.models import NTModel
from znrnd.optimizers import TraceOptimizer


class TestTraceOptimizer:
"""
Test suite for optimizers.
"""

def test_optimizer_instantiation(self):
"""
Unit test for the trace optimizer
"""
# Test default settings
my_optimizer = TraceOptimizer(scale_factor=100.0)
assert my_optimizer.scale_factor == 100.0
assert my_optimizer.rescale_interval == 1

# Test custom settings
my_optimizer = TraceOptimizer(scale_factor=50.0, rescale_interval=5)
assert my_optimizer.scale_factor == 50.0
assert my_optimizer.rescale_interval == 5

def test_apply_operation(self):
"""
Test the apply operation of the optimizer.
"""
# Set parameters.
scale_factor = 10
rescale_interval = 1

# Use MNIST data
data = MNISTGenerator(ds_size=10)

# Define the optimizer
optimizer = TraceOptimizer(
scale_factor=scale_factor, rescale_interval=rescale_interval
)

# Use small dense model
network = stax.serial(
stax.Flatten(), stax.Dense(5), stax.Relu(), stax.Dense(10)
)
# Define the model
model = NTModel(
loss_fn=CrossEntropyLoss(),
optimizer=optimizer,
input_shape=(1, 28, 28, 1),
nt_module=network,
accuracy_fn=LabelAccuracy(),
batch_size=5,
training_threshold=0.01,
)

# Get theoretical values
ntk = model.compute_ntk(data.train_ds["inputs"], normalize=False)["empirical"]
expected_lr = scale_factor / np.trace(ntk)

# Compute actual values
actual_lr = optimizer.apply_optimizer(
model_state=model.model_state,
data_set=data.train_ds["inputs"],
ntk_fn=model.compute_ntk,
epoch=1,
).opt_state

assert actual_lr.hyperparams["learning_rate"] == expected_lr
616 changes: 571 additions & 45 deletions examples/CIFAR10.ipynb

Large diffs are not rendered by default.

24 changes: 10 additions & 14 deletions examples/Computing-Entropy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/samueltovey/miniconda3/envs/zincware/lib/python3.8/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.1\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n",
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
"2023-01-02 15:09:42.140673: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2023-01-02 15:09:45.639478: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/slurm/lib:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cudnn-8.2.4.15-11.4-r5srvd2bjed7zlr75cesfus3nwsjprw6/lib64:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cuda-11.4.2-jefqkwdwi245u5nbdg5tw3ufrucvsnag/lib64:/opt/slurm/lib:\n",
"2023-01-02 15:09:45.639991: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/slurm/lib:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cudnn-8.2.4.15-11.4-r5srvd2bjed7zlr75cesfus3nwsjprw6/lib64:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cuda-11.4.2-jefqkwdwi245u5nbdg5tw3ufrucvsnag/lib64:/opt/slurm/lib:\n",
"2023-01-02 15:09:45.640027: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"2023-01-02 15:09:53.541520: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
"/home/s/S.Tovey/miniconda3/envs/zincware/lib/python3.8/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n",
" PyTreeDef = type(jax.tree_structure(None))\n"
]
},
{
Expand Down Expand Up @@ -201,16 +206,7 @@
"execution_count": 7,
"id": "b17ba4b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/samueltovey/miniconda3/envs/zincware/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n",
" warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n"
]
}
],
"outputs": [],
"source": [
"# Step 2\n",
"\n",
Expand Down Expand Up @@ -253,7 +249,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0.70654\n"
"0.73508817\n"
]
}
],
Expand Down
6 changes: 3 additions & 3 deletions examples/MNIST-Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...\u001B[0m\n"
"\u001b[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n"
]
},
{
Expand All @@ -96,7 +96,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[1mDataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001B[0m\n"
"\u001b[1mDataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001b[0m\n"
]
}
],
Expand Down Expand Up @@ -25785,7 +25785,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.12"
}
},
"nbformat": 4,
Expand Down
27 changes: 21 additions & 6 deletions znrnd/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
Description: Parent class for the Jax-based models.
"""
import logging
from typing import Callable, Tuple
from typing import Callable, Tuple, Union

import jax.numpy as np
import jax.random
import numpy as onp
import optax
from flax.training.train_state import TrainState
from tqdm import trange

from znrnd.accuracy_functions.accuracy_function import AccuracyFunction
from znrnd.optimizers.trace_optimizer import TraceOptimizer
from znrnd.utils.prng import PRNGKey

logger = logging.getLogger(__name__)
Expand All @@ -31,13 +33,14 @@ class JaxModel:
def __init__(
self,
loss_fn: Callable,
optimizer: Callable,
optimizer: Union[Callable, TraceOptimizer],
input_shape: tuple,
training_threshold: float,
accuracy_fn: AccuracyFunction = None,
seed: int = None,
):
"""Construct a znrnd model.
"""
Construct a znrnd model.
Parameters
----------
Expand Down Expand Up @@ -104,9 +107,13 @@ def _create_train_state(
"""
params = self._init_params(kernel_init, bias_init)

return TrainState.create(
apply_fn=self.apply_fn, params=params, tx=self.optimizer
)
# Set dummy optimizer for case of trace optimizer.
if isinstance(self.optimizer, TraceOptimizer):
optimizer = optax.sgd(1.0)
else:
optimizer = self.optimizer

return TrainState.create(apply_fn=self.apply_fn, params=params, tx=optimizer)

def _train_step(self, state: TrainState, batch: dict):
"""
Expand Down Expand Up @@ -344,6 +351,14 @@ def train_model(
for i in loading_bar:
loading_bar.set_description(f"Epoch: {i}")

if isinstance(self.optimizer, TraceOptimizer):
state = self.optimizer.apply_optimizer(
model_state=state,
data_set=train_ds["inputs"],
ntk_fn=self.compute_ntk,
epoch=i,
)

state, train_metrics = self._train_epoch(
state, train_ds, batch_size=batch_size
)
Expand Down
3 changes: 2 additions & 1 deletion znrnd/models/nt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from znrnd.accuracy_functions.accuracy_function import AccuracyFunction
from znrnd.models.jax_model import JaxModel
from znrnd.optimizers.trace_optimizer import TraceOptimizer
from znrnd.utils import normalize_covariance_matrix

logger = logging.getLogger(__name__)
Expand All @@ -48,7 +49,7 @@ class NTModel(JaxModel):
def __init__(
self,
loss_fn: Callable,
optimizer: Callable,
optimizer: Union[Callable, TraceOptimizer],
input_shape: tuple,
training_threshold: float = 0.01,
nt_module: serial = None,
Expand Down
30 changes: 30 additions & 0 deletions znrnd/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
ZnRND: A zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Init file for the optimizer package
"""
from znrnd.optimizers.trace_optimizer import TraceOptimizer

__all__ = [TraceOptimizer.__name__]
102 changes: 102 additions & 0 deletions znrnd/optimizers/trace_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
ZnRND: A zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Module for the trace optimizer.
"""
from dataclasses import dataclass
from typing import Callable

import jax.numpy as np
import optax
from flax.training.train_state import TrainState


@dataclass
class TraceOptimizer:
"""
Class implementation of the trace optimizer
Attributes
----------
scale_factor : float
Scale factor to apply to the optimizer.
rescale_interval : int
Number of epochs to wait before re-scaling the learning rate.
"""

scale_factor: float
rescale_interval: float = 1

@optax.inject_hyperparams
def optimizer(self, learning_rate):
return optax.sgd(learning_rate)

def apply_optimizer(
self,
model_state: TrainState,
data_set: np.ndarray,
ntk_fn: Callable,
epoch: int,
):
"""
Apply the optimizer to a model state.
Parameters
----------
model_state : TrainState
Current state of the model
data_set : jnp.ndarray
Data-set to use in the computation.
ntk_fn : Callable
Function to use for the NTK computation
epoch : int
Current epoch
Returns
-------
new_state : TrainState
New state of the model
"""
eps = 1e-8
# Check if the update should be performed.
if epoch % self.rescale_interval == 0:
# Compute the ntk trace.
ntk = ntk_fn(data_set, normalize=False)["empirical"]
trace = np.trace(ntk)

# Create the new optimizer.
new_optimizer = self.optimizer(self.scale_factor / (trace + eps))

# Create the new state
new_state = TrainState.create(
apply_fn=model_state.apply_fn,
params=model_state.params,
tx=new_optimizer,
)
else:
# If no update is needed, return the old state.
new_state = model_state

return new_state

0 comments on commit eef920d

Please sign in to comment.