diff --git a/lib/layernorm.cc b/lib/layernorm.cc index c60b62b..1f1921d 100644 --- a/lib/layernorm.cc +++ b/lib/layernorm.cc @@ -1,44 +1,75 @@ #include "layernorm.h" #include +#include void layernorm_A(Matmul *x, int rows, int cols, const std::vector<__fp16> &gamma, - const std::vector<__fp16> &beta, __fp16 eps) { + const std::vector<__fp16> &beta, float eps) +{ assert(!(cols % 8)); #pragma omp parallel for - for (int i = 0; i < rows; i++) { + for (int i = 0; i < rows; i++) + { // Calculate mean. __fp16 *input_ptr = x->get_A_ptr() + 8 * i; float16x8_t mean = vdupq_n_f16(0); - for (int j = 0; j < cols; j += 8) { + float32x4_t mean_f32x4_high = vdupq_n_f32(0); + float32x4_t mean_f32x4_low = vdupq_n_f32(0); + + for (int j = 0; j < cols; j += 8) + { int offset = rows * j; - mean = vaddq_f16(vld1q_f16(input_ptr + offset), mean); + // mean = vaddq_f16(vld1q_f16(input_ptr + offset), mean); + float16x8_t cur_8 = vld1q_f16(input_ptr + offset); + mean_f32x4_high = vaddq_f32(vcvt_f32_f16(vget_high_f16(cur_8)),mean_f32x4_high); + mean_f32x4_low = vaddq_f32(vcvt_f32_f16(vget_low_f16(cur_8)),mean_f32x4_low); } - float16x4_t mean_f16x4 = vadd_f16(vget_high_f16(mean), vget_low_f16(mean)); - float32x4_t mean_f32x4 = vcvt_f32_f16(mean_f16x4); + + // float16x4_t mean_f16x4 = vadd_f16(vget_high_f16(mean), vget_low_f16(mean)); + // float32x4_t mean_f32x4 = vcvt_f32_f16(mean_f16x4); + + // mean_f32x4_high = vcvt_f32_f16(vget_high_f16(mean)); + // mean_f32x4_low = vcvt_f32_f16(vget_low_f16(mean)); + float32x4_t mean_f32x4 = vaddq_f32(mean_f32x4_high, mean_f32x4_low); + float32_t mean_f32 = vaddvq_f32(mean_f32x4) / (float32_t)cols; + mean = vdupq_n_f16((__fp16)mean_f32); + mean_f32x4_high = vdupq_n_f32(mean_f32); + mean_f32x4_low = vdupq_n_f32(mean_f32); // Calculate mean squared deviation - float16x8_t msd = vdupq_n_f16(0); - for (int j = 0; j < cols; j += 8) { + // float16x8_t msd = vdupq_n_f16(0); + float32x4_t msd_f32x4_high = vdupq_n_f32(0); + float32x4_t msd_f32x4_low = vdupq_n_f32(0); + for (int j = 0; j < cols; j += 8) + { int offset = rows * j; float16x8_t val = vld1q_f16(input_ptr + offset); - val = vsubq_f16(val, mean); - val = vmulq_f16(val, val); - msd = vaddq_f16(val, msd); + float32x4_t val_f32x4_high = vcvt_f32_f16(vget_high_f16(val)); + float32x4_t val_f32x4_low = vcvt_f32_f16(vget_low_f16(val)); + val_f32x4_high = vsubq_f32(val_f32x4_high, mean_f32x4_high); + val_f32x4_low = vsubq_f32(val_f32x4_low, mean_f32x4_low); + msd_f32x4_high = vaddq_f32(vmulq_f32(val_f32x4_high, val_f32x4_high), msd_f32x4_high); + msd_f32x4_low = vaddq_f32(vmulq_f32(val_f32x4_low, val_f32x4_low), msd_f32x4_low); + + // val = vsubq_f16(val, mean); + // val = vmulq_f16(val, val); + // msd = vaddq_f16(val, msd); } - float16x4_t msd_f16x4 = vadd_f16(vget_high_f16(msd), vget_low_f16(msd)); - float32x4_t msd_f32x4 = vcvt_f32_f16(msd_f16x4); + // float16x4_t msd_f16x4 = vadd_f16(vget_high_f16(msd), vget_low_f16(msd)); + // float32x4_t msd_f32x4 = vcvt_f32_f16(msd_f16x4); + float32x4_t msd_f32x4 = vaddq_f32(msd_f32x4_high, msd_f32x4_low); float32_t msd_f32 = vaddvq_f32(msd_f32x4); - float32_t denom_single = msd_f32 / (float32_t)(cols - 1); + float32_t denom_single = msd_f32 / (float32_t)(cols); denom_single = std::sqrt(denom_single + eps); float16x8_t denom = vdupq_n_f16((__fp16)denom_single); // Normalize based on mean + MSD, beta, gamma and epsilon. - for (int j = 0; j < cols; j += 8) { + for (int j = 0; j < cols; j += 8) + { int offset = rows * j; float16x8_t val = vld1q_f16(input_ptr + offset); float16x8_t gamma_val = vld1q_f16(gamma.data() + j); @@ -51,3 +82,4 @@ void layernorm_A(Matmul *x, int rows, int cols, } } } + diff --git a/lib/layernorm.h b/lib/layernorm.h index c241ad6..e325328 100644 --- a/lib/layernorm.h +++ b/lib/layernorm.h @@ -7,6 +7,6 @@ void layernorm_A(Matmul *x, int rows, int cols, const std::vector<__fp16> &gamma, - const std::vector<__fp16> &beta, __fp16 eps = 1e-5); + const std::vector<__fp16> &beta, float eps = 1e-5); #endif // _LIB_LAYERNORM_H_