diff --git a/test/test_estim/test_linear.py b/test/test_estim/test_linear.py index d33264f..d3a4905 100644 --- a/test/test_estim/test_linear.py +++ b/test/test_estim/test_linear.py @@ -48,6 +48,8 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1): # Construct the linear estimator est = vp.estim.LinEst(Aop,y,wvar,var_axes='all') + # Construct same linear estimator with conjugate gradient method + est_cg = vp.estim.LinEst(Aop,y,wvar,var_axes='all',est_meth='cg') # Perform the initial estimate. This is just run to make sure it # doesn't crash @@ -56,6 +58,11 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1): raise vp.common.TestException(\ "est_init does not produce the correct shape") + # Initial estimate with conjugate gradient method should match SVD method + zhat_cg, zhatvar_cg = est_cg.est_init() + if not np.allclose(zhat_cg, zhat, atol=1e-3, rtol=1e-3) or not np.allclose(zhatvar_cg, zhatvar, atol=1e-1, rtol=2e-1): + raise vp.common.TestException("Conjugate gradient initial estimate estimate does not match SVD method") + # Posterior estimate zhat, zhatvar, cost = est.est(r,rvar,return_cost=True) zerr = np.mean(np.abs(z-zhat)**2) @@ -65,6 +72,11 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1): if fail: raise vp.common.TestException("Posterior estimate Gaussian error "+ " does not match predicted value") + + # Posterior estimate with conjugate gradient method should match SVD method + zhat_cg, zhatvar_cg = est_cg.est(r,rvar) + if not np.allclose(zhat_cg, zhat, atol=1e-3, rtol=1e-3) or not np.allclose(zhatvar_cg, zhatvar, atol=1e-1, rtol=2e-1): + raise vp.common.TestException("Conjugate gradient posterior estimate estimate does not match SVD method") class TestCases(unittest.TestCase): diff --git a/test/test_solver/test_vamp.py b/test/test_solver/test_vamp.py index 190027e..934d334 100644 --- a/test/test_solver/test_vamp.py +++ b/test/test_solver/test_vamp.py @@ -13,7 +13,7 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\ - is_complex=False, tol=1e-5): + is_complex=False, tol=1e-5, tol_var=1e-5, tol_cost=1e-5, est_meth='svd'): """ Unit test for VAMP method using a Gaussian prior. @@ -68,7 +68,7 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\ # Create output estimator Aop = vp.trans.MatrixLT(A,zshape) - est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est, name='Posterior') + est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est, name='Posterior', est_meth=est_meth) # Create the variance handler msg_hdl = vp.estim.MsgHdlSimp(map_est=map_est, is_complex=is_complex, \ @@ -96,7 +96,7 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\ zvar_err = np.abs(zvar-np.mean(solver.zvar2)) if verbose: print("zvar error: {0:12.4e}".format(zvar_err)) - if zvar_err > tol: + if zvar_err > tol_var: raise vp.common.TestException("Variance does not match") @@ -148,12 +148,12 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\ if verbose: print("costs: Direct {0:12.4e} ".format(cost_tot)+\ "Termwise {0:12.4e} solver: {1:12.4e}".format(cost_tota,solver.cost)) - if np.abs(cost_tot - cost_tota) > tol: + if np.abs(cost_tot - cost_tota) > tol_cost: raise vp.common.TestException("Direct and termwise costs do not match") - if np.abs(cost_tot - cost_tota) > tol: + if np.abs(cost_tot - solver.cost) > tol_cost: raise vp.common.TestException("Predicted cost does not match solver output") -def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17): +def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17, est_meth='svd'): """ Unit test for VAMP using a Gaussian mixture model (GMM) @@ -216,7 +216,7 @@ def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17): # Create output estimator Aop = vp.trans.MatrixLT(A,zshape) - est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est) + est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est,est_meth=est_meth) # Create the variance handler msg_hdl = vp.estim.MsgHdlSimp(map_est=map_est, is_complex=is_complex,\ @@ -243,7 +243,7 @@ def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17): if mse[-1] > mse_tol: raise vp.common.TestException("MSE exceeded expected value") -def vamp_bg_test(nz=1000,ny=500,ns=1, snr=30, verbose=False, pred_tol=3.0): +def vamp_bg_test(nz=1000,ny=500,ns=1, snr=30, verbose=False, pred_tol=3.0, est_meth='svd'): """ Unit test for VAMP using a Gaussian mixture model (GMM) @@ -299,7 +299,7 @@ def vamp_bg_test(nz=1000,ny=500,ns=1, snr=30, verbose=False, pred_tol=3.0): # Create output estimator Aop = vp.trans.MatrixLT(A,zshape) - est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est) + est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est, est_meth=est_meth) # Create the variance handler msg_hdl = vp.estim.MsgHdlSimp(map_est=map_est, is_complex=is_complex,\ @@ -342,18 +342,25 @@ def test_vamp_gauss(self): vamp_gauss_test(nz=200,ny=100,ns=10,map_est=map_est,verbose=verbose) vamp_gauss_test(nz=100,ny=200,ns=1,map_est=map_est,verbose=verbose) vamp_gauss_test(nz=200,ny=100,ns=1,map_est=map_est,verbose=verbose) + + vamp_gauss_test(nz=100,ny=200,ns=10,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=2e-1) + vamp_gauss_test(nz=200,ny=100,ns=10,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=2e-1) + vamp_gauss_test(nz=100,ny=200,ns=1,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=2e-1) + vamp_gauss_test(nz=200,ny=100,ns=1,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=3e-1) def test_vamp_gmm(self): """ Run the vamp_gmm_test """ vamp_gmm_test(nz=1000,ny=500,ns=1,verbose=False) + vamp_gmm_test(nz=1000,ny=500,ns=1,verbose=False,est_meth='cg') def test_vamp_bg(self): """ Run VAMP with a BG prior """ vamp_bg_test(nz=1000,ny=500,ns=10,verbose=False) + vamp_bg_test(nz=1000,ny=500,ns=10,verbose=False,est_meth='cg') if __name__ == '__main__': #vamp_bg_test(nz=1000,ny=500,ns=10,verbose=verbose) diff --git a/vampyre/estim/linear.py b/vampyre/estim/linear.py index d656599..74b1207 100644 --- a/vampyre/estim/linear.py +++ b/vampyre/estim/linear.py @@ -4,6 +4,7 @@ from __future__ import division import numpy as np +import scipy.sparse.linalg # Import other subpackages in vampyre import vampyre.common as common @@ -52,7 +53,7 @@ class LinEst(BaseEst): """ def __init__(self,A,y,wvar=0,\ wrep_axes='all', var_axes=(0,),name=None,map_est=False,\ - is_complex=False,rvar_init=1e5,tune_wvar=False): + is_complex=False,rvar_init=1e5,tune_wvar=False, est_meth='svd', nit_cg=100, atol_cg=1e-10): BaseEst.__init__(self, shape=A.shape0, var_axes=var_axes,\ dtype=A.dtype0, name=name,\ @@ -76,23 +77,55 @@ def __init__(self,A,y,wvar=0,\ wrep_axes = tuple(range(ndim)) self.wrep_axes = wrep_axes + # Initialization depending on the estimation method + self.est_meth = est_meth + self.nit_cg = nit_cg + self.atol_cg = atol_cg + if self.est_meth == 'svd': + self.init_svd() + elif self.est_meth == 'cg': + self.init_cg() + else: + raise common.VpException( + "Unknown estimation method {0:s}".format(est_meth)) + + def init_cg(self): + """ + Initialization that is specific to the conjugate gradient method + """ + + # Draw random perturbations for computing the numerical gradients + grad_step = .01 + self.dr0 = np.random.normal(0,grad_step,self.zshape) + self.dr0_norm_sq = np.mean(np.abs(self.dr0)**2, self.var_axes) + + # Initialize the variables + self.zlast = None + self.zvec0_last = None + + def init_svd(self): + """ + Initialization for the SVD method + """ + # Compute the SVD terms # Take an SVD A=USV'. Then write p = SV'z + w, - if not A.svd_avail: + if not self.A.svd_avail: raise common.VpException("Transform must support an SVD") - self.p = A.UsvdH(y) - srep_axes = A.get_svd_diag()[2] + self.p = self.A.UsvdH(self.y) + srep_axes = self.A.get_svd_diag()[2] # Compute the norm of ||y-UU*(y)||^2/wvar if np.all(self.wvar > 0): - yp = A.Usvd(self.p) - wvar1 = common.repeat_axes(wvar, self.yshape, self.wrep_axes, rep=False) - err = np.abs(y-yp)**2 + yp = self.A.Usvd(self.p) + wvar1 = common.repeat_axes(self.wvar, self.yshape, self.wrep_axes, rep=False) + err = np.abs(self.y-yp)**2 self.ypnorm = np.sum(err/wvar1) else: self.ypnorm = 0 # Check that all axes on which A operates are repeated + ndim = len(self.yshape) for i in range(ndim): if not (i in self.wrep_axes) and not (i in srep_axes): raise common.VpException( @@ -121,6 +154,9 @@ def est_init(self, return_cost=False, ind_out=None,\ if not avg_var_cost: raise ValueError("disabling variance averaging not supported for LinEst") + if self.est_meth == 'cg': + return self.est_cg(np.zeros(self.zshape), np.mean(np.full(self.zshape,np.Inf), axis=self.var_axes), return_cost=return_cost) + # Get the diagonal parameters s, sshape, srep_axes = self.A.get_svd_diag() shape0 = self.A.shape0 @@ -182,7 +218,15 @@ def est(self,r,rvar,return_cost=False, ind_out=None,\ if not avg_var_cost: raise ValueError("disabling variance averaging not supported for LinEst") + if self.est_meth == 'svd': + return self.est_svd(r,rvar,return_cost) + elif self.est_meth == 'cg': + return self.est_cg(r,rvar,return_cost) + else: + raise common.VpException( + "Unknown estimation method {0:s}".format(self.est_meth)) + def est_svd(self,r,rvar,return_cost=False): # Get the diagonal parameters s, sshape, srep_axes = self.A.get_svd_diag() @@ -244,3 +288,188 @@ def est(self,r,rvar,return_cost=False, ind_out=None,\ cost = 0.5*cost return zhat, zhatvar, cost + + def est_cg(self,r,rvar,return_cost=False): + """ + First-order terms + """ + # Create the LSQR transform for the problem + # The VAMP problem is equivalent to minimizing ||F(z)-g||^2 + F = LSQROp(self.A,self.y,rvar, self.wvar,\ + self.var_axes, self.wrep_axes,\ + self.zshape, self.yshape, self.is_complex) + g = F.get_tgt_vec(r) + + # Get the initial condition + if self.zlast is None: + zinit = F.pack(r) + else: + zinit = self.zlast + g -= F.dot(zinit) + + # Run the LSQR optimization + lsqr_out = scipy.sparse.linalg.lsqr(F,g,iter_lim=self.nit_cg,atol=self.atol_cg) + zvec = lsqr_out[0] + zinit + self.zlast = zvec + zhat = F.unpack(zvec) + + """ + Cost + """ + if return_cost: + # Compute the cost + cost = lsqr_out[3]**2 + + # Add the cost for the second order terms. + # + # We only consider the MAP case, where the second-order cost is + # (1/2)*ny*log(2*pi*wvar) + if self.is_complex: + cscale = 1 + else: + cscale = 2 + cost /= cscale + if np.all(self.wvar > 0): + ny = np.prod(self.yshape) + cost += (1/cscale)*ny*np.mean(np.log(cscale*np.pi*self.wvar)) + + """ + Second-order terms + + These are computed via the numerical gradient along a random direction + """ + if np.any(rvar==np.Inf): + # Create the LSQR transform for the problem + # The VAMP problem is equivalent to minimizing ||F(z)-g||^2 + F = LSQROp(self.A,np.random.normal(self.y,np.sqrt(self.wvar),self.y.shape),rvar, self.wvar,\ + self.var_axes, self.wrep_axes,\ + self.zshape, self.yshape, self.is_complex) + g = F.get_tgt_vec(r) + + # Get the initial condition + zinit = F.pack(r) + g -= F.dot(zinit) + + # Run the LSQR optimization + lsqr_out = scipy.sparse.linalg.lsqr(F,g,iter_lim=self.nit_cg,atol=self.atol_cg) + zvec = lsqr_out[0] + zinit + zhat_ = F.unpack(zvec) + zhatvar = np.mean(np.abs(zhat_-zhat)**2, axis=self.var_axes) + + sshape = np.array(self.zshape) + sshape[-1-(len(self.zshape)-1)] = min(self.zshape[-1-(len(self.zshape)-1)], self.yshape[-1-(len(self.yshape)-1)]) + sshape = tuple(sshape) + shape0 = self.A.shape0 + rdim = np.product(sshape)/np.product(shape0) + zhatvar = rdim*zhatvar + (1-rdim)*self.rvar_init + + else: + # Perturb r0 + r0p = r + self.dr0 + g0 = F.get_tgt_vec(r0p) + + # Get the initial condition + if self.zvec0_last is None: + zinit = F.pack(r0p) + else: + zinit = self.zvec0_last + g0 -= F.dot(zinit) + + # Run the LSQR optimization + lsqr_out = scipy.sparse.linalg.lsqr(F,g0,iter_lim=self.nit_cg,atol=self.atol_cg) + zvec0 = lsqr_out[0] + zinit + self.zvec0_last = zvec0 + dzvec = zvec0 - zvec + dz0 = F.unpack(dzvec) + + # Compute the correlations + alpha0 = np.mean(np.real(self.dr0.conj()*dz0),self.var_axes) /\ + self.dr0_norm_sq + zhatvar = alpha0*rvar + + if return_cost: + return zhat,zhatvar, cost + else: + return zhat,zhatvar + +class LSQROp(scipy.sparse.linalg.LinearOperator): + """ + LSQR operator for the VAMP least squares problem. + + Defines an operator F(z) and constant vector g such that the VAMP + optimization is equivalent to + + min_z ||F(z) - g||^2 + + This can be solved with LSQR. + + F(z) = D*[z; A.dot(z)] g = D*[r; b] + D = diag(1/sqrt([rvar; wvar])) + """ + def __init__(self,A,b,rvar,wvar,zrep_axes,wrep_axes,\ + shape0,shape1,is_complex): + self.A = A + self.b = b + self.shape0 = shape0 + self.shape1 = shape1 + self.n0 = np.prod(shape0) + self.n1 = np.prod(shape1) + + # Compute scale factors + self.rsqrt = common.repeat_axes( + np.sqrt(rvar), self.shape0, zrep_axes, rep=False) + self.wsqrt = common.repeat_axes( + np.sqrt(wvar), self.shape1, wrep_axes, rep=False) + + # Compute dimensions of the transform F + nin = self.n0 + nout = self.n0 + self.n1 + self.shape = (nout,nin) + if is_complex: + self.dtype = np.dtype(complex) + else: + self.dtype = np.dtype(float) + + def unpack(self,zvec): + """ + Unpacks the variables from vector for the CG estimation + """ + z = zvec.reshape(self.shape0) + return z + + def pack(self,z): + """ + Packs the variables from vector for the CG estimation to a vector + """ + zvec = z.ravel() + return zvec + + def get_tgt_vec(self,r): + """ + Computes the target vector `g` in the above description + """ + g0 = 1/self.rsqrt*r + g1 = 1/self.wsqrt*(self.b) + g = np.hstack((g0.ravel(),g1.ravel())) + return g + + def _matvec(self,r): + """ + Forward multiplication for the operator `F` defined above + """ + r0 = self.unpack(r) + y0 = 1/self.rsqrt*r0 + yout = 1/self.wsqrt*self.A.dot(r0) + y = np.hstack((y0.ravel(), yout.ravel())) + return y + + def _rmatvec(self,y): + """ + Adjoint multiplication for the operator `F` defined above + """ + y0 = np.reshape(y[:self.n0], self.shape0) + y1 = np.reshape(y[self.n0:], self.shape1) + r0 = 1/self.rsqrt*y0 + r0 += 1/self.wsqrt*self.A.dotH(y1) + r = r0.ravel() + return r diff --git a/vampyre/estim/linear_two.py b/vampyre/estim/linear_two.py index a0117af..1233ce6 100644 --- a/vampyre/estim/linear_two.py +++ b/vampyre/estim/linear_two.py @@ -187,7 +187,7 @@ def est_init(self, return_cost=False, ind_out=None, avg_var_cost=True): if ind_out is None: ind_out = [0,1] if not avg_var_cost: - raise ValueError("disabling variance averaging not supported for MixEst") + raise ValueError("disabling variance averaging not supported for LinEstTwo") # Set initial mean and variance @@ -232,7 +232,7 @@ def est(self,r,rvar,return_cost=False,ind_out=None, avg_var_cost=True): if ind_out is None: ind_out = [0,1] if not avg_var_cost: - raise ValueError("disabling variance averaging not supported for MixEst") + raise ValueError("disabling variance averaging not supported for LinEstTwo") if self.est_meth == 'svd': return self.est_svd(r,rvar,return_cost,ind_out) @@ -627,8 +627,8 @@ class LSQROp(scipy.sparse.linalg.LinearOperator): This can be solved with LSQR. When wvar == 0: - F(z0) = [D*z0; A.dot(z0)] g = D*[r0; r1-b] - D= diag(1/sqrt([rvar0; rvar1])) + F(z0) = D*[z0; A.dot(z0)] g = D*[r0; r1-b] + D = diag(1/sqrt([rvar0; rvar1])) When wvar > 0: F(z0,z1) = D*[z1-A.dot(z0); z0; z1] g=D*[b; r0; r1]