Skip to content

Commit

Permalink
feat: support base model generation and refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 7, 2024
1 parent f032ccf commit f5e75e7
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 165 deletions.
12 changes: 7 additions & 5 deletions server/text_generation_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from text_generation_server.adapters.config import AdapterConfig, ModuleMap

LORA = "lora"
from text_generation_server.adapters.weights import (
AdapterBatchMetadata,
AdapterWeights,
Expand Down Expand Up @@ -246,7 +245,7 @@ def can_vectorize(self, pg: ProcessGroup) -> bool:

@classmethod
def key(cls) -> str:
return LORA
return "lora"

@classmethod
def load(
Expand Down Expand Up @@ -279,9 +278,12 @@ def load(
}

max_rank = max(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
(
adapter_weights[idx].lora_a_r
for idx in segment_indices
if idx in adapter_weights
),
default=0,
)

if prefill or max_rank > BGMV_MAX_RANK:
Expand Down
9 changes: 2 additions & 7 deletions server/text_generation_server/adapters/weights.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#############
from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
Expand All @@ -7,10 +6,6 @@
import torch


LORA = "lora"
LM_HEAD = "lm_head"


@dataclass
class AdapterBatchMetadata:
# [batch_size]
Expand Down Expand Up @@ -127,15 +122,15 @@ def from_meta(
if v.is_empty():
continue
data[k] = v.get_data(
meta, prefill, prefill_head_indices if k == LM_HEAD else None
meta, prefill, prefill_head_indices if k == "lm_head" else None
)
return AdapterBatchData(meta=meta, data=data, prefill=prefill)

def ranks(self) -> Set[int]:
# TODO(travis): refactor to be less coupled to lora implementation
ranks = set()
for layer_data in self.data.values():
lora_data = layer_data.get(LORA)
lora_data = layer_data.get("lora")
if lora_data is None:
continue

Expand Down
1 change: 0 additions & 1 deletion server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def serve(
if otlp_endpoint is not None:
setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint)

# TODO: determine if this api makes sense
lora_adapter_ids = os.getenv("LORA_ADAPTERS", None)

# split on comma and strip whitespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,6 @@
except Exception as e:
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")

# Constants
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"

GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"


def load_attention(config, prefix, weights, layer_id):
# Only defined in granite.
Expand Down Expand Up @@ -96,7 +86,7 @@ def load_attention(config, prefix, weights, layer_id):
return TensorParallelMultiAdapterLinear.load(
base_layer,
layer_id,
[Q_PROJ, K_PROJ, V_PROJ],
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
Expand Down Expand Up @@ -151,7 +141,7 @@ def __init__(
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
index,
O_PROJ,
"o_proj",
process_group=weights.process_group,
)

Expand Down Expand Up @@ -259,7 +249,7 @@ def __init__(self, prefix, config, weights, index):
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
index,
[GATE_PROJ, UP_PROJ],
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
Expand All @@ -277,7 +267,7 @@ def __init__(self, prefix, config, weights, index):
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
index,
DOWN_PROJ,
"down_proj",
process_group=weights.process_group,
)

Expand Down
48 changes: 17 additions & 31 deletions server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,31 +20,17 @@
tracer = trace.get_tracer(__name__)

from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.lora import LoraConfig

Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"

GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"

LM_HEAD = "lm_head"


# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [
Q_PROJ,
K_PROJ,
V_PROJ,
O_PROJ,
GATE_PROJ,
UP_PROJ,
DOWN_PROJ,
] # LM_HEAD
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


class FlashLlama(FlashCausalLM):
Expand Down Expand Up @@ -123,32 +109,32 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:

prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = (
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, K_PROJ)] = (
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, V_PROJ)] = (
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, O_PROJ)] = (
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)

layer_weights[(i, GATE_PROJ)] = (
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, UP_PROJ)] = (
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, DOWN_PROJ)] = (
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
Expand All @@ -162,7 +148,7 @@ def adapter_layers(self) -> List[str]:

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]
return ["q_proj", "v_proj"]

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
Expand Down
51 changes: 19 additions & 32 deletions server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,16 @@
tracer = trace.get_tracer(__name__)


Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"

GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"

LM_HEAD = "lm_head"


# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [
Q_PROJ,
K_PROJ,
V_PROJ,
O_PROJ,
GATE_PROJ,
UP_PROJ,
DOWN_PROJ,
] # LM_HEAD
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"}


class BaseFlashMistral(FlashCausalLM):
Expand Down Expand Up @@ -133,37 +120,37 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:

prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = (
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, K_PROJ)] = (
layer_weights[(i, "k_proj")] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, V_PROJ)] = (
layer_weights[(i, "v_proj")] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, O_PROJ)] = (
layer_weights[(i, "o_proj")] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)

layer_weights[(i, GATE_PROJ)] = (
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, UP_PROJ)] = (
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, DOWN_PROJ)] = (
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)

layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
return layer_weights

@property
Expand All @@ -172,10 +159,10 @@ def adapter_layers(self) -> List[str]:

@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]
return ["q_proj", "v_proj"]

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
return 1 if layer_type == "lm_head" else len(self.model.model.layers)

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ async def serve_inner(
density=1.0,
majority_sign_method=0,
)
adapter_index = index
adapter_index = index + 1
adapter_to_index[adapter_id] = adapter_index
model.load_adapter(
adapter_parameters,
Expand Down
74 changes: 0 additions & 74 deletions server/text_generation_server/utils/lora.py

This file was deleted.

0 comments on commit f5e75e7

Please sign in to comment.