Skip to content

Commit 7853a72

Browse files
committed
add momentum
1 parent cca4925 commit 7853a72

File tree

4 files changed

+274
-82
lines changed

4 files changed

+274
-82
lines changed

README.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,12 @@ Giving reconstruction algorithms a warm start can speed up reconstruction time b
1212

1313
We employ three different iterative algortihms.
1414

15-
### 1) EM-preconditioner, DOwG step size rule, SAGA gradient estimation (in branch: main)
16-
**Update rule**: SGD-like for the first epochs, then SAGA-like afterwards with full-gradients computed as 2nd, 6th, 10th and 14th epochs. Here, we do not use random subsets, but rather accessing the subsets according to a Herman Meyer order.
15+
### 1) EM-preconditioner, DOwG step size rule, SAGA gradient estimation with Katyusha momentum (in branch: main)
16+
**Update rule**: We use ideas from the [Katyusha paper](https://arxiv.org/abs/1603.05953) to accelerate SAGA. In particular we choose $\theta_1 = \theta_2 = 0.5$, such that the update rules becomes
17+
18+
$$ x_{k+1} = 0.5 z_k + 0.5 \tilde{x} \\ \tilde{\nabla} = \nabla f(\tilde{x}) + \nabla f_i(x_{k+1}) - \nabla f_i(\tilde{x}) \\ z_{k+1} = z_k - \alpha \tilde{\nabla}$$
19+
20+
We access subsets according to a Herman Meyer order. We choose $\tilde{x}$ as the last precition of the previous epoch.
1721

1822
**Step-size rule**: All iterations use [DoWG](https://arxiv.org/abs/2305.16284) (Distance over Weighted Gradients) for the step size calculation.
1923

bsrem_saga.py

+31-78
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def __init__(self, data, initial,
6868
self.x_prev = None
6969
self.x_update_prev = None
7070

71+
self.x_tilde = initial.copy()
72+
7173
self.x_update = initial.get_uniform_copy(0)
74+
self.z = initial.copy()
7275

7376
self.gm = [self.x.get_uniform_copy(0) for _ in range(self.num_subsets)]
7477

@@ -94,94 +97,44 @@ def epoch(self):
9497
return self.iteration // self.num_subsets
9598

9699
def update(self):
100+
101+
if self.iteration % self.num_subsets == 0 or self.iteration == 0:
102+
self.sum_gm = self.x.get_uniform_copy(0)
103+
for i in range(self.num_subsets):
104+
gm = self.subset_gradient(self.x_tilde, self.subset_order[i])
105+
self.gm[self.subset_order[i]] = gm
106+
self.sum_gm.add(gm, out=self.sum_gm)
107+
108+
self.sum_gm /= self.num_subsets
109+
110+
subset_choice = self.subset_order[self.subset]
111+
g = self.subset_gradient(self.x, subset_choice)
97112

98-
# for the first epochs just do SGD
99-
if self.epoch() < 1:
100-
# construct gradient of subset
101-
subset_choice = self.subset_order[self.subset]
102-
g = self.subset_gradient(self.x, subset_choice)
103-
104-
g.multiply(self.x + self.eps, out=self.x_update)
105-
self.x_update.divide(self.average_sensitivity, out=self.x_update)
106-
107-
# DOwG learning rate: DOG unleashed!
108-
self.r = max((self.x - self.initial).norm(), self.r)
109-
self.v += self.r**2 * self.x_update.norm()**2
110-
step_size = 1.2*self.r**2 / np.sqrt(self.v)
111-
step_size = max(step_size, 1e-4) # dont get too small
112-
113-
if self.update_filter is not None:
114-
self.update_filter.apply(self.x_update)
115-
116-
#print(self.alpha, self.sum_gradient)
117-
self.x.sapyb(1.0, self.x_update, step_size, out=self.x)
118-
#self.x += self.alpha * self.x_update
119-
self.x.maximum(0, out=self.x)
120-
121-
# do SAGA
122-
else:
123-
# do one step of full gradient descent to set up subset gradients
124-
if (self.epoch() in [1,2,6,10,14]) and self.iteration % self.num_subsets == 0:
125-
# construct gradient of subset
126-
#print("One full gradient step to intialise SAGA")
127-
g = self.x.get_uniform_copy(0)
128-
for i in range(self.num_subsets):
129-
gm = self.subset_gradient(self.x, self.subset_order[i])
130-
self.gm[self.subset_order[i]] = gm
131-
g.add(gm, out=g)
132-
#g += gm
133-
134-
g /= self.num_subsets
135-
136-
# DOwG learning rate: DOG unleashed!
137-
self.r = max((self.x - self.initial).norm(), self.r)
138-
self.v += self.r**2 * self.x_update.norm()**2
139-
step_size = self.r**2 / np.sqrt(self.v)
140-
step_size = max(step_size, 1e-4) # dont get too small
141-
142-
g.multiply(self.x + self.eps, out=self.x_update)
143-
self.x_update.divide(self.average_sensitivity, out=self.x_update)
144-
145-
if self.update_filter is not None:
146-
self.update_filter.apply(self.x_update)
147-
148-
self.x.sapyb(1.0, self.x_update, step_size, out=self.x)
149-
150-
# threshold to non-negative
151-
self.x.maximum(0, out=self.x)
152-
153-
self.sum_gm = self.x.get_uniform_copy(0)
154-
for gm in self.gm:
155-
self.sum_gm += gm
156-
157-
158-
subset_choice = self.subset_order[self.subset]
159-
g = self.subset_gradient(self.x, subset_choice)
160-
161-
gradient = (g - self.gm[subset_choice]) + self.sum_gm / self.num_subsets
113+
gradient = (g - self.gm[subset_choice]) + self.sum_gm
162114

163-
gradient.multiply(self.x + self.eps, out=self.x_update)
164-
self.x_update.divide(self.average_sensitivity, out=self.x_update)
115+
gradient.multiply(self.x + self.eps, out=self.x_update)
116+
self.x_update.divide(self.average_sensitivity, out=self.x_update)
165117

166-
if self.update_filter is not None:
167-
self.update_filter.apply(self.x_update)
118+
if self.update_filter is not None:
119+
self.update_filter.apply(self.x_update)
168120

169-
# DOwG learning rate: DOG unleashed!
170-
self.r = max((self.x - self.initial).norm(), self.r)
171-
self.v += self.r**2 * self.x_update.norm()**2
172-
step_size = self.r**2 / np.sqrt(self.v)
173-
step_size = max(step_size, 1e-4) # dont get too small
121+
# DOwG learning rate: DOG unleashed!
122+
self.r = max((self.x - self.initial).norm(), self.r)
123+
self.v += self.r**2 * self.x_update.norm()**2
124+
step_size = self.r**2 / np.sqrt(self.v)
125+
step_size = max(step_size, 1e-3) # dont get too small
126+
self.z.sapyb(1.0, self.x_update, step_size, out=self.z)
174127

175-
self.x.sapyb(1.0, self.x_update, step_size, out=self.x)
128+
# threshold to non-negative
129+
self.z.maximum(0, out=self.z)
176130

177-
# threshold to non-negative
178-
self.x.maximum(0, out=self.x)
131+
self.x_tilde.sapyb(0.5, self.z, 0.5, out=self.x)
179132

180-
self.sum_gm = self.sum_gm - self.gm[subset_choice] + g
181-
self.gm[subset_choice] = g
133+
self.x_tilde = self.x.copy()
182134

183135
self.subset = (self.subset + 1) % self.num_subsets
184136

137+
185138
def update_objective(self):
186139
# required for current CIL (needs to set self.loss)
187140
self.loss.append(self.objective_function(self.x))

legacy_stuff/bsrem_saga_old.py

+235
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
#
2+
#
3+
# Classes implementing the SAGA algorithm in sirf.STIR
4+
#
5+
# A. Defazio, F. Bach, and S. Lacoste-Julien, “SAGA: A Fast
6+
# Incremental Gradient Method With Support for Non-Strongly
7+
# Convex Composite Objectives,” in Advances in Neural Infor-
8+
# mation Processing Systems, vol. 27, Curran Associates, Inc., 2014
9+
#
10+
# Twyman, R., Arridge, S., Kereta, Z., Jin, B., Brusaferri, L.,
11+
# Ahn, S., ... & Thielemans, K. (2022). An investigation of stochastic variance
12+
# reduction algorithms for relative difference penalized 3D PET image reconstruction.
13+
# IEEE Transactions on Medical Imaging, 42(1), 29-41.
14+
15+
import numpy
16+
import numpy as np
17+
import sirf.STIR as STIR
18+
19+
from cil.optimisation.algorithms import Algorithm
20+
from utils.herman_meyer import herman_meyer_order
21+
22+
import torch
23+
24+
class BSREMSkeleton(Algorithm):
25+
''' Main implementation of a modified BSREM algorithm
26+
27+
This essentially implements constrained preconditioned gradient ascent
28+
with an EM-type preconditioner.
29+
30+
In each update step, the gradient of a subset is computed, multiplied by a step_size and a EM-type preconditioner.
31+
Before adding this to the previous iterate, an update_filter can be applied.
32+
33+
'''
34+
def __init__(self, data, initial,
35+
update_filter=STIR.TruncateToCylinderProcessor(),
36+
**kwargs):
37+
'''
38+
Arguments:
39+
``data``: list of items as returned by `partitioner`
40+
``initial``: initial estimate
41+
``initial_step_size``, ``relaxation_eta``: step-size constants
42+
``update_filter`` is applied on the (additive) update term, i.e. before adding to the previous iterate.
43+
Set the filter to `None` if you don't want any.
44+
'''
45+
super().__init__(**kwargs)
46+
self.x = initial.copy()
47+
self.initial = initial.copy()
48+
self.data = data
49+
self.num_subsets = len(data)
50+
51+
# compute small number to add to image in preconditioner
52+
# don't make it too small as otherwise the algorithm cannot recover from zeroes.
53+
self.eps = initial.max()/1e3
54+
self.average_sensitivity = initial.get_uniform_copy(0)
55+
for s in range(len(data)):
56+
self.average_sensitivity += self.subset_sensitivity(s)/self.num_subsets
57+
# add a small number to avoid division by zero in the preconditioner
58+
self.average_sensitivity += self.average_sensitivity.max()/1e4
59+
60+
self.precond = initial.get_uniform_copy(0)
61+
62+
self.subset = 0
63+
self.update_filter = update_filter
64+
self.configured = True
65+
66+
self.subset_order = herman_meyer_order(self.num_subsets)
67+
68+
self.x_prev = None
69+
self.x_update_prev = None
70+
71+
self.x_update = initial.get_uniform_copy(0)
72+
73+
self.gm = [self.x.get_uniform_copy(0) for _ in range(self.num_subsets)]
74+
75+
self.sum_gm = self.x.get_uniform_copy(0)
76+
self.x_update = self.x.get_uniform_copy(0)
77+
78+
self.r = 0.1
79+
self.v = 0 # weighted gradient sum
80+
81+
def subset_sensitivity(self, subset_num):
82+
raise NotImplementedError
83+
84+
def subset_gradient(self, x, subset_num):
85+
raise NotImplementedError
86+
87+
def subset_gradient_likelihood(self, x, subset_num):
88+
raise NotImplementedError
89+
90+
def subset_gradient_prior(self, x, subset_num):
91+
raise NotImplementedError
92+
93+
def epoch(self):
94+
return self.iteration // self.num_subsets
95+
96+
def update(self):
97+
98+
# for the first epochs just do SGD
99+
if self.epoch() < 1:
100+
# construct gradient of subset
101+
subset_choice = self.subset_order[self.subset]
102+
g = self.subset_gradient(self.x, subset_choice)
103+
104+
g.multiply(self.x + self.eps, out=self.x_update)
105+
self.x_update.divide(self.average_sensitivity, out=self.x_update)
106+
107+
if self.update_filter is not None:
108+
self.update_filter.apply(self.x_update)
109+
110+
# DOwG learning rate: DOG unleashed!
111+
self.r = max((self.x - self.initial).norm(), self.r)
112+
self.v += self.r**2 * self.x_update.norm()**2
113+
step_size = 1.05*self.r**2 / np.sqrt(self.v)
114+
step_size = max(step_size, 1e-4) # dont get too small
115+
116+
#print(self.alpha, self.sum_gradient)
117+
self.x.sapyb(1.0, self.x_update, step_size, out=self.x)
118+
#self.x += self.alpha * self.x_update
119+
self.x.maximum(0, out=self.x)
120+
121+
# do SAGA
122+
else:
123+
# do one step of full gradient descent to set up subset gradients
124+
if (self.epoch() in [1,2,6,10,14]) and self.iteration % self.num_subsets == 0:
125+
# construct gradient of subset
126+
#print("One full gradient step to intialise SAGA")
127+
g = self.x.get_uniform_copy(0)
128+
for i in range(self.num_subsets):
129+
gm = self.subset_gradient(self.x, self.subset_order[i])
130+
self.gm[self.subset_order[i]] = gm
131+
g.add(gm, out=g)
132+
#g += gm
133+
134+
g /= self.num_subsets
135+
136+
137+
g.multiply(self.x + self.eps, out=self.x_update)
138+
self.x_update.divide(self.average_sensitivity, out=self.x_update)
139+
140+
if self.update_filter is not None:
141+
self.update_filter.apply(self.x_update)
142+
143+
# DOwG learning rate: DOG unleashed!
144+
self.r = max((self.x - self.initial).norm(), self.r)
145+
self.v += self.r**2 * self.x_update.norm()**2
146+
step_size = self.r**2 / np.sqrt(self.v)
147+
step_size = max(step_size, 1e-4) # dont get too small
148+
149+
self.x.sapyb(1.0, self.x_update, step_size, out=self.x)
150+
151+
# threshold to non-negative
152+
self.x.maximum(0, out=self.x)
153+
154+
self.sum_gm = self.x.get_uniform_copy(0)
155+
for gm in self.gm:
156+
self.sum_gm += gm
157+
158+
159+
subset_choice = self.subset_order[self.subset]
160+
g = self.subset_gradient(self.x, subset_choice)
161+
162+
gradient = (g - self.gm[subset_choice]) + self.sum_gm / self.num_subsets
163+
164+
gradient.multiply(self.x + self.eps, out=self.x_update)
165+
self.x_update.divide(self.average_sensitivity, out=self.x_update)
166+
167+
if self.update_filter is not None:
168+
self.update_filter.apply(self.x_update)
169+
170+
# DOwG learning rate: DOG unleashed!
171+
self.r = max((self.x - self.initial).norm(), self.r)
172+
self.v += self.r**2 * self.x_update.norm()**2
173+
step_size = self.r**2 / np.sqrt(self.v)
174+
step_size = max(step_size, 1e-4) # dont get too small
175+
176+
self.x.sapyb(1.0, self.x_update, step_size, out=self.x)
177+
178+
# threshold to non-negative
179+
self.x.maximum(0, out=self.x)
180+
181+
self.sum_gm = self.sum_gm - self.gm[subset_choice] + g
182+
self.gm[subset_choice] = g
183+
184+
self.subset = (self.subset + 1) % self.num_subsets
185+
186+
def update_objective(self):
187+
# required for current CIL (needs to set self.loss)
188+
self.loss.append(self.objective_function(self.x))
189+
190+
def objective_function(self, x):
191+
''' value of objective function summed over all subsets '''
192+
v = 0
193+
#for s in range(len(self.data)):
194+
# v += self.subset_objective(x, s)
195+
return v
196+
197+
def objective_function_inter(self, x):
198+
''' value of objective function summed over all subsets '''
199+
v = 0
200+
for s in range(len(self.data)):
201+
v += self.subset_objective(x, s)
202+
return v
203+
204+
205+
def subset_objective(self, x, subset_num):
206+
''' value of objective function for one subset '''
207+
raise NotImplementedError
208+
209+
210+
class BSREM(BSREMSkeleton):
211+
''' SAGA implementation using sirf.STIR objective functions'''
212+
def __init__(self, data, obj_funs, initial, **kwargs):
213+
'''
214+
construct Algorithm with lists of data and, objective functions, initial estimate
215+
and optionally Algorithm parameters
216+
'''
217+
self.obj_funs = obj_funs
218+
super().__init__(data, initial, **kwargs)
219+
220+
def subset_sensitivity(self, subset_num):
221+
''' Compute sensitivity for a particular subset'''
222+
self.obj_funs[subset_num].set_up(self.x)
223+
# note: sirf.STIR Poisson likelihood uses `get_subset_sensitivity(0) for the whole
224+
# sensitivity if there are no subsets in that likelihood
225+
return self.obj_funs[subset_num].get_subset_sensitivity(0)
226+
227+
def subset_gradient(self, x, subset_num):
228+
''' Compute gradient at x for a particular subset'''
229+
return self.obj_funs[subset_num].gradient(x)
230+
231+
def subset_objective(self, x, subset_num):
232+
''' value of objective function for one subset '''
233+
return self.obj_funs[subset_num](x)
234+
235+

0 commit comments

Comments
 (0)