Skip to content

Commit

Permalink
add const_pa and fix const_norm
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Jan 23, 2025
1 parent 5100b9e commit 32bf527
Showing 1 changed file with 42 additions and 5 deletions.
47 changes: 42 additions & 5 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
HPUPagedAttentionMetadata)
from vllm.logger import init_logger
import habana_frameworks.torch.core as htcore
import math

logger = init_logger(__name__)

Expand Down Expand Up @@ -57,24 +58,58 @@ def prompt_fsdpa(
return attn_weights

const_norm = os.environ.get('VLLM_SOFTMAX_CONST_NORM', 'false').lower() == 'true'
const_pa = os.environ.get('VLLM_SOFTMAX_CONST_PA', 'false').lower() == 'true'
const_val = float(os.environ.get('VLLM_SOFTMAX_CONST_VAL', '10.0'))
eps_value = float(os.environ.get('VLLM_SOFTMAX_EPS_VALUE', str(torch.finfo(torch.bfloat16).tiny)))
def pa(attn, value, block_groups, block_mapping, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op):

def wsum_head_amax(attn, block_mapping, block_scales, **rest):
"""Perform weighted sum fused with head maximum normalization"""
attn_max = attn.amax(-1)
missing_dims = attn_max.dim() - block_scales.dim()
block_sum_attn = attn_max.mul(block_scales.reshape(-1, *[1 for _ in range(missing_dims)]))
block_sum_attn = ops.block2batch(block_sum_attn, block_mapping)
block_sum_attn = ops.batch2block(block_sum_attn, block_mapping)
attn.sub_(block_sum_attn.unsqueeze(-1))
attn_max.sub_(block_sum_attn)
attn_max = attn_max.amax(0, keepdim=True)
return attn_max.unsqueeze(-1)

def pa(attn, value, batch_size, block_groups, block_mapping, block_scales, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op):
#normalization
attn.sub_(const_val)
#attn_max = wsum_head_amax(attn, block_mapping, block_scales)
#print("attn_max(mean, max, min) is ", torch.mean(attn_max).item(), torch.max(attn_max).item(), torch.min(attn_max).item())
#attn.sub_(attn_max)
# end of norm
attn = attn.exp()
sums = attn.sum(dim=-1).unsqueeze(-1)
block_sum = sums
# Sum block's sums that belongs to the same sequeneces
group_sums = ops.block2batch(sums, block_mapping, block2batch_matmul_op)
group_sums = ops.batch2block(group_sums, block_mapping, batch2block_matmul_op)
group_sums = ops.block2batch(sums, block_mapping)
group_sums = ops.batch2block(group_sums, block_mapping)
group_sums.add_(eps_value)
group_sums = torch.maximum(block_sum, group_sums)
attn.div_(group_sums)
attn = matmul_av_op(attn, value)
return attn

def pipelined_const_pa(attn, value, block_groups, block_mapping, block_scales,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op):
# Normalize the attention scores
attn.sub_(const_val)
attn = attn.exp()
# Sum block's sums that belongs to the same sequeneces
sums = attn.sum(dim=-1).unsqueeze(-1)
block_sums = sums
group_sums = ops.block2batch(sums, block_mapping)
group_sums = ops.batch2block(group_sums, block_mapping)
# For stability in case some of the sums have been zeroed out during block aggretation
group_sums.add_(eps_value)
group_sums = torch.maximum(block_sums, group_sums)
attn = matmul_av_op(attn, value)
attn.div_(group_sums)
return attn

def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
block_bias, block_scales, block_groups, scale, matmul_qk_op,
matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,
Expand All @@ -101,8 +136,10 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping,
attn = attn.float()
htcore.mark_step()
attn = attn + block_bias
if const_norm:
attn = pa(attn, value, block_groups, block_mapping, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,)
if const_pa:
attn = pipelined_const_pa(attn, value, block_groups, block_mapping, block_scales, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op)
elif const_norm:
attn = pa(attn, value, batch_size, block_groups, block_mapping, block_scales, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op,)
else:
attn = ops.pipelined_pa(attn, value, block_groups, block_mapping, block_scales=block_scales,
batch_size=batch_size, matmul_av_op=matmul_av_op,
Expand Down

0 comments on commit 32bf527

Please sign in to comment.