15
15
from scipy import stats
16
16
from pymc3 .distributions .distribution import generate_samples , draw_values
17
17
18
+
18
19
def extend_axis_aet (array , axis ):
19
20
n = array .shape [axis ] + 1
20
21
sum_vals = array .sum (axis , keepdims = True )
21
22
norm = sum_vals / (np .sqrt (n ) + n )
22
23
fill_val = norm - sum_vals / np .sqrt (n )
23
-
24
+
24
25
out = aet .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
25
26
return out - norm .astype (str (array .dtype ))
26
27
@@ -32,7 +33,7 @@ def extend_axis_rev_aet(array: np.ndarray, axis: int):
32
33
33
34
n = array .shape [axis ]
34
35
last = aet .take (array , [- 1 ], axis = axis )
35
-
36
+
36
37
sum_vals = - last * np .sqrt (n )
37
38
norm = sum_vals / (np .sqrt (n ) + n )
38
39
slice_before = (slice (None , None ),) * axis
@@ -44,15 +45,15 @@ def extend_axis(array, axis):
44
45
sum_vals = array .sum (axis , keepdims = True )
45
46
norm = sum_vals / (np .sqrt (n ) + n )
46
47
fill_val = norm - sum_vals / np .sqrt (n )
47
-
48
+
48
49
out = np .concatenate ([array , fill_val .astype (str (array .dtype ))], axis = axis )
49
50
return out - norm .astype (str (array .dtype ))
50
51
51
52
52
53
def extend_axis_rev (array , axis ):
53
54
n = array .shape [axis ]
54
55
last = np .take (array , [- 1 ], axis = axis )
55
-
56
+
56
57
sum_vals = - last * np .sqrt (n )
57
58
norm = sum_vals / (np .sqrt (n ) + n )
58
59
slice_before = (slice (None , None ),) * len (array .shape [:axis ])
@@ -61,60 +62,60 @@ def extend_axis_rev(array, axis):
61
62
62
63
class ZeroSumTransform (pm .distributions .transforms .Transform ):
63
64
name = "zerosum"
64
-
65
+
65
66
_active_dims : List [int ]
66
-
67
+
67
68
def __init__ (self , active_dims ):
68
69
self ._active_dims = active_dims
69
-
70
+
70
71
def forward (self , x ):
71
72
for axis in self ._active_dims :
72
73
x = extend_axis_rev_aet (x , axis = axis )
73
74
return x
74
-
75
+
75
76
def forward_val (self , x , point = None ):
76
77
for axis in self ._active_dims :
77
78
x = extend_axis_rev (x , axis = axis )
78
79
return x
79
-
80
+
80
81
def backward (self , z ):
81
82
z = aet .as_tensor_variable (z )
82
83
for axis in self ._active_dims :
83
84
z = extend_axis_aet (z , axis = axis )
84
85
return z
85
-
86
+
86
87
def jacobian_det (self , x ):
87
- return aet .constant (0. )
88
-
89
-
88
+ return aet .constant (0.0 )
89
+
90
+
90
91
class ZeroSumNormal (pm .Continuous ):
91
92
def __init__ (self , sigma = 1 , * , active_dims = None , active_axes = None , ** kwargs ):
92
93
shape = kwargs .get ("shape" , ())
93
94
dims = kwargs .get ("dims" , None )
94
95
if isinstance (shape , int ):
95
96
shape = (shape ,)
96
-
97
+
97
98
if isinstance (dims , str ):
98
99
dims = (dims ,)
99
100
100
101
self .mu = self .median = self .mode = aet .zeros (shape )
101
102
self .sigma = aet .as_tensor_variable (sigma )
102
-
103
+
103
104
if active_dims is None and active_axes is None :
104
105
if shape :
105
106
active_axes = (- 1 ,)
106
107
else :
107
108
active_axes = ()
108
-
109
+
109
110
if isinstance (active_axes , int ):
110
111
active_axes = (active_axes ,)
111
-
112
+
112
113
if isinstance (active_dims , str ):
113
114
active_dims = (active_dims ,)
114
-
115
+
115
116
if active_axes is not None and active_dims is not None :
116
117
raise ValueError ("Only one of active_axes and active_dims can be specified." )
117
-
118
+
118
119
if active_dims is not None :
119
120
model = pm .modelcontext (None )
120
121
print (model .RV_dims )
@@ -123,19 +124,19 @@ def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
123
124
active_axes = []
124
125
for dim in active_dims :
125
126
active_axes .append (dims .index (dim ))
126
-
127
+
127
128
super ().__init__ (** kwargs , transform = ZeroSumTransform (active_axes ))
128
129
129
130
def logp (self , x ):
130
131
return pm .Normal .dist (sigma = self .sigma ).logp (x )
131
-
132
+
132
133
@staticmethod
133
134
def _random (scale , size ):
134
135
samples = stats .norm .rvs (loc = 0 , scale = scale , size = size )
135
136
return samples - np .mean (samples , axis = - 1 , keepdims = True )
136
-
137
+
137
138
def random (self , point = None , size = None ):
138
- sigma , = draw_values ([self .sigma ], point = point , size = size )
139
+ ( sigma ,) = draw_values ([self .sigma ], point = point , size = size )
139
140
return generate_samples (self ._random , scale = sigma , dist_shape = self .shape , size = size )
140
141
141
142
def _distr_parameters_for_repr (self ):
0 commit comments