Skip to content

Commit

Permalink
add mkl library wrapper with impemntation of conv1D (#443)
Browse files Browse the repository at this point in the history
Summary:
MKL Vector Statistical Library (VSL) fetures convolution and cross correlation routines as well as
commonly used pseudo- or quasi-random number generators with continuous and discrete distribution.
This diff adds a wrapper for conv1D which performed x3000 faster then hand written nested loops
on sound effect RIR application.

Pull Request resolved: flashlight/flashlight#443

Reviewed By: vineelpratap

Differential Revision: D26132123

Pulled By: avidov

fbshipit-source-id: 9729625bef328d30c752c79ccc5139e66f3b5ce3
  • Loading branch information
Your Name authored and facebook-github-bot committed Apr 13, 2021
1 parent 329d8c5 commit a6dcdc4
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 0 deletions.
24 changes: 24 additions & 0 deletions flashlight/lib/mkl/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
cmake_minimum_required(VERSION 3.10)

# ----------------------------- Dependencies -----------------------------
find_package(MKL REQUIRED)

# ----------------------------- Lib -----------------------------

target_sources(
fl-libraries
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/Functions.cpp
)

target_link_libraries(
fl-libraries
PRIVATE
${MKL_LIBRARIES}
)

target_include_directories(
fl-libraries
PRIVATE
${MKL_INCLUDE_DIR}
)
44 changes: 44 additions & 0 deletions flashlight/lib/mkl/Functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "flashlight/lib/mkl/Functions.h"

#include <sstream>

#include <mkl.h>

namespace fl {
namespace lib {
namespace mkl {

#define FL_VSL_CHECK(cmd) \
::fl::lib::mkl::vslCheck(cmd, __FILE__, __LINE__, #cmd)

void vslCheck(MKL_INT err, const char* file, int line, const char* cmd) {
if (err != VSL_STATUS_OK) {
std::ostringstream ess;
ess << file << ':' << line << "] MKL-VSL error: " << err << " cmd:" << cmd;
throw std::runtime_error(ess.str());
}
}

std::vector<float> Correlate(
const std::vector<float>& kernel,
const std::vector<float>& input) {
std::vector<float> output(kernel.size() + input.size() - 1, 0);
VSLConvTaskPtr task;
FL_VSL_CHECK(vslsConvNewTask1D(
&task, VSL_CONV_MODE_AUTO, kernel.size(), input.size(), output.size()));
FL_VSL_CHECK(vslsConvExec1D(
task, kernel.data(), 1, input.data(), 1, output.data(), 1));
FL_VSL_CHECK(vslConvDeleteTask(&task));
return output;
}

} // namespace mkl
} // namespace lib
} // namespace fl
26 changes: 26 additions & 0 deletions flashlight/lib/mkl/Functions.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <vector>

namespace fl {
namespace lib {
namespace mkl {

/**
* Convolves the kernel on the input by delegating to MKL-VSL convolution.
* Size of return value is kernel.size() + input.size() - 1
*/
std::vector<float> Correlate(
const std::vector<float>& kernel,
const std::vector<float>& input);

} // namespace mkl
} // namespace lib
} // namespace fl
1 change: 1 addition & 0 deletions flashlight/lib/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ build_test(SRC ${DIR}/audio/feature/WindowingTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/ProducerConsumerQueueTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/StringTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/common/SystemTest.cpp LIBS ${LIBS})
build_test(SRC ${DIR}/mkl/FunctionsTest.cpp LIBS ${LIBS})
build_test(
SRC ${DIR}/text/dictionary/DictionaryTest.cpp
LIBS ${LIBS}
Expand Down
44 changes: 44 additions & 0 deletions flashlight/lib/test/mkl/FunctionsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <string>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "flashlight/lib/mkl/Functions.h"

using namespace ::fl::lib;
using ::testing::ElementsAre;
using ::testing::Pointwise;

MATCHER_P(FloatNearPointwise, tol, "Out of range") {
return (
std::get<0>(arg) > std::get<1>(arg) - tol &&
std::get<0>(arg) < std::get<1>(arg) + tol);
}

TEST(CorrelateTest, Identity) {
std::vector<float> kernel = {1};
std::vector<float> input = {1, 2, 3, 4, 5};
std::vector<float> output = fl::lib::mkl::Correlate(kernel, input);
EXPECT_EQ(output.size(), input.size() + kernel.size() - 1);
EXPECT_THAT(output, Pointwise(FloatNearPointwise(0.01), input));
}

TEST(CorrelateTest, BasicReverb) {
std::vector<float> kernel = {1, 2, 3};
std::vector<float> input = {1, 2, 3, 4, 5};
std::vector<float> output = fl::lib::mkl::Correlate(kernel, input);
EXPECT_EQ(output.size(), input.size() + kernel.size() - 1);
EXPECT_THAT(output, ElementsAre(1, 4, 10, 16, 22, 22, 15));
}

int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

0 comments on commit a6dcdc4

Please sign in to comment.