Skip to content

Commit

Permalink
update based on comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Nov 8, 2024
1 parent d3ae5ef commit 96e5bfb
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,19 @@ void ComputeJob(
void ComputeJob(
const MLFloat16* input_data,
const MLFloat16* skip_data,
const IAllocatorUniquePtr<float>& prepacked_skip_fp32_data,
const float* gamma_float_ptr,
const float* beta_float_ptr,
const float* bias_float_ptr,
float* input_float_ptr,
float* output_float_ptr,
float* skip_float_ptr,
ptrdiff_t task_idx,
int hidden_size,
int64_t skip_size,
float epsilon,
bool simplified,
MLFloat16* output_data,
MLFloat16* skip_input_bias_add_output_data) {
MLFloat16* skip_input_bias_add_output_data,
AllocatorPtr alloc) {
auto offset = task_idx * hidden_size;
const MLFloat16* p_input = input_data + offset;
MLFloat16* p_output = output_data + offset;
Expand All @@ -121,13 +121,18 @@ void ComputeJob(
float mean_square(0.0f);
const size_t num_elems = static_cast<size_t>(hidden_size);

MlasConvertHalfToFloatBuffer(p_input, input_float_ptr, num_elems);
IAllocatorUniquePtr<float> input_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_input, input_float_uptr.get(), num_elems);

if (skip_data) {
IAllocatorUniquePtr<float> skip_float_uptr = nullptr;
if (prepacked_skip_fp32_data == nullptr && skip_data) {
const MLFloat16* p_skip = skip_data + (offset % skip_size);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_ptr, num_elems);
skip_float_uptr = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(p_skip, skip_float_uptr.get(), num_elems);
}

const float* input_float_ptr = input_float_uptr.get();
const float* skip_float_ptr = prepacked_skip_fp32_data ? prepacked_skip_fp32_data.get() : skip_float_uptr.get();
for (size_t h = 0; h < num_elems; h++) {
float val = input_float_ptr[h] + skip_float_ptr[h];

Expand Down Expand Up @@ -211,8 +216,8 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
bias,
hidden_size,
input_dims_size,
bool(prepacked_skip_fp32_data_),
bool(prepacked_gamma_fp32_data_)));
prepacked_skip_fp32_data_ != nullptr,
prepacked_gamma_fp32_data_ != nullptr));

int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);

Expand All @@ -232,24 +237,16 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));

IAllocatorUniquePtr<float> input_fp32;
IAllocatorUniquePtr<float> output_fp32;
IAllocatorUniquePtr<float> skip_fp32;
IAllocatorUniquePtr<float> gamma_fp32;
IAllocatorUniquePtr<float> beta_fp32;
IAllocatorUniquePtr<float> bias_fp32;

if constexpr (std::is_same_v<T, MLFloat16>) {
const size_t num_elems = static_cast<size_t>(hidden_size);

input_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
output_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);

if (prepacked_skip_fp32_data_ == nullptr && skip_data) {
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
// skip data will be converted inside ComputeJob, because it needs to use an offset based on task_idx.
}

if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
Expand All @@ -271,13 +268,13 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
[&](ptrdiff_t task_idx) {
if constexpr (std::is_same_v<T, MLFloat16>) {
ComputeJob(input_data, skip_data,
prepacked_skip_fp32_data_,
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
input_fp32.get(), output_fp32.get(),
prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get() : skip_fp32.get(),
output_fp32.get(),
task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
skip_input_bias_add_output_data);
skip_input_bias_add_output_data, alloc);
} else {
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
Expand Down

0 comments on commit 96e5bfb

Please sign in to comment.