Skip to content

Commit c4bb526

Browse files
fix(server): decrease memory fragmentation (#557)
1 parent 6f42942 commit c4bb526

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

server/text_generation_server/cache.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
from typing import Dict, Optional, TypeVar
24

35
from text_generation_server.models.types import Batch
@@ -20,6 +22,8 @@ def delete(self, batch_id: int):
2022
batch = self.pop(batch_id)
2123
if batch is not None:
2224
del batch
25+
if torch.cuda.is_available():
26+
torch.cuda.empty_cache()
2327

2428
def clear(self):
2529
keys = list(self.cache.keys())

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
638638
# Needed to avoid dropping blocks when the batches will go out of scope
639639
for b in batches:
640640
b.block_tables = None
641+
del b
642+
torch.cuda.empty_cache()
641643

642644
return FlashCausalLMBatch(
643645
batch_id=batches[0].batch_id,
@@ -732,6 +734,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
732734
)
733735
raise e
734736
del batch
737+
torch.cuda.empty_cache()
735738

736739
def decode(self, generated_ids: Union[torch.Tensor, List[int]]) -> str:
737740
return self.tokenizer.decode(
@@ -775,16 +778,21 @@ def generate_token(
775778
# Allocate blocks to this batch
776779
CACHE_MANAGER.allocate(batch)
777780

778-
out = self.forward(
779-
batch.input_ids,
780-
batch.position_ids,
781-
batch.cu_seqlen_prefill,
782-
batch.block_tables_tensor,
783-
batch.slots[batch.slot_indices],
784-
batch.input_lengths_tensor,
785-
batch.max_seqlen,
786-
batch.prefill_head_indices,
787-
)
781+
try:
782+
out = self.forward(
783+
batch.input_ids,
784+
batch.position_ids,
785+
batch.cu_seqlen_prefill,
786+
batch.block_tables_tensor,
787+
batch.slots[batch.slot_indices],
788+
batch.input_lengths_tensor,
789+
batch.max_seqlen,
790+
batch.prefill_head_indices,
791+
)
792+
except Exception as e:
793+
del batch
794+
torch.cuda.empty_cache()
795+
raise e
788796

789797
if prefill:
790798
next_token_logits = (

0 commit comments

Comments
 (0)