44#include < iostream>
55#include < cassert>
66namespace cp_algo ::math {
7- inline constexpr uint64_t inv64 ( uint64_t x) {
7+ inline constexpr auto inv2 ( auto x) {
88 assert (x % 2 );
9- uint64_t y = 1 ;
9+ std:: make_unsigned_t < decltype (x)> y = 1 ;
1010 while (y * x != 1 ) {
1111 y *= 2 - x * y;
1212 }
1313 return y;
1414 }
1515
16- template <typename modint>
16+ template <typename modint, typename _Int >
1717 struct modint_base {
18- static int64_t mod () {
18+ using Int = _Int;
19+ using Uint = std::make_unsigned_t <Int>;
20+ static constexpr size_t bits = sizeof (Int) * 8 ;
21+ using Int2 = std::conditional_t <bits <= 32 , uint64_t , __uint128_t >;
22+ static Int mod () {
1923 return modint::mod ();
2024 }
21- static uint64_t imod () {
25+ static Uint imod () {
2226 return modint::imod ();
2327 }
24- static __uint128_t pw128 () {
28+ static Int2 pw128 () {
2529 return modint::pw128 ();
2630 }
27- static uint64_t m_reduce (__uint128_t ab) {
31+ static Uint m_reduce (Int2 ab) {
2832 if (mod () % 2 == 0 ) [[unlikely]] {
2933 return ab % mod ();
3034 } else {
31- uint64_t m = ab * imod ();
32- return (ab + __uint128_t (m) * mod ()) >> 64 ;
35+ Uint m = ab * imod ();
36+ return (ab + (Int2)m * mod ()) >> bits ;
3337 }
3438 }
35- static uint64_t m_transform (uint64_t a) {
39+ static Uint m_transform (Uint a) {
3640 if (mod () % 2 == 0 ) [[unlikely]] {
3741 return a;
3842 } else {
3943 return m_reduce (a * pw128 ());
4044 }
4145 }
4246 modint_base (): r(0 ) {}
43- modint_base (int64_t rr): r(rr % mod()) {
47+ modint_base (Int rr): r(rr % mod()) {
4448 r = std::min (r, r + mod ());
4549 r = m_transform (r);
4650 }
@@ -56,7 +60,7 @@ namespace cp_algo::math {
5660 return to_modint () *= t.inv ();
5761 }
5862 modint& operator *= (const modint &t) {
59- r = m_reduce (__uint128_t (r) * t.r );
63+ r = m_reduce ((Int2)r * t.r );
6064 return to_modint ();
6165 }
6266 modint& operator += (const modint &t) {
@@ -78,86 +82,89 @@ namespace cp_algo::math {
7882 auto operator >= (const modint_base &t) const {return getr () >= t.getr ();}
7983 auto operator < (const modint_base &t) const {return getr () < t.getr ();}
8084 auto operator > (const modint_base &t) const {return getr () > t.getr ();}
81- int64_t rem () const {
82- uint64_t R = getr ();
83- return 2 * R > (uint64_t )mod () ? R - mod () : R;
85+ Int rem () const {
86+ Uint R = getr ();
87+ return 2 * R > (Uint )mod () ? R - mod () : R;
8488 }
8589
8690 // Only use if you really know what you're doing!
87- uint64_t modmod () const {return 8ULL * mod () * mod ();};
88- void add_unsafe (uint64_t t) {r += t;}
91+ Uint modmod () const {return (Uint) 8 * mod () * mod ();};
92+ void add_unsafe (Uint t) {r += t;}
8993 void pseudonormalize () {r = std::min (r, r - modmod ());}
9094 modint const & normalize () {
91- if (r >= (uint64_t )mod ()) {
95+ if (r >= (Uint )mod ()) {
9296 r %= mod ();
9397 }
9498 return to_modint ();
9599 }
96- void setr (uint64_t rr) {r = m_transform (rr);}
97- uint64_t getr () const {
98- uint64_t res = m_reduce (r);
100+ void setr (Uint rr) {r = m_transform (rr);}
101+ Uint getr () const {
102+ Uint res = m_reduce (r);
99103 return std::min (res, res - mod ());
100104 }
101- void setr_direct (uint64_t rr) {r = rr;}
102- uint64_t getr_direct () const {return r;}
105+ void setr_direct (Uint rr) {r = rr;}
106+ Uint getr_direct () const {return r;}
103107 private:
104- uint64_t r;
108+ Uint r;
105109 modint& to_modint () {return static_cast <modint&>(*this );}
106110 modint const & to_modint () const {return static_cast <modint const &>(*this );}
107111 };
108112 template <typename modint>
109- std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
110- uint64_t r;
113+ concept modint_type = std::is_base_of_v<modint_base<modint, typename modint::Int>, modint>;
114+ template <modint_type modint>
115+ std::istream& operator >> (std::istream &in, modint &x) {
116+ typename modint::Uint r;
111117 auto &res = in >> r;
112118 x.setr (r);
113119 return res;
114120 }
115- template <typename modint>
116- std::ostream& operator << (std::ostream &out, modint_base< modint> const & x) {
121+ template <modint_type modint>
122+ std::ostream& operator << (std::ostream &out, modint const & x) {
117123 return out << x.getr ();
118124 }
119125
120- template <typename modint>
121- concept modint_type = std::is_base_of_v<modint_base<modint>, modint>;
122-
123- template <int64_t m>
124- struct modint : modint_base<modint<m>> {
125- static constexpr uint64_t im = m % 2 ? inv64(-m) : 0 ;
126- static constexpr uint64_t r2 = __uint128_t (-1 ) % m + 1 ;
127- static constexpr int64_t mod () {return m;}
128- static constexpr uint64_t imod () {return im;}
129- static constexpr __uint128_t pw128 () {return r2;}
130- using Base = modint_base<modint<m>>;
126+ template <auto m>
127+ struct modint : modint_base<modint<m>, decltype (m)> {
128+ using Base = modint_base<modint<m>, decltype (m)>;
131129 using Base::Base;
130+ static constexpr Base::Uint im = m % 2 ? inv2(-m) : 0 ;
131+ static constexpr Base::Uint r2 = (typename Base::Int2)(-1 ) % m + 1 ;
132+ static constexpr Base::Int mod () {return m;}
133+ static constexpr Base::Uint imod () {return im;}
134+ static constexpr Base::Int2 pw128 () {return r2;}
132135 };
133136
134- struct dynamic_modint : modint_base<dynamic_modint> {
135- static int64_t mod () {return m;}
136- static uint64_t imod () {return im;}
137- static __uint128_t pw128 () {return r2;}
138- static void switch_mod (int64_t nm) {
137+ template <typename Int = int64_t >
138+ struct dynamic_modint : modint_base<dynamic_modint<Int>, Int> {
139+ using Base = modint_base<dynamic_modint<Int>, Int>;
140+ using Base::Base;
141+ static Int mod () {return m;}
142+ static Base::Uint imod () {return im;}
143+ static Base::Int2 pw128 () {return r2;}
144+ static void switch_mod (Int nm) {
139145 m = nm;
140- im = m % 2 ? inv64 (-m) : 0 ;
141- r2 = __uint128_t (-1 ) % m + 1 ;
146+ im = m % 2 ? inv2 (-m) : 0 ;
147+ r2 = ( typename Base::Int2) (-1 ) % m + 1 ;
142148 }
143- using Base = modint_base<dynamic_modint>;
144- using Base::Base;
145149
146150 // Wrapper for temp switching
147- auto static with_mod (int64_t tmp, auto callback) {
151+ auto static with_mod (Int tmp, auto callback) {
148152 struct scoped {
149- int64_t prev = mod();
153+ Int prev = mod();
150154 ~scoped () {switch_mod (prev);}
151155 } _;
152156 switch_mod (tmp);
153157 return callback ();
154158 }
155159 private:
156- static int64_t m;
157- static uint64_t im, r1, r2;
160+ static Int m;
161+ static Base::Uint im, r1, r2;
158162 };
159- int64_t dynamic_modint::m = 1 ;
160- uint64_t dynamic_modint::im = -1 ;
161- uint64_t dynamic_modint::r2 = 0 ;
163+ template <typename Int>
164+ Int dynamic_modint<Int>::m = 1 ;
165+ template <typename Int>
166+ dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::im = -1 ;
167+ template <typename Int>
168+ dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::r2 = 0 ;
162169}
163170#endif // CP_ALGO_MATH_MODINT_HPP
0 commit comments