Skip to content

Commit

Permalink
Add support for Phi-3-medium
Browse files Browse the repository at this point in the history
The main difference between the medium and mini models is that medium
uses grouped query attention with a packed QKV matrix. This change adds
support for GQA with packed matrixes to `Weights.get_weights_col_packed`
and uses it for Phi-3. This also allows us to remove the custom
implementation of GQA from dbrx attention loading.
  • Loading branch information
danieldk committed Jun 7, 2024
1 parent bf3c813 commit 75ec1a2
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 164 deletions.
17 changes: 15 additions & 2 deletions server/text_generation_server/layers/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,22 @@ def load_gate_up(cls, config, prefix: str, weights, bias: bool):
return cls(linear)

@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
def load_qkv(
cls,
config,
prefix: str,
weights,
bias: bool,
num_heads: int,
num_key_value_heads: int,
):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
weight = weights.get_weights_col_packed_qkv(
prefix,
quantize=config.quantize,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any
from loguru import logger
from text_generation_server.utils.import_utils import SYSTEM

if SYSTEM != "xpu":
Expand Down Expand Up @@ -164,129 +163,13 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor:


def load_attention(config, prefix, weights):
if config.n_heads != config.attn_config.kv_n_heads:
return _load_gqa(config, prefix, weights)
else:
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
)


def _load_gqa(config, prefix: str, weights):
assert config.d_model % config.n_heads == 0
assert config.n_heads % weights.process_group.size() == 0

head_dim = config.d_model // config.n_heads
world_size = weights.process_group.size()
rank = weights.process_group.rank()

q_block_size = config.d_model // world_size
q_start = rank * q_block_size
q_stop = (rank + 1) * q_block_size

kv_block_size = (config.attn_config.kv_n_heads * head_dim) // world_size
k_offset = config.d_model
k_start = k_offset + rank * kv_block_size
k_stop = k_offset + (rank + 1) * kv_block_size

v_offset = config.d_model + config.attn_config.kv_n_heads * head_dim
v_start = v_offset + rank * kv_block_size
v_stop = v_offset + (rank + 1) * kv_block_size

if config.quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight

try:
qweight_slice = weights._get_slice(f"{prefix}.qweight")
q_qweight = qweight_slice[:, q_start:q_stop]
k_qweight = qweight_slice[:, k_start:k_stop]
v_qweight = qweight_slice[:, v_start:v_stop]

qweight = torch.cat([q_qweight, k_qweight, v_qweight], dim=1)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{config.quantize}` weight, make sure the model is already quantized"
)

qzeros_slice = weights._get_slice(f"{prefix}.qzeros")
q_qzeros = qzeros_slice[:, q_start:q_stop]
k_qzeros = qzeros_slice[:, k_start:k_stop]
v_qzeros = qzeros_slice[:, v_start:v_stop]

qzeros = torch.cat([q_qzeros, k_qzeros, v_qzeros], dim=1)

scales_slice = weights._get_slice(f"{prefix}.scales")
q_scales = scales_slice[:, q_start:q_stop]
k_scales = scales_slice[:, k_start:k_stop]
v_scales = scales_slice[:, v_start:v_stop]

scales = torch.cat([q_scales, k_scales, v_scales], dim=1)

bits, groupsize, desc_act, quant_method = weights._get_gptq_params()

from text_generation_server.layers import HAS_EXLLAMA

use_exllama = (
bits == 4 and HAS_EXLLAMA and config.quantize == "gptq" and not desc_act
)

if config.quantize == "gptq" and quant_method == "gptq":
g_idx_slice = weights._get_slice(f"{prefix}.g_idx")
q_g_idx = g_idx_slice[:, q_start:q_stop]
k_g_idx = g_idx_slice[:, k_start:k_stop]
v_g_idx = g_idx_slice[:, v_start:v_stop]

w = [q_g_idx, k_g_idx, v_g_idx]
for w2 in w[1:]:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
elif config.quantize == "gptq" and quant_method == "awq":
log_once(
logger.info, "Converting AWQ model to Exllama/GPTQ packing format."
)
from text_generation_server.layers.awq.conveersion_utils import (
fast_awq_to_gptq,
)

qweight, qzeros = fast_awq_to_gptq(qweight, qzeros)
if use_exllama:
g_idx = None
else:
g_idx = (
torch.arange(qweight.shape[0] * (32 // bits), device=qweight.device)
// groupsize
).to(dtype=torch.int32)
else:
g_idx = None

weight = GPTQWeight(
qweight=qweight,
qzeros=qzeros,
scales=scales,
g_idx=g_idx,
bits=bits,
groupsize=groupsize,
use_exllama=use_exllama,
)
elif config.quantize == "marlin":
# NOTE: at the time marlin support was added, the only model that
# exists is LnL-AI/dbrx-base-converted-v2-4bit-gptq-marlin(-v2),
# but it requires manual concatenation of weight files.
raise RuntimeError("dbrx models with marlin quantization are not yet supported")
else:
qkv_slice = weights._get_slice(f"{prefix}.Wqkv.weight")
q = qkv_slice[q_start:q_stop]
k = qkv_slice[k_start:k_stop]
v = qkv_slice[v_start:v_stop]

weight = torch.cat([q, k, v], dim=0)
weight = weight.to(dtype=weights.dtype).to(device=weights.device)

return TensorParallelColumnLinear(
get_linear(weight, bias=None, quantize=config.quantize)
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.Wqkv",
weights=weights,
bias=False,
num_heads=config.n_heads,
num_key_value_heads=config.attn_config.kv_n_heads,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def _load_qkv_gptq(config, prefix: str, weights):
rank = weights.process_group.rank()

# Weights
weight = weights.get_weights_col_packed_qkv(f"{prefix}.c_attn", config.quantize)
weight = weights.get_weights_col_packed_qkv(
f"{prefix}.c_attn",
config.quantize,
config.num_attention_heads,
config.num_attention_heads,
)

# Bias
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,17 @@ def load_attention(config, prefix, weights):
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
num_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
)

# otherwise, load the default attention based on the number of heads
Expand Down Expand Up @@ -107,6 +111,11 @@ def __init__(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
if config.num_key_value_heads % weights.process_group.size() != 0:
raise ValueError(
f"`num_key_value_heads` must be divisible by `num_shards` (got `num_key_value_heads`: {config.num_key_value_heads} "
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.num_key_value_heads = (
config.num_key_value_heads // weights.process_group.size()
Expand Down
118 changes: 81 additions & 37 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dataclasses import dataclass, field
import os
from pathlib import Path
from typing import List, Dict, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
Expand Down Expand Up @@ -121,58 +120,71 @@ def get_sharded(self, tensor_name: str, dim: int):
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)

def _get_qweight(self, name: str, blocks: int):
def _get_qweight(self, name: str, block_sizes: Union[int, List[int]]):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
assert (
total_size % blocks == 0
), f"Prepacked quantized matrix is not divisible by {blocks}"
single_size = total_size // blocks
block_sizes = _blocks_to_block_sizes(total_size=total_size, blocks=block_sizes)

world_size = self.process_group.size()
rank = self.process_group.rank()

assert (
single_size % world_size == 0
), f"Prepacked quantized matrix cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size

weights = []
for block in range(blocks):
weights.append(
slice_[:, start + block * single_size : stop + block * single_size]
)
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
weights.append(slice_[:, block_offset + start : block_offset + stop])
block_offset += block_size

weight = torch.cat(weights, dim=1)
weight = weight.to(device=self.device)
return weight

def get_weights_col_packed_qkv(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 3)
def get_weights_col_packed_qkv(
self,
prefix: str,
quantize: str,
num_heads: int,
num_key_value_heads: int,
):
return self.get_weights_col_packed(
prefix, quantize, [num_heads, num_key_value_heads, num_key_value_heads]
)

def get_weights_col_packed_gate_up(self, prefix: str, quantize: str):
return self.get_weights_col_packed(prefix, quantize, 2)

def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
def get_weights_col_packed(
self, prefix: str, quantize: str, block_sizes: Union[int, List[int]]
):
"""
Highly specific when the underlying tensor is a simple cat of Q,K,V instead of being
already alternating Q,K,V within the main tensor
already alternating Q,K,V within the main tensor.
The columns are split in equally sized blocks when blocks is an `int`, or
in blocks proportional given to the sizes. For instance `[2, 1, 1]` will
divide an input with dimensionality `1024` in `[512, 256, 256]`. This is
convenient for e.g. splitting QKV without knowing the storage details of
quantized weights.
"""
if quantize in ["gptq", "awq"]:
from text_generation_server.layers.gptq import GPTQWeight

try:
qweight = self._get_qweight(f"{prefix}.qweight", blocks)
qweight = self._get_qweight(f"{prefix}.qweight", block_sizes)
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)

bits, groupsize, _, quant_method = self._get_gptq_params()

qzeros = self._get_qweight(f"{prefix}.qzeros", blocks)
scales = self._get_qweight(f"{prefix}.scales", blocks)
qzeros = self._get_qweight(f"{prefix}.qzeros", block_sizes)
scales = self._get_qweight(f"{prefix}.scales", block_sizes)
scales = scales.to(dtype=self.dtype)

if quantize == "gptq" and quant_method == "gptq":
Expand Down Expand Up @@ -205,27 +217,31 @@ def get_weights_col_packed(self, prefix: str, quantize: str, blocks: int):
elif quantize == "marlin":
from text_generation_server.layers.marlin import MarlinWeight

B = self._get_qweight(f"{prefix}.B", blocks)
s = self._get_qweight(f"{prefix}.s", blocks)
B = self._get_qweight(f"{prefix}.B", block_sizes)
s = self._get_qweight(f"{prefix}.s", block_sizes)
weight = MarlinWeight(B=B, s=s)
else:
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
block_sizes = _blocks_to_block_sizes(
total_size=total_size, blocks=block_sizes
)

world_size = self.process_group.size()
rank = self.process_group.rank()

assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensors = []
for i in range(blocks):
tensor = slice_[start + i * single_size : stop + i * single_size]
block_offset = 0
for block_size in block_sizes:
assert (
block_size % world_size == 0
), f"Prepacked weights cannot be sharded across {world_size} shards"
shard_block_size = block_size // world_size
start = rank * shard_block_size
stop = (rank + 1) * shard_block_size
tensor = slice_[block_offset + start : block_offset + stop]
tensors.append(tensor)
block_offset += block_size
weight = torch.cat(tensors, dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
Expand Down Expand Up @@ -593,3 +609,31 @@ def _set_gptq_params(self, model_id, revision):
self.quant_method = "awq"
except Exception:
pass


def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""
Convert block count or proportions to block sizes.
This function accepts
- The number of blocks (int), in which case the block size is
total_size//blocks; or
- A list of block sizes (List[int]).
In the latter case, if sum(blocks) < total_size, the ratios between
the block sizes will be preserved. For instance, if blocks is
[2, 1, 1] and total_size is 1024, the returned block sizes are
[512, 256, 256].
"""
if isinstance(blocks, list):
total_blocks = sum(blocks)
assert (
total_size % total_blocks == 0
), f"Cannot split {total_size} in proportional blocks: {blocks}"
part_size = total_size // total_blocks
return [part_size * block for block in blocks]
else:
assert total_size % blocks == 0, f"Prepacked is not divisible by {blocks}"
single_size = total_size // blocks
return [single_size] * blocks

0 comments on commit 75ec1a2

Please sign in to comment.