Skip to content

Commit 6ec03a6

Browse files
Update request_distribution in DP input preparation (#1211)
1 parent 05e4b16 commit 6ec03a6

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

tests/runner/test_tpu_runner_dp.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,131 @@ def mock_get_padded_token_len(paddings_list, val):
564564
np.testing.assert_array_equal(logits_indices_selector,
565565
expected_selector)
566566

567+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
568+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
569+
@patch('tpu_inference.runner.tpu_runner.device_array',
570+
side_effect=lambda mesh, tensors, **kwargs: tensors)
571+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
572+
def test_prepare_inputs_dp_with_decode_requests(self,
573+
mock_sampling_metadata,
574+
mock_device_array,
575+
mock_runner_utils,
576+
mock_named_sharding):
577+
"""Test _prepare_inputs_dp with decode requests (1 token each) to verify request_distribution."""
578+
579+
# Setup mocking
580+
def mock_get_padded_token_len(paddings_list, val):
581+
if val <= 2:
582+
return 4 # For request padding
583+
elif val <= 4:
584+
return 8 # For token padding
585+
else:
586+
return 16
587+
588+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
589+
mock_sampling_instance = MagicMock()
590+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
591+
mock_named_sharding.return_value = MagicMock()
592+
593+
# Setup test data with decode requests (1 token) and prefill requests (>1 token)
594+
# req1: decode (1 token), req2: decode (1 token), req3: prefill (3 tokens), req4: decode (1 token)
595+
num_scheduled_tokens = {"req1": 1, "req2": 1, "req3": 3, "req4": 1}
596+
assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
597+
598+
self.runner.input_batch.num_reqs = 4
599+
self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
600+
self.runner.input_batch.num_computed_tokens_cpu = np.array(
601+
[5, 6, 7, 8])
602+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
603+
dtype=np.int32)
604+
605+
scheduler_output = self._create_mock_scheduler_output(
606+
num_scheduled_tokens, assigned_dp_ranks)
607+
608+
# Setup required attributes
609+
self.runner.uses_mrope = False
610+
self.runner.phase_based_profiler = None
611+
self.runner.lora_config = None
612+
self.runner.mesh = MagicMock()
613+
self.runner.data_parallel_sharding = MagicMock()
614+
self.runner.data_parallel_attn_sharding = MagicMock()
615+
self.runner.mm_manager = MagicMock()
616+
self.runner.speculative_decoding_manager = MagicMock()
617+
self.runner.lora_utils = MagicMock()
618+
619+
# Execute the method
620+
result = self.runner._prepare_inputs_dp(scheduler_output)
621+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
622+
623+
# Verify request_distribution
624+
# DP rank 0: req1 (decode), req2 (decode) -> [2, 2, 2]
625+
# DP rank 1: req3 (prefill), req4 (decode) -> [1, 1, 2]
626+
expected_distribution = np.array([[2, 2, 2], [1, 1, 2]]).flatten()
627+
np.testing.assert_array_equal(attention_metadata.request_distribution,
628+
expected_distribution)
629+
630+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
631+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
632+
@patch('tpu_inference.runner.tpu_runner.device_array',
633+
side_effect=lambda mesh, tensors, **kwargs: tensors)
634+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
635+
def test_prepare_inputs_dp_all_decode_requests(self,
636+
mock_sampling_metadata,
637+
mock_device_array,
638+
mock_runner_utils,
639+
mock_named_sharding):
640+
"""Test _prepare_inputs_dp with all decode requests."""
641+
642+
# Setup mocking
643+
def mock_get_padded_token_len(paddings_list, val):
644+
if val <= 2:
645+
return 4
646+
elif val <= 4:
647+
return 8
648+
else:
649+
return 16
650+
651+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
652+
mock_sampling_instance = MagicMock()
653+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
654+
mock_named_sharding.return_value = MagicMock()
655+
656+
# All requests are decode (1 token each)
657+
num_scheduled_tokens = {"req1": 1, "req2": 1}
658+
assigned_dp_ranks = {"req1": 0, "req2": 1}
659+
660+
self.runner.input_batch.num_reqs = 2
661+
self.runner.input_batch.req_ids = ["req1", "req2"]
662+
self.runner.input_batch.num_computed_tokens_cpu = np.array([5, 6])
663+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
664+
dtype=np.int32)
665+
666+
scheduler_output = self._create_mock_scheduler_output(
667+
num_scheduled_tokens, assigned_dp_ranks)
668+
669+
# Setup required attributes
670+
self.runner.uses_mrope = False
671+
self.runner.phase_based_profiler = None
672+
self.runner.lora_config = None
673+
self.runner.mesh = MagicMock()
674+
self.runner.data_parallel_sharding = MagicMock()
675+
self.runner.data_parallel_attn_sharding = MagicMock()
676+
self.runner.mm_manager = MagicMock()
677+
self.runner.speculative_decoding_manager = MagicMock()
678+
self.runner.lora_utils = MagicMock()
679+
680+
# Execute the method
681+
result = self.runner._prepare_inputs_dp(scheduler_output)
682+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
683+
684+
# Verify request_distribution
685+
# Both ranks have only decode requests
686+
# DP rank 0: req1 (decode) -> [1, 1, 1]
687+
# DP rank 1: req2 (decode) -> [1, 1, 1]
688+
expected_distribution = np.array([[1, 1, 1], [1, 1, 1]]).flatten()
689+
np.testing.assert_array_equal(attention_metadata.request_distribution,
690+
expected_distribution)
691+
567692
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
568693
@patch('tpu_inference.runner.tpu_runner.runner_utils')
569694
@patch('tpu_inference.runner.tpu_runner.device_array',

tpu_inference/runner/tpu_runner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1336,7 +1336,14 @@ def _prepare_inputs_dp(self, scheduler_output: "VllmSchedulerOutput"):
13361336
_request_distribution = []
13371337
for dp_rank in range(dp_size):
13381338
_num_reqs = num_req_per_dp_rank[dp_rank]
1339-
_request_distribution.append([0, 0, _num_reqs])
1339+
# The batch has been reordered by _reorder_batch so decode requests come first
1340+
# Count decode requests (those with num_scheduled_tokens == 1) in this DP rank
1341+
num_decode_in_dp_rank = 0
1342+
for req_id in req_ids_dp[dp_rank]:
1343+
if scheduler_output.num_scheduled_tokens[req_id] == 1:
1344+
num_decode_in_dp_rank += 1
1345+
_request_distribution.append(
1346+
[num_decode_in_dp_rank, num_decode_in_dp_rank, _num_reqs])
13401347
request_distribution = np.array(_request_distribution).ravel()
13411348

13421349
use_spec_decode = len(

0 commit comments

Comments
 (0)