Skip to content

Commit 9f51e93

Browse files
feat: support isolated MTP auxiliary loss
Co-authored-by: liuzhenhai93 <liuzhenhai93@outlook.com>
1 parent 1fe7825 commit 9f51e93

7 files changed

Lines changed: 260 additions & 14 deletions

File tree

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
from contextlib import nullcontext
44
from typing import Optional
@@ -325,6 +325,7 @@ def __init__(
325325
runtime_gather_output: Optional[bool] = None,
326326
loss_mask: Optional[Tensor] = None,
327327
padding_mask=None,
328+
return_logits: bool = False,
328329
):
329330
"""Initialize the schedule plan of all Transformer layers' sub-modules.
330331
@@ -342,6 +343,8 @@ def __init__(
342343
extra_block_kwargs: Additional keyword arguments for blocks.
343344
runtime_gather_output: Whether to gather output at runtime.
344345
loss_mask (torch.Tensor): Used to mask out some portions of the loss
346+
return_logits (bool): Return logits instead of main LM loss when labels
347+
are provided. MTP auxiliary loss still consumes labels.
345348
346349
Returns:
347350
The model chunk schedule plan.
@@ -365,6 +368,7 @@ def __init__(
365368
self._model_chunk_state.loss_mask = loss_mask
366369
self._model_chunk_state.packed_seq_params = packed_seq_params
367370
self._model_chunk_state.padding_mask = padding_mask
371+
self._model_chunk_state.return_logits = return_logits
368372
self._model_chunk_state.extra_block_kwargs = extra_block_kwargs
369373
self._model_chunk_state.runtime_gather_output = runtime_gather_output
370374
self._model_chunk_state.model = model

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import weakref
44
from contextlib import nullcontext
@@ -228,6 +228,7 @@ def forward_impl(self, hidden_states):
228228
sequence_len_offset=self.chunk_state.sequence_len_offset,
229229
runtime_gather_output=self.chunk_state.runtime_gather_output,
230230
extra_block_kwargs=self.chunk_state.extra_block_kwargs,
231+
return_logits=self.chunk_state.return_logits,
231232
)
232233

233234
# For now, 1f1b only supports fp16 module

megatron/core/models/gpt/gpt_model.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
from collections import OrderedDict
44
from typing import Dict, Literal, Optional
@@ -503,6 +503,7 @@ def forward(
503503
inference_params: Optional[BaseInferenceContext] = None,
504504
loss_mask: Optional[Tensor] = None,
505505
padding_mask: Optional[Tensor] = None,
506+
return_logits: bool = False,
506507
) -> Tensor:
507508
"""Forward function of the GPT Model This function passes the input tensors
508509
through the embedding layer, and then the decoder and finally into the post
@@ -516,6 +517,9 @@ def forward(
516517
padding_mask (Tensor, optional): Padding mask for MoE routing.
517518
Shape [bsz, seq_length]. True = padding (exclude), False = valid (include).
518519
Only used for MoE layers to exclude padding tokens from routing computations.
520+
return_logits (bool): If True, return logits even when `labels` are provided.
521+
This lets online RL pass sampled labels for MTP auxiliary loss while
522+
computing the main RL loss externally from logits.
519523
"""
520524
if self.config.fine_grained_activation_offloading:
521525
self.preprocess_for_fine_grained_offloading()
@@ -591,6 +595,7 @@ def forward(
591595
extra_block_kwargs=extra_block_kwargs,
592596
inference_context=inference_context,
593597
mhc_multistream=mhc_multistream,
598+
return_logits=return_logits,
594599
)
595600

596601
def _postprocess(
@@ -613,6 +618,7 @@ def _postprocess(
613618
extra_block_kwargs=None,
614619
inference_context=None,
615620
mhc_multistream=None,
621+
return_logits=False,
616622
):
617623
"""Postprocesses decoder hidden states to generate logits or compute loss.
618624
@@ -699,7 +705,8 @@ def _postprocess(
699705
reshaped = hidden_states.squeeze(1).unsqueeze(0)
700706
hidden_states = inference_context.last_token_logits(reshaped).unsqueeze(1)
701707

702-
if has_config_logger_enabled(self.config) or labels is None:
708+
should_return_logits = return_logits or labels is None
709+
if has_config_logger_enabled(self.config) or should_return_logits:
703710
logits, _ = self.output_layer(
704711
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output
705712
)
@@ -730,7 +737,9 @@ def _postprocess(
730737
)
731738
log_config_to_disk(self.config, payload, prefix='input_and_logits')
732739

733-
if labels is None:
740+
if should_return_logits:
741+
# `return_logits` only changes the main LM output contract. MTP auxiliary
742+
# loss above still consumes `labels`/`loss_mask` when they are provided.
734743
# [s b h] => [b s h]
735744
return logits.transpose(0, 1).contiguous()
736745

@@ -763,6 +772,7 @@ def build_schedule_plan(
763772
inference_params: Optional[BaseInferenceContext] = None,
764773
loss_mask: Optional[Tensor] = None,
765774
padding_mask: Optional[Tensor] = None,
775+
return_logits: bool = False,
766776
):
767777
"""Builds a computation schedule plan for the model.
768778
@@ -789,6 +799,8 @@ def build_schedule_plan(
789799
Parameters for inference. Defaults to None.
790800
loss_mask (Optional[Tensor], optional): Loss mask. Defaults to None.
791801
padding_mask (Optional[Tensor], optional): Padding mask. Defaults to None.
802+
return_logits (bool, optional): Return logits instead of main LM loss when labels
803+
are provided. MTP auxiliary loss still uses labels. Defaults to False.
792804
793805
Returns:
794806
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
@@ -813,6 +825,7 @@ def build_schedule_plan(
813825
runtime_gather_output,
814826
loss_mask,
815827
padding_mask,
828+
return_logits,
816829
)
817830

818831
def sharded_state_dict(

megatron/core/models/hybrid/hybrid_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023-2026, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22

33
import logging
44
from typing import Literal, Optional
@@ -403,12 +403,15 @@ def forward(
403403
loss_mask: Optional[Tensor] = None,
404404
packed_seq_params: Optional[PackedSeqParams] = None,
405405
padding_mask: Optional[Tensor] = None,
406+
return_logits: bool = False,
406407
) -> Tensor:
407408
"""Forward function of the Hybrid model. This function passes the input tensors
408409
through the embedding layer, and then the decoder and finally into the post
409410
processing layer (optional).
410411
411412
It either returns the Loss values if labels are given or the final hidden units
413+
unless `return_logits` is True. In that case, labels still drive MTP auxiliary
414+
loss, while the main LM head returns logits for an external loss.
412415
"""
413416
# If decoder_input is provided (not None), then input_ids and position_ids are ignored.
414417
# Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input.
@@ -572,7 +575,9 @@ def forward(
572575
)
573576
self.output_layer.sequence_parallel = True
574577

575-
if labels is None:
578+
if return_logits or labels is None:
579+
# `return_logits` only controls the main LM output. Labels, when present,
580+
# have already been consumed by MTP auxiliary loss above.
576581
# [s b h] => [b s h]
577582
return logits.transpose(0, 1).contiguous()
578583

megatron/core/transformer/multi_token_prediction.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved.
1+
# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
from __future__ import annotations
33

44
import warnings
@@ -640,7 +640,7 @@ def set_loss_scale(scale: torch.Tensor):
640640

641641
def process_mtp_loss(
642642
hidden_states: Tensor,
643-
labels: Tensor,
643+
labels: Optional[Tensor],
644644
loss_mask: Optional[Tensor],
645645
output_layer: Callable,
646646
output_weight: Optional[Tensor],
@@ -685,6 +685,23 @@ def process_mtp_loss(
685685
if loss_mask is None:
686686
loss_mask = torch.ones_like(mtp_labels)
687687

688+
output_weight_for_mtp = output_weight
689+
output_layer_for_mtp = output_layer
690+
if config.mtp_isolated_loss:
691+
if output_weight_for_mtp is not None:
692+
output_weight_for_mtp = output_weight_for_mtp.detach()
693+
if isinstance(output_layer, torch.nn.Module):
694+
output_layer_params = {
695+
name: param.detach() for name, param in output_layer.named_parameters()
696+
}
697+
output_layer_buffers = dict(output_layer.named_buffers())
698+
output_layer_state = {**output_layer_params, **output_layer_buffers}
699+
700+
def output_layer_for_mtp(input_: Tensor, **kwargs):
701+
return torch.func.functional_call(
702+
output_layer, output_layer_state, args=(input_,), kwargs=kwargs
703+
)
704+
688705
# Store the original number of tokens before rolling for proper normalization
689706
# when calculate_per_token_loss is enabled. This ensures MTP gradients are
690707
# correctly scaled relative to the main loss gradients in finalize_model_grads.
@@ -701,17 +718,17 @@ def process_mtp_loss(
701718
loss_mask, shifts=-1, dims=-1, cp_group=cp_group, packed_seq_params=packed_seq_params
702719
)
703720
if fuse_linear_cross_entropy:
704-
mtp_loss = output_layer(
721+
mtp_loss = output_layer_for_mtp(
705722
hidden_states_list[mtp_layer_number + 1],
706-
weight=output_weight,
723+
weight=output_weight_for_mtp,
707724
runtime_gather_output=runtime_gather_output,
708725
output_cross_entropy_loss=True,
709726
labels=mtp_labels,
710727
)
711728
else:
712-
mtp_logits, _ = output_layer(
729+
mtp_logits, _ = output_layer_for_mtp(
713730
hidden_states_list[mtp_layer_number + 1],
714-
weight=output_weight,
731+
weight=output_weight_for_mtp,
715732
runtime_gather_output=runtime_gather_output,
716733
)
717734
if scale_logits_fn is not None:
@@ -991,6 +1008,8 @@ def _get_embeddings(
9911008
)
9921009
# embedding
9931010
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
1011+
if self.config.mtp_isolated_loss:
1012+
decoder_input = decoder_input.detach()
9941013

9951014
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
9961015

@@ -1724,6 +1743,11 @@ def forward(
17241743
hidden_states = mhc_chunks[offset]
17251744
else:
17261745
hidden_states = hidden_states_list[offset]
1746+
if self.config.mtp_isolated_loss:
1747+
hidden_states = hidden_states.detach().requires_grad_(True)
1748+
hidden_states = make_viewless_tensor(
1749+
inp=hidden_states, requires_grad=True, keep_graph=False
1750+
)
17271751
for iteration in range(self.config.mtp_num_layers):
17281752
layer_idx = 0 if self.mtp_use_repeated_layer else iteration
17291753
(hidden_states, input_ids, position_ids) = self.layers[layer_idx](

megatron/core/transformer/transformer_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ class TransformerConfig(ModelParallelConfig):
8181
which serves as an additional training objective.
8282
"""
8383

84+
mtp_isolated_loss: bool = False
85+
"""If True, MTP loss only updates MTP module parameters. The MTP loss graph is
86+
detached from the main decoder, shared embeddings, and output layer weights."""
87+
8488
mtp_use_repeated_layer: bool = False
8589
"""Use a single MTP layer repeatedly instead of multiple separate layers."""
8690

0 commit comments

Comments
 (0)