Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix layernorm_A #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions lib/layernorm.cc
Original file line number Diff line number Diff line change
@@ -1,44 +1,75 @@
#include "layernorm.h"

#include <arm_neon.h>
#include <cstdio>

void layernorm_A(Matmul *x, int rows, int cols,
const std::vector<__fp16> &gamma,
const std::vector<__fp16> &beta, __fp16 eps) {
const std::vector<__fp16> &beta, float eps)
{
assert(!(cols % 8));
#pragma omp parallel for
for (int i = 0; i < rows; i++) {
for (int i = 0; i < rows; i++)
{
// Calculate mean.
__fp16 *input_ptr = x->get_A_ptr() + 8 * i;
float16x8_t mean = vdupq_n_f16(0);
for (int j = 0; j < cols; j += 8) {
float32x4_t mean_f32x4_high = vdupq_n_f32(0);
float32x4_t mean_f32x4_low = vdupq_n_f32(0);

for (int j = 0; j < cols; j += 8)
{
int offset = rows * j;
mean = vaddq_f16(vld1q_f16(input_ptr + offset), mean);
// mean = vaddq_f16(vld1q_f16(input_ptr + offset), mean);
float16x8_t cur_8 = vld1q_f16(input_ptr + offset);
mean_f32x4_high = vaddq_f32(vcvt_f32_f16(vget_high_f16(cur_8)),mean_f32x4_high);
mean_f32x4_low = vaddq_f32(vcvt_f32_f16(vget_low_f16(cur_8)),mean_f32x4_low);
}
float16x4_t mean_f16x4 = vadd_f16(vget_high_f16(mean), vget_low_f16(mean));
float32x4_t mean_f32x4 = vcvt_f32_f16(mean_f16x4);

// float16x4_t mean_f16x4 = vadd_f16(vget_high_f16(mean), vget_low_f16(mean));
// float32x4_t mean_f32x4 = vcvt_f32_f16(mean_f16x4);

// mean_f32x4_high = vcvt_f32_f16(vget_high_f16(mean));
// mean_f32x4_low = vcvt_f32_f16(vget_low_f16(mean));
float32x4_t mean_f32x4 = vaddq_f32(mean_f32x4_high, mean_f32x4_low);

float32_t mean_f32 = vaddvq_f32(mean_f32x4) / (float32_t)cols;

mean = vdupq_n_f16((__fp16)mean_f32);
mean_f32x4_high = vdupq_n_f32(mean_f32);
mean_f32x4_low = vdupq_n_f32(mean_f32);

// Calculate mean squared deviation
float16x8_t msd = vdupq_n_f16(0);
for (int j = 0; j < cols; j += 8) {
// float16x8_t msd = vdupq_n_f16(0);
float32x4_t msd_f32x4_high = vdupq_n_f32(0);
float32x4_t msd_f32x4_low = vdupq_n_f32(0);
for (int j = 0; j < cols; j += 8)
{
int offset = rows * j;
float16x8_t val = vld1q_f16(input_ptr + offset);
val = vsubq_f16(val, mean);
val = vmulq_f16(val, val);
msd = vaddq_f16(val, msd);
float32x4_t val_f32x4_high = vcvt_f32_f16(vget_high_f16(val));
float32x4_t val_f32x4_low = vcvt_f32_f16(vget_low_f16(val));
val_f32x4_high = vsubq_f32(val_f32x4_high, mean_f32x4_high);
val_f32x4_low = vsubq_f32(val_f32x4_low, mean_f32x4_low);
msd_f32x4_high = vaddq_f32(vmulq_f32(val_f32x4_high, val_f32x4_high), msd_f32x4_high);
msd_f32x4_low = vaddq_f32(vmulq_f32(val_f32x4_low, val_f32x4_low), msd_f32x4_low);

// val = vsubq_f16(val, mean);
// val = vmulq_f16(val, val);
// msd = vaddq_f16(val, msd);
}

float16x4_t msd_f16x4 = vadd_f16(vget_high_f16(msd), vget_low_f16(msd));
float32x4_t msd_f32x4 = vcvt_f32_f16(msd_f16x4);
// float16x4_t msd_f16x4 = vadd_f16(vget_high_f16(msd), vget_low_f16(msd));
// float32x4_t msd_f32x4 = vcvt_f32_f16(msd_f16x4);
float32x4_t msd_f32x4 = vaddq_f32(msd_f32x4_high, msd_f32x4_low);
float32_t msd_f32 = vaddvq_f32(msd_f32x4);
float32_t denom_single = msd_f32 / (float32_t)(cols - 1);
float32_t denom_single = msd_f32 / (float32_t)(cols);
denom_single = std::sqrt(denom_single + eps);
float16x8_t denom = vdupq_n_f16((__fp16)denom_single);

// Normalize based on mean + MSD, beta, gamma and epsilon.
for (int j = 0; j < cols; j += 8) {
for (int j = 0; j < cols; j += 8)
{
int offset = rows * j;
float16x8_t val = vld1q_f16(input_ptr + offset);
float16x8_t gamma_val = vld1q_f16(gamma.data() + j);
Expand All @@ -51,3 +82,4 @@ void layernorm_A(Matmul *x, int rows, int cols,
}
}
}

2 changes: 1 addition & 1 deletion lib/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@

void layernorm_A(Matmul *x, int rows, int cols,
const std::vector<__fp16> &gamma,
const std::vector<__fp16> &beta, __fp16 eps = 1e-5);
const std::vector<__fp16> &beta, float eps = 1e-5);

#endif // _LIB_LAYERNORM_H_