Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/neural/backends/cuda/common_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ void addVectors(T* c, T* a, T* b, int size, int asize, int bsize,
const int kBlockSize = 256;
int blocks = DivUp(size, kBlockSize);

CUDA_KERNEL_LAUNCH_LOG(addVectors_kernel, blocks, kBlockSize, 0, stream);
addVectors_kernel<<<blocks, kBlockSize, 0, stream>>>(c, a, b, size, asize,
bsize, activation);
ReportCUDAErrors(cudaGetLastError());
Expand Down Expand Up @@ -98,6 +99,7 @@ template <typename T>
void addVectorsHNC_NHC(T* a, T* b, int N, int H, int C, cudaStream_t stream) {
const int kBlockSize = 256;
int blocks = DivUp(N * H * C, kBlockSize);
CUDA_KERNEL_LAUNCH_LOG(addVectorsHNC_NHC_kernel, blocks, kBlockSize, 0, stream);
addVectorsHNC_NHC_kernel<<<blocks, kBlockSize, 0, stream>>>(a, b, N, H, C);

ReportCUDAErrors(cudaGetLastError());
Expand Down Expand Up @@ -171,26 +173,32 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N,

switch (activation) {
case ACTIVATION_NONE:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_NONE, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_NONE>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C);
break;
case ACTIVATION_SELU:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_SELU, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_SELU>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C);
break;
case ACTIVATION_MISH:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_MISH, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_MISH>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C);
break;
case ACTIVATION_RELU:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_RELU, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_RELU>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C);
break;
case ACTIVATION_SWISH:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_SWISH, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_SWISH>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C);
break;
case ACTIVATION_RELU_2: // square relu
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_RELU_2, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_RELU_2>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C);
break;
Expand Down Expand Up @@ -271,31 +279,37 @@ void addBiasBatched(T* output, const T* input, const T* bias, int Batch, int N,

switch (activation) {
case ACTIVATION_NONE:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_NONE, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_NONE>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C,
Nstride);
break;
case ACTIVATION_SELU:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_SELU, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_SELU>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C,
Nstride);
break;
case ACTIVATION_MISH:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_MISH, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_MISH>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C,
Nstride);
break;
case ACTIVATION_RELU:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_RELU, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_RELU>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C,
Nstride);
break;
case ACTIVATION_SWISH:
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_SWISH, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_SWISH>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C,
Nstride);
break;
case ACTIVATION_RELU_2: // square relu
CUDA_KERNEL_LAUNCH_LOG(addBiasBatched_kernel_RELU_2, gridDim, blockDim, 0, stream);
addBiasBatched_kernel<T, ACTIVATION_RELU_2>
<<<gridDim, blockDim, 0, stream>>>(output, input, bias, N, C,
Nstride);
Expand Down Expand Up @@ -336,6 +350,7 @@ void addBias_NCHW(T* c, T* a, T* b, int N, int C, int H, int W,
const int kBlockSize = 256;
int blocks = DivUp(size, kBlockSize);

CUDA_KERNEL_LAUNCH_LOG(addBias_NCHW_kernel, blocks, kBlockSize, 0, stream);
addBias_NCHW_kernel<<<blocks, kBlockSize, 0, stream>>>(c, a, b, N, C, H, W,
activation);
ReportCUDAErrors(cudaGetLastError());
Expand Down Expand Up @@ -387,6 +402,7 @@ void convertNCHWtoNHWC(DstType* output_tensor, const SrcType* input_tensor,
size_t numElements = Nout * Cout * H * W;
const int blockSize = 256;
int blocks = DivUp(numElements, blockSize);
CUDA_KERNEL_LAUNCH_LOG(NCHWtoNHWC_kernel, blocks, blockSize, 0, stream);
NCHWtoNHWC_kernel<<<blocks, blockSize, 0, stream>>>(
output_tensor, input_tensor, Nin, Cin, Nout, Cout, H, W);
}
Expand All @@ -405,6 +421,7 @@ template <typename DstType, typename SrcType>
void copyTypeConverted(DstType* op, SrcType* ip, int N, cudaStream_t stream) {
const int kBlockSize = 256;
int blocks = DivUp(N, kBlockSize);
CUDA_KERNEL_LAUNCH_LOG(copyTypeConverted_kernel, blocks, kBlockSize, 0, stream);
copyTypeConverted_kernel<<<blocks, kBlockSize, 0, stream>>>(op, ip, N);
}

Expand Down Expand Up @@ -444,6 +461,7 @@ void batchNorm(T* output, const T* input, const T* skipInput, int N, int C,
const int kBlockSize = 256;
int blocks = DivUp(total_elements, kBlockSize);

CUDA_KERNEL_LAUNCH_LOG(batchNorm_kernel, blocks, kBlockSize, 0, stream);
batchNorm_kernel<<<blocks, kBlockSize, 0, stream>>>(
output, input, skipInput, N, C, H, W, means, var_multipliers, activation);

Expand Down Expand Up @@ -483,6 +501,7 @@ void expandPlanes_Fp32_NCHW(float* output, const uint64_t* masks,
int threads = n * 8 * 8 / 2; // Each thread writes two elements.
const int blockSize = 256;
int blocks = DivUp(threads, blockSize);
CUDA_KERNEL_LAUNCH_LOG(expandPlanes_kernel_Fp32_NCHW, blocks, blockSize, 0, stream);
expandPlanes_kernel_Fp32_NCHW<<<blocks, blockSize, 0, stream>>>(output, masks,
values, n);
ReportCUDAErrors(cudaGetLastError());
Expand Down Expand Up @@ -517,6 +536,7 @@ void expandPlanes_Fp16_NHWC(half* output, const uint64_t* masks,
int threads = n * 8 * 8; // Each thread writes a single element.
const int kBlockSize = 256;
int blocks = DivUp(threads, kBlockSize);
CUDA_KERNEL_LAUNCH_LOG(expandPlanes_kernel_Fp16_NHWC, blocks, kBlockSize, 0, stream);
expandPlanes_kernel_Fp16_NHWC<<<blocks, kBlockSize, 0, stream>>>(
output, masks, values, n);
ReportCUDAErrors(cudaGetLastError());
Expand Down Expand Up @@ -558,6 +578,7 @@ void expandPlanes_Fp16_NCHW(half* output, const uint64_t* masks,
unsigned threads = n * 8 * 8 / 2; // each thread writes two elements.
const int blockSize = 256;
unsigned blocks = DivUp(threads, blockSize);
CUDA_KERNEL_LAUNCH_LOG(expandPlanes_kernel_Fp16_NCHW, blocks, blockSize, 0, stream);
expandPlanes_kernel_Fp16_NCHW<<<blocks, blockSize, 0, stream>>>(output, masks,
values, n);
ReportCUDAErrors(cudaGetLastError());
Expand Down Expand Up @@ -704,6 +725,7 @@ void globalAvgPool(int N, int C, T* output, const T* input,
if (nhwc) {
assert((std::is_same<half, T>::value));
// For NHWC fp16, simply launch N blocks, each with C threads.
CUDA_KERNEL_LAUNCH_LOG(globalAvgPool_kernel_NHWC_fp16, N, C, 0, stream);
globalAvgPool_kernel_NHWC_fp16<<<N, C, 0, stream>>>(
(half*)output, (half*)input, (half*)prevLayerBias, N * C * kPlaneSize,
N * C);
Expand All @@ -717,6 +739,7 @@ void globalAvgPool(int N, int C, T* output, const T* input,
const int kBlockSize = kWarpsPerBlock * 32;

int blocks = DivUp(kTotalWarps, kWarpsPerBlock);
CUDA_KERNEL_LAUNCH_LOG(globalAvgPool_kernel, blocks, kBlockSize, 0, stream);
globalAvgPool_kernel<<<blocks, kBlockSize, 0, stream>>>(
output, input, prevLayerBias, N * C * kPlaneSize, N * C, C);
}
Expand All @@ -733,10 +756,12 @@ void globalScale(int N, int C, T* output, const T* input, const T* scaleBias,

if (nhwc) {
assert((std::is_same<half, T>::value));
CUDA_KERNEL_LAUNCH_LOG(globalScale_kernel_fp16_nhwc, kBlocks, kBlockSize, 0, stream);
globalScale_kernel_fp16_nhwc<<<kBlocks, kBlockSize, 0, stream>>>(
(half*)output, (half*)input, (half*)scaleBias, (half*)prevLayerBias,
N * C * 8 * 8, C, 8 * 8 * C, activation);
} else {
CUDA_KERNEL_LAUNCH_LOG(globalScale_kernel, kBlocks, kBlockSize, 0, stream);
globalScale_kernel<<<kBlocks, kBlockSize, 0, stream>>>(
output, input, scaleBias, prevLayerBias, N * C * 8 * 8, C, activation);
}
Expand Down Expand Up @@ -770,6 +795,7 @@ void PolicyMap(int N, T* output, const T* input, const short* indices,
const int kBlockSize = 256;
const int kBlocks = DivUp(N * usedSize, kBlockSize);

CUDA_KERNEL_LAUNCH_LOG(policyMap_kernel, kBlocks, kBlockSize, 0, stream);
policyMap_kernel<T><<<kBlocks, kBlockSize, 0, stream>>>(
(T*)output, (T*)input, (short*)indices, N, inputSize, usedSize,
outputSize);
Expand All @@ -785,6 +811,7 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input,
// Each thread processes entire chess board
if (use_se == false) {
dim3 grid_dim(DivUp(C, kOpInpTransformBlockSize), N, 1);
CUDA_KERNEL_LAUNCH_LOG(OutputTransform_relu_InputTransform_kernel, grid_dim, kOpInpTransformBlockSize, 0, stream);
OutputTransform_relu_InputTransform_kernel<float, activation, use_bias,
use_skip>
<<<grid_dim, kOpInpTransformBlockSize, 0, stream>>>(N, C, output, input,
Expand All @@ -794,6 +821,7 @@ void OutputInputTransform(int N, int C, int se_K, T* output, const T* input,
"res block fusing opt not supported for the given data type and no "
"of filters\n");
} else {
CUDA_KERNEL_LAUNCH_LOG(OutputTransform_SE_relu_InputTransform_kernel, N, C, 0, stream);
OutputTransform_SE_relu_InputTransform_kernel<float, activation, use_bias,
use_skip>
<<<N, C, 0, stream>>>(N, C, se_K, output, input, (float*)skip, bias, w1,
Expand Down Expand Up @@ -934,9 +962,11 @@ void Softmax(int N, int C, T* output, const T* input, const T* input2,
int size = N * 32; // Total no of threads needed
const int kBlockSize = 256;
int blocks = DivUp(size, kBlockSize);
CUDA_KERNEL_LAUNCH_LOG(softmax_opt_64_kernel, blocks, kBlockSize, 0, stream);
softmax_opt_64_kernel<T>
<<<blocks, kBlockSize, 0, stream>>>(output, input, input2, size);
} else {
CUDA_KERNEL_LAUNCH_LOG(softmax_kernel, N, C, 0, stream);
softmax_kernel<T><<<N, C, 0, stream>>>(output, input, input2);
}

Expand Down Expand Up @@ -1143,6 +1173,7 @@ void LayerNorm(int N, int C, T* output, const T* input, const T* bias,
gridDim.y = 1;
gridDim.z = 1;

CUDA_KERNEL_LAUNCH_LOG(layer_norm_kernel, gridDim, blockDim, 0, stream);
layer_norm_kernel<T><<<gridDim, blockDim, 0, stream>>>(
N, C, output, input, bias, skip, gammas, betas, ep, alpha, act);

Expand Down Expand Up @@ -1239,6 +1270,7 @@ void ComputePromotionLogits(int N, int C, T* output, const T* keys,
// 8 * 24 threads
// Each thread computes a single output element
dim3 blockDim(24, 8, 1);
CUDA_KERNEL_LAUNCH_LOG(promotion_logits_kernel, N, blockDim, 0, stream);
promotion_logits_kernel<T>
<<<N, blockDim, 0, stream>>>(C, output, keys, ppo, policy_attn_logits);
}
Expand Down Expand Up @@ -1281,6 +1313,7 @@ void inputPreprocessForAttentionBody(T* output, const T* input,
// Each thread computes a single output element
dim3 gridSize = dim3(N, 64);
int blockSize = input_size + encoding_size;
CUDA_KERNEL_LAUNCH_LOG(preprocess_for_attention_body_kernel, gridSize, blockSize, 0, stream);
preprocess_for_attention_body_kernel<T><<<gridSize, blockSize, 0, stream>>>(
output, input, encoding, input_size, encoding_size,
is_pe_dense_embedding);
Expand Down Expand Up @@ -1317,6 +1350,7 @@ void applyInputGating(T* output, const T* input, const T* mult, const T* add,
gridSize.x = DivUp(C, blockSize.x);
gridSize.y = 1;
gridSize.z = N;
CUDA_KERNEL_LAUNCH_LOG(input_gating_kernel, gridSize, blockSize, 0, stream);
input_gating_kernel<T>
<<<gridSize, blockSize, 0, stream>>>(output, input, mult, add, HW, C);

Expand Down
3 changes: 3 additions & 0 deletions src/neural/backends/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
typedef void* cudnnHandle_t;
#endif

// Include CUDA wrapper functions for debug logging
#include "cuda_wrapper.h"

#if CUBLAS_VER_MAJOR < 11
#define CUBLAS_PEDANTIC_MATH CUBLAS_DEFAULT_MATH
#endif
Expand Down
Loading