-
Notifications
You must be signed in to change notification settings - Fork 51
Add FP8 KV Support #737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add FP8 KV Support #737
Conversation
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds FP8 KV cache quantization support to the AutoRound quantization library. This feature enables quantizing the key-value cache in attention mechanisms to FP8 format for improved memory efficiency.
- Adds a new
enable_fp8_kv
parameter to control FP8 KV cache quantization - Implements FP8 KV cache infrastructure with calibration and quantization context
- Updates test coverage to verify FP8 KV cache serialization and functionality
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
File | Description |
---|---|
auto_round/autoround.py | Adds enable_fp8_kv parameter and context manager integration |
auto_round/experimental/fp8_kv_cache.py | Core FP8 KV cache implementation with quantization logic |
test/test_cpu/test_export.py | Test updates to verify FP8 KV cache export functionality |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
from transformers.cache_utils import DynamicCache | ||
|
||
logger.add(sys.stderr, level="TRACE") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding a logger to sys.stderr at TRACE level in a module import can interfere with application-level logging configuration. Consider removing this global logger configuration or making it conditional.
Copilot uses AI. Check for mistakes.
Get the k_scale and v_scale and output the quant-dequant key_states and value_states | ||
""" | ||
|
||
# FIXME: Should we append the key_states/value_states to the cache? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FIXME comment indicates uncertainty about the intended behavior. This should be resolved before production use, as it could affect the correctness of the KV cache implementation.
Copilot uses AI. Check for mistakes.
# FIXME: Handle this better. | ||
return "attention" in module.__class__.__name__.lower() and ( | ||
hasattr(module, "k_proj") or hasattr(module, "v_proj") or hasattr(module, "qkv_proj") | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention module detection logic uses a simple string check and attribute existence. This brittle approach should be improved with a more robust detection mechanism.
) | |
""" | |
Robustly detect if a module is an attention module. | |
Checks for known attention base classes and required attributes. | |
""" | |
# Check for known attention base classes | |
attention_bases = tuple( | |
cls for cls in (MultiheadAttention, LlamaAttention) if cls is not None | |
) | |
if attention_bases and isinstance(module, attention_bases): | |
return True | |
# Fallback: check for common attention attributes and method signatures | |
has_proj = any( | |
hasattr(module, attr) | |
for attr in ("k_proj", "v_proj", "qkv_proj", "in_proj_weight", "out_proj") | |
) | |
has_forward = hasattr(module, "forward") | |
# Optionally, check for typical attention input signatures | |
return has_proj and has_forward |
Copilot uses AI. Check for mistakes.
`[batch_size, num_heads, seq_len - residual_length, head_dim]`. | ||
|
||
|
||
# TODO: Triggered by adding kv_cache_scheme in ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO comment is incomplete and doesn't provide meaningful information. It should either be completed with specific details or removed.
# TODO: Triggered by adding kv_cache_scheme in ... |
Copilot uses AI. Check for mistakes.
Please note that kv cache does not affect tuning as we only use forward in tuning. So the first change is moving the arg from init to save_quantized? |
Signed-off-by: yiliu30 <[email protected]>
No description provided.