Skip to content

Commit

Permalink
fix precision warning
Browse files Browse the repository at this point in the history
  • Loading branch information
amarin16 committed Oct 1, 2024
1 parent ab2e5f2 commit 63e9644
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ ORT_FORCEINLINE float* OnlyCreateBufferIfMLFloat16(MLFloat16* p_output, int64_t

template <typename T>
ORT_FORCEINLINE std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded(
[[maybe_unused]] const T* p_input, [[maybe_unused]] int64_t num_elems);
[[maybe_unused]] const T* p_input, [[maybe_unused]] size_t num_elems);

template <typename T>
ORT_FORCEINLINE std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded(
[[maybe_unused]] const std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double>, T>* p_input,
[[maybe_unused]] int64_t num_elems) {
[[maybe_unused]] size_t num_elems) {
return nullptr;
}

template <>
std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded<MLFloat16>(const MLFloat16* p_input, int64_t num_elems) {
std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded<MLFloat16>(const MLFloat16* p_input, size_t num_elems) {
if (!p_input) {
return nullptr;
}
Expand All @@ -51,7 +51,7 @@ std::shared_ptr<std::vector<float>> ConvertMLFloat16ToFloatBufferIfNeeded<MLFloa
return vec;
}

void ConvertFloatBufferToMLFloat16(const float* output_buffer, MLFloat16* p_output, int64_t num_elems) {
void ConvertFloatBufferToMLFloat16(const float* output_buffer, MLFloat16* p_output, size_t num_elems) {
if (!output_buffer || !p_output) {
return;
}
Expand Down Expand Up @@ -192,7 +192,8 @@ Status LayerNormImpl::ComputeWithoutContext(
DoubleOrFloat mean(0.0f);
DoubleOrFloat mean_square(0.0f);

std::shared_ptr<std::vector<float>> float_input = ConvertMLFloat16ToFloatBufferIfNeeded<T>(p_input, norm_size);
std::shared_ptr<std::vector<float>> float_input = ConvertMLFloat16ToFloatBufferIfNeeded<T>(
p_input, static_cast<size_t>(norm_size));
const DoubleOrFloat* converted_input =
float_input == nullptr
? reinterpret_cast<const DoubleOrFloat*>(p_input)
Expand All @@ -215,12 +216,14 @@ Status LayerNormImpl::ComputeWithoutContext(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

std::shared_ptr<std::vector<float>> float_scale = ConvertMLFloat16ToFloatBufferIfNeeded<T>(scale_data, norm_size);
std::shared_ptr<std::vector<float>> float_scale = ConvertMLFloat16ToFloatBufferIfNeeded<T>(
scale_data, static_cast<size_t>(norm_size));
const DoubleOrFloat* converted_scale =
float_scale == nullptr
? reinterpret_cast<const DoubleOrFloat*>(scale_data)
: reinterpret_cast<const DoubleOrFloat*>(&(*float_scale)[0]);
std::shared_ptr<std::vector<float>> float_bias = ConvertMLFloat16ToFloatBufferIfNeeded<T>(bias_data, norm_size);
std::shared_ptr<std::vector<float>> float_bias = ConvertMLFloat16ToFloatBufferIfNeeded<T>(
bias_data, static_cast<size_t>(norm_size));
const DoubleOrFloat* converted_bias =
float_bias == nullptr
? reinterpret_cast<const DoubleOrFloat*>(bias_data)
Expand All @@ -238,7 +241,9 @@ Status LayerNormImpl::ComputeWithoutContext(

if (std::is_same_v<decltype(p_output), MLFloat16*>) {
ConvertFloatBufferToMLFloat16(
reinterpret_cast<float*>(output_buffer), reinterpret_cast<MLFloat16*>(p_output), norm_size);
reinterpret_cast<float*>(output_buffer),
reinterpret_cast<MLFloat16*>(p_output),
static_cast<size_t>(norm_size));
delete[] output_buffer;
}

Expand Down

0 comments on commit 63e9644

Please sign in to comment.