1- # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+ # Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
33from collections import OrderedDict
44from typing import Dict , Literal , Optional
@@ -503,6 +503,7 @@ def forward(
503503 inference_params : Optional [BaseInferenceContext ] = None ,
504504 loss_mask : Optional [Tensor ] = None ,
505505 padding_mask : Optional [Tensor ] = None ,
506+ return_logits : bool = False ,
506507 ) -> Tensor :
507508 """Forward function of the GPT Model This function passes the input tensors
508509 through the embedding layer, and then the decoder and finally into the post
@@ -516,6 +517,9 @@ def forward(
516517 padding_mask (Tensor, optional): Padding mask for MoE routing.
517518 Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
518519 Only used for MoE layers to exclude padding tokens from routing computations.
520+ return_logits (bool): If True, return logits even when `labels` are provided.
521+ This lets online RL pass sampled labels for MTP auxiliary loss while
522+ computing the main RL loss externally from logits.
519523 """
520524 if self .config .fine_grained_activation_offloading :
521525 self .preprocess_for_fine_grained_offloading ()
@@ -591,6 +595,7 @@ def forward(
591595 extra_block_kwargs = extra_block_kwargs ,
592596 inference_context = inference_context ,
593597 mhc_multistream = mhc_multistream ,
598+ return_logits = return_logits ,
594599 )
595600
596601 def _postprocess (
@@ -613,6 +618,7 @@ def _postprocess(
613618 extra_block_kwargs = None ,
614619 inference_context = None ,
615620 mhc_multistream = None ,
621+ return_logits = False ,
616622 ):
617623 """Postprocesses decoder hidden states to generate logits or compute loss.
618624
@@ -699,7 +705,8 @@ def _postprocess(
699705 reshaped = hidden_states .squeeze (1 ).unsqueeze (0 )
700706 hidden_states = inference_context .last_token_logits (reshaped ).unsqueeze (1 )
701707
702- if has_config_logger_enabled (self .config ) or labels is None :
708+ should_return_logits = return_logits or labels is None
709+ if has_config_logger_enabled (self .config ) or should_return_logits :
703710 logits , _ = self .output_layer (
704711 hidden_states , weight = output_weight , runtime_gather_output = runtime_gather_output
705712 )
@@ -730,7 +737,9 @@ def _postprocess(
730737 )
731738 log_config_to_disk (self .config , payload , prefix = 'input_and_logits' )
732739
733- if labels is None :
740+ if should_return_logits :
741+ # `return_logits` only changes the main LM output contract. MTP auxiliary
742+ # loss above still consumes `labels`/`loss_mask` when they are provided.
734743 # [s b h] => [b s h]
735744 return logits .transpose (0 , 1 ).contiguous ()
736745
@@ -763,6 +772,7 @@ def build_schedule_plan(
763772 inference_params : Optional [BaseInferenceContext ] = None ,
764773 loss_mask : Optional [Tensor ] = None ,
765774 padding_mask : Optional [Tensor ] = None ,
775+ return_logits : bool = False ,
766776 ):
767777 """Builds a computation schedule plan for the model.
768778
@@ -789,6 +799,8 @@ def build_schedule_plan(
789799 Parameters for inference. Defaults to None.
790800 loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
791801 padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
802+ return_logits (bool, optional): Return logits instead of main LM loss when labels
803+ are provided. MTP auxiliary loss still uses labels. Defaults to False.
792804
793805 Returns:
794806 TransformerModelChunkSchedulePlan: The model chunk schedule plan.
@@ -813,6 +825,7 @@ def build_schedule_plan(
813825 runtime_gather_output ,
814826 loss_mask ,
815827 padding_mask ,
828+ return_logits ,
816829 )
817830
818831 def sharded_state_dict (
0 commit comments