Skip to content

Commit 7438002

Browse files
jereliuedward-bot
authored andcommitted
Removes unnecessary ViT-GP hyper-parameters.
PiperOrigin-RevId: 388484029
1 parent 194a984 commit 7438002

File tree

4 files changed

+158
-6
lines changed

4 files changed

+158
-6
lines changed

edward2/jax/nn/random_feature.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
[3]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
2828
Machines. In _Neural Information Processing Systems_, 2007.
2929
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
30+
[4]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal
31+
Jackknife. _arXiv preprint arXiv:2006.07584_, 2020.
32+
https://arxiv.org/abs/2006.07584
3033
"""
3134
import dataclasses
3235
import functools
@@ -47,8 +50,13 @@
4750

4851
# Default config for random features.
4952
default_rbf_activation = jnp.cos
50-
default_rbf_kernel_init = nn.initializers.normal(stddev=1.)
5153
default_rbf_bias_init = nn.initializers.uniform(scale=2. * jnp.pi)
54+
# Using "he_normal" style random feature distribution. Effectively, this is
55+
# equivalent to approximating a RBF kernel but with the input standardized by
56+
# its dimensionality (i.e., input_scaled = input * sqrt(2. / dim_input)) and
57+
# empirically leads to better performance for neural network inputs.
58+
default_rbf_kernel_init = nn.initializers.variance_scaling(
59+
scale=2.0, mode='fan_in', distribution='normal')
5260

5361
# Default field value for kwargs, to be used for data class declaration.
5462
default_kwarg_dict = lambda: dataclasses.field(default_factory=dict)
@@ -149,7 +157,7 @@ class RandomFourierFeatures(nn.Module):
149157
dtype: the dtype of the computation (default: float32).
150158
"""
151159
features: int
152-
feature_scale: Optional[jnp.float32] = None
160+
feature_scale: Optional[jnp.float32] = 1.
153161
activation: Callable[[Array], Array] = default_rbf_activation
154162
kernel_init: Initializer = default_rbf_kernel_init
155163
bias_init: Initializer = default_rbf_bias_init

edward2/jax/nn/random_feature_test.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import edward2.jax as ed
2323

24+
import flax.linen as nn
25+
2426
import jax
2527
import jax.numpy as jnp
2628
import numpy as np
@@ -94,6 +96,10 @@ def setUp(self):
9496
self.x_test = _generate_normal_data(
9597
self.num_test_sample, self.num_data_dim, seed=21)
9698

99+
# Uses classic RBF random feature distribution.
100+
self.hidden_kwargs = dict(
101+
kernel_init=nn.initializers.normal(stddev=1.), feature_scale=None)
102+
97103
self.rbf_approx_maximum_tol = 5e-3
98104
self.rbf_approx_average_tol = 5e-4
99105
self.primal_dual_maximum_diff = 1e-6
@@ -105,6 +111,7 @@ def one_step_rfgp_result(self, train_data, test_data, **eval_kwargs):
105111
features=1,
106112
hidden_features=self.num_random_features,
107113
normalize_input=False,
114+
hidden_kwargs=self.hidden_kwargs,
108115
covmat_kwargs=dict(ridge_penalty=self.ridge_penalty))
109116

110117
# Computes posterior covariance on test data.
@@ -231,13 +238,19 @@ def setUp(self):
231238
self.x_test = _generate_normal_data(self.num_train_sample,
232239
self.num_data_dim)
233240

241+
# Uses classic RBF random feature distribution.
242+
self.hidden_kwargs = dict(
243+
kernel_init=nn.initializers.normal(stddev=1.), feature_scale=None)
244+
234245
self.kernel_approx_tolerance = dict(atol=5e-2, rtol=1e-2)
235246

236247
def test_random_feature_mutable_collection(self):
237248
"""Tests if RFF variables are properly nested under a mutable collection."""
238249
rng = jax.random.PRNGKey(self.seed)
239250
rff_layer = ed.nn.RandomFourierFeatures(
240-
features=self.num_random_features, collection_name=self.collection_name)
251+
features=self.num_random_features,
252+
collection_name=self.collection_name,
253+
**self.hidden_kwargs)
241254

242255
# Computes forward pass with mutable collection specified.
243256
init_vars = rff_layer.init(rng, self.x_train)
@@ -260,7 +273,8 @@ def test_random_feature_mutable_collection(self):
260273
def test_random_feature_nd_input(self, input_shape):
261274
rng = jax.random.PRNGKey(self.seed)
262275
x = jnp.ones(input_shape)
263-
rff_layer = ed.nn.RandomFourierFeatures(features=self.num_random_features)
276+
rff_layer = ed.nn.RandomFourierFeatures(
277+
features=self.num_random_features, **self.hidden_kwargs)
264278
y, _ = rff_layer.init_with_output(rng, x)
265279

266280
expected_output_shape = input_shape[:-1] + (self.num_random_features,)
@@ -270,7 +284,9 @@ def test_random_feature_kernel_approximation(self):
270284
"""Tests if default RFF layer approximates a RBF kernel matrix."""
271285
rng = jax.random.PRNGKey(self.seed)
272286
rff_layer = ed.nn.RandomFourierFeatures(
273-
features=self.num_random_features, collection_name=self.collection_name)
287+
features=self.num_random_features,
288+
collection_name=self.collection_name,
289+
**self.hidden_kwargs)
274290

275291
# Extracts random features by computing forward pass.
276292
init_vars = rff_layer.init(rng, self.x_train)

edward2/jax/nn/utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
"""JAX layer and utils."""
1717

18-
from typing import Iterable, Callable
18+
from typing import Callable, Iterable, Optional
1919

2020
from jax import random
2121
import jax.numpy as jnp
2222

23+
Array = jnp.ndarray
2324
DType = type(jnp.float32)
2425
InitializeFn = Callable[[jnp.ndarray, Iterable[int], DType], jnp.ndarray]
2526

@@ -48,3 +49,55 @@ def initializer(key, shape, dtype=jnp.float32):
4849
x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0
4950
return x.astype(dtype)
5051
return initializer
52+
53+
54+
def mean_field_logits(logits: Array,
55+
covmat: Optional[Array] = None,
56+
mean_field_factor: float = 1.,
57+
likelihood: str = 'logistic'):
58+
"""Adjust the model logits so its softmax approximates the posterior mean [4].
59+
60+
Arguments:
61+
logits: A float ndarray of shape (batch_size, num_classes).
62+
covmat: A float ndarray of shape (batch_size, ). If None then it is assumed
63+
to be a vector of 1.'s.
64+
mean_field_factor: The scale factor for mean-field approximation, used to
65+
adjust the influence of posterior variance in posterior mean
66+
approximation. If covmat=None then it is used as the scaling parameter for
67+
temperature scaling.
68+
likelihood: name of the likelihood for integration in Gaussian-approximated
69+
latent posterior. Must be one of ('logistic', 'binary_logistic',
70+
'poisson').
71+
72+
Returns:
73+
A float ndarray of uncertainty-adjusted logits, shape
74+
(batch_size, num_classes).
75+
76+
Raises:
77+
(ValueError) If likelihood is not one of ('logistic', 'binary_logistic',
78+
'poisson').
79+
"""
80+
if likelihood not in ('logistic', 'binary_logistic', 'poisson'):
81+
raise ValueError(
82+
f'Likelihood" must be one of (\'logistic\', \'binary_logistic\', \'poisson\'), got {likelihood}.'
83+
)
84+
85+
if mean_field_factor < 0:
86+
return logits
87+
88+
# Defines predictive variance.
89+
variances = 1. if covmat is None else covmat
90+
91+
# Computes scaling coefficient for mean-field approximation.
92+
if likelihood == 'poisson':
93+
logits_scale = jnp.exp(-variances * mean_field_factor / 2.) # pylint:disable=invalid-unary-operand-type
94+
else:
95+
logits_scale = jnp.sqrt(1. + variances * mean_field_factor)
96+
97+
# Pads logits_scale to compatible dimension.
98+
while logits_scale.ndim < logits.ndim:
99+
logits_scale = jnp.expand_dims(logits_scale, axis=-1)
100+
101+
return logits / logits_scale
102+
103+

edward2/jax/nn/utils_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding=utf-8
2+
# Copyright 2021 The Edward2 Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for utils."""
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
20+
import edward2.jax as ed
21+
22+
import jax
23+
import jax.numpy as jnp
24+
25+
import numpy as np
26+
import tensorflow as tf
27+
28+
29+
class MeanFieldLogitsTest(parameterized.TestCase, tf.test.TestCase):
30+
31+
def testMeanFieldLogitsLikelihood(self):
32+
"""Tests if scaling is correct under different likelihood."""
33+
batch_size = 10
34+
num_classes = 12
35+
variance = 1.5
36+
mean_field_factor = 2.
37+
38+
rng_key = jax.random.PRNGKey(0)
39+
logits = jax.random.normal(rng_key, (batch_size, num_classes))
40+
covmat = jnp.ones(batch_size) * variance
41+
42+
logits_logistic = ed.nn.utils.mean_field_logits(
43+
logits, covmat, mean_field_factor=mean_field_factor)
44+
logits_poisson = ed.nn.utils.mean_field_logits(
45+
logits,
46+
covmat,
47+
mean_field_factor=mean_field_factor,
48+
likelihood='poisson')
49+
50+
self.assertAllClose(logits_logistic, logits / 2., atol=1e-4)
51+
self.assertAllClose(logits_poisson, logits * np.exp(1.5), atol=1e-4)
52+
53+
def testMeanFieldLogitsTemperatureScaling(self):
54+
"""Tests using mean_field_logits as temperature scaling method."""
55+
batch_size = 10
56+
num_classes = 12
57+
58+
rng_key = jax.random.PRNGKey(0)
59+
logits = jax.random.normal(rng_key, (batch_size, num_classes))
60+
61+
# Test if there's no change to logits when mean_field_factor < 0.
62+
logits_no_change = ed.nn.utils.mean_field_logits(
63+
logits, covmat=None, mean_field_factor=-1)
64+
65+
# Test if mean_field_logits functions as a temperature scaling method when
66+
# mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor).
67+
logits_scale_by_two = ed.nn.utils.mean_field_logits(
68+
logits, covmat=None, mean_field_factor=3.)
69+
70+
self.assertAllClose(logits_no_change, logits, atol=1e-4)
71+
self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4)
72+
73+
74+
if __name__ == '__main__':
75+
absltest.main()

0 commit comments

Comments
 (0)