-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add mkl library wrapper with impemntation of conv1D (#443)
Summary: MKL Vector Statistical Library (VSL) fetures convolution and cross correlation routines as well as commonly used pseudo- or quasi-random number generators with continuous and discrete distribution. This diff adds a wrapper for conv1D which performed x3000 faster then hand written nested loops on sound effect RIR application. Pull Request resolved: flashlight/flashlight#443 Reviewed By: vineelpratap Differential Revision: D26132123 Pulled By: avidov fbshipit-source-id: 9729625bef328d30c752c79ccc5139e66f3b5ce3
- Loading branch information
1 parent
329d8c5
commit a6dcdc4
Showing
5 changed files
with
139 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
cmake_minimum_required(VERSION 3.10) | ||
|
||
# ----------------------------- Dependencies ----------------------------- | ||
find_package(MKL REQUIRED) | ||
|
||
# ----------------------------- Lib ----------------------------- | ||
|
||
target_sources( | ||
fl-libraries | ||
PRIVATE | ||
${CMAKE_CURRENT_LIST_DIR}/Functions.cpp | ||
) | ||
|
||
target_link_libraries( | ||
fl-libraries | ||
PRIVATE | ||
${MKL_LIBRARIES} | ||
) | ||
|
||
target_include_directories( | ||
fl-libraries | ||
PRIVATE | ||
${MKL_INCLUDE_DIR} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "flashlight/lib/mkl/Functions.h" | ||
|
||
#include <sstream> | ||
|
||
#include <mkl.h> | ||
|
||
namespace fl { | ||
namespace lib { | ||
namespace mkl { | ||
|
||
#define FL_VSL_CHECK(cmd) \ | ||
::fl::lib::mkl::vslCheck(cmd, __FILE__, __LINE__, #cmd) | ||
|
||
void vslCheck(MKL_INT err, const char* file, int line, const char* cmd) { | ||
if (err != VSL_STATUS_OK) { | ||
std::ostringstream ess; | ||
ess << file << ':' << line << "] MKL-VSL error: " << err << " cmd:" << cmd; | ||
throw std::runtime_error(ess.str()); | ||
} | ||
} | ||
|
||
std::vector<float> Correlate( | ||
const std::vector<float>& kernel, | ||
const std::vector<float>& input) { | ||
std::vector<float> output(kernel.size() + input.size() - 1, 0); | ||
VSLConvTaskPtr task; | ||
FL_VSL_CHECK(vslsConvNewTask1D( | ||
&task, VSL_CONV_MODE_AUTO, kernel.size(), input.size(), output.size())); | ||
FL_VSL_CHECK(vslsConvExec1D( | ||
task, kernel.data(), 1, input.data(), 1, output.data(), 1)); | ||
FL_VSL_CHECK(vslConvDeleteTask(&task)); | ||
return output; | ||
} | ||
|
||
} // namespace mkl | ||
} // namespace lib | ||
} // namespace fl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <vector> | ||
|
||
namespace fl { | ||
namespace lib { | ||
namespace mkl { | ||
|
||
/** | ||
* Convolves the kernel on the input by delegating to MKL-VSL convolution. | ||
* Size of return value is kernel.size() + input.size() - 1 | ||
*/ | ||
std::vector<float> Correlate( | ||
const std::vector<float>& kernel, | ||
const std::vector<float>& input); | ||
|
||
} // namespace mkl | ||
} // namespace lib | ||
} // namespace fl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) Facebook, Inc. and its affiliates. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <string> | ||
|
||
#include <gmock/gmock.h> | ||
#include <gtest/gtest.h> | ||
|
||
#include "flashlight/lib/mkl/Functions.h" | ||
|
||
using namespace ::fl::lib; | ||
using ::testing::ElementsAre; | ||
using ::testing::Pointwise; | ||
|
||
MATCHER_P(FloatNearPointwise, tol, "Out of range") { | ||
return ( | ||
std::get<0>(arg) > std::get<1>(arg) - tol && | ||
std::get<0>(arg) < std::get<1>(arg) + tol); | ||
} | ||
|
||
TEST(CorrelateTest, Identity) { | ||
std::vector<float> kernel = {1}; | ||
std::vector<float> input = {1, 2, 3, 4, 5}; | ||
std::vector<float> output = fl::lib::mkl::Correlate(kernel, input); | ||
EXPECT_EQ(output.size(), input.size() + kernel.size() - 1); | ||
EXPECT_THAT(output, Pointwise(FloatNearPointwise(0.01), input)); | ||
} | ||
|
||
TEST(CorrelateTest, BasicReverb) { | ||
std::vector<float> kernel = {1, 2, 3}; | ||
std::vector<float> input = {1, 2, 3, 4, 5}; | ||
std::vector<float> output = fl::lib::mkl::Correlate(kernel, input); | ||
EXPECT_EQ(output.size(), input.size() + kernel.size() - 1); | ||
EXPECT_THAT(output, ElementsAre(1, 4, 10, 16, 22, 22, 15)); | ||
} | ||
|
||
int main(int argc, char** argv) { | ||
::testing::InitGoogleTest(&argc, argv); | ||
return RUN_ALL_TESTS(); | ||
} |