1
1
#include < base/cuda_config.h>
2
2
#include < tensor/tensor.h>
3
+ #include < cfloat>
3
4
#include < cub/cub.cuh>
4
5
#include " mha_kernel.cuh"
6
+ #include < base/tick.h>
5
7
namespace kernel {
6
8
constexpr static int thread_num = 256 ;
7
9
__device__ void softmax_gpu (float * __restrict__ x, int size) {
@@ -44,6 +46,7 @@ __device__ void softmax_gpu(float* __restrict__ x, int size) {
44
46
}
45
47
}
46
48
49
+
47
50
__global__ void multi_head_attention_kernel (int32_t pos, int32_t seq_len, float * query,
48
51
float * score_ptr, float * output, float * key_cache,
49
52
float * value_cache, int32_t kv_dim, int32_t kv_mul,
@@ -54,38 +57,32 @@ __global__ void multi_head_attention_kernel(int32_t pos, int32_t seq_len, float*
54
57
return ;
55
58
}
56
59
57
- float scale = 1 .f / sqrtf (head_size);
60
+ extern __shared__ float s_query_head[];
61
+ float scale = 1 .f / sqrtf (float (head_size));
58
62
float * query_head = query + head * head_size;
63
+
64
+ // 预加载query到共享内存
65
+ for (int i = threadIdx .x ; i < head_size; i += blockDim .x ) {
66
+ s_query_head[i] = query_head[i];
67
+ }
68
+ __syncthreads ();
69
+
59
70
float * score_head = score_ptr + head * seq_len;
71
+ // head当前的注意力头索引,kv_mul用于gqa,head_size表示一个自注意力头的维度
72
+ // kv_dim = head_size * head_num,多头自注意力情况下的key,value 维度
73
+ // kv_dim = head_size * head_num / kv_num,GQA情况下的key,value 维度
60
74
int head_offset = (head / kv_mul) * head_size;
75
+ // 计算自注意力分数
61
76
for (int t = threadIdx .x ; t <= pos; t += blockDim .x ) {
62
77
float * key_head = key_cache + layer_offset + t * kv_dim + head_offset;
63
- /* *
64
- * 在Meta的Llama注意力机制实现中,head_dim等于head_size。
65
- *
66
- * xq = xq.transpose(1, 2) # 转置后形状为 (heads, sequence_length, head_dim)
67
- * # 如果sequence_length为1,则形状简化为 (heads, head_dim)
68
- * keys = keys.transpose(1, 2) # 同样转置keys,得到形状 (heads, sequence_length, head_dim)
69
- * # 若sequence_length为1,则形状也简化为 (heads, head_dim)
70
- *
71
- * 在我们的代码实现中,计算公式为 (head / kv_mul) * head_size。
72
- * 其中,在多头注意力(MHA)机制里,kv_mul的值为1,
73
- * 因此计算得到的head_offset就等于head * head_size。
74
- *
75
- * 这里的head_offset用于定位到当前处理的头部(head),而t * kv_dim (即t *
76
- * dim)则用于定位到历史的key向量。
77
- */
78
-
79
- // query @ key 逐个头相乘,从上面的代码可以看出
78
+
80
79
float score = 0 .0f ;
81
- #pragma unroll
82
80
for (int i = 0 ; i < head_size; i += 4 ) {
83
- float4 key_head_float4 = *reinterpret_cast <float4 *>(key_head + i);
84
- float4 query_head_float4 = *reinterpret_cast <float4 *>(query_head + i);
85
- score += key_head_float4.x * query_head_float4.x ;
86
- score += key_head_float4.y * query_head_float4.y ;
87
- score += key_head_float4.z * query_head_float4.z ;
88
- score += key_head_float4.w * query_head_float4.w ;
81
+ float4 key_val = *reinterpret_cast <float4 *>(key_head + i);
82
+ float4 query_val = *reinterpret_cast <float4 *>(s_query_head + i);
83
+
84
+ score += key_val.x * query_val.x + key_val.y * query_val.y + key_val.z * query_val.z +
85
+ key_val.w * query_val.w ;
89
86
}
90
87
91
88
score *= scale;
@@ -97,9 +94,9 @@ __global__ void multi_head_attention_kernel(int32_t pos, int32_t seq_len, float*
97
94
__syncthreads ();
98
95
99
96
float * output_head = output + head * head_size;
97
+ // 使用自注意力分数对value矩阵加权
100
98
for (int i = threadIdx .x ; i < head_size; i += blockDim .x ) {
101
99
float value = 0 .0f ;
102
- #pragma unroll
103
100
for (int t = 0 ; t <= pos; t++) {
104
101
float * value_head = value_cache + layer_offset + t * kv_dim + head_offset;
105
102
float score = score_head[t];
@@ -124,7 +121,7 @@ void mha_kernel_cu(int32_t pos, int32_t head_num, int32_t layer_index, int32_t s
124
121
float * value_cache = const_cast <float *>(value_cache_tensor.ptr <float >());
125
122
126
123
cudaStream_t stream = config->stream ;
127
- multi_head_attention_kernel<<<head_num, thread_num, 0 , stream>>> (
124
+ multi_head_attention_kernel<<<head_num, thread_num, head_size * sizeof ( float ) , stream>>> (
128
125
pos, seq_len, query, score, output, key_cache, value_cache, kv_dim, kv_mul, head_num,
129
126
head_size, layer_offset);
130
127
}
0 commit comments