From 6c4135eacd9bb276a6e8b08fe571601c33a23739 Mon Sep 17 00:00:00 2001 From: drbh Date: Thu, 6 Jun 2024 21:02:03 +0000 Subject: [PATCH] fix: pass model_id for all flash causal lms --- server/text_generation_server/models/flash_causal_lm.py | 3 ++- server/text_generation_server/models/flash_cohere.py | 1 + server/text_generation_server/models/flash_dbrx.py | 1 + server/text_generation_server/models/flash_gemma.py | 1 + server/text_generation_server/models/flash_gpt2.py | 1 + server/text_generation_server/models/flash_mistral.py | 2 ++ server/text_generation_server/models/flash_mixtral.py | 1 + server/text_generation_server/models/flash_neox.py | 1 + server/text_generation_server/models/flash_phi.py | 1 + server/text_generation_server/models/flash_qwen2.py | 1 + server/text_generation_server/models/flash_rw.py | 1 + server/text_generation_server/models/flash_santacoder.py | 1 + server/text_generation_server/models/flash_starcoder2.py | 1 + 13 files changed, 15 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index fe229853c2e..e7e6c15b6c2 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -950,12 +950,13 @@ def warmup(self, batch: FlashCausalLMBatch): total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size free_memory = get_free_memory(self.device, MEMORY_FRACTION) + batch_num_blocks = batch.num_blocks if batch is not None else 0 num_blocks = ( # Leave 5% for some wiggle room int((free_memory * 0.95) // total_cache_size) # Add batch.num_blocks as we allocated it above, so it is included in the peak memory. - + batch.num_blocks + + batch_num_blocks ) del batch diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index b907ee08642..6738a6810ab 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -62,6 +62,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashCohere, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index d5eb1a6e20a..0e9c913c221 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -87,6 +87,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashDbrx, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 358883e688e..c635ccfaf26 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -62,6 +62,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index ef129e922f4..e9fc471ee97 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -65,6 +65,7 @@ def __init__( model = FlashGPT2ForCausalLM(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super(FlashGPT2, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 081c2e2c8dc..e66a1c3de78 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -79,6 +79,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) num_layers, num_kv_heads, head_size = self.get_layer_config(model) super().__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=num_layers, @@ -110,6 +111,7 @@ def __init__( trust_remote_code: bool = False, ): super(FlashMistral, self).__init__( + model_id=model_id, config_cls=MistralConfig, model_cls=FlashMistralForCausalLM, model_id=model_id, diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 587d423f94b..216279a8d78 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -20,6 +20,7 @@ def __init__( trust_remote_code: bool = False, ): super(FlashMixtral, self).__init__( + model_id=model_id, config_cls=MixtralConfig, model_cls=FlashMixtralForCausalLM, model_id=model_id, diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index adefaeb22f0..f14f20f500e 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -65,6 +65,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashNeoXSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.gpt_neox.layers), diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 32b573a953b..709c15c5962 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -91,6 +91,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashPhi, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index 752858635d4..b984c7cac03 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -71,6 +71,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers), diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e6350611c40..b1867988a05 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -74,6 +74,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashRWSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 2ad36b938f5..8ee662928fe 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -76,6 +76,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(FlashSantacoderSharded, self).__init__( + model_id=model_id, model=model.to(device), tokenizer=tokenizer, num_layers=len(model.transformer.h), diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 5533c9d957d..9eb88816d76 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -70,6 +70,7 @@ def __init__( torch.distributed.barrier(group=self.process_group) super(BaseFlashMistral, self).__init__( + model_id=model_id, model=model, tokenizer=tokenizer, num_layers=len(model.model.layers),