Skip to content

Commit 65b35a0

Browse files
Relax restriction on classical inputs for modular addition to fix nightly CI (#1537)
Fix Nightly CI
1 parent 676b02a commit 65b35a0

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

qualtran/bloqs/mod_arithmetic/mod_addition.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ def signature(self) -> 'Signature':
8787
def on_classical_vals(
8888
self, x: 'ClassicalValT', y: 'ClassicalValT'
8989
) -> Dict[str, 'ClassicalValT']:
90-
if not (0 <= x < self.mod):
90+
# The construction still works when at most one of inputs equals `mod`.
91+
special_case = (x == self.mod) ^ (y == self.mod)
92+
if not (0 <= x < self.mod or special_case):
9193
raise ValueError(
92-
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
94+
f'{x=} is outside the valid interval for modular addition [0, {self.mod}]'
9395
)
94-
if not (0 <= y < self.mod):
96+
if not (0 <= y < self.mod or special_case):
9597
raise ValueError(
96-
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
98+
f'{y=} is outside the valid interval for modular addition [0, {self.mod}]'
9799
)
98100

99101
y = (x + y) % self.mod
@@ -320,7 +322,7 @@ def on_classical_vals(
320322

321323
if not (0 <= x < self.mod):
322324
raise ValueError(
323-
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
325+
f'{x=} is outside the valid interval for modular addition [0, {self.mod}]'
324326
)
325327

326328
x = (x + self.k) % self.mod
@@ -508,13 +510,15 @@ def on_classical_vals(
508510
if ctrl != self.cv:
509511
return {'ctrl': ctrl, 'x': x, 'y': y}
510512

511-
if not (0 <= x < self.mod):
513+
# The construction still works when at most one of inputs equals `mod`.
514+
special_case = (x == self.mod) ^ (y == self.mod)
515+
if not (0 <= x < self.mod or special_case):
512516
raise ValueError(
513-
f'{x=} is outside the valid interval for modular addition [0, {self.mod})'
517+
f'{x=} is outside the valid interval for modular addition [0, {self.mod}]'
514518
)
515-
if not (0 <= y < self.mod):
519+
if not (0 <= y < self.mod or special_case):
516520
raise ValueError(
517-
f'{y=} is outside the valid interval for modular addition [0, {self.mod})'
521+
f'{y=} is outside the valid interval for modular addition [0, {self.mod}]'
518522
)
519523

520524
y = (x + y) % self.mod

qualtran/bloqs/mod_arithmetic/mod_addition_test.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,8 @@ def test_classical_action_mod_add(prime, bitsize):
131131
def test_classical_action_cmodadd(control, prime, dtype, bitsize):
132132
b = CModAdd(dtype(bitsize), mod=prime, cv=control)
133133
cb = b.decompose_bloq()
134-
valid_range = range(prime)
135134
for c in range(2):
136-
for x, y in itertools.product(valid_range, repeat=2):
135+
for x, y in itertools.product(range(prime + 1), range(prime)):
137136
assert b.call_classically(ctrl=c, x=x, y=y) == cb.call_classically(ctrl=c, x=x, y=y)
138137

139138

@@ -207,7 +206,7 @@ def test_cmod_add_complexity_vs_ref():
207206
@pytest.mark.parametrize(['prime', 'bitsize'], [(p, bitsize) for p in [5, 7] for bitsize in (5, 6)])
208207
def test_mod_add_classical_action(bitsize, prime):
209208
b = ModAdd(bitsize, prime)
210-
assert_consistent_classical_action(b, x=range(prime), y=range(prime))
209+
assert_consistent_classical_action(b, x=range(prime + 1), y=range(prime))
211210

212211

213212
def test_cmodadd_tensor():

0 commit comments

Comments
 (0)