Skip to content

Commit 788e909

Browse files
committed
make tests pass - formatting of ZeroSumNormal.py
1 parent 5e0b00c commit 788e909

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

examples/generalized_linear_models/ZeroSumNormal.py

+24-23
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from scipy import stats
1616
from pymc3.distributions.distribution import generate_samples, draw_values
1717

18+
1819
def extend_axis_aet(array, axis):
1920
n = array.shape[axis] + 1
2021
sum_vals = array.sum(axis, keepdims=True)
2122
norm = sum_vals / (np.sqrt(n) + n)
2223
fill_val = norm - sum_vals / np.sqrt(n)
23-
24+
2425
out = aet.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
2526
return out - norm.astype(str(array.dtype))
2627

@@ -32,7 +33,7 @@ def extend_axis_rev_aet(array: np.ndarray, axis: int):
3233

3334
n = array.shape[axis]
3435
last = aet.take(array, [-1], axis=axis)
35-
36+
3637
sum_vals = -last * np.sqrt(n)
3738
norm = sum_vals / (np.sqrt(n) + n)
3839
slice_before = (slice(None, None),) * axis
@@ -44,15 +45,15 @@ def extend_axis(array, axis):
4445
sum_vals = array.sum(axis, keepdims=True)
4546
norm = sum_vals / (np.sqrt(n) + n)
4647
fill_val = norm - sum_vals / np.sqrt(n)
47-
48+
4849
out = np.concatenate([array, fill_val.astype(str(array.dtype))], axis=axis)
4950
return out - norm.astype(str(array.dtype))
5051

5152

5253
def extend_axis_rev(array, axis):
5354
n = array.shape[axis]
5455
last = np.take(array, [-1], axis=axis)
55-
56+
5657
sum_vals = -last * np.sqrt(n)
5758
norm = sum_vals / (np.sqrt(n) + n)
5859
slice_before = (slice(None, None),) * len(array.shape[:axis])
@@ -61,60 +62,60 @@ def extend_axis_rev(array, axis):
6162

6263
class ZeroSumTransform(pm.distributions.transforms.Transform):
6364
name = "zerosum"
64-
65+
6566
_active_dims: List[int]
66-
67+
6768
def __init__(self, active_dims):
6869
self._active_dims = active_dims
69-
70+
7071
def forward(self, x):
7172
for axis in self._active_dims:
7273
x = extend_axis_rev_aet(x, axis=axis)
7374
return x
74-
75+
7576
def forward_val(self, x, point=None):
7677
for axis in self._active_dims:
7778
x = extend_axis_rev(x, axis=axis)
7879
return x
79-
80+
8081
def backward(self, z):
8182
z = aet.as_tensor_variable(z)
8283
for axis in self._active_dims:
8384
z = extend_axis_aet(z, axis=axis)
8485
return z
85-
86+
8687
def jacobian_det(self, x):
87-
return aet.constant(0.)
88-
89-
88+
return aet.constant(0.0)
89+
90+
9091
class ZeroSumNormal(pm.Continuous):
9192
def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
9293
shape = kwargs.get("shape", ())
9394
dims = kwargs.get("dims", None)
9495
if isinstance(shape, int):
9596
shape = (shape,)
96-
97+
9798
if isinstance(dims, str):
9899
dims = (dims,)
99100

100101
self.mu = self.median = self.mode = aet.zeros(shape)
101102
self.sigma = aet.as_tensor_variable(sigma)
102-
103+
103104
if active_dims is None and active_axes is None:
104105
if shape:
105106
active_axes = (-1,)
106107
else:
107108
active_axes = ()
108-
109+
109110
if isinstance(active_axes, int):
110111
active_axes = (active_axes,)
111-
112+
112113
if isinstance(active_dims, str):
113114
active_dims = (active_dims,)
114-
115+
115116
if active_axes is not None and active_dims is not None:
116117
raise ValueError("Only one of active_axes and active_dims can be specified.")
117-
118+
118119
if active_dims is not None:
119120
model = pm.modelcontext(None)
120121
print(model.RV_dims)
@@ -123,19 +124,19 @@ def __init__(self, sigma=1, *, active_dims=None, active_axes=None, **kwargs):
123124
active_axes = []
124125
for dim in active_dims:
125126
active_axes.append(dims.index(dim))
126-
127+
127128
super().__init__(**kwargs, transform=ZeroSumTransform(active_axes))
128129

129130
def logp(self, x):
130131
return pm.Normal.dist(sigma=self.sigma).logp(x)
131-
132+
132133
@staticmethod
133134
def _random(scale, size):
134135
samples = stats.norm.rvs(loc=0, scale=scale, size=size)
135136
return samples - np.mean(samples, axis=-1, keepdims=True)
136-
137+
137138
def random(self, point=None, size=None):
138-
sigma, = draw_values([self.sigma], point=point, size=size)
139+
(sigma,) = draw_values([self.sigma], point=point, size=size)
139140
return generate_samples(self._random, scale=sigma, dist_shape=self.shape, size=size)
140141

141142
def _distr_parameters_for_repr(self):

0 commit comments

Comments
 (0)