Skip to content

Commit

Permalink
fix for FP8 (HabanaAI#639)
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi authored Dec 17, 2024
1 parent a23e1a1 commit e2ea481
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,7 @@ def __init__(self,
self.split_threshold = split_threshold
self.split_size = split_size
self.prefix = prefix
self.skip_seq_split = False

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -1099,7 +1100,7 @@ def resolve_input(self, input_):
input_parallel = splitted_input[tp_rank].contiguous()
return input_parallel

def forward(self, input_, skip_seq_split=False):
def forward(self, input_):
input_parallel = self.resolve_input(input_)

# Matrix multiply.
Expand All @@ -1122,7 +1123,7 @@ def forward(self, input_, skip_seq_split=False):
do_split = self.do_split and seq_len > 1 # split decode
# NOTE: we found split tensor when it is too small is not helping with the performance.
# 1 * 1024 * 4096 * 3 is [batch_size, seq_len, hidden_size * 3]
do_split = do_split and shape_total > 1 * 1024 * 8192 * 3 and not skip_seq_split
do_split = do_split and shape_total > 1 * 1024 * 8192 * 3 and not self.skip_seq_split

if do_split:
input_parallels = split_tensor_along_x_dim(input_parallel, 1, self.split_size)
Expand Down
6 changes: 4 additions & 2 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def __init__(
def forward(self, x, skip_seq_split=False):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x, skip_seq_split=skip_seq_split)
self.down_proj.skip_seq_split=skip_seq_split
x, _ = self.down_proj(x)
return x


Expand Down Expand Up @@ -215,7 +216,8 @@ def forward(
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs)
output, _ = self.o_proj(attn_output, skip_seq_split=skip_seq_split)
self.o_proj.skip_seq_split=skip_seq_split
output, _ = self.o_proj(attn_output)
return output


Expand Down

0 comments on commit e2ea481

Please sign in to comment.