@@ -564,6 +564,131 @@ def mock_get_padded_token_len(paddings_list, val):
564564 np .testing .assert_array_equal (logits_indices_selector ,
565565 expected_selector )
566566
567+ @patch ('tpu_inference.runner.tpu_runner.NamedSharding' )
568+ @patch ('tpu_inference.runner.tpu_runner.runner_utils' )
569+ @patch ('tpu_inference.runner.tpu_runner.device_array' ,
570+ side_effect = lambda mesh , tensors , ** kwargs : tensors )
571+ @patch ('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata' )
572+ def test_prepare_inputs_dp_with_decode_requests (self ,
573+ mock_sampling_metadata ,
574+ mock_device_array ,
575+ mock_runner_utils ,
576+ mock_named_sharding ):
577+ """Test _prepare_inputs_dp with decode requests (1 token each) to verify request_distribution."""
578+
579+ # Setup mocking
580+ def mock_get_padded_token_len (paddings_list , val ):
581+ if val <= 2 :
582+ return 4 # For request padding
583+ elif val <= 4 :
584+ return 8 # For token padding
585+ else :
586+ return 16
587+
588+ mock_runner_utils .get_padded_token_len .side_effect = mock_get_padded_token_len
589+ mock_sampling_instance = MagicMock ()
590+ mock_sampling_metadata .from_input_batch .return_value = mock_sampling_instance
591+ mock_named_sharding .return_value = MagicMock ()
592+
593+ # Setup test data with decode requests (1 token) and prefill requests (>1 token)
594+ # req1: decode (1 token), req2: decode (1 token), req3: prefill (3 tokens), req4: decode (1 token)
595+ num_scheduled_tokens = {"req1" : 1 , "req2" : 1 , "req3" : 3 , "req4" : 1 }
596+ assigned_dp_ranks = {"req1" : 0 , "req2" : 0 , "req3" : 1 , "req4" : 1 }
597+
598+ self .runner .input_batch .num_reqs = 4
599+ self .runner .input_batch .req_ids = ["req1" , "req2" , "req3" , "req4" ]
600+ self .runner .input_batch .num_computed_tokens_cpu = np .array (
601+ [5 , 6 , 7 , 8 ])
602+ self .runner .input_batch .token_ids_cpu = np .zeros ((8 , 64 ),
603+ dtype = np .int32 )
604+
605+ scheduler_output = self ._create_mock_scheduler_output (
606+ num_scheduled_tokens , assigned_dp_ranks )
607+
608+ # Setup required attributes
609+ self .runner .uses_mrope = False
610+ self .runner .phase_based_profiler = None
611+ self .runner .lora_config = None
612+ self .runner .mesh = MagicMock ()
613+ self .runner .data_parallel_sharding = MagicMock ()
614+ self .runner .data_parallel_attn_sharding = MagicMock ()
615+ self .runner .mm_manager = MagicMock ()
616+ self .runner .speculative_decoding_manager = MagicMock ()
617+ self .runner .lora_utils = MagicMock ()
618+
619+ # Execute the method
620+ result = self .runner ._prepare_inputs_dp (scheduler_output )
621+ input_ids , positions , attention_metadata , sampling_metadata , logits_indices , spec_decode_metadata , logits_indices_selector , padded_num_reqs = result
622+
623+ # Verify request_distribution
624+ # DP rank 0: req1 (decode), req2 (decode) -> [2, 2, 2]
625+ # DP rank 1: req3 (prefill), req4 (decode) -> [1, 1, 2]
626+ expected_distribution = np .array ([[2 , 2 , 2 ], [1 , 1 , 2 ]]).flatten ()
627+ np .testing .assert_array_equal (attention_metadata .request_distribution ,
628+ expected_distribution )
629+
630+ @patch ('tpu_inference.runner.tpu_runner.NamedSharding' )
631+ @patch ('tpu_inference.runner.tpu_runner.runner_utils' )
632+ @patch ('tpu_inference.runner.tpu_runner.device_array' ,
633+ side_effect = lambda mesh , tensors , ** kwargs : tensors )
634+ @patch ('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata' )
635+ def test_prepare_inputs_dp_all_decode_requests (self ,
636+ mock_sampling_metadata ,
637+ mock_device_array ,
638+ mock_runner_utils ,
639+ mock_named_sharding ):
640+ """Test _prepare_inputs_dp with all decode requests."""
641+
642+ # Setup mocking
643+ def mock_get_padded_token_len (paddings_list , val ):
644+ if val <= 2 :
645+ return 4
646+ elif val <= 4 :
647+ return 8
648+ else :
649+ return 16
650+
651+ mock_runner_utils .get_padded_token_len .side_effect = mock_get_padded_token_len
652+ mock_sampling_instance = MagicMock ()
653+ mock_sampling_metadata .from_input_batch .return_value = mock_sampling_instance
654+ mock_named_sharding .return_value = MagicMock ()
655+
656+ # All requests are decode (1 token each)
657+ num_scheduled_tokens = {"req1" : 1 , "req2" : 1 }
658+ assigned_dp_ranks = {"req1" : 0 , "req2" : 1 }
659+
660+ self .runner .input_batch .num_reqs = 2
661+ self .runner .input_batch .req_ids = ["req1" , "req2" ]
662+ self .runner .input_batch .num_computed_tokens_cpu = np .array ([5 , 6 ])
663+ self .runner .input_batch .token_ids_cpu = np .zeros ((8 , 64 ),
664+ dtype = np .int32 )
665+
666+ scheduler_output = self ._create_mock_scheduler_output (
667+ num_scheduled_tokens , assigned_dp_ranks )
668+
669+ # Setup required attributes
670+ self .runner .uses_mrope = False
671+ self .runner .phase_based_profiler = None
672+ self .runner .lora_config = None
673+ self .runner .mesh = MagicMock ()
674+ self .runner .data_parallel_sharding = MagicMock ()
675+ self .runner .data_parallel_attn_sharding = MagicMock ()
676+ self .runner .mm_manager = MagicMock ()
677+ self .runner .speculative_decoding_manager = MagicMock ()
678+ self .runner .lora_utils = MagicMock ()
679+
680+ # Execute the method
681+ result = self .runner ._prepare_inputs_dp (scheduler_output )
682+ input_ids , positions , attention_metadata , sampling_metadata , logits_indices , spec_decode_metadata , logits_indices_selector , padded_num_reqs = result
683+
684+ # Verify request_distribution
685+ # Both ranks have only decode requests
686+ # DP rank 0: req1 (decode) -> [1, 1, 1]
687+ # DP rank 1: req2 (decode) -> [1, 1, 1]
688+ expected_distribution = np .array ([[1 , 1 , 1 ], [1 , 1 , 1 ]]).flatten ()
689+ np .testing .assert_array_equal (attention_metadata .request_distribution ,
690+ expected_distribution )
691+
567692 @patch ('tpu_inference.runner.tpu_runner.NamedSharding' )
568693 @patch ('tpu_inference.runner.tpu_runner.runner_utils' )
569694 @patch ('tpu_inference.runner.tpu_runner.device_array' ,
0 commit comments