diff --git a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py index c0242f973..84cd0232d 100644 --- a/qualtran/bloqs/mod_arithmetic/mod_multiplication.py +++ b/qualtran/bloqs/mod_arithmetic/mod_multiplication.py @@ -586,10 +586,14 @@ def __attrs_post_init__(self): assert self.mod > 1 and self.mod % 2 == 1 # Must be an odd integer greater than 1. if isinstance(self.mod, int) and isinstance(self.bitsize, int): - assert 2 * self.mod - 1 < 2**self.bitsize, f'bitsize={self.bitsize} is too small' + assert self.mod - 1 < 2**self.bitsize, f'bitsize={self.bitsize} is too small' if isinstance(self.window_size, int) and isinstance(self.bitsize, int): - assert self.bitsize % self.window_size == 0 + assert (self.bitsize + self._add_extra_qubit) % self.window_size == 0 + + @cached_property + def _add_extra_qubit(self): + return 2 * self.mod - 2 >= 2**self.bitsize @cached_property def signature(self) -> 'Signature': @@ -656,7 +660,13 @@ def on_classical_vals( if self.uncompute: assert ( target is not None - and target == (x * y * pow(2, self.bitsize * (self.mod - 2), self.mod)) % self.mod + and target + == ( + x + * y + * pow(2, (self.bitsize + self._add_extra_qubit) * (self.mod - 2), self.mod) + ) + % self.mod ) assert qrom_indices is not None assert reduced is not None @@ -671,20 +681,24 @@ def on_classical_vals( target = 0 qrom_indices = 0 reduced = 0 - for i in range(0, self.bitsize, self.window_size): + for i in range(0, self.bitsize + self._add_extra_qubit, self.window_size): target, qrom_indices = self._classical_action_window(x >> i, y, target, qrom_indices) if target >= self.mod: target -= self.mod reduced = 1 - montgomery_prod = (x * y * pow(2, self.bitsize * (self.mod - 2), self.mod)) % self.mod + montgomery_prod = ( + x * y * pow(2, (self.bitsize + self._add_extra_qubit) * (self.mod - 2), self.mod) + ) % self.mod assert target == montgomery_prod return {'x': x, 'y': y, 'target': target, 'qrom_indices': qrom_indices, 'reduced': reduced} @cached_property def _mod_mul_impl(self) -> Bloq: - b: Bloq = _DirtyOutOfPlaceMontgomeryModMulImpl(self.bitsize, self.window_size, self.mod) + b: Bloq = _DirtyOutOfPlaceMontgomeryModMulImpl( + self.bitsize + self._add_extra_qubit, self.window_size, self.mod + ) if self.uncompute: b = b.adjoint() return b @@ -703,6 +717,11 @@ def build_composite_bloq( assert qrom_indices is not None assert reduced is not None + if self._add_extra_qubit: + x = bb.join([x, bb.allocate(1, dtype=QMontgomeryUInt(1))]) + y = bb.join([y, bb.allocate(1, dtype=QMontgomeryUInt(1))]) + target = bb.join([target, bb.allocate(1, dtype=QMontgomeryUInt(1))]) + x, y, target, qrom_indices, reduced = bb.add_from( # type: ignore self._mod_mul_impl, x=x, @@ -718,6 +737,10 @@ def build_composite_bloq( return {'x': x, 'y': y} target = bb.allocate(self.bitsize, QMontgomeryUInt(self.bitsize)) + if self._add_extra_qubit: + x = bb.join([x, bb.allocate(1, dtype=QMontgomeryUInt(1))]) + y = bb.join([y, bb.allocate(1, dtype=QMontgomeryUInt(1))]) + target = bb.join([target, bb.allocate(1, dtype=QMontgomeryUInt(1))]) num_windows = (self.bitsize + self.window_size - 1) // self.window_size qrom_indices = bb.allocate( num_windows * self.window_size, QMontgomeryUInt(num_windows * self.window_size) @@ -727,6 +750,20 @@ def build_composite_bloq( x, y, target, qrom_indices, reduced = bb.add_from( self._mod_mul_impl, x=x, y=y, target=target, qrom_indices=qrom_indices, reduced=reduced ) + + if self._add_extra_qubit: + x_arr = bb.split(x) # type: ignore[arg-type] + bb.free(x_arr[0]) + x = bb.join(x_arr[1:]) + + y_arr = bb.split(y) # type: ignore[arg-type] + bb.free(y_arr[0]) + y = bb.join(y_arr[1:]) + + target_arr = bb.split(target) # type: ignore[arg-type] + bb.free(target_arr[0]) + target = bb.join(target_arr[1:]) + return {'x': x, 'y': y, 'target': target, 'qrom_indices': qrom_indices, 'reduced': reduced} def build_call_graph(