Skip to content

Conversation

@jthakurH
Copy link
Contributor

Summary

This PR introduces a decode attention batch split feature that enables splitting the attention computation across batches during the decoding phase, along with a distributed barrier optimization to reduce network communication overhead.

Key Changes

1. Decode Attention Batch Split Feature

  • Added decode_attn_batch_split parameter to enable batch splitting during decoding phase
  • Implemented batch splitting logic in GaudiLlamaDecoderLayer.forward() method
  • Properly manages KV cache handling across batch splits during decode phase

2. Implementation Details

  • Extended existing attn_batch_split functionality to support decode phase
  • Maintains backward compatibility with existing attention batch split for prompt phase
  • Proper residual connection handling across batch splits
  • Efficient memory management for split operations

Benefits

  1. Performance: Enables better memory utilization during decoding by splitting attention computation
  2. Scalability: Reduces network communication overhead in multi-node setups

Total Perf for 405B llama model increase by 4.8% .
baseline 2k_2k_180 config: Perf -> 965 tokens/sec
with decode split 2 , 2k_2k_180 config: perf->1010 tokens/sec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants