Skip to content

Commit

Permalink
fix: pass model_id for all flash causal lms
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 6, 2024
1 parent d0f1470 commit 6c4135e
Show file tree
Hide file tree
Showing 13 changed files with 15 additions and 1 deletion.
3 changes: 2 additions & 1 deletion server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,12 +950,13 @@ def warmup(self, batch: FlashCausalLMBatch):
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

free_memory = get_free_memory(self.device, MEMORY_FRACTION)
batch_num_blocks = batch.num_blocks if batch is not None else 0

num_blocks = (
# Leave 5% for some wiggle room
int((free_memory * 0.95) // total_cache_size)
# Add batch.num_blocks as we allocated it above, so it is included in the peak memory.
+ batch.num_blocks
+ batch_num_blocks
)

del batch
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashCohere, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashDbrx, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashGemma, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
model = FlashGPT2ForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashGPT2, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
torch.distributed.barrier(group=self.process_group)
num_layers, num_kv_heads, head_size = self.get_layer_config(model)
super().__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=num_layers,
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(
trust_remote_code: bool = False,
):
super(FlashMistral, self).__init__(
model_id=model_id,
config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM,
model_id=model_id,
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
trust_remote_code: bool = False,
):
super(FlashMixtral, self).__init__(
model_id=model_id,
config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM,
model_id=model_id,
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashNeoXSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.gpt_neox.layers),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashPhi, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashRWSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_santacoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(FlashSantacoderSharded, self).__init__(
model_id=model_id,
model=model.to(device),
tokenizer=tokenizer,
num_layers=len(model.transformer.h),
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/flash_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(

torch.distributed.barrier(group=self.process_group)
super(BaseFlashMistral, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
num_layers=len(model.model.layers),
Expand Down

0 comments on commit 6c4135e

Please sign in to comment.