Skip to content

Commit 4fef2fb

Browse files
committedMar 15, 2025
修改attention模块的注释,可读性更强
1 parent dc5b2e6 commit 4fef2fb

File tree

1 file changed

+24
-27
lines changed

1 file changed

+24
-27
lines changed
 

‎kuiper/source/op/kernels/cuda/mha_kernel.cu

+24-27
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
#include <base/cuda_config.h>
22
#include <tensor/tensor.h>
3+
#include <cfloat>
34
#include <cub/cub.cuh>
45
#include "mha_kernel.cuh"
6+
#include <base/tick.h>
57
namespace kernel {
68
constexpr static int thread_num = 256;
79
__device__ void softmax_gpu(float* __restrict__ x, int size) {
@@ -44,6 +46,7 @@ __device__ void softmax_gpu(float* __restrict__ x, int size) {
4446
}
4547
}
4648

49+
4750
__global__ void multi_head_attention_kernel(int32_t pos, int32_t seq_len, float* query,
4851
float* score_ptr, float* output, float* key_cache,
4952
float* value_cache, int32_t kv_dim, int32_t kv_mul,
@@ -54,38 +57,32 @@ __global__ void multi_head_attention_kernel(int32_t pos, int32_t seq_len, float*
5457
return;
5558
}
5659

57-
float scale = 1.f / sqrtf(head_size);
60+
extern __shared__ float s_query_head[];
61+
float scale = 1.f / sqrtf(float(head_size));
5862
float* query_head = query + head * head_size;
63+
64+
// 预加载query到共享内存
65+
for (int i = threadIdx.x; i < head_size; i += blockDim.x) {
66+
s_query_head[i] = query_head[i];
67+
}
68+
__syncthreads();
69+
5970
float* score_head = score_ptr + head * seq_len;
71+
// head当前的注意力头索引,kv_mul用于gqa,head_size表示一个自注意力头的维度
72+
// kv_dim = head_size * head_num,多头自注意力情况下的key,value 维度
73+
// kv_dim = head_size * head_num / kv_num,GQA情况下的key,value 维度
6074
int head_offset = (head / kv_mul) * head_size;
75+
// 计算自注意力分数
6176
for (int t = threadIdx.x; t <= pos; t += blockDim.x) {
6277
float* key_head = key_cache + layer_offset + t * kv_dim + head_offset;
63-
/**
64-
* 在Meta的Llama注意力机制实现中,head_dim等于head_size。
65-
*
66-
* xq = xq.transpose(1, 2) # 转置后形状为 (heads, sequence_length, head_dim)
67-
* # 如果sequence_length为1,则形状简化为 (heads, head_dim)
68-
* keys = keys.transpose(1, 2) # 同样转置keys,得到形状 (heads, sequence_length, head_dim)
69-
* # 若sequence_length为1,则形状也简化为 (heads, head_dim)
70-
*
71-
* 在我们的代码实现中,计算公式为 (head / kv_mul) * head_size。
72-
* 其中,在多头注意力(MHA)机制里,kv_mul的值为1,
73-
* 因此计算得到的head_offset就等于head * head_size。
74-
*
75-
* 这里的head_offset用于定位到当前处理的头部(head),而t * kv_dim (即t *
76-
* dim)则用于定位到历史的key向量。
77-
*/
78-
79-
// query @ key 逐个头相乘,从上面的代码可以看出
78+
8079
float score = 0.0f;
81-
#pragma unroll
8280
for (int i = 0; i < head_size; i += 4) {
83-
float4 key_head_float4 = *reinterpret_cast<float4*>(key_head + i);
84-
float4 query_head_float4 = *reinterpret_cast<float4*>(query_head + i);
85-
score += key_head_float4.x * query_head_float4.x;
86-
score += key_head_float4.y * query_head_float4.y;
87-
score += key_head_float4.z * query_head_float4.z;
88-
score += key_head_float4.w * query_head_float4.w;
81+
float4 key_val = *reinterpret_cast<float4*>(key_head + i);
82+
float4 query_val = *reinterpret_cast<float4*>(s_query_head + i);
83+
84+
score += key_val.x * query_val.x + key_val.y * query_val.y + key_val.z * query_val.z +
85+
key_val.w * query_val.w;
8986
}
9087

9188
score *= scale;
@@ -97,9 +94,9 @@ __global__ void multi_head_attention_kernel(int32_t pos, int32_t seq_len, float*
9794
__syncthreads();
9895

9996
float* output_head = output + head * head_size;
97+
// 使用自注意力分数对value矩阵加权
10098
for (int i = threadIdx.x; i < head_size; i += blockDim.x) {
10199
float value = 0.0f;
102-
#pragma unroll
103100
for (int t = 0; t <= pos; t++) {
104101
float* value_head = value_cache + layer_offset + t * kv_dim + head_offset;
105102
float score = score_head[t];
@@ -124,7 +121,7 @@ void mha_kernel_cu(int32_t pos, int32_t head_num, int32_t layer_index, int32_t s
124121
float* value_cache = const_cast<float*>(value_cache_tensor.ptr<float>());
125122

126123
cudaStream_t stream = config->stream;
127-
multi_head_attention_kernel<<<head_num, thread_num, 0, stream>>>(
124+
multi_head_attention_kernel<<<head_num, thread_num, head_size * sizeof(float), stream>>>(
128125
pos, seq_len, query, score, output, key_cache, value_cache, kv_dim, kv_mul, head_num,
129126
head_size, layer_offset);
130127
}

0 commit comments

Comments
 (0)
Please sign in to comment.