Skip to content

Commit

Permalink
Trying to fix non chunking targets.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Oct 23, 2024
1 parent a31db04 commit 0a01dde
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,22 +1398,32 @@ def warmup(
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size

if max_total_tokens is None:
model_max_length = self.tokenizer.model_max_length
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
spare_blocks = (
# Leave 5% for some wiggle room
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
+ batch.num_blocks
)
spare_blocks = small_power_of_2(spare_blocks)

available_blocks = min(model_max_length, spare_blocks)
batch.num_blocks = available_blocks
batch.max_blocks = available_blocks
max_input_tokens = (
available_blocks - 1 if max_input_tokens is None else max_input_tokens
)
max_total_tokens = available_blocks
if get_support_chunking():
model_max_length = self.tokenizer.model_max_length
free_memory = get_free_memory(self.device, MEMORY_FRACTION)
spare_blocks = (
# Leave 5% for some wiggle room
int((free_memory * TGI_WIGGLE_ROOM) // total_cache_size)
+ batch.num_blocks
)
spare_blocks = small_power_of_2(spare_blocks)

available_blocks = min(model_max_length, spare_blocks)
batch.num_blocks = available_blocks
batch.max_blocks = available_blocks
max_input_tokens = (
available_blocks - 1
if max_input_tokens is None
else max_input_tokens
)
max_total_tokens = available_blocks
else:
max_total_tokens = batch.num_blocks
max_input_tokens = (
batch.num_blocks - 1
if max_input_tokens is None
else max_input_tokens
)

try:
self.init_kv_cache(
Expand Down

0 comments on commit 0a01dde

Please sign in to comment.