Add batch splitting in attention layer for decode to hide NIC latency #2334
+75
−16
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces a decode attention batch split feature that enables splitting the attention computation across batches during the decoding phase, along with a distributed barrier optimization to reduce network communication overhead.
Key Changes
1. Decode Attention Batch Split Feature
decode_attn_batch_splitparameter to enable batch splitting during decoding phaseGaudiLlamaDecoderLayer.forward()method2. Implementation Details
attn_batch_splitfunctionality to support decode phaseBenefits
Total Perf for 405B llama model increase by 4.8% .
baseline 2k_2k_180 config: Perf -> 965 tokens/sec
with decode split 2 , 2k_2k_180 config: perf->1010 tokens/sec