Skip to content

Commit

Permalink
Add microbenchmark for layer normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Sep 25, 2024
1 parent a7d056c commit e9c3bf7
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 56 deletions.
3 changes: 2 additions & 1 deletion cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,8 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
${BENCHMARK_DIR}/gelu.cc
${BENCHMARK_DIR}/activation.cc
${BENCHMARK_DIR}/quantize.cc
${BENCHMARK_DIR}/reduceminmax.cc)
${BENCHMARK_DIR}/reduceminmax.cc
${BENCHMARK_DIR}/layer_normalization.cc)
target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc)
target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE)
if(WIN32)
Expand Down
113 changes: 58 additions & 55 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

namespace onnxruntime {

namespace {

// Utility to convert from MLFloat16 to float only when the input type is MLFloat16.
template <typename T, typename Ret>
ORT_FORCEINLINE Ret ConvertMLFloat16ToDoubleOrFloatIfNeeded(T val);
Expand Down Expand Up @@ -63,15 +65,16 @@ ORT_FORCEINLINE constexpr double ConvertToMLFloat16IfNeeded(double val) {
return val;
}

} // namespace

LayerNormImpl::LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified, bool contrib_op)
: OpKernel(op_kernel_info), simplified_{simplified}, contrib_op_{contrib_op} {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
}

namespace {
template <typename T, typename U>
Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) {
Status LayerNormImpl::ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const {
// Inputs
const Tensor* X = p_ctx->Input<Tensor>(0);
const Tensor* scale = p_ctx->Input<Tensor>(1);
Expand All @@ -81,21 +84,12 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions());
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
if (scale_size != norm_size || (bias_data && bias_size != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}

const TensorShape& scale_shape = scale->Shape();
const TensorShape& bias_shape = bias->Shape();
Tensor* Y = p_ctx->Output(0, x_shape);
auto Y_data = Y->MutableData<T>();
T* Y_data = Y->MutableData<T>();

const int64_t axis = HandleNegativeAxis(orig_axis, x_shape.NumDimensions());

std::vector<int64_t> mean_inv_std_dev_dim;
mean_inv_std_dev_dim.reserve(x_shape.NumDimensions());
Expand All @@ -107,17 +101,11 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
}
}

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

int output_index = 1;

Tensor* mean = p_ctx->Output(output_index++, TensorShape(mean_inv_std_dev_dim));
U* mean_data = nullptr;
if (!simplified) {
Tensor* mean = p_ctx->Output(output_index++, TensorShape(mean_inv_std_dev_dim));
if (mean != nullptr) {
mean_data = mean->MutableData<U>();
}
if (mean != nullptr) {
mean_data = mean->MutableData<U>();
}

U* inv_std_dev_data = nullptr;
Expand All @@ -126,8 +114,51 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
inv_std_dev_data = inv_std_dev->MutableData<U>();
}

onnxruntime::concurrency::ThreadPool* thread_pool = p_ctx->GetOperatorThreadPool();

return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape,
Y_data, mean_data, inv_std_dev_data, thread_pool, axis, epsilon, simplified);
}

Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const {
const auto elem_type = p_ctx->Input<Tensor>(0)->GetElementType();

using SupportedTypeList = boost::mp11::mp_list<float, double, MLFloat16>;

utils::MLTypeCallDispatcherFromTypeList<SupportedTypeList> t_disp(elem_type);
return t_disp.InvokeRet<Status, SrcDispatcher>(this, p_ctx, axis_, epsilon_, simplified_, contrib_op_);
}

template<typename T, typename U>
Status LayerNormImpl::ComputeWithoutContext(
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
const TensorShape& scale_shape,
const T* bias_data,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
onnxruntime::concurrency::ThreadPool* thread_pool,
int64_t axis,
float epsilon,
bool simplified
) const {
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

const auto scale_size = scale_shape.Size();
const auto bias_size = (bias_data) ? bias_shape.Size() : 0;
if (scale_size != norm_size || (bias_data && bias_size != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}

concurrency::ThreadPool::TryBatchParallelFor(
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(norm_count),
thread_pool, static_cast<int32_t>(norm_count),
[&](ptrdiff_t task_idx) {
const T* p_input = X_data + task_idx * norm_size;
T* p_output = Y_data + task_idx * norm_size;
Expand Down Expand Up @@ -159,7 +190,7 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
DoubleOrFloat scale_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(scale_data[h]);
if (simplified) {
p_output[h] = ConvertToMLFloat16IfNeeded<T>(input_value / mean_square * scale_value);
} else if (nullptr == bias) {
} else if (nullptr == bias_data) {
p_output[h] = ConvertToMLFloat16IfNeeded<T>((input_value - mean) / mean_square * scale_value);
} else {
DoubleOrFloat bias_value = ConvertMLFloat16ToDoubleOrFloatIfNeeded<T, DoubleOrFloat>(bias_data[h]);
Expand All @@ -181,32 +212,4 @@ Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, boo
return Status::OK();
}

template <typename T>
struct SrcDispatcher {
Status operator()(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified, bool contrib_op) const {
// the contrib op kernel was always registered with the same type for all constraints.
// our implementation of the onnx op only supports 'float' as the U constraint.
#if !defined(DISABLE_CONTRIB_OPS)
if (contrib_op) {
return ComputeImpl<T, T>(p_ctx, orig_axis, epsilon, simplified);
} else
#else
ORT_UNUSED_PARAMETER(contrib_op);
#endif
{
return ComputeImpl<T, float>(p_ctx, orig_axis, epsilon, simplified);
}
}
};
} // namespace

Status LayerNormImpl::Compute(OpKernelContext* p_ctx) const {
const auto elem_type = p_ctx->Input<Tensor>(0)->GetElementType();

using SupportedTypeList = boost::mp11::mp_list<float, double, MLFloat16>;

utils::MLTypeCallDispatcherFromTypeList<SupportedTypeList> t_disp(elem_type);
return t_disp.InvokeRet<Status, SrcDispatcher>(p_ctx, axis_, epsilon_, simplified_, contrib_op_);
}

} // namespace onnxruntime
39 changes: 39 additions & 0 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,46 @@ class LayerNormImpl : public OpKernel {
LayerNormImpl(const OpKernelInfo& op_kernel_info, bool simplified = false, bool contrib_op = false);
Status Compute(OpKernelContext* p_op_kernel_context) const override;

// This method was created so that it can be called directly from `test/onnx/microbenchmark/layer_normalization.cc`.
template<typename T, typename U>
Status ComputeWithoutContext(
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
const TensorShape& scale_shape,
const T* bias_data,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev,
onnxruntime::concurrency::ThreadPool* thread_pool,
int64_t axis,
float epsilon = epsilon_,
bool simplified = simplified_
) const;

private:
template <typename T, typename U>
Status ComputeImpl(OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified) const;

template <typename T>
struct SrcDispatcher {
Status operator()(const LayerNormImpl* p_instance, OpKernelContext* p_ctx, int64_t orig_axis, float epsilon, bool simplified, bool contrib_op) const {
// the contrib op kernel was always registered with the same type for all constraints.
// our implementation of the onnx op only supports 'float' as the U constraint.
#if !defined(DISABLE_CONTRIB_OPS)
if (contrib_op) {
return p_instance->ComputeImpl<T, T>(p_ctx, orig_axis, epsilon, simplified);
} else
#else
ORT_UNUSED_PARAMETER(contrib_op);
#endif
{
return p_instance->ComputeImpl<T, float>(p_ctx, orig_axis, epsilon, simplified);
}
}
};

int64_t axis_;
float epsilon_;
const bool simplified_;
Expand Down
108 changes: 108 additions & 0 deletions onnxruntime/test/onnx/microbenchmark/layer_normalization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include "core/platform/threadpool.h"
#include "core/util/thread_utils.h"
#include <benchmark/benchmark.h>

#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif

#include "core/framework/allocator.h"
#include "core/framework/config_options.h"
#include "core/framework/data_transfer_manager.h"
#include "core/framework/op_kernel_info.h"
#include "core/framework/ort_value_name_idx_map.h"
#include "core/platform/windows/env.h"
#include "core/providers/cpu/nn/layer_norm_impl.h"
#include "core/providers/cpu/cpu_provider_factory.h"
#include "core/providers/cpu/cpu_provider_factory_creator.h"
#include "core/util/thread_utils.h"

#if defined(__GNUC__) && !defined(__clang__)
#pragma GCC diagnostic pop
#endif

using namespace onnxruntime;

template<typename T, typename U>
static void BM_LayerNormalization(benchmark::State& state) {
bool simplified = false;
const float epsilon = 1e-05f;
int64_t axis = 1;

onnxruntime::Node node;
// Required by LayerNormImpl constructor
node.AddAttribute("axis", axis);
node.AddAttribute("epsilon", epsilon);

KernelDef kernel_def;
std::unique_ptr<IExecutionProvider> execution_provider = CPUProviderFactoryCreator::Create(true)->CreateProvider();
std::unordered_map<int, OrtValue> constant_initialized_tensors;
OrtValueNameIdxMap mlvalue_name_idx_map;
DataTransferManager data_transfer_mgr;
AllocatorMap allocators;
ConfigOptions config_options;

OpKernelInfo op_kernel_info(node, kernel_def, *execution_provider, constant_initialized_tensors, mlvalue_name_idx_map,
data_transfer_mgr, allocators, config_options);

LayerNormImpl layer_norm_impl(op_kernel_info);

std::vector<int64_t> x_dims{2, 2, 2};
TensorShape x_shape(x_dims);
std::vector<float> x{1, 1, 1, 1, 1, 1, 1, 1};

std::vector<int64_t> scale_bias_dims{1, 2, 2};
TensorShape scale_shape(scale_bias_dims);
TensorShape bias_shape(scale_bias_dims);
std::vector<float> scale{1, 1, 1, 1};
std::vector<float> bias{1, 1, 1, 1};

T* X_data = static_cast<T*>(malloc(x.size() * sizeof(T)));
T* scale_data = static_cast<T*>(malloc(scale.size() * sizeof(T)));
T* bias_data = static_cast<T*>(malloc(bias.size() * sizeof(T)));
for (size_t i = 0; i < x.size(); i++) {
X_data[i] = T(x[i]);
}
for (size_t i = 0; i < scale.size(); i++) {
scale_data[i] = T(scale[i]);
}
for (size_t i = 0; i < bias.size(); i++) {
bias_data[i] = T(bias[i]);
}

T* Y_data = static_cast<T*>(malloc(x.size() * sizeof(T)));
U* mean_data = static_cast<U*>(malloc(x.size() * sizeof(U)));
U* inv_std_dev_data = static_cast<U*>(malloc(x.size() * sizeof(U)));

OrtThreadPoolParams tp_params;
tp_params.name = ORT_TSTR("intra-op");
std::unique_ptr<concurrency::ThreadPool> thread_pool = concurrency::CreateThreadPool(
&Env::Default(), tp_params, concurrency::ThreadPoolType::INTRA_OP);

for (auto _ : state) {
auto status = layer_norm_impl.ComputeWithoutContext(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape,
Y_data, mean_data, inv_std_dev_data, thread_pool.get(), axis, epsilon, simplified);

if (! status.IsOK())
{
std::cout << "ComputeWithoutContext status not OK: " << status.ErrorMessage() << std::endl;
break;
}
}
}


BENCHMARK(BM_LayerNormalization<float, float>)
->Arg(1)
->Arg(256)
->Arg(1024)
->UseRealTime()
->Unit(benchmark::TimeUnit::kMicrosecond);

BENCHMARK(BM_LayerNormalization<MLFloat16, MLFloat16>)
->Arg(1)
->Arg(256)
->Arg(1024)
->UseRealTime()
->Unit(benchmark::TimeUnit::kMicrosecond);

0 comments on commit e9c3bf7

Please sign in to comment.