Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add trace optimizer to ZnRND #73

Merged
merged 4 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions CI/unit_tests/optimizers/test_trace_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
ZnRND: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for testing the trace optimizer
"""
import jax.numpy as np
from neural_tangents import stax

from znrnd.accuracy_functions import LabelAccuracy
from znrnd.data import MNISTGenerator
from znrnd.loss_functions import CrossEntropyLoss
from znrnd.models import NTModel
from znrnd.optimizers import TraceOptimizer


class TestTraceOptimizer:
"""
Test suite for optimizers.
"""

def test_optimizer_instantiation(self):
"""
Unit test for the trace optimizer
"""
# Test default settings
my_optimizer = TraceOptimizer(scale_factor=100.0)
assert my_optimizer.scale_factor == 100.0
assert my_optimizer.rescale_interval == 1

# Test custom settings
my_optimizer = TraceOptimizer(scale_factor=50.0, rescale_interval=5)
assert my_optimizer.scale_factor == 50.0
assert my_optimizer.rescale_interval == 5

def test_apply_operation(self):
"""
Test the apply operation of the optimizer.
"""
# Set parameters.
scale_factor = 10
rescale_interval = 1

# Use MNIST data
data = MNISTGenerator(ds_size=10)

# Define the optimizer
optimizer = TraceOptimizer(
scale_factor=scale_factor, rescale_interval=rescale_interval
)

# Use small dense model
network = stax.serial(
stax.Flatten(), stax.Dense(5), stax.Relu(), stax.Dense(10)
)
# Define the model
model = NTModel(
loss_fn=CrossEntropyLoss(),
optimizer=optimizer,
input_shape=(1, 28, 28, 1),
nt_module=network,
accuracy_fn=LabelAccuracy(),
batch_size=5,
training_threshold=0.01,
)

# Get theoretical values
ntk = model.compute_ntk(data.train_ds["inputs"], normalize=False)["empirical"]
expected_lr = scale_factor / np.trace(ntk)

# Compute actual values
actual_lr = optimizer.apply_optimizer(
model_state=model.model_state,
data_set=data.train_ds["inputs"],
ntk_fn=model.compute_ntk,
epoch=1,
).opt_state

assert actual_lr.hyperparams["learning_rate"] == expected_lr
616 changes: 571 additions & 45 deletions examples/CIFAR10.ipynb

Large diffs are not rendered by default.

24 changes: 10 additions & 14 deletions examples/Computing-Entropy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/samueltovey/miniconda3/envs/zincware/lib/python3.8/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.1\n",
" warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n",
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
"2023-01-02 15:09:42.140673: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2023-01-02 15:09:45.639478: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/slurm/lib:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cudnn-8.2.4.15-11.4-r5srvd2bjed7zlr75cesfus3nwsjprw6/lib64:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cuda-11.4.2-jefqkwdwi245u5nbdg5tw3ufrucvsnag/lib64:/opt/slurm/lib:\n",
"2023-01-02 15:09:45.639991: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/slurm/lib:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cudnn-8.2.4.15-11.4-r5srvd2bjed7zlr75cesfus3nwsjprw6/lib64:/software/opt/focal/x86_64/spack/2021.12/spack/opt/spack/linux-ubuntu20.04-x86_64_v2/gcc-11.2.0/cuda-11.4.2-jefqkwdwi245u5nbdg5tw3ufrucvsnag/lib64:/opt/slurm/lib:\n",
"2023-01-02 15:09:45.640027: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
"2023-01-02 15:09:53.541520: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n",
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n",
"/home/s/S.Tovey/miniconda3/envs/zincware/lib/python3.8/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.\n",
" PyTreeDef = type(jax.tree_structure(None))\n"
]
},
{
Expand Down Expand Up @@ -201,16 +206,7 @@
"execution_count": 7,
"id": "b17ba4b0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/samueltovey/miniconda3/envs/zincware/lib/python3.8/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n",
" warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n"
]
}
],
"outputs": [],
"source": [
"# Step 2\n",
"\n",
Expand Down Expand Up @@ -253,7 +249,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0.70654\n"
"0.73508817\n"
]
}
],
Expand Down
6 changes: 3 additions & 3 deletions examples/MNIST-Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...\u001B[0m\n"
"\u001b[1mDownloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to ~/tensorflow_datasets/mnist/3.0.1...\u001b[0m\n"
]
},
{
Expand All @@ -96,7 +96,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"\u001B[1mDataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001B[0m\n"
"\u001b[1mDataset mnist downloaded and prepared to ~/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.\u001b[0m\n"
]
}
],
Expand Down Expand Up @@ -25785,7 +25785,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.8.12"
}
},
"nbformat": 4,
Expand Down
27 changes: 21 additions & 6 deletions znrnd/models/jax_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
Description: Parent class for the Jax-based models.
"""
import logging
from typing import Callable, Tuple
from typing import Callable, Tuple, Union

import jax.numpy as np
import jax.random
import numpy as onp
import optax
from flax.training.train_state import TrainState
from tqdm import trange

from znrnd.accuracy_functions.accuracy_function import AccuracyFunction
from znrnd.optimizers.trace_optimizer import TraceOptimizer
from znrnd.utils.prng import PRNGKey

logger = logging.getLogger(__name__)
Expand All @@ -31,13 +33,14 @@ class JaxModel:
def __init__(
self,
loss_fn: Callable,
optimizer: Callable,
optimizer: Union[Callable, TraceOptimizer],
input_shape: tuple,
training_threshold: float,
accuracy_fn: AccuracyFunction = None,
seed: int = None,
):
"""Construct a znrnd model.
"""
Construct a znrnd model.

Parameters
----------
Expand Down Expand Up @@ -104,9 +107,13 @@ def _create_train_state(
"""
params = self._init_params(kernel_init, bias_init)

return TrainState.create(
apply_fn=self.apply_fn, params=params, tx=self.optimizer
)
# Set dummy optimizer for case of trace optimizer.
if isinstance(self.optimizer, TraceOptimizer):
optimizer = optax.sgd(1.0)
else:
optimizer = self.optimizer

return TrainState.create(apply_fn=self.apply_fn, params=params, tx=optimizer)

def _train_step(self, state: TrainState, batch: dict):
"""
Expand Down Expand Up @@ -344,6 +351,14 @@ def train_model(
for i in loading_bar:
loading_bar.set_description(f"Epoch: {i}")

if isinstance(self.optimizer, TraceOptimizer):
state = self.optimizer.apply_optimizer(
model_state=state,
data_set=train_ds["inputs"],
ntk_fn=self.compute_ntk,
epoch=i,
)

state, train_metrics = self._train_epoch(
state, train_ds, batch_size=batch_size
)
Expand Down
3 changes: 2 additions & 1 deletion znrnd/models/nt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from znrnd.accuracy_functions.accuracy_function import AccuracyFunction
from znrnd.models.jax_model import JaxModel
from znrnd.optimizers.trace_optimizer import TraceOptimizer
from znrnd.utils import normalize_covariance_matrix

logger = logging.getLogger(__name__)
Expand All @@ -48,7 +49,7 @@ class NTModel(JaxModel):
def __init__(
self,
loss_fn: Callable,
optimizer: Callable,
optimizer: Union[Callable, TraceOptimizer],
input_shape: tuple,
training_threshold: float = 0.01,
nt_module: serial = None,
Expand Down
30 changes: 30 additions & 0 deletions znrnd/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
ZnRND: A zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Init file for the optimizer package
"""
from znrnd.optimizers.trace_optimizer import TraceOptimizer

__all__ = [TraceOptimizer.__name__]
102 changes: 102 additions & 0 deletions znrnd/optimizers/trace_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
ZnRND: A zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html

SPDX-License-Identifier: EPL-2.0

Copyright Contributors to the zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Module for the trace optimizer.
"""
from dataclasses import dataclass
from typing import Callable

import jax.numpy as np
import optax
from flax.training.train_state import TrainState


@dataclass
class TraceOptimizer:
"""
Class implementation of the trace optimizer

Attributes
----------
scale_factor : float
Scale factor to apply to the optimizer.
rescale_interval : int
Number of epochs to wait before re-scaling the learning rate.
"""

scale_factor: float
rescale_interval: float = 1

@optax.inject_hyperparams
def optimizer(self, learning_rate):
return optax.sgd(learning_rate)

def apply_optimizer(
self,
model_state: TrainState,
data_set: np.ndarray,
ntk_fn: Callable,
epoch: int,
):
"""
Apply the optimizer to a model state.

Parameters
----------
model_state : TrainState
Current state of the model
data_set : jnp.ndarray
Data-set to use in the computation.
ntk_fn : Callable
Function to use for the NTK computation
epoch : int
Current epoch

Returns
-------
new_state : TrainState
New state of the model
"""
eps = 1e-8
# Check if the update should be performed.
if epoch % self.rescale_interval == 0:
# Compute the ntk trace.
ntk = ntk_fn(data_set, normalize=False)["empirical"]
trace = np.trace(ntk)

# Create the new optimizer.
new_optimizer = self.optimizer(self.scale_factor / (trace + eps))

# Create the new state
new_state = TrainState.create(
apply_fn=model_state.apply_fn,
params=model_state.params,
tx=new_optimizer,
)
else:
# If no update is needed, return the old state.
new_state = model_state

return new_state