Skip to content

Commit ff1e2b2

Browse files
Merge branch 'main' into montgomery
2 parents 0444490 + 676b02a commit ff1e2b2

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

qualtran/_infra/data_types.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
import abc
5252
from enum import Enum
5353
from functools import cached_property
54-
from typing import Any, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union
54+
from typing import Any, Iterable, List, Literal, Optional, Sequence, TYPE_CHECKING, Union
5555

5656
import attrs
5757
import numpy as np
@@ -908,8 +908,19 @@ class QGF(QDType):
908908

909909
characteristic: SymbolicInt
910910
degree: SymbolicInt
911-
irreducible_poly: Optional['galois.Poly'] = None
912-
element_repr: str = 'int'
911+
irreducible_poly: Optional['galois.Poly'] = attrs.field()
912+
element_repr: Literal["int", "poly", "power"] = attrs.field(default='int')
913+
914+
@irreducible_poly.default
915+
def _irreducible_poly_default(self):
916+
if is_symbolic(self.characteristic, self.degree):
917+
return None
918+
919+
from galois import GF
920+
921+
return GF( # type: ignore[call-overload]
922+
int(self.characteristic), int(self.degree), compile='python-calculate'
923+
).irreducible_poly
913924

914925
@cached_property
915926
def order(self) -> SymbolicInt:
@@ -938,10 +949,12 @@ def _quint_equivalent(self) -> QUInt:
938949
def gf_type(self):
939950
from galois import GF
940951

952+
poly = self.irreducible_poly if self.degree > 1 else None
953+
941954
return GF( # type: ignore[call-overload]
942955
int(self.characteristic),
943956
int(self.degree),
944-
irreducible_poly=self.irreducible_poly,
957+
irreducible_poly=poly,
945958
repr=self.element_repr,
946959
compile='python-calculate',
947960
)

qualtran/_infra/data_types_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -536,3 +536,11 @@ def test_montgomery_bit_conversion(bitsize):
536536
dtype = QMontgomeryUInt(bitsize)
537537
for v in range(1 << bitsize):
538538
assert v == dtype.from_bits(dtype.to_bits(v))
539+
540+
541+
def test_qgf_with_default_poly_is_compatible():
542+
qgf_one = QGF(2, 4)
543+
544+
qgf_two = QGF(2, 4, irreducible_poly=qgf_one.gf_type.irreducible_poly)
545+
546+
assert qgf_one == qgf_two

0 commit comments

Comments
 (0)