Skip to content

Commit

Permalink
Fixup other models that use packed QKV
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jun 7, 2024
1 parent f1db89c commit 9bff076
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 133 deletions.
2 changes: 0 additions & 2 deletions server/text_generation_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,13 @@ def load_qkv(
prefix: str,
weights,
bias: bool,
head_size: int,
num_heads: int,
num_key_value_heads: int,
):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix,
quantize=config.quantize,
head_size=head_size,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM != "xpu":
Expand Down Expand Up @@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:


def load_attention(config, prefix, weights):
if config.n_heads != config.attn_config.kv_n_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
)


def _load_gqa(config, prefix: str, weights):
assert config.d_model % config.n_heads == 0
assert config.n_heads % weights.process_group.size() == 0

head_dim = config.d_model // config.n_heads
world_size = weights.process_group.size()
rank = weights.process_group.rank()

q_block_size = config.d_model // world_size
q_start = rank * q_block_size
q_stop = (rank + 1) * q_block_size

kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
k_offset = config.d_model
k_start = k_offset + rank * kv_block_size
k_stop = k_offset + (rank + 1) * kv_block_size

v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
v_start = v_offset + rank * kv_block_size
v_stop = v_offset + (rank + 1) * kv_block_size

if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight

try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
k_qweight = qweight_slice[:, k_start:k_stop]
v_qweight = qweight_slice[:, v_start:v_stop]

qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
)

qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
q_qzeros = qzeros_slice[:, q_start:q_stop]
k_qzeros = qzeros_slice[:, k_start:k_stop]
v_qzeros = qzeros_slice[:, v_start:v_stop]

qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)

scales_slice = weights._get_slice(f"{prefix}.scales")
q_scales = scales_slice[:, q_start:q_stop]
k_scales = scales_slice[:, k_start:k_stop]
v_scales = scales_slice[:, v_start:v_stop]

scales = torch.cat([q_scales, k_scales, v_scales], dim=1)

bits, groupsize, desc_act, quant_method = weights._get_gptq_params()

from text_generation_server.layers import HAS_EXLLAMA

use_exllama = (
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
)

if config.quantize == "gptq" and quant_method == "gptq":
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
q_g_idx = g_idx_slice[:, q_start:q_stop]
k_g_idx = g_idx_slice[:, k_start:k_stop]
v_g_idx = g_idx_slice[:, v_start:v_stop]

w = [q_g_idx, k_g_idx, v_g_idx]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif config.quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conveersion_utils import (
fast_awq_to_gptq,
)

qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None

weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif config.quantize == "marlin":
# NOTE: at the time marlin support was added, the only model that
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
# but it requires manual concatenation of weight files.
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
k = qkv_slice[k_start:k_stop]
v = qkv_slice[v_start:v_stop]

weight = torch.cat([q, k, v], dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
num_heads=config.n_heads,
num_key_value_heads=config.attn_config.kv_n_heads,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
rank = weights.process_group.rank()

# Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads,
config.num_attention_heads,
)

# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,14 @@ def load_attention(config, prefix, weights):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)

num_heads = config.num_attention_heads
head_size = config.hidden_size // num_heads

# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
head_size=head_size,
num_heads=num_heads,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
Expand All @@ -75,6 +71,8 @@ def load_attention(config, prefix, weights):
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)

# otherwise, load the default attention based on the number of heads
Expand Down
1 change: 0 additions & 1 deletion server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def get_weights_col_packed_qkv(
self,
prefix: str,
quantize: str,
head_size: int,
num_heads: int,
num_key_value_heads: int,
):
Expand Down

0 comments on commit 9bff076

Please sign in to comment.