diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 734ee93e043b..5095010d202d 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -44,10 +44,28 @@ typedef __hip_bfloat16 __nv_bfloat16; #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#define DATTN_UNIFIED_QK_MAX 1 +#define WARNING(msg) printf("\033[33mWARNING: %s\033[0m\n", msg) +#define DATTENTION_QK_MAX 1.73f // llama2-7B short 99.99mean + + using namespace std; + +#if !defined(likely) +#define likely(x) __builtin_expect(!!(x), 1) +#endif +#if !defined(unlikely) +#define unlikely(x) __builtin_expect(!!(x), 0) +#endif + namespace vllm { +inline __device__ bool is_half_inf(float val) { + return (val <= -65504.0f || val >= 65504.0f); +} + // Utility function for attention softmax. template inline __device__ float block_sum(float* red_smem, float sum) { @@ -158,6 +176,9 @@ __device__ void paged_attention_kernel( return; } + __shared__ bool recompute; + recompute = false; + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; @@ -347,12 +368,26 @@ __device__ void paged_attention_kernel( logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); + + if (unlikely((qk - DATTENTION_QK_MAX) >=88.72283f || + (qk - DATTENTION_QK_MAX) <= -87.33654f)) { + /* Rollback */ + recompute = true; + // WARNING("qk_max causes float32 overflow!!"); + printf("\033[33mWARNING: Recompute. qk_max overflow!! qk = %.6f\033[0m\n", qk); + } } } } + + /* For synchronizing all recomput */ + if (likely(recompute == false)) { + /* Set unfied qk_max value */ + qk_max = DATTENTION_QK_MAX; + } else { qk_max = propogate_qk_max(&red_smem[0], qk_max); - + } // Get the sum of the exp values. float exp_sum = 0.f; for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { @@ -820,6 +855,7 @@ __global__ void dattention_kernel( // Each thread group fetches x elements from the key at a time. constexpr int x = 16 / sizeof(cache_t); float qk_max = -FLT_MAX; + bool recompute = false; // blocksparse specific vars int bs_block_offset; @@ -920,15 +956,43 @@ __global__ void dattention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= seq_len; + + #if DATTN_UNIFIED_QK_MAX // new + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + /***** debug ******/ + if (is_half_inf(qk - DATTENTION_QK_MAX)) { + WARNING("qk_max causes float16 overflow!!"); + } + if (likely(!isinf(qk - DATTENTION_QK_MAX))) { /* Set unfied qk_max value */ + qk_max = DATTENTION_QK_MAX; + } else { /* Rollback */ + recompute = true; + WARNING("qk_max causes float32 overflow!!"); + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + #else // vanilla logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); + #endif } } } - // Perform reduction across all threads in the same thread block - qk_max = propogate_qk_max(&red_smem[0], qk_max); + __syncthreads(); // this works but not very sure why yet + // Multiple thread blocks/warps here + // TODO - use __shared__ for recompute to reduce sync overhead + /* For synchronizing all recomput */ + if (recompute == true) { /* Someone overflowed */ + // Perform reduction across all threads in the same thread block + qk_max = propogate_qk_max(&red_smem[0], qk_max); + //if(threadIdx.x == 0) { + // printf("[%d, %d, %d]: scale %f qk_max %f. layer_offset %ld, kv_head_stride %d - %d. q_stride %ld\n", blockIdx.x, blockIdx.y, threadIdx.x, scale, qk_max, layer_offset, KV_HEAD_STRIDE, kv_head_stride, q_stride); + //} + } else { + // another point can set qk_max + } // Get the sum of the exp values. float exp_sum = 0.f;