Skip to content

Commit 29a8d6a

Browse files
authored
[Fix] Don't deep-copy LogitsProcessors when copying SamplingParams (#3099)
1 parent 2c08ff2 commit 29a8d6a

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

vllm/engine/llm_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,9 @@ def add_request(
484484
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
485485
if lora_request else 0) if prefix_pos is not None else None
486486

487-
# Defensive copy of SamplingParams, which are used by the sampler
488-
sampling_params = copy.deepcopy(sampling_params)
487+
# Defensive copy of SamplingParams, which are used by the sampler,
488+
# this doesn't deep-copy LogitsProcessor objects
489+
sampling_params = sampling_params.clone()
489490

490491
# Create the sequence group.
491492
seq_group = SequenceGroup(request_id, [seq], sampling_params,

vllm/sampling_params.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Sampling parameters for text generation."""
2+
import copy
23
from enum import IntEnum
34
from functools import cached_property
45
from typing import Callable, List, Optional, Union
@@ -237,6 +238,20 @@ def sampling_type(self) -> SamplingType:
237238
return SamplingType.RANDOM_SEED
238239
return SamplingType.RANDOM
239240

241+
def clone(self) -> "SamplingParams":
242+
"""Deep copy excluding LogitsProcessor objects.
243+
244+
LogitsProcessor objects are excluded because they may contain an
245+
arbitrary, nontrivial amount of data.
246+
See https://github.com/vllm-project/vllm/issues/3087
247+
"""
248+
249+
logit_processor_refs = None if self.logits_processors is None else {
250+
id(lp): lp
251+
for lp in self.logits_processors
252+
}
253+
return copy.deepcopy(self, memo=logit_processor_refs)
254+
240255
def __repr__(self) -> str:
241256
return (
242257
f"SamplingParams(n={self.n}, "

0 commit comments

Comments
 (0)