diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 867390f2c..8fe8b9e2a 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3054,213 +3054,39 @@ template __global__ void kExtractOutliers(char *A, int *idx, char * } } -#define WARPS 3 -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc) -{ - -#if __CUDA_ARCH__ >= 750 - using namespace nvcuda; - int col_offset = blockIdx.x *32; - const int warp_id = threadIdx.x / 32; - const int half_warp_id = threadIdx.x / 16; - const int half_warp_lane = threadIdx.x % 16; - const int batch_size_warps = (WARPS-1)*2; - const int val_per_iter = blockDim.x-32; - - T local_A[4]; - T local_B[128]; - - const int a_tile_offset = 16; - const int b_tile_offset = (16*32 + 16); - - __shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))]; - __shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))]; - //__shared__ T smem_C[8*32]; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0f); - - int ticktock = 0; - int idx = 0 + threadIdx.x; - int loaded_values = 0; - // prefetch - if(idx < K && warp_id < (WARPS-1)) - { - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; - } - loaded_values = 3; - } - else - { - - if(loaded_values == 3) - { - local_A[0] = local_A[1]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(32)]; - } - else if(loaded_values == 2) - { - local_A[0] = local_A[2]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(64)]; - } - else - { - local_A[0] = local_A[3]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(96)]; - } - loaded_values--; +template cudaError_t CutlassSGemmNN(int M, int N, int K, T * __restrict__ const A, T * B, T* out, int lda, int lbd, int ldc, float alpha, float beta){ + + #include "cutlass/gemm/device/gemm.h" + //initial sgemm without using cute + using col_major = cutlass::layout::ColumnMajor; + using cutlass_gemm = cutlass::gemm::device::gemm; + cutlass_gemm gemm_operator; + cutlass_gemm::Arguments args({M, N, K}, + {A, lda}, + {B, ldb}, + {out, ldc}, + {out, ldc}, + {alpha, beta}); + + cutlass::Status status = gemm_operator(args); + if(status!=cutlass::Status::kSuccess){ + return cudaErrorUnknown; } + return cudaSuccess; - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - //for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32) - { - idx = base_idx + threadIdx.x; - - __syncthreads(); - if(idx < K && warp_id < (WARPS-1)) - { - //local_A[0] = A[idx]; - - //#pragma unroll 32 - //for(int col = 0; col < 32; col++) - // local_B[col] = B[(col_offset+col)*ldb+idx]; - if(loaded_values == 0) - { - local_A[0] = A[idx]; - local_A[1] = A[idx+(1*val_per_iter)]; - local_A[2] = A[idx+(2*val_per_iter)]; - local_A[3] = A[idx+(3*val_per_iter)]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - { - local_B[col] = B[(col_offset+col)*ldb+idx]; - local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)]; - local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)]; - local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)]; - } - loaded_values = 3; - - } - else - { - - if(loaded_values == 3) - { - local_A[0] = local_A[1]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(32)]; - } - else if(loaded_values == 2) - { - local_A[0] = local_A[2]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(64)]; - } - else - { - local_A[0] = local_A[3]; - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = local_B[col+(96)]; - } - loaded_values--; - } - - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0]; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col]; - } - else if(warp_id < (WARPS-1)) - { - local_A[0] = T(0.0); - smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - local_B[col] = 0.0f; - - #pragma unroll 32 - for(int col = 0; col < 32; col++) - smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f; - } - ticktock = ticktock == 0 ? 1 : 0; - - if(warp_id == (WARPS-1)) - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - } +} - __syncthreads(); - if(warp_id != (WARPS-1)){ return; } - // only warp_id == (WARPS-1) from here - int warp_lane = threadIdx.x % 32; - ticktock = ticktock == 0 ? 1 : 0; - for(int k = 0; k < batch_size_warps; k++) - { - wmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu - wmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } +#define WARPS 3 +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, float alpha, float beta) +{ - // 129 mu - if(warp_id == (WARPS-1)) - wmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, wmma::mem_row_major); + cudaError_t result; + result = CutlassSGemmNN(M, N, K, A, B, out, lda, ldb, ldc, alpha, beta); + if(result!= cudaSuccess) { + std::cerr << "CUTLASS GEMM kernel failed: " + << cudaGetErrorString(result) << std::endl; - if(col_offset + warp_lane < M) - out[col_offset + warp_lane] = smem_A[warp_lane]; -#endif } @@ -3764,25 +3590,25 @@ template __global__ void kfunc(float *A, float *B, float value, lon // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); // these are not used and make no sense, but the compiler needs them //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); //template __global__ void gemm_device(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); -template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); +template __global__ void gemm_device(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index ec6daebe5..d88b362a2 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -123,7 +123,7 @@ template __global__ void kExtractOutliers(char *A, int *idx, char *out, int idx_size, int rowsA, int colsA, int tiledRowsA, int tiledColsA); -template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc); +template __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc, float alpha , float beta); template __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); diff --git a/csrc/ops.cu b/csrc/ops.cu index 7ca854baf..9bb65fc70 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -682,29 +682,21 @@ template void extractOutliers(char * A, int *idx, char *out, int id -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits) +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, float alpha, float beta) { - int num_blocks = (m+31)/32; + auto M = (int)m; + auto N = (int)n; + auto K = (int)k; + dim3 dimBlock(16, 16); + dim3 dimGrid((M+127)/128, (M+127)/128); - //cout << num_blocks << endl; - //cout << lda << endl; - //cout << ldb << endl; - //cout << ldc << endl; + gemm_device<<< dimGrid, dimBlock, 0, 0>>> + (M, N, K, + A, B, out, + lda, ldb, ldc, + alpha, beta); - //cout << m << endl; - //cout << n << endl; - //cout << k << endl; - if(bits == 32) - //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - if(bits == 16) - //gemm_device<<< num_blocks, 256, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - gemm_device<<< num_blocks, 160, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 96, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); - //gemm_device<<< num_blocks, 64, 0, 0 >>>(m, n, k, A, B, out, lda, ldb, ldc); } template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize) @@ -759,7 +751,7 @@ template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); -template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, float alpha, float beta); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); template void extractOutliers(char * A, int *idx, char *out, int idx_size, int rows, int cols); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index b0ecc4622..028f8c659 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -195,7 +195,7 @@ template void extractOutliers(char * A, int *idx, char *out, int id void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rowsA, int colsA, int colsB); -template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); +template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, float alpha, float beta); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream);