From b28946d6955c327aa556bac94e2e437478282e3d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 11:13:51 +0000 Subject: [PATCH 01/22] Refactor dead code. --- .../text_generation_server/models/__init__.py | 27 ++-- .../models/flash_causal_lm.py | 106 ++++++++++++--- .../models/flash_mistral.py | 122 +----------------- 3 files changed, 109 insertions(+), 146 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 5ea432909c8..52499b33136 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -56,8 +56,12 @@ from text_generation_server.models.flash_rw import FlashRWSharded from text_generation_server.models.flash_gpt2 import FlashGPT2 from text_generation_server.models.flash_neox import FlashNeoXSharded - from text_generation_server.models.flash_llama import ( - FlashLlama, + + # from text_generation_server.models.flash_llama import ( + # FlashLlama, + # ) + from text_generation_server.models.custom_modeling.flash_llama_modeling import ( + FlashLlamaForCausalLM, ) from text_generation_server.models.flash_qwen2 import ( FlashQwen2, @@ -81,7 +85,9 @@ from text_generation_server.models.llava_next import LlavaNext from text_generation_server.models.idefics2 import Idefics2 from text_generation_server.models.flash_mistral import FlashMistral - from text_generation_server.models.flash_mixtral import FlashMixtral + from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( + FlashMistralForCausalLM, + ) from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx @@ -97,7 +103,7 @@ __all__.append(FlashNeoXSharded) __all__.append(FlashRWSharded) __all__.append(FlashSantacoderSharded) - __all__.append(FlashLlama) + # __all__.append(FlashLlama) __all__.append(IDEFICSSharded) __all__.append(FlashMistral) __all__.append(FlashMixtral) @@ -599,9 +605,10 @@ def get_model( elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: if FLASH_ATTENTION: - return FlashLlama( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashLlamaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -743,12 +750,14 @@ def get_model( if model_type == MISTRAL: if FLASH_ATTENTION: return FlashMistral( - model_id, - revision, + model_id=model_id, + model_class=FlashMistralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 4f276ed4f7d..6b7cceef557 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -10,7 +10,12 @@ from loguru import logger from dataclasses import dataclass from opentelemetry import trace -from transformers import PreTrainedTokenizerBase +from transformers import ( + PreTrainedTokenizerBase, + AutoConfig, + AutoTokenizer, + GenerationConfig, +) from typing import Iterable, Optional, Tuple, List, Type, Dict from text_generation_server.adapters import AdapterBatchData, AdapterBatchMetadata @@ -21,6 +26,12 @@ from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK from text_generation_server.utils.speculate import get_speculate +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, + hub, +) from text_generation_server.models.types import ( Batch, Tokens, @@ -803,25 +814,88 @@ class FlashCausalLM(Model): def __init__( self, model_id: str, - model: torch.nn.Module, - tokenizer: PreTrainedTokenizerBase, - num_layers: int, - num_kv_heads: int, - head_size: int, - dtype: torch.dtype, - device: torch.device, - rank: int = 0, - world_size: int = 1, - sliding_window: Optional[int] = None, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + lora_adapter_ids: Optional[list] = [], + tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + default_dtype=torch.float16, + # self, + # model_id: str, + # model_class, + # tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + # num_layers: int, + # num_kv_heads: int, + # head_size: int, + # dtype: torch.dtype, + # device: torch.device, + # rank: int = 0, + # world_size: int = 1, + # sliding_window: Optional[int] = None, ): - self.num_layers = num_layers - self.num_kv_heads = num_kv_heads - self.head_size = head_size + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype + else: + raise NotImplementedError(f"{model_class} is only available on GPU") + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + try: + generation_config = GenerationConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + if isinstance(generation_config.eos_token_id, (list, set)): + # TODO Huge hack + tokenizer._eos_token_ids = set(generation_config.eos_token_id) + except Exception: + pass + + config = AutoConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + config.quantize = quantize + config.speculator = speculator + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + + torch.distributed.barrier(group=self.process_group) + + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + prefix = "" + model = model_class(prefix, config, weights) + torch.distributed.barrier(group=self.process_group) + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} self.kv_cache = [] - super(FlashCausalLM, self).__init__( + super().__init__( model_id=model_id, model=model, tokenizer=tokenizer, @@ -830,7 +904,7 @@ def __init__( device=device, rank=rank, world_size=world_size, - sliding_window=sliding_window, + sliding_window=config.sliding_window, ) @property diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0f5746debeb..c2482dc27aa 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -1,24 +1,7 @@ import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig from typing import Optional, Tuple, Dict, List from text_generation_server.models import FlashCausalLM -from text_generation_server.models.flash_causal_lm import set_sliding_window -from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( - FlashMistralForCausalLM, - MistralConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) ADAPTER_LAYERS = [ @@ -33,88 +16,7 @@ ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} -class BaseFlashMistral(FlashCausalLM): - def __init__( - self, - model_cls, - model_id: str, - config_cls=AutoConfig, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - tokenizer_class=AutoTokenizer, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashMistral is only available on GPU") - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = config_cls.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = model_cls(prefix, config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - num_layers, num_kv_heads, head_size = self.get_layer_config(model) - super().__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=num_layers, - num_kv_heads=num_kv_heads, - head_size=head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.model.layers), - model.model.num_key_value_heads, - model.model.head_size, - ) - +class FlashMistral(FlashCausalLM): @property def supports_adapter_loading(self) -> bool: return True @@ -183,25 +85,3 @@ def get_num_layers_for_type(self, layer_type: str) -> int: def is_row_parallel(self, layer_type: str) -> bool: return layer_type in ROW_PARALLEL - - -class FlashMistral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMistral, self).__init__( - config_cls=MistralConfig, - model_cls=FlashMistralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) From 69cb084b5fb2d3ca85790c308750e4ab6fbb398b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 11:25:18 +0000 Subject: [PATCH 02/22] First working step. --- .../text_generation_server/models/__init__.py | 12 ++++--- .../models/flash_mixtral.py | 31 ------------------- .../models/flash_qwen2.py | 11 ++----- .../models/flash_starcoder2.py | 5 ++- .../models/vlm_causal_lm.py | 4 +-- 5 files changed, 15 insertions(+), 48 deletions(-) delete mode 100644 server/text_generation_server/models/flash_mixtral.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 52499b33136..d159df88d85 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -88,6 +88,9 @@ from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) + from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( + FlashMixtralForCausalLM, + ) from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx @@ -106,7 +109,6 @@ # __all__.append(FlashLlama) __all__.append(IDEFICSSharded) __all__.append(FlashMistral) - __all__.append(FlashMixtral) __all__.append(FlashDbrx) __all__.append(FlashPhi) __all__.append(FlashQwen2) @@ -773,13 +775,15 @@ def get_model( if model_type == MIXTRAL: if FLASH_ATTENTION: - return FlashMixtral( - model_id, - revision, + return FlashMistral( + model_id=model_id, + model_class=FlashMixtralForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py deleted file mode 100644 index 587d423f94b..00000000000 --- a/server/text_generation_server/models/flash_mixtral.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch - -from typing import Optional - -from text_generation_server.models.flash_mistral import BaseFlashMistral -from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( - MixtralConfig, - FlashMixtralForCausalLM, -) - - -class FlashMixtral(BaseFlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - super(FlashMixtral, self).__init__( - config_cls=MixtralConfig, - model_cls=FlashMixtralForCausalLM, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index cd6078f1a40..4176aa05729 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -8,8 +8,7 @@ from typing import Optional from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, + FlashMistral, ) from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( Qwen2ForCausalLM, @@ -24,7 +23,7 @@ tracer = trace.get_tracer(__name__) -class FlashQwen2(BaseFlashMistral): +class FlashQwen2(FlashMistral): def __init__( self, model_id: str, @@ -62,10 +61,6 @@ def __init__( config.quantize = quantize config.speculator = speculator - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -78,7 +73,7 @@ def __init__( self.cuda_graphs = {} torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( + super(FlashMistral, self).__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 369e9e4c1f1..16c9a8b9c29 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -7,8 +7,7 @@ from transformers.models.gpt2 import GPT2TokenizerFast from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, - set_sliding_window, + FlashMistral, ) from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( Starcoder2Config, @@ -22,7 +21,7 @@ # Starcoder2 has the same base as Mistral -class FlashStarcoder2(BaseFlashMistral): +class FlashStarcoder2(FlashMistral): def __init__( self, model_id: str, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 1cdf37ea6b0..708f8ac63ab 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -11,7 +11,7 @@ from text_generation_server.pb import generate_pb2 from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch from text_generation_server.models.flash_mistral import ( - BaseFlashMistral, + FlashMistral, ) tracer = trace.get_tracer(__name__) @@ -239,7 +239,7 @@ def from_pb_processor( return batch -class VlmCausalLM(BaseFlashMistral): +class VlmCausalLM(FlashMistral): @property def batch_type(self) -> Type[VlmCausalLMBatch]: return VlmCausalLMBatch From ed34cf0222a8879aa45839ebb89bc70a5cfe5134 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 14:17:57 +0200 Subject: [PATCH 03/22] Remove a lot of duplicated code. --- .../text_generation_server/models/__init__.py | 202 +++++++++++------- .../flash_santacoder_modeling.py | 1 + .../models/flash_causal_lm.py | 20 +- .../models/flash_dbrx.py | 100 --------- .../models/flash_gemma.py | 83 ------- .../models/flash_gemma2.py | 83 ------- .../models/flash_gpt2.py | 82 ------- .../models/flash_neox.py | 82 ------- .../models/flash_phi.py | 111 ---------- .../models/flash_qwen2.py | 88 -------- .../text_generation_server/models/flash_rw.py | 91 -------- .../models/flash_santacoder.py | 99 --------- .../text_generation_server/models/idefics2.py | 51 ----- .../models/llava_next.py | 46 ---- .../models/pali_gemma.py | 26 --- .../models/vlm_causal_lm.py | 28 ++- 16 files changed, 156 insertions(+), 1037 deletions(-) delete mode 100644 server/text_generation_server/models/flash_dbrx.py delete mode 100644 server/text_generation_server/models/flash_gemma.py delete mode 100644 server/text_generation_server/models/flash_gemma2.py delete mode 100644 server/text_generation_server/models/flash_gpt2.py delete mode 100644 server/text_generation_server/models/flash_neox.py delete mode 100644 server/text_generation_server/models/flash_phi.py delete mode 100644 server/text_generation_server/models/flash_qwen2.py delete mode 100644 server/text_generation_server/models/flash_rw.py delete mode 100644 server/text_generation_server/models/flash_santacoder.py delete mode 100644 server/text_generation_server/models/idefics2.py delete mode 100644 server/text_generation_server/models/llava_next.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index d159df88d85..ddc1d461655 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -53,47 +53,62 @@ try: from text_generation_server.models.flash_causal_lm import FlashCausalLM - from text_generation_server.models.flash_rw import FlashRWSharded - from text_generation_server.models.flash_gpt2 import FlashGPT2 - from text_generation_server.models.flash_neox import FlashNeoXSharded - - # from text_generation_server.models.flash_llama import ( - # FlashLlama, - # ) + from text_generation_server.models.vlm_causal_lm import VlmCausalLM from text_generation_server.models.custom_modeling.flash_llama_modeling import ( FlashLlamaForCausalLM, ) - from text_generation_server.models.flash_qwen2 import ( - FlashQwen2, + from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( + FlashCohereForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( + FlashGemmaForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( + FlashGemma2ForCausalLM, ) - from text_generation_server.models.flash_cohere import ( - FlashCohere, + from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( + FlashDbrxForCausalLM, + DbrxConfig, ) - from text_generation_server.models.flash_gemma import ( - FlashGemma, + from text_generation_server.models.custom_modeling.flash_rw_modeling import ( + RWConfig, + FlashRWForCausalLM, ) - from text_generation_server.models.flash_gemma2 import ( - FlashGemma2, + from text_generation_server.models.custom_modeling.flash_neox_modeling import ( + FlashGPTNeoXForCausalLM, ) from text_generation_server.models.pali_gemma import ( - PaliGemma, + PaliGemmaBatch, ) - from text_generation_server.models.flash_santacoder import ( - FlashSantacoderSharded, + from text_generation_server.models.custom_modeling.flash_phi_modeling import ( + FlashPhiForCausalLM, ) from text_generation_server.models.idefics import IDEFICSSharded - from text_generation_server.models.llava_next import LlavaNext - from text_generation_server.models.idefics2 import Idefics2 + from text_generation_server.models.custom_modeling.llava_next import ( + LlavaNextForConditionalGeneration, + ) from text_generation_server.models.flash_mistral import FlashMistral + from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( + FlashSantacoderForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( + FlashStarcoder2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( + Qwen2ForCausalLM, + ) from text_generation_server.models.custom_modeling.flash_mistral_modeling import ( FlashMistralForCausalLM, ) from text_generation_server.models.custom_modeling.flash_mixtral_modeling import ( FlashMixtralForCausalLM, ) - from text_generation_server.models.flash_phi import FlashPhi - from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 - from text_generation_server.models.flash_dbrx import FlashDbrx + from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( + FlashGPT2ForCausalLM, + ) + from text_generation_server.models.custom_modeling.idefics2 import ( + Idefics2ForConditionalGeneration, + ) from text_generation_server.layers.attention import SUPPORTS_WINDOWING except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") @@ -102,20 +117,8 @@ if FLASH_ATTENTION: __all__.append(FlashCausalLM) - __all__.append(FlashGPT2) - __all__.append(FlashNeoXSharded) - __all__.append(FlashRWSharded) - __all__.append(FlashSantacoderSharded) - # __all__.append(FlashLlama) __all__.append(IDEFICSSharded) __all__.append(FlashMistral) - __all__.append(FlashDbrx) - __all__.append(FlashPhi) - __all__.append(FlashQwen2) - __all__.append(FlashStarcoder2) - __all__.append(FlashGemma) - __all__.append(FlashGemma2) - __all__.append(FlashCohere) MAMBA_AVAILABLE = True try: @@ -468,13 +471,16 @@ def get_model( and model_id.startswith("bigcode/") ): if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashSantacoderForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + aliases={"transformer.wte.weight": ["lm_head.weight"]}, ) elif sharded: raise NotImplementedError( @@ -511,13 +517,15 @@ def get_model( elif model_type == GPT2: if FLASH_ATTENTION: try: - return FlashGPT2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPT2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) except RuntimeError as e: # Lots of legacy models with various weight names. @@ -543,13 +551,15 @@ def get_model( ) elif model_type == GPT_NEOX: if FLASH_ATTENTION: - return FlashNeoXSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGPTNeoXForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: return GPTNeoxSharded( @@ -572,13 +582,15 @@ def get_model( elif model_type == PHI: if FLASH_ATTENTION: - return FlashPhi( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashPhiForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) else: return CausalLM( @@ -630,13 +642,15 @@ def get_model( ) if model_type == GEMMA: if FLASH_ATTENTION: - return FlashGemma( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemmaForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) @@ -651,13 +665,15 @@ def get_model( ) elif model_type == GEMMA2: if FLASH_ATTENTION: - return FlashGemma2( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashGemma2ForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) @@ -673,13 +689,15 @@ def get_model( if model_type == COHERE: if FLASH_ATTENTION: - return FlashCohere( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashCohereForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) @@ -695,13 +713,16 @@ def get_model( if model_type == DBRX: if FLASH_ATTENTION: - return FlashDbrx( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashDbrxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=DbrxConfig, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) @@ -720,24 +741,30 @@ def get_model( if FLASH_ATTENTION: if config_dict.get("alibi", False): raise NotImplementedError("sharded is not supported for this model") - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Falcon")) else: if FLASH_ATTENTION and not config_dict.get("alibi", False): - return FlashRWSharded( - model_id, - revision, + return FlashCausalLM( + model_id=model_id, + model_class=FlashRWForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + config_class=RWConfig, ) else: return RW( @@ -799,12 +826,15 @@ def get_model( if model_type == STARCODER2: if FLASH_ATTENTION: - return FlashStarcoder2( - model_id, - revision, + return FlashMistral( + model_id=model_id, + model_class=FlashStarcoder2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError( @@ -822,12 +852,15 @@ def get_model( if model_type == QWEN2: if FLASH_ATTENTION: - return FlashQwen2( - model_id, - revision, + return FlashMistral( + model_id=model_id, + model_class=Qwen2ForCausalLM, + revision=revision, quantize=quantize, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) @@ -874,34 +907,43 @@ def get_model( raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == IDEFICS2: if FLASH_ATTENTION: - return Idefics2( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=Idefics2ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + # XXX: Extremely important to cap resolution in order to limit + # VRAM usage. + processor_kwargs={"size": {"longest_edge": 448, "shortest_edge": 378}}, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == "paligemma": if FLASH_ATTENTION: - return PaliGemma( - model_id, - revision, + return VlmCausalLM( + model_id=model_id, + model_class=PaliGemmaForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + lora_adapter_ids=lora_adapter_ids, + batch_class=PaliGemmaBatch, ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) if model_type == LLAVA_NEXT: if FLASH_ATTENTION: - return LlavaNext( - model_id, - revision, + return VlmCausalLM( + model_class=LlavaNextForConditionalGeneration, + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 30989a375db..a77a76552cf 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -466,6 +466,7 @@ def forward( class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() + config.transpose = config.architectures[0].startswith("GPT2") self.transformer = FlashSantacoderModel(config, weights) self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 6b7cceef557..f2e66d56e21 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -822,19 +822,9 @@ def __init__( trust_remote_code: bool = False, lora_adapter_ids: Optional[list] = [], tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, + config_class: PreTrainedTokenizerBase = AutoConfig, default_dtype=torch.float16, - # self, - # model_id: str, - # model_class, - # tokenizer_class: PreTrainedTokenizerBase = AutoTokenizer, - # num_layers: int, - # num_kv_heads: int, - # head_size: int, - # dtype: torch.dtype, - # device: torch.device, - # rank: int = 0, - # world_size: int = 1, - # sliding_window: Optional[int] = None, + aliases=None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -868,7 +858,7 @@ def __init__( except Exception: pass - config = AutoConfig.from_pretrained( + config = config_class.from_pretrained( model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize @@ -881,7 +871,9 @@ def __init__( torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) + weights = Weights( + filenames, device, dtype, process_group=self.process_group, aliases=aliases + ) if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py deleted file mode 100644 index 2aba6a002bf..00000000000 --- a/server/text_generation_server/models/flash_dbrx.py +++ /dev/null @@ -1,100 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_dbrx_modeling import ( - FlashDbrxForCausalLM, - DbrxConfig, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashDbrx(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashDBRX is only available on GPU") - - try: - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - try: - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - except: - # FIXME: change back to model id once the tokenizer.json is merged - tokenizer = GPT2TokenizerFast.from_pretrained( - "Xenova/dbrx-instruct-tokenizer", - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = DbrxConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashDbrxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashDbrx, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py deleted file mode 100644 index 7e2b8780bdd..00000000000 --- a/server/text_generation_server/models/flash_gemma.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma_modeling import ( - FlashGemmaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemmaForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gemma2.py b/server/text_generation_server/models/flash_gemma2.py deleted file mode 100644 index 86cfc7e2bca..00000000000 --- a/server/text_generation_server/models/flash_gemma2.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import PretrainedConfig, AutoTokenizer - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gemma2_modeling import ( - FlashGemma2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGemma2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.bfloat16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGemma2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = PretrainedConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - # TODO hardcoded - prefix = "" - model = FlashGemma2ForCausalLM(prefix, config, weights, causal=True) - - torch.distributed.barrier(group=self.process_group) - super(FlashGemma2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py deleted file mode 100644 index 323fcafa8ae..00000000000 --- a/server/text_generation_server/models/flash_gpt2.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from transformers.models.gpt2 import GPT2Tokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_gpt2_modeling import ( - FlashGPT2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashGPT2(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashGPT2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashGPT2ForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashGPT2, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py deleted file mode 100644 index ac1fd5732f3..00000000000 --- a/server/text_generation_server/models/flash_neox.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_neox_modeling import ( - FlashGPTNeoXForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashNeoXSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashNeoX is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashGPTNeoXForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashNeoXSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.gpt_neox.layers), - num_kv_heads=model.gpt_neox.num_heads, - head_size=model.gpt_neox.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py deleted file mode 100644 index a530d1c3fe5..00000000000 --- a/server/text_generation_server/models/flash_phi.py +++ /dev/null @@ -1,111 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_phi_modeling import ( - FlashPhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashPhi(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashPhi is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashPhiForCausalLM(config, weights) - if speculator: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(speculator).exists() and Path(speculator).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - speculator, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - speculator, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(speculator) / "config.json") - medusa_head = str(Path(speculator) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - - torch.distributed.barrier(group=self.process_group) - super(FlashPhi, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py deleted file mode 100644 index 4176aa05729..00000000000 --- a/server/text_generation_server/models/flash_qwen2.py +++ /dev/null @@ -1,88 +0,0 @@ -import math - -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional - -from text_generation_server.models.flash_mistral import ( - FlashMistral, -) -from text_generation_server.models.custom_modeling.flash_qwen2_modeling import ( - Qwen2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashQwen2(FlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashQwen2 is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = Qwen2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(FlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py deleted file mode 100644 index b1f75adc83f..00000000000 --- a/server/text_generation_server/models/flash_rw.py +++ /dev/null @@ -1,91 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer -from typing import Optional - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_rw_modeling import ( - RWConfig, - FlashRWForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashRWSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashRW is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = RWConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device, - dtype, - process_group=self.process_group, - aliases={ - "lm_head.weight": ["transformer.word_embeddings.weight"], - "transformer.word_embeddings.weight": ["lm_head.weight"], - }, - ) - - config.quantize = quantize - config.speculator = speculator - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashRWForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashRWSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=model.transformer.cache_size, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py deleted file mode 100644 index e1a7b36e852..00000000000 --- a/server/text_generation_server/models/flash_santacoder.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoTokenizer, AutoConfig -from typing import Optional, List -import json -import os - -from huggingface_hub import hf_hub_download -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( - FlashSantacoderForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -from text_generation_server.utils.import_utils import SYSTEM - -tracer = trace.get_tracer(__name__) - - -class FlashSantacoderSharded(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashSantacoderSharded is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=True, - ) - config.quantize = quantize - config.speculator = speculator - config.transpose = config.architectures[0].startswith("GPT2") - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={"transformer.wte.weight": ["lm_head.weight"]}, - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashSantacoderForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashSantacoderSharded, self).__init__( - model_id=model_id, - model=model.to(device), - tokenizer=tokenizer, - num_layers=len(model.transformer.h), - num_kv_heads=1, - head_size=model.transformer.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py deleted file mode 100644 index 314c0500ddc..00000000000 --- a/server/text_generation_server/models/idefics2.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.idefics2 import ( - Idefics2ForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class Idefics2(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - # XXX: Extremely important to cap resolution in order to limit - # VRAM usage. - size={"longest_edge": 448, "shortest_edge": 378}, - ) - super().__init__( - model_cls=Idefics2ForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py deleted file mode 100644 index effe8b9107c..00000000000 --- a/server/text_generation_server/models/llava_next.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch - -from typing import Optional, Tuple - -from transformers import ( - AutoProcessor, -) -from text_generation_server.models.custom_modeling.llava_next import ( - LlavaNextForConditionalGeneration, -) - -from text_generation_server.models.vlm_causal_lm import VlmCausalLM - - -class LlavaNext(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - super().__init__( - model_cls=LlavaNextForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.language_model.model.layers), - model.language_model.model.num_key_value_heads, - model.language_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.language_model, "max_past", None) diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index a167e4679a5..533a47ea8e3 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -77,32 +77,6 @@ def batch_tokenized_inputs( class PaliGemma(VlmCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.processor = AutoProcessor.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - - super().__init__( - config_cls=AutoConfig, - model_cls=PaliGemmaForConditionalGeneration, - model_id=model_id, - revision=revision, - quantize=quantize, - speculator=speculator, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - @property def batch_type(self): return PaliGemmaBatch diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 708f8ac63ab..bee848e1a23 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -13,6 +13,7 @@ from text_generation_server.models.flash_mistral import ( FlashMistral, ) +from transformers import AutoProcessor tracer = trace.get_tracer(__name__) @@ -240,9 +241,34 @@ def from_pb_processor( class VlmCausalLM(FlashMistral): + def __init__( + self, + model_id: str, + *, + processor_class=AutoProcessor, + processor_kwargs=None, + batch_class=VlmCausalLMBatch, + revision, + trust_remote_code: bool, + **kwargs, + ): + if processor_kwargs is None: + processor_kwargs = {} + self.processor = processor_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + **processor_kwargs, + ) + self.batch_class = batch_class + super().__init__(**kwargs) + @property def batch_type(self) -> Type[VlmCausalLMBatch]: - return VlmCausalLMBatch + return self.batch_class + + def max_past(self) -> Optional[int]: + return getattr(self.model.text_model, "max_past", None) def forward( self, From 7d96b1a10374ea1c1eb2114043753a9a251772ba Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 14:45:35 +0200 Subject: [PATCH 04/22] More dead code. --- .../models/flash_cohere.py | 75 -------- .../models/flash_llama.py | 171 ------------------ .../models/flash_starcoder2.py | 83 --------- 3 files changed, 329 deletions(-) delete mode 100644 server/text_generation_server/models/flash_cohere.py delete mode 100644 server/text_generation_server/models/flash_llama.py delete mode 100644 server/text_generation_server/models/flash_starcoder2.py diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py deleted file mode 100644 index 9f8bcb3fbc8..00000000000 --- a/server/text_generation_server/models/flash_cohere.py +++ /dev/null @@ -1,75 +0,0 @@ -import torch -import torch.distributed - -from opentelemetry import trace -from typing import Optional -from transformers import AutoTokenizer, AutoConfig - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_cohere_modeling import ( - FlashCohereForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class FlashCohere(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashCohere is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - use_fast=True, - from_slow=False, - ) - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashCohereForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(FlashCohere, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py deleted file mode 100644 index d996b9c3a6e..00000000000 --- a/server/text_generation_server/models/flash_llama.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import torch -import torch.distributed - -from opentelemetry import trace -from transformers import AutoConfig, AutoTokenizer, GenerationConfig -from typing import Optional, Tuple, Dict, List - -from text_generation_server.models import FlashCausalLM -from text_generation_server.models.custom_modeling.flash_llama_modeling import ( - FlashLlamaForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, - hub, -) - -tracer = trace.get_tracer(__name__) - -from text_generation_server.utils.import_utils import SYSTEM - -ADAPTER_LAYERS = [ - "q_proj", - "k_proj", - "v_proj", - "o_proj", - "gate_proj", - "up_proj", - "down_proj", -] -ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} - - -class FlashLlama(FlashCausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - lora_adapter_ids: Optional[list] = [], - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "ipex": - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device = torch.device(f"xpu:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype - else: - raise NotImplementedError("FlashLlama is only available on GPU") - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - try: - generation_config = GenerationConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - if isinstance(generation_config.eos_token_id, (list, set)): - # TODO Huge hack - tokenizer._eos_token_ids = set(generation_config.eos_token_id) - except Exception: - pass - - config = AutoConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["awq", "exl2", "gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - prefix = "" - model = FlashLlamaForCausalLM(prefix, config, weights) - torch.distributed.barrier(group=self.process_group) - super(FlashLlama, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def supports_adapter_loading(self) -> bool: - return True - - def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: - layer_weights = {} - - prefix = "model.layers" - - # This accounts for VLMs (e.g. LlavaNext, Idefics2) - # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): - _model = self.model.text_model - else: - _model = self.model - - for i, layer in enumerate(_model.model.layers): - layer_weights[(i, "q_proj")] = ( - f"{prefix}.{i}.self_attn.q_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "k_proj")] = ( - f"{prefix}.{i}.self_attn.k_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "v_proj")] = ( - f"{prefix}.{i}.self_attn.v_proj", - layer.self_attn.query_key_value, - ) - layer_weights[(i, "o_proj")] = ( - f"{prefix}.{i}.self_attn.o_proj", - layer.self_attn.o_proj, - ) - - layer_weights[(i, "gate_proj")] = ( - f"{prefix}.{i}.mlp.gate_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "up_proj")] = ( - f"{prefix}.{i}.mlp.up_proj", - layer.mlp.gate_up_proj, - ) - layer_weights[(i, "down_proj")] = ( - f"{prefix}.{i}.mlp.down_proj", - layer.mlp.down_proj, - ) - - layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) - return layer_weights - - @property - def adapter_layers(self) -> List[str]: - return ADAPTER_LAYERS - - @property - def default_traced_adapter_layers(self) -> List[str]: - return ["q_proj", "v_proj"] - - def get_num_layers_for_type(self, layer_type: str) -> int: - return 1 if layer_type == "lm_head" else len(self.model.model.layers) - - def is_row_parallel(self, layer_type: str) -> bool: - return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py deleted file mode 100644 index 16c9a8b9c29..00000000000 --- a/server/text_generation_server/models/flash_starcoder2.py +++ /dev/null @@ -1,83 +0,0 @@ -import math - -import torch - -from typing import Optional - -from transformers.models.gpt2 import GPT2TokenizerFast - -from text_generation_server.models.flash_mistral import ( - FlashMistral, -) -from text_generation_server.models.custom_modeling.flash_starcoder2_modeling import ( - Starcoder2Config, - FlashStarcoder2ForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -# Starcoder2 has the same base as Mistral -class FlashStarcoder2(FlashMistral): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - raise NotImplementedError("FlashStarcoder2 is only available on GPU") - - tokenizer = GPT2TokenizerFast.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = Starcoder2Config.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - config.quantize = quantize - config.speculator = speculator - - # Set context windows - if config.sliding_window is not None: - set_sliding_window(config.sliding_window) - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "awq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = FlashStarcoder2ForCausalLM(config, weights) - - self.cuda_graphs = {} - - torch.distributed.barrier(group=self.process_group) - super(BaseFlashMistral, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - num_layers=len(model.model.layers), - num_kv_heads=model.model.num_key_value_heads, - head_size=model.model.head_size, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - sliding_window=config.sliding_window, - ) From ce913b874b14c9b0e84f2031562c55a9fcb61df9 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 16:11:53 +0200 Subject: [PATCH 05/22] More cleanup. --- .../text_generation_server/models/__init__.py | 121 ++++++++++-------- server/text_generation_server/models/bloom.py | 73 ----------- .../models/causal_lm.py | 76 ++++++++++- .../text_generation_server/models/gpt_neox.py | 89 ------------- server/text_generation_server/models/mpt.py | 105 --------------- server/text_generation_server/models/phi.py | 69 ---------- server/text_generation_server/models/rw.py | 84 ------------ .../models/santacoder.py | 77 ----------- .../models/seq2seq_lm.py | 80 +++++++++++- .../models/{opt.py => sharded_seq2seq_lm.py} | 53 ++++---- server/text_generation_server/models/t5.py | 115 ----------------- 11 files changed, 248 insertions(+), 694 deletions(-) delete mode 100644 server/text_generation_server/models/gpt_neox.py delete mode 100644 server/text_generation_server/models/mpt.py delete mode 100644 server/text_generation_server/models/phi.py delete mode 100644 server/text_generation_server/models/rw.py delete mode 100644 server/text_generation_server/models/santacoder.py rename server/text_generation_server/models/{opt.py => sharded_seq2seq_lm.py} (62%) delete mode 100644 server/text_generation_server/models/t5.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index ddc1d461655..fb7c8cbeea0 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -12,16 +12,26 @@ from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM +from text_generation_server.models.custom_modeling.mpt_modeling import ( + MPTForCausalLM, +) from text_generation_server.models.bloom import BLOOMSharded -from text_generation_server.models.mpt import MPTSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.rw import RW -from text_generation_server.models.opt import OPTSharded from text_generation_server.models.galactica import GalacticaSharded -from text_generation_server.models.santacoder import SantaCoder -from text_generation_server.models.t5 import T5Sharded -from text_generation_server.models.gpt_neox import GPTNeoxSharded -from text_generation_server.models.phi import Phi +from text_generation_server.models.custom_modeling.neox_modeling import ( + GPTNeoxForCausalLM, +) +from text_generation_server.models.custom_modeling.phi_modeling import ( + PhiConfig, + PhiForCausalLM, +) +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils.import_utils import SYSTEM @@ -41,9 +51,6 @@ "CausalLM", "GalacticaSharded", "Seq2SeqLM", - "SantaCoder", - "OPTSharded", - "T5Sharded", "get_model", ] @@ -457,8 +464,9 @@ def get_model( if model_id.startswith("facebook/galactica"): return GalacticaSharded( - model_id, - revision, + model_id=model_id, + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -487,9 +495,9 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") ) else: - return SantaCoder( - model_id, - revision, + return CausalLM.fallback( + model_id=model_id, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -498,17 +506,19 @@ def get_model( if model_type == BLOOM: return BLOOMSharded( - model_id, - revision, + model_id=model_id, + model_class=BloomForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) elif model_type == MPT: - return MPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=MPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -530,7 +540,7 @@ def get_model( except RuntimeError as e: # Lots of legacy models with various weight names. logger.warning(f"Couldn't load flash gpt2 variant: {e}") - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -541,7 +551,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded GPT-2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -562,16 +572,17 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) elif sharded: - return GPTNeoxSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=GPTNeoxForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -593,7 +604,7 @@ def get_model( lora_adapter_ids=lora_adapter_ids, ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -608,9 +619,11 @@ def get_model( "Legacy phi-msft is not supported with Flash Attention" ) else: - return Phi( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=PhiForCausalLM, + config_class=PhiConfig, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -632,7 +645,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -655,7 +668,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -678,7 +691,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -702,7 +715,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Cohere")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -727,7 +740,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded DBRX")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -767,7 +780,7 @@ def get_model( config_class=RWConfig, ) else: - return RW( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -791,7 +804,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mistral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -815,7 +828,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Mixtral")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -841,7 +854,7 @@ def get_model( FLASH_ATT_ERROR_MESSAGE.format("Sharded Starcoder2") ) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -865,7 +878,7 @@ def get_model( elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen2")) else: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -875,9 +888,10 @@ def get_model( ) if model_type == OPT: - return OPTSharded( - model_id, - revision, + return CausalLM( + model_id=model_id, + model_class=OPTForCausalLM, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, @@ -885,13 +899,20 @@ def get_model( ) if model_type == T5: - return T5Sharded( - model_id, - revision, + return Seq2SeqLM( + model_id=model_id, + model_class=T5ForConditionalGeneration, + revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + aliases={ + "shared.weight": [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + }, ) if model_type == IDEFICS: if FLASH_ATTENTION: @@ -967,7 +988,7 @@ def get_model( elif quantize == "exl2": raise NotImplementedError("exl2 quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -976,7 +997,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type in modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES: - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, @@ -988,7 +1009,7 @@ def get_model( auto_map = config_dict.get("auto_map", None) if trust_remote_code and auto_map is not None: if "AutoModelForCausalLM" in auto_map.keys(): - return CausalLM( + return CausalLM.fallback( model_id, revision, quantize=quantize, @@ -997,7 +1018,7 @@ def get_model( trust_remote_code=trust_remote_code, ) if "AutoModelForSeq2SeqLM" in auto_map.keys(): - return Seq2SeqLM( + return Seq2SeqLM.fallback( model_id, revision, quantize=quantize, diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 17aa12e84dc..732b4c5394c 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -4,22 +4,12 @@ from typing import Optional, Type from transformers import ( - AutoTokenizer, - AutoConfig, PreTrainedTokenizerBase, ) -from text_generation_server.models.custom_modeling.bloom_modeling import ( - BloomForCausalLM, -) from text_generation_server.models import CausalLM from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.pb import generate_pb2 -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) class BloomCausalLMBatch(CausalLMBatch): @@ -37,69 +27,6 @@ def from_pb( class BLOOMSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - slow_but_exact=False, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.pad_token_id = 3 - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - prefix="transformer", - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = BloomForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - @property def batch_type(self) -> Type[CausalLMBatch]: return BloomCausalLMBatch diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 10c64c6611f..71a59fee5c5 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -1,11 +1,22 @@ import torch import time +import torch.distributed from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase +from transformers import ( + AutoConfig, + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens @@ -482,6 +493,67 @@ class CausalLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + tokenizer_class=AutoTokenizer, + config_class=AutoConfig, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = config.pad_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, device=device, dtype=dtype, process_group=self.process_group + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -537,7 +609,7 @@ def __init__( else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, self).__init__( + super(CausalLM, cls).__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py deleted file mode 100644 index c37cfb7da72..00000000000 --- a/server/text_generation_server/models/gpt_neox.py +++ /dev/null @@ -1,89 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional - -from transformers import ( - AutoTokenizer, - AutoConfig, -) -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.neox_modeling import ( - GPTNeoxForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class GPTNeoxSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = GPTNeoxForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py deleted file mode 100644 index 1e79b25f263..00000000000 --- a/server/text_generation_server/models/mpt.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import torch.distributed - -from pathlib import Path -from typing import Optional, Type -from opentelemetry import trace -from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase -from huggingface_hub import hf_hub_download -import json - -from text_generation_server.models import CausalLM -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.pb import generate_pb2 -from text_generation_server.models.custom_modeling.mpt_modeling import ( - MPTForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - -tracer = trace.get_tracer(__name__) - - -class MPTCausalLMBatch(CausalLMBatch): - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) - batch.keys_head_dim_last = False - return batch - - -class MPTSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.pad_token = tokenizer.eos_token - - # If model_id is a local path, load the file directly - local_path = Path(model_id, "config.json") - if local_path.exists(): - filename = str(local_path.resolve()) - else: - filename = hf_hub_download( - model_id, revision=revision, filename="config.json" - ) - with open(filename, "r") as f: - config = json.load(f) - config = PretrainedConfig(**config) - config.quantize = quantize - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - config.quantize = quantize - model = MPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=False, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return MPTCausalLMBatch diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py deleted file mode 100644 index 93d42b2b8dc..00000000000 --- a/server/text_generation_server/models/phi.py +++ /dev/null @@ -1,69 +0,0 @@ -import torch -import torch.distributed - -from transformers import AutoConfig, AutoTokenizer -from typing import Optional, List, Tuple - -from text_generation_server.models import CausalLM -from text_generation_server.models.custom_modeling.phi_modeling import ( - PhiConfig, - PhiForCausalLM, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class Phi(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, _rank, _world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - config = PhiConfig.from_pretrained( - model_id, revision=revision, trust_remote_code=trust_remote_code - ) - - tokenizer.bos_token_id = config.bos_token_id - tokenizer.eos_token_id = config.eos_token_id - tokenizer.pad_token = tokenizer.eos_token - - config.quantize = quantize - config.speculator = speculator - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights(filenames, device, dtype, process_group=self.process_group) - model = PhiForCausalLM(config, weights) - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py deleted file mode 100644 index 37ca277b7e0..00000000000 --- a/server/text_generation_server/models/rw.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch - -from transformers import AutoTokenizer, AutoModelForCausalLM -from typing import List, Optional, Tuple - -from text_generation_server.models import CausalLM - - -class RW(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if speculator: - raise RuntimeError("Medusa decoding is not enabled for AutoModel") - - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - device_map=( - "auto" - if torch.cuda.is_available() and torch.cuda.device_count() > 1 - else None - ), - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - if torch.cuda.is_available() and torch.cuda.device_count() == 1: - model = model.cuda() - - if tokenizer.pad_token_id is None: - if model.config.pad_token_id is not None: - tokenizer.pad_token_id = model.config.pad_token_id - elif model.config.eos_token_id is not None: - tokenizer.pad_token_id = model.config.eos_token_id - elif tokenizer.eos_token_id is not None: - tokenizer.pad_token_id = tokenizer.eos_token_id - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py deleted file mode 100644 index caddbe191b3..00000000000 --- a/server/text_generation_server/models/santacoder.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import torch.distributed - -from typing import Optional, List -from transformers import AutoTokenizer, AutoModelForCausalLM - -from text_generation_server.models import CausalLM - -FIM_PREFIX = "" -FIM_MIDDLE = "" -FIM_SUFFIX = "" -FIM_PAD = "" -EOD = "<|endoftext|>" - - -class SantaCoder(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - if torch.cuda.is_available(): - device = torch.device("cuda") - dtype = torch.float16 if dtype is None else dtype - else: - if quantize: - raise ValueError("quantization is not available on CPU") - - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.add_special_tokens( - { - "additional_special_tokens": [ - EOD, - FIM_PREFIX, - FIM_MIDDLE, - FIM_SUFFIX, - FIM_PAD, - ], - "pad_token": EOD, - } - ) - with device: - model = AutoModelForCausalLM.from_pretrained( - model_id, - revision=revision, - torch_dtype=dtype, - load_in_8bit=quantize == "bitsandbytes", - trust_remote_code=trust_remote_code, - ) - - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - ) - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index d454d80477a..e3684071302 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -1,11 +1,22 @@ import torch +import torch.distributed import time from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, PreTrainedTokenizerBase +from transformers import ( + AutoTokenizer, + AutoModelForSeq2SeqLM, + PreTrainedTokenizerBase, + AutoConfig, +) from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models import Model @@ -531,6 +542,71 @@ class Seq2SeqLM(Model): def __init__( self, model_id: str, + model_class, + revision: Optional[str] = None, + quantize: Optional[str] = None, + speculator: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + config_class=AutoConfig, + tokenizer_class=AutoTokenizer, + aliases=None, + ): + self.process_group, rank, world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device(f"cuda:{rank}") + dtype = torch.float16 if dtype is None else dtype + else: + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + config = config_class.from_pretrained( + model_id, + revision=revision, + trust_remote_code=trust_remote_code, + ) + config.quantize = quantize + config.speculator = speculator + + tokenizer = tokenizer_class.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + tokenizer.bos_token_id = config.decoder_start_token_id + + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights( + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases=aliases, + ) + if config.quantize in ["awq", "exl2", "gptq", "marlin"]: + weights._set_gptq_params(model_id, revision) + + model = model_class(config, weights) + + torch.distributed.barrier(group=self.process_group) + super().__init__( + model_id=model_id, + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + rank=rank, + world_size=world_size, + ) + + @classmethod + def fallback( + cls, + model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, @@ -574,7 +650,7 @@ def __init__( ) tokenizer.bos_token_id = model.config.decoder_start_token_id - super(Seq2SeqLM, self).__init__( + super(Seq2SeqLM, cls).__init__( model_id=model_id, model=model, tokenizer=tokenizer, diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/sharded_seq2seq_lm.py similarity index 62% rename from server/text_generation_server/models/opt.py rename to server/text_generation_server/models/sharded_seq2seq_lm.py index 6d7d07f59c3..b73df83ab66 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/sharded_seq2seq_lm.py @@ -1,14 +1,17 @@ import torch import torch.distributed -from typing import Optional +from typing import List, Optional, Tuple from transformers import ( AutoTokenizer, AutoConfig, ) -from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM -from text_generation_server.models import CausalLM + +from text_generation_server.models import Seq2SeqLM +from text_generation_server.models.custom_modeling.t5_modeling import ( + T5ForConditionalGeneration, +) from text_generation_server.utils import ( initialize_torch_distributed, weight_files, @@ -16,15 +19,19 @@ ) -class OPTSharded(CausalLM): +class ShardedSeq2SeqLM(Seq2SeqLM): def __init__( self, model_id: str, + model_class, revision: Optional[str] = None, quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, + config_class=AutoConfig, + tokenizer_class=AutoTokenizer, + aliases=None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -34,35 +41,37 @@ def __init__( device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype - tokenizer = AutoTokenizer.from_pretrained( + config = config_class.from_pretrained( model_id, revision=revision, - padding_side="left", - truncation_side="left", trust_remote_code=trust_remote_code, ) + config.quantize = quantize + config.speculator = speculator - config = AutoConfig.from_pretrained( + tokenizer = tokenizer_class.from_pretrained( model_id, revision=revision, + padding_side="left", + truncation_side="left", trust_remote_code=trust_remote_code, ) - config.quantize = quantize - config.speculator = speculator - tokenizer.pad_token_id = config.pad_token_id + tokenizer.bos_token_id = config.decoder_start_token_id torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group + filenames, + device=device, + dtype=dtype, + process_group=self.process_group, + aliases=aliases, ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - model = OPTForCausalLM(config, weights) + model = model_class(config, weights) torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( + super(Seq2SeqLM, self).__init__( model_id=model_id, model=model, tokenizer=tokenizer, @@ -72,15 +81,3 @@ def __init__( rank=rank, world_size=world_size, ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py deleted file mode 100644 index adef664c75e..00000000000 --- a/server/text_generation_server/models/t5.py +++ /dev/null @@ -1,115 +0,0 @@ -import torch -import torch.distributed - -from typing import List, Optional, Tuple - -from transformers import ( - AutoTokenizer, - AutoConfig, -) - -from text_generation_server.models import Seq2SeqLM -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class T5Sharded(Seq2SeqLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.bos_token_id = config.decoder_start_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases={ - "shared.weight": [ - "encoder.embed_tokens.weight", - "decoder.embed_tokens.weight", - ] - }, - ) - - model = T5ForConditionalGeneration(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(Seq2SeqLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - def forward( - self, - input_ids, - attention_mask, - decoder_input_ids, - decoder_attention_mask: Optional, - encoder_last_hidden_state: Optional, - past_key_values: Optional = None, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], - ]: - # Model Forward - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - encoder_outputs=encoder_last_hidden_state, - past_key_values=past_key_values, - use_cache=True, - ) - - return ( - outputs.logits, - speculative_logits, - outputs.encoder_last_hidden_state, - outputs.past_key_values, - ) From db9acc44182b69d81d9901fca195e4c90d518a02 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 16:55:48 +0200 Subject: [PATCH 06/22] Fix Santacoder test. --- server/tests/models/test_santacoder.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/tests/models/test_santacoder.py b/server/tests/models/test_santacoder.py index cb2622d9b53..d5c91bffc84 100644 --- a/server/tests/models/test_santacoder.py +++ b/server/tests/models/test_santacoder.py @@ -1,13 +1,12 @@ import pytest from text_generation_server.pb import generate_pb2 -from text_generation_server.models.causal_lm import CausalLMBatch -from text_generation_server.models.santacoder import SantaCoder +from text_generation_server.models.causal_lm import CausalLMBatch, CausalLM @pytest.fixture(scope="session") def default_santacoder(): - return SantaCoder("bigcode/santacoder") + return CausalLM.fallback(model_id="bigcode/santacoder") @pytest.fixture From 298500a08eedaef076fa410a2cee49db079c563d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 15:13:24 +0000 Subject: [PATCH 07/22] Fixing the simple tests. --- server/tests/models/test_bloom.py | 8 +++++++- server/tests/models/test_causal_lm.py | 2 +- server/tests/models/test_seq2seq_lm.py | 2 +- server/text_generation_server/models/causal_lm.py | 7 ++++++- server/text_generation_server/models/seq2seq_lm.py | 7 ++++++- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/server/tests/models/test_bloom.py b/server/tests/models/test_bloom.py index 32ee6686b6b..08292920e2f 100644 --- a/server/tests/models/test_bloom.py +++ b/server/tests/models/test_bloom.py @@ -8,6 +8,9 @@ from text_generation_server.models.causal_lm import CausalLMBatch from text_generation_server.utils import weight_hub_files, download_weights from text_generation_server.models.bloom import BloomCausalLMBatch, BLOOMSharded +from text_generation_server.models.custom_modeling.bloom_modeling import ( + BloomForCausalLM, +) @pytest.fixture(scope="session") @@ -16,7 +19,10 @@ def default_bloom(): revision = "main" filenames = weight_hub_files(model_id, revision, ".safetensors") download_weights(filenames, model_id, revision) - return BLOOMSharded(model_id) + return BLOOMSharded( + model_id, + model_class=BloomForCausalLM, + ) @pytest.fixture(scope="session") diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index 6e6463bc948..c000ef26d8a 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -10,7 +10,7 @@ @pytest.fixture(scope="session") def default_causal_lm(): - return CausalLM("gpt2") + return CausalLM.fallback("gpt2") @pytest.fixture(scope="session") diff --git a/server/tests/models/test_seq2seq_lm.py b/server/tests/models/test_seq2seq_lm.py index 943c3b0820d..02666042040 100644 --- a/server/tests/models/test_seq2seq_lm.py +++ b/server/tests/models/test_seq2seq_lm.py @@ -20,7 +20,7 @@ def mt0_small_tokenizer(): @pytest.fixture(scope="session") def default_seq2seq_lm(): - return Seq2SeqLM("bigscience/mt0-small") + return Seq2SeqLM.fallback("bigscience/mt0-small") @pytest.fixture diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 71a59fee5c5..685177c7782 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -609,7 +609,11 @@ def fallback( else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - super(CausalLM, cls).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -617,6 +621,7 @@ def fallback( dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[CausalLMBatch]: diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e3684071302..38695b1952f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -650,7 +650,11 @@ def fallback( ) tokenizer.bos_token_id = model.config.decoder_start_token_id - super(Seq2SeqLM, cls).__init__( + self = cls.__new__( + cls, + ) + super().__init__( + self, model_id=model_id, model=model, tokenizer=tokenizer, @@ -658,6 +662,7 @@ def fallback( dtype=dtype, device=device, ) + return self @property def batch_type(self) -> Type[Seq2SeqLMBatch]: From b2fb845923eaec07de43c23405476a1fd8885900 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 15:37:27 +0000 Subject: [PATCH 08/22] Fixing sharding. --- server/text_generation_server/models/flash_causal_lm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f2e66d56e21..a5da215a75a 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -881,7 +881,9 @@ def __init__( model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) self.num_layers = config.num_hidden_layers - self.num_kv_heads = config.num_key_value_heads + + # Validation is done in the model itself + self.num_kv_heads = config.num_key_value_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} From 43ef5268fd1c5161227eee1ec1cff5f99d47ea59 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 16:02:33 +0000 Subject: [PATCH 09/22] Fixes for VLM. --- .../models/custom_modeling/llava_next.py | 8 ++++---- server/text_generation_server/models/flash_causal_lm.py | 5 ++++- server/text_generation_server/models/vlm_causal_lm.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 6d38442cc02..567131ef7ad 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -136,7 +136,7 @@ def __init__(self, prefix, config, weights): self.config = config config.text_config.quantize = config.quantize config.text_config.speculator = config.speculator - self.language_model = load_text_model( + self.text_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, weights=weights, @@ -180,7 +180,7 @@ def forward( image_sizes: Optional[torch.LongTensor] = None, adapter_data: Optional[torch.Tensor] = None, ): - inputs_embeds = self.language_model.embed_tokens(input_ids) + inputs_embeds = self.text_model.embed_tokens(input_ids) if pixel_values is not None and len(pixel_values) > 0: # num_special_image_tokens = (input_ids == self.config.image_token_index).sum() # assert num_special_image_tokens == len(pixel_values), f"Received {num_special_image_tokens} for {len(pixel_values)} images, this is invalid" @@ -269,7 +269,7 @@ def forward( input_ids, inputs_embeds, image_features ) - hidden_states = self.language_model.model( + hidden_states = self.text_model.model( inputs_embeds=inputs_embeds, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -283,5 +283,5 @@ def forward( ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits, speculative_logits = self.language_model.lm_head(hidden_states) + logits, speculative_logits = self.text_model.lm_head(hidden_states) return logits, speculative_logits diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index a5da215a75a..07e9f97f306 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -880,8 +880,11 @@ def __init__( prefix = "" model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) - self.num_layers = config.num_hidden_layers + text_config = getattr(config, "text_config", None) + if text_config is not None: + config = text_config + self.num_layers = config.num_hidden_layers # Validation is done in the model itself self.num_kv_heads = config.num_key_value_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index bee848e1a23..90c4c46e330 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -261,7 +261,7 @@ def __init__( **processor_kwargs, ) self.batch_class = batch_class - super().__init__(**kwargs) + super().__init__(model_id=model_id, **kwargs) @property def batch_type(self) -> Type[VlmCausalLMBatch]: From dbf9292afc1bec7737735ed33e0d1916da210d62 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 16:35:08 +0000 Subject: [PATCH 10/22] Fixing santacoder (num_kv_heads hardcoded). --- server/text_generation_server/models/__init__.py | 4 ++++ .../models/custom_modeling/flash_santacoder_modeling.py | 2 +- server/text_generation_server/models/flash_causal_lm.py | 7 ++++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index fb7c8cbeea0..32c2168f52b 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -87,6 +87,9 @@ from text_generation_server.models.pali_gemma import ( PaliGemmaBatch, ) + from text_generation_server.models.custom_modeling.flash_pali_gemma_modeling import ( + PaliGemmaForConditionalGeneration, + ) from text_generation_server.models.custom_modeling.flash_phi_modeling import ( FlashPhiForCausalLM, ) @@ -489,6 +492,7 @@ def get_model( trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, aliases={"transformer.wte.weight": ["lm_head.weight"]}, + num_kv_heads=1, ) elif sharded: raise NotImplementedError( diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index a77a76552cf..2bc305fedad 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -464,7 +464,7 @@ def forward( class FlashSantacoderForCausalLM(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__() config.transpose = config.architectures[0].startswith("GPT2") self.transformer = FlashSantacoderModel(config, weights) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 07e9f97f306..5f558caa892 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -825,6 +825,8 @@ def __init__( config_class: PreTrainedTokenizerBase = AutoConfig, default_dtype=torch.float16, aliases=None, + # Used for Santacoder override of config + num_kv_heads=None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -886,7 +888,10 @@ def __init__( config = text_config self.num_layers = config.num_hidden_layers # Validation is done in the model itself - self.num_kv_heads = config.num_key_value_heads // self.process_group.size() + num_heads = getattr(config, "num_key_value_heads", config.n_head) + if num_kv_heads is None: + num_kv_heads = config.num_key_value_heads + self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads self.cuda_graphs = {} From 24bbd7b822dbbb987d7270352e24ec49a5f65e78 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 16:46:52 +0000 Subject: [PATCH 11/22] Removing more dead code. --- server/text_generation_server/models/causal_lm.py | 9 +++++---- server/text_generation_server/models/flash_causal_lm.py | 1 + server/text_generation_server/models/seq2seq_lm.py | 9 +++++---- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 685177c7782..dce793f5b72 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -627,10 +627,11 @@ def fallback( def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch - def decode(self, generated_ids: List[int]) -> str: - return self.tokenizer.decode( - generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + # This is not used anymore + # def decode(self, generated_ids: List[int]) -> str: + # return self.tokenizer.decode( + # generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + # ) def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5f558caa892..77125f53028 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -827,6 +827,7 @@ def __init__( aliases=None, # Used for Santacoder override of config num_kv_heads=None, + skip_special_tokens: bool = True, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 38695b1952f..5d16c364658 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -668,10 +668,11 @@ def fallback( def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch - def decode(self, decoder_ids: List[int]) -> str: - return self.tokenizer.decode( - decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + # Not used anymore + # def decode(self, decoder_ids: List[int]) -> str: + # return self.tokenizer.decode( + # decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + # ) def forward( self, From e8ff76fd187f449958a21f2d2aa2af3ff3211490 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 2 Jul 2024 17:01:25 +0000 Subject: [PATCH 12/22] Fixing `config.n_head`. --- server/text_generation_server/models/flash_causal_lm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 77125f53028..5fc16fe3893 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -884,14 +884,17 @@ def __init__( model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) + # VLM models define the config we care about in their text_config text_config = getattr(config, "text_config", None) if text_config is not None: config = text_config self.num_layers = config.num_hidden_layers # Validation is done in the model itself - num_heads = getattr(config, "num_key_value_heads", config.n_head) if num_kv_heads is None: - num_kv_heads = config.num_key_value_heads + num_kv_heads = getattr(config, "num_key_value_heads", None) + if num_kv_heads is None: + # Final overide for GPT2 + num_kv_heads = config.n_head self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads From 2259d2f78add6ae5aff37b90f05b70fc1817aad7 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 08:09:58 +0000 Subject: [PATCH 13/22] Stopping earlier because of `` in idefics2. --- .../test_flash_idefics2_two_images.json | 48 ++++++++----------- integration-tests/models/test_idefics2.py | 2 +- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json index bf2dc5a1ffe..44ccea7111a 100644 --- a/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json +++ b/integration-tests/models/__snapshots__/test_idefics2/test_flash_idefics2_two_images.json @@ -1,130 +1,124 @@ { "details": { "best_of_sequences": null, - "finish_reason": "length", - "generated_tokens": 20, + "finish_reason": "eos_token", + "generated_tokens": 19, "prefill": [], "seed": null, "tokens": [ { "id": 415, - "logprob": -0.039886475, + "logprob": -0.03665161, "special": false, "text": " The" }, { "id": 12072, - "logprob": -0.1430664, + "logprob": -0.13549805, "special": false, "text": " cow" }, { "id": 349, - "logprob": -0.056488037, + "logprob": -0.05819702, "special": false, "text": " is" }, { "id": 6328, - "logprob": -0.6855469, + "logprob": -0.6826172, "special": false, "text": " standing" }, { "id": 356, - "logprob": -0.1685791, + "logprob": -0.1607666, "special": false, "text": " on" }, { "id": 272, - "logprob": -0.50097656, + "logprob": -0.5073242, "special": false, "text": " the" }, { "id": 10305, - "logprob": -0.017303467, + "logprob": -0.016418457, "special": false, "text": " beach" }, { "id": 304, - "logprob": -1.3564453, + "logprob": -1.3916016, "special": false, "text": " and" }, { "id": 272, - "logprob": -0.017868042, + "logprob": -0.020217896, "special": false, "text": " the" }, { "id": 13088, - "logprob": -0.0027103424, + "logprob": -0.0028133392, "special": false, "text": " chicken" }, { "id": 349, - "logprob": -0.003156662, + "logprob": -0.003145218, "special": false, "text": " is" }, { "id": 6398, - "logprob": -0.37304688, + "logprob": -0.37060547, "special": false, "text": " sitting" }, { "id": 356, - "logprob": -0.034576416, + "logprob": -0.034851074, "special": false, "text": " on" }, { "id": 264, - "logprob": -0.29418945, + "logprob": -0.2878418, "special": false, "text": " a" }, { "id": 17972, - "logprob": -0.042877197, + "logprob": -0.046051025, "special": false, "text": " pile" }, { "id": 302, - "logprob": -0.00028443336, + "logprob": -0.00028848648, "special": false, "text": " of" }, { "id": 2445, - "logprob": -0.023223877, + "logprob": -0.025772095, "special": false, "text": " money" }, { "id": 28723, - "logprob": -0.018157959, + "logprob": -0.018127441, "special": false, "text": "." }, { "id": 32002, - "logprob": -0.00018393993, + "logprob": -0.00019824505, "special": true, "text": "" - }, - { - "id": 2, - "logprob": -1.1920929e-07, - "special": true, - "text": "" } ], "top_tokens": null diff --git a/integration-tests/models/test_idefics2.py b/integration-tests/models/test_idefics2.py index 9aaf6d8ae4c..c5f48da3525 100644 --- a/integration-tests/models/test_idefics2.py +++ b/integration-tests/models/test_idefics2.py @@ -57,7 +57,7 @@ async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot) response.generated_text == " The cow is standing on the beach and the chicken is sitting on a pile of money." ), f"{repr(response.generated_text)}" - assert response.details.generated_tokens == 20 + assert response.details.generated_tokens == 19 assert response == response_snapshot From 9cc58d1cb333fc1e01edc3763cb3222635374838 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 13:29:19 +0000 Subject: [PATCH 14/22] Addresses comments. --- .../text_generation_server/models/__init__.py | 32 +++++-- .../models/causal_lm.py | 21 ++++- .../models/galactica.py | 80 ------------------ .../models/pali_gemma.py | 16 ---- .../models/seq2seq_lm.py | 11 ++- .../models/sharded_seq2seq_lm.py | 83 ------------------- 6 files changed, 55 insertions(+), 188 deletions(-) delete mode 100644 server/text_generation_server/models/sharded_seq2seq_lm.py diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 32c2168f52b..853968e10d3 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -11,17 +11,16 @@ from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model -from text_generation_server.models.causal_lm import CausalLM +from text_generation_server.models.causal_lm import CausalLM, CausalLMBatchKeysLast from text_generation_server.models.custom_modeling.opt_modeling import OPTForCausalLM from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, ) -from text_generation_server.models.bloom import BLOOMSharded from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) from text_generation_server.models.seq2seq_lm import Seq2SeqLM -from text_generation_server.models.galactica import GalacticaSharded +from text_generation_server.models.galactica import GalacticaCausalLMBatch from text_generation_server.models.custom_modeling.neox_modeling import ( GPTNeoxForCausalLM, ) @@ -169,6 +168,11 @@ class ModelType(enum.Enum): "name": "Gemma", "url": "https://huggingface.co/google/gemma-7b", } + PALIGEMMA = { + "type": "paligemma", + "name": "PaliGemma", + "url": "https://huggingface.co/google/paligemma-3b-pt-224", + } GEMMA2 = { "type": "gemma2", "name": "Gemma2", @@ -466,14 +470,16 @@ def get_model( ) if model_id.startswith("facebook/galactica"): - return GalacticaSharded( + return CausalLM( model_id=model_id, + # Yes galactica is just an OPT model. model_class=OPTForCausalLM, revision=revision, quantize=quantize, speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=GalacticaCausalLMBatch, ) if ( @@ -509,7 +515,7 @@ def get_model( ) if model_type == BLOOM: - return BLOOMSharded( + return CausalLM( model_id=model_id, model_class=BloomForCausalLM, revision=revision, @@ -517,6 +523,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == MPT: return CausalLM( @@ -527,6 +534,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, + batch_class=CausalLMBatchKeysLast, ) elif model_type == GPT2: if FLASH_ATTENTION: @@ -666,6 +674,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -689,6 +699,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, ) @@ -737,6 +749,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Dbrx works better in bfloat16. + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=DbrxConfig, @@ -765,6 +779,10 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + aliases={ + "lm_head.weight": ["transformer.word_embeddings.weight"], + "transformer.word_embeddings.weight": ["lm_head.weight"], + }, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, config_class=RWConfig, @@ -947,7 +965,7 @@ def get_model( ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) - if model_type == "paligemma": + if model_type == PALIGEMMA: if FLASH_ATTENTION: return VlmCausalLM( model_id=model_id, @@ -956,6 +974,8 @@ def get_model( quantize=quantize, speculator=speculator, dtype=dtype, + # Works better for these models + default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, batch_class=PaliGemmaBatch, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index dce793f5b72..ea0e4ae4759 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -489,6 +489,11 @@ def __len__(self): return len(self.requests) +@dataclass +class CausalLMBatchKeysLast(Batch): + keys_head_dim_last: bool = False + + class CausalLM(Model): def __init__( self, @@ -498,14 +503,25 @@ def __init__( quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, config_class=AutoConfig, + batch_class=CausalLMBatch, ): + self.batch_class = batch_class self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype @@ -612,6 +628,7 @@ def fallback( self = cls.__new__( cls, ) + self.batch_class = CausalLMBatch super().__init__( self, model_id=model_id, @@ -625,7 +642,7 @@ def fallback( @property def batch_type(self) -> Type[CausalLMBatch]: - return CausalLMBatch + return self.batch # This is not used anymore # def decode(self, generated_ids: List[int]) -> str: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index 30c92d90e27..2d43244a0fb 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -162,83 +162,3 @@ def from_pb( padding_right_offset=padding_right_offset, max_tokens=max_tokens, ) - - -class GalacticaSharded(CausalLM): - def __init__( - self, - model_id: str, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - tokenizer = AutoTokenizer.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - - config = AutoConfig.from_pretrained( - model_id, - revision=revision, - tp_parallel=True, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - tokenizer.pad_token_id = config.pad_token_id - config.speculator = speculator - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, device=device, dtype=dtype, process_group=self.process_group - ) - if config.quantize in ["gptq", "marlin"]: - weights._set_gptq_params(model_id, revision) - - model = OPTForCausalLM(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(CausalLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) - - @property - def batch_type(self) -> Type[CausalLMBatch]: - return GalacticaCausalLMBatch - - def decode(self, generated_ids: List[int]) -> str: - # Do not skip special tokens as they are used for custom parsing rules of the generated text - return self.tokenizer.decode( - generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - def forward( - self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ): - outputs, speculative_logits = self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - use_cache=True, - ) - return outputs.logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/pali_gemma.py b/server/text_generation_server/models/pali_gemma.py index 533a47ea8e3..3994ac70730 100644 --- a/server/text_generation_server/models/pali_gemma.py +++ b/server/text_generation_server/models/pali_gemma.py @@ -74,19 +74,3 @@ def batch_tokenized_inputs( else: image_inputs = None return batch_tokenized_inputs, image_inputs - - -class PaliGemma(VlmCausalLM): - @property - def batch_type(self): - return PaliGemmaBatch - - def get_layer_config(self, model) -> Tuple[int, int, int]: - return ( - len(model.text_model.model.layers), - model.text_model.model.num_key_value_heads, - model.text_model.model.head_size, - ) - - def max_past(self) -> Optional[int]: - return getattr(self.model.text_model, "max_past", None) diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 5d16c364658..bc30404e6d1 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -547,6 +547,7 @@ def __init__( quantize: Optional[str] = None, speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, + default_dtype=torch.float16, trust_remote_code: bool = False, config_class=AutoConfig, tokenizer_class=AutoTokenizer, @@ -555,7 +556,15 @@ def __init__( self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype + dtype = default_dtype if dtype is None else dtype + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + dtype = default_dtype if dtype is None else dtype + else: + device = torch.device("cpu") + # Float16 doesn't exist on target. + dtype = torch.bfloat16 if dtype is None else dtype else: device = torch.device("cpu") dtype = torch.float32 if dtype is None else dtype diff --git a/server/text_generation_server/models/sharded_seq2seq_lm.py b/server/text_generation_server/models/sharded_seq2seq_lm.py deleted file mode 100644 index b73df83ab66..00000000000 --- a/server/text_generation_server/models/sharded_seq2seq_lm.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -import torch.distributed - -from typing import List, Optional, Tuple - -from transformers import ( - AutoTokenizer, - AutoConfig, -) - -from text_generation_server.models import Seq2SeqLM -from text_generation_server.models.custom_modeling.t5_modeling import ( - T5ForConditionalGeneration, -) -from text_generation_server.utils import ( - initialize_torch_distributed, - weight_files, - Weights, -) - - -class ShardedSeq2SeqLM(Seq2SeqLM): - def __init__( - self, - model_id: str, - model_class, - revision: Optional[str] = None, - quantize: Optional[str] = None, - speculator: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - trust_remote_code: bool = False, - config_class=AutoConfig, - tokenizer_class=AutoTokenizer, - aliases=None, - ): - self.process_group, rank, world_size = initialize_torch_distributed() - if torch.cuda.is_available(): - device = torch.device(f"cuda:{rank}") - dtype = torch.float16 if dtype is None else dtype - else: - device = torch.device("cpu") - dtype = torch.float32 if dtype is None else dtype - - config = config_class.from_pretrained( - model_id, - revision=revision, - trust_remote_code=trust_remote_code, - ) - config.quantize = quantize - config.speculator = speculator - - tokenizer = tokenizer_class.from_pretrained( - model_id, - revision=revision, - padding_side="left", - truncation_side="left", - trust_remote_code=trust_remote_code, - ) - tokenizer.bos_token_id = config.decoder_start_token_id - - torch.distributed.barrier(group=self.process_group) - filenames = weight_files(model_id, revision=revision, extension=".safetensors") - weights = Weights( - filenames, - device=device, - dtype=dtype, - process_group=self.process_group, - aliases=aliases, - ) - - model = model_class(config, weights) - - torch.distributed.barrier(group=self.process_group) - super(Seq2SeqLM, self).__init__( - model_id=model_id, - model=model, - tokenizer=tokenizer, - requires_padding=True, - dtype=dtype, - device=device, - rank=rank, - world_size=world_size, - ) From fbf38c997c1aad02a7f2feb433a73bc8f1361526 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 13:34:37 +0000 Subject: [PATCH 15/22] Removing the dead code. --- server/text_generation_server/models/causal_lm.py | 6 ------ server/text_generation_server/models/seq2seq_lm.py | 6 ------ 2 files changed, 12 deletions(-) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index ea0e4ae4759..b8ed9f47aaa 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -644,12 +644,6 @@ def fallback( def batch_type(self) -> Type[CausalLMBatch]: return self.batch - # This is not used anymore - # def decode(self, generated_ids: List[int]) -> str: - # return self.tokenizer.decode( - # generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - # ) - def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ) -> Tuple[ diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index bc30404e6d1..dbaf1253093 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -677,12 +677,6 @@ def fallback( def batch_type(self) -> Type[Seq2SeqLMBatch]: return Seq2SeqLMBatch - # Not used anymore - # def decode(self, decoder_ids: List[int]) -> str: - # return self.tokenizer.decode( - # decoder_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False - # ) - def forward( self, input_ids, From f5ff9b57428357f0af0e7cdb0991f9548943969b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 15:08:44 +0000 Subject: [PATCH 16/22] Fuse back mistral into FlashCausalLM. --- .../text_generation_server/models/__init__.py | 7 +- .../models/flash_causal_lm.py | 81 +++++++++++++++++++ .../models/flash_mistral.py | 4 +- 3 files changed, 86 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 853968e10d3..31cea93889a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -96,7 +96,8 @@ from text_generation_server.models.custom_modeling.llava_next import ( LlavaNextForConditionalGeneration, ) - from text_generation_server.models.flash_mistral import FlashMistral + + # from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, ) @@ -127,7 +128,7 @@ if FLASH_ATTENTION: __all__.append(FlashCausalLM) __all__.append(IDEFICSSharded) - __all__.append(FlashMistral) + # __all__.append(FlashMistral) MAMBA_AVAILABLE = True try: @@ -813,7 +814,7 @@ def get_model( if model_type == MISTRAL: if FLASH_ATTENTION: - return FlashMistral( + return FlashCausalLM( model_id=model_id, model_class=FlashMistralForCausalLM, revision=revision, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 5fc16fe3893..c7f5f1f9331 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -810,6 +810,18 @@ def __len__(self): return len(self.requests) +ADAPTER_LAYERS = [ + "q_proj", + "k_proj", + "v_proj", + "o_proj", + "gate_proj", + "up_proj", + "down_proj", +] +ROW_PARALLEL = {"o_proj", "down_proj", "lm_head"} + + class FlashCausalLM(Model): def __init__( self, @@ -1658,3 +1670,72 @@ def generate_token( forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, batch, (forward_ns, decode_ns) + + @property + def supports_adapter_loading(self) -> bool: + return True + + def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: + layer_weights = {} + + prefix = "model.layers" + + # This accounts for VLMs (e.g. LlavaNext, Idefics2) + # that have a language_model inside of the larger model. + if hasattr(self.model, "language_model"): + _model = self.model.language_model + elif hasattr(self.model, "text_model"): + _model = self.model.text_model + else: + _model = self.model + + for i, layer in enumerate(_model.model.layers): + layer_weights[(i, "q_proj")] = ( + f"{prefix}.{i}.self_attn.q_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "k_proj")] = ( + f"{prefix}.{i}.self_attn.k_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "v_proj")] = ( + f"{prefix}.{i}.self_attn.v_proj", + layer.self_attn.query_key_value, + ) + layer_weights[(i, "o_proj")] = ( + f"{prefix}.{i}.self_attn.o_proj", + layer.self_attn.o_proj, + ) + + # TODO: this is a hack to avoid the gate_proj for + # FlashStarcoder2 that doesnt have these layers + if hasattr(layer, "mlp") and hasattr(layer.mlp, "gate_up_proj"): + layer_weights[(i, "gate_proj")] = ( + f"{prefix}.{i}.mlp.gate_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "up_proj")] = ( + f"{prefix}.{i}.mlp.up_proj", + layer.mlp.gate_up_proj, + ) + layer_weights[(i, "down_proj")] = ( + f"{prefix}.{i}.mlp.down_proj", + layer.mlp.down_proj, + ) + + layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head) + return layer_weights + + @property + def adapter_layers(self) -> List[str]: + return ADAPTER_LAYERS + + @property + def default_traced_adapter_layers(self) -> List[str]: + return ["q_proj", "v_proj"] + + def get_num_layers_for_type(self, layer_type: str) -> int: + return 1 if layer_type == "lm_head" else len(self.model.model.layers) + + def is_row_parallel(self, layer_type: str) -> bool: + return layer_type in ROW_PARALLEL diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index c2482dc27aa..2b2bd2e0495 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -28,9 +28,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]: # This accounts for VLMs (e.g. LlavaNext, Idefics2) # that have a language_model inside of the larger model. - if hasattr(self.model, "language_model"): - _model = self.model.language_model - elif hasattr(self.model, "text_model"): + if hasattr(self.model, "text_model"): _model = self.model.text_model else: _model = self.model From e2edf2beb2c57eadc357a42f514553a26cfd74fd Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 3 Jul 2024 15:19:06 +0000 Subject: [PATCH 17/22] Finish removal. --- server/text_generation_server/models/__init__.py | 8 +++----- server/text_generation_server/models/vlm_causal_lm.py | 8 ++++---- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 31cea93889a..15e746229a9 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -97,7 +97,6 @@ LlavaNextForConditionalGeneration, ) - # from text_generation_server.models.flash_mistral import FlashMistral from text_generation_server.models.custom_modeling.flash_santacoder_modeling import ( FlashSantacoderForCausalLM, ) @@ -128,7 +127,6 @@ if FLASH_ATTENTION: __all__.append(FlashCausalLM) __all__.append(IDEFICSSharded) - # __all__.append(FlashMistral) MAMBA_AVAILABLE = True try: @@ -838,7 +836,7 @@ def get_model( if model_type == MIXTRAL: if FLASH_ATTENTION: - return FlashMistral( + return FlashCausalLM( model_id=model_id, model_class=FlashMixtralForCausalLM, revision=revision, @@ -862,7 +860,7 @@ def get_model( if model_type == STARCODER2: if FLASH_ATTENTION: - return FlashMistral( + return FlashCausalLM( model_id=model_id, model_class=FlashStarcoder2ForCausalLM, revision=revision, @@ -888,7 +886,7 @@ def get_model( if model_type == QWEN2: if FLASH_ATTENTION: - return FlashMistral( + return FlashCausalLM( model_id=model_id, model_class=Qwen2ForCausalLM, revision=revision, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 90c4c46e330..ace4880588d 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -9,9 +9,9 @@ from transformers import PreTrainedTokenizerBase from transformers.image_processing_utils import select_best_resolution from text_generation_server.pb import generate_pb2 -from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch -from text_generation_server.models.flash_mistral import ( - FlashMistral, +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch, + FlashCausalLM, ) from transformers import AutoProcessor @@ -240,7 +240,7 @@ def from_pb_processor( return batch -class VlmCausalLM(FlashMistral): +class VlmCausalLM(FlashCausalLM): def __init__( self, model_id: str, From 8ecee7283ce5fe99554f868ced36e15d78abb469 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 12:59:59 +0200 Subject: [PATCH 18/22] Fixing docs + causal_lm `batch_class`. --- docs/source/supported_models.md | 1 + server/text_generation_server/models/causal_lm.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/supported_models.md b/docs/source/supported_models.md index 1eeed39f812..2bdd00de61d 100644 --- a/docs/source/supported_models.md +++ b/docs/source/supported_models.md @@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware - [Llama](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) - [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) - [Gemma](https://huggingface.co/google/gemma-7b) +- [PaliGemma](https://huggingface.co/google/paligemma-3b-pt-224) - [Gemma2](https://huggingface.co/google/gemma2-9b) - [Cohere](https://huggingface.co/CohereForAI/c4ai-command-r-plus) - [Dbrx](https://huggingface.co/databricks/dbrx-instruct) diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index b8ed9f47aaa..6cf876f79fc 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -642,7 +642,7 @@ def fallback( @property def batch_type(self) -> Type[CausalLMBatch]: - return self.batch + return self.batch_class def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None From fc5bfa070a57d5eeb9e91058ced76daea0b4334b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 13:37:22 +0200 Subject: [PATCH 19/22] Fixing docs + causal.lm. --- docs/openapi.json | 2 +- server/text_generation_server/models/causal_lm.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/openapi.json b/docs/openapi.json index 5e0399e0d49..9c9a8b1a346 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -10,7 +10,7 @@ "name": "Apache 2.0", "url": "https://www.apache.org/licenses/LICENSE-2.0" }, - "version": "2.1.1-dev0" + "version": "2.1.2-dev0" }, "paths": { "/": { diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 6cf876f79fc..cac36ebdd0a 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -19,6 +19,7 @@ ) from text_generation_server.models import Model from text_generation_server.utils.chunks import concat_text_chunks +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.models.types import ( Batch, From 425f348e485ef726192bd3318b4aa54e92cc1212 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 16:36:16 +0200 Subject: [PATCH 20/22] Add default to Gemma Causality. --- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 842df0d4d2d..4d731cbf371 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool = True): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim From 4aa0642f4d842710f874b569d20c4aa21f767dab Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 17:17:46 +0200 Subject: [PATCH 21/22] Default value for gemma/gemma2. --- .../models/custom_modeling/flash_gemma2_modeling.py | 2 +- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index cfa6b2fe4b3..625baa9109b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -442,7 +442,7 @@ def forward( class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4d731cbf371..864bf9b0fd3 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -419,7 +419,7 @@ def forward( class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 From 25c9611c04b70859dd51350d6721e5e07f34db4a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 4 Jul 2024 17:18:26 +0200 Subject: [PATCH 22/22] Wrong default. --- .../models/custom_modeling/flash_gemma_modeling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 864bf9b0fd3..b7ce6307580 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -162,7 +162,7 @@ def _load_gqa(config, prefix: str, weights): class FlashGemmaAttention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool = True): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim