diff --git a/server/text_generation_server/adapters/lora.py b/server/text_generation_server/adapters/lora.py index 458a22e1dc2..fb07ca289d6 100644 --- a/server/text_generation_server/adapters/lora.py +++ b/server/text_generation_server/adapters/lora.py @@ -8,7 +8,6 @@ from text_generation_server.adapters.config import AdapterConfig, ModuleMap -LORA = "lora" from text_generation_server.adapters.weights import ( AdapterBatchMetadata, AdapterWeights, @@ -246,7 +245,7 @@ def can_vectorize(self, pg: ProcessGroup) -> bool: @classmethod def key(cls) -> str: - return LORA + return "lora" @classmethod def load( @@ -279,9 +278,12 @@ def load( } max_rank = max( - adapter_weights[idx].lora_a_r - for idx in segment_indices - if idx in adapter_weights + ( + adapter_weights[idx].lora_a_r + for idx in segment_indices + if idx in adapter_weights + ), + default=0, ) if prefill or max_rank > BGMV_MAX_RANK: diff --git a/server/text_generation_server/adapters/weights.py b/server/text_generation_server/adapters/weights.py index 2ed08df568c..50c072caf5d 100644 --- a/server/text_generation_server/adapters/weights.py +++ b/server/text_generation_server/adapters/weights.py @@ -1,4 +1,3 @@ -############# from abc import ABC, abstractclassmethod from collections import defaultdict from dataclasses import dataclass @@ -7,10 +6,6 @@ import torch -LORA = "lora" -LM_HEAD = "lm_head" - - @dataclass class AdapterBatchMetadata: # [batch_size] @@ -127,7 +122,7 @@ def from_meta( if v.is_empty(): continue data[k] = v.get_data( - meta, prefill, prefill_head_indices if k == LM_HEAD else None + meta, prefill, prefill_head_indices if k == "lm_head" else None ) return AdapterBatchData(meta=meta, data=data, prefill=prefill) @@ -135,7 +130,7 @@ def ranks(self) -> Set[int]: # TODO(travis): refactor to be less coupled to lora implementation ranks = set() for layer_data in self.data.values(): - lora_data = layer_data.get(LORA) + lora_data = layer_data.get("lora") if lora_data is None: continue diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 71a2c9210e5..c3ee0a17796 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -77,7 +77,6 @@ def serve( if otlp_endpoint is not None: setup_tracing(shard=os.getenv("RANK", 0), otlp_endpoint=otlp_endpoint) - # TODO: determine if this api makes sense lora_adapter_ids = os.getenv("LORA_ADAPTERS", None) # split on comma and strip whitespace diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 704c8ef3b8b..95a7f77a709 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -52,16 +52,6 @@ except Exception as e: raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -# Constants -Q_PROJ = "q_proj" -K_PROJ = "k_proj" -V_PROJ = "v_proj" -O_PROJ = "o_proj" - -GATE_PROJ = "gate_proj" -UP_PROJ = "up_proj" -DOWN_PROJ = "down_proj" - def load_attention(config, prefix, weights, layer_id): # Only defined in granite. @@ -96,7 +86,7 @@ def load_attention(config, prefix, weights, layer_id): return TensorParallelMultiAdapterLinear.load( base_layer, layer_id, - [Q_PROJ, K_PROJ, V_PROJ], + ["q_proj", "k_proj", "v_proj"], sizes=[ head_size * config.num_attention_heads, head_size * config.num_key_value_heads, @@ -151,7 +141,7 @@ def __init__( self.o_proj = TensorParallelAdapterRowLinear.load( o_proj, index, - O_PROJ, + "o_proj", process_group=weights.process_group, ) @@ -259,7 +249,7 @@ def __init__(self, prefix, config, weights, index): self.gate_up_proj = TensorParallelMultiAdapterLinear.load( gate_up_proj, index, - [GATE_PROJ, UP_PROJ], + ["gate_proj", "up_proj"], sizes=[ config.intermediate_size, config.intermediate_size, @@ -277,7 +267,7 @@ def __init__(self, prefix, config, weights, index): self.down_proj = TensorParallelAdapterRowLinear.load( down_proj, index, - DOWN_PROJ, + "down_proj", process_group=weights.process_group, ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 1266f6de6ec..327e4a6faa0 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -20,31 +20,17 @@ tracer = trace.get_tracer(__name__) from text_generation_server.utils.import_utils import SYSTEM -from text_generation_server.utils.lora import LoraConfig -Q_PROJ = "q_proj" -K_PROJ = "k_proj" -V_PROJ = "v_proj" -O_PROJ = "o_proj" - -GATE_PROJ = "gate_proj" -UP_PROJ = "up_proj" -DOWN_PROJ = "down_proj" - -LM_HEAD = "lm_head" - - -# TODO(travis): re-enable LM_HEAD after resolving issues with outputs ADAPTER_LAYERS = [ - Q_PROJ, - K_PROJ, - V_PROJ, - O_PROJ, - GATE_PROJ, - UP_PROJ, - DOWN_PROJ, -] # LM_HEAD -ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} class FlashLlama(FlashCausalLM): @@ -123,32 +109,32 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = ( + layer_weights[(i, "q_proj")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, K_PROJ)] = ( + layer_weights[(i, "k_proj")] = ( f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, V_PROJ)] = ( + layer_weights[(i, "v_proj")] = ( f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, O_PROJ)] = ( + layer_weights[(i, "o_proj")] = ( f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj, ) - layer_weights[(i, GATE_PROJ)] = ( + layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, UP_PROJ)] = ( + layer_weights[(i, "up_proj")] = ( f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, DOWN_PROJ)] = ( + layer_weights[(i, "down_proj")] = ( f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj, ) @@ -162,7 +148,7 @@ def adapter_layers(self) -> List[str]: @property def default_traced_adapter_layers(self) -> List[str]: - return [Q_PROJ, V_PROJ] + return ["q_proj", "v_proj"] def get_num_layers_for_type(self, layer_type: str) -> int: return 1 if layer_type == LM_HEAD else len(self.model.model.layers) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 37cc0235a6f..90a95c41dc8 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -21,29 +21,16 @@ tracer = trace.get_tracer(__name__) -Q_PROJ = "q_proj" -K_PROJ = "k_proj" -V_PROJ = "v_proj" -O_PROJ = "o_proj" - -GATE_PROJ = "gate_proj" -UP_PROJ = "up_proj" -DOWN_PROJ = "down_proj" - -LM_HEAD = "lm_head" - - -# TODO(travis): re-enable LM_HEAD after resolving issues with outputs ADAPTER_LAYERS = [ - Q_PROJ, - K_PROJ, - V_PROJ, - O_PROJ, - GATE_PROJ, - UP_PROJ, - DOWN_PROJ, -] # LM_HEAD -ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD} + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} class BaseFlashMistral(FlashCausalLM): @@ -133,37 +120,37 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: prefix = "model.layers" for i, layer in enumerate(self.model.model.layers): - layer_weights[(i, Q_PROJ)] = ( + layer_weights[(i, "q_proj")] = ( f"{prefix}.{i}.self_attn.q_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, K_PROJ)] = ( + layer_weights[(i, "k_proj")] = ( f"{prefix}.{i}.self_attn.k_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, V_PROJ)] = ( + layer_weights[(i, "v_proj")] = ( f"{prefix}.{i}.self_attn.v_proj", layer.self_attn.query_key_value, ) - layer_weights[(i, O_PROJ)] = ( + layer_weights[(i, "o_proj")] = ( f"{prefix}.{i}.self_attn.o_proj", layer.self_attn.o_proj, ) - layer_weights[(i, GATE_PROJ)] = ( + layer_weights[(i, "gate_proj")] = ( f"{prefix}.{i}.mlp.gate_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, UP_PROJ)] = ( + layer_weights[(i, "up_proj")] = ( f"{prefix}.{i}.mlp.up_proj", layer.mlp.gate_up_proj, ) - layer_weights[(i, DOWN_PROJ)] = ( + layer_weights[(i, "down_proj")] = ( f"{prefix}.{i}.mlp.down_proj", layer.mlp.down_proj, ) - layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) + layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head) return layer_weights @property @@ -172,10 +159,10 @@ def adapter_layers(self) -> List[str]: @property def default_traced_adapter_layers(self) -> List[str]: - return [Q_PROJ, V_PROJ] + return ["q_proj", "v_proj"] def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == LM_HEAD else len(self.model.model.layers) + return 1 if layer_type == "lm_head" else len(self.model.model.layers) def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 334ad4f900b..a81da92a408 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -249,7 +249,7 @@ async def serve_inner( density=1.0, majority_sign_method=0, ) - adapter_index = index + adapter_index = index + 1 adapter_to_index[adapter_id] = adapter_index model.load_adapter( adapter_parameters, diff --git a/server/text_generation_server/utils/lora.py b/server/text_generation_server/utils/lora.py deleted file mode 100644 index 8eed3a970af..00000000000 --- a/server/text_generation_server/utils/lora.py +++ /dev/null @@ -1,74 +0,0 @@ -import json -from text_generation_server.utils import ( - hub, -) -import os - - -class LoraConfig: - def __init__( - self, - alpha_pattern=None, - auto_mapping=None, - base_model_name_or_path="", - bias="none", - fan_in_fan_out=False, - inference_mode=True, - init_lora_weights=True, - layer_replication=None, - layers_pattern=None, - layers_to_transform=None, - loftq_config=None, - lora_alpha=16, - lora_dropout=0.1, - megatron_config=None, - megatron_core="megatron.core", - modules_to_save=None, - peft_type="LORA", - r=8, - rank_pattern=None, - revision=None, - target_modules=None, - task_type="CAUSAL_LM", - use_dora=False, - use_rslora=False, - config_path=None, - ): - self.alpha_pattern = alpha_pattern or {} - self.auto_mapping = auto_mapping - self.base_model_name_or_path = base_model_name_or_path - self.bias = bias - self.fan_in_fan_out = fan_in_fan_out - self.inference_mode = inference_mode - self.init_lora_weights = init_lora_weights - self.layer_replication = layer_replication - self.layers_pattern = layers_pattern - self.layers_to_transform = layers_to_transform - self.loftq_config = loftq_config or {} - self.lora_alpha = lora_alpha - self.lora_dropout = lora_dropout - self.megatron_config = megatron_config - self.megatron_core = megatron_core - self.modules_to_save = modules_to_save - self.peft_type = peft_type - self.r = r - self.rank_pattern = rank_pattern or {} - self.revision = revision - self.target_modules = target_modules or ["q_proj", "v_proj"] - self.task_type = task_type - self.use_dora = use_dora - self.use_rslora = use_rslora - self.config_path = config_path - - @classmethod - def from_file(cls, filename): - with open(filename, "r") as f: - json_data = json.load(f) - return cls(**json_data, config_path=filename) - - # TODO: support fetching the model from the hub if it's not in the cache - @classmethod - def from_pretrained(cls, adapter_id, revision=None): - d = hub._get_cached_revision_directory(adapter_id, revision) - filename = os.path.join(d, "adapter_config.json") - return cls.from_file(filename)