diff --git a/flashlight/lib/mkl/CMakeLists.txt b/flashlight/lib/mkl/CMakeLists.txt new file mode 100644 index 00000000..447f9c4f --- /dev/null +++ b/flashlight/lib/mkl/CMakeLists.txt @@ -0,0 +1,24 @@ +cmake_minimum_required(VERSION 3.10) + +# ----------------------------- Dependencies ----------------------------- +find_package(MKL REQUIRED) + +# ----------------------------- Lib ----------------------------- + +target_sources( + fl-libraries + PRIVATE + ${CMAKE_CURRENT_LIST_DIR}/Functions.cpp + ) + +target_link_libraries( + fl-libraries + PRIVATE + ${MKL_LIBRARIES} + ) + +target_include_directories( + fl-libraries + PRIVATE + ${MKL_INCLUDE_DIR} + ) diff --git a/flashlight/lib/mkl/Functions.cpp b/flashlight/lib/mkl/Functions.cpp new file mode 100644 index 00000000..44ef2d22 --- /dev/null +++ b/flashlight/lib/mkl/Functions.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "flashlight/lib/mkl/Functions.h" + +#include + +#include + +namespace fl { +namespace lib { +namespace mkl { + +#define FL_VSL_CHECK(cmd) \ + ::fl::lib::mkl::vslCheck(cmd, __FILE__, __LINE__, #cmd) + +void vslCheck(MKL_INT err, const char* file, int line, const char* cmd) { + if (err != VSL_STATUS_OK) { + std::ostringstream ess; + ess << file << ':' << line << "] MKL-VSL error: " << err << " cmd:" << cmd; + throw std::runtime_error(ess.str()); + } +} + +std::vector Correlate( + const std::vector& kernel, + const std::vector& input) { + std::vector output(kernel.size() + input.size() - 1, 0); + VSLConvTaskPtr task; + FL_VSL_CHECK(vslsConvNewTask1D( + &task, VSL_CONV_MODE_AUTO, kernel.size(), input.size(), output.size())); + FL_VSL_CHECK(vslsConvExec1D( + task, kernel.data(), 1, input.data(), 1, output.data(), 1)); + FL_VSL_CHECK(vslConvDeleteTask(&task)); + return output; +} + +} // namespace mkl +} // namespace lib +} // namespace fl diff --git a/flashlight/lib/mkl/Functions.h b/flashlight/lib/mkl/Functions.h new file mode 100644 index 00000000..10ad9545 --- /dev/null +++ b/flashlight/lib/mkl/Functions.h @@ -0,0 +1,26 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace fl { +namespace lib { +namespace mkl { + +/** + * Convolves the kernel on the input by delegating to MKL-VSL convolution. + * Size of return value is kernel.size() + input.size() - 1 + */ +std::vector Correlate( + const std::vector& kernel, + const std::vector& input); + +} // namespace mkl +} // namespace lib +} // namespace fl diff --git a/flashlight/lib/test/CMakeLists.txt b/flashlight/lib/test/CMakeLists.txt index ea7bf5e8..dccaad53 100644 --- a/flashlight/lib/test/CMakeLists.txt +++ b/flashlight/lib/test/CMakeLists.txt @@ -19,6 +19,7 @@ build_test(SRC ${DIR}/audio/feature/WindowingTest.cpp LIBS ${LIBS}) build_test(SRC ${DIR}/common/ProducerConsumerQueueTest.cpp LIBS ${LIBS}) build_test(SRC ${DIR}/common/StringTest.cpp LIBS ${LIBS}) build_test(SRC ${DIR}/common/SystemTest.cpp LIBS ${LIBS}) +build_test(SRC ${DIR}/mkl/FunctionsTest.cpp LIBS ${LIBS}) build_test( SRC ${DIR}/text/dictionary/DictionaryTest.cpp LIBS ${LIBS} diff --git a/flashlight/lib/test/mkl/FunctionsTest.cpp b/flashlight/lib/test/mkl/FunctionsTest.cpp new file mode 100644 index 00000000..26ccef9d --- /dev/null +++ b/flashlight/lib/test/mkl/FunctionsTest.cpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include "flashlight/lib/mkl/Functions.h" + +using namespace ::fl::lib; +using ::testing::ElementsAre; +using ::testing::Pointwise; + +MATCHER_P(FloatNearPointwise, tol, "Out of range") { + return ( + std::get<0>(arg) > std::get<1>(arg) - tol && + std::get<0>(arg) < std::get<1>(arg) + tol); +} + +TEST(CorrelateTest, Identity) { + std::vector kernel = {1}; + std::vector input = {1, 2, 3, 4, 5}; + std::vector output = fl::lib::mkl::Correlate(kernel, input); + EXPECT_EQ(output.size(), input.size() + kernel.size() - 1); + EXPECT_THAT(output, Pointwise(FloatNearPointwise(0.01), input)); +} + +TEST(CorrelateTest, BasicReverb) { + std::vector kernel = {1, 2, 3}; + std::vector input = {1, 2, 3, 4, 5}; + std::vector output = fl::lib::mkl::Correlate(kernel, input); + EXPECT_EQ(output.size(), input.size() + kernel.size() - 1); + EXPECT_THAT(output, ElementsAre(1, 4, 10, 16, 22, 22, 15)); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}