Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,28 @@ typedef __hip_bfloat16 __nv_bfloat16;
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))

#define DATTN_UNIFIED_QK_MAX 1
#define WARNING(msg) printf("\033[33mWARNING: %s\033[0m\n", msg)
#define DATTENTION_QK_MAX 1.73f // llama2-7B short 99.99mean


using namespace std;


#if !defined(likely)
#define likely(x) __builtin_expect(!!(x), 1)
#endif
#if !defined(unlikely)
#define unlikely(x) __builtin_expect(!!(x), 0)
#endif

namespace vllm {

inline __device__ bool is_half_inf(float val) {
return (val <= -65504.0f || val >= 65504.0f);
}

// Utility function for attention softmax.
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
Expand Down Expand Up @@ -158,6 +176,9 @@ __device__ void paged_attention_kernel(
return;
}

__shared__ bool recompute;
recompute = false;

const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
Expand Down Expand Up @@ -347,12 +368,26 @@ __device__ void paged_attention_kernel(
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);

if (unlikely((qk - DATTENTION_QK_MAX) >=88.72283f ||
(qk - DATTENTION_QK_MAX) <= -87.33654f)) {
/* Rollback */
recompute = true;
// WARNING("qk_max causes float32 overflow!!");
printf("\033[33mWARNING: Recompute. qk_max overflow!! qk = %.6f\033[0m\n", qk);
}
}
}
}

/* For synchronizing all recomput */
if (likely(recompute == false)) {
/* Set unfied qk_max value */
qk_max = DATTENTION_QK_MAX;
} else {

qk_max = propogate_qk_max<NUM_WARPS, THREAD_GROUP_SIZE>(&red_smem[0], qk_max);

}
// Get the sum of the exp values.
float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
Expand Down Expand Up @@ -820,6 +855,7 @@ __global__ void dattention_kernel(
// Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX;
bool recompute = false;

// blocksparse specific vars
int bs_block_offset;
Expand Down Expand Up @@ -920,15 +956,43 @@ __global__ void dattention_kernel(
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len;

#if DATTN_UNIFIED_QK_MAX // new
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
/***** debug ******/
if (is_half_inf(qk - DATTENTION_QK_MAX)) {
WARNING("qk_max causes float16 overflow!!");
}
if (likely(!isinf(qk - DATTENTION_QK_MAX))) { /* Set unfied qk_max value */
qk_max = DATTENTION_QK_MAX;
} else { /* Rollback */
recompute = true;
WARNING("qk_max causes float32 overflow!!");
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
}
#else // vanilla
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
#endif
}
}
}

// Perform reduction across all threads in the same thread block
qk_max = propogate_qk_max<NUM_WARPS, THREAD_GROUP_SIZE>(&red_smem[0], qk_max);
__syncthreads(); // this works but not very sure why yet
// Multiple thread blocks/warps here
// TODO - use __shared__ for recompute to reduce sync overhead
/* For synchronizing all recomput */
if (recompute == true) { /* Someone overflowed */
// Perform reduction across all threads in the same thread block
qk_max = propogate_qk_max<NUM_WARPS, THREAD_GROUP_SIZE>(&red_smem[0], qk_max);
//if(threadIdx.x == 0) {
// printf("[%d, %d, %d]: scale %f qk_max %f. layer_offset %ld, kv_head_stride %d - %d. q_stride %ld\n", blockIdx.x, blockIdx.y, threadIdx.x, scale, qk_max, layer_offset, KV_HEAD_STRIDE, kv_head_stride, q_stride);
//}
} else {
// another point can set qk_max
}

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand Down