Skip to content

Commit

Permalink
sample() tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gplepage committed Mar 5, 2024
1 parent aea62e1 commit e577685
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions tests/test_gvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,45 +1371,45 @@ def test_sample(self):
s2 = next(raniter(g, eps=eps))
self.assertEqual(str(s1), str(s2))
# test default svdcut (=> no error generated)
with self.assertRaises(UserWarning):
sample(gvar([1,1], [[1,1],[1,1]]), svdcut=None)
# with self.assertRaises(UserWarning):
# sample(gvar([1,1], [[1,1],[1,1]]), svdcut=None)
sample(gvar([1,1], [[1,1],[1,1]]))


# def test_batch_sample(self): ### makes no sense!
# " sample(g, nbatch=...) raniter(g, nbatch=...)"
# # dictionary
# g = gvar(BufferDict(s='1(1)', a=[['1(1)','1(1)','1(1)']]))
# nbatch = 5
# ranseed(1)
# sl = sample(g, nbatch=nbatch, mode='lbatch')
# ranseed(1)
# sr = sample(g, nbatch=nbatch, mode='rbatch')
# for k in g:
# self.assertTrue(sl[k].shape[0] == sr[k].shape[-1] == nbatch)
# np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
# for s in sl.batch_iter('lbatch'):
# self.assertLess(chi2(s, g) / g.size, 10.)
# for s in sr.batch_iter('rbatch'):
# self.assertLess(chi2(s, g) / g.size, 10.)
# # array
# ranseed(1)
# sl = sample(g['a'], nbatch=nbatch, mode='lbatch')
# ranseed(1)
# sr = sample(g['a'], nbatch=nbatch, mode='rbatch')
# self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
# np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
# for s in sl:
# self.assertLess(chi2(s, g['a']) / g['a'].size, 10.)
# # gvar
# ranseed(1)
# sl = sample(g['s'], nbatch=nbatch, mode='lbatch')
# ranseed(1)
# sr = sample(g['s'], nbatch=nbatch, mode='rbatch')
# self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
# self.assertEqual(list(sl), list(sr))
# for s in sl:
# self.assertLess(chi2(s, g['s']), 10.)
def test_batch_sample(self):
" sample(g, nbatch=...) raniter(g, nbatch=...)"
# dictionary
g = gvar(BufferDict(s='1(1)', a=[['1(1)','1(1)','1(1)']]))
nbatch = 5
ranseed(1)
sl = sample(g, nbatch=nbatch, mode='lbatch')
ranseed(1)
sr = sample(g, nbatch=nbatch, mode='rbatch')
for k in g:
self.assertTrue(sl[k].shape[0] == sr[k].shape[-1] == nbatch)
np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
for s in sl.batch_iter('lbatch'):
self.assertLess(chi2(s, g) / g.size, 10.)
for s in sr.batch_iter('rbatch'):
self.assertLess(chi2(s, g) / g.size, 10.)
# array
ranseed(1)
sl = sample(g['a'], nbatch=nbatch, mode='lbatch')
ranseed(1)
sr = sample(g['a'], nbatch=nbatch, mode='rbatch')
self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
np.testing.assert_allclose(np.sum(sl.flat), np.sum(sr.flat))
for s in sl:
self.assertLess(chi2(s, g['a']) / g['a'].size, 10.)
# gvar
ranseed(1)
sl = sample(g['s'], nbatch=nbatch, mode='lbatch')
ranseed(1)
sr = sample(g['s'], nbatch=nbatch, mode='rbatch')
self.assertTrue(sl.shape[0] == sr.shape[-1] == nbatch)
self.assertEqual(list(sl), list(sr))
for s in sl:
self.assertLess(chi2(s, g['s']), 10.)

@unittest.skipIf(FAST,"skipping test_gvar_from_sample for speed")
def test_gvar_from_sample(self):
Expand Down

0 comments on commit e577685

Please sign in to comment.