This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathlinear_variational.py
127 lines (112 loc) · 5.5 KB
/
linear_variational.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright (C) 2020 Intel Corporation
#
# BSD-3-Clause License
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS
# BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
# OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT
# OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
# EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Linear Variational Layers with reparameterization estimator to perform
# mean-field variational inference in Bayesian neural networks. Variational layers
# enables Monte Carlo approximation of the distribution over 'kernel' and 'bias'.
#
# Kullback-Leibler divergence between the surrogate posterior and prior is computed
# and returned along with the tensors of outputs after linear opertaion, which is
# required to compute Evidence Lower Bound (ELBO) loss for variational inference.
#
# @authors: Ranganath Krishnan
#
# ======================================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
import math
class LinearVariational(Module):
def __init__(self,
prior_mean,
prior_variance,
posterior_mu_init,
posterior_rho_init,
in_features,
out_features,
bias=True):
super(LinearVariational, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.prior_mean = prior_mean
self.prior_variance = prior_variance
self.posterior_mu_init = posterior_mu_init, # mean of weight
self.posterior_rho_init = posterior_rho_init, # variance of weight --> sigma = log (1 + exp(rho))
self.bias = bias
self.mu_weight = Parameter(torch.Tensor(out_features, in_features))
self.rho_weight = Parameter(torch.Tensor(out_features, in_features))
self.register_buffer('eps_weight',
torch.Tensor(out_features, in_features))
self.register_buffer('prior_weight_mu',
torch.Tensor(out_features, in_features))
if bias:
self.mu_bias = Parameter(torch.Tensor(out_features))
self.rho_bias = Parameter(torch.Tensor(out_features))
self.register_buffer('eps_bias', torch.Tensor(out_features))
self.register_buffer('prior_bias_mu', torch.Tensor(out_features))
else:
self.register_buffer('prior_bias_mu', None)
self.register_parameter('mu_bias', None)
self.register_parameter('rho_bias', None)
self.register_buffer('eps_bias', None)
self.init_parameters()
def init_parameters(self):
self.prior_weight_mu.fill_(self.prior_mean)
self.mu_weight.data.normal_(std=0.1)
self.rho_weight.data.normal_(mean=self.posterior_rho_init[0], std=0.1)
if self.mu_bias is not None:
self.prior_bias_mu.fill_(self.prior_mean)
self.mu_bias.data.normal_(std=0.1)
self.rho_bias.data.normal_(mean=self.posterior_rho_init[0],
std=0.1)
def kl_div(self, mu_q, sigma_q, mu_p, sigma_p):
sigma_p = torch.tensor(sigma_p)
kl = torch.log(sigma_p) - torch.log(
sigma_q) + (sigma_q**2 + (mu_q - mu_p)**2) / (2 *
(sigma_p**2)) - 0.5
return kl.sum()
def forward(self, input):
sigma_weight = torch.log1p(torch.exp(self.rho_weight))
weight = self.mu_weight + (sigma_weight * self.eps_weight.normal_())
kl_weight = self.kl_div(self.mu_weight, sigma_weight,
self.prior_weight_mu, self.prior_variance)
bias = None
if self.mu_bias is not None:
sigma_bias = torch.log1p(torch.exp(self.rho_bias))
bias = self.mu_bias + (sigma_bias * self.eps_bias.normal_())
kl_bias = self.kl_div(self.mu_bias, sigma_bias, self.prior_bias_mu,
self.prior_variance)
out = F.linear(input, weight, bias)
if self.mu_bias is not None:
kl = kl_weight + kl_bias
else:
kl = kl_weight
return out, kl