Skip to content

Commit bb89527

Browse files
authored
Fixes a bug in DataSet count retrieval and adds unit tests. (#663)
Corrects an indexing bug that would cause DataSet lookups to silently fail (return incorrect values) when data had been added with dictionaries that didn't have keys ordered in the same way as the DataSet's outcome list. This commit fixes this issue and adds several unit tests used in debugging. Previously failing minimal example (now included as a unit test) is: ``` import pygsti c = "Gxpi2:0@(0)" counts = {'00': 1, '10': 0, '01': 97, '11': 2} ds = pygsti.data.DataSet(outcome_labels=["00", "10", "01", "11"], static=False) ds.add_count_dict(c, counts) check = ds[c].counts print(ds) assert {k[0]: v for k,v in check.items()} == counts ```
2 parents fd94318 + 3423dc1 commit bb89527

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

pygsti/data/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,18 +575,18 @@ def _get_counts(self, timestamp=None, all_outcomes=False):
575575
cntDict.setitem_unsafe(ol, cnt)
576576
else:
577577
for ol, i in self.dataset.olIndex.items():
578-
inds = oli_tslc[oli_tslc == i]
578+
inds = _np.where(oli_tslc == i)[0]
579579
if len(inds) > 0 or all_outcomes:
580580
cntDict.setitem_unsafe(ol, float(sum(self.reps[tslc][inds])))
581581
else:
582582
if self.reps is None:
583583
for ol_index in oli_tslc:
584584
ol = self.dataset.ol[ol_index]
585-
cntDict.setitem_unsafe(ol, 1.0 + cntDict.getitem_unsafe(ol, 0.0))
585+
cntDict.setitem_unsafe(ol, float(1.0 + cntDict.getitem_unsafe(ol, 0.0)))
586586
else:
587587
for ol_index, reps in zip(oli_tslc, self.reps[tslc]):
588588
ol = self.dataset.ol[ol_index]
589-
cntDict.setitem_unsafe(ol, reps + cntDict.getitem_unsafe(ol, 0.0))
589+
cntDict.setitem_unsafe(ol, float(reps + cntDict.getitem_unsafe(ol, 0.0)))
590590

591591
return cntDict
592592

test/unit/objects/test_dataset.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,43 @@ def test_static_constructor_raises_on_missing_counts(self):
9999
with self.assertRaises(ValueError):
100100
DataSet(circuits=self.gstrs, outcome_labels=['0', '1'], static=True)
101101

102+
def test_add_count_dict_order(self):
103+
c = "Gxpi2:0@(0)"
104+
counts = {'00': 1, '10': 0, '01': 97, '11': 2}
105+
ds = DataSet(outcome_labels=["00", "10", "01", "11"], static=False)
106+
ds.add_count_dict(c, counts)
107+
check = ds[c].counts
108+
print(ds)
109+
110+
self.assertEqual({k[0]: v for k,v in check.items()}, counts)
111+
112+
# Add more
113+
c2 = "Gypi2:0@(0)"
114+
counts2 = {'00': 1, '01': 43, '10': 24, '11': 7}
115+
ds.add_count_dict(c2, counts2)
116+
check = ds[c2].counts
117+
118+
self.assertEqual({k[0]: v for k,v in check.items()}, counts2)
119+
120+
ds.done_adding_data()
121+
self.assertEqual({k[0]: v for k,v in ds[c].counts.items()}, counts)
122+
self.assertEqual({k[0]: v for k,v in ds[c2].counts.items()}, counts2)
123+
124+
def test_add_count_dict_single_outcome(self):
125+
c0 = Circuit("Gi:Q0@Q0")
126+
c1 = Circuit("Gi:Q0Gi:Q0@Q0")
127+
c0_counts = {("1",): 10.0}
128+
c1_counts = {("0",): 4, ("1",): 6}
129+
130+
ds = DataSet()
131+
ds.add_count_dict(c0, c0_counts)
132+
ds.add_count_dict(c1, c1_counts)
133+
134+
self.assertEqual({k: v for k,v in ds[c0].counts.items()}, c0_counts)
135+
self.assertEqual({k: v for k,v in ds[c1].counts.items()}, c1_counts)
136+
137+
138+
102139

103140
class DefaultDataSetInstance(object):
104141
def setUp(self):

0 commit comments

Comments
 (0)