Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
unset_inference_cuda_graphed_iteration_for_ep_inference,
)
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs, graph_capture
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
from megatron.core.transformer.enums import CudaGraphScope
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.utils import (
Expand Down Expand Up @@ -410,25 +410,27 @@ def create_cuda_graphs(self, reset_context: bool = True):
# MTP CUDA graph warmup for this batch dimension.
if mtp_warmup_enabled:
n = cuda_graph_batch_dimension.req_count
# pylint: disable-next=possibly-used-before-assignment
if sp_enabled:
n = round_up_to_nearest_multiple(n, tp_size)
# pylint: disable-next=possibly-used-before-assignment
if n > 0 and n not in mtp_seen_batch_sizes:
mtp_seen_batch_sizes.add(n)
device = torch.cuda.current_device()
batch_dim = n // tp_size if sp_enabled else n
# Use zeros (not empty) — garbage token IDs cause OOB embedding lookups during graph capture/replay.
for depth in mtp_warmup_depths:
with graph_capture():
unwrapped.compute_mtp_single_step(
hidden_states=torch.zeros(
(batch_dim, 1, model_config.hidden_size),
device=device,
dtype=model_config.params_dtype,
),
next_token_ids=torch.zeros((1, n), device=device, dtype=torch.long),
position_ids=torch.zeros((1, n), device=device, dtype=torch.int64),
depth=depth,
)
unwrapped.compute_mtp_single_step(
hidden_states=torch.zeros(
(batch_dim, 1, model_config.hidden_size),
device=device,
dtype=model_config.params_dtype,
),
next_token_ids=torch.zeros((1, n), device=device, dtype=torch.long),
position_ids=torch.zeros((1, n), device=device, dtype=torch.int64),
depth=depth,
cache_key=("mtp", n, depth),
)

context.reset()

Expand All @@ -437,7 +439,6 @@ def create_cuda_graphs(self, reset_context: bool = True):
unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model)

if mtp_warmup_enabled and mtp_seen_batch_sizes:
controller.has_mtp_cuda_graphs = True
logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_seen_batch_sizes))

# Memory usage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, token

self.sampling_rng = torch.Generator(device=torch.cuda.current_device())
self.num_mtp_heads = self._get_mtp_num_heads()
self.has_mtp_cuda_graphs = False
self.sampling_rng.manual_seed(self.model_config.inference_sampling_seed)

if (
Expand Down Expand Up @@ -615,7 +614,7 @@ def _dynamic_step_context_init(
# Derive the MTP padded batch size from the existing padded graph dimensions.
# For MoE models this is post EP sync. In eager mode MTP uses locally SP-aligned
# batch size instead.
if self.has_mtp_cuda_graphs and context.using_cuda_graph_this_step():
if context.using_cuda_graph_this_step():
self._mtp_resolved_padded_count = context.padded_batch_dimensions.req_count
if self._sp_enabled:
self._mtp_resolved_padded_count = round_up_to_nearest_multiple(
Expand All @@ -624,13 +623,6 @@ def _dynamic_step_context_init(
else:
self._mtp_resolved_padded_count = None

# Tell the model whether to use MTP CUDA graphs this step. When the
# main model falls back to eager mode, MTP must also run eagerly across
# all EP ranks — otherwise some ranks may replay a captured graph while
# others run eagerly, causing EP collectives to hang.
if self.has_mtp_cuda_graphs:
unwrapped_model.use_mtp_cuda_graphs = context.using_cuda_graph_this_step()

# If using symmetric kernels and we are using using nccl
# for prefill turn off symmetric kernels
symmetric_ar_type = self.model_config.symmetric_ar_type
Expand Down Expand Up @@ -938,6 +930,12 @@ def _compute_serial_mtp_and_sample(self):
next_token_ids=token_ids_buf,
position_ids=position_ids_buf,
depth=mtp_depth,
eager=not context.using_cuda_graph_this_step(),
cache_key=(
("mtp", padded_count, mtp_depth)
if context.using_cuda_graph_this_step()
else None
),
)
nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/forward")

Expand Down Expand Up @@ -1689,6 +1687,8 @@ def _dummy_serial_mtp_forward(self):
dummy_token_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long)
dummy_position_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long)

context = self.inference_wrapped_model.inference_context

for depth in range(self._num_mtp_depths):
nvtx_range_push(f"mtp-spec-decoding/dummy-depth-{depth}")
mtp_logits_2d = None
Expand All @@ -1699,6 +1699,12 @@ def _dummy_serial_mtp_forward(self):
next_token_ids=dummy_token_ids,
position_ids=dummy_position_ids,
depth=mtp_depth,
eager=not context.using_cuda_graph_this_step(),
cache_key=(
("mtp", padded_count, mtp_depth)
if context.using_cuda_graph_this_step()
else None
),
)
mtp_logits_2d = mtp_logits.squeeze(1) # [padded_count, vocab_size]

Expand Down
15 changes: 11 additions & 4 deletions megatron/core/models/common/language_module/language_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def _setup_mtp_cuda_graphs(self):
base_module=self,
function_name="compute_mtp_single_step",
need_backward=False,
is_mtp_inference=True,
inline_capture=True,
)

def _is_in_embd_group(self):
Expand Down Expand Up @@ -345,6 +345,8 @@ def compute_mtp_single_step(
next_token_ids: Tensor,
position_ids: Tensor,
depth: Optional[int] = None,
eager: bool = False,
cache_key=None,
) -> tuple:
"""Compute a single MTP depth for speculative decoding.

Expand All @@ -355,14 +357,19 @@ def compute_mtp_single_step(
hidden_states (Tensor): Hidden states at last accepted positions.
next_token_ids (Tensor): Correct next token IDs [1, N].
position_ids (Tensor): Position IDs for the next tokens [1, N].
depth (int, optional): MTP depth index. Only needed when
``mtp_use_repeated_layer`` is False (each depth uses a
distinct layer). Omit for repeated-layer models so that a
depth (int, optional): MTP depth index. Only needed when `mtp_use_repeated_layer` is
False (each depth uses a distinct layer). Omit for repeated-layer models so that a
single CUDA graph can serve all depths.
eager, cache_key: The `CudaGraphManager` works by monkey-patching this argument onto the
function signature. Explictly including them removes the need for a monkey-patch,
and makes it straightforward to call the same method with and without eager mode.
These arguments are consumed by `CudaGraphManager`, if it exists.

Returns:
tuple: (new_hidden_states, logits [N, 1, vocab_size]).
"""
# CudaGraphManager consumes these args, if it exists
del eager, cache_key
Comment thread
tdene marked this conversation as resolved.
layer_idx = 0 if depth is None else depth
mtp_hidden = self.mtp.layers[layer_idx].forward_single_position(
hidden_states=hidden_states,
Expand Down
93 changes: 19 additions & 74 deletions megatron/core/transformer/cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import os
import time
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from contextlib import nullcontext
from copy import deepcopy
from dataclasses import dataclass, is_dataclass
from enum import Enum
Expand Down Expand Up @@ -98,16 +98,6 @@ def _set_capture_end():
_IS_GRAPH_CAPTURING = False


@contextmanager
def graph_capture():
"""Context manager that brackets a graph-capture region."""
_set_capture_start()
try:
yield
finally:
_set_capture_end()


def is_graph_warmup():
"""Query if currently warming up for graph capture."""
return _IS_GRAPH_WARMUP
Expand Down Expand Up @@ -341,10 +331,6 @@ class _CudagraphGlobalRecord:
cudagraph_record: list[tuple] = []
cudagraph_inference_record: list[tuple] = []

# MTP CudaGraphManagers registered at construction time so that
# delete_cuda_graphs() can clear their lookup tables.
mtp_cudagraph_managers: list = []

"""A pool-like data structure to reuse input and output buffers across cudagraph."""
tensor_reuse_pool = TensorReusePool()

Expand Down Expand Up @@ -520,19 +506,6 @@ def delete_cuda_graphs():
runner.bwd_graph = None
runner.mempool = None

# Reset MTP runners (excluded from the global inference record).
for mgr in _CudagraphGlobalRecord.mtp_cudagraph_managers:
for runner in mgr.cudagraph_runners:
runner.cudagraph_created = False
runner.fwd_graph_recorded = False
runner.bwd_graph_recorded = False
runner.fwd_graph = None
runner.bwd_graph = None
runner.mempool = None
mgr.cudagraph_runners.clear()
mgr.custom_cudagraphs_lookup_table.clear()
_CudagraphGlobalRecord.mtp_cudagraph_managers.clear()

# Reset global tracking state
_CudagraphGlobalRecord.cudagraph_created = False
_CudagraphGlobalRecord.cudagraph_record = []
Expand Down Expand Up @@ -832,6 +805,7 @@ def create_fwd_graph(self, args, kwargs, outputs=None, clone_inputs=True):
# graph capture's forward passes do not corrupt its value. Inference is not affected
# (no known buffer mutators) and would add new buffers (lazy MoE _fc1_weight/
# _fc2_weight) that misalign the positional restore.

if self.training and torch.is_grad_enabled():
buffer_backup = []
for buf in self.base_module.buffers():
Expand Down Expand Up @@ -1446,7 +1420,6 @@ def __init__(
function_name=None,
need_backward=True,
pg_collection=None,
is_mtp_inference=False,
inline_capture=False,
num_warmup_steps=None,
):
Expand All @@ -1456,7 +1429,6 @@ def __init__(
Args:
config: TransformerConfig object containing CUDA graph settings for memory
pooling, graph retention, gradient accumulation, FP8/FP4, and warmup steps.
is_mtp_inference: Whether this manager wraps an MTP inference forward pass.
inline_capture: Normally, whether the inline capture path is taken depends on whether
`inference_context` is present in the kwargs of the forward call.
Setting this argument to True always forces the inline capture path to be taken.
Expand All @@ -1469,7 +1441,6 @@ def __init__(
self.pg_collection = pg_collection
rng_tracker = get_cuda_rng_tracker()
self.need_backward = need_backward
self.is_mtp_inference = is_mtp_inference

if function_name is not None:
func = getattr(base_module, function_name)
Expand Down Expand Up @@ -1515,10 +1486,6 @@ def wrapped_func(*args, eager=False, cache_key=None, **kwargs):
self.custom_cudagraphs_lookup_table: dict = defaultdict(lambda: None)
self.is_first_microbatch = False

if is_mtp_inference:
# Registered so delete_cuda_graphs() can clear the lookup table.
_CudagraphGlobalRecord.mtp_cudagraph_managers.append(self)

# Without pipeline parallelism, microbatches execute one at a time.
# Therefore modules will always execute in the same order, so cudagraphs
# can both be reused and share a single mempool.
Expand Down Expand Up @@ -1624,19 +1591,14 @@ def __call__(self, megatron_module, args, kwargs, cache_key=None):
cache_key: Optional hashable key for O(1) runner lookup.
If `inference_context` is provided, this gets set to the correct value.
"""
is_inference_mode = (
'inference_context' in kwargs.keys() and kwargs['inference_context']
) or self.is_mtp_inference
is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context']
if cache_key is None and is_inference_mode:
if 'inference_context' in kwargs and kwargs['inference_context']:
inference_context = kwargs['inference_context']
if inference_context.is_static_batching():
batch_size = kwargs['hidden_states'].shape[0]
cache_key = (batch_size, inference_context.is_decode_only())
else:
cache_key = inference_context.padded_batch_dimensions
elif self.is_mtp_inference:
cache_key = ('mtp', kwargs['hidden_states'].shape, kwargs.get('depth'))
inference_context = kwargs['inference_context']
if inference_context.is_static_batching():
batch_size = kwargs['hidden_states'].shape[0]
cache_key = (batch_size, inference_context.is_decode_only())
else:
cache_key = inference_context.padded_batch_dimensions
is_in_checkpoint_fwd = is_checkpointing()
if HAVE_TE_GRAPHS:
is_in_checkpoint_fwd = is_in_checkpoint_fwd or is_fp8_activation_recompute_enabled()
Expand All @@ -1654,39 +1616,26 @@ def __call__(self, megatron_module, args, kwargs, cache_key=None):
out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs)
else:
if is_inference_mode or self._inline_capture:
# MTP must match the main model's eager/graph mode so all EP
# ranks take the same code path. Skip during graph capture.
if (
self.is_mtp_inference
and not getattr(megatron_module, 'use_mtp_cuda_graphs', False)
and not is_graph_capturing()
):
return self.func(*args, **kwargs)

# Inference generation mode creates graphs immediately
runner = self.get_cudagraph_runner(
megatron_module, args, kwargs, True, cache_key=cache_key
)

if (
not runner.fwd_graph_recorded
and self.is_mtp_inference
and not is_graph_capturing()
):
# No pre-warmed graph for this batch size — run eagerly.
return self.func(*args, **kwargs)

if not runner.fwd_graph_recorded:
# Reuse graph input-output buffers for inference
local_args, local_kwargs = args, kwargs
if not runner.is_first_layer:
# Find previous layer's runner in the global record
# Find previous layer's runner in the global record.
# Method-wrapped managers (e.g. the MTP wrapper around
# `compute_mtp_single_step`) have a base_module without
# `layer_number`; `getattr(..., None)` makes those rows
# harmlessly skipped by the predicate.
try:
previous_runner = next(
r
for r in _CudagraphGlobalRecord.cudagraph_inference_record
if (
r[0].base_module.layer_number
getattr(r[0].base_module, 'layer_number', None)
== runner.base_module.layer_number - 1
and r[0].fwd_graph is not None
and ArgMetadata(r[3]['hidden_states'])
Expand All @@ -1707,14 +1656,10 @@ def __call__(self, megatron_module, args, kwargs, cache_key=None):
runner.cudagraph_created = True
runner = runner.eval()

# Record to the global execution record. MTP runners are
# excluded — they don't chain with decoder layers (the
# previous-layer lookup expects layer_number) and are
# cleaned up via mtp_cudagraph_managers instead.
if not self.is_mtp_inference:
_CudagraphGlobalRecord.cudagraph_inference_record.append(
(runner, "fwd", args, kwargs)
)
# Record this to the global execution record
_CudagraphGlobalRecord.cudagraph_inference_record.append(
(runner, "fwd", args, kwargs)
)

# Now replay the graph
out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs)
Expand Down
Loading
Loading