diff --git a/lib/dft-tables.cpp b/lib/dft-tables.cpp index bbedc75..21635b5 100644 --- a/lib/dft-tables.cpp +++ b/lib/dft-tables.cpp @@ -2,6 +2,7 @@ #include "datacache.h" #include +#include #include #include @@ -11,42 +12,71 @@ namespace dsplib { namespace tables { //TODO: optional disable caching -static datacache g_dft_cache; +static datacache g_dft_cache; static datacache g_bitrev_cache; //------------------------------------------------------------------------------------------------- -static dft_ptr _gen_dft_table(size_t size) { - auto tb = std::make_shared>(size); - auto data = tb->data(); +fft2tb_ptr fft2tb::alloc(size_t n) { + if (n != (1L << nextpow2(n))) { + DSPLIB_THROW("fft size is not power of 2"); + } - real_t p; - for (size_t i = 0; i < size; ++i) { - p = i / real_t(size); - data[i].re = std::cos(2 * pi * p); - data[i].im = -std::sin(2 * pi * p); + if (g_dft_cache.cached(n)) { + return g_dft_cache.get(n); } - return tb; + auto ptr = fft2tb_ptr(new fft2tb(n)); + g_dft_cache.update(n, ptr); + return ptr; } //------------------------------------------------------------------------------------------------- -const dft_ptr dft_table(size_t size) { - if (g_dft_cache.cached(size)) { - return g_dft_cache.get(size); - } +void fft2tb::reset(size_t n) { + g_dft_cache.reset(n); +} - g_dft_cache.update(size, _gen_dft_table(size)); - return g_dft_cache.get(size); +//------------------------------------------------------------------------------------------------- +bool fft2tb::is_cached(size_t n) { + return g_dft_cache.cached(n); } //------------------------------------------------------------------------------------------------- -void dft_clear(size_t size) { - g_dft_cache.reset(size); +arr_cmplx fft2tb::unpack() const noexcept { + arr_cmplx r(_n); + + //real + for (size_t i = 0; i < _n4; i++) { + r[i].re = _cos_tb[i]; + } + r[_n4].re = 0; + for (size_t i = 0; i < _n4; i++) { + r[_n4 + 1 + i].re = -_cos_tb[_n4 - i - 1]; + } + for (size_t i = 0; i < _n2 - 1; i++) { + r[_n2 + 1 + i].re = r[_n2 - i - 1].re; + } + + //imag + const uint32_t ns = (_n - 1); + for (size_t i = 0; i < _n; i++) { + r[i].im = r[(i + _n4) & ns].re; + } + + return r; } //------------------------------------------------------------------------------------------------- -bool dft_cached(size_t size) { - return g_dft_cache.cached(size); +fft2tb::fft2tb(uint32_t n) noexcept + : _n{n} + , _n2{n / 2} + , _n4{n / 4} + , _cos_tb(_n4) { + assert(n >= 4); + assert(n == (1L << nextpow2(n))); + const real_t dt = 1 / real_t(_n); + for (size_t i = 0; i < _n4; ++i) { + _cos_tb[i] = std::cos(2 * pi * i * dt); + } } //------------------------------------------------------------------------------------------------- @@ -89,7 +119,7 @@ static bitrev_ptr _gen_bitrev_table(size_t size) { } //------------------------------------------------------------------------------------------------- -const bitrev_ptr bitrev_table(size_t size) { +bitrev_ptr bitrev_table(size_t size) { if (g_bitrev_cache.cached(size)) { return g_bitrev_cache.get(size); } diff --git a/lib/dft-tables.h b/lib/dft-tables.h index d782705..ff71e95 100644 --- a/lib/dft-tables.h +++ b/lib/dft-tables.h @@ -1,38 +1,39 @@ #pragma once -#include +#include + #include #include -#include +#include namespace dsplib { namespace tables { -using dft_ptr = std::shared_ptr>; +class fft2tb; +using fft2tb_ptr = std::shared_ptr; + +//wrapper for table exp(-1i * 2 * pi * i / n) compresed to 1/4 +class fft2tb +{ +public: + static fft2tb_ptr alloc(size_t n); + static void reset(size_t n); + static bool is_cached(size_t n); -/*! - * \brief Get (or generate) a table for calculating DFT - * \param n DFT base - * \return Table pointer - */ -const dft_ptr dft_table(size_t n); + arr_cmplx unpack() const noexcept; -/*! - * \brief Clear table from cache - * \param n DFT base - */ -void dft_clear(size_t n); +private: + explicit fft2tb(uint32_t n) noexcept; -/*! - * \brief Check if table cached - * \param n DFT base - * \return Cached - */ -bool dft_cached(size_t n); + const uint32_t _n; + const uint32_t _n2; + const uint32_t _n4; + std::vector _cos_tb; +}; //bit-reverse table using bitrev_ptr = std::shared_ptr>; -const bitrev_ptr bitrev_table(size_t n); +bitrev_ptr bitrev_table(size_t n); bool bitrev_cached(size_t n); void bitrev_clear(size_t n); diff --git a/lib/fft.cpp b/lib/fft.cpp index 9e5fac4..91f76a3 100644 --- a/lib/fft.cpp +++ b/lib/fft.cpp @@ -76,11 +76,11 @@ class fft_plan_impl const int n2 = 1L << nextpow2(n); if (n == n2) { //n == 2^K - auto brev = tables::bitrev_table(n); - auto coeff = tables::dft_table(n); + const auto brev = tables::bitrev_table(n); + const auto coeff = tables::fft2tb::alloc(n)->unpack(); solve = [brev, coeff, n](const arr_cmplx& x) { arr_cmplx r = x; - _fft2(r.data(), coeff->data(), brev->data(), n); + _fft2(r.data(), coeff.data(), brev->data(), n); return r; }; } else { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a25ff7a..f45587d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,5 +15,5 @@ CPMAddPackage(NAME googletest file(GLOB_RECURSE SOURCES "*.cpp" "*.h") add_executable(${PROJECT_NAME} ${SOURCES}) -target_include_directories(${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR}) +target_include_directories(${PROJECT_NAME} PUBLIC ${CMAKE_CURRENT_LIST_DIR} "${CMAKE_SOURCE_DIR}/lib") target_link_libraries(${PROJECT_NAME} PUBLIC dsplib gtest) \ No newline at end of file diff --git a/tests/fft_test.cpp b/tests/fft_test.cpp index 0e95b5d..d3a56b8 100644 --- a/tests/fft_test.cpp +++ b/tests/fft_test.cpp @@ -1,7 +1,9 @@ #include "tests_common.h" +#include + //------------------------------------------------------------------------------------------------- -TEST(MathTest, FftReal) { +TEST(FFT, FftReal) { using namespace dsplib; int idx = 10; int nfft = 512; @@ -15,7 +17,7 @@ TEST(MathTest, FftReal) { } //------------------------------------------------------------------------------------------------- -TEST(MathTest, FftCmplx) { +TEST(FFT, FftCmplx) { using namespace dsplib; int idx = 10; int nfft = 512; @@ -28,7 +30,7 @@ TEST(MathTest, FftCmplx) { } //------------------------------------------------------------------------------------------------- -TEST(MathTest, Ifft) { +TEST(FFT, Ifft) { using namespace dsplib; { @@ -49,7 +51,7 @@ TEST(MathTest, Ifft) { } //------------------------------------------------------------------------------------------------- -TEST(MathTest, Czt) { +TEST(FFT, Czt) { using namespace dsplib; arr_cmplx dft_ref = {6.00000000000000 + 0.00000000000000i, -1.50000000000000 + 0.866025403784439i, -1.50000000000000 - 0.866025403784439i}; @@ -61,7 +63,7 @@ TEST(MathTest, Czt) { } //------------------------------------------------------------------------------------------------- -TEST(MathTest, CztICzt) { +TEST(FFT, CztICzt) { using namespace dsplib; for (size_t i = 0; i < 1000; i++) { int n = randi({16, 2000}); @@ -73,7 +75,7 @@ TEST(MathTest, CztICzt) { } //------------------------------------------------------------------------------------------------- -TEST(MathTest, CztFft2) { +TEST(FFT, CztFft2) { using namespace dsplib; for (size_t i = 0; i < 1000; i++) { int n = randi({16, 2000}); @@ -87,7 +89,7 @@ TEST(MathTest, CztFft2) { } //------------------------------------------------------------------------------------------------- -TEST(MathTest, CztIFft2) { +TEST(FFT, CztIFft2) { using namespace dsplib; for (size_t i = 0; i < 1000; i++) { int n = randi({16, 2000}); @@ -99,3 +101,15 @@ TEST(MathTest, CztIFft2) { ASSERT_EQ_ARR_CMPLX(y1, y2); } } + +//------------------------------------------------------------------------------------------------- +TEST(FFT, Fft2Table) { + using namespace dsplib; + auto nfft_list = {4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192}; + for (auto nfft : nfft_list) { + auto tb = tables::fft2tb::alloc(nfft); + auto x1 = tb->unpack(); + auto x2 = expj(-2 * dsplib::pi * range(nfft) / nfft); + ASSERT_EQ_ARR_CMPLX(x1, x2); + } +}