Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_estim/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1):

# Construct the linear estimator
est = vp.estim.LinEst(Aop,y,wvar,var_axes='all')
# Construct same linear estimator with conjugate gradient method
est_cg = vp.estim.LinEst(Aop,y,wvar,var_axes='all',est_meth='cg')

# Perform the initial estimate. This is just run to make sure it
# doesn't crash
Expand All @@ -56,6 +58,11 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1):
raise vp.common.TestException(\
"est_init does not produce the correct shape")

# Initial estimate with conjugate gradient method should match SVD method
zhat_cg, zhatvar_cg = est_cg.est_init()
if not np.allclose(zhat_cg, zhat, atol=1e-3, rtol=1e-3) or not np.allclose(zhatvar_cg, zhatvar, atol=1e-1, rtol=2e-1):
raise vp.common.TestException("Conjugate gradient initial estimate estimate does not match SVD method")

# Posterior estimate
zhat, zhatvar, cost = est.est(r,rvar,return_cost=True)
zerr = np.mean(np.abs(z-zhat)**2)
Expand All @@ -65,6 +72,11 @@ def lin_test(zshape=(500,10),Ashape=(1000,500),verbose=False,tol=0.1):
if fail:
raise vp.common.TestException("Posterior estimate Gaussian error "+
" does not match predicted value")

# Posterior estimate with conjugate gradient method should match SVD method
zhat_cg, zhatvar_cg = est_cg.est(r,rvar)
if not np.allclose(zhat_cg, zhat, atol=1e-3, rtol=1e-3) or not np.allclose(zhatvar_cg, zhatvar, atol=1e-1, rtol=2e-1):
raise vp.common.TestException("Conjugate gradient posterior estimate estimate does not match SVD method")


class TestCases(unittest.TestCase):
Expand Down
25 changes: 16 additions & 9 deletions test/test_solver/test_vamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\
is_complex=False, tol=1e-5):
is_complex=False, tol=1e-5, tol_var=1e-5, tol_cost=1e-5, est_meth='svd'):
"""
Unit test for VAMP method using a Gaussian prior.

Expand Down Expand Up @@ -68,7 +68,7 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\

# Create output estimator
Aop = vp.trans.MatrixLT(A,zshape)
est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est, name='Posterior')
est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est, name='Posterior', est_meth=est_meth)

# Create the variance handler
msg_hdl = vp.estim.MsgHdlSimp(map_est=map_est, is_complex=is_complex, \
Expand Down Expand Up @@ -96,7 +96,7 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\
zvar_err = np.abs(zvar-np.mean(solver.zvar2))
if verbose:
print("zvar error: {0:12.4e}".format(zvar_err))
if zvar_err > tol:
if zvar_err > tol_var:
raise vp.common.TestException("Variance does not match")


Expand Down Expand Up @@ -148,12 +148,12 @@ def vamp_gauss_test(nz=100,ny=200,ns=10, snr=30, map_est=False, verbose=False,\
if verbose:
print("costs: Direct {0:12.4e} ".format(cost_tot)+\
"Termwise {0:12.4e} solver: {1:12.4e}".format(cost_tota,solver.cost))
if np.abs(cost_tot - cost_tota) > tol:
if np.abs(cost_tot - cost_tota) > tol_cost:
raise vp.common.TestException("Direct and termwise costs do not match")
if np.abs(cost_tot - cost_tota) > tol:
if np.abs(cost_tot - solver.cost) > tol_cost:
raise vp.common.TestException("Predicted cost does not match solver output")

def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17):
def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17, est_meth='svd'):
"""
Unit test for VAMP using a Gaussian mixture model (GMM)

Expand Down Expand Up @@ -216,7 +216,7 @@ def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17):

# Create output estimator
Aop = vp.trans.MatrixLT(A,zshape)
est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est)
est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est,est_meth=est_meth)

# Create the variance handler
msg_hdl = vp.estim.MsgHdlSimp(map_est=map_est, is_complex=is_complex,\
Expand All @@ -243,7 +243,7 @@ def vamp_gmm_test(nz=100,ny=200,ns=10, snr=30, verbose=False, mse_tol=-17):
if mse[-1] > mse_tol:
raise vp.common.TestException("MSE exceeded expected value")

def vamp_bg_test(nz=1000,ny=500,ns=1, snr=30, verbose=False, pred_tol=3.0):
def vamp_bg_test(nz=1000,ny=500,ns=1, snr=30, verbose=False, pred_tol=3.0, est_meth='svd'):
"""
Unit test for VAMP using a Gaussian mixture model (GMM)

Expand Down Expand Up @@ -299,7 +299,7 @@ def vamp_bg_test(nz=1000,ny=500,ns=1, snr=30, verbose=False, pred_tol=3.0):

# Create output estimator
Aop = vp.trans.MatrixLT(A,zshape)
est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est)
est_out = vp.estim.LinEst(Aop,y,wvar,map_est=map_est, est_meth=est_meth)

# Create the variance handler
msg_hdl = vp.estim.MsgHdlSimp(map_est=map_est, is_complex=is_complex,\
Expand Down Expand Up @@ -342,18 +342,25 @@ def test_vamp_gauss(self):
vamp_gauss_test(nz=200,ny=100,ns=10,map_est=map_est,verbose=verbose)
vamp_gauss_test(nz=100,ny=200,ns=1,map_est=map_est,verbose=verbose)
vamp_gauss_test(nz=200,ny=100,ns=1,map_est=map_est,verbose=verbose)

vamp_gauss_test(nz=100,ny=200,ns=10,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=2e-1)
vamp_gauss_test(nz=200,ny=100,ns=10,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=2e-1)
vamp_gauss_test(nz=100,ny=200,ns=1,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=2e-1)
vamp_gauss_test(nz=200,ny=100,ns=1,map_est=True,est_meth='cg',verbose=verbose,tol=1e-4,tol_var=3e-1)

def test_vamp_gmm(self):
"""
Run the vamp_gmm_test
"""
vamp_gmm_test(nz=1000,ny=500,ns=1,verbose=False)
vamp_gmm_test(nz=1000,ny=500,ns=1,verbose=False,est_meth='cg')

def test_vamp_bg(self):
"""
Run VAMP with a BG prior
"""
vamp_bg_test(nz=1000,ny=500,ns=10,verbose=False)
vamp_bg_test(nz=1000,ny=500,ns=10,verbose=False,est_meth='cg')

if __name__ == '__main__':
#vamp_bg_test(nz=1000,ny=500,ns=10,verbose=verbose)
Expand Down
Loading