@@ -638,6 +638,8 @@ def concatenate(cls, batches: List["FlashCausalLMBatch"]) -> "FlashCausalLMBatch
638638 # Needed to avoid dropping blocks when the batches will go out of scope
639639 for b in batches :
640640 b .block_tables = None
641+ del b
642+ torch .cuda .empty_cache ()
641643
642644 return FlashCausalLMBatch (
643645 batch_id = batches [0 ].batch_id ,
@@ -732,6 +734,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_total_tokens: int):
732734 )
733735 raise e
734736 del batch
737+ torch .cuda .empty_cache ()
735738
736739 def decode (self , generated_ids : Union [torch .Tensor , List [int ]]) -> str :
737740 return self .tokenizer .decode (
@@ -775,16 +778,21 @@ def generate_token(
775778 # Allocate blocks to this batch
776779 CACHE_MANAGER .allocate (batch )
777780
778- out = self .forward (
779- batch .input_ids ,
780- batch .position_ids ,
781- batch .cu_seqlen_prefill ,
782- batch .block_tables_tensor ,
783- batch .slots [batch .slot_indices ],
784- batch .input_lengths_tensor ,
785- batch .max_seqlen ,
786- batch .prefill_head_indices ,
787- )
781+ try :
782+ out = self .forward (
783+ batch .input_ids ,
784+ batch .position_ids ,
785+ batch .cu_seqlen_prefill ,
786+ batch .block_tables_tensor ,
787+ batch .slots [batch .slot_indices ],
788+ batch .input_lengths_tensor ,
789+ batch .max_seqlen ,
790+ batch .prefill_head_indices ,
791+ )
792+ except Exception as e :
793+ del batch
794+ torch .cuda .empty_cache ()
795+ raise e
788796
789797 if prefill :
790798 next_token_logits = (
0 commit comments