forked from ndwork/dworkLib
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchambollePockWLS.m
More file actions
129 lines (113 loc) · 3.83 KB
/
chambollePockWLS.m
File metadata and controls
129 lines (113 loc) · 3.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
function [xStar,objValues] = chambollePockWLS( x, proxf, proxgConj, varargin )
% [xStar,objValues] = chambollePockWLS( x, proxf, proxgConj [, ...
% 'N', N, 'A', A, 'f', f, 'g', g, 'mu', mu, 'tau', tau, ...
% 'theta', theta, 'y', y, 'verbose', verbose ] )
%
% Implements Chambolle-Pock (Primal-Dual Hybrid graident method) with line search
% based on A First-Order Primal-Dual Algorithm with Linesearch by Malitsky and Pock
%
% minimizes f( x ) + g( A x )
%
% Inputs:
% x - initial guess
%
% Optional Inputs:
% A - if A is not provided, it is assumed to be the identity
% f - to determine the objective values, f must be provided
% g - to determine the objective values, g must be provided
% N - the number of iterations that CP will perform (default is 100)
% y - the initial values of y in the CP iterations
%
% Outputs:
% xStar - the optimal point
%
% Optional Outputs:
% objValues - a 1D array containing the objective value of each iteration
%
% Written by Nicholas Dwork - Copyright 2019
%
% This software is offered under the GNU General Public License 3.0. It
% is offered without any warranty expressed or implied, including the
% implied warranties of merchantability or fitness for a particular
% purpose.
p = inputParser;
p.addParameter( 'A', [] );
p.addParameter( 'beta', 1, @ispositive );
p.addParameter( 'delta', 0.99, @(x) x>0 && x<1 );
p.addParameter( 'doCheckAdjoint', false, @(x) islogical(x) || x == 1 || x == 0 );
p.addParameter( 'f', [] );
p.addParameter( 'g', [] );
p.addParameter( 'mu', 0.7, @(x) x>0 && x<1 );
p.addParameter( 'N', 100, @ispositive );
p.addParameter( 'printEvery', 1, @ispositive );
p.addParameter( 'tau', 1, @ispositive );
p.addParameter( 'theta', 1, @ispositive );
p.addParameter( 'y', [], @isnumeric );
p.addParameter( 'verbose', false, @(x) islogical(x) || x == 1 || x == 0 );
p.parse( varargin{:} );
A = p.Results.A;
beta = p.Results.beta;
delta = p.Results.delta;
doCheckAdjoint = p.Results.doCheckAdjoint;
f = p.Results.f;
g = p.Results.g;
mu = p.Results.mu;
N = p.Results.N;
printEvery = p.Results.printEvery;
tau = p.Results.tau;
theta = p.Results.theta;
y = p.Results.y;
verbose = p.Results.verbose;
if numel( A ) == 0
applyA = @(x) x;
applyAT = @(x) x;
elseif isnumeric( A )
applyA = @(x) A * x;
applyAT = @(y) A' * y;
else
applyA = @(x) A( x, 'notransp' );
applyAT = @(x) A( x, 'transp' );
end
if numel( y ) == 0, y = applyA( x ); end
if doCheckAdjoint == true
[adjointCheckPassed,adjCheckErr] = checkAdjoint( x, applyA, applyAT );
if ~adjointCheckPassed, error([ 'checkAdjoint failed with error ', num2str(adjCheckErr) ]); end
end
if nargout > 1, objValues = zeros( N, 1 ); end
for optIter = 1 : N
lastX = x;
tmp = lastX - tau * applyAT( y );
x = proxf( tmp, tau );
if nargout > 1
objValues( optIter ) = f( x ) + g( applyA( x ) );
end
if verbose == true
if mod( optIter, printEvery ) == 0 || optIter == 1
if nargout > 1
disp([ 'chambollePockWLS: working on ', indx2str(optIter,N), ' of ', num2str(N), ', ', ...
'objective value: ', num2str( objValues( optIter ),'%15.13f' ) ]);
else
disp([ 'chambollePockWLS: working on ', indx2str(optIter,N), ' of ', num2str(N) ]);
end
end
end
lastTau = tau;
tau = tau * sqrt( 1 + theta );
diffx = x - lastX;
lastY = y;
while true
theta = tau / lastTau;
xBar = x + theta * ( diffx );
betaTau = beta * tau;
tmp = lastY + betaTau * applyA( xBar );
y = proxgConj( tmp, betaTau );
diffy = y - lastY;
ATdiffy = applyAT( diffy );
if tau * sqrt( beta ) * norm( ATdiffy(:) ) <= delta * norm( diffy(:) )
break
end
tau = mu * tau;
end
end
xStar = x;
end