Skip to content

Commit

Permalink
avoid most torch cpu operations
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <[email protected]>
  • Loading branch information
daquexian committed Aug 12, 2023
1 parent 69bfb70 commit aec55ba
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
39 changes: 26 additions & 13 deletions rwkv_pip_package/src/rwkv/cuda/att_one_v5.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ struct Mix {
using torch::Tensor;

void gemm_cublas_tensor(const Tensor &a, const Tensor &b, const Tensor &c);
void gemm_cublas(const void *a, const void *b, void *c, int batch, int ori_m,
int ori_n, int ori_k, at::ScalarType torch_input_dtype,
at::ScalarType torch_output_dtype);

Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
Tensor lx_w, Tensor lx_b, Tensor kvr_mix,
Expand All @@ -67,20 +70,27 @@ Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
/* imm */ Tensor buf,
/* imm */ Tensor s1,
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
Tensor xx = at::layer_norm(x, {x.size(-1)}, ln_w, ln_b);
const int x_numel = x.numel();
Tensor xx = at::layer_norm(x, {x_numel}, ln_w, ln_b);
element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
data_ptr<half>(kvr_mix), static_cast<int>(x.numel()),
data_ptr<half>(kvr_mix), static_cast<int>(x_numel),
data_ptr<half>(kvrx)},
x.numel());
x_numel);

int H = t_decay.size(0);
int S = x.size(-1) / H;
gemm_cublas_tensor(at::unsqueeze(kvrx, 1), kvrw, kvr);
Tensor k = at::reshape(kvr[0], {H, S, 1});
Tensor v = at::reshape(kvr[1], {H, 1, S});
Tensor r = at::reshape(kvr[2], {H, 1, S});
int S = x_numel / H;
// gemm_cublas_tensor(at::unsqueeze(kvrx, 1), kvrw, kvr);
gemm_cublas(data_ptr<half>(kvrx), data_ptr<half>(kvrw), data_ptr<float>(kvr),
3, 1, x_numel, x_numel, at::kHalf, at::kFloat);
float* k = data_ptr<float>(kvr);
float* v = k + x_numel;
float* r = v + x_numel;
// Tensor k = at::reshape(kvr[0], {H, S, 1});
// Tensor v = at::reshape(kvr[1], {H, 1, S});
// Tensor r = at::reshape(kvr[2], {H, 1, S});

gemm_cublas_tensor(k, v, a);
// gemm_cublas_tensor(k, v, a);
gemm_cublas(k, v, data_ptr<float>(a), H, S, S, 1, at::kFloat, at::kFloat);
// s1 = t_first * a + s
// s2 = a + t_decay * s
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
Expand All @@ -89,12 +99,15 @@ Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
data_ptr<float>(s1), data_ptr<float>(s2)},
a.numel());

gemm_cublas_tensor(r, s1, buf);
buf = at::flatten(buf);
buf = at::squeeze(at::group_norm(at::unsqueeze(buf, 0), H, lx_w, lx_b), 0);
// gemm_cublas_tensor(r, s1, buf);
gemm_cublas(r, data_ptr<float>(s1), data_ptr<float>(buf), H, 1, S, S,
at::kFloat, at::kFloat);
buf = at::group_norm(buf, H, lx_w, lx_b);
buf = at::_cast_Half(buf);

gemm_cublas_tensor(buf, ow, x_plus_out);
// gemm_cublas_tensor(buf, ow, x_plus_out);
gemm_cublas(data_ptr<half>(buf), data_ptr<half>(ow), data_ptr<half>(x_plus_out),
1, 1, x_numel, x_numel, at::kHalf, at::kHalf);
x_plus_out += x;
return xx;
}
3 changes: 2 additions & 1 deletion rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def att_seq_i8(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay
@MyFunction
def att_one_v5(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, t_decay, t_first, kvrw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
# import pdb; pdb.set_trace()
kvrx = xx * kvr_mix + sx * (1 - kvr_mix)
# kx = xx * k_mix + sx * (1 - k_mix)
# vx = xx * v_mix + sx * (1 - v_mix)
Expand Down Expand Up @@ -752,7 +753,7 @@ def cuda_att_one_v5_fp16(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, t_deca

kvr = torch.empty((3, 1, x.shape[-1]), dtype=torch.float32, device=x.device)
a = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
buf = torch.empty((H, 1, S), dtype=torch.float32, device=x.device)
buf = torch.empty((1, x.shape[-1]), dtype=torch.float32, device=x.device)
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
x_plus_out = torch.empty_like(x)
Expand Down

0 comments on commit aec55ba

Please sign in to comment.