Skip to content

Commit 9910707

Browse files
Fixes a bug in DataSet count retrieval and adds unit tests. (#663) (Develop Merge) (#665)
Note: Since the automatic merge from bugfix into develop didn't work this is just a manual merge of PR #663 into develop. Co-authored-by: Erik Nielsen <[email protected]>
1 parent 66e6394 commit 9910707

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

pygsti/data/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,18 +575,18 @@ def _get_counts(self, timestamp=None, all_outcomes=False):
575575
cntDict.setitem_unsafe(ol, cnt)
576576
else:
577577
for ol, i in self.dataset.olIndex.items():
578-
inds = oli_tslc[oli_tslc == i]
578+
inds = _np.where(oli_tslc == i)[0]
579579
if len(inds) > 0 or all_outcomes:
580580
cntDict.setitem_unsafe(ol, float(sum(self.reps[tslc][inds])))
581581
else:
582582
if self.reps is None:
583583
for ol_index in oli_tslc:
584584
ol = self.dataset.ol[ol_index]
585-
cntDict.setitem_unsafe(ol, 1.0 + cntDict.getitem_unsafe(ol, 0.0))
585+
cntDict.setitem_unsafe(ol, float(1.0 + cntDict.getitem_unsafe(ol, 0.0)))
586586
else:
587587
for ol_index, reps in zip(oli_tslc, self.reps[tslc]):
588588
ol = self.dataset.ol[ol_index]
589-
cntDict.setitem_unsafe(ol, reps + cntDict.getitem_unsafe(ol, 0.0))
589+
cntDict.setitem_unsafe(ol, float(reps + cntDict.getitem_unsafe(ol, 0.0)))
590590

591591
return cntDict
592592

test/unit/objects/test_dataset.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,43 @@ def test_static_constructor_raises_on_missing_counts(self):
9999
with self.assertRaises(ValueError):
100100
DataSet(circuits=self.gstrs, outcome_labels=['0', '1'], static=True)
101101

102+
def test_add_count_dict_order(self):
103+
c = "Gxpi2:0@(0)"
104+
counts = {'00': 1, '10': 0, '01': 97, '11': 2}
105+
ds = DataSet(outcome_labels=["00", "10", "01", "11"], static=False)
106+
ds.add_count_dict(c, counts)
107+
check = ds[c].counts
108+
print(ds)
109+
110+
self.assertEqual({k[0]: v for k,v in check.items()}, counts)
111+
112+
# Add more
113+
c2 = "Gypi2:0@(0)"
114+
counts2 = {'00': 1, '01': 43, '10': 24, '11': 7}
115+
ds.add_count_dict(c2, counts2)
116+
check = ds[c2].counts
117+
118+
self.assertEqual({k[0]: v for k,v in check.items()}, counts2)
119+
120+
ds.done_adding_data()
121+
self.assertEqual({k[0]: v for k,v in ds[c].counts.items()}, counts)
122+
self.assertEqual({k[0]: v for k,v in ds[c2].counts.items()}, counts2)
123+
124+
def test_add_count_dict_single_outcome(self):
125+
c0 = Circuit("Gi:Q0@Q0")
126+
c1 = Circuit("Gi:Q0Gi:Q0@Q0")
127+
c0_counts = {("1",): 10.0}
128+
c1_counts = {("0",): 4, ("1",): 6}
129+
130+
ds = DataSet()
131+
ds.add_count_dict(c0, c0_counts)
132+
ds.add_count_dict(c1, c1_counts)
133+
134+
self.assertEqual({k: v for k,v in ds[c0].counts.items()}, c0_counts)
135+
self.assertEqual({k: v for k,v in ds[c1].counts.items()}, c1_counts)
136+
137+
138+
102139

103140
class DefaultDataSetInstance(object):
104141
def setUp(self):

0 commit comments

Comments
 (0)