@@ -17,18 +17,8 @@ namespace cp_algo::math::fft {
1717 using vpoint = complex <vftype>;
1818 static constexpr size_t flen = vftype::size();
1919
20-
21- template <typename ft>
22- constexpr ft to_ft (auto x) {
23- return ft{} + x;
24- }
25- template <typename pt>
26- constexpr pt to_pt (point r) {
27- using ft = std::conditional_t <std::is_same_v<point, pt>, ftype, vftype>;
28- return {to_ft<ft>(r.real ()), to_ft<ft>(r.imag ())};
29- }
3020 struct cvector {
31- static constexpr size_t pre_roots = 1 << 17 ;
21+ static constexpr size_t pre_roots = 1 << 19 ;
3222 std::vector<vftype> x, y;
3323 cvector (size_t n) {
3424 n = std::max (flen, std::bit_ceil (n));
@@ -67,32 +57,28 @@ namespace cp_algo::math::fft {
6757 }
6858 }
6959 static const cvector roots;
70- template <class pt = point>
71- static pt root (size_t n, size_t k) {
72- if (n < pre_roots) {
60+ template <class pt = point, bool precalc = false >
61+ static pt root (size_t n, size_t k, auto &&arg ) {
62+ if (n < pre_roots && !precalc ) {
7363 return roots.get <pt>(n + k);
7464 } else {
75- auto arg = std::numbers::pi / (ftype)n;
76- if constexpr (std::is_same_v<pt, point>) {
77- return {cos ((ftype)k * arg), sin ((ftype)k * arg)};
78- } else {
79- return pt{vftype{[&](auto i) {return cos (ftype (k + i) * arg);}},
80- vftype{[&](auto i) {return sin (ftype (k + i) * arg);}}};
81- }
65+ return polar<typename pt::value_type>(1 ., arg);
8266 }
8367 }
84- template <class pt = point>
68+ template <class pt = point, bool precalc = false >
8569 static void exec_on_roots (size_t n, size_t m, auto &&callback) {
70+ ftype arg = std::numbers::pi / (ftype)n;
8671 size_t step = sizeof (pt) / sizeof (point);
87- pt cur;
88- pt arg = to_pt<pt>(root<point>(n, step));
89- for (size_t i = 0 ; i < m; i += step) {
90- if (i % 32 == 0 || n < pre_roots) {
91- cur = root<pt>(n, i);
72+ using ft = pt::value_type;
73+ auto k = [&]() {
74+ if constexpr (std::is_same_v<pt, point>) {
75+ return ft{};
9276 } else {
93- cur *= arg ;
77+ return ft{[]( auto i) { return i;}} ;
9478 }
95- callback (i, cur);
79+ }();
80+ for (size_t i = 0 ; i < m; i += step, k += (ftype)step) {
81+ callback (i, root<pt, precalc>(n, i, arg * k));
9682 }
9783 }
9884
@@ -106,15 +92,15 @@ namespace cp_algo::math::fft {
10692 set (k + i, get<pt>(k) - t);
10793 set (k, get<pt>(k) + t);
10894 };
109- if (2 * i <= flen) {
95+ if (i < flen) {
11096 exec_on_roots (i, i, butterfly);
11197 } else {
11298 exec_on_roots<vpoint>(i, i, butterfly);
11399 }
114100 }
115101 }
116102 for (size_t k = 0 ; k < n; k += flen) {
117- set (k, get<vpoint>(k) /= to_pt<vpoint>(( ftype)n) );
103+ set (k, get<vpoint>(k) /= ( ftype)n);
118104 }
119105 }
120106 void fft () {
@@ -128,7 +114,7 @@ namespace cp_algo::math::fft {
128114 set (k, A);
129115 set (k + i, B * rt);
130116 };
131- if (2 * i <= flen) {
117+ if (i < flen) {
132118 exec_on_roots (i, i, butterfly);
133119 } else {
134120 exec_on_roots<vpoint>(i, i, butterfly);
@@ -140,14 +126,13 @@ namespace cp_algo::math::fft {
140126 const cvector cvector::roots = []() {
141127 cvector res (pre_roots);
142128 for (size_t n = 1 ; n < res.size (); n *= 2 ) {
143- auto base = polar<ftype>(1 ., std::numbers::pi / (ftype)n);
144- point cur = 1 ;
145- for (size_t k = 0 ; k < n; k++) {
146- if ((k & 15 ) == 0 ) {
147- cur = polar<ftype>(1 ., std::numbers::pi * (ftype)k / (ftype)n);
148- }
149- res.set (n + k, cur);
150- cur *= base;
129+ auto propagate = [&](size_t k, auto rt) {
130+ res.set (n + k, rt);
131+ };
132+ if (n < flen) {
133+ res.exec_on_roots <point, true >(n, n, propagate);
134+ } else {
135+ res.exec_on_roots <vpoint, true >(n, n, propagate);
151136 }
152137 }
153138 return res;
0 commit comments