Skip to content

Commit a28daa5

Browse files
committed
Don't get data from a prepacked tensor
1 parent 00b7d8c commit a28daa5

File tree

2 files changed

+31
-46
lines changed

2 files changed

+31
-46
lines changed

onnxruntime/contrib_ops/cpu/skip_layer_norm.cc

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,17 @@ namespace {
4242
template <typename T, typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, void>>
4343
void ComputeJob(
4444
const T* input_data,
45+
const T* skip_data,
4546
const T* gamma_data,
4647
const T* beta_data,
4748
const T* bias_data,
48-
const T* skip_data,
49-
const float* gamma_float_ptr,
50-
const float* beta_float_ptr,
51-
const float* bias_float_ptr,
52-
float* skip_float_ptr,
53-
bool should_convert_skip,
5449
ptrdiff_t task_idx,
5550
int hidden_size,
5651
int64_t skip_size,
5752
float epsilon,
5853
bool simplified,
5954
T* output_data,
60-
T* skip_input_bias_add_output_data,
61-
AllocatorPtr alloc) {
62-
ORT_UNUSED_PARAMETER(gamma_float_ptr); // only used in MLFloat16 overload
63-
ORT_UNUSED_PARAMETER(beta_float_ptr); // only used in MLFloat16 overload
64-
ORT_UNUSED_PARAMETER(bias_float_ptr); // only used in MLFloat16 overload
65-
ORT_UNUSED_PARAMETER(skip_float_ptr); // only used in MLFloat16 overload
66-
ORT_UNUSED_PARAMETER(should_convert_skip); // only used in MLFloat16 overload
67-
ORT_UNUSED_PARAMETER(alloc);
68-
55+
T* skip_input_bias_add_output_data) {
6956
auto offset = task_idx * hidden_size;
7057
const T* p_input = input_data + offset;
7158
const T* p_skip = skip_data + (offset % skip_size);
@@ -111,9 +98,6 @@ void ComputeJob(
11198

11299
void ComputeJob(
113100
const MLFloat16* input_data,
114-
const MLFloat16* gamma_data,
115-
const MLFloat16* beta_data,
116-
const MLFloat16* bias_data,
117101
const MLFloat16* skip_data,
118102
const float* gamma_float_ptr,
119103
const float* beta_float_ptr,
@@ -128,10 +112,6 @@ void ComputeJob(
128112
MLFloat16* output_data,
129113
MLFloat16* skip_input_bias_add_output_data,
130114
AllocatorPtr alloc) {
131-
ORT_UNUSED_PARAMETER(gamma_data); // only used in double/float overload
132-
ORT_UNUSED_PARAMETER(beta_data); // only used in double/float overload
133-
ORT_UNUSED_PARAMETER(bias_data); // only used in double/float overload
134-
135115
auto offset = task_idx * hidden_size;
136116
const MLFloat16* p_input = input_data + offset;
137117
const MLFloat16* p_skip = skip_data + (offset % skip_size);
@@ -206,21 +186,21 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I
206186
template <typename T, bool simplified>
207187
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
208188
: OpKernel(op_kernel_info),
189+
prepacked_skip_fp32_data_(nullptr),
209190
prepacked_gamma_fp32_data_(nullptr),
210191
prepacked_beta_fp32_data_(nullptr),
211-
prepacked_bias_fp32_data_(nullptr),
212-
prepacked_skip_fp32_data_(nullptr) {
192+
prepacked_bias_fp32_data_(nullptr) {
213193
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
214194
ORT_ENFORCE(epsilon_ >= 0);
215195
}
216196

217197
template <typename T, bool simplified>
218198
Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
219199
const Tensor* input = p_ctx->Input<Tensor>(0);
220-
const Tensor* skip = p_ctx->Input<Tensor>(1);
221-
const Tensor* gamma = p_ctx->Input<Tensor>(2);
222-
const Tensor* beta = p_ctx->Input<Tensor>(3);
223-
const Tensor* bias = p_ctx->Input<Tensor>(4);
200+
const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(1);
201+
const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(2);
202+
const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(3);
203+
const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input<Tensor>(4);
224204
Tensor* output = p_ctx->Output(0, input->Shape());
225205
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
226206
Tensor* skip_input_bias_add_output = p_ctx->Output(3, input->Shape());
@@ -240,8 +220,8 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
240220
int64_t task_count = input->Shape().SizeToDimension(input_dims_size - 1);
241221

242222
const T* input_data = input->Data<T>();
243-
const T* skip_data = skip->Data<T>();
244-
const T* gamma_data = gamma->Data<T>();
223+
const T* skip_data = skip == nullptr ? nullptr : skip->Data<T>();
224+
const T* gamma_data = gamma == nullptr ? nullptr : gamma->Data<T>();
245225
const T* beta_data = beta == nullptr ? nullptr : beta->Data<T>();
246226
const T* bias_data = bias == nullptr ? nullptr : bias->Data<T>();
247227

@@ -255,15 +235,21 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
255235
AllocatorPtr alloc;
256236
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
257237

238+
IAllocatorUniquePtr<float> skip_fp32;
258239
IAllocatorUniquePtr<float> gamma_fp32;
259240
IAllocatorUniquePtr<float> beta_fp32;
260241
IAllocatorUniquePtr<float> bias_fp32;
261-
IAllocatorUniquePtr<float> skip_fp32;
262242
bool should_convert_skip = false;
263243
if constexpr (std::is_same_v<T, MLFloat16>) {
264244
const size_t num_elems = static_cast<size_t>(hidden_size);
265245

266-
if (prepacked_gamma_fp32_data_ == nullptr) {
246+
if (prepacked_skip_fp32_data_ == nullptr && skip_data) {
247+
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
248+
should_convert_skip = true;
249+
// skip data will be converted inside ComputeJob, because it needs to use the offset.
250+
}
251+
252+
if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
267253
gamma_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
268254
MlasConvertHalfToFloatBuffer(gamma_data, gamma_fp32.get(), num_elems);
269255
}
@@ -277,24 +263,23 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
277263
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
278264
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
279265
}
280-
281-
if (prepacked_skip_fp32_data_ == nullptr) {
282-
skip_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
283-
should_convert_skip = true;
284-
// skip data will be converted inside ComputeJob, because it needs to use the offset.
285-
}
286266
}
287267

288268
concurrency::ThreadPool::TryBatchParallelFor(
289269
p_ctx->GetOperatorThreadPool(), static_cast<int32_t>(task_count),
290270
[&](ptrdiff_t task_idx) {
291-
ComputeJob(input_data, gamma_data, beta_data, bias_data, skip_data,
292-
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
293-
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
294-
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
295-
prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get() : skip_fp32.get(),
296-
should_convert_skip, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
297-
skip_input_bias_add_output_data, alloc);
271+
if constexpr (std::is_same_v<T, MLFloat16>) {
272+
ComputeJob(input_data, skip_data,
273+
prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get() : gamma_fp32.get(),
274+
prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get() : beta_fp32.get(),
275+
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
276+
prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get() : skip_fp32.get(),
277+
should_convert_skip, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
278+
skip_input_bias_add_output_data, alloc);
279+
} else {
280+
ComputeJob(input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
281+
epsilon_, simplified, output_data, skip_input_bias_add_output_data);
282+
}
298283
},
299284
0);
300285

onnxruntime/contrib_ops/cpu/skip_layer_norm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ class SkipLayerNorm final : public OpKernel {
2121

2222
private:
2323
float epsilon_;
24+
IAllocatorUniquePtr<float> prepacked_skip_fp32_data_;
2425
IAllocatorUniquePtr<float> prepacked_gamma_fp32_data_;
2526
IAllocatorUniquePtr<float> prepacked_beta_fp32_data_;
2627
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
27-
IAllocatorUniquePtr<float> prepacked_skip_fp32_data_;
2828
};
2929

3030
} // namespace contrib

0 commit comments

Comments
 (0)