Skip to content

Commit

Permalink
Format bert or transformers code (microsoft#12646)
Browse files Browse the repository at this point in the history
(1) Modify some lines to fit line length limit 120
(2) Adjust parameter order of LaunchAttentionKernel
(3) Format code with Clang-Format in VS Code
(4) Fix spelling errors
  • Loading branch information
tianleiwu authored Aug 22, 2022
1 parent dc486d1 commit d93e653
Show file tree
Hide file tree
Showing 46 changed files with 1,371 additions and 945 deletions.
120 changes: 77 additions & 43 deletions onnxruntime/contrib_ops/cpu/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
}

if (hidden_size % num_heads_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads.");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisible by num_heads.");
}
} else {
int qkv_sizes = 0;
Expand All @@ -129,12 +129,13 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,

if (qkv_hidden_sizes_[0] != qkv_hidden_sizes_[1]) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"qkv_hidden_sizes first element should be same as the second");
"qkv_hidden_sizes first element should be same as the second");
}

for (size_t i = 0; i < qkv_hidden_sizes_.size(); i++) {
if (qkv_hidden_sizes_[i] % num_heads_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "hidden_size should be divisiable by num_heads:", qkv_hidden_sizes_[i]);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"hidden_size should be divisible by num_heads:", qkv_hidden_sizes_[i]);
}

qkv_sizes += static_cast<int>(qkv_hidden_sizes_[i]);
Expand Down Expand Up @@ -164,13 +165,16 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 0 shall have length of 2");
}
if (static_cast<int>(past_dims[1]) != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'past' dimension 1 shall have same length as dimension 0 of input 0");
}
if (static_cast<int>(past_dims[2]) != num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'past' dimension 2 shall have length of num_heads", num_heads_);
}
if (static_cast<int>(past_dims[4]) != hidden_size / num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_);
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'past' dimension 2 shall have length of ", hidden_size / num_heads_);
}
past_sequence_length = static_cast<int>(past_dims[3]);
}
Expand All @@ -179,31 +183,50 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const auto& mask_dims = mask_index->Shape().GetDims();
if (mask_dims.size() == 1) {
if (static_cast<int>(mask_dims[0]) != batch_size && static_cast<int>(mask_dims[0]) != 2 * batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
}
} else if (mask_dims.size() == 2) {
if (static_cast<int>(mask_dims[0]) != batch_size || static_cast<int>(mask_dims[1]) != past_sequence_length + sequence_length) {
if (static_cast<int>(mask_dims[0]) != batch_size ||
static_cast<int>(mask_dims[1]) != past_sequence_length + sequence_length) {
// Add operator supports broadcasting. Here we handle a case with only one element in the 2nd dimension.
if ((static_cast<int>(mask_dims[0]) == batch_size || static_cast<int>(mask_dims[0]) == 1) && static_cast<int>(mask_dims[1]) == 1) {
// Mask will have same value after propogation, which has same effect as no mask.
if ((static_cast<int>(mask_dims[0]) == batch_size || static_cast<int>(mask_dims[0]) == 1) &&
static_cast<int>(mask_dims[1]) == 1) {
// Mask will have same value after propagation, which has same effect as no mask.
mask_index = nullptr;
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 2D data shall have shape batch_size x (past_sequence_length + sequence_length)");
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 2D data shall have shape "
"batch_size x (past_sequence_length + sequence_length)");
}
}
} else if (mask_dims.size() == 3) {
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 3D data shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
if (static_cast<int>(mask_dims[0]) != batch_size ||
mask_dims[1] != sequence_length ||
static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 3D data shall have shape "
"batch_size x sequence_length x (past_sequence_length + sequence_length)");
}
} else if (mask_dims.size() == 4) {
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != 1 || mask_dims[2] != mask_dims[3] || mask_dims[2] < static_cast<int64_t>(past_sequence_length) + sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have shape batch_size x 1 x max_sequence_length x max_sequence_length)");
if (static_cast<int>(mask_dims[0]) != batch_size ||
mask_dims[1] != 1 ||
mask_dims[2] != mask_dims[3] ||
mask_dims[2] < static_cast<int64_t>(past_sequence_length) + sequence_length) {
return ORT_MAKE_STATUS(
ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 4D data shall have shape "
"batch_size x 1 x max_sequence_length x max_sequence_length)");
}
if (is_unidirectional_ == true) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false");
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 4D data shall have is_unidirectional_ set to false");
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'mask_index' is expected to have 1, 2, 3 or 4 dimensions, got ",
mask_dims.size());
}
}
Expand All @@ -212,24 +235,29 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const auto& extra_add_qk_dims = extra_add_qk->Shape().GetDims();

if (extra_add_qk_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' is expected to have 4 dimensions, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'extra_add_qk' is expected to have 4 dimensions, got ",
extra_add_qk_dims.size());
}

if (extra_add_qk_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'extra_add_qk' dimension 0 should be same as batch_size, got ",
extra_add_qk_dims[0]);
}
if (extra_add_qk_dims[1] != num_heads_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'extra_add_qk' dimension 1 should be same as number of heads, got ",
extra_add_qk_dims[1]);
}
if (extra_add_qk_dims[2] != sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'extra_add_qk' dimension 2 should be same as sequence_length, got ",
extra_add_qk_dims[2]);
}
if (extra_add_qk_dims[3] != sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'extra_add_qk' dimension 3 should be same as sequence_length, got ",
extra_add_qk_dims[3]);
}
}
Expand Down Expand Up @@ -322,7 +350,6 @@ template <typename T>
Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {

/* The PrePack() massages the weights to speed up Compute(), there is an option to
* use shared prepacked weights in which case prepacked_weights parameter would be non-null.
*
Expand Down Expand Up @@ -375,9 +402,14 @@ Status Attention<T>::PrePack(const Tensor& weights, int input_idx, AllocatorPtr
const size_t qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
const size_t weight_matrix_col_size = q_hidden_size + k_hidden_size + v_hidden_size;

if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size, weights_data, weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size, weights_data + (num_heads_ * qkv_head_size[0]), weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size, weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])), weight_matrix_col_size, prepacked_weights)) {
if (!IsPackWeightsSuccessful(0, alloc, qkv_head_size[0], input_hidden_size,
weights_data, weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(1, alloc, qkv_head_size[1], input_hidden_size,
weights_data + (num_heads_ * qkv_head_size[0]),
weight_matrix_col_size, prepacked_weights) ||
!IsPackWeightsSuccessful(2, alloc, qkv_head_size[2], input_hidden_size,
weights_data + (num_heads_ * (qkv_head_size[0] + qkv_head_size[1])),
weight_matrix_col_size, prepacked_weights)) {
if (prepacked_weights == nullptr) {
FreePackedWeights(packed_weights_, qkv_hidden_sizes_.size());
}
Expand Down Expand Up @@ -469,7 +501,8 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
// gemm_data(BS, NT) = input(BS, D) x weights(D, NT) + bias(NT)
// D (input_hidden_size) is hidden dimension of input, where D could be larger than any of the hidden_sizes
// (NH) when model is pruned. T = H1 + H2 + H3, where H1, H2, H3 are head sizes of Q, K, V respectively
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * (q_hidden_size + k_hidden_size + v_hidden_size) * element_size);
int qkv_hidden_size = (q_hidden_size + k_hidden_size + v_hidden_size);
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * qkv_hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(std::move(allocator)));

auto Q = reinterpret_cast<T*>(gemm_data);
Expand Down Expand Up @@ -523,12 +556,13 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
// C: QKV[qkv_index] (BxNxSxT) (B.N.)S x T S x H
if (is_prepack_) {
uint8_t* packed_weight;
packed_weight = static_cast<uint8_t*>(packed_weights_[qkv_index].get()) + packed_weights_size_[qkv_index] * (weights_offset / head_size);
packed_weight = static_cast<uint8_t*>(packed_weights_[qkv_index].get()) +
packed_weights_size_[qkv_index] * (weights_offset / head_size);

MlasGemm(
CblasNoTrans, // TransA = no
sequence_length, // M = S
head_size, // N = H
head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
Expand All @@ -540,20 +574,20 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
nullptr); // use single-thread
} else {
math::GemmEx<float, ThreadPool>(
CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
input_hidden_size, // lda = D
weights_data + weights_offset, // B
q_hidden_size + k_hidden_size + v_hidden_size,// ldb = NH1 + NH2 + NH3
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
CblasNoTrans, // TransA = no
CblasNoTrans, // TransB = no
sequence_length, // M = S
head_size, // N = H
input_hidden_size, // K = D
1.0f, // alpha
input_data + input_offset, // A
input_hidden_size, // lda = D
weights_data + weights_offset, // B
q_hidden_size + k_hidden_size + v_hidden_size, // ldb = NH1 + NH2 + NH3
1.0f, // beta
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
);
}
}
Expand Down
11 changes: 6 additions & 5 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <vector>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"

Expand All @@ -17,7 +18,7 @@ class AttentionBase {
const TensorShape& bias_shape,
const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr.
const Tensor* past,
const Tensor *extra_add_qk,
const Tensor* extra_add_qk,
const int max_threads_per_block) const;

Tensor* GetPresent(OpKernelContext* context,
Expand Down Expand Up @@ -45,11 +46,11 @@ class AttentionBase {
const TensorShape& bias_shape,
const Tensor*& mask_index, // For dummy mask with shape (1, 1) or (batch_size, 1), it will be updated to nullptr.
const Tensor* past,
const Tensor *extra_add_qk) const;
const Tensor* extra_add_qk) const;

int num_heads_; // number of attention heads
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V path hidden layer sizes
int num_heads_; // number of attention heads
bool is_unidirectional_; // whether every token can only attend to previous tokens.
std::vector<int64_t> qkv_hidden_sizes_; // Q, K, V path hidden layer sizes
};

} // namespace contrib
Expand Down
Loading

0 comments on commit d93e653

Please sign in to comment.