diff --git a/README.md b/README.md index 8ae96eb3..ff9297de 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,20 @@ Benchmarks run on one GCD of a MI-250x. | Llama-2-7B | Base | 76.33 | 1028.70 | | | 8-bit | 101.86 | 700.06 | +### Using Grouped Query Attention +Benchmarks run on 1 NVIDIA H100. + +Using +```bash +export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf +``` + +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | +| -------- | ------- | ------ | ------ | +| Llama-2-7B | Base | 146.66 | 1938.12 | +| | 8-bit | 233.50 | 1543.55 | +| | 4-bit (G=32) | 267.11 | 1103.14 | + ## Generate Text Model definition in `model.py`, generation code in `generate.py`. diff --git a/model.py b/model.py index b89a19a0..06cf365a 100644 --- a/model.py +++ b/model.py @@ -195,9 +195,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona if self.kv_cache is not None: k, v = self.kv_cache.update(input_pos, k, v) - k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, enable_gqa=True) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)