Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add fast cuda kernels for one mode #154

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/gemv.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
// Based on https://github.com/wangsiping97/FastGEMV

#include "util.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

#define WARP_SIZE 32
#define SHARED_MEM_MAX_ROWS 64
#define MAX_THREADS_PER_BLOCK 1024

__device__ __forceinline__ float warp_reduce_sum(float sum,
unsigned int thread_num) {
if (thread_num >= 32)
sum += __shfl_down_sync(0xffffffff, sum, 16); // 0-16, 1-17, 2-18, etc.
if (thread_num >= 16)
sum += __shfl_down_sync(0xffffffff, sum, 8); // 0-8, 1-9, 2-10, etc.
if (thread_num >= 8)
sum += __shfl_down_sync(0xffffffff, sum, 4); // 0-4, 1-5, 2-6, etc.
if (thread_num >= 4)
sum += __shfl_down_sync(0xffffffff, sum, 2); // 0-2, 1-3, 4-6, 5-7, etc.
if (thread_num >= 2)
sum += __shfl_down_sync(0xffffffff, sum, 1); // 0-1, 2-3, 4-5, etc.
return sum;
}

template <typename OutputT, typename Func>
__global__ void _gemv_fp16(half *mat, half *vec, OutputT *res, unsigned int n,
unsigned int num_per_thread, const Func epilogue) {
static_assert(std::is_same_v<OutputT, float> || std::is_same_v<OutputT, half>,
"Output type must be float or half");
{
float sum = 0;
// each thread load num_per_thread elements from global
unsigned int tid = threadIdx.x;
unsigned int row = blockIdx.y * blockDim.y + threadIdx.y;
unsigned int start_idx = threadIdx.x;
float4 *mat4 = reinterpret_cast<float4 *>(mat);
float4 *vec4 = reinterpret_cast<float4 *>(vec);

#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
if (j < n >> 3) {
float4 vec_val = vec4[j];
float4 mat_val = mat4[row * (n >> 3) + j];
const half2 *vec_h1 = (half2 *)&vec_val.x;
const half2 *vec_h2 = (half2 *)&vec_val.y;
const half2 *vec_h3 = (half2 *)&vec_val.z;
const half2 *vec_h4 = (half2 *)&vec_val.w;
const half2 *mat_h1 = (half2 *)&mat_val.x;
const half2 *mat_h2 = (half2 *)&mat_val.y;
const half2 *mat_h3 = (half2 *)&mat_val.z;
const half2 *mat_h4 = (half2 *)&mat_val.w;
sum += __half2float(vec_h1->x) * __half2float(mat_h1->x);
sum += __half2float(vec_h1->y) * __half2float(mat_h1->y);
sum += __half2float(vec_h2->x) * __half2float(mat_h2->x);
sum += __half2float(vec_h2->y) * __half2float(mat_h2->y);
sum += __half2float(vec_h3->x) * __half2float(mat_h3->x);
sum += __half2float(vec_h3->y) * __half2float(mat_h3->y);
sum += __half2float(vec_h4->x) * __half2float(mat_h4->x);
sum += __half2float(vec_h4->y) * __half2float(mat_h4->y);
}
}

sum = warp_reduce_sum(sum, blockDim.x);

if (blockDim.x <= WARP_SIZE) {
if (tid == 0) {
if constexpr (std::is_same_v<OutputT, float>) {
res[row] = epilogue(sum, row);
} else {
res[row] = epilogue(__float2half(sum), row);
}
}
return;
}

// Shared mem for partial sums (one per warp in the block)
static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE];
const int laneId = threadIdx.x % WARP_SIZE;
const int warpId = threadIdx.x / WARP_SIZE;
if (laneId == 0)
warpLevelSums[threadIdx.y][warpId] = sum;
__syncthreads();
// read from shared memory only if that warp existed
sum = (threadIdx.x < blockDim.x / WARP_SIZE)
? warpLevelSums[threadIdx.y][laneId]
: 0.0;
// Final reduce using first warp
if (warpId == 0)
sum = warp_reduce_sum(sum, blockDim.x / WARP_SIZE);
if (tid == 0) {
if constexpr (std::is_same_v<OutputT, float>) {
res[row] = epilogue(sum, row);
} else {
res[row] = epilogue(__float2half(sum), row);
}
}
}
}

template <typename OutputT, typename Func>
void gemv_fp16(torch::Tensor mat, torch::Tensor vec, torch::Tensor out,
const Func &func) {
const int32_t BLOCK_DIM_X = 32;
const int32_t BLOCK_DIM_Y = 16;
assert(BLOCK_DIM_Y <= SHARED_MEM_MAX_ROWS);
assert(BLOCK_DIM_X * BLOCK_DIM_Y <= MAX_THREADS_PER_BLOCK);

auto N = mat.size(0);
auto K = mat.size(1);
assert(vec.size(0) == K);
assert(out.size(0) == N);
const int32_t num_per_thread = K / BLOCK_DIM_X;
assert(num_per_thread >= 8);

dim3 grid_dim(1, N / BLOCK_DIM_Y);
dim3 block_dim(BLOCK_DIM_X, BLOCK_DIM_Y);
_gemv_fp16<OutputT><<<grid_dim, block_dim>>>(
data_ptr<half>(mat), data_ptr<half>(vec), data_ptr<OutputT>(out), K,
num_per_thread, func);
}

template <typename T> struct IdentityEpilogue {
__device__ __forceinline__ T operator()(T x, unsigned int idx) const {
return x;
}
};

struct SigmoidEpilogue {
__device__ __forceinline__ half operator()(half x, unsigned int idx) const {
return hrcp(__hadd(__float2half(1.0), hexp(__hneg(x))));
}
};

struct ReLUAndSqaureEpilogue {
__device__ __forceinline__ half operator()(half x, unsigned int idx) const {
return __hgt(x, __float2half(0.0)) ? __hmul(x, x) : __float2half(0.0);
}
};

struct BiasEpilogue {
BiasEpilogue(half *bias) : bias(bias) {}
__device__ __forceinline__ half operator()(half x, unsigned int idx) const {
return __hadd(bias[idx], x);
}
half *bias;
};

struct ScaleAndBiasEpilogue {
ScaleAndBiasEpilogue(half *scale, half *bias) : scale(scale), bias(bias) {}
__device__ __forceinline__ half operator()(half x, unsigned int idx) const {
return __hadd(bias[idx], __hmul(scale[idx], x));
}
half *scale;
half *bias;
};
53 changes: 50 additions & 3 deletions rwkv_pip_package/src/rwkv/cuda/operators.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
#include "gemv.cuh"
#include "wkv_forward_one.cuh"
#include <cuda_fp16.h>
#include <functional>
#include <torch/extension.h>

using torch::Tensor;
#define MIN_VALUE (-1e38)
typedef at::Half fp16;
__half *cast(fp16 *ptr) {
return reinterpret_cast<__half *>(ptr);
}
__half *cast(fp16 *ptr) { return reinterpret_cast<__half *>(ptr); }

template <typename F>
__global__ void kernel_wkv_forward(const int B, const int T, const int C,
Expand Down Expand Up @@ -244,3 +248,46 @@ void cuda_mm8_one<fp16>(int N, int M,
N, M, cast(x), w, w_stride,
cast(mx), cast(rx), cast(my), cast(ry), y);
}

void ffn_one(torch::Tensor rw, torch::Tensor rx, torch::Tensor kw,
torch::Tensor kx, torch::Tensor vw, torch::Tensor x,
/* imm */ torch::Tensor r, /* imm */ torch::Tensor vx,
/* out */ torch::Tensor x_plus_out) {

// r = torch.sigmoid(gemv(rw, rx))
gemv_fp16<half>(rw, rx, r, SigmoidEpilogue{});

// vx = torch.square(torch.relu(gemv(kw, kx)))
gemv_fp16<half>(kw, kx, vx, ReLUAndSqaureEpilogue{});

// out = r * gemv(vw, vx)
// x + out
half *r_ptr = cast(r.data_ptr<fp16>());
half *x_ptr = cast(x.data_ptr<fp16>());
gemv_fp16<half>(vw, vx, x_plus_out,
ScaleAndBiasEpilogue{r_ptr, x_ptr});
}

void att_one(Tensor x, Tensor kw, Tensor kx, Tensor vw, Tensor vx, Tensor rw,
Tensor rx, Tensor ow, Tensor t_first, /* imm */ Tensor k,
Tensor pp, Tensor ww, Tensor aa, Tensor bb, Tensor t_decay,
/* imm */ Tensor v, /* in & out */ Tensor r,
/* out */ Tensor x_plus_out, /* out */ Tensor t1,
/* out */ Tensor t2, /* out */ Tensor p) {
gemv_fp16<float>(kw, kx, k, IdentityEpilogue<float>{});
gemv_fp16<float>(vw, vx, v, IdentityEpilogue<float>{});
gemv_fp16<half>(rw, rx, r, SigmoidEpilogue{});

size_t elem_num = t_first.numel();
// 256 is good enough on most GPUs
const int32_t BLOCK_SIZE = 256;
assert(elem_num % BLOCK_SIZE == 0);
wkv_forward_one<<<elem_num / BLOCK_SIZE, BLOCK_SIZE>>>(
t_first.data_ptr<float>(), k.data_ptr<float>(), pp.data_ptr<float>(),
aa.data_ptr<float>(), bb.data_ptr<float>(), t_decay.data_ptr<float>(),
v.data_ptr<float>(), t1.data_ptr<float>(), t2.data_ptr<float>(),
p.data_ptr<float>(), cast(r.data_ptr<fp16>()), elem_num);

half *x_ptr = cast(x.data_ptr<fp16>());
gemv_fp16<half>(ow, r, x_plus_out, BiasEpilogue{x_ptr});
}
28 changes: 28 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/util.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "ATen/ATen.h"
#include <cuda_fp16.h>
#include <torch/extension.h>

template <typename T> T *data_ptr(torch::Tensor x) { return x.data_ptr<T>(); }
template <> inline half *data_ptr(torch::Tensor x) {
return reinterpret_cast<half *>(x.data_ptr<at::Half>());
}

inline __host__ __device__ float4 operator+(float4 a, float4 b) {
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ float4 operator-(float4 a, float4 b) {
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ float4 operator*(float4 a, float4 b) {
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ float4 operator/(float4 a, float4 b) {
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
}
inline __host__ __device__ float4 fmaxf(float4 a, float4 b) {
return make_float4(fmaxf(a.x, b.x), fmaxf(a.y, b.y), fmaxf(a.z, b.z),
fmaxf(a.w, b.w));
}
inline __host__ __device__ float4 expf(float4 a) {
return make_float4(expf(a.x), expf(a.y), expf(a.z), expf(a.w));
}
89 changes: 89 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/wkv_forward_one.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
#include "util.cuh"
#include <cuda_fp16.h>
#include <cuda_runtime.h>

// Equivalent Python code:
// ww = t_first + k
// p = torch.maximum(pp, ww)
// e1 = torch.exp(pp - p)
// e2 = torch.exp(ww - p)
// wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
// ww = t_decay + pp
// p = torch.maximum(ww, k)
// e1 = torch.exp(ww - p)
// e2 = torch.exp(k - p)
// t1 = e1 * aa + e2 * v
// t2 = e1 * bb + e2
// r = r * wkv
// return t1, t2, p, r

// NOTE: float4 is a overkill for current sizes (4096 in 7B model and 768 in 0.1B model),
// and is not faster than the plain float version.
// Now the plain float version is used.
__global__ void wkv_forward_one(float4 *t_first, float4 *k, float4 *pp,
float4 *aa, float4 *bb, float4 *t_decay,
float4 *v, /* out */ float4 *t1,
/* out */ float4 *t2, /* out */ float4 *p,
/* in & out, half */ float2 *r,
unsigned int n) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n / 4;
i += blockDim.x * gridDim.x) {
float4 ww = t_first[i] + k[i];
float4 pp_ = pp[i];
float4 p_ = fmaxf(pp_, ww);
float4 e1 = expf(pp_ - p_);
float4 e2 = expf(ww - p_);

float4 aa_ = aa[i];
float4 bb_ = bb[i];
float4 v_ = v[i];
half2 wkv1 = make_half2(
__float2half(((e1.x * aa_.x + e2.x * v_.x) / (e1.x * bb_.x + e2.x))),
__float2half(((e1.y * aa_.y + e2.y * v_.y) / (e1.y * bb_.y + e2.y))));
half2 wkv2 = make_half2(
__float2half(((e1.z * aa_.z + e2.z * v_.z) / (e1.z * bb_.z + e2.z))),
__float2half(((e1.w * aa_.w + e2.w * v_.w) / (e1.w * bb_.w + e2.w))));
half2 *r1 = reinterpret_cast<half2 *>(&r[i].x);
half2 *r2 = reinterpret_cast<half2 *>(&r[i].y);
*r1 = __hmul2(wkv1, *r1);
*r2 = __hmul2(wkv2, *r2);

ww = t_decay[i] + pp_;
float4 k_ = k[i];
p_ = fmaxf(ww, k_);
e1 = expf(ww - p_);
e2 = expf(k_ - p_);

t1[i] = e1 * aa_ + e2 * v_;
t2[i] = e1 * bb_ + e2;

p[i] = p_;
}
}

__global__ void wkv_forward_one(float *t_first, float *k, float *pp, float *aa,
float *bb, float *t_decay, float *v,
/* out */ float *t1, /* out */ float *t2,
/* out */ float *p, /* in & out */ half *r,
unsigned int n) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n;
i += blockDim.x * gridDim.x) {
float ww = t_first[i] + k[i];
float pp_ = pp[i];
float p_ = (pp_ > ww) ? pp_ : ww;
float e1 = expf(pp_ - p_);
float e2 = expf(ww - p_);
float aa_ = aa[i];
float bb_ = bb[i];
float v_ = v[i];
r[i] = __hmul(r[i], __float2half(((e1 * aa_ + e2 * v_) / (e1 * bb_ + e2))));
ww = t_decay[i] + pp_;
float k_ = k[i];
p_ = (ww > k_) ? ww : k_;
e1 = expf(ww - p_);
e2 = expf(k_ - p_);
t1[i] = e1 * aa_ + e2 * v_;
t2[i] = e1 * bb_ + e2;
p[i] = p_;
}
}
12 changes: 12 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/wkv_forward_one.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "cuda_runtime.h"

__global__ void wkv_forward_one(float4 *t_first, float4 *k, float4 *pp,
float4 *aa, float4 *bb, float4 *t_decay,
float4 *v, /* out */ float4 *t1,
/* out */ float4 *t2, /* out */ float4 *p,
/* in & out, half */ float2 *r, unsigned int n);
__global__ void wkv_forward_one(float *t_first, float *k, float *pp, float *aa,
float *bb, float *t_decay, float *v,
/* out */ float *t1, /* out */ float *t2,
/* out */ float *p, /* in & out */ half *r,
unsigned int n);
10 changes: 10 additions & 0 deletions rwkv_pip_package/src/rwkv/cuda/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,24 @@ void mm8_one(int64_t N, int64_t M,
}
}

using torch::Tensor;

void ffn_one(Tensor rw, Tensor rx, Tensor kw, Tensor kx, Tensor vw, Tensor x, /* imm */ Tensor r, /* imm */ Tensor vx, /* out */ Tensor x_plus_out);

void att_one(Tensor x, Tensor kw, Tensor kx, Tensor vw, Tensor vx, Tensor rw, Tensor rx, Tensor ow, Tensor t_first, /* imm */ Tensor k, Tensor pp, Tensor ww, Tensor aa, Tensor bb, Tensor t_decay, /* imm */ Tensor v, /* in & out */ Tensor r, /* out */ Tensor x_plus_out, /* out */ Tensor t1, /* out */ Tensor t2, /* out */ Tensor p);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wkv_forward", &wkv_forward, "wkv forward");
m.def("mm8_seq", &mm8_seq, "mm8 seq");
m.def("mm8_one", &mm8_one, "mm8 one");
m.def("ffn_one", &ffn_one, "ffn one large kernel");
m.def("att_one", &att_one, "att one");
}

TORCH_LIBRARY(rwkv, m) {
m.def("wkv_forward", wkv_forward);
m.def("mm8_seq", mm8_seq);
m.def("mm8_one", mm8_one);
m.def("ffn_one", ffn_one);
m.def("att_one", att_one);
}
Loading