-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdistributions.py
150 lines (124 loc) · 5.18 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
class ParametrizedGaussian(object):
def __init__(self, mu, rho):
super().__init__()
self.mu = mu
self.rho = rho
# torch.distributions doesn't go to cuda when we call model.to('cuda').
# We have to manually put the sampling value to cuda as we did in function self.sample,
# the other way is to store the distribution as buffers, as described in
# https://stackoverflow.com/questions/59179609/how-to-make-a-pytorch-distribution-on-gpu
self.normal = torch.distributions.Normal(0, 1)
self.constant = (1 + math.log(2 * math.pi)) / 2
@property
def sigma(self):
# It is the standard deviation
# torch.log1p returns a new tensor with the natural logarithm of (1 + input).
# \sigma = ln(e^\rho + 1)
return torch.log1p(torch.exp(self.rho))
def sample(self, n_samples=1):
epsilon = self.normal.sample(sample_shape=(n_samples, *self.rho.size()))
epsilon = epsilon.to(self.mu.device)
return self.mu + self.sigma * epsilon
def log_prob(self, x):
return (-math.log(math.sqrt(2 * math.pi))
- torch.log(self.sigma)
- ((x - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()
def entropy(self):
"""
Computes the entropy of the Diagonal Gaussian distribution.
Details on the computation can be found in
https://math.stackexchange.com/questions/2029707/entropy-of-the-multivariate-gaussian
"""
part1 = torch.sum(torch.log(self.sigma))
part2 = self.mu.numel() * self.constant
return part1 + part2
class ScaleMixtureGaussian(object):
def __init__(self, pi, sigma1, sigma2):
super().__init__()
if pi > 1 or pi < 0:
raise ValueError(f"pi must be in the range of (0, 1). Got {pi} instead")
# pi is the (hyper)params for balancing the two Gaussian Dist
self.pi = pi
self.sigma1 = sigma1
self.sigma2 = sigma2
self.gaussian1 = torch.distributions.Normal(0, sigma1)
self.gaussian2 = torch.distributions.Normal(0, sigma2)
def log_prob(self, x):
prob1 = torch.exp(self.gaussian1.log_prob(x))
prob2 = torch.exp(self.gaussian2.log_prob(x))
return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()
class SpikeAndSlab(object):
def __init__(self, pi, sigma1, sigma2):
super().__init__()
if pi > 1 or pi < 0:
raise ValueError(f"pi must be in the range of (0, 1). Got {pi} instead")
# pi is the (hyper)params for balancing the two Gaussian Dist
self.pi = pi
self.sigma1 = sigma1
self.sigma2 = sigma2
self.gaussian1 = torch.distributions.Normal(0, sigma1)
self.gaussian2 = torch.distributions.Normal(0, sigma2)
def log_prob(self, x):
prob1 = torch.exp(self.gaussian1.log_prob(x))
prob2 = torch.exp(self.gaussian2.log_prob(x))
return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()
class InverseGamma(object):
""" Inverse Gamma distribution """
def __init__(self, shape, rate):
"""
Class constructor, sets parameters of the distribution.
Args:
shape: torch tensor of floats, shape parameters of the distribution
rate: torch tensor of floats, rate parameters of the distribution
"""
super().__init__()
self.shape = shape
self.rate = rate
def exp_inverse(self):
"""
Calculates the expectation E[1/x], where x follows
the inverse gamma distribution
"""
return self.shape / self.rate
def exp_log(self):
"""
Calculates the expectation E[log(x)], where x follows
the inverse gamma distribution
"""
exp_log = torch.log(self.rate) - torch.digamma(self.shape)
return exp_log
def entropy(self):
"""
Calculates the entropy of the inverse gamma distribution E[-ln(p(x))]
"""
entropy = self.shape + torch.log(self.rate) + torch.lgamma(self.shape) - \
(1 + self.shape) * torch.digamma(self.shape)
return torch.sum(entropy)
def logprob(self, target):
"""
Computes the value of the predictive log likelihood at the target value
log(pdf(Inv-Gamma)) = shape * log(rate) - log(Gamma(shape)) - (shape + 1) * log(x) - rate / x
Args:
target: Torch tensor of floats, point(s) to evaluate the logprob
Returns:
loglike: float, the log likelihood
"""
part1 = self.shape * torch.log(self.rate)
part2 = - torch.lgamma(self.shape)
part3 = - (self.shape + 1) * torch.log(target)
part4 = - self.rate / target
return part1 + part2 + part3 + part4
def update(self, shape, rate):
"""
Updates shape and rate of the distribution. Used for the fixed point updates.
Args:
shape: float, shape parameter of the distribution
rate: float, rate parameter of the distribution
"""
self.shape = shape
self.rate = rate