Skip to content

Commit f783ad7

Browse files
jereliuedward-bot
authored andcommitted
Removes unnecessary ViT-GP hyper-parameters.
Due to [pull #489](#489) to `edward2.jax.nn.RandomFeatureGaussianProcess`. Some of the special hyper-parameter configs are no longer needed. Therefore we remove them to simplify the model API. PiperOrigin-RevId: 388484029
1 parent 8de4e03 commit f783ad7

File tree

3 files changed

+132
-1
lines changed

3 files changed

+132
-1
lines changed

edward2/jax/nn/random_feature.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
[3]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
2828
Machines. In _Neural Information Processing Systems_, 2007.
2929
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
30+
[4]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal
31+
Jackknife. _arXiv preprint arXiv:2006.07584_, 2020.
32+
https://arxiv.org/abs/2006.07584
3033
"""
3134
import dataclasses
3235
import functools

edward2/jax/nn/utils.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515

1616
"""JAX layer and utils."""
1717

18-
from typing import Iterable, Callable
18+
from typing import Callable, Iterable, Optional
1919

2020
from jax import random
2121
import jax.numpy as jnp
2222

23+
Array = jnp.ndarray
2324
DType = type(jnp.float32)
2425
InitializeFn = Callable[[jnp.ndarray, Iterable[int], DType], jnp.ndarray]
2526

@@ -48,3 +49,55 @@ def initializer(key, shape, dtype=jnp.float32):
4849
x = random.normal(key, shape, dtype) * (-random_sign_init) + 1.0
4950
return x.astype(dtype)
5051
return initializer
52+
53+
54+
def mean_field_logits(logits: Array,
55+
covmat: Optional[Array] = None,
56+
mean_field_factor: float = 1.,
57+
likelihood: str = 'logistic'):
58+
"""Adjust the model logits so its softmax approximates the posterior mean [4].
59+
60+
Arguments:
61+
logits: A float ndarray of shape (batch_size, num_classes).
62+
covmat: A float ndarray of shape (batch_size, ). If None then it is assumed
63+
to be a vector of 1.'s.
64+
mean_field_factor: The scale factor for mean-field approximation, used to
65+
adjust the influence of posterior variance in posterior mean
66+
approximation. If covmat=None then it is used as the scaling parameter for
67+
temperature scaling.
68+
likelihood: name of the likelihood for integration in Gaussian-approximated
69+
latent posterior. Must be one of ('logistic', 'binary_logistic',
70+
'poisson').
71+
72+
Returns:
73+
A float ndarray of uncertainty-adjusted logits, shape
74+
(batch_size, num_classes).
75+
76+
Raises:
77+
(ValueError) If likelihood is not one of ('logistic', 'binary_logistic',
78+
'poisson').
79+
"""
80+
if likelihood not in ('logistic', 'binary_logistic', 'poisson'):
81+
raise ValueError(
82+
f'Likelihood" must be one of (\'logistic\', \'binary_logistic\', \'poisson\'), got {likelihood}.'
83+
)
84+
85+
if mean_field_factor < 0:
86+
return logits
87+
88+
# Defines predictive variance.
89+
variances = 1. if covmat is None else covmat
90+
91+
# Computes scaling coefficient for mean-field approximation.
92+
if likelihood == 'poisson':
93+
logits_scale = jnp.exp(-variances * mean_field_factor / 2.) # pylint:disable=invalid-unary-operand-type
94+
else:
95+
logits_scale = jnp.sqrt(1. + variances * mean_field_factor)
96+
97+
# Pads logits_scale to compatible dimension.
98+
while logits_scale.ndim < logits.ndim:
99+
logits_scale = jnp.expand_dims(logits_scale, axis=-1)
100+
101+
return logits / logits_scale
102+
103+

edward2/jax/nn/utils_test.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# coding=utf-8
2+
# Copyright 2021 The Edward2 Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for utils."""
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
20+
import edward2.jax as ed
21+
22+
import jax
23+
import jax.numpy as jnp
24+
25+
import numpy as np
26+
import tensorflow as tf
27+
28+
29+
class MeanFieldLogitsTest(parameterized.TestCase, tf.test.TestCase):
30+
31+
def testMeanFieldLogitsLikelihood(self):
32+
"""Tests if scaling is correct under different likelihood."""
33+
batch_size = 10
34+
num_classes = 12
35+
variance = 1.5
36+
mean_field_factor = 2.
37+
38+
rng_key = jax.random.PRNGKey(0)
39+
logits = jax.random.normal(rng_key, (batch_size, num_classes))
40+
covmat = jnp.ones(batch_size) * variance
41+
42+
logits_logistic = ed.nn.utils.mean_field_logits(
43+
logits, covmat, mean_field_factor=mean_field_factor)
44+
logits_poisson = ed.nn.utils.mean_field_logits(
45+
logits,
46+
covmat,
47+
mean_field_factor=mean_field_factor,
48+
likelihood='poisson')
49+
50+
self.assertAllClose(logits_logistic, logits / 2., atol=1e-4)
51+
self.assertAllClose(logits_poisson, logits * np.exp(1.5), atol=1e-4)
52+
53+
def testMeanFieldLogitsTemperatureScaling(self):
54+
"""Tests using mean_field_logits as temperature scaling method."""
55+
batch_size = 10
56+
num_classes = 12
57+
58+
rng_key = jax.random.PRNGKey(0)
59+
logits = jax.random.normal(rng_key, (batch_size, num_classes))
60+
61+
# Test if there's no change to logits when mean_field_factor < 0.
62+
logits_no_change = ed.nn.utils.mean_field_logits(
63+
logits, covmat=None, mean_field_factor=-1)
64+
65+
# Test if mean_field_logits functions as a temperature scaling method when
66+
# mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor).
67+
logits_scale_by_two = ed.nn.utils.mean_field_logits(
68+
logits, covmat=None, mean_field_factor=3.)
69+
70+
self.assertAllClose(logits_no_change, logits, atol=1e-4)
71+
self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4)
72+
73+
74+
if __name__ == '__main__':
75+
absltest.main()

0 commit comments

Comments
 (0)