Skip to content

Commit 7dbe3cf

Browse files
Add a modulus parameter to QMontgomeryUInt
1 parent 7b3aef9 commit 7dbe3cf

File tree

10 files changed

+107
-111
lines changed

10 files changed

+107
-111
lines changed

qualtran/_infra/data_types.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,8 @@ class QMontgomeryUInt(QDType):
792792
We follow Montgomery form as described in the above paper; namely, r = 2^bitsize.
793793
"""
794794

795-
# TODO(https://github.com/quantumlib/Qualtran/issues/1471): Add modulus p as a class member.
796795
bitsize: SymbolicInt
796+
modulus: Optional[SymbolicInt] = None
797797

798798
@property
799799
def num_qubits(self):
@@ -803,7 +803,9 @@ def is_symbolic(self) -> bool:
803803
return is_symbolic(self.bitsize)
804804

805805
def get_classical_domain(self) -> Iterable[Any]:
806-
return range(2**self.bitsize)
806+
if self.modulus is None or is_symbolic(self.modulus):
807+
return range(2**self.bitsize)
808+
return range(int(self.modulus))
807809

808810
def to_bits(self, x: int) -> List[int]:
809811
self.assert_valid_classical_val(x)
@@ -828,42 +830,42 @@ def assert_valid_classical_val_array(
828830
if np.any(val_array >= 2**self.bitsize):
829831
raise ValueError(f"Too-large classical values encountered in {debug_str}")
830832

831-
def montgomery_inverse(self, xm: int, p: int) -> int:
833+
def montgomery_inverse(self, xm: int) -> int:
832834
"""Returns the modular inverse of an integer in montgomery form.
833835
834836
Args:
835837
xm: An integer in montgomery form.
836838
p: The modulus of the finite field.
837839
"""
838-
return ((pow(xm, -1, p)) * pow(2, 2 * self.bitsize, p)) % p
840+
return ((pow(xm, -1, self.modulus)) * pow(2, 2 * self.bitsize, self.modulus)) % self.modulus
839841

840-
def montgomery_product(self, xm: int, ym: int, p: int) -> int:
842+
def montgomery_product(self, xm: int, ym: int) -> int:
841843
"""Returns the modular product of two integers in montgomery form.
842844
843845
Args:
844846
xm: The first montgomery form integer for the product.
845847
ym: The second montgomery form integer for the product.
846848
p: The modulus of the finite field.
847849
"""
848-
return (xm * ym * pow(2, -self.bitsize, p)) % p
850+
return (xm * ym * pow(2, -self.bitsize, self.modulus)) % self.modulus
849851

850-
def montgomery_to_uint(self, xm: int, p: int) -> int:
852+
def montgomery_to_uint(self, xm: int) -> int:
851853
"""Converts an integer in montgomery form to a normal form integer.
852854
853855
Args:
854856
xm: An integer in montgomery form.
855857
p: The modulus of the finite field.
856858
"""
857-
return (xm * pow(2, -self.bitsize, p)) % p
859+
return (xm * pow(2, -self.bitsize, self.modulus)) % self.modulus
858860

859-
def uint_to_montgomery(self, x: int, p: int) -> int:
861+
def uint_to_montgomery(self, x: int) -> int:
860862
"""Converts an integer into montgomery form.
861863
862864
Args:
863865
x: An integer.
864866
p: The modulus of the finite field.
865867
"""
866-
return (x * pow(2, int(self.bitsize), p)) % p
868+
return (x * pow(2, int(self.bitsize), self.modulus)) % self.modulus
867869

868870

869871
@attrs.frozen

qualtran/_infra/data_types_test.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -138,27 +138,23 @@ def test_qmontgomeryuint():
138138
@pytest.mark.parametrize('p', [13, 17, 29])
139139
@pytest.mark.parametrize('val', [1, 5, 7, 9])
140140
def test_qmontgomeryuint_operations(val, p):
141-
qmontgomeryuint_8 = QMontgomeryUInt(8)
141+
qmontgomeryuint_8 = QMontgomeryUInt(8, p)
142142
# Convert value to montgomery form and get the modular inverse.
143-
val_m = qmontgomeryuint_8.uint_to_montgomery(val, p)
144-
mod_inv = qmontgomeryuint_8.montgomery_inverse(val_m, p)
143+
val_m = qmontgomeryuint_8.uint_to_montgomery(val)
144+
mod_inv = qmontgomeryuint_8.montgomery_inverse(val_m)
145145

146146
# Calculate the product in montgomery form and convert back to normal form for assertion.
147147
assert (
148-
qmontgomeryuint_8.montgomery_to_uint(
149-
qmontgomeryuint_8.montgomery_product(val_m, mod_inv, p), p
150-
)
148+
qmontgomeryuint_8.montgomery_to_uint(qmontgomeryuint_8.montgomery_product(val_m, mod_inv))
151149
== 1
152150
)
153151

154152

155153
@pytest.mark.parametrize('p', [13, 17, 29])
156154
@pytest.mark.parametrize('val', [1, 5, 7, 9])
157155
def test_qmontgomeryuint_conversions(val, p):
158-
qmontgomeryuint_8 = QMontgomeryUInt(8)
159-
assert val == qmontgomeryuint_8.montgomery_to_uint(
160-
qmontgomeryuint_8.uint_to_montgomery(val, p), p
161-
)
156+
qmontgomeryuint_8 = QMontgomeryUInt(8, p)
157+
assert val == qmontgomeryuint_8.montgomery_to_uint(qmontgomeryuint_8.uint_to_montgomery(val))
162158

163159

164160
def test_qgf():

qualtran/bloqs/cryptography/ecc/ec_add.py

+11-20
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,8 @@ def on_classical_vals(
250250
lam = lam_r
251251
f1 = 0
252252
else:
253-
lam = QMontgomeryUInt(self.n).montgomery_product(
254-
int(y),
255-
QMontgomeryUInt(self.n).montgomery_inverse(int(x), int(self.mod)),
256-
int(self.mod),
253+
lam = QMontgomeryUInt(self.n, self.mod).montgomery_product(
254+
int(y), QMontgomeryUInt(self.n, self.mod).montgomery_inverse(int(x))
257255
)
258256
# TODO(https://github.com/quantumlib/Qualtran/issues/1461): Fix bug in circuit
259257
# which flips f1 when lam and lam_r are equal.
@@ -540,10 +538,10 @@ def on_classical_vals(
540538
self, x: 'ClassicalValT', y: 'ClassicalValT', lam: 'ClassicalValT'
541539
) -> Dict[str, 'ClassicalValT']:
542540
x = (
543-
x - QMontgomeryUInt(self.n).montgomery_product(int(lam), int(lam), int(self.mod))
541+
x - QMontgomeryUInt(self.n, self.mod).montgomery_product(int(lam), int(lam))
544542
) % self.mod
545543
if lam > 0:
546-
y = QMontgomeryUInt(self.n).montgomery_product(int(x), int(lam), int(self.mod))
544+
y = QMontgomeryUInt(self.n, self.mod).montgomery_product(int(x), int(lam))
547545
return {'x': x, 'y': y, 'lam': lam}
548546

549547
def build_composite_bloq(
@@ -1071,30 +1069,23 @@ def build_composite_bloq(
10711069
return {'a': a, 'b': b, 'x': x, 'y': y, 'lam_r': lam_r}
10721070

10731071
def on_classical_vals(self, a, b, x, y, lam_r) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
1072+
dtype = QMontgomeryUInt(self.n, self.mod)
10741073
curve_a = (
1075-
QMontgomeryUInt(self.n).montgomery_to_uint(lam_r, int(self.mod))
1076-
* 2
1077-
* QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod))
1078-
- (3 * QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)) ** 2)
1074+
dtype.montgomery_to_uint(lam_r) * 2 * dtype.montgomery_to_uint(b)
1075+
- (3 * dtype.montgomery_to_uint(a) ** 2)
10791076
) % self.mod
10801077
p1 = ECPoint(
1081-
QMontgomeryUInt(self.n).montgomery_to_uint(a, int(self.mod)),
1082-
QMontgomeryUInt(self.n).montgomery_to_uint(b, int(self.mod)),
1083-
mod=self.mod,
1084-
curve_a=curve_a,
1078+
dtype.montgomery_to_uint(a), dtype.montgomery_to_uint(b), mod=self.mod, curve_a=curve_a
10851079
)
10861080
p2 = ECPoint(
1087-
QMontgomeryUInt(self.n).montgomery_to_uint(x, int(self.mod)),
1088-
QMontgomeryUInt(self.n).montgomery_to_uint(y, int(self.mod)),
1089-
mod=self.mod,
1090-
curve_a=curve_a,
1081+
dtype.montgomery_to_uint(x), dtype.montgomery_to_uint(y), mod=self.mod, curve_a=curve_a
10911082
)
10921083
result = p1 + p2
10931084
return {
10941085
'a': a,
10951086
'b': b,
1096-
'x': QMontgomeryUInt(self.n).uint_to_montgomery(result.x, int(self.mod)),
1097-
'y': QMontgomeryUInt(self.n).uint_to_montgomery(result.y, int(self.mod)),
1087+
'x': dtype.uint_to_montgomery(result.x),
1088+
'y': dtype.uint_to_montgomery(result.y),
10981089
'lam_r': lam_r,
10991090
}
11001091

qualtran/bloqs/cryptography/ecc/ec_add_r.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -179,17 +179,17 @@ def qrom(self) -> QROAMClean:
179179

180180
cR = self.R
181181
data_a, data_b, data_lam = [0], [0], [0]
182-
mon_int = QMontgomeryUInt(self.n)
182+
mon_int = QMontgomeryUInt(self.n, self.R.mod)
183183
for _ in range(1, 2**self.add_window_size):
184-
data_a.append(mon_int.uint_to_montgomery(int(cR.x), int(self.R.mod)))
185-
data_b.append(mon_int.uint_to_montgomery(int(cR.y), int(self.R.mod)))
184+
data_a.append(mon_int.uint_to_montgomery(int(cR.x)))
185+
data_b.append(mon_int.uint_to_montgomery(int(cR.y)))
186186
lam_num = (3 * cR.x**2 + cR.curve_a) % cR.mod
187187
lam_denom = (2 * cR.y) % cR.mod
188188
if lam_denom != 0:
189189
lam = (lam_num * pow(lam_denom, -1, mod=cR.mod)) % cR.mod
190190
else:
191191
lam = 0
192-
data_lam.append(mon_int.uint_to_montgomery(int(lam), int(self.R.mod)))
192+
data_lam.append(mon_int.uint_to_montgomery(int(lam)))
193193
cR = cR + self.R
194194

195195
return QROAMClean(
@@ -244,18 +244,19 @@ def build_call_graph(self, ssa: 'SympySymbolAllocator') -> 'BloqCountDictT':
244244

245245
def on_classical_vals(self, ctrl, x, y) -> Dict[str, Union['ClassicalValT', sympy.Expr]]:
246246
# TODO(https://github.com/quantumlib/Qualtran/issues/1476): make ECAdd accept SymbolicInt.
247+
dtype = QMontgomeryUInt(self.n, self.R.mod)
247248
A = ECPoint(
248-
QMontgomeryUInt(self.n).montgomery_to_uint(int(x), int(self.R.mod)),
249-
QMontgomeryUInt(self.n).montgomery_to_uint(int(y), int(self.R.mod)),
249+
dtype.montgomery_to_uint(int(x)),
250+
dtype.montgomery_to_uint(int(y)),
250251
mod=self.R.mod,
251252
curve_a=self.R.curve_a,
252253
)
253254
ctrls = QUInt(self.n).from_bits(ctrl)
254255
result: ECPoint = A + (ctrls * self.R)
255256
return {
256257
'ctrl': ctrl,
257-
'x': QMontgomeryUInt(self.n).uint_to_montgomery(int(result.x), int(self.R.mod)),
258-
'y': QMontgomeryUInt(self.n).uint_to_montgomery(int(result.y), int(self.R.mod)),
258+
'x': dtype.uint_to_montgomery(int(result.x)),
259+
'y': dtype.uint_to_montgomery(int(result.y)),
259260
}
260261

261262
def wire_symbol(

qualtran/bloqs/cryptography/ecc/ec_add_r_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def test_ec_window_add_r_bloq_counts(n, window_size, a, b):
6060
def test_ec_window_add_r_classical(n, m, ctrl, x, y, a, b):
6161
p = 17
6262
R = ECPoint(a, b, mod=p)
63-
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
64-
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
63+
x = QMontgomeryUInt(n, p).uint_to_montgomery(x)
64+
y = QMontgomeryUInt(n, p).uint_to_montgomery(y)
6565
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
6666
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
6767
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)
@@ -80,8 +80,8 @@ def test_ec_window_add_r_classical(n, m, ctrl, x, y, a, b):
8080
def test_ec_window_add_r_classical_slow(n, m, ctrl, x, y, a, b):
8181
p = 17
8282
R = ECPoint(a, b, mod=p)
83-
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
84-
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
83+
x = QMontgomeryUInt(n, p).uint_to_montgomery(x)
84+
y = QMontgomeryUInt(n, p).uint_to_montgomery(y)
8585
ctrl = np.array(QUInt(m).to_bits(ctrl % (2**m)))
8686
bloq = ECWindowAddR(n=n, R=R, add_window_size=m, mul_window_size=m)
8787
ret1 = bloq.call_classically(ctrl=ctrl, x=x, y=y)

qualtran/bloqs/cryptography/ecc/ec_add_test.py

+34-30
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def test_ec_add_steps_classical_fast(n, m, a, b, x, y):
4444
lam_denom = (2 * b) % p
4545
lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p
4646

47-
a = QMontgomeryUInt(n).uint_to_montgomery(a, p)
48-
b = QMontgomeryUInt(n).uint_to_montgomery(b, p)
49-
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
50-
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
51-
lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p
47+
dtype = QMontgomeryUInt(n, p)
48+
a = dtype.uint_to_montgomery(a)
49+
b = dtype.uint_to_montgomery(b)
50+
x = dtype.uint_to_montgomery(x)
51+
y = dtype.uint_to_montgomery(y)
52+
lam_r = dtype.uint_to_montgomery(lam_r) if lam_r != 0 else p
5253

5354
bloq = _ECAddStepOne(n=n, mod=p)
5455
ret1 = bloq.call_classically(a=a, b=b, x=x, y=y)
@@ -184,11 +185,12 @@ def test_ec_add_steps_classical(n, m, a, b, x, y):
184185
lam_denom = (2 * b) % p
185186
lam_r = 0 if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p
186187

187-
a = QMontgomeryUInt(n).uint_to_montgomery(a, p)
188-
b = QMontgomeryUInt(n).uint_to_montgomery(b, p)
189-
x = QMontgomeryUInt(n).uint_to_montgomery(x, p)
190-
y = QMontgomeryUInt(n).uint_to_montgomery(y, p)
191-
lam_r = QMontgomeryUInt(n).uint_to_montgomery(lam_r, p) if lam_r != 0 else p
188+
dtype = QMontgomeryUInt(n, p)
189+
a = dtype.uint_to_montgomery(a)
190+
b = dtype.uint_to_montgomery(b)
191+
x = dtype.uint_to_montgomery(x)
192+
y = dtype.uint_to_montgomery(y)
193+
lam_r = dtype.uint_to_montgomery(lam_r) if lam_r != 0 else p
192194

193195
bloq = _ECAddStepOne(n=n, mod=p)
194196
ret1 = bloq.call_classically(a=a, b=b, x=x, y=y)
@@ -307,19 +309,20 @@ def test_ec_add_classical_fast(n, m, a, b, x, y):
307309
lam_num = (3 * a**2) % p
308310
lam_denom = (2 * b) % p
309311
lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p
312+
dtype = QMontgomeryUInt(n, p)
310313
ret1 = bloq.call_classically(
311-
a=QMontgomeryUInt(n).uint_to_montgomery(a, p),
312-
b=QMontgomeryUInt(n).uint_to_montgomery(b, p),
313-
x=QMontgomeryUInt(n).uint_to_montgomery(x, p),
314-
y=QMontgomeryUInt(n).uint_to_montgomery(y, p),
315-
lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p),
314+
a=dtype.uint_to_montgomery(a),
315+
b=dtype.uint_to_montgomery(b),
316+
x=dtype.uint_to_montgomery(x),
317+
y=dtype.uint_to_montgomery(y),
318+
lam_r=dtype.uint_to_montgomery(lam_r),
316319
)
317320
ret2 = bloq.decompose_bloq().call_classically(
318-
a=QMontgomeryUInt(n).uint_to_montgomery(a, p),
319-
b=QMontgomeryUInt(n).uint_to_montgomery(b, p),
320-
x=QMontgomeryUInt(n).uint_to_montgomery(x, p),
321-
y=QMontgomeryUInt(n).uint_to_montgomery(y, p),
322-
lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p),
321+
a=dtype.uint_to_montgomery(a),
322+
b=dtype.uint_to_montgomery(b),
323+
x=dtype.uint_to_montgomery(x),
324+
y=dtype.uint_to_montgomery(y),
325+
lam_r=dtype.uint_to_montgomery(lam_r),
323326
)
324327
assert ret1 == ret2
325328

@@ -352,19 +355,20 @@ def test_ec_add_classical(n, m, a, b, x, y):
352355
lam_num = (3 * a**2) % p
353356
lam_denom = (2 * b) % p
354357
lam_r = p if b == 0 else (lam_num * pow(lam_denom, -1, mod=p)) % p
358+
dtype = QMontgomeryUInt(n, p)
355359
ret1 = bloq.call_classically(
356-
a=QMontgomeryUInt(n).uint_to_montgomery(a, p),
357-
b=QMontgomeryUInt(n).uint_to_montgomery(b, p),
358-
x=QMontgomeryUInt(n).uint_to_montgomery(x, p),
359-
y=QMontgomeryUInt(n).uint_to_montgomery(y, p),
360-
lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p),
360+
a=dtype.uint_to_montgomery(a),
361+
b=dtype.uint_to_montgomery(b),
362+
x=dtype.uint_to_montgomery(x),
363+
y=dtype.uint_to_montgomery(y),
364+
lam_r=dtype.uint_to_montgomery(lam_r),
361365
)
362366
ret2 = bloq.decompose_bloq().call_classically(
363-
a=QMontgomeryUInt(n).uint_to_montgomery(a, p),
364-
b=QMontgomeryUInt(n).uint_to_montgomery(b, p),
365-
x=QMontgomeryUInt(n).uint_to_montgomery(x, p),
366-
y=QMontgomeryUInt(n).uint_to_montgomery(y, p),
367-
lam_r=QMontgomeryUInt(n).uint_to_montgomery(lam_r, p),
367+
a=dtype.uint_to_montgomery(a),
368+
b=dtype.uint_to_montgomery(b),
369+
x=dtype.uint_to_montgomery(x),
370+
y=dtype.uint_to_montgomery(y),
371+
lam_r=dtype.uint_to_montgomery(lam_r),
368372
)
369373
assert ret1 == ret2
370374

qualtran/bloqs/mod_arithmetic/mod_addition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ class ModAdd(Bloq):
7979
def signature(self) -> 'Signature':
8080
return Signature(
8181
[
82-
Register('x', QMontgomeryUInt(self.bitsize)),
83-
Register('y', QMontgomeryUInt(self.bitsize)),
82+
Register('x', QMontgomeryUInt(self.bitsize, self.mod)),
83+
Register('y', QMontgomeryUInt(self.bitsize, self.mod)),
8484
]
8585
)
8686

0 commit comments

Comments
 (0)