Skip to content

Commit 7d47d8f

Browse files
authored
[Fix] fix resources limit error when apply speculative decoding and aclgraph (#2472)
### What this PR does / why we need it? When both speculative decoding and aclgraph are applied, and cudagraph_capture_sizes uses the default value, it will report that the stream resources are insufficient. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.10.1.1 - vLLM main: vllm-project/vllm@9c99e48 Signed-off-by: withHades <[email protected]>
1 parent 0c0789b commit 7d47d8f

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

tests/ut/test_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,20 @@ def test_update_aclgraph_sizes(self):
261261
self.assertEqual(
262262
147,
263263
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
264+
265+
test_vllm_config.speculative_config = mock.MagicMock()
266+
test_vllm_config.speculative_config.draft_model_config = mock.MagicMock(
267+
)
268+
test_vllm_config.speculative_config.draft_model_config.hf_config = mock.MagicMock(
269+
)
270+
test_vllm_config.speculative_config.draft_model_config.hf_config.num_hidden_layers = 2
271+
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
272+
utils.update_aclgraph_sizes(test_vllm_config)
273+
del os.environ['HCCL_OP_EXPANSION_MODE']
274+
self.assertEqual(
275+
120,
276+
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
277+
264278
# max_num_batch_sizes >= len(original_sizes)
265279
test_compilation_config = CompilationConfig(
266280
cudagraph_capture_sizes=[1, 2, 3])

vllm_ascend/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,12 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
304304
num_hidden_layers = get_max_hidden_layers(hf_config)
305305
parallel_config = vllm_config.parallel_config
306306

307+
# Calculate maximum supported batch sizes considering model architecture
308+
resources_per_graph = num_hidden_layers + 1
309+
if vllm_config.speculative_config is not None:
310+
draft_model_hf_config = vllm_config.speculative_config.draft_model_config.hf_config
311+
resources_per_graph += draft_model_hf_config.num_hidden_layers + 1
312+
307313
# TODO: Find out whether we need to take into account the pp_size
308314
num_comm_groups = sum(size > 1 for size in [
309315
parallel_config.data_parallel_size,
@@ -318,8 +324,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
318324
# Assume the following case:
319325
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
320326
# According to the formula, max_num_batch_sizes = math.floor(1920 / (48 + 1) / 2) = 19
321-
max_num_batch_sizes = math.floor(
322-
MAX_CAPTURE_SIZE / (num_hidden_layers + 1) / parallel_factor)
327+
max_num_batch_sizes = math.floor(MAX_CAPTURE_SIZE /
328+
resources_per_graph / parallel_factor)
323329
logger.info(
324330
"Calculated maximum supported batch sizes for ACL graph: %s",
325331
max_num_batch_sizes)
@@ -335,8 +341,8 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
335341
# MAX_CAPTURE_SIZE = 1920, num_hidden_layers = 48, data_parallel_size is 1, tensor_parallel_size is 4,
336342
# According to the formula, max_num_batch_sizes = math.floor((1920 - 1 * 40) / (48 + 1) / (1 + 1 * 2)) = 12
337343
max_num_batch_sizes = math.floor(
338-
(MAX_CAPTURE_SIZE - num_comm_groups * 40) /
339-
(num_hidden_layers + 1) / (1 + num_comm_groups * 2))
344+
(MAX_CAPTURE_SIZE - num_comm_groups * 40) / resources_per_graph /
345+
(1 + num_comm_groups * 2))
340346
logger.info(
341347
"Calculated maximum supported batch sizes for ACL graph: %s",
342348
max_num_batch_sizes)

0 commit comments

Comments
 (0)