@@ -42,30 +42,17 @@ namespace {
42
42
template <typename T, typename = std::enable_if_t <std::is_same_v<T, float > || std::is_same_v<T, double >, void >>
43
43
void ComputeJob (
44
44
const T* input_data,
45
+ const T* skip_data,
45
46
const T* gamma_data,
46
47
const T* beta_data,
47
48
const T* bias_data,
48
- const T* skip_data,
49
- const float * gamma_float_ptr,
50
- const float * beta_float_ptr,
51
- const float * bias_float_ptr,
52
- float * skip_float_ptr,
53
- bool should_convert_skip,
54
49
ptrdiff_t task_idx,
55
50
int hidden_size,
56
51
int64_t skip_size,
57
52
float epsilon,
58
53
bool simplified,
59
54
T* output_data,
60
- T* skip_input_bias_add_output_data,
61
- AllocatorPtr alloc) {
62
- ORT_UNUSED_PARAMETER (gamma_float_ptr); // only used in MLFloat16 overload
63
- ORT_UNUSED_PARAMETER (beta_float_ptr); // only used in MLFloat16 overload
64
- ORT_UNUSED_PARAMETER (bias_float_ptr); // only used in MLFloat16 overload
65
- ORT_UNUSED_PARAMETER (skip_float_ptr); // only used in MLFloat16 overload
66
- ORT_UNUSED_PARAMETER (should_convert_skip); // only used in MLFloat16 overload
67
- ORT_UNUSED_PARAMETER (alloc);
68
-
55
+ T* skip_input_bias_add_output_data) {
69
56
auto offset = task_idx * hidden_size;
70
57
const T* p_input = input_data + offset;
71
58
const T* p_skip = skip_data + (offset % skip_size);
@@ -111,9 +98,6 @@ void ComputeJob(
111
98
112
99
void ComputeJob (
113
100
const MLFloat16* input_data,
114
- const MLFloat16* gamma_data,
115
- const MLFloat16* beta_data,
116
- const MLFloat16* bias_data,
117
101
const MLFloat16* skip_data,
118
102
const float * gamma_float_ptr,
119
103
const float * beta_float_ptr,
@@ -128,10 +112,6 @@ void ComputeJob(
128
112
MLFloat16* output_data,
129
113
MLFloat16* skip_input_bias_add_output_data,
130
114
AllocatorPtr alloc) {
131
- ORT_UNUSED_PARAMETER (gamma_data); // only used in double/float overload
132
- ORT_UNUSED_PARAMETER (beta_data); // only used in double/float overload
133
- ORT_UNUSED_PARAMETER (bias_data); // only used in double/float overload
134
-
135
115
auto offset = task_idx * hidden_size;
136
116
const MLFloat16* p_input = input_data + offset;
137
117
const MLFloat16* p_skip = skip_data + (offset % skip_size);
@@ -206,21 +186,21 @@ void ConvertMLFloat16ToFloatIfNeeded(const Tensor& tensor, AllocatorPtr alloc, I
206
186
template <typename T, bool simplified>
207
187
SkipLayerNorm<T, simplified>::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
208
188
: OpKernel(op_kernel_info),
189
+ prepacked_skip_fp32_data_(nullptr ),
209
190
prepacked_gamma_fp32_data_(nullptr ),
210
191
prepacked_beta_fp32_data_(nullptr ),
211
- prepacked_bias_fp32_data_(nullptr ),
212
- prepacked_skip_fp32_data_(nullptr ) {
192
+ prepacked_bias_fp32_data_(nullptr ) {
213
193
ORT_ENFORCE (op_kernel_info.GetAttr <float >(" epsilon" , &epsilon_).IsOK ());
214
194
ORT_ENFORCE (epsilon_ >= 0 );
215
195
}
216
196
217
197
template <typename T, bool simplified>
218
198
Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
219
199
const Tensor* input = p_ctx->Input <Tensor>(0 );
220
- const Tensor* skip = p_ctx->Input <Tensor>(1 );
221
- const Tensor* gamma = p_ctx->Input <Tensor>(2 );
222
- const Tensor* beta = p_ctx->Input <Tensor>(3 );
223
- const Tensor* bias = p_ctx->Input <Tensor>(4 );
200
+ const Tensor* skip = prepacked_skip_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(1 );
201
+ const Tensor* gamma = prepacked_gamma_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(2 );
202
+ const Tensor* beta = prepacked_beta_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(3 );
203
+ const Tensor* bias = prepacked_bias_fp32_data_ ? nullptr : p_ctx->Input <Tensor>(4 );
224
204
Tensor* output = p_ctx->Output (0 , input->Shape ());
225
205
// For inferencing, we support one more optional output which is the sum of the input and skip tensors
226
206
Tensor* skip_input_bias_add_output = p_ctx->Output (3 , input->Shape ());
@@ -240,8 +220,8 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
240
220
int64_t task_count = input->Shape ().SizeToDimension (input_dims_size - 1 );
241
221
242
222
const T* input_data = input->Data <T>();
243
- const T* skip_data = skip->Data <T>();
244
- const T* gamma_data = gamma->Data <T>();
223
+ const T* skip_data = skip == nullptr ? nullptr : skip ->Data <T>();
224
+ const T* gamma_data = gamma == nullptr ? nullptr : gamma ->Data <T>();
245
225
const T* beta_data = beta == nullptr ? nullptr : beta->Data <T>();
246
226
const T* bias_data = bias == nullptr ? nullptr : bias->Data <T>();
247
227
@@ -255,15 +235,21 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
255
235
AllocatorPtr alloc;
256
236
ORT_RETURN_IF_ERROR (p_ctx->GetTempSpaceAllocator (&alloc));
257
237
238
+ IAllocatorUniquePtr<float > skip_fp32;
258
239
IAllocatorUniquePtr<float > gamma_fp32;
259
240
IAllocatorUniquePtr<float > beta_fp32;
260
241
IAllocatorUniquePtr<float > bias_fp32;
261
- IAllocatorUniquePtr<float > skip_fp32;
262
242
bool should_convert_skip = false ;
263
243
if constexpr (std::is_same_v<T, MLFloat16>) {
264
244
const size_t num_elems = static_cast <size_t >(hidden_size);
265
245
266
- if (prepacked_gamma_fp32_data_ == nullptr ) {
246
+ if (prepacked_skip_fp32_data_ == nullptr && skip_data) {
247
+ skip_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
248
+ should_convert_skip = true ;
249
+ // skip data will be converted inside ComputeJob, because it needs to use the offset.
250
+ }
251
+
252
+ if (prepacked_gamma_fp32_data_ == nullptr && gamma_data) {
267
253
gamma_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
268
254
MlasConvertHalfToFloatBuffer (gamma_data, gamma_fp32.get (), num_elems);
269
255
}
@@ -277,24 +263,23 @@ Status SkipLayerNorm<T, simplified>::Compute(OpKernelContext* p_ctx) const {
277
263
bias_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
278
264
MlasConvertHalfToFloatBuffer (bias_data, bias_fp32.get (), num_elems);
279
265
}
280
-
281
- if (prepacked_skip_fp32_data_ == nullptr ) {
282
- skip_fp32 = IAllocator::MakeUniquePtr<float >(alloc, num_elems);
283
- should_convert_skip = true ;
284
- // skip data will be converted inside ComputeJob, because it needs to use the offset.
285
- }
286
266
}
287
267
288
268
concurrency::ThreadPool::TryBatchParallelFor (
289
269
p_ctx->GetOperatorThreadPool (), static_cast <int32_t >(task_count),
290
270
[&](ptrdiff_t task_idx) {
291
- ComputeJob (input_data, gamma_data, beta_data, bias_data, skip_data,
292
- prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get () : gamma_fp32.get (),
293
- prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get () : beta_fp32.get (),
294
- prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
295
- prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get () : skip_fp32.get (),
296
- should_convert_skip, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
297
- skip_input_bias_add_output_data, alloc);
271
+ if constexpr (std::is_same_v<T, MLFloat16>) {
272
+ ComputeJob (input_data, skip_data,
273
+ prepacked_gamma_fp32_data_ ? prepacked_gamma_fp32_data_.get () : gamma_fp32.get (),
274
+ prepacked_beta_fp32_data_ ? prepacked_beta_fp32_data_.get () : beta_fp32.get (),
275
+ prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get () : bias_fp32.get (),
276
+ prepacked_skip_fp32_data_ ? prepacked_skip_fp32_data_.get () : skip_fp32.get (),
277
+ should_convert_skip, task_idx, hidden_size, skip_size, epsilon_, simplified, output_data,
278
+ skip_input_bias_add_output_data, alloc);
279
+ } else {
280
+ ComputeJob (input_data, skip_data, gamma_data, beta_data, bias_data, task_idx, hidden_size, skip_size,
281
+ epsilon_, simplified, output_data, skip_input_bias_add_output_data);
282
+ }
298
283
},
299
284
0 );
300
285
0 commit comments