Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions src/llmcompressor/modifiers/quantization/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def update(
_pad_and_append_at_idx_(self.k_observers, layer_idx, k_observer)
_pad_and_append_at_idx_(self.v_observers, layer_idx, v_observer)

kv_states_dim = key_states.dim()
if kv_states_dim == 4:
# reshape for per channel scenario
num_heads = key_states.shape[1]
head_dim = key_states.shape[-1]
# from [batch_size, num_heads, seq_len - residual_length, head_dim]
# to [batch_size, seq_len - residual_length, num_heads * head_dim]
key_states = key_states.transpose(1, 2).flatten(2)
value_states = value_states.transpose(1, 2).flatten(2)

q_key_states = self._quantize(
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
)
Expand All @@ -106,6 +116,19 @@ def update(
q_value_states, KVCacheScaleType.VALUE, layer_idx
)

if kv_states_dim == 4:
# reshape for per channel scenario
# from [batch_size, seq_len - residual_length, num_heads * head_dim]
# to [batch_size, num_heads, seq_len - residual_length, head_dim]
qdq_key_states = qdq_key_states.view(
qdq_key_states.shape[0], qdq_key_states.shape[1],
num_heads, head_dim
).transpose(1, 2).contiguous()
qdq_value_states = qdq_value_states.view(
qdq_value_states.shape[0], qdq_value_states.shape[1],
num_heads, head_dim
).transpose(1, 2).contiguous()

keys_to_return, values_to_return = qdq_key_states, qdq_value_states

return keys_to_return, values_to_return
Expand Down Expand Up @@ -155,8 +178,8 @@ def _quantize(self, tensor, kv_type, layer_idx):
zps = self.v_zps

scale, zp = observer(tensor)
_pad_and_append_at_idx_(scales, layer_idx, scale)
_pad_and_append_at_idx_(zps, layer_idx, zp)
_pad_and_append_at_idx_(scales, layer_idx, scale.squeeze())
_pad_and_append_at_idx_(zps, layer_idx, zp.squeeze())

q_tensor = quantize(
x=tensor,
Expand Down
4 changes: 4 additions & 0 deletions src/llmcompressor/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ def calibrate_kv_cache_output_hook(module: Module, _args: Any, _output: torch.Te
kv_cache = getattr(module, "kv_cache")
k_scale = kv_cache.k_scales[module.layer_idx]
v_scale = kv_cache.v_scales[module.layer_idx]

if kv_cache.quantization_args.strategy == QuantizationStrategy.CHANNEL:
k_scale = k_scale.unsqueeze(-1)
v_scale = v_scale.unsqueeze(-1)
update_offload_parameter(module, KVCacheScaleType.KEY.value, k_scale)
update_offload_parameter(module, KVCacheScaleType.VALUE.value, v_scale)

Expand Down
11 changes: 9 additions & 2 deletions src/llmcompressor/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,15 @@ def get_qparams(
self._zero_point[:, group_index] = zero_point.squeeze(1)

elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
# assume observed is transposed, because its the output, hence use dim 0
self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
# 1. dim=2 scenario: in kv cache quant scenario which is
# [batch_size, seq_len - residual_length, num_heads * head_dim]
# 2. dim=0 scenario: assume observed is transposed,
# because its the output, hence use dim 0
dim = 2 if observed.dim() == 3 else 0
self._scale, self._zero_point = self.get_qparams_along_dim(
observed,
dim
)

elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
# use dim 1, assume the obsersed.shape = [batch, token, hidden]
Expand Down