Skip to content

Commit

Permalink
Add basic FP8 KV cache support
Browse files Browse the repository at this point in the history
This change adds rudimentary FP8 KV cache support. The support is
enabled by passing `--kv-cache-dtype fp8_e5m2` to the launcher. Doing so
uses this type for the KV cache. However support is still limited:

* Only the `fp8_e5m2` type is supported.
* The KV cache layout is the same as `float16`/`bfloat16` (HND).
* The FP8 KV cache is only supported for FlashInfer.
* Loading of scales is not yet supported.
  • Loading branch information
danieldk committed Oct 3, 2024
1 parent 1c84a30 commit 37df2ff
Show file tree
Hide file tree
Showing 32 changed files with 1,008 additions and 232 deletions.
9 changes: 9 additions & 0 deletions docs/source/reference/launcher.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ Options:
[env: DTYPE=]
[possible values: float16, bfloat16]

```
## KV_CACHE_DTYPE
```shell
--kv-cache-dtype <KV_CACHE_DTYPE>
Specify the dtype for the key-value cache. When this option is not provided, the dtype of the model is used (typically `float16` or `bfloat16`). Currently the only supported value is `fp8_e5m2` on CUDA

[env: KV_CACHE_DTYPE=]
[possible values: fp8_e5m2]

```
## TRUST_REMOTE_CODE
```shell
Expand Down
9 changes: 8 additions & 1 deletion integration-tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def local_launcher(
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
Expand Down Expand Up @@ -375,6 +376,9 @@ def local_launcher(
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
Expand Down Expand Up @@ -434,6 +438,7 @@ def docker_launcher(
use_flash_attention: bool = True,
disable_grammar_support: bool = False,
dtype: Optional[str] = None,
kv_cache_dtype: Optional[str] = None,
revision: Optional[str] = None,
max_input_length: Optional[int] = None,
max_batch_prefill_tokens: Optional[int] = None,
Expand All @@ -456,6 +461,9 @@ def docker_launcher(
if dtype is not None:
args.append("--dtype")
args.append(dtype)
if kv_cache_dtype is not None:
args.append("--kv-cache-dtype")
args.append(kv_cache_dtype)
if revision is not None:
args.append("--revision")
args.append(revision)
Expand Down Expand Up @@ -589,7 +597,6 @@ async def generate_load_inner(
max_new_tokens: int,
seed: Optional[int] = None,
) -> List[Response]:

import numpy as np

arange = np.arange(len(prompts))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 3923,
"logprob": -5.6328125,
"text": "What"
},
{
"id": 374,
"logprob": -1.2265625,
"text": " is"
},
{
"id": 5655,
"logprob": -9.1015625,
"text": " deep"
},
{
"id": 6975,
"logprob": -1.8085938,
"text": " learning"
},
{
"id": 30,
"logprob": -1.0439453,
"text": "?"
}
],
"seed": null,
"tokens": [
{
"id": 18682,
"logprob": -2.1992188,
"special": false,
"text": " Deep"
},
{
"id": 6975,
"logprob": -0.079956055,
"special": false,
"text": " learning"
},
{
"id": 374,
"logprob": -0.2763672,
"special": false,
"text": " is"
},
{
"id": 264,
"logprob": -0.37548828,
"special": false,
"text": " a"
},
{
"id": 27084,
"logprob": -1.4628906,
"special": false,
"text": " subset"
},
{
"id": 315,
"logprob": -0.02885437,
"special": false,
"text": " of"
},
{
"id": 5780,
"logprob": -0.2565918,
"special": false,
"text": " machine"
},
{
"id": 6975,
"logprob": -0.0063438416,
"special": false,
"text": " learning"
},
{
"id": 430,
"logprob": -1.3056641,
"special": false,
"text": " that"
},
{
"id": 374,
"logprob": -1.6035156,
"special": false,
"text": " is"
}
],
"top_tokens": null
},
"generated_text": " Deep learning is a subset of machine learning that is"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 3,
"prefill": [
{
"id": 128000,
"logprob": null,
"text": "<|begin_of_text|>"
},
{
"id": 374,
"logprob": -22.96875,
"text": " is"
},
{
"id": 5655,
"logprob": -10.71875,
"text": " deep"
},
{
"id": 6975,
"logprob": -2.6992188,
"text": " learning"
},
{
"id": 30,
"logprob": -4.8398438,
"text": "?"
}
],
"seed": 0,
"tokens": [
{
"id": 720,
"logprob": -0.4411621,
"special": false,
"text": " \n"
},
{
"id": 220,
"logprob": -0.35864258,
"special": false,
"text": " "
},
{
"id": 128001,
"logprob": 0.0,
"special": true,
"text": "<|end_of_text|>"
}
],
"top_tokens": null
},
"generated_text": "What is deep learning? \n "
}
Loading

0 comments on commit 37df2ff

Please sign in to comment.