diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index f96235db0c0..2a40937ae6b 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -50,7 +50,7 @@ unset_inference_cuda_graphed_iteration_for_ep_inference, ) from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.cuda_graphs import delete_cuda_graphs, graph_capture +from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import CudaGraphScope from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.utils import ( @@ -410,25 +410,27 @@ def create_cuda_graphs(self, reset_context: bool = True): # MTP CUDA graph warmup for this batch dimension. if mtp_warmup_enabled: n = cuda_graph_batch_dimension.req_count + # pylint: disable-next=possibly-used-before-assignment if sp_enabled: n = round_up_to_nearest_multiple(n, tp_size) + # pylint: disable-next=possibly-used-before-assignment if n > 0 and n not in mtp_seen_batch_sizes: mtp_seen_batch_sizes.add(n) device = torch.cuda.current_device() batch_dim = n // tp_size if sp_enabled else n # Use zeros (not empty) — garbage token IDs cause OOB embedding lookups during graph capture/replay. for depth in mtp_warmup_depths: - with graph_capture(): - unwrapped.compute_mtp_single_step( - hidden_states=torch.zeros( - (batch_dim, 1, model_config.hidden_size), - device=device, - dtype=model_config.params_dtype, - ), - next_token_ids=torch.zeros((1, n), device=device, dtype=torch.long), - position_ids=torch.zeros((1, n), device=device, dtype=torch.int64), - depth=depth, - ) + unwrapped.compute_mtp_single_step( + hidden_states=torch.zeros( + (batch_dim, 1, model_config.hidden_size), + device=device, + dtype=model_config.params_dtype, + ), + next_token_ids=torch.zeros((1, n), device=device, dtype=torch.long), + position_ids=torch.zeros((1, n), device=device, dtype=torch.int64), + depth=depth, + cache_key=("mtp", n, depth), + ) context.reset() @@ -437,7 +439,6 @@ def create_cuda_graphs(self, reset_context: bool = True): unset_inference_cuda_graphed_iteration_for_ep_inference(unwrapped_model) if mtp_warmup_enabled and mtp_seen_batch_sizes: - controller.has_mtp_cuda_graphs = True logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_seen_batch_sizes)) # Memory usage. diff --git a/megatron/core/inference/text_generation_controllers/text_generation_controller.py b/megatron/core/inference/text_generation_controllers/text_generation_controller.py index 993d05afbe0..e66591edad0 100644 --- a/megatron/core/inference/text_generation_controllers/text_generation_controller.py +++ b/megatron/core/inference/text_generation_controllers/text_generation_controller.py @@ -100,7 +100,6 @@ def __init__(self, inference_wrapped_model: AbstractModelInferenceWrapper, token self.sampling_rng = torch.Generator(device=torch.cuda.current_device()) self.num_mtp_heads = self._get_mtp_num_heads() - self.has_mtp_cuda_graphs = False self.sampling_rng.manual_seed(self.model_config.inference_sampling_seed) if ( @@ -615,7 +614,7 @@ def _dynamic_step_context_init( # Derive the MTP padded batch size from the existing padded graph dimensions. # For MoE models this is post EP sync. In eager mode MTP uses locally SP-aligned # batch size instead. - if self.has_mtp_cuda_graphs and context.using_cuda_graph_this_step(): + if context.using_cuda_graph_this_step(): self._mtp_resolved_padded_count = context.padded_batch_dimensions.req_count if self._sp_enabled: self._mtp_resolved_padded_count = round_up_to_nearest_multiple( @@ -624,13 +623,6 @@ def _dynamic_step_context_init( else: self._mtp_resolved_padded_count = None - # Tell the model whether to use MTP CUDA graphs this step. When the - # main model falls back to eager mode, MTP must also run eagerly across - # all EP ranks — otherwise some ranks may replay a captured graph while - # others run eagerly, causing EP collectives to hang. - if self.has_mtp_cuda_graphs: - unwrapped_model.use_mtp_cuda_graphs = context.using_cuda_graph_this_step() - # If using symmetric kernels and we are using using nccl # for prefill turn off symmetric kernels symmetric_ar_type = self.model_config.symmetric_ar_type @@ -938,6 +930,12 @@ def _compute_serial_mtp_and_sample(self): next_token_ids=token_ids_buf, position_ids=position_ids_buf, depth=mtp_depth, + eager=not context.using_cuda_graph_this_step(), + cache_key=( + ("mtp", padded_count, mtp_depth) + if context.using_cuda_graph_this_step() + else None + ), ) nvtx_range_pop(f"mtp-spec-decoding/depth-{depth}/forward") @@ -1689,6 +1687,8 @@ def _dummy_serial_mtp_forward(self): dummy_token_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long) dummy_position_ids = torch.zeros((1, padded_count), device=device, dtype=torch.long) + context = self.inference_wrapped_model.inference_context + for depth in range(self._num_mtp_depths): nvtx_range_push(f"mtp-spec-decoding/dummy-depth-{depth}") mtp_logits_2d = None @@ -1699,6 +1699,12 @@ def _dummy_serial_mtp_forward(self): next_token_ids=dummy_token_ids, position_ids=dummy_position_ids, depth=mtp_depth, + eager=not context.using_cuda_graph_this_step(), + cache_key=( + ("mtp", padded_count, mtp_depth) + if context.using_cuda_graph_this_step() + else None + ), ) mtp_logits_2d = mtp_logits.squeeze(1) # [padded_count, vocab_size] diff --git a/megatron/core/models/common/language_module/language_module.py b/megatron/core/models/common/language_module/language_module.py index 85870726269..84b0ca2fea3 100644 --- a/megatron/core/models/common/language_module/language_module.py +++ b/megatron/core/models/common/language_module/language_module.py @@ -75,7 +75,7 @@ def _setup_mtp_cuda_graphs(self): base_module=self, function_name="compute_mtp_single_step", need_backward=False, - is_mtp_inference=True, + inline_capture=True, ) def _is_in_embd_group(self): @@ -345,6 +345,8 @@ def compute_mtp_single_step( next_token_ids: Tensor, position_ids: Tensor, depth: Optional[int] = None, + eager: bool = False, + cache_key=None, ) -> tuple: """Compute a single MTP depth for speculative decoding. @@ -355,14 +357,19 @@ def compute_mtp_single_step( hidden_states (Tensor): Hidden states at last accepted positions. next_token_ids (Tensor): Correct next token IDs [1, N]. position_ids (Tensor): Position IDs for the next tokens [1, N]. - depth (int, optional): MTP depth index. Only needed when - ``mtp_use_repeated_layer`` is False (each depth uses a - distinct layer). Omit for repeated-layer models so that a + depth (int, optional): MTP depth index. Only needed when `mtp_use_repeated_layer` is + False (each depth uses a distinct layer). Omit for repeated-layer models so that a single CUDA graph can serve all depths. + eager, cache_key: The `CudaGraphManager` works by monkey-patching this argument onto the + function signature. Explictly including them removes the need for a monkey-patch, + and makes it straightforward to call the same method with and without eager mode. + These arguments are consumed by `CudaGraphManager`, if it exists. Returns: tuple: (new_hidden_states, logits [N, 1, vocab_size]). """ + # CudaGraphManager consumes these args, if it exists + del eager, cache_key layer_idx = 0 if depth is None else depth mtp_hidden = self.mtp.layers[layer_idx].forward_single_position( hidden_states=hidden_states, diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 06e1d8958fe..9cd8cd2ffb6 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -8,7 +8,7 @@ import os import time from collections import defaultdict -from contextlib import contextmanager, nullcontext +from contextlib import nullcontext from copy import deepcopy from dataclasses import dataclass, is_dataclass from enum import Enum @@ -98,16 +98,6 @@ def _set_capture_end(): _IS_GRAPH_CAPTURING = False -@contextmanager -def graph_capture(): - """Context manager that brackets a graph-capture region.""" - _set_capture_start() - try: - yield - finally: - _set_capture_end() - - def is_graph_warmup(): """Query if currently warming up for graph capture.""" return _IS_GRAPH_WARMUP @@ -341,10 +331,6 @@ class _CudagraphGlobalRecord: cudagraph_record: list[tuple] = [] cudagraph_inference_record: list[tuple] = [] - # MTP CudaGraphManagers registered at construction time so that - # delete_cuda_graphs() can clear their lookup tables. - mtp_cudagraph_managers: list = [] - """A pool-like data structure to reuse input and output buffers across cudagraph.""" tensor_reuse_pool = TensorReusePool() @@ -520,19 +506,6 @@ def delete_cuda_graphs(): runner.bwd_graph = None runner.mempool = None - # Reset MTP runners (excluded from the global inference record). - for mgr in _CudagraphGlobalRecord.mtp_cudagraph_managers: - for runner in mgr.cudagraph_runners: - runner.cudagraph_created = False - runner.fwd_graph_recorded = False - runner.bwd_graph_recorded = False - runner.fwd_graph = None - runner.bwd_graph = None - runner.mempool = None - mgr.cudagraph_runners.clear() - mgr.custom_cudagraphs_lookup_table.clear() - _CudagraphGlobalRecord.mtp_cudagraph_managers.clear() - # Reset global tracking state _CudagraphGlobalRecord.cudagraph_created = False _CudagraphGlobalRecord.cudagraph_record = [] @@ -832,6 +805,7 @@ def create_fwd_graph(self, args, kwargs, outputs=None, clone_inputs=True): # graph capture's forward passes do not corrupt its value. Inference is not affected # (no known buffer mutators) and would add new buffers (lazy MoE _fc1_weight/ # _fc2_weight) that misalign the positional restore. + if self.training and torch.is_grad_enabled(): buffer_backup = [] for buf in self.base_module.buffers(): @@ -1446,7 +1420,6 @@ def __init__( function_name=None, need_backward=True, pg_collection=None, - is_mtp_inference=False, inline_capture=False, num_warmup_steps=None, ): @@ -1456,7 +1429,6 @@ def __init__( Args: config: TransformerConfig object containing CUDA graph settings for memory pooling, graph retention, gradient accumulation, FP8/FP4, and warmup steps. - is_mtp_inference: Whether this manager wraps an MTP inference forward pass. inline_capture: Normally, whether the inline capture path is taken depends on whether `inference_context` is present in the kwargs of the forward call. Setting this argument to True always forces the inline capture path to be taken. @@ -1469,7 +1441,6 @@ def __init__( self.pg_collection = pg_collection rng_tracker = get_cuda_rng_tracker() self.need_backward = need_backward - self.is_mtp_inference = is_mtp_inference if function_name is not None: func = getattr(base_module, function_name) @@ -1515,10 +1486,6 @@ def wrapped_func(*args, eager=False, cache_key=None, **kwargs): self.custom_cudagraphs_lookup_table: dict = defaultdict(lambda: None) self.is_first_microbatch = False - if is_mtp_inference: - # Registered so delete_cuda_graphs() can clear the lookup table. - _CudagraphGlobalRecord.mtp_cudagraph_managers.append(self) - # Without pipeline parallelism, microbatches execute one at a time. # Therefore modules will always execute in the same order, so cudagraphs # can both be reused and share a single mempool. @@ -1624,19 +1591,14 @@ def __call__(self, megatron_module, args, kwargs, cache_key=None): cache_key: Optional hashable key for O(1) runner lookup. If `inference_context` is provided, this gets set to the correct value. """ - is_inference_mode = ( - 'inference_context' in kwargs.keys() and kwargs['inference_context'] - ) or self.is_mtp_inference + is_inference_mode = 'inference_context' in kwargs.keys() and kwargs['inference_context'] if cache_key is None and is_inference_mode: - if 'inference_context' in kwargs and kwargs['inference_context']: - inference_context = kwargs['inference_context'] - if inference_context.is_static_batching(): - batch_size = kwargs['hidden_states'].shape[0] - cache_key = (batch_size, inference_context.is_decode_only()) - else: - cache_key = inference_context.padded_batch_dimensions - elif self.is_mtp_inference: - cache_key = ('mtp', kwargs['hidden_states'].shape, kwargs.get('depth')) + inference_context = kwargs['inference_context'] + if inference_context.is_static_batching(): + batch_size = kwargs['hidden_states'].shape[0] + cache_key = (batch_size, inference_context.is_decode_only()) + else: + cache_key = inference_context.padded_batch_dimensions is_in_checkpoint_fwd = is_checkpointing() if HAVE_TE_GRAPHS: is_in_checkpoint_fwd = is_in_checkpoint_fwd or is_fp8_activation_recompute_enabled() @@ -1654,39 +1616,26 @@ def __call__(self, megatron_module, args, kwargs, cache_key=None): out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs) else: if is_inference_mode or self._inline_capture: - # MTP must match the main model's eager/graph mode so all EP - # ranks take the same code path. Skip during graph capture. - if ( - self.is_mtp_inference - and not getattr(megatron_module, 'use_mtp_cuda_graphs', False) - and not is_graph_capturing() - ): - return self.func(*args, **kwargs) - # Inference generation mode creates graphs immediately runner = self.get_cudagraph_runner( megatron_module, args, kwargs, True, cache_key=cache_key ) - if ( - not runner.fwd_graph_recorded - and self.is_mtp_inference - and not is_graph_capturing() - ): - # No pre-warmed graph for this batch size — run eagerly. - return self.func(*args, **kwargs) - if not runner.fwd_graph_recorded: # Reuse graph input-output buffers for inference local_args, local_kwargs = args, kwargs if not runner.is_first_layer: - # Find previous layer's runner in the global record + # Find previous layer's runner in the global record. + # Method-wrapped managers (e.g. the MTP wrapper around + # `compute_mtp_single_step`) have a base_module without + # `layer_number`; `getattr(..., None)` makes those rows + # harmlessly skipped by the predicate. try: previous_runner = next( r for r in _CudagraphGlobalRecord.cudagraph_inference_record if ( - r[0].base_module.layer_number + getattr(r[0].base_module, 'layer_number', None) == runner.base_module.layer_number - 1 and r[0].fwd_graph is not None and ArgMetadata(r[3]['hidden_states']) @@ -1707,14 +1656,10 @@ def __call__(self, megatron_module, args, kwargs, cache_key=None): runner.cudagraph_created = True runner = runner.eval() - # Record to the global execution record. MTP runners are - # excluded — they don't chain with decoder layers (the - # previous-layer lookup expects layer_number) and are - # cleaned up via mtp_cudagraph_managers instead. - if not self.is_mtp_inference: - _CudagraphGlobalRecord.cudagraph_inference_record.append( - (runner, "fwd", args, kwargs) - ) + # Record this to the global execution record + _CudagraphGlobalRecord.cudagraph_inference_record.append( + (runner, "fwd", args, kwargs) + ) # Now replay the graph out = runner.replay_graph_capture(self.is_first_microbatch, args, kwargs) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 4f02369d0ed..e88924f1af0 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -2294,7 +2294,9 @@ def mock_mtp_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.zeros( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -2417,7 +2419,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) # Predict next_token_ids + 1 (continuing the ascending sequence) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) @@ -2501,7 +2505,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) # Predict next_token_ids + 1 (continuing the ascending sequence) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) @@ -2586,7 +2592,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) # Predict next_token_ids + 1 (continuing the ascending sequence) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) @@ -2709,8 +2717,12 @@ def deterministic_forward(*args, **kwargs): # Wrap the real MTP step similarly. real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -2863,8 +2875,17 @@ def deterministic_forward(*args, **kwargs): real_mtp = model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, + next_token_ids, + position_ids, + depth, + eager=eager, + cache_key=cache_key, + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -2940,7 +2961,9 @@ def mock_safe_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.zeros( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3153,7 +3176,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.randn( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3273,7 +3298,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.randn( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3403,7 +3430,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) logits = torch.randn( n, 1, test_config.vocab_size, device=hidden_states.device, dtype=torch.bfloat16 @@ -3656,7 +3685,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_wrong(hidden_states, next_token_ids, position_ids, depth): + def mock_compute_mtp_wrong( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): n = hidden_states.size(0) wrong_toks = (next_token_ids + 5).clamp(max=test_config.vocab_size - 1) logits = torch.zeros( @@ -3752,7 +3783,9 @@ def mock_deterministic_forward(*args, **kwargs): ) return base_logits - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): n = hidden_states.size(0) pred_toks = (next_token_ids + 1).clamp(max=test_config.vocab_size - 1) logits = torch.zeros( @@ -3864,8 +3897,12 @@ def deterministic_forward(*args, **kwargs): real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -3973,8 +4010,12 @@ def deterministic_forward(*args, **kwargs): # Deterministic MTP: also predict token 0 → all speculative tokens accepted. real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -4045,8 +4086,12 @@ def deterministic_forward(*args, **kwargs): # During prefill, no MTP runs, so request 2 is unaffected. real_mtp = unwrapped_model.compute_mtp_single_step - def heterogeneous_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def heterogeneous_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) n = logits.size(0) logits.zero_() if n >= 2: @@ -4144,8 +4189,12 @@ def deterministic_forward(*args, **kwargs): real_mtp = unwrapped_model.compute_mtp_single_step - def deterministic_mtp(hidden_states, next_token_ids, position_ids, depth): - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + def deterministic_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() logits[..., 0] = 100.0 return hidden_states, logits @@ -4247,9 +4296,13 @@ def deterministic_forward(*args, **kwargs): real_mtp = unwrapped_model.compute_mtp_single_step - def mtp_with_rejection(hidden_states, next_token_ids, position_ids, depth): + def mtp_with_rejection( + hidden_states, next_token_ids, position_ids, depth, eager=False, cache_key=None + ): # Run real MTP to exercise Mamba intermediate state saving. - hidden_states, logits = real_mtp(hidden_states, next_token_ids, position_ids, depth) + hidden_states, logits = real_mtp( + hidden_states, next_token_ids, position_ids, depth, eager=eager, cache_key=cache_key + ) logits.zero_() if rejection_mode == "all_accepted": # Predict token 0 (same as base) → accepted. diff --git a/tests/unit_tests/inference/engines/test_mtp_cuda_graph_inference.py b/tests/unit_tests/inference/engines/test_mtp_cuda_graph_inference.py index d6605e88d0d..d266e2ef1ed 100644 --- a/tests/unit_tests/inference/engines/test_mtp_cuda_graph_inference.py +++ b/tests/unit_tests/inference/engines/test_mtp_cuda_graph_inference.py @@ -39,7 +39,7 @@ from megatron.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer import TransformerConfig -from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord, delete_cuda_graphs +from megatron.core.transformer.cuda_graphs import delete_cuda_graphs from megatron.core.transformer.enums import AttnBackend from megatron.core.utils import unwrap_model from tests.unit_tests.test_utilities import Utils @@ -182,10 +182,16 @@ def _get_mtp_warmed_batch_sizes(engine): return sorted(sizes) @staticmethod - def _set_mtp_cuda_graph_flag(model, enabled): - """Set `use_mtp_cuda_graphs` on the model.""" - unwrapped = unwrap_model(model) - unwrapped.use_mtp_cuda_graphs = enabled + def _mtp_kwargs(use_graph, batch_size, mtp_depth): + """Construct call-site kwargs that route `compute_mtp_single_step` to + either CUDA graph replay or eager execution. + + The wrapped `compute_mtp_single_step` honors `eager=True` to bypass the + manager and `cache_key=...` for O(1) runner lookup. + """ + if use_graph: + return {"cache_key": ("mtp", batch_size, mtp_depth)} + return {"eager": True} @staticmethod def _assert_mtp_cuda_graphs_were_replayed(model, expect_replayed): @@ -248,22 +254,22 @@ def test_cuda_graph_output_matches_eager(self, mtp_use_repeated_layer): dist.broadcast(token_ids, src=0) position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) - self._set_mtp_cuda_graph_flag(model, True) h_graph, logits_graph = unwrapped.compute_mtp_single_step( hidden_states=hidden.clone(), next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=mtp_depth, + **self._mtp_kwargs(use_graph=True, batch_size=batch_size, mtp_depth=mtp_depth), ) h_graph = h_graph.clone() logits_graph = logits_graph.clone() - self._set_mtp_cuda_graph_flag(model, False) h_eager, logits_eager = unwrapped.compute_mtp_single_step( hidden_states=hidden.clone(), next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=mtp_depth, + **self._mtp_kwargs(use_graph=False, batch_size=batch_size, mtp_depth=mtp_depth), ) torch.testing.assert_close( @@ -308,22 +314,22 @@ def test_cuda_graph_output_matches_eager_with_sp(self, mtp_use_repeated_layer): dist.broadcast(token_ids, src=0) position_ids = torch.arange(batch_size, device='cuda', dtype=torch.int64).unsqueeze(0) - self._set_mtp_cuda_graph_flag(model, True) h_graph, logits_graph = unwrapped.compute_mtp_single_step( hidden_states=hidden_sp.clone(), next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=mtp_depth, + **self._mtp_kwargs(use_graph=True, batch_size=batch_size, mtp_depth=mtp_depth), ) h_graph = h_graph.clone() logits_graph = logits_graph.clone() - self._set_mtp_cuda_graph_flag(model, False) h_eager, logits_eager = unwrapped.compute_mtp_single_step( hidden_states=hidden_sp.clone(), next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=mtp_depth, + **self._mtp_kwargs(use_graph=False, batch_size=batch_size, mtp_depth=mtp_depth), ) torch.testing.assert_close( @@ -410,7 +416,7 @@ def test_cuda_graph_sp_padding_end_to_end(self, mtp_use_repeated_layer): ctrl._last_accepted_seq_indices = torch.arange(active_request_count, device='cuda') ctrl._mtp_resolved_padded_count = padded_count - self._set_mtp_cuda_graph_flag(model, True) + context._using_cuda_graph_this_step = True ctrl._torch_sampling_buckets = [(list(range(active_request_count)), 1.0, 1, 0.0)] ctrl._torch_sampling_bucket_index_tensors = [ @@ -496,13 +502,11 @@ def _run_mtp(use_cuda_graph): ) if use_cuda_graph: - ctrl.has_mtp_cuda_graphs = True ctrl._mtp_resolved_padded_count = padded_count - self._set_mtp_cuda_graph_flag(model, True) + context._using_cuda_graph_this_step = True else: - ctrl.has_mtp_cuda_graphs = False ctrl._mtp_resolved_padded_count = None - self._set_mtp_cuda_graph_flag(model, False) + context._using_cuda_graph_this_step = False tp_rank = parallel_state.get_tensor_model_parallel_rank() @@ -561,7 +565,6 @@ def test_cuda_graph_multi_depth(self, mtp_use_repeated_layer): use_repeated = unwrapped.mtp.mtp_use_repeated_layer batch_size = batch_sizes[0] - self._set_mtp_cuda_graph_flag(model, True) hidden = torch.randn(batch_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16) dist.broadcast(hidden, src=0) @@ -577,6 +580,7 @@ def test_cuda_graph_multi_depth(self, mtp_use_repeated_layer): next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=mtp_depth, + **self._mtp_kwargs(use_graph=True, batch_size=batch_size, mtp_depth=mtp_depth), ) current_hidden = current_hidden.clone() @@ -594,14 +598,14 @@ def test_cuda_graph_multi_depth(self, mtp_use_repeated_layer): self._assert_mtp_cuda_graphs_were_replayed(model, True) - # ---- Test 6: eager fallback when no matching graph exists ------------- # + # ---- Test 6: caller-driven eager bypass for non-warmed shapes --------- # @pytest.mark.parametrize("mtp_use_repeated_layer", [False, True]) @torch.inference_mode() - def test_eager_fallback_no_matching_graph(self, mtp_use_repeated_layer): - """When `use_mtp_cuda_graphs` is True but no warmed graph matches the - batch size, `compute_mtp_single_step` falls back to eager execution. - The system should produce valid outputs without errors. + def test_eager_bypass_for_non_warmed_shape(self, mtp_use_repeated_layer): + """Passing `eager=True` runs `compute_mtp_single_step` outside the + CudaGraphManager wrapper. This is the canonical caller-side fallback + for a shape that warmup did not capture. """ engine = self._build_engine(mtp_use_repeated_layer=mtp_use_repeated_layer) model = engine.controller.inference_wrapped_model.model @@ -618,7 +622,6 @@ def test_eager_fallback_no_matching_graph(self, mtp_use_repeated_layer): mtp_depth = None if unwrapped.mtp.mtp_use_repeated_layer else 0 - self._set_mtp_cuda_graph_flag(model, True) hidden = torch.randn( fallback_size, 1, self.HIDDEN_SIZE, device='cuda', dtype=torch.bfloat16 ) @@ -632,36 +635,21 @@ def test_eager_fallback_no_matching_graph(self, mtp_use_repeated_layer): next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=mtp_depth, + eager=True, ) assert h_out.shape == (fallback_size, 1, self.HIDDEN_SIZE) assert logits.shape == (fallback_size, 1, self.VOCAB_SIZE) assert torch.all(torch.isfinite(logits)) - # ---- Test 7: graph flag propagation matches main model ---------------- # - - @torch.inference_mode() - def test_mtp_graph_flag_propagation(self): - """`use_mtp_cuda_graphs` is correctly toggled via the helper.""" - model = self._build_model(mtp_num_layers=2) - unwrapped = unwrap_model(model) - - self._set_mtp_cuda_graph_flag(model, True) - assert unwrapped.use_mtp_cuda_graphs is True - - self._set_mtp_cuda_graph_flag(model, False) - assert unwrapped.use_mtp_cuda_graphs is False - - # ---- Test 8: delete_cuda_graphs resets MTP runners -------------------- # + # ---- Test 7: delete_cuda_graphs resets MTP runners -------------------- # @torch.inference_mode() def test_delete_cuda_graphs_resets_mtp_runners(self): """`delete_cuda_graphs()` resets MTP CUDA graph runners. - MTP runners are excluded from the global inference record, so they - require special handling in `delete_cuda_graphs()`. After deletion, - no MTP runners should have `fwd_graph_recorded=True` and the global - manager list should be cleared. + MTP runners join the standard `cudagraph_inference_record`, so the + standard cleanup loop resets their `fwd_graph_recorded` flag. """ engine = self._build_engine() model = engine.controller.inference_wrapped_model.model @@ -671,12 +659,13 @@ def test_delete_cuda_graphs_resets_mtp_runners(self): unwrapped = unwrap_model(model) manager = getattr(unwrapped, '_mtp_cudagraph_manager', None) assert manager is not None - assert len(manager.custom_cudagraphs_lookup_table) > 0 + assert len(manager.cudagraph_runners) > 0 + assert all(r.fwd_graph_recorded for r in manager.cudagraph_runners) delete_cuda_graphs() - assert len(manager.custom_cudagraphs_lookup_table) == 0 - assert len(_CudagraphGlobalRecord.mtp_cudagraph_managers) == 0 + assert all(not r.fwd_graph_recorded for r in manager.cudagraph_runners) + assert all(r.fwd_graph is None for r in manager.cudagraph_runners) # --------------------------------------------------------------------------- # @@ -828,6 +817,7 @@ def test_ep_mtp_eager_forward(self, batch_size): next_token_ids=token_ids.clone(), position_ids=position_ids.clone(), depth=0, + eager=True, ) assert h_out.shape == (batch_size, 1, self.HIDDEN_SIZE) @@ -865,7 +855,11 @@ def test_ep_mtp_eager_dummy_and_real_ranks(self): # All ranks must complete without hanging. h_out, logits = unwrapped.compute_mtp_single_step( - hidden_states=hidden, next_token_ids=token_ids, position_ids=position_ids, depth=0 + hidden_states=hidden, + next_token_ids=token_ids, + position_ids=position_ids, + depth=0, + eager=True, ) assert h_out.shape == (batch_size, 1, self.HIDDEN_SIZE) @@ -985,6 +979,7 @@ def test_ep_dummy_bailout_with_decode_only_cuda_graphs(self, peer_state): next_token_ids=dummy_tokens, position_ids=dummy_positions, depth=0, + eager=True, ) assert h_out.shape == (tp_size, 1, self.HIDDEN_SIZE) diff --git a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py index dd4764ee92d..f0c41ca7d83 100644 --- a/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py +++ b/tests/unit_tests/inference/text_generation_controllers/test_text_generation_controller.py @@ -1394,7 +1394,9 @@ def test_speculative_mtp_position_ids_with_prefill(self): captured_position_ids = [] - def mock_compute_mtp_single_step(hidden_states, next_token_ids, position_ids, depth=None): + def mock_compute_mtp_single_step( + hidden_states, next_token_ids, position_ids, depth=None, eager=False, cache_key=None + ): captured_position_ids.append(position_ids.clone()) return hidden_states, torch.randn(2, 1, self.vocab_size, device='cuda') diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index 4d29f080d24..fe3fe287622 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -1100,6 +1100,24 @@ def my_op(self, x): return self.linear(x) +class _SimpleNonModule: + """non-nn.Module base_module for testing the function_name= form of `CudaGraphManager`.""" + + def __init__(self, config): + self.weight = torch.randn(config.hidden_size, config.hidden_size, device="cuda") + + def my_op(self, x): + return x @ self.weight + + +def _make_simple_module(config): + return _SimpleModule(config).cuda().eval() + + +def _make_simple_non_module(config): + return _SimpleNonModule(config) + + class TestInlineCaptureManager: """Tests for CudaGraphManager with inline_capture, function_name, eager, and cache_key.""" @@ -1126,11 +1144,18 @@ def teardown_method(self, method): CudaGraphManager.global_mempool = None Utils.destroy_model_parallel() + @pytest.mark.parametrize( + "make_module", + [ + pytest.param(_make_simple_module, id="nn_module"), + pytest.param(_make_simple_non_module, id="plain_class"), + ], + ) @torch.inference_mode() - def test_inline_capture_matches_eager(self): + def test_inline_capture_matches_eager(self, make_module): """Inline-captured graph output must match eager execution.""" config = self._make_config() - module = _SimpleModule(config).cuda().eval() + module = make_module(config) # Get eager reference before wrapping x = torch.randn(4, config.hidden_size, device="cuda")