From c2852f3e777d7003de8b7b10963063aac7f83fe2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 17:57:44 +0800 Subject: [PATCH 01/31] feat(engine): LinearCrossEntropy --- areal/api/cli_args.py | 7 + .../fsdp_utils/kernels/fused_experts.py | 430 ++++ areal/engine/megatron_engine.py | 130 +- .../megatron_utils/fused_lce_capture.py | 203 ++ areal/utils/functional/__init__.py | 7 + .../utils/functional/linear_cross_entropy.py | 167 ++ areal/utils/kernel/__init__.py | 22 + areal/utils/kernel/kernels.py | 1757 +++++++++++++++++ areal/utils/kernel/linear_cross_entropy.py | 271 +++ benchmark/bench_linear_cross_entropy.py | 174 ++ tests/test_linear_cross_entropy.py | 332 ++++ 11 files changed, 3485 insertions(+), 15 deletions(-) create mode 100644 areal/engine/fsdp_utils/kernels/fused_experts.py create mode 100644 areal/engine/megatron_utils/fused_lce_capture.py create mode 100644 areal/utils/functional/linear_cross_entropy.py create mode 100644 areal/utils/kernel/__init__.py create mode 100644 areal/utils/kernel/kernels.py create mode 100644 areal/utils/kernel/linear_cross_entropy.py create mode 100644 benchmark/bench_linear_cross_entropy.py create mode 100644 tests/test_linear_cross_entropy.py diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index af103a9669..17e0b556ef 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -461,6 +461,13 @@ def __post_init__(self): }, ) + use_fused_moe: bool = field( + default=True, + metadata={ + "help": "" + }, + ) + @dataclass class ArchonFP8Config: diff --git a/areal/engine/fsdp_utils/kernels/fused_experts.py b/areal/engine/fsdp_utils/kernels/fused_experts.py new file mode 100644 index 0000000000..4e19d7604c --- /dev/null +++ b/areal/engine/fsdp_utils/kernels/fused_experts.py @@ -0,0 +1,430 @@ +"""Fused MoE autograd functions adapted for AReaL FSDP backend. + +Forward reuses SGLang's Triton kernels. Backward uses a Triton kernel written in +``fused_moe_triton_backward_kernels.py`` that computes ``grad_input``, ``grad_weight`` +and (optionally) ``grad_topk_weights`` with ``tl.atomic_add``. + +Debug logging can be enabled by setting the environment variable +``AREAL_FUSED_MOE_DEBUG=1``. When enabled, each intermediate tensor of the pipeline +is printed (shape + mean/std + first few values) so that the fused path can be +compared against a reference implementation by ``diff``-ing two runs. +""" + +from __future__ import annotations + +import os + +import torch +import triton.language as tl +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + invoke_fused_moe_kernel, + moe_align_block_size, + moe_sum_reduce, + silu_and_mul, +) + +from .fused_moe_triton_backward_kernels import invoke_fused_moe_backward_kernel + +_DEBUG = os.environ.get("AREAL_FUSED_MOE_DEBUG", "0") == "1" + + +def _dbg(tag: str, t: torch.Tensor | None) -> None: + """Emit a compact summary of a tensor when debug mode is on. + + Only rank-0 logging is fine because callers run the same op on every rank and + we just want to sanity-check numerical content during a single-process test. + """ + if not _DEBUG or t is None: + return + try: + flat = t.detach().float().reshape(-1) + head = flat[:4].tolist() + print( + f"[fused_moe] {tag}: shape={tuple(t.shape)} dtype={t.dtype} " + f"mean={flat.mean().item():.6e} std={flat.std().item():.6e} head={head}" + ) + except Exception as e: # pragma: no cover + print(f"[fused_moe] {tag}: ") + + +class GateUpProjFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + hidden_states: torch.Tensor, + w1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + # Match slime / vLLM convention: chunked launch to avoid the bug + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = 64 * 1024 + + # Default deterministic config. Tuned for H800 / A100 bf16 MoE. + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + + topk = topk_ids.shape[1] + + intermediate_cache1 = torch.empty( + (num_tokens * topk, N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + _dbg("gate_up.fwd.hidden_states", hidden_states) + _dbg("gate_up.fwd.w1", w1) + _dbg("gate_up.fwd.topk_weights", topk_weights) + _dbg("gate_up.fwd.topk_ids", topk_ids) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + cur_intermediate_cache1 = intermediate_cache1[ + begin_chunk_idx * topk : end_chunk_idx * topk + ] + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + invoke_fused_moe_kernel( + curr_hidden_states, + w1, + None, + cur_intermediate_cache1, + None, + None, + None, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=tl.bfloat16, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None, + c_sorted=False, + filter_expert=True, + ) + + _dbg("gate_up.fwd.intermediate_cache1", intermediate_cache1) + + ctx.save_for_backward(hidden_states, w1, topk_weights, topk_ids) + ctx.config = config + ctx.num_tokens = num_tokens + ctx.topk = topk + + return intermediate_cache1 + + @staticmethod + def backward(ctx, grad_output): + """Backward for GateUpProj using Triton kernels. + + ``grad_output`` has shape ``(num_tokens * topk, N)``. We return + ``(grad_hidden_states, grad_w1, None, None)`` because ``topk_weights`` and + ``topk_ids`` are not multiplied in the forward kernel for this stage. + """ + hidden_states, w1, topk_weights, topk_ids = ctx.saved_tensors + config = ctx.config + num_tokens = ctx.num_tokens + topk = ctx.topk + + E, N, D_in = w1.shape + CHUNK_SIZE = 64 * 1024 + + grad_hidden_states = torch.zeros_like(hidden_states) + grad_w1 = torch.zeros_like(w1) + grad_topk_weights = torch.zeros_like(topk_weights) + + _dbg("gate_up.bwd.grad_output", grad_output) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + + curr_num_tokens = end_chunk_idx - begin_chunk_idx + if curr_num_tokens == 0: + continue + + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_grad_output = grad_output[begin_chunk_idx * topk : end_chunk_idx * topk] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + curr_grad_hidden_states = torch.zeros_like(curr_hidden_states) + curr_grad_w1 = torch.zeros_like(w1) + + invoke_fused_moe_backward_kernel( + grad_output=curr_grad_output, + input=curr_hidden_states, + weight=w1, + grad_input=curr_grad_hidden_states, + grad_weight=curr_grad_w1, + grad_topk_weights=None, + topk_weights=curr_topk_weights, + topk_ids=curr_topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_routed_weight=False, + top_k=topk, + config=config, + compute_type=tl.bfloat16, + ) + + grad_hidden_states[begin_chunk_idx:end_chunk_idx] += curr_grad_hidden_states + grad_w1 += curr_grad_w1 + + _dbg("gate_up.bwd.grad_hidden_states", grad_hidden_states) + _dbg("gate_up.bwd.grad_w1", grad_w1) + + return grad_hidden_states, grad_w1, grad_topk_weights, None + + +class SiluAndMulFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, intermediate_cache1: torch.Tensor): + num_tokens, N = intermediate_cache1.shape + intermediate_cache2 = torch.empty( + (num_tokens, N // 2), + device=intermediate_cache1.device, + dtype=intermediate_cache1.dtype, + ) + silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) + _dbg("silu.fwd.intermediate_cache2", intermediate_cache2) + + ctx.save_for_backward(intermediate_cache1) + return intermediate_cache2 + + @staticmethod + def backward(ctx, grad_output): + (intermediate_cache1,) = ctx.saved_tensors + N = intermediate_cache1.shape[-1] + x1, x2 = intermediate_cache1.view(-1, N).chunk(2, dim=-1) + silu_x1 = torch.nn.functional.silu(x1) + + sig = torch.sigmoid(x1) + dsilu_dx1 = sig + x1 * sig * (1 - sig) + grad_x1 = grad_output * x2 * dsilu_dx1 + grad_x2 = grad_output * silu_x1 + grad_input = torch.cat([grad_x1, grad_x2], dim=-1) + _dbg("silu.bwd.grad_input", grad_input) + + return grad_input.view_as(intermediate_cache1) + + +class DownProjFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + intermediate_cache2: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ): + num_tokens, _ = intermediate_cache2.shape + topk = topk_ids.shape[1] + num_tokens //= topk + E, _, _ = w2.shape + CHUNK_SIZE = 64 * 1024 + + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + } + + intermediate_cache3 = torch.empty( + (num_tokens, topk, w2.shape[1]), + device=intermediate_cache2.device, + dtype=intermediate_cache2.dtype, + ) + + _dbg("down.fwd.intermediate_cache2", intermediate_cache2) + _dbg("down.fwd.w2", w2) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + cur_intermediate_cache2 = intermediate_cache2[ + begin_chunk_idx * topk : end_chunk_idx * topk + ] + cur_intermediate_cache3 = intermediate_cache3[begin_chunk_idx:end_chunk_idx] + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + invoke_fused_moe_kernel( + cur_intermediate_cache2, + w2, + None, + cur_intermediate_cache3, + None, + None, + None, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=tl.bfloat16, + use_fp8_w8a8=False, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False, + block_shape=None, + a_use_tma=False, + b_use_tma=False, + ) + + _dbg("down.fwd.intermediate_cache3", intermediate_cache3) + + ctx.save_for_backward(intermediate_cache2, w2, topk_weights, topk_ids) + ctx.config = config + ctx.num_tokens = num_tokens + ctx.topk = topk + + return intermediate_cache3 + + @staticmethod + def backward(ctx, grad_output): + """Backward for DownProj. + + ``grad_output`` has shape ``(num_tokens, topk, hidden_size)``. + Returns ``(grad_intermediate_cache2, grad_w2, grad_topk_weights, None)``. + """ + intermediate_cache2, w2, topk_weights, topk_ids = ctx.saved_tensors + config = ctx.config + num_tokens = ctx.num_tokens + topk = ctx.topk + + E, hidden_size, intermediate_size = w2.shape + CHUNK_SIZE = 64 * 1024 + + grad_intermediate_cache2 = torch.zeros_like(intermediate_cache2) + grad_w2 = torch.zeros_like(w2) + grad_topk_weights = torch.zeros_like(topk_weights) + + _dbg("down.bwd.grad_output", grad_output) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + + curr_num_tokens = end_chunk_idx - begin_chunk_idx + if curr_num_tokens == 0: + continue + + curr_intermediate_cache2 = intermediate_cache2[ + begin_chunk_idx * topk : end_chunk_idx * topk + ] + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + curr_grad_output = grad_output[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) + + curr_grad_intermediate_cache2 = torch.zeros_like(curr_intermediate_cache2) + curr_grad_w2 = torch.zeros_like(w2) + curr_grad_topk_weights = torch.zeros_like(curr_topk_weights) + + # Note: Use top_k=1 to match forward pass indexing convention of + # DownProj (each routed copy is its own "token"). + invoke_fused_moe_backward_kernel( + grad_output=curr_grad_output, + input=curr_intermediate_cache2, + weight=w2, + grad_input=curr_grad_intermediate_cache2, + grad_weight=curr_grad_w2, + grad_topk_weights=curr_grad_topk_weights, + topk_weights=curr_topk_weights, + topk_ids=curr_topk_ids, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_tokens_post_padded=num_tokens_post_padded, + mul_routed_weight=True, + top_k=1, + config=config, + compute_type=tl.bfloat16, + ) + + grad_intermediate_cache2[ + begin_chunk_idx * topk : end_chunk_idx * topk + ] = curr_grad_intermediate_cache2 + grad_w2 += curr_grad_w2 + grad_topk_weights[begin_chunk_idx:end_chunk_idx] = curr_grad_topk_weights + + _dbg("down.bwd.grad_intermediate_cache2", grad_intermediate_cache2) + _dbg("down.bwd.grad_w2", grad_w2) + _dbg("down.bwd.grad_topk_weights", grad_topk_weights) + + return grad_intermediate_cache2, grad_w2, grad_topk_weights, None + + +class MoeSumReduceFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + intermediate_cache3: torch.Tensor, + hidden_states_shape, + ): + out_hidden_states = torch.empty( + hidden_states_shape, + device=intermediate_cache3.device, + dtype=intermediate_cache3.dtype, + ) + moe_sum_reduce( + intermediate_cache3, + out_hidden_states, + 1.0, + ) + _dbg("sum_reduce.fwd.out_hidden_states", out_hidden_states) + ctx.save_for_backward(intermediate_cache3) + return out_hidden_states + + @staticmethod + def backward(ctx, grad_output): + (intermediate_cache3,) = ctx.saved_tensors + grad = grad_output.unsqueeze(1).expand_as(intermediate_cache3) + _dbg("sum_reduce.bwd.grad_input", grad) + return grad, None diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 49119c8e53..44e790437c 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -106,7 +106,17 @@ split_padded_tensor_dict_into_mb_list, unpad_logits, ) -from areal.utils.functional import gather_logprobs, gather_logprobs_entropy +from areal.engine.megatron_utils.fused_lce_capture import ( + FUSED_LCE_HIDDEN_KEY, + FUSED_LCE_WEIGHT_KEY, + capture_lm_head_hidden, +) +from areal.utils.functional import ( + gather_logprobs, + gather_logprobs_entropy, + linear_cross_entropy_logprobs, + linear_cross_entropy_logprobs_entropy, +) from areal.utils.hf_utils import load_hf_tokenizer from areal.utils.lock import DistributedLock from areal.utils.network import find_free_ports, format_host_for_url, gethostip @@ -710,6 +720,17 @@ def forward_backward_batch( ) -> None: self._ensure_ready() + # Resolve once per call: whether the fused linear-cross-entropy path + # should engage. We engage it only on the pipeline-last stage, in + # non-critic mode, and outside the tree-training branch (those branches + # bring their own gather kernels and additional invariants we do not + # currently extend). + use_fused_lce = ( + getattr(self.config, "use_fused_linear_ce", False) + and not self.config.is_critic + and not self.enable_tree_training + ) + def forward_step(batch_iter, model): mb_input: MicroBatchItem = next(batch_iter) @@ -740,12 +761,40 @@ def forward_step(batch_iter, model): cp_size = mpu.get_context_parallel_world_size() cp_local = cp_size > 1 - output = packed_context_parallel_forward( - model, - mb_input.padded_mb, - gather_cp_output=not cp_local, + # Engage fused linear-cross-entropy capture only on the pipeline + # last stage; the LM head only exists there. CP-local logit-gather + # path keeps the standard materialised-logits route because the + # split-and-gather machinery operates on the [seq, vocab] tensor. + model_vp_stage_for_capture = getattr(model, "vp_stage", 0) + should_capture = ( + use_fused_lce + and mpu.is_pipeline_last_stage( + ignore_virtual=False, vp_stage=model_vp_stage_for_capture + ) + and not cp_local ) + with capture_lm_head_hidden( + model, enabled=should_capture + ) as capture: + output = packed_context_parallel_forward( + model, + mb_input.padded_mb, + gather_cp_output=not cp_local, + ) + + # Stash captured hidden + LM-head weight on the orig_mb dict so + # the downstream loss/forward callbacks can pick them up via the + # standard `inputs` argument (which is the *same* dict). + if ( + capture is not None + and capture.hidden is not None + and capture.weight is not None + ): + mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = capture.hidden + mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight + mb_input.orig_mb["_fused_lce_active"] = True + # Release tree attention metadata after forward pass for key in tree_attn_keys: del mb_input.padded_mb[key] @@ -784,6 +833,13 @@ def _process_output(input_, output_): cu_seqlens=cu_seqlens, old_cu_seqlens=mb_input.old_cu_seqlens, ) + # When fused-LCE capture is active, the model's + # ``output`` is actually the pre-projection hidden + # state (the LM-head was monkey-patched to a no-op), + # so the unpadded tensor is the hidden we want to + # feed into the fused kernel. + if mb_input.orig_mb.get("_fused_lce_active", False): + mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output return output, functools.partial(_process_output, mb_input.orig_mb) forward_backward_func = get_forward_backward_func() @@ -1814,16 +1870,42 @@ def _compute_logprobs_and_loss( labels = cp_local_labels else: labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) - logprobs, entropy = gather_logprobs_entropy( - output, - labels, - temperature=self.config.temperature, - tp_group=mpu.get_tensor_model_parallel_group() - if mpu.get_tensor_model_parallel_world_size() > 1 - else None, - ) - vocab_min_logits = output.detach().min(-1).values.float() - vocab_max_logits = output.detach().max(-1).values.float() + fused_active = inputs.get("_fused_lce_active", False) + fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) + fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) + if ( + fused_active + and fused_hidden is not None + and fused_weight is not None + ): + logprobs, entropy = linear_cross_entropy_logprobs_entropy( + fused_hidden, + fused_weight, + labels, + temperature=self.config.temperature, + tp_group=mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None, + ) + # vocab_min/max_logits are diagnostics consumed by the + # clip-ratio statistics inside the PPO loss; the fused + # kernel never materialises the [seq, vocab] tensor, so + # we substitute finite proxies derived from the chosen + # logprobs (cheap and never stalls training). + proxy = logprobs.detach().float() + vocab_min_logits = proxy + vocab_max_logits = proxy + else: + logprobs, entropy = gather_logprobs_entropy( + output, + labels, + temperature=self.config.temperature, + tp_group=mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None, + ) + vocab_min_logits = output.detach().min(-1).values.float() + vocab_max_logits = output.detach().max(-1).values.float() loss = loss_fn( logprobs, entropy, @@ -1860,6 +1942,24 @@ def _compute_forward_result( ) return logprobs labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) + fused_active = inputs.get("_fused_lce_active", False) + fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) + fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) + if ( + fused_active + and fused_hidden is not None + and fused_weight is not None + ): + logprobs = linear_cross_entropy_logprobs( + fused_hidden, + fused_weight, + labels, + temperature=self.config.temperature, + tp_group=mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None, + ) + return logprobs logprobs = gather_logprobs( output, labels, diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py new file mode 100644 index 0000000000..5886026446 --- /dev/null +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LM-head hidden-state capture for the fused linear-cross-entropy fast path. + +The fused :func:`areal.utils.kernel.linear_cross_entropy` kernel needs the +pre-projection hidden state (``[seq, hidden]``) and the LM-head weight +(``[vocab, hidden]``, possibly vocab-sharded along the TP group) instead of +the materialised ``[seq, vocab]`` logits tensor. The Megatron-Core +:class:`GPTModel` does not expose either of these to AReaL's +``_compute_logprobs_and_loss`` call site by default, so we install a +temporary monkey-patch on ``output_layer.forward`` for the duration of one +microbatch forward pass: + +1. Stashes the input tensor (``hidden``) and the actual weight (either the + ``output_layer``'s own weight, or the embedding-tied weight passed in via + ``weight=``). +2. Returns ``(hidden, None)`` instead of ``(logits, bias)``. Because + :func:`areal.utils.data.unpad_logits` and + :func:`postprocess_packed_seqs_context_parallel` are shape-agnostic on + the leading sequence dim and propagate ``shape[1:]`` verbatim, the + returned hidden tensor flows through the rest of the engine pipeline + without modification — the engine's downstream code on the fused path + never inspects the trailing dim except to take a min/max for diagnostic + purposes, which we override with proxies in + ``MegatronEngine._compute_logprobs_and_loss``. + +The patch is installed only when ``enabled=True`` and uninstalled on +context exit (including on exception), so error-path leaks of the patched +method are impossible. + +Compatibility notes: + +* The patch is incompatible with Megatron-Core's MuP logit scaling + (``config.use_mup``), MTP (``config.mtp_num_layers > 0``) and inference + paths that materialise ``last_token_logits``. We assert against these + configurations at install time and refuse to engage; the engine then + falls back to the materialised path automatically. +* The patch is also incompatible with the critic value head, since that + head is a 1-output-dim ``ColumnParallelLinear`` and the fused kernel + requires the LM-head weight; the engine guards on ``is_critic`` before + calling this helper. +""" + +from __future__ import annotations + +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Iterator, Optional + +import torch +from megatron.core import parallel_state as mpu + +from areal.utils import logging + +logger = logging.getLogger("FusedLCECapture") + +# Keys used to pass captured tensors from forward_step → process_output. +# Centralised here to keep the engine and helper in lockstep. +FUSED_LCE_HIDDEN_KEY = "_fused_lce_hidden" +FUSED_LCE_WEIGHT_KEY = "_fused_lce_weight" + + +@dataclass +class _CaptureSlot: + """Mutable single-shot stash populated by the patched ``forward``.""" + + hidden: Optional[torch.Tensor] = None + weight: Optional[torch.Tensor] = None + + +def _unwrap_to_post_process_module(model: torch.nn.Module) -> Optional[torch.nn.Module]: + """Strip DDP/Float16Module wrappers and return the inner module that + owns ``output_layer`` (i.e. an mcore ``GPTModel`` on the last PP stage), + or ``None`` if no such module is reachable on this rank. + + Returning ``None`` (instead of raising) lets the caller skip the patch + transparently on intermediate pipeline stages. + """ + inner = model + # Loop bound: at most ~4 wrapper layers in practice (DDP, Float16Module, + # vp wrapper). 8 is a generous upper bound that protects against + # accidental cycles. + for _ in range(8): + if hasattr(inner, "output_layer") and inner.output_layer is not None: + return inner + if not hasattr(inner, "module"): + return None + inner = inner.module + return None + + +def _is_compatible(post_process_module: torch.nn.Module) -> bool: + """Refuse to engage when the model uses features incompatible with the + fused kernel. Falling back is preferred over silently producing wrong + numbers.""" + config = getattr(post_process_module, "config", None) + if config is None: + # Conservative default: don't patch unknown modules. + return False + + if getattr(config, "use_mup", False): + logger.warning( + "Fused LCE: MuP scaling is enabled (config.use_mup=True); " + "fused path is disabled for this microbatch." + ) + return False + if getattr(config, "mtp_num_layers", 0): + logger.warning( + "Fused LCE: MTP is enabled (config.mtp_num_layers>0); " + "fused path is disabled for this microbatch." + ) + return False + + output_layer = getattr(post_process_module, "output_layer", None) + if output_layer is None: + return False + + # Sequence parallel + TP gather inside output_layer is what we *want* + # to bypass; AReaL runs with parallel_output=True which keeps logits + # vocab-sharded — exactly what the fused kernel expects via tp_group. + parallel_output = getattr(post_process_module, "parallel_output", True) + if not parallel_output: + # If gather_output=True, the engine has already requested the + # full-vocab logits to be all-gathered; capturing hidden here would + # mean the downstream kernel needs to gather instead, doubling + # comms. Prefer the existing materialised path in that case. + logger.warning( + "Fused LCE: model has parallel_output=False; fused path is " + "disabled to avoid an extra TP gather." + ) + return False + + return True + + +@contextmanager +def capture_lm_head_hidden( + model: torch.nn.Module, *, enabled: bool +) -> Iterator[Optional[_CaptureSlot]]: + """Context manager that captures the input + weight handed to the + ``output_layer`` of the wrapped Megatron GPT model. + + Yields: + ``_CaptureSlot`` on the pipeline-last stage when ``enabled`` is + True and the model is compatible; ``None`` otherwise. The caller is + expected to inspect ``slot.hidden`` for ``None`` to decide whether + the fused path is usable for this microbatch. + """ + if not enabled: + yield None + return + + post_process = _unwrap_to_post_process_module(model) + if post_process is None or not _is_compatible(post_process): + # Either an intermediate PP stage or an incompatible config; the + # engine will transparently fall back to the materialised path. + yield None + return + + output_layer = post_process.output_layer + slot = _CaptureSlot() + original_forward = output_layer.forward + + def _patched_forward(input_, weight=None, runtime_gather_output=None): + # Resolve the actual weight: either passed in (weight tying) or the + # output_layer's own parameter. We intentionally store a *reference* + # (not detach) so autograd flows through both the kernel forward + # and backward. + actual_weight = weight if weight is not None else output_layer.weight + slot.hidden = input_ + slot.weight = actual_weight + # Return ``(input_, None)``: callers expect ``(logits, bias)`` and + # only ever destructure with ``logits, _ = output_layer(...)``. The + # downstream pipeline (``unpad_logits`` etc.) is shape-agnostic on + # the trailing dim, so passing ``hidden`` through is safe; the + # fused kernel will then consume the stashed tensors and produce + # the real per-token logprobs. + return input_, None + + # ``output_layer.forward = _patched_forward`` replaces the bound method + # at instance level (via ``__dict__`` lookup), shadowing the class + # method without mutating the class. Restoration in ``finally`` is + # therefore exception-safe. + output_layer.forward = _patched_forward # type: ignore[assignment] + try: + yield slot + finally: + # Best-effort restoration. ``del`` removes the instance-level + # binding and re-exposes the class-level method, which is what + # callers will execute on subsequent forwards. + try: + del output_layer.forward + except AttributeError: + # If __dict__ assignment is not supported (rare for nn.Module + # subclasses), fall back to direct restoration. + output_layer.forward = original_forward # type: ignore[assignment] + + +__all__ = [ + "FUSED_LCE_HIDDEN_KEY", + "FUSED_LCE_WEIGHT_KEY", + "capture_lm_head_hidden", +] diff --git a/areal/utils/functional/__init__.py b/areal/utils/functional/__init__.py index c91c3ff2b6..d1182361b8 100644 --- a/areal/utils/functional/__init__.py +++ b/areal/utils/functional/__init__.py @@ -11,6 +11,10 @@ reward_overlong_penalty, sapo_loss_fn, ) +from areal.utils.functional.linear_cross_entropy import ( + linear_cross_entropy_logprobs, + linear_cross_entropy_logprobs_entropy, +) from areal.utils.functional.vocab_parallel import ( gather_logprobs, gather_logprobs_entropy, @@ -30,4 +34,7 @@ # vocab_parallel.py "gather_logprobs", "gather_logprobs_entropy", + # linear_cross_entropy.py (fused linear + CE/entropy via Triton) + "linear_cross_entropy_logprobs", + "linear_cross_entropy_logprobs_entropy", ] diff --git a/areal/utils/functional/linear_cross_entropy.py b/areal/utils/functional/linear_cross_entropy.py new file mode 100644 index 0000000000..bf91c7b7e0 --- /dev/null +++ b/areal/utils/functional/linear_cross_entropy.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +High-level fused linear cross-entropy entry points for AReaL. + +These wrappers bridge the :class:`LinearCrossEntropy` Triton kernel into +AReaL's existing :func:`gather_logprobs_entropy` interface so that the +Megatron path can opt in via a single configuration flag without +restructuring the model forward. + +The wrappers: + +* accept already-flat ``hidden`` of shape ``(num_tokens, hidden_size)`` and + ``labels`` of shape ``(num_tokens,)`` (or higher-dimensional tensors with + an explicit last hidden dim) so the call site looks identical to the + existing materialised path; +* support optional tensor-parallel via ``tp_group`` for vocab-sharded + ``weight`` matrices; +* fall back gracefully to the materialised reference path when Triton is + unavailable or inputs are not on CUDA, so unit tests can still run on CPU. +""" + +from __future__ import annotations + +import os +from typing import Optional + +import torch +import torch.distributed as dist + +from areal.utils import logging + +logger = logging.getLogger("LinearCrossEntropy") + + +def _force_fallback() -> bool: + """Allow ops/CI to disable the fused kernel via env var without code change.""" + return os.environ.get("AREAL_DISABLE_FUSED_LCE", "0") == "1" + + +def _kernel_available() -> bool: + """Whether the Triton fused kernel can run on this host.""" + if _force_fallback(): + return False + if not torch.cuda.is_available(): + return False + try: + import triton # noqa: F401 + except ImportError: + return False + return True + + +def _reference_logprobs_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, + tp_group: Optional[dist.ProcessGroup], +) -> tuple[torch.Tensor, torch.Tensor]: + """Reference (materialised-logits) implementation. + + Used when Triton is unavailable. Mathematically equivalent to the fused + kernel up to floating-point reordering, which is why the test suite + asserts with explicit rtol/atol rather than bitwise equality. + """ + # Shape normalisation matches the fused kernel. + flat_hidden = hidden.reshape(-1, hidden.shape[-1]) + flat_labels = labels.reshape(-1) + + logits = torch.matmul(flat_hidden.float(), weight.float().t()) + if temperature != 1.0: + logits = logits / temperature + + if tp_group is not None and dist.get_world_size(tp_group) > 1: + # Vocab-parallel: gather full vocab logits across TP group. + # Used only as a slow correctness fallback. + world_size = dist.get_world_size(tp_group) + gathered = [torch.empty_like(logits) for _ in range(world_size)] + dist.all_gather(gathered, logits, group=tp_group) + logits = torch.cat(gathered, dim=-1) + + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + log_probs_labels = log_softmax.gather( + dim=-1, index=flat_labels.unsqueeze(-1) + ).squeeze(-1) + probs = log_softmax.exp() + entropy = -(probs * log_softmax).sum(dim=-1) + return log_probs_labels, entropy + + +def linear_cross_entropy_logprobs_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, + tp_group: Optional[dist.ProcessGroup] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute per-token log-prob and entropy from hidden states + lm-head weight. + + This is the fused counterpart to + :func:`areal.utils.functional.vocab_parallel.gather_logprobs_entropy`, + but consumes ``hidden`` (last layer states) and ``weight`` (lm-head + weight) directly instead of a materialised ``[num_tokens, vocab_size]`` + logits tensor. Memory savings scale with ``vocab_size``. + + Args: + hidden: ``(..., hidden_size)`` last-layer hidden states. + weight: ``(vocab_size, hidden_size)`` lm-head weight; may be sharded + on the vocab dimension when ``tp_group`` is set. + labels: ``(...,)`` integer label ids matching the leading dims of + ``hidden``. With TP, labels MUST hold *global* vocab ids. + temperature: softmax temperature. + tp_group: optional tensor-parallel group when ``weight`` is sharded. + + Returns: + ``(logprobs, entropy)`` both shaped like ``labels``. + """ + leading_shape = labels.shape + + if _kernel_available(): + # Lazy import: keeps a hard Triton import out of the module path so + # CPU-only environments can still load areal.utils.functional. + from areal.utils.kernel.linear_cross_entropy import linear_cross_entropy + + if hidden.device.type != "cuda": + logger.warning( + "Fused LCE requested but hidden is on %s; falling back to reference path.", + hidden.device, + ) + else: + try: + logprobs, entropy = linear_cross_entropy( + hidden, + weight, + labels, + temperature, + "none", + tp_group, + ) + return logprobs.reshape(leading_shape), entropy.reshape(leading_shape) + except Exception as exc: # pragma: no cover - fall back path + logger.warning( + "Fused LCE kernel raised %s; falling back to reference path.", + exc, + ) + + logprobs, entropy = _reference_logprobs_entropy( + hidden, weight, labels, temperature, tp_group + ) + return logprobs.reshape(leading_shape), entropy.reshape(leading_shape) + + +def linear_cross_entropy_logprobs( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, + tp_group: Optional[dist.ProcessGroup] = None, +) -> torch.Tensor: + """Logprobs-only counterpart of :func:`linear_cross_entropy_logprobs_entropy`. + + Returns a tensor shaped like ``labels``. + """ + logprobs, _ = linear_cross_entropy_logprobs_entropy( + hidden, weight, labels, temperature, tp_group + ) + return logprobs diff --git a/areal/utils/kernel/__init__.py b/areal/utils/kernel/__init__.py new file mode 100644 index 0000000000..9ecc81d669 --- /dev/null +++ b/areal/utils/kernel/__init__.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Triton-based fused linear-cross-entropy kernels for AReaL. + +The kernel implementations under :mod:`areal.utils.kernel.kernels` fuse +the matmul with cross-entropy reduction, preserving numerical semantics +while avoiding materialization of the ``[num_tokens, vocab_size]`` logits +tensor. The :class:`LinearCrossEntropy` autograd function exposed below +provides a memory-efficient drop-in replacement for the materialized +``logits = hidden @ weight.T`` followed by softmax / log-softmax / +entropy computation. +""" + +from areal.utils.kernel.linear_cross_entropy import ( + LinearCrossEntropy, + linear_cross_entropy, +) + +__all__ = [ + "LinearCrossEntropy", + "linear_cross_entropy", +] diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py new file mode 100644 index 0000000000..de7ccb7f16 --- /dev/null +++ b/areal/utils/kernel/kernels.py @@ -0,0 +1,1757 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementations of the linear cross entropy with token entropy kernel. + +The Triton kernel implementations fuse the matmul with cross-entropy +reduction so that the ``[num_tokens, vocab_size]`` logits tensor is never +materialized, trading kernel-launch overhead for large memory savings. +""" + +import typing +from dataclasses import dataclass + +import torch +import torch.distributed as dist + + +# --- Device helpers ----------------------------------------------------------- +# AReaL relies on torch directly for CUDA device primitives used by the +# Triton kernels below. +def _is_cuda_available() -> bool: + return torch.cuda.is_available() + + +def get_device_capability(): + if torch.cuda.is_available(): + return torch.cuda.get_device_capability() + return (0, 0) + + +def get_device_name() -> str: + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + +def get_torch_device(): + return torch.cuda + + +is_cuda_available = _is_cuda_available() + + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True + SUPPORT_CUDA_TMA = ( + is_cuda_available + and get_device_capability()[0] >= 9 + and hasattr(tl, "make_tensor_descriptor") + ) + +except ImportError: + HAVE_TRITON = False + SUPPORT_CUDA_TMA = False + +if not HAVE_TRITON: + from contextlib import contextmanager + from unittest.mock import MagicMock + + @contextmanager + def null_decorator(*args, **kwargs): + if len(kwargs) == 0 and len(args) == 1 and callable(args[0]): + return args[0] + else: + + def inner(func): + return func + + return inner + + triton = MagicMock() + triton.jit = null_decorator + triton.autotune = null_decorator + tl = MagicMock() + +elif SUPPORT_CUDA_TMA: + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: typing.Optional[int]): + return torch.empty(size, device=get_device_name(), dtype=torch.int8) + + # https://github.com/triton-lang/triton/commit/43625fc968b693ab51884ca95adbcf3e43483fd0 + # Triton 3.5.0 stores allocators in ContextVar; values do not propagate to new + # threads by default. Some execution paths use thread pools (e.g., + # concurrent.futures), so we set a ContextVar *default* to avoid falling + # back to NullAllocator in worker threads. + try: + import contextvars + + import triton.runtime._allocation as _triton_allocation + + if isinstance(getattr(_triton_allocation, "_allocator", None), contextvars.ContextVar): + _triton_allocation._allocator = contextvars.ContextVar( + _triton_allocation._allocator.name, + default=alloc_fn, + ) + except (ImportError, AttributeError): + pass + + triton.set_allocator(alloc_fn) + + +@dataclass +class EntropyReductionEnum: + """ + Enum for the reduction method of cross entropy. + """ + + _None = 0 + _Sum = 1 + _Mean = 2 + + +def get_entropy_reduction_enum_number(reduction: str) -> int: + """ + Get the enum number for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if reduction == "none": + _enum = EntropyReductionEnum._None + elif reduction == "sum": + _enum = EntropyReductionEnum._Sum + elif reduction == "mean": + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid reduction: {reduction}") + return _enum + + +def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + """ + Get the enum for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if ce_reduction == 0: + _enum = EntropyReductionEnum._None + elif ce_reduction == 1: + _enum = EntropyReductionEnum._Sum + elif ce_reduction == 2: + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid ce_reduction: {ce_reduction}") + return _enum + + +@dataclass +class BackwardEnum: + """ + Enum for the backward method. + """ + + _Total_Fuse_MN = ( + 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + ) + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size + _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens + + +@dataclass +class Config: + """Configuration for efficient entropy kernel operations. + + Args: + _backward (BackwardEnum): Backward computation method. Defaults to BackwardEnum._Split_Dlogits_N. + _use_triton (bool): Whether to use Triton kernels for computation. Defaults to True. + """ + + _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N + _use_triton: bool = True + + +_config = Config() + + +def set_backward_method(backward_method: BackwardEnum): + """ + Set the backward method. + """ + global _config + _config._backward = backward_method + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop( + rank, + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + rcp_temperature: tl.float32, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + USE_TMA: tl.constexpr, +): + """ + forward mainloop + """ + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + if pid_m == 0 and pid_n == 0: + tl.store(global_logprobs_scalar_ptr, 0.0) + + # create pointers for the first blocks of hidden + start_offs_am = pid_m * BLOCK_SIZE_M + offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + if USE_TMA: + # using TMA and device-side descriptor creation + hidden_desc = tl.make_tensor_descriptor( + hidden_ptr, + shape=[num_tokens, hidden_size], + strides=[stride_hidden_m, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + + weight_desc = tl.make_tensor_descriptor( + weight_ptr, + shape=[vocab_size, hidden_size], + strides=[stride_weight_n, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + + else: + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + vocab_bound = min((pid_n + 1) * vocab_per_split, vocab_size) + for n in range(0, num_pid_n): + start_offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if not USE_TMA: + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over K dimension + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + if USE_TMA: + # load the next block of hidden and weight + start_offs_k = k * BLOCK_SIZE_K + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + # load the next block of hidden and weight + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))), + other=0.0, + ) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + # GEMM + logits = tl.dot(_hidden, _weight.trans(), logits) + + if not USE_TMA: + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # scale logits by temperature + logits *= rcp_temperature + + logits_for_lse = tl.where(offs_bn[None, :] < vocab_bound, logits, float("-inf")) + + # update global maximum + _max_old = _max + m_pid_n = tl.max(logits_for_lse, axis=1) + _max = tl.maximum(_max_old, m_pid_n) + + exp_logits = tl.exp(logits_for_lse - _max[:, None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) + entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m + tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + # store logprobs + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) + mask &= offs_am < num_tokens + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue( + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + num_tokens, + num_splits, + global_max_ptr, + stride_global_max: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + global_entropy_ptr, + stride_global_entropy: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + foward epilogue + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + + _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + _entropy_b = tl.load( + entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0 + ) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + global_logprobs = -1 * global_logprobs + if reduction == 0: + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif reduction == 2: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, + num_splits, + reduced_max_ptr, + stride_reduced_max_m: tl.int64, + stride_reduced_max_n: tl.int64, + original_max_ptr, + stride_original_max_m: tl.int64, + stride_original_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_max_ptr, + stride_global_max: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load( + reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _original_max = tl.load( + original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _accu = tl.load( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load( + entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update( + num_tokens, + logprobs_ptr, + stride_logprobs: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accumulate_ptr, + stride_accumulate: tl.int64, + entropy_b_ptr, + stride_entropy_b: tl.int64, + entropy_ptr, + stride_entropy: tl.int64, + logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + if reduction == 0: + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif reduction == 2: + logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + + +_dedicated_stream, _dedicated_events = None, None + + +def efficient_entropy_forward( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[int] = 2, + temperature: typing.Optional[float] = 1.0, + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> list[torch.Tensor]: + """ + forward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + global _dedicated_stream, _dedicated_events + _dedicated_stream = get_torch_device().Stream(hidden.device) + _dedicated_events = [get_torch_device().Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): + logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) + else: + raise ValueError(f"Invalid reduction: {reduction}") + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + + if REDUCTION == EntropyReductionEnum._None: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + + if _config._use_triton: + # 1D kernel launch, then split the tile + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + + efficient_entropy_kernel_general_mainloop[mainloop_grid]( + _rank, + hidden, + weight, + labels, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + hidden.stride(0), + hidden.stride(1), + weight.stride(0), + weight.stride(1), + _max, + _max.stride(0), + _max.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + _logprobs, + _logprobs.stride(0), + logprobs, + 1.0 / temperature, + USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, + ) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + # reduction on maximum and maximum_indices + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid]( + _max, + _max.stride(0), + _max.stride(1), + num_tokens, + num_splits, + maximum, + maximum.stride(0), + _accu, + _accu.stride(0), + _accu.stride(1), + accumulate, + accumulate.stride(0), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + _logprobs, + _logprobs.stride(0), + logprobs, + REDUCTION, + ) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + get_torch_device().current_stream().record_event(_dedicated_events[0]) + with get_torch_device().stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid]( + num_tokens, + num_splits, + _max, + _max.stride(0), + _max.stride(1), + _max_backup, + _max_backup.stride(0), + _max_backup.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + ) + get_torch_device().current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid]( + num_tokens, + _logprobs, + _logprobs.stride(0), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + logprobs, + REDUCTION, + ) + + return (logprobs, entropy, maximum, accumulate, entropy_b) + + +# NOTE: merge d_weight & d_hidden here, split along M & N +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ) + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_mainloop_MN( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + USE_TMA: tl.constexpr, +): + """ + backward mainloop, where d_logits & d_hidden & d_weight are fused + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_offs_am = pid_m * BLOCK_SIZE_M + offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) + start_offs_bn = pid_n * BLOCK_SIZE_N + offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + if USE_TMA: + # using TMA and device-side descriptor creation + hidden_desc = tl.make_tensor_descriptor( + hidden_ptr, + shape=[num_tokens, hidden_size], + strides=[stride_hidden_m, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + + weight_desc = tl.make_tensor_descriptor( + weight_ptr, + shape=[vocab_size, hidden_size], + strides=[stride_weight_n, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + if not USE_TMA: + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n + d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + if USE_TMA: + start_offs_k = k * BLOCK_SIZE_K + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + other=0.0, + ) + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits = tl.dot(_hidden, _weight.T, logits) + + if not USE_TMA: + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # loop for d_weight & d_hidden + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + start_offs_k = k * BLOCK_SIZE_K + if USE_TMA: + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + else: + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) + # tl.atomic_add(d_weight_ptrs, + # _d_weight, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) + _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32)) + tl.atomic_add( + d_weight_ptrs, + _d_weight, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + ) + + if USE_TMA: + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + # _weight = tl.load( + # weight_ptrs, + # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + # other=0.0 + # ) + # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + other=0.0, + ) + _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) + tl.atomic_add( + d_hidden_ptrs, + _d_hidden, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + ) + + if not USE_TMA: + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_hidden( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_hidden_ptr, + stride_d_hidden_m: tl.int64, + stride_d_hidden_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + backward d_hidden + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_k = pid // num_pid_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + # iterate over vocab_size + d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): + offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + # iterate over hidden_size to get logits + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), + other=0.0, + ) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits + d_logits *= rcp_temperature + + # calculate d_hidden + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) + _weight = tl.load( + weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0 + ) + d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) + + # write back + tl.store( + d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, + d_hidden, + mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_d_weight( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_weight_ptr, + stride_d_weight_n: tl.int64, + stride_d_weight_k: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + pid_n = pid % num_pid_n + pid_k = pid // num_pid_n + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + result_offs_k = pid_k * BLOCK_SIZE_K + offs_k + + d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): + offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), + other=0.0, + ) + + logits = tl.dot(_hidden, _weight.trans(), logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) + _hidden = tl.load( + hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0 + ) + d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) + + # write back + tl.store( + d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, + d_weight, + mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size), + ) + + +# NOTE: split tile from d_logits' perspective +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits( + num_tokens: int, + hidden_size: int, + vocab_size: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + USE_TMA: tl.constexpr, +): + """ + backward d_logits + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_offs_am = pid_m * BLOCK_SIZE_M + offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) + start_offs_bn = pid_n * BLOCK_SIZE_N + offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if USE_TMA: + # using TMA and device-side descriptor creation + hidden_desc = tl.make_tensor_descriptor( + hidden_ptr, + shape=[num_tokens, hidden_size], + strides=[stride_hidden_m, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + weight_desc = tl.make_tensor_descriptor( + weight_ptr, + shape=[vocab_size, hidden_size], + strides=[stride_weight_n, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + else: + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + if USE_TMA: + start_offs_k = k * BLOCK_SIZE_K + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), + other=0.0, + ) + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + logits = tl.dot(_hidden, _weight.T, logits) + + if not USE_TMA: + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + # scale logits by temperature + logits *= rcp_temperature + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # scale d_logits by temperature + d_logits *= rcp_temperature + + # store d_logits + d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + tl.store( + d_logits_ptrs, + d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty + mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size), + ) + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits_split_N( + split_idx: int, + num_tokens: int, + hidden_size: int, + vocab_size: int, + vocab_per_split: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + reduction: int, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + USE_TMA: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_offs_am = pid_m * BLOCK_SIZE_M + offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) + start_offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) + accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: + d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) + labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if USE_TMA: + # using TMA and device-side descriptor creation + hidden_desc = tl.make_tensor_descriptor( + hidden_ptr, + shape=[num_tokens, hidden_size], + strides=[stride_hidden_m, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + weight_desc = tl.make_tensor_descriptor( + weight_ptr, + shape=[vocab_size, hidden_size], + strides=[stride_weight_n, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + else: + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) + + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + if USE_TMA: + start_offs_k = k * BLOCK_SIZE_K + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), + other=0.0, + ) + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + logits = tl.dot(_hidden, _weight.T, logits) + + logits *= rcp_temperature + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + d_logits *= rcp_temperature + + # filter d_logits with mask + result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) + + tl.store( + d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask + ) + + +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: typing.Optional[int] = 2, + should_return_fp32_grad: bool = False, + temperature: typing.Optional[float] = 1.0, + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> list[torch.Tensor]: + """ + backward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + assert dlogprobs.shape == (num_tokens,) + else: + assert dlogprobs.dim() == 0 + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + d_hidden, d_weight = None, None + if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad: + d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) + d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) + else: + d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + if _config._backward == BackwardEnum._Total_Fuse_MN: + # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + d_hidden, + d_hidden.stride(0), + d_hidden.stride(1), + d_weight, + d_weight.stride(0), + d_weight.stride(1), + 1.0 / temperature, + USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, + ) + + elif _config._backward == BackwardEnum._Total_Separate: + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + if _config._use_triton: + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, + ) + + torch.matmul(_d_logits, weight, out=d_hidden) + torch.matmul(_d_logits.T, hidden, out=d_weight) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + elif _config._backward == BackwardEnum._Split_Dlogits_N: + vocab_per_split = 9504 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) + + for split_idx in range(num_splits): + efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( + split_idx, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, + ) + + if split_idx == (num_splits - 1): + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split + _d_logits = _d_logits[:, :vocab_right_bound].contiguous() + + if split_idx == 0: + torch.matmul( + _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden + ) + else: + d_hidden += torch.matmul( + _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] + ) + torch.matmul( + _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] + ) + + elif _config._backward == BackwardEnum._Split_Dlogits_M: + raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") + + return d_hidden, d_weight diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py new file mode 100644 index 0000000000..ab0ae8dbe3 --- /dev/null +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +``LinearCrossEntropy`` autograd Function for AReaL. + +This module exposes a drop-in replacement for the +``logits = hidden @ weight.T`` -> ``log_softmax`` -> per-token +log-probability and entropy pipeline. Internally it dispatches to a Triton +kernel that fuses the matmul with the cross-entropy +reduction so that the ``[num_tokens, vocab_size]`` logits tensor is never +materialized. +""" + +from __future__ import annotations + +import os +import typing + +import torch +import torch.distributed as dist + +from areal.utils import logging + +logger = logging.getLogger("LinearCrossEntropy") + + +def _debug_enabled() -> bool: + """Whether to emit per-tensor debug summaries. + + Toggled via ``AREAL_LCE_DEBUG=1`` to avoid GPU-CPU sync in hot paths. + """ + return os.environ.get("AREAL_LCE_DEBUG", "0") == "1" + + +def _summarize(name: str, tensor: torch.Tensor) -> None: + """Emit a tiny statistical summary for diff-driven debugging. + + Triggers a CPU-GPU sync via ``.item()``; only call when + ``_debug_enabled()`` is true. + """ + if not tensor.is_floating_point(): + logger.debug( + "[diff] %s shape=%s dtype=%s device=%s", + name, + tuple(tensor.shape), + tensor.dtype, + tensor.device, + ) + return + flat = tensor.detach().float().reshape(-1) + if flat.numel() == 0: + logger.debug("[diff] %s is empty", name) + return + logger.debug( + "[diff] %s shape=%s dtype=%s mean=%.6e std=%.6e min=%.6e max=%.6e", + name, + tuple(tensor.shape), + tensor.dtype, + flat.mean().item(), + flat.std(unbiased=False).item() if flat.numel() > 1 else 0.0, + flat.min().item(), + flat.max().item(), + ) + + +class LinearCrossEntropy(torch.autograd.Function): + """Fused linear + cross-entropy / token-entropy autograd Function. + + Forward signature: + + Args: + hidden: ``(num_tokens, hidden_size)`` or + ``(batch_size, seq_len, hidden_size)``. Must be contiguous on + CUDA. + weight: ``(vocab_size, hidden_size)`` lm-head weight. Must be + contiguous on CUDA. + labels: integer label ids; either ``(num_tokens,)`` or + ``(batch_size, seq_len)``. + temperature: softmax temperature; defaults to ``1.0``. + reduction: ``"none"`` returns per-token negative log-likelihood; + ``"sum"`` and ``"mean"`` return scalars. + dist_process_group: optional tensor-parallel group for vocab-sharded + ``weight``. ``labels`` must contain *global* vocab ids on every + rank; the kernel handles the per-rank slice internally. + + Returns: + ``(logprobs, entropy)`` where ``entropy`` has shape ``(num_tokens,)`` + and ``logprobs`` has shape ``(num_tokens,)`` for ``reduction="none"`` + or ``()`` otherwise. + """ + + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: typing.Optional[float] = 1.0, + reduction: typing.Optional[str] = "none", + dist_process_group: typing.Optional[dist.ProcessGroup] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if not isinstance(temperature, float): + temperature = float(temperature) + if not isinstance(reduction, str): + raise TypeError(f"reduction must be str, got {type(reduction)}") + + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + # Local import keeps Triton dependency lazy: tests can still + # import this module on machines without Triton. + from areal.utils.kernel import kernels + + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + original_hidden_shape = hidden.shape + if hidden.dim() != 2: + hidden = hidden.reshape(-1, hidden.shape[-1]) + if labels.dim() != 1: + labels = labels.reshape(-1) + + # Triton kernels demand contiguous CUDA tensors; bail out loudly + # on misuse rather than silently materialising copies on a hot + # path. + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda, ( + "LinearCrossEntropy requires CUDA inputs" + ) + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous(), ( + "LinearCrossEntropy requires contiguous tensors" + ) + + if _debug_enabled(): + _summarize("forward.hidden", hidden) + _summarize("forward.weight", weight) + _summarize("forward.labels", labels) + logger.debug( + "[diff] forward.meta temperature=%.6f reduction=%s tp_world=%d", + temperature, + reduction, + 1 if dist_process_group is None else dist.get_world_size(dist_process_group), + ) + + ( + logprobs, + entropy, + _maximum, + _accumulate, + _entropy_b, + ) = kernels.efficient_entropy_forward( + hidden, + weight, + labels, + REDUCTION, + temperature, + dist_process_group, + ) + + if _debug_enabled(): + _summarize("forward.logprobs", logprobs) + _summarize("forward.entropy", entropy) + _summarize("forward._maximum", _maximum) + _summarize("forward._accumulate", _accumulate) + _summarize("forward._entropy_b", _entropy_b) + + ctx.save_for_backward( + hidden, weight, labels, _maximum, _accumulate, _entropy_b + ) + ctx.original_hidden_shape = original_hidden_shape + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + ctx.should_return_fp32_grad = False + ctx.temperature = temperature + + return logprobs, entropy + + @staticmethod + def backward( + ctx, + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + ) -> tuple: + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + from areal.utils.kernel import kernels + + ( + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + ) = ctx.saved_tensors + + if _debug_enabled(): + _summarize("backward.dlogprobs", dlogprobs) + _summarize("backward.dentropy", dentropy) + + d_hidden, d_weight = kernels.efficient_entropy_backward( + dlogprobs, + dentropy, + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + ctx.REDUCTION, + ctx.should_return_fp32_grad, + ctx.temperature, + ctx.dist_process_group, + ) + + # TP all-reduce on d_hidden. + # + # Why this is required: + # ``efficient_entropy_backward`` computes a *local* contribution + # ``d_hidden_local = d_logits_local @ weight_local`` where each TP + # rank holds only a vocab-shard of ``weight``. The mathematically + # correct gradient is the sum across the TP group: + # d_hidden = sum_over_tp_ranks(d_logits_local @ weight_local). + # In Megatron's normal forward, the surrounding + # ``ColumnParallelLinear`` (output_layer) inserts this all-reduce + # via ``linear_with_grad_accumulation_and_async_allreduce``. The + # fused-LCE fast path monkey-patches ``output_layer.forward`` to + # return ``(hidden, None)`` (an autograd identity), which bypasses + # mcore's machinery — so the all-reduce vanishes unless we + # reproduce it here. + # + # Without this reduction, TP > 1 silently produces gradients that + # equal each rank's local partial, leading to incorrect training + # that is *not* caught by any forward-only invariant since the + # forward kernel already all-reduces (max / logsumexp / entropy + # auxiliaries) inside ``efficient_entropy_forward``. + # + # ``d_weight`` does NOT need an all-reduce: each rank legitimately + # owns its vocab slice's weights, so the gradient on the local + # weight shard is correctly local-only — exactly mirroring how + # mcore handles ColumnParallel weight grads. + if ( + ctx.dist_process_group is not None + and dist.get_world_size(ctx.dist_process_group) > 1 + ): + dist.all_reduce( + d_hidden, + op=dist.ReduceOp.SUM, + group=ctx.dist_process_group, + ) + + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + if _debug_enabled(): + _summarize("backward.d_hidden", d_hidden) + _summarize("backward.d_weight", d_weight) + + # Order matches forward: hidden, weight, labels, temperature, reduction, group + return d_hidden, d_weight, None, None, None, None + + +def linear_cross_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, + reduction: str = "none", + dist_process_group: typing.Optional[dist.ProcessGroup] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Functional wrapper around :class:`LinearCrossEntropy`. + + Returns ``(logprobs, entropy)`` with shapes following ``reduction`` + semantics. See the class docstring for full argument descriptions. + """ + return LinearCrossEntropy.apply( + hidden, weight, labels, temperature, reduction, dist_process_group + ) diff --git a/benchmark/bench_linear_cross_entropy.py b/benchmark/bench_linear_cross_entropy.py new file mode 100644 index 0000000000..83d9dcb280 --- /dev/null +++ b/benchmark/bench_linear_cross_entropy.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Standalone benchmark for the fused linear-cross-entropy kernel. + +Designed to be run *outside* pytest so that NVIDIA Nsight Systems +(``nsys profile``) can capture a clean, deterministic trace covering both +the materialised reference path and the fused Triton path. + +NVTX ranges are emitted around each phase so the resulting ``.nsys-rep`` +file can be filtered down to just the linear-CE kernels in the Nsight UI. + +Usage:: + + # Plain run (sanity) + python -m benchmark.bench_linear_cross_entropy --tokens 4096 --vocab 152064 + + # Profile with Nsight Systems + nsys profile -t nvtx,cuda,cudnn,cublas \\ + -o lce_profile --capture-range cudaProfilerApi --capture-range-end stop \\ + python -m benchmark.bench_linear_cross_entropy \\ + --tokens 4096 --vocab 152064 --use-cuda-profiler-api + +See ``docs/perf/nsight_linear_cross_entropy.md`` for a full Nsight workflow. +""" + +from __future__ import annotations + +import argparse +import gc +import math +import sys + +import torch + + +def _make_inputs(num_tokens, hidden_size, vocab_size, dtype, seed=0): + g = torch.Generator(device="cuda").manual_seed(seed) + hidden = ( + torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda", generator=g) + * 0.02 + ) + weight = ( + torch.randn(vocab_size, hidden_size, dtype=dtype, device="cuda", generator=g) + * 0.02 + ) + labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda", generator=g) + return hidden.contiguous(), weight.contiguous(), labels.contiguous() + + +def _ref_step(hidden, weight, labels, temperature=1.0): + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + logits = (h.float() @ w.float().t()) / temperature + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) + (lp.sum() + ent.sum()).backward() + return h.grad, w.grad + + +def _fused_step(hidden, weight, labels, temperature=1.0): + from areal.utils.kernel import linear_cross_entropy + + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + lp, ent = linear_cross_entropy(h, w, labels, temperature, "none", None) + (lp.sum() + ent.sum()).backward() + return h.grad, w.grad + + +def _measure(label, fn, hidden, weight, labels, args, warmup, iters): + nvtx = torch.cuda.nvtx + times = [] + mems = [] + + # Warmup + nvtx.range_push(f"{label}/warmup") + for _ in range(warmup): + fn(hidden, weight, labels) + gc.collect() + torch.cuda.empty_cache() + nvtx.range_pop() + + nvtx.range_push(f"{label}/measure") + for i in range(iters): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + nvtx.range_push(f"{label}/iter{i}") + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn(hidden, weight, labels) + end.record() + torch.cuda.synchronize() + nvtx.range_pop() + times.append(start.elapsed_time(end)) + mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + nvtx.range_pop() + + return times, mems + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tokens", type=int, default=4096) + parser.add_argument("--hidden", type=int, default=4096) + parser.add_argument("--vocab", type=int, default=152064) + parser.add_argument( + "--dtype", + choices=["bfloat16", "float16", "float32"], + default="bfloat16", + ) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--iters", type=int, default=10) + parser.add_argument( + "--use-cuda-profiler-api", + action="store_true", + help=( + "Wrap the measurement region with cudaProfilerStart/Stop so that " + "`nsys profile --capture-range cudaProfilerApi` only records the " + "interesting region." + ), + ) + parser.add_argument("--mode", choices=["both", "ref", "fused"], default="both") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA is not available; aborting.", file=sys.stderr) + sys.exit(1) + + dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[ + args.dtype + ] + hidden, weight, labels = _make_inputs(args.tokens, args.hidden, args.vocab, dtype) + print( + f"[bench] tokens={args.tokens} hidden={args.hidden} vocab={args.vocab} " + f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}" + ) + + if args.use_cuda_profiler_api: + torch.cuda.cudart().cudaProfilerStart() + + results = {} + if args.mode in ("both", "ref"): + t, m = _measure("reference", _ref_step, hidden, weight, labels, args, args.warmup, args.iters) + results["reference"] = (t, m) + if args.mode in ("both", "fused"): + t, m = _measure("fused", _fused_step, hidden, weight, labels, args, args.warmup, args.iters) + results["fused"] = (t, m) + + if args.use_cuda_profiler_api: + torch.cuda.cudart().cudaProfilerStop() + + for name, (t, m) in results.items(): + median = sorted(t)[len(t) // 2] + peak = max(m) + print(f"[bench] {name:9s} median={median:7.2f}ms peak_mem={peak:8.1f}MB") + + if "reference" in results and "fused" in results: + ref_med = sorted(results["reference"][0])[len(results["reference"][0]) // 2] + fused_med = sorted(results["fused"][0])[len(results["fused"][0]) // 2] + ref_peak = max(results["reference"][1]) + fused_peak = max(results["fused"][1]) + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + print( + f"[bench] speedup={speedup:.2f}x fused_peak/ref_peak={mem_ratio:.2f}x" + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py new file mode 100644 index 0000000000..b2efc0f09f --- /dev/null +++ b/tests/test_linear_cross_entropy.py @@ -0,0 +1,332 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Correctness + performance tests for the fused linear-cross-entropy kernel. + +The test suite verifies that +:func:`areal.utils.functional.linear_cross_entropy_logprobs_entropy` produces +results numerically equivalent to the materialised ``logits @ weight`` + +``log_softmax`` reference, and that it provides a measurable wall-clock / +memory benefit over the reference path on representative LLM shapes. + +The performance assertions are intentionally loose (>=1.0x runtime, i.e. +"not slower") so they remain meaningful in CI where cudagraph capture and +power-state variability can swing absolute timings; the PRINTED report is +the authoritative artifact for review. + +Run only the correctness checks (fast, single-GPU):: + + pytest tests/test_linear_cross_entropy.py -k correctness -s + +Run the full benchmark (includes large-vocab cases, slow):: + + pytest tests/test_linear_cross_entropy.py -m slow -s +""" + +from __future__ import annotations + +import gc +import math + +import pytest +import torch + +CUDA_AVAILABLE = torch.cuda.is_available() +try: + import triton # noqa: F401 + + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + + +pytestmark = pytest.mark.skipif( + not (CUDA_AVAILABLE and TRITON_AVAILABLE), + reason="Fused LCE requires CUDA + Triton", +) + + +def _reference_logprobs_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Materialised-logits reference. Same math, no fusion.""" + logits = (hidden.float() @ weight.float().t()) / temperature + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + logprobs = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + entropy = -(probs * log_softmax).sum(dim=-1) + return logprobs, entropy + + +def _make_inputs( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, + device: str = "cuda", + seed: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g = torch.Generator(device=device).manual_seed(seed) + hidden = ( + torch.randn( + num_tokens, hidden_size, dtype=dtype, device=device, generator=g + ) + * 0.02 + ) + weight = ( + torch.randn( + vocab_size, hidden_size, dtype=dtype, device=device, generator=g + ) + * 0.02 + ) + labels = torch.randint( + 0, vocab_size, (num_tokens,), device=device, generator=g + ) + return hidden.contiguous(), weight.contiguous(), labels.contiguous() + + +# --------------------------------------------------------------------------- +# Correctness +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size,dtype", + [ + (256, 512, 4096, torch.float32), + (512, 1024, 32000, torch.bfloat16), + (128, 768, 8192, torch.float16), + ], +) +def test_linear_cross_entropy_correctness( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, +) -> None: + """Fused forward output must match the materialised reference.""" + from areal.utils.functional import linear_cross_entropy_logprobs_entropy + + hidden, weight, labels = _make_inputs(num_tokens, hidden_size, vocab_size, dtype) + + ref_logprobs, ref_entropy = _reference_logprobs_entropy(hidden, weight, labels) + fused_logprobs, fused_entropy = linear_cross_entropy_logprobs_entropy( + hidden, weight, labels, temperature=1.0 + ) + + # Tolerances are dtype-dependent; bf16/fp16 inputs widen them as expected. + if dtype == torch.float32: + rtol, atol = 1e-4, 1e-4 + else: + rtol, atol = 5e-2, 5e-2 + + torch.testing.assert_close( + fused_logprobs.float(), ref_logprobs.float(), rtol=rtol, atol=atol + ) + torch.testing.assert_close( + fused_entropy.float(), ref_entropy.float(), rtol=rtol, atol=atol + ) + + +@pytest.mark.parametrize("temperature", [0.7, 1.0, 1.5]) +def test_linear_cross_entropy_temperature(temperature: float) -> None: + """Temperature scaling matches the reference for non-trivial values.""" + from areal.utils.functional import linear_cross_entropy_logprobs_entropy + + hidden, weight, labels = _make_inputs( + num_tokens=128, hidden_size=512, vocab_size=4096, dtype=torch.float32 + ) + ref_lp, ref_h = _reference_logprobs_entropy(hidden, weight, labels, temperature) + fused_lp, fused_h = linear_cross_entropy_logprobs_entropy( + hidden, weight, labels, temperature=temperature + ) + torch.testing.assert_close(fused_lp, ref_lp, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(fused_h, ref_h, rtol=1e-4, atol=1e-4) + + +def test_linear_cross_entropy_backward_matches_reference() -> None: + """Backward gradients on hidden/weight match autograd through the reference.""" + from areal.utils.kernel import linear_cross_entropy + + num_tokens, hidden_size, vocab_size = 64, 256, 2048 + hidden_a, weight_a, labels = _make_inputs( + num_tokens, hidden_size, vocab_size, torch.float32 + ) + hidden_b = hidden_a.clone() + weight_b = weight_a.clone() + hidden_a.requires_grad_(True) + weight_a.requires_grad_(True) + hidden_b.requires_grad_(True) + weight_b.requires_grad_(True) + + # Reference path + logits = hidden_b @ weight_b.t() + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + ref_lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ref_h = -(probs * log_softmax).sum(dim=-1) + (ref_lp.sum() + 0.5 * ref_h.sum()).backward() + + # Fused path + fused_lp, fused_h = linear_cross_entropy( + hidden_a, weight_a, labels, 1.0, "none", None + ) + (fused_lp.sum() + 0.5 * fused_h.sum()).backward() + + torch.testing.assert_close( + hidden_a.grad, hidden_b.grad, rtol=5e-3, atol=5e-3 + ) + torch.testing.assert_close( + weight_a.grad, weight_b.grad, rtol=5e-3, atol=5e-3 + ) + + +# --------------------------------------------------------------------------- +# Performance benchmark +# --------------------------------------------------------------------------- + + +def _peak_memory_mb(fn, *args, **kwargs) -> tuple[float, float]: + """Return (elapsed_ms, peak_mem_mb) of a single forward+backward pass.""" + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + out = fn(*args, **kwargs) + if isinstance(out, tuple): + loss = sum(t.float().sum() for t in out if t.requires_grad or t.grad_fn is not None) + else: + loss = out.float().sum() + loss.backward() + end.record() + torch.cuda.synchronize() + elapsed = start.elapsed_time(end) + peak = torch.cuda.max_memory_allocated() / (1024 * 1024) + return elapsed, peak + + +def _run_reference_forward_backward(hidden, weight, labels, temperature): + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + logits = (h.float() @ w.float().t()) / temperature + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) + return lp, ent + + +def _run_fused_forward_backward(hidden, weight, labels, temperature): + from areal.utils.kernel import linear_cross_entropy + + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + return linear_cross_entropy(h, w, labels, temperature, "none", None) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size", + [ + # Small: validates the speedup is measurable even on toy shapes. + (1024, 1024, 32000), + # Medium: typical 7B-class one-microbatch shape. + (4096, 4096, 128256), + # Large vocab: where fused kernel really wins (e.g. Qwen3). + (2048, 4096, 152064), + ], +) +def test_linear_cross_entropy_performance_benchmark( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + """Compare fused vs materialised forward+backward time and peak memory. + + Failures here mean the fused path *regressed* against the reference; the + captured numbers are also printed for human review. + """ + dtype = torch.bfloat16 + hidden, weight, labels = _make_inputs( + num_tokens, hidden_size, vocab_size, dtype + ) + + # warm-up + for _ in range(2): + lp, ent = _run_reference_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + del lp, ent + gc.collect() + torch.cuda.empty_cache() + for _ in range(2): + lp, ent = _run_fused_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + del lp, ent + gc.collect() + torch.cuda.empty_cache() + + # Reference timing + ref_times = [] + ref_mems = [] + for _ in range(5): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + lp, ent = _run_reference_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + end.record() + torch.cuda.synchronize() + ref_times.append(start.elapsed_time(end)) + ref_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + del lp, ent + + # Fused timing + fused_times = [] + fused_mems = [] + for _ in range(5): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + lp, ent = _run_fused_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + end.record() + torch.cuda.synchronize() + fused_times.append(start.elapsed_time(end)) + fused_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + del lp, ent + + ref_med = sorted(ref_times)[len(ref_times) // 2] + fused_med = sorted(fused_times)[len(fused_times) // 2] + ref_peak = max(ref_mems) + fused_peak = max(fused_mems) + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + + print( + f"\n[LCE-Bench] tokens={num_tokens} hidden={hidden_size} vocab={vocab_size} " + f"dtype={dtype}\n" + f" reference: {ref_med:7.2f} ms / {ref_peak:7.1f} MB peak\n" + f" fused : {fused_med:7.2f} ms / {fused_peak:7.1f} MB peak\n" + f" speedup : {speedup:5.2f}x memory_ratio: {mem_ratio:5.2f}x" + ) + + # Soft assertions: fused path must not be drastically slower or more + # memory-hungry. Tight thresholds would cause flaky CI on shared GPUs. + assert fused_med < ref_med * 1.5, ( + f"Fused LCE is more than 1.5x slower than reference " + f"(fused={fused_med:.2f}ms ref={ref_med:.2f}ms). Please investigate." + ) + assert fused_peak < ref_peak * 1.2, ( + f"Fused LCE peak memory exceeds reference by >20% " + f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." + ) From fc5211bf9d86f81a7995c6ea68dfa616cb93cea7 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 18:02:18 +0800 Subject: [PATCH 02/31] fix(kernel): continus --- areal/utils/kernel/linear_cross_entropy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py index ab0ae8dbe3..8b0c150332 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -192,6 +192,12 @@ def backward( _summarize("backward.dlogprobs", dlogprobs) _summarize("backward.dentropy", dentropy) + # PyTorch autograd may produce non-contiguous gradient tensors + # (e.g. expanded views from broadcast). Triton kernels require + # contiguous inputs, so ensure contiguity before dispatching. + dlogprobs = dlogprobs.contiguous() + dentropy = dentropy.contiguous() + d_hidden, d_weight = kernels.efficient_entropy_backward( dlogprobs, dentropy, From e494738ba2959d2bfb9b4d5bdba6a213586e8763 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 18:16:59 +0800 Subject: [PATCH 03/31] test(linear_cross_entropy): add test for tp > 1 --- tests/test_linear_cross_entropy.py | 244 ++++++++++++++++++++++++++++- 1 file changed, 243 insertions(+), 1 deletion(-) diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py index b2efc0f09f..74b9b8baa5 100644 --- a/tests/test_linear_cross_entropy.py +++ b/tests/test_linear_cross_entropy.py @@ -184,7 +184,249 @@ def test_linear_cross_entropy_backward_matches_reference() -> None: # --------------------------------------------------------------------------- -# Performance benchmark +# Tensor-parallel (TP=2) correctness + performance +# --------------------------------------------------------------------------- + + +def _tp2_available() -> bool: + """Whether we can launch a 2-rank TP test on this host.""" + if not (CUDA_AVAILABLE and TRITON_AVAILABLE): + return False + if torch.cuda.device_count() < 2: + return False + return True + + +_tp2_skip = pytest.mark.skipif(not _tp2_available(), reason="TP=2 requires >= 2 CUDA GPUs") + + +def _init_tp2(): + """Initialise a 2-rank NCCL process group; return (rank, group).""" + import os + + import torch.distributed as dist + + if dist.is_initialized(): + rank = dist.get_rank() + group = dist.new_group(ranks=[0, 1], backend="nccl") + return rank, group + + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + group = dist.new_group(ranks=list(range(dist.get_world_size())), backend="nccl") + return rank, group + + +@_tp2_skip +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size,dtype", + [ + (128, 512, 8192, torch.float32), + (256, 1024, 32000, torch.bfloat16), + ], +) +def test_linear_cross_entropy_tp2_correctness( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, +) -> None: + """TP=2 fused forward+backward must match the materialised single-GPU reference.""" + import torch.distributed as dist + + from areal.utils.kernel import linear_cross_entropy + + rank, tp_group = _init_tp2() + world_size = dist.get_world_size(tp_group) + assert world_size == 2 + + torch.cuda.set_device(rank) + device = f"cuda:{rank}" + + vocab_per_rank = vocab_size // world_size + assert vocab_size % world_size == 0, "vocab_size must be divisible by world_size" + + g = torch.Generator(device=device).manual_seed(42) + hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) + labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) + + weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) + weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() + + # --- Reference (single-GPU, full weight) --- + hidden_ref = hidden.detach().clone().requires_grad_(True) + weight_ref = weight_full.detach().clone().requires_grad_(True) + logits = (hidden_ref.float() @ weight_ref.float().t()) + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + ref_lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ref_h = -(probs * log_softmax).sum(dim=-1) + (ref_lp.sum() + 0.5 * ref_h.sum()).backward() + + # --- Fused TP=2 --- + hidden_fused = hidden.detach().clone().requires_grad_(True) + weight_fused = weight_shard.detach().clone().requires_grad_(True) + fused_lp, fused_h = linear_cross_entropy( + hidden_fused, weight_fused, labels, 1.0, "none", tp_group + ) + (fused_lp.sum() + 0.5 * fused_h.sum()).backward() + + if dtype == torch.float32: + rtol, atol = 1e-3, 1e-3 + else: + rtol, atol = 5e-2, 5e-2 + + torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_h.float(), ref_h.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(hidden_fused.grad.float(), hidden_ref.grad.float(), rtol=rtol, atol=atol) + + # d_weight is per-rank (vocab shard), compare only the owned shard + torch.testing.assert_close( + weight_fused.grad.float(), + weight_ref.grad[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].float(), + rtol=rtol, + atol=atol, + ) + + dist.destroy_process_group() + + +@_tp2_skip +@pytest.mark.slow +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size", + [ + (1024, 1024, 32000), + (2048, 4096, 152064), + ], +) +def test_linear_cross_entropy_tp2_performance_benchmark( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + """TP=2 fused vs materialised forward+backward time and peak memory.""" + import torch.distributed as dist + + from areal.utils.kernel import linear_cross_entropy + + rank, tp_group = _init_tp2() + world_size = dist.get_world_size(tp_group) + assert world_size == 2 + + torch.cuda.set_device(rank) + device = f"cuda:{rank}" + dtype = torch.bfloat16 + + vocab_per_rank = vocab_size // world_size + assert vocab_size % world_size == 0 + + g = torch.Generator(device=device).manual_seed(0) + hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) + labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) + weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) + weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() + + # --- warm-up --- + for _ in range(2): + h = hidden.detach().clone().requires_grad_(True) + w = weight_shard.detach().clone().requires_grad_(True) + lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) + (lp.sum() + ent.sum()).backward() + del lp, ent, h, w + gc.collect() + torch.cuda.empty_cache() + + for _ in range(2): + h = hidden.detach().clone().requires_grad_(True) + w = weight_full.detach().clone().requires_grad_(True) + logits = (h.float() @ w.float().t()) + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) + (lp.sum() + ent.sum()).backward() + del lp, ent, h, w + gc.collect() + torch.cuda.empty_cache() + + # --- Fused TP=2 timing --- + fused_times = [] + fused_mems = [] + for _ in range(5): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + h = hidden.detach().clone().requires_grad_(True) + w = weight_shard.detach().clone().requires_grad_(True) + lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) + (lp.sum() + ent.sum()).backward() + end.record() + torch.cuda.synchronize() + fused_times.append(start.elapsed_time(end)) + fused_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + del lp, ent, h, w + + # --- Reference (single-GPU, full weight) timing --- + ref_times = [] + ref_mems = [] + for _ in range(5): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + h = hidden.detach().clone().requires_grad_(True) + w = weight_full.detach().clone().requires_grad_(True) + logits = (h.float() @ w.float().t()) + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) + (lp.sum() + ent.sum()).backward() + end.record() + torch.cuda.synchronize() + ref_times.append(start.elapsed_time(end)) + ref_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + del lp, ent, h, w + + ref_med = sorted(ref_times)[len(ref_times) // 2] + fused_med = sorted(fused_times)[len(fused_times) // 2] + ref_peak = max(ref_mems) + fused_peak = max(fused_mems) + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + + print( + f"\n[LCE-TP2-Bench rank={rank}] tokens={num_tokens} hidden={hidden_size} vocab={vocab_size} " + f"dtype={dtype}\n" + f" reference: {ref_med:7.2f} ms / {ref_peak:7.1f} MB peak\n" + f" fused : {fused_med:7.2f} ms / {fused_peak:7.1f} MB peak\n" + f" speedup : {speedup:5.2f}x memory_ratio: {mem_ratio:5.2f}x" + ) + + assert fused_med < ref_med * 1.5, ( + f"Fused TP=2 LCE is more than 1.5x slower than reference " + f"(fused={fused_med:.2f}ms ref={ref_med:.2f}ms)." + ) + assert fused_peak < ref_peak * 1.2, ( + f"Fused TP=2 LCE peak memory exceeds reference by >20% " + f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." + ) + + dist.destroy_process_group() + + +# --------------------------------------------------------------------------- +# Performance benchmark (single-GPU) # --------------------------------------------------------------------------- From 9495141a9f2792b9c3125f7f0092220f115ac984 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 18:28:10 +0800 Subject: [PATCH 04/31] test(linear_cross_entropy): log test --- tests/test_linear_cross_entropy.py | 83 ++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 15 deletions(-) diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py index 74b9b8baa5..6b7e406cf0 100644 --- a/tests/test_linear_cross_entropy.py +++ b/tests/test_linear_cross_entropy.py @@ -200,24 +200,50 @@ def _tp2_available() -> bool: _tp2_skip = pytest.mark.skipif(not _tp2_available(), reason="TP=2 requires >= 2 CUDA GPUs") +import sys + + +def _log(msg: str) -> None: + """Real-time log to stderr (unbuffered, bypasses pytest capture).""" + import os + + rank = os.environ.get("RANK", "?") + local_rank = os.environ.get("LOCAL_RANK", "?") + sys.stderr.write(f"[LCE-TP2 rank={rank} local_rank={local_rank}] {msg}\n") + sys.stderr.flush() + + def _init_tp2(): """Initialise a 2-rank NCCL process group; return (rank, group).""" import os import torch.distributed as dist + _log("Entering _init_tp2") + if dist.is_initialized(): + _log("dist already initialized, creating new subgroup") rank = dist.get_rank() - group = dist.new_group(ranks=[0, 1], backend="nccl") + world_size = dist.get_world_size() + _log(f"rank={rank} world_size={world_size}") + group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + _log(f"subgroup created, group={group}") return rank, group - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", "29500") - os.environ.setdefault("RANK", "0") - os.environ.setdefault("WORLD_SIZE", "1") + _log("dist NOT initialized, calling init_process_group") + master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") + master_port = os.environ.get("MASTER_PORT", "29500") + _log(f"MASTER_ADDR={master_addr} MASTER_PORT={master_port}") + _log(f"RANK={os.environ.get('RANK', '?')} WORLD_SIZE={os.environ.get('WORLD_SIZE', '?')} " + f"LOCAL_RANK={os.environ.get('LOCAL_RANK', '?')} LOCAL_WORLD_SIZE={os.environ.get('LOCAL_WORLD_SIZE', '?')}") + dist.init_process_group(backend="nccl") rank = dist.get_rank() - group = dist.new_group(ranks=list(range(dist.get_world_size())), backend="nccl") + world_size = dist.get_world_size() + _log(f"init_process_group done, rank={rank} world_size={world_size}") + + group = dist.new_group(ranks=list(range(world_size)), backend="nccl") + _log(f"subgroup created") return rank, group @@ -240,9 +266,12 @@ def test_linear_cross_entropy_tp2_correctness( from areal.utils.kernel import linear_cross_entropy + _log(f"test start: tokens={num_tokens} hidden={hidden_size} vocab={vocab_size} dtype={dtype}") + rank, tp_group = _init_tp2() world_size = dist.get_world_size(tp_group) assert world_size == 2 + _log(f"init done: rank={rank} world_size={world_size}") torch.cuda.set_device(rank) device = f"cuda:{rank}" @@ -250,14 +279,16 @@ def test_linear_cross_entropy_tp2_correctness( vocab_per_rank = vocab_size // world_size assert vocab_size % world_size == 0, "vocab_size must be divisible by world_size" + _log("Creating inputs...") g = torch.Generator(device=device).manual_seed(42) hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) - weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() + _log(f"Inputs ready: hidden={hidden.shape} weight_shard={weight_shard.shape} labels={labels.shape}") # --- Reference (single-GPU, full weight) --- + _log("Running reference path...") hidden_ref = hidden.detach().clone().requires_grad_(True) weight_ref = weight_full.detach().clone().requires_grad_(True) logits = (hidden_ref.float() @ weight_ref.float().t()) @@ -266,33 +297,44 @@ def test_linear_cross_entropy_tp2_correctness( probs = log_softmax.exp() ref_h = -(probs * log_softmax).sum(dim=-1) (ref_lp.sum() + 0.5 * ref_h.sum()).backward() + _log(f"Reference done: ref_lp={ref_lp.shape} ref_h={ref_h.shape}") # --- Fused TP=2 --- + _log("Running fused TP=2 path...") hidden_fused = hidden.detach().clone().requires_grad_(True) weight_fused = weight_shard.detach().clone().requires_grad_(True) + _log(f"Calling linear_cross_entropy with tp_group={tp_group}...") fused_lp, fused_h = linear_cross_entropy( hidden_fused, weight_fused, labels, 1.0, "none", tp_group ) + _log(f"Fused forward done: fused_lp={fused_lp.shape} fused_h={fused_h.shape}") + _log("Running fused backward...") (fused_lp.sum() + 0.5 * fused_h.sum()).backward() + _log("Fused backward done") if dtype == torch.float32: rtol, atol = 1e-3, 1e-3 else: rtol, atol = 5e-2, 5e-2 + _log("Asserting logprobs...") torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) + _log("Asserting entropy...") torch.testing.assert_close(fused_h.float(), ref_h.float(), rtol=rtol, atol=atol) + _log("Asserting d_hidden...") torch.testing.assert_close(hidden_fused.grad.float(), hidden_ref.grad.float(), rtol=rtol, atol=atol) - - # d_weight is per-rank (vocab shard), compare only the owned shard + _log("Asserting d_weight...") torch.testing.assert_close( weight_fused.grad.float(), weight_ref.grad[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].float(), rtol=rtol, atol=atol, ) + _log("All assertions passed!") - dist.destroy_process_group() + _log("Calling dist.barrier before cleanup...") + dist.barrier(tp_group) + _log("Test complete, NOT destroying process group (kept for subsequent tests)") @_tp2_skip @@ -317,6 +359,7 @@ def test_linear_cross_entropy_tp2_performance_benchmark( rank, tp_group = _init_tp2() world_size = dist.get_world_size(tp_group) assert world_size == 2 + _log(f"perf bench init done: rank={rank} world_size={world_size}") torch.cuda.set_device(rank) device = f"cuda:{rank}" @@ -325,14 +368,17 @@ def test_linear_cross_entropy_tp2_performance_benchmark( vocab_per_rank = vocab_size // world_size assert vocab_size % world_size == 0 + _log("Creating inputs...") g = torch.Generator(device=device).manual_seed(0) hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() + _log(f"Inputs ready: hidden={hidden.shape} weight_shard={weight_shard.shape}") # --- warm-up --- - for _ in range(2): + _log("Warm-up fused...") + for i in range(2): h = hidden.detach().clone().requires_grad_(True) w = weight_shard.detach().clone().requires_grad_(True) lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) @@ -340,8 +386,10 @@ def test_linear_cross_entropy_tp2_performance_benchmark( del lp, ent, h, w gc.collect() torch.cuda.empty_cache() + _log("Warm-up fused done") - for _ in range(2): + _log("Warm-up reference...") + for i in range(2): h = hidden.detach().clone().requires_grad_(True) w = weight_full.detach().clone().requires_grad_(True) logits = (h.float() @ w.float().t()) @@ -353,11 +401,13 @@ def test_linear_cross_entropy_tp2_performance_benchmark( del lp, ent, h, w gc.collect() torch.cuda.empty_cache() + _log("Warm-up reference done") # --- Fused TP=2 timing --- + _log("Fused TP=2 timing (5 iters)...") fused_times = [] fused_mems = [] - for _ in range(5): + for i in range(5): torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -373,11 +423,13 @@ def test_linear_cross_entropy_tp2_performance_benchmark( fused_times.append(start.elapsed_time(end)) fused_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) del lp, ent, h, w + _log(f"Fused timing done: median={sorted(fused_times)[len(fused_times)//2]:.2f}ms") # --- Reference (single-GPU, full weight) timing --- + _log("Reference timing (5 iters)...") ref_times = [] ref_mems = [] - for _ in range(5): + for i in range(5): torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -397,6 +449,7 @@ def test_linear_cross_entropy_tp2_performance_benchmark( ref_times.append(start.elapsed_time(end)) ref_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) del lp, ent, h, w + _log(f"Reference timing done: median={sorted(ref_times)[len(ref_times)//2]:.2f}ms") ref_med = sorted(ref_times)[len(ref_times) // 2] fused_med = sorted(fused_times)[len(fused_times) // 2] @@ -422,7 +475,7 @@ def test_linear_cross_entropy_tp2_performance_benchmark( f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." ) - dist.destroy_process_group() + _log("TP2 perf bench complete, NOT destroying process group") # --------------------------------------------------------------------------- From 77ea28070443dbfe9514b0bec04163a35a27967d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 18:33:47 +0800 Subject: [PATCH 05/31] refactor(test): remove useless test code --- tests/test_linear_cross_entropy.py | 68 ++---------------------------- 1 file changed, 4 insertions(+), 64 deletions(-) diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py index 6b7e406cf0..fdc63cdd6f 100644 --- a/tests/test_linear_cross_entropy.py +++ b/tests/test_linear_cross_entropy.py @@ -200,50 +200,22 @@ def _tp2_available() -> bool: _tp2_skip = pytest.mark.skipif(not _tp2_available(), reason="TP=2 requires >= 2 CUDA GPUs") -import sys - - -def _log(msg: str) -> None: - """Real-time log to stderr (unbuffered, bypasses pytest capture).""" - import os - - rank = os.environ.get("RANK", "?") - local_rank = os.environ.get("LOCAL_RANK", "?") - sys.stderr.write(f"[LCE-TP2 rank={rank} local_rank={local_rank}] {msg}\n") - sys.stderr.flush() - - def _init_tp2(): """Initialise a 2-rank NCCL process group; return (rank, group).""" import os import torch.distributed as dist - _log("Entering _init_tp2") - if dist.is_initialized(): - _log("dist already initialized, creating new subgroup") rank = dist.get_rank() world_size = dist.get_world_size() - _log(f"rank={rank} world_size={world_size}") group = dist.new_group(ranks=list(range(world_size)), backend="nccl") - _log(f"subgroup created, group={group}") return rank, group - _log("dist NOT initialized, calling init_process_group") - master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") - master_port = os.environ.get("MASTER_PORT", "29500") - _log(f"MASTER_ADDR={master_addr} MASTER_PORT={master_port}") - _log(f"RANK={os.environ.get('RANK', '?')} WORLD_SIZE={os.environ.get('WORLD_SIZE', '?')} " - f"LOCAL_RANK={os.environ.get('LOCAL_RANK', '?')} LOCAL_WORLD_SIZE={os.environ.get('LOCAL_WORLD_SIZE', '?')}") - dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() - _log(f"init_process_group done, rank={rank} world_size={world_size}") - group = dist.new_group(ranks=list(range(world_size)), backend="nccl") - _log(f"subgroup created") return rank, group @@ -266,12 +238,9 @@ def test_linear_cross_entropy_tp2_correctness( from areal.utils.kernel import linear_cross_entropy - _log(f"test start: tokens={num_tokens} hidden={hidden_size} vocab={vocab_size} dtype={dtype}") - rank, tp_group = _init_tp2() world_size = dist.get_world_size(tp_group) assert world_size == 2 - _log(f"init done: rank={rank} world_size={world_size}") torch.cuda.set_device(rank) device = f"cuda:{rank}" @@ -279,16 +248,13 @@ def test_linear_cross_entropy_tp2_correctness( vocab_per_rank = vocab_size // world_size assert vocab_size % world_size == 0, "vocab_size must be divisible by world_size" - _log("Creating inputs...") g = torch.Generator(device=device).manual_seed(42) hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() - _log(f"Inputs ready: hidden={hidden.shape} weight_shard={weight_shard.shape} labels={labels.shape}") # --- Reference (single-GPU, full weight) --- - _log("Running reference path...") hidden_ref = hidden.detach().clone().requires_grad_(True) weight_ref = weight_full.detach().clone().requires_grad_(True) logits = (hidden_ref.float() @ weight_ref.float().t()) @@ -297,44 +263,31 @@ def test_linear_cross_entropy_tp2_correctness( probs = log_softmax.exp() ref_h = -(probs * log_softmax).sum(dim=-1) (ref_lp.sum() + 0.5 * ref_h.sum()).backward() - _log(f"Reference done: ref_lp={ref_lp.shape} ref_h={ref_h.shape}") # --- Fused TP=2 --- - _log("Running fused TP=2 path...") hidden_fused = hidden.detach().clone().requires_grad_(True) weight_fused = weight_shard.detach().clone().requires_grad_(True) - _log(f"Calling linear_cross_entropy with tp_group={tp_group}...") fused_lp, fused_h = linear_cross_entropy( hidden_fused, weight_fused, labels, 1.0, "none", tp_group ) - _log(f"Fused forward done: fused_lp={fused_lp.shape} fused_h={fused_h.shape}") - _log("Running fused backward...") (fused_lp.sum() + 0.5 * fused_h.sum()).backward() - _log("Fused backward done") if dtype == torch.float32: rtol, atol = 1e-3, 1e-3 else: rtol, atol = 5e-2, 5e-2 - _log("Asserting logprobs...") torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) - _log("Asserting entropy...") torch.testing.assert_close(fused_h.float(), ref_h.float(), rtol=rtol, atol=atol) - _log("Asserting d_hidden...") torch.testing.assert_close(hidden_fused.grad.float(), hidden_ref.grad.float(), rtol=rtol, atol=atol) - _log("Asserting d_weight...") torch.testing.assert_close( weight_fused.grad.float(), weight_ref.grad[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].float(), rtol=rtol, atol=atol, ) - _log("All assertions passed!") - _log("Calling dist.barrier before cleanup...") dist.barrier(tp_group) - _log("Test complete, NOT destroying process group (kept for subsequent tests)") @_tp2_skip @@ -359,7 +312,6 @@ def test_linear_cross_entropy_tp2_performance_benchmark( rank, tp_group = _init_tp2() world_size = dist.get_world_size(tp_group) assert world_size == 2 - _log(f"perf bench init done: rank={rank} world_size={world_size}") torch.cuda.set_device(rank) device = f"cuda:{rank}" @@ -368,17 +320,14 @@ def test_linear_cross_entropy_tp2_performance_benchmark( vocab_per_rank = vocab_size // world_size assert vocab_size % world_size == 0 - _log("Creating inputs...") g = torch.Generator(device=device).manual_seed(0) hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() - _log(f"Inputs ready: hidden={hidden.shape} weight_shard={weight_shard.shape}") # --- warm-up --- - _log("Warm-up fused...") - for i in range(2): + for _ in range(2): h = hidden.detach().clone().requires_grad_(True) w = weight_shard.detach().clone().requires_grad_(True) lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) @@ -386,10 +335,8 @@ def test_linear_cross_entropy_tp2_performance_benchmark( del lp, ent, h, w gc.collect() torch.cuda.empty_cache() - _log("Warm-up fused done") - _log("Warm-up reference...") - for i in range(2): + for _ in range(2): h = hidden.detach().clone().requires_grad_(True) w = weight_full.detach().clone().requires_grad_(True) logits = (h.float() @ w.float().t()) @@ -401,13 +348,11 @@ def test_linear_cross_entropy_tp2_performance_benchmark( del lp, ent, h, w gc.collect() torch.cuda.empty_cache() - _log("Warm-up reference done") # --- Fused TP=2 timing --- - _log("Fused TP=2 timing (5 iters)...") fused_times = [] fused_mems = [] - for i in range(5): + for _ in range(5): torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -423,13 +368,11 @@ def test_linear_cross_entropy_tp2_performance_benchmark( fused_times.append(start.elapsed_time(end)) fused_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) del lp, ent, h, w - _log(f"Fused timing done: median={sorted(fused_times)[len(fused_times)//2]:.2f}ms") # --- Reference (single-GPU, full weight) timing --- - _log("Reference timing (5 iters)...") ref_times = [] ref_mems = [] - for i in range(5): + for _ in range(5): torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -449,7 +392,6 @@ def test_linear_cross_entropy_tp2_performance_benchmark( ref_times.append(start.elapsed_time(end)) ref_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) del lp, ent, h, w - _log(f"Reference timing done: median={sorted(ref_times)[len(ref_times)//2]:.2f}ms") ref_med = sorted(ref_times)[len(ref_times) // 2] fused_med = sorted(fused_times)[len(fused_times) // 2] @@ -475,8 +417,6 @@ def test_linear_cross_entropy_tp2_performance_benchmark( f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." ) - _log("TP2 perf bench complete, NOT destroying process group") - # --------------------------------------------------------------------------- # Performance benchmark (single-GPU) From 47fd3e26f859c45cee063f7dcbaab497115d08e9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 19:25:17 +0800 Subject: [PATCH 06/31] perf: NVTX --- areal/engine/megatron_engine.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 44e790437c..dec172b55f 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -774,6 +774,8 @@ def forward_step(batch_iter, model): and not cp_local ) + lce_label = "fused_lce" if should_capture else "materialized_lce" + torch.cuda.nvtx.range_push(f"lce_forward/{lce_label}") with capture_lm_head_hidden( model, enabled=should_capture ) as capture: @@ -795,6 +797,8 @@ def forward_step(batch_iter, model): mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight mb_input.orig_mb["_fused_lce_active"] = True + torch.cuda.nvtx.range_pop() + # Release tree attention metadata after forward pass for key in tree_attn_keys: del mb_input.padded_mb[key] @@ -1878,6 +1882,7 @@ def _compute_logprobs_and_loss( and fused_hidden is not None and fused_weight is not None ): + torch.cuda.nvtx.range_push("lce_loss/fused_lce") logprobs, entropy = linear_cross_entropy_logprobs_entropy( fused_hidden, fused_weight, @@ -1895,7 +1900,9 @@ def _compute_logprobs_and_loss( proxy = logprobs.detach().float() vocab_min_logits = proxy vocab_max_logits = proxy + torch.cuda.nvtx.range_pop() else: + torch.cuda.nvtx.range_push("lce_loss/materialized_lce") logprobs, entropy = gather_logprobs_entropy( output, labels, @@ -1906,6 +1913,7 @@ def _compute_logprobs_and_loss( ) vocab_min_logits = output.detach().min(-1).values.float() vocab_max_logits = output.detach().max(-1).values.float() + torch.cuda.nvtx.range_pop() loss = loss_fn( logprobs, entropy, @@ -1950,6 +1958,7 @@ def _compute_forward_result( and fused_hidden is not None and fused_weight is not None ): + torch.cuda.nvtx.range_push("lce_forward_result/fused_lce") logprobs = linear_cross_entropy_logprobs( fused_hidden, fused_weight, @@ -1959,7 +1968,9 @@ def _compute_forward_result( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) + torch.cuda.nvtx.range_pop() return logprobs + torch.cuda.nvtx.range_push("lce_forward_result/materialized_lce") logprobs = gather_logprobs( output, labels, @@ -1968,6 +1979,7 @@ def _compute_forward_result( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) + torch.cuda.nvtx.range_pop() return logprobs else: values = output.squeeze(-1) From 203a7e457a6257277aef240dbd3848a0c021c34d Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 19:26:56 +0800 Subject: [PATCH 07/31] feat(config): add use_fused_linear_ce config --- areal/api/cli_args.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 17e0b556ef..2085adafcc 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -461,13 +461,6 @@ def __post_init__(self): }, ) - use_fused_moe: bool = field( - default=True, - metadata={ - "help": "" - }, - ) - @dataclass class ArchonFP8Config: @@ -1144,6 +1137,14 @@ class TrainEngineConfig: default=False, metadata={"help": "Enable tree training with flex attention module."}, ) + use_fused_linear_ce: bool = field( + default=False, + metadata={ + "help": "Fuse the linear projection with cross-entropy so that the " + "[num_tokens, vocab_size] logits tensor is never materialised. " + "Only effective for the Megatron actor backend with parallel_output=True." + }, + ) # Scheduling scheduling_spec: tuple[SchedulingSpec, ...] = field( From e2e5d705537067890a8c70229726c99eab2b948e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 19:32:48 +0800 Subject: [PATCH 08/31] fix(utils): network --- areal/utils/network.py | 30 +++++++++++++++----------- examples/math/gsm8k_grpo_megatron.yaml | 2 +- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/areal/utils/network.py b/areal/utils/network.py index ca4ae264cb..2f0c36a883 100644 --- a/areal/utils/network.py +++ b/areal/utils/network.py @@ -23,6 +23,22 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: Raises: RuntimeError: If no suitable address can be determined """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.connect((probe_host, probe_port)) + return sock.getsockname()[0] + except OSError: + pass + + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: + sock.connect(("2001:4860:4860::8888", probe_port)) + ip6 = sock.getsockname()[0] + if ip6 and ip6 != "::1": + return ip6 + except OSError: + pass + try: hostname = socket.gethostname() infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_DGRAM) @@ -38,19 +54,7 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: except socket.gaierror: pass - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: - sock.connect((probe_host, probe_port)) - return sock.getsockname()[0] - except OSError as e: - try: - with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: - sock.connect(("2001:4860:4860::8888", probe_port)) - ip6 = sock.getsockname()[0] - if ip6 and ip6 != "::1": - return ip6 - except OSError: - raise RuntimeError("Could not determine host IP") from e + raise RuntimeError("Could not determine host IP") def get_loopback_ip() -> str: diff --git a/examples/math/gsm8k_grpo_megatron.yaml b/examples/math/gsm8k_grpo_megatron.yaml index 2482b297bc..0cb03c0830 100644 --- a/examples/math/gsm8k_grpo_megatron.yaml +++ b/examples/math/gsm8k_grpo_megatron.yaml @@ -43,7 +43,7 @@ actor: backend: "megatron:d4p1t1" experiment_name: ${experiment_name} trial_name: ${trial_name} - path: Qwen/Qwen2.5-1.5B-Instruct + path: /workspace/models/Qwen3-0.6B init_from_scratch: false disable_dropout: true gradient_checkpointing: false From fecfcbddee20a58ef32fd357c523d97b69ba6187 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 21:48:17 +0800 Subject: [PATCH 09/31] feat(profiling):nsys flush --- areal/engine/megatron_engine.py | 41 ++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index dec172b55f..398407cd48 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -871,6 +871,27 @@ def train_batch( self._ensure_ready() self.optimizer_zero_grad() + if not hasattr(self, "_lce_profiler"): + lce_profiler_dir = os.environ.get("AREAL_LCE_PROFILER_DIR", "") + if lce_profiler_dir: + import torch.profiler + + self._lce_profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + schedule=torch.profiler.schedule( + wait=1, warmup=1, active=1, repeat=1 + ), + on_trace_ready=torch.profiler.tensorboard_trace_handler( + lce_profiler_dir + ), + ) + self._lce_profiler.start() + input_batched, _ = self._normalize_batch_input(input_) # Step 1: Prepare micro-batches @@ -905,7 +926,12 @@ def process_output( ) # Step 4: Optimizer step - return self.optimizer_step() + result = self.optimizer_step() + + if hasattr(self, "_lce_profiler"): + self._lce_profiler.step() + + return result @torch.no_grad() def eval_batch( @@ -1308,6 +1334,19 @@ def _ensure_ready(self) -> None: if self.model is None: raise RuntimeError("Model is not initialized.") + if not hasattr(self, "_nsys_flush_registered"): + self._nsys_flush_registered = True + import signal + + def _nsys_flush_handler(signum, frame): + try: + torch.cuda.cudart().cudaProfilerStop() + except Exception: + pass + raise SystemExit(128 + signum) + + signal.signal(signal.SIGTERM, _nsys_flush_handler) + def _update_bucket_weights_from_distributed( self, meta: WeightUpdateMeta, From 0b5eefe5f8e86bc6314f032df6d34ae264b1ea7f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 21:59:33 +0800 Subject: [PATCH 10/31] fix(engine): fix --- areal/engine/megatron_engine.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 398407cd48..2ba6e04787 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1336,16 +1336,19 @@ def _ensure_ready(self) -> None: if not hasattr(self, "_nsys_flush_registered"): self._nsys_flush_registered = True - import signal + import threading - def _nsys_flush_handler(signum, frame): - try: - torch.cuda.cudart().cudaProfilerStop() - except Exception: - pass - raise SystemExit(128 + signum) + if threading.current_thread() is threading.main_thread(): + import signal + + def _nsys_flush_handler(signum, frame): + try: + torch.cuda.cudart().cudaProfilerStop() + except Exception: + pass + raise SystemExit(128 + signum) - signal.signal(signal.SIGTERM, _nsys_flush_handler) + signal.signal(signal.SIGTERM, _nsys_flush_handler) def _update_bucket_weights_from_distributed( self, From e69674e8244c4699f3b0c1817c4bab78c34fd798 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Fri, 8 May 2026 22:15:59 +0800 Subject: [PATCH 11/31] feat(profiler): torch profile --- areal/engine/megatron_engine.py | 126 ++++++++++++++++++++------------ 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 2ba6e04787..9219fd2b1e 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -148,6 +148,33 @@ def parameters(self, *args, **kwargs) -> Iterator[nn.Parameter]: yield parameter +_LCE_KERNEL_PREFIXES = ( + "efficient_entropy_kernel", + "lce_backward", + "triton_poi", +) + + +def _print_lce_summary(prof: torch.profiler.profile, rank: int) -> None: + logger = logging.getLogger("LCEProfiler") + events = prof.key_averages() + lce_rows: list[tuple[str, float, float, float]] = [] + for evt in events: + key = evt.key + if any(p in key for p in _LCE_KERNEL_PREFIXES): + cuda_ms = evt.cuda_time_total / 1000.0 + cpu_ms = evt.cpu_time_total / 1000.0 + calls = evt.count + lce_rows.append((key, cuda_ms, cpu_ms, calls)) + if not lce_rows: + logger.info(f"[Rank {rank}] No LCE Triton kernels found in profiler trace.") + return + header = f"[Rank {rank}] LCE Kernel Profiling Summary (CUDA ms / CPU ms / calls):" + logger.info(header) + for key, cuda_ms, cpu_ms, calls in lce_rows: + logger.info(f" {key}: {cuda_ms:.3f} / {cpu_ms:.3f} / {calls}") + + class MegatronEngine(TrainEngine): def __init__(self, config: TrainEngineConfig): self.config = config @@ -774,8 +801,6 @@ def forward_step(batch_iter, model): and not cp_local ) - lce_label = "fused_lce" if should_capture else "materialized_lce" - torch.cuda.nvtx.range_push(f"lce_forward/{lce_label}") with capture_lm_head_hidden( model, enabled=should_capture ) as capture: @@ -797,8 +822,6 @@ def forward_step(batch_iter, model): mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight mb_input.orig_mb["_fused_lce_active"] = True - torch.cuda.nvtx.range_pop() - # Release tree attention metadata after forward pass for key in tree_attn_keys: del mb_input.padded_mb[key] @@ -871,26 +894,7 @@ def train_batch( self._ensure_ready() self.optimizer_zero_grad() - if not hasattr(self, "_lce_profiler"): - lce_profiler_dir = os.environ.get("AREAL_LCE_PROFILER_DIR", "") - if lce_profiler_dir: - import torch.profiler - - self._lce_profiler = torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - profile_memory=True, - schedule=torch.profiler.schedule( - wait=1, warmup=1, active=1, repeat=1 - ), - on_trace_ready=torch.profiler.tensorboard_trace_handler( - lce_profiler_dir - ), - ) - self._lce_profiler.start() + self._maybe_init_lce_profiler() input_batched, _ = self._normalize_batch_input(input_) @@ -928,8 +932,7 @@ def process_output( # Step 4: Optimizer step result = self.optimizer_step() - if hasattr(self, "_lce_profiler"): - self._lce_profiler.step() + self._maybe_step_lce_profiler() return result @@ -1334,21 +1337,60 @@ def _ensure_ready(self) -> None: if self.model is None: raise RuntimeError("Model is not initialized.") - if not hasattr(self, "_nsys_flush_registered"): - self._nsys_flush_registered = True - import threading + _LCE_PROFILER_KEY = "_lce_profiler" + _LCE_PROFILER_ENV = "ARENAL_LCE_PROFILER_DIR" - if threading.current_thread() is threading.main_thread(): - import signal + def _maybe_init_lce_profiler(self) -> None: + if hasattr(self, self._LCE_PROFILER_KEY): + return + profiler_dir = os.environ.get(self._LCE_PROFILER_ENV, "") + if not profiler_dir: + setattr(self, self._LCE_PROFILER_KEY, None) + return - def _nsys_flush_handler(signum, frame): - try: - torch.cuda.cudart().cudaProfilerStop() - except Exception: - pass - raise SystemExit(128 + signum) + import torch.profiler + + rank = dist.get_rank() if dist.is_initialized() else 0 + output_dir = os.path.join(profiler_dir, f"rank_{rank}") + os.makedirs(output_dir, exist_ok=True) + + def _lce_trace_handler(prof: torch.profiler.profile) -> None: + torch.profiler.tensorboard_trace_handler(output_dir)(prof) + _print_lce_summary(prof, rank) + + setattr( + self, + self._LCE_PROFILER_KEY, + torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + record_shapes=True, + profile_memory=True, + with_stack=True, + schedule=torch.profiler.schedule( + wait=2, warmup=1, active=3, repeat=1 + ), + on_trace_ready=_lce_trace_handler, + ), + ) + getattr(self, self._LCE_PROFILER_KEY).start() + logger = logging.getLogger("LCEProfiler") + logger.info( + f"[Rank {rank}] torch.profiler started, traces will be saved to {output_dir} " + f"(schedule: wait=2, warmup=1, active=3)" + ) - signal.signal(signal.SIGTERM, _nsys_flush_handler) + def _maybe_step_lce_profiler(self) -> None: + profiler = getattr(self, self._LCE_PROFILER_KEY, None) + if profiler is None: + return + profiler.step() + if profiler.profiler is None: + logger = logging.getLogger("LCEProfiler") + rank = dist.get_rank() if dist.is_initialized() else 0 + logger.info(f"[Rank {rank}] torch.profiler finished and stopped.") def _update_bucket_weights_from_distributed( self, @@ -1924,7 +1966,6 @@ def _compute_logprobs_and_loss( and fused_hidden is not None and fused_weight is not None ): - torch.cuda.nvtx.range_push("lce_loss/fused_lce") logprobs, entropy = linear_cross_entropy_logprobs_entropy( fused_hidden, fused_weight, @@ -1942,9 +1983,7 @@ def _compute_logprobs_and_loss( proxy = logprobs.detach().float() vocab_min_logits = proxy vocab_max_logits = proxy - torch.cuda.nvtx.range_pop() else: - torch.cuda.nvtx.range_push("lce_loss/materialized_lce") logprobs, entropy = gather_logprobs_entropy( output, labels, @@ -1955,7 +1994,6 @@ def _compute_logprobs_and_loss( ) vocab_min_logits = output.detach().min(-1).values.float() vocab_max_logits = output.detach().max(-1).values.float() - torch.cuda.nvtx.range_pop() loss = loss_fn( logprobs, entropy, @@ -2000,7 +2038,6 @@ def _compute_forward_result( and fused_hidden is not None and fused_weight is not None ): - torch.cuda.nvtx.range_push("lce_forward_result/fused_lce") logprobs = linear_cross_entropy_logprobs( fused_hidden, fused_weight, @@ -2010,9 +2047,7 @@ def _compute_forward_result( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) - torch.cuda.nvtx.range_pop() return logprobs - torch.cuda.nvtx.range_push("lce_forward_result/materialized_lce") logprobs = gather_logprobs( output, labels, @@ -2021,7 +2056,6 @@ def _compute_forward_result( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) - torch.cuda.nvtx.range_pop() return logprobs else: values = output.squeeze(-1) From d444cbeb7fb06dc959e3082d4e806fb5ab274896 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 9 May 2026 00:04:59 +0800 Subject: [PATCH 12/31] fix(sequence_parallel): fix sp --- .../megatron_utils/fused_lce_capture.py | 52 +++++++++++++++++-- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py index 5886026446..e67d3e6593 100644 --- a/areal/engine/megatron_utils/fused_lce_capture.py +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -13,7 +13,11 @@ 1. Stashes the input tensor (``hidden``) and the actual weight (either the ``output_layer``'s own weight, or the embedding-tied weight passed in via - ``weight=``). + ``weight=``). When sequence-parallel is active (TP > 1 in AReaL), the + incoming ``input_`` is scattered along seq to ``[seq/tp_size, hidden]``, + so we first call ``gather_from_sequence_parallel_region`` to restore the + full ``[seq, hidden]`` tensor — exactly mirroring the first step of + mcore's ``ColumnParallelLinear.forward``. 2. Returns ``(hidden, None)`` instead of ``(logits, bias)``. Because :func:`areal.utils.data.unpad_logits` and :func:`postprocess_packed_seqs_context_parallel` are shape-agnostic on @@ -49,6 +53,9 @@ import torch from megatron.core import parallel_state as mpu +from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, +) from areal.utils import logging @@ -161,21 +168,58 @@ def capture_lm_head_hidden( slot = _CaptureSlot() original_forward = output_layer.forward + # Detect sequence-parallel mode. In mcore, when TP > 1 AReaL enables + # ``sequence_parallel=True`` (see ``MegatronEngine._make_parallel_strategy``), + # which means the input handed to ``ColumnParallelLinear.forward`` is + # *scattered* along the sequence dimension to shape ``[seq/tp_size, hidden]``. + # The original ``ColumnParallelLinear.forward`` first calls + # ``gather_from_sequence_parallel_region`` to restore the full ``[seq, hidden]`` + # tensor before doing the matmul. Our identity-style patch must replicate + # that gather, otherwise: + # * the captured ``hidden`` is only this rank's sequence shard, leading + # to wrong fused-kernel inputs and wrong logprobs; + # * the tensor returned to mcore (which then flows through + # ``postprocess_packed_seqs_context_parallel`` and ``unpad_logits``) has + # dim-0 = seq/tp_size, which mismatches ``cu_seqlens`` / ``old_cu_seqlens`` + # and crashes with shape errors like "expanded size (X) must match + # existing size (X/tp_size) at non-singleton dimension 0". + config = getattr(post_process, "config", None) + sequence_parallel = bool(getattr(config, "sequence_parallel", False)) + tp_world_size = mpu.get_tensor_model_parallel_world_size() + needs_sp_gather = sequence_parallel and tp_world_size > 1 + def _patched_forward(input_, weight=None, runtime_gather_output=None): # Resolve the actual weight: either passed in (weight tying) or the # output_layer's own parameter. We intentionally store a *reference* # (not detach) so autograd flows through both the kernel forward # and backward. actual_weight = weight if weight is not None else output_layer.weight - slot.hidden = input_ + + # When sequence parallel is on, ``input_`` is shape ``[seq/tp_size, hidden]``. + # Gather along the sequence dim to obtain the full ``[seq, hidden]`` tensor + # — this is exactly what the original ``ColumnParallelLinear.forward`` + # does as its first step. ``gather_from_sequence_parallel_region`` is + # an autograd-aware op (its backward is a reduce-scatter along seq), + # so gradients flow correctly back into the SP-scattered upstream. + hidden = input_ + if needs_sp_gather: + hidden = gather_from_sequence_parallel_region(hidden) + + slot.hidden = hidden slot.weight = actual_weight - # Return ``(input_, None)``: callers expect ``(logits, bias)`` and + # Return ``(hidden, None)``: callers expect ``(logits, bias)`` and # only ever destructure with ``logits, _ = output_layer(...)``. The # downstream pipeline (``unpad_logits`` etc.) is shape-agnostic on # the trailing dim, so passing ``hidden`` through is safe; the # fused kernel will then consume the stashed tensors and produce # the real per-token logprobs. - return input_, None + # + # Crucially we return the *gathered* hidden so that the leading + # sequence dim matches what mcore would have produced for the real + # logits tensor (``[seq, vocab/tp_size]``). This keeps every + # downstream shape invariant intact (CP all-gather, batch-padding + # strip, ``unpad_logits`` cu_seqlens slicing). + return hidden, None # ``output_layer.forward = _patched_forward`` replaces the bound method # at instance level (via ``__dict__`` lookup), shadowing the class From 07f03d8a6229fcd7224aaad310f9cd383e9ecd5e Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 9 May 2026 00:19:58 +0800 Subject: [PATCH 13/31] fix(engine): dtype --- .../megatron_utils/fused_lce_capture.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py index e67d3e6593..a1f3899800 100644 --- a/areal/engine/megatron_utils/fused_lce_capture.py +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -205,6 +205,28 @@ def _patched_forward(input_, weight=None, runtime_gather_output=None): if needs_sp_gather: hidden = gather_from_sequence_parallel_region(hidden) + # Align ``hidden`` dtype to ``actual_weight`` dtype before handing the + # tensors to the fused Triton kernel. + # + # Why this is required: + # * Megatron-Core feeds ``output_layer`` with the post-final-layernorm + # activation, which is typically fp32 under mixed-precision training, + # while ``output_layer.weight`` is bf16/fp16. The original + # ``ColumnParallelLinear.forward`` silently downcasts ``input_`` to + # the weight dtype inside + # ``linear_with_grad_accumulation_and_async_allreduce``; our + # identity-style monkey-patch bypasses that path and would otherwise + # hand mismatched dtypes to ``efficient_entropy_forward``. + # * Triton's ``tl.dot`` requires both operands to share the same dtype; + # a mismatch triggers warnings such as + # "Both operands must be same dtype. Got fp32 and bf16; falling back + # to reference path." and silently disables the fused fast path. + # * ``Tensor.to(dtype)`` is autograd-aware: backward auto-upcasts + # gradients to the original dtype, so the upstream fp32 activation + # receives a fp32 grad as expected. + if hidden.dtype != actual_weight.dtype: + hidden = hidden.to(actual_weight.dtype) + slot.hidden = hidden slot.weight = actual_weight # Return ``(hidden, None)``: callers expect ``(logits, bias)`` and From 003296b12f89027e0efe654a444cfb71e2d3d364 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 9 May 2026 00:38:10 +0800 Subject: [PATCH 14/31] fix(engine): dtype again --- areal/engine/megatron_engine.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 9219fd2b1e..5628abb5d4 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -865,7 +865,37 @@ def _process_output(input_, output_): # state (the LM-head was monkey-patched to a no-op), # so the unpadded tensor is the hidden we want to # feed into the fused kernel. + # + # Megatron's outer ``Float16Module`` wraps the inner + # ``GPTModel`` and, on the last pipeline stage, + # *upcasts the wrapped module's outputs to fp32* + # (see ``Float16Module.forward(..., fp32_output=True)`` + # and ``float16_to_fp32`` in + # ``megatron.core.transformer.module``). The captured + # hidden was already cast to ``weight.dtype`` (bf16/fp16) + # inside ``capture_lm_head_hidden._patched_forward`` to + # mirror ``ColumnParallelLinear``'s implicit downcast, + # but ``Float16Module``'s post-hoc upcast then re-promotes + # the tensor returned to mcore back to fp32. The Triton + # GEMM in ``efficient_entropy_forward`` requires both + # operands to share the same dtype; without re-aligning + # here, the kernel raises + # "Both operands must be same dtype. Got fp32 and + # bf16; falling back to reference path." + # and silently disables the fused fast path. + # + # ``Tensor.to(dtype)`` is autograd-aware; backward will + # auto-upcast gradients to the upstream fp32 dtype, which + # is exactly what mcore would have produced anyway. if mb_input.orig_mb.get("_fused_lce_active", False): + fused_weight = mb_input.orig_mb.get( + FUSED_LCE_WEIGHT_KEY + ) + if ( + fused_weight is not None + and output.dtype != fused_weight.dtype + ): + output = output.to(fused_weight.dtype) mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output return output, functools.partial(_process_output, mb_input.orig_mb) From 0020faa34c4a6e229ded91c0792a2a9cb74f201b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sat, 9 May 2026 23:16:27 +0800 Subject: [PATCH 15/31] test(linear_cross_entropy): fix test --- tests/test_linear_cross_entropy.py | 303 ++++++++++------------------- tests/torchrun/run_lce_tp2.py | 265 +++++++++++++++++++++++++ 2 files changed, 371 insertions(+), 197 deletions(-) create mode 100644 tests/torchrun/run_lce_tp2.py diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py index fdc63cdd6f..fdc2abd8b1 100644 --- a/tests/test_linear_cross_entropy.py +++ b/tests/test_linear_cross_entropy.py @@ -116,11 +116,17 @@ def test_linear_cross_entropy_correctness( hidden, weight, labels, temperature=1.0 ) - # Tolerances are dtype-dependent; bf16/fp16 inputs widen them as expected. + # Tolerances are dtype-dependent. The fused kernel performs the same + # matmul + log-softmax math as the reference, so fp32 inputs should agree + # to within a few ulps (~1e-5). bf16 / fp16 inputs are widened only to + # absorb the documented matmul-accumulation drift; anything looser would + # mask real numerical regressions. if dtype == torch.float32: - rtol, atol = 1e-4, 1e-4 - else: - rtol, atol = 5e-2, 5e-2 + rtol, atol = 1e-5, 1e-5 + elif dtype == torch.bfloat16: + rtol, atol = 2e-2, 2e-2 + else: # float16 + rtol, atol = 1e-2, 1e-2 torch.testing.assert_close( fused_logprobs.float(), ref_logprobs.float(), rtol=rtol, atol=atol @@ -142,15 +148,40 @@ def test_linear_cross_entropy_temperature(temperature: float) -> None: fused_lp, fused_h = linear_cross_entropy_logprobs_entropy( hidden, weight, labels, temperature=temperature ) - torch.testing.assert_close(fused_lp, ref_lp, rtol=1e-4, atol=1e-4) - torch.testing.assert_close(fused_h, ref_h, rtol=1e-4, atol=1e-4) + # fp32 inputs: fused vs reference must agree to ~1e-5 (a few ulps). + torch.testing.assert_close(fused_lp, ref_lp, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(fused_h, ref_h, rtol=1e-5, atol=1e-5) + +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size", + [ + # Small shape: catches obvious correctness bugs cheaply. + (64, 256, 2048), + # Medium shape: typical SFT microbatch. + (512, 1024, 32000), + # Large shape: stresses the fused backward at LLM-class dimensions + # where the materialised reference begins to dominate memory but is + # still fp32-tractable on a single GPU. This is the configuration + # most likely to surface accumulation-order bugs in d_hidden / + # d_weight reductions. + (2048, 2048, 32000), + ], + ids=["small_64x256x2048", "medium_512x1024x32k", "large_2048x2048x32k"], +) +def test_linear_cross_entropy_backward_matches_reference( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + """Backward gradients on hidden/weight match autograd through the reference. -def test_linear_cross_entropy_backward_matches_reference() -> None: - """Backward gradients on hidden/weight match autograd through the reference.""" + Runs across small / medium / large shapes so that any accumulation-order + drift in the fused d_hidden / d_weight kernels is caught at scale rather + than only on toy inputs. + """ from areal.utils.kernel import linear_cross_entropy - num_tokens, hidden_size, vocab_size = 64, 256, 2048 hidden_a, weight_a, labels = _make_inputs( num_tokens, hidden_size, vocab_size, torch.float32 ) @@ -175,16 +206,25 @@ def test_linear_cross_entropy_backward_matches_reference() -> None: ) (fused_lp.sum() + 0.5 * fused_h.sum()).backward() + # fp32 inputs: backward must match the reference to ~1e-4. The fused + # kernel's d_weight accumulates ``num_tokens`` partial products, so we + # use a slightly looser absolute tolerance for d_weight at the largest + # shape; rtol stays tight to catch directional errors. torch.testing.assert_close( - hidden_a.grad, hidden_b.grad, rtol=5e-3, atol=5e-3 + hidden_a.grad, hidden_b.grad, rtol=1e-4, atol=1e-4 ) + weight_atol = 1e-4 if num_tokens <= 512 else 5e-4 torch.testing.assert_close( - weight_a.grad, weight_b.grad, rtol=5e-3, atol=5e-3 + weight_a.grad, weight_b.grad, rtol=1e-4, atol=weight_atol ) # --------------------------------------------------------------------------- # Tensor-parallel (TP=2) correctness + performance +# +# These tests are invoked through pytest, while the 2-rank distributed body is +# launched with subprocess.run(["torchrun", ...]) following the repository's +# distributed-test pattern. Users do not need to run torchrun manually. # --------------------------------------------------------------------------- @@ -197,100 +237,77 @@ def _tp2_available() -> bool: return True -_tp2_skip = pytest.mark.skipif(not _tp2_available(), reason="TP=2 requires >= 2 CUDA GPUs") - - -def _init_tp2(): - """Initialise a 2-rank NCCL process group; return (rank, group).""" - import os - - import torch.distributed as dist +_tp2_skip = pytest.mark.skipif( + not _tp2_available(), reason="TP=2 requires >= 2 CUDA GPUs" +) - if dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - group = dist.new_group(ranks=list(range(world_size)), backend="nccl") - return rank, group - dist.init_process_group(backend="nccl") - rank = dist.get_rank() - world_size = dist.get_world_size() - group = dist.new_group(ranks=list(range(world_size)), backend="nccl") - return rank, group +def _run_lce_tp2_with_torchrun( + test_type: str, + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: str = "bfloat16", +) -> None: + import subprocess + + from areal.utils.network import find_free_ports + + port = find_free_ports(1)[0] + try: + subprocess.run( + [ + "torchrun", + "--nproc_per_node=2", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/torchrun/run_lce_tp2.py", + f"--test_type={test_type}", + f"--num_tokens={num_tokens}", + f"--hidden_size={hidden_size}", + f"--vocab_size={vocab_size}", + f"--dtype={dtype}", + ], + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + pytest.fail( + "TP=2 LCE torchrun test failed:\n" + f"STDOUT:\n{e.stdout}\n" + f"STDERR:\n{e.stderr}" + ) @_tp2_skip +@pytest.mark.multi_gpu @pytest.mark.parametrize( - "num_tokens,hidden_size,vocab_size,dtype", + "num_tokens,hidden_size,vocab_size,dtype_str", [ - (128, 512, 8192, torch.float32), - (256, 1024, 32000, torch.bfloat16), + (128, 512, 8192, "float32"), + (256, 1024, 32000, "bfloat16"), ], ) def test_linear_cross_entropy_tp2_correctness( num_tokens: int, hidden_size: int, vocab_size: int, - dtype: torch.dtype, + dtype_str: str, ) -> None: - """TP=2 fused forward+backward must match the materialised single-GPU reference.""" - import torch.distributed as dist - - from areal.utils.kernel import linear_cross_entropy - - rank, tp_group = _init_tp2() - world_size = dist.get_world_size(tp_group) - assert world_size == 2 - - torch.cuda.set_device(rank) - device = f"cuda:{rank}" + """TP=2 fused forward+backward matches a full-vocab reference. - vocab_per_rank = vocab_size // world_size - assert vocab_size % world_size == 0, "vocab_size must be divisible by world_size" - - g = torch.Generator(device=device).manual_seed(42) - hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) - labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) - weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) - weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() - - # --- Reference (single-GPU, full weight) --- - hidden_ref = hidden.detach().clone().requires_grad_(True) - weight_ref = weight_full.detach().clone().requires_grad_(True) - logits = (hidden_ref.float() @ weight_ref.float().t()) - log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) - ref_lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - probs = log_softmax.exp() - ref_h = -(probs * log_softmax).sum(dim=-1) - (ref_lp.sum() + 0.5 * ref_h.sum()).backward() - - # --- Fused TP=2 --- - hidden_fused = hidden.detach().clone().requires_grad_(True) - weight_fused = weight_shard.detach().clone().requires_grad_(True) - fused_lp, fused_h = linear_cross_entropy( - hidden_fused, weight_fused, labels, 1.0, "none", tp_group - ) - (fused_lp.sum() + 0.5 * fused_h.sum()).backward() - - if dtype == torch.float32: - rtol, atol = 1e-3, 1e-3 - else: - rtol, atol = 5e-2, 5e-2 - - torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) - torch.testing.assert_close(fused_h.float(), ref_h.float(), rtol=rtol, atol=atol) - torch.testing.assert_close(hidden_fused.grad.float(), hidden_ref.grad.float(), rtol=rtol, atol=atol) - torch.testing.assert_close( - weight_fused.grad.float(), - weight_ref.grad[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].float(), - rtol=rtol, - atol=atol, + The 2-rank worker is launched via torchrun inside this pytest test, so the + caller can use a normal pytest command. + """ + _run_lce_tp2_with_torchrun( + "correctness", num_tokens, hidden_size, vocab_size, dtype_str ) - dist.barrier(tp_group) - @_tp2_skip +@pytest.mark.multi_gpu @pytest.mark.slow @pytest.mark.parametrize( "num_tokens,hidden_size,vocab_size", @@ -304,117 +321,9 @@ def test_linear_cross_entropy_tp2_performance_benchmark( hidden_size: int, vocab_size: int, ) -> None: - """TP=2 fused vs materialised forward+backward time and peak memory.""" - import torch.distributed as dist - - from areal.utils.kernel import linear_cross_entropy - - rank, tp_group = _init_tp2() - world_size = dist.get_world_size(tp_group) - assert world_size == 2 - - torch.cuda.set_device(rank) - device = f"cuda:{rank}" - dtype = torch.bfloat16 - - vocab_per_rank = vocab_size // world_size - assert vocab_size % world_size == 0 - - g = torch.Generator(device=device).manual_seed(0) - hidden = (torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) - labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) - weight_full = (torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02) - weight_shard = weight_full[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].contiguous() - - # --- warm-up --- - for _ in range(2): - h = hidden.detach().clone().requires_grad_(True) - w = weight_shard.detach().clone().requires_grad_(True) - lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) - (lp.sum() + ent.sum()).backward() - del lp, ent, h, w - gc.collect() - torch.cuda.empty_cache() - - for _ in range(2): - h = hidden.detach().clone().requires_grad_(True) - w = weight_full.detach().clone().requires_grad_(True) - logits = (h.float() @ w.float().t()) - log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) - lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - probs = log_softmax.exp() - ent = -(probs * log_softmax).sum(dim=-1) - (lp.sum() + ent.sum()).backward() - del lp, ent, h, w - gc.collect() - torch.cuda.empty_cache() - - # --- Fused TP=2 timing --- - fused_times = [] - fused_mems = [] - for _ in range(5): - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - h = hidden.detach().clone().requires_grad_(True) - w = weight_shard.detach().clone().requires_grad_(True) - lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) - (lp.sum() + ent.sum()).backward() - end.record() - torch.cuda.synchronize() - fused_times.append(start.elapsed_time(end)) - fused_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) - del lp, ent, h, w - - # --- Reference (single-GPU, full weight) timing --- - ref_times = [] - ref_mems = [] - for _ in range(5): - torch.cuda.synchronize() - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - h = hidden.detach().clone().requires_grad_(True) - w = weight_full.detach().clone().requires_grad_(True) - logits = (h.float() @ w.float().t()) - log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) - lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - probs = log_softmax.exp() - ent = -(probs * log_softmax).sum(dim=-1) - (lp.sum() + ent.sum()).backward() - end.record() - torch.cuda.synchronize() - ref_times.append(start.elapsed_time(end)) - ref_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) - del lp, ent, h, w - - ref_med = sorted(ref_times)[len(ref_times) // 2] - fused_med = sorted(fused_times)[len(fused_times) // 2] - ref_peak = max(ref_mems) - fused_peak = max(fused_mems) - speedup = ref_med / fused_med if fused_med > 0 else math.inf - mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf - - print( - f"\n[LCE-TP2-Bench rank={rank}] tokens={num_tokens} hidden={hidden_size} vocab={vocab_size} " - f"dtype={dtype}\n" - f" reference: {ref_med:7.2f} ms / {ref_peak:7.1f} MB peak\n" - f" fused : {fused_med:7.2f} ms / {fused_peak:7.1f} MB peak\n" - f" speedup : {speedup:5.2f}x memory_ratio: {mem_ratio:5.2f}x" - ) - - assert fused_med < ref_med * 1.5, ( - f"Fused TP=2 LCE is more than 1.5x slower than reference " - f"(fused={fused_med:.2f}ms ref={ref_med:.2f}ms)." - ) - assert fused_peak < ref_peak * 1.2, ( - f"Fused TP=2 LCE peak memory exceeds reference by >20% " - f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." + """TP=2 fused vs TP-materialised forward+backward time and peak memory.""" + _run_lce_tp2_with_torchrun( + "performance", num_tokens, hidden_size, vocab_size ) diff --git a/tests/torchrun/run_lce_tp2.py b/tests/torchrun/run_lce_tp2.py new file mode 100644 index 0000000000..e19fb84821 --- /dev/null +++ b/tests/torchrun/run_lce_tp2.py @@ -0,0 +1,265 @@ +import argparse +import gc +import math +import os + +import torch +import torch.distributed as dist + +from areal.infra.platforms import current_platform +from areal.utils.functional import gather_logprobs_entropy +from areal.utils.kernel import linear_cross_entropy + + +def _setup_distributed_environment() -> None: + if dist.is_initialized(): + return + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ["MASTER_PORT"] + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=world_size, + rank=rank, + ) + current_platform.set_device(rank) + + +def _get_tp_group() -> dist.ProcessGroup: + return dist.distributed_c10d._get_default_group() + + +def _make_tp_inputs( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + rank = dist.get_rank() + world_size = dist.get_world_size() + vocab_per_rank = vocab_size // world_size + assert vocab_size % world_size == 0 + + g = torch.Generator(device=device).manual_seed(seed) + hidden = ( + torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) + * 0.02 + ) + labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) + weight_full = ( + torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) + * 0.02 + ) + weight_shard = weight_full[ + rank * vocab_per_rank : (rank + 1) * vocab_per_rank + ].contiguous() + return ( + hidden.contiguous(), + labels.contiguous(), + weight_full.contiguous(), + weight_shard, + ) + + +def _run_full_reference( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + entropy_weight: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_ref = hidden.detach().clone().requires_grad_(True) + weight_ref = weight.detach().clone().requires_grad_(True) + logits = hidden_ref.float() @ weight_ref.float().t() + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + ref_lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ref_h = -(probs * log_softmax).sum(dim=-1) + (ref_lp.sum() + entropy_weight * ref_h.sum()).backward() + return ref_lp, ref_h, hidden_ref.grad, weight_ref.grad + + +def _run_tp_materialized_step( + hidden: torch.Tensor, + weight_shard: torch.Tensor, + labels: torch.Tensor, + tp_group: dist.ProcessGroup, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + h = hidden.detach().clone().requires_grad_(True) + w = weight_shard.detach().clone().requires_grad_(True) + local_logits = h.float() @ w.float().t() + lp, ent = gather_logprobs_entropy(local_logits, labels, tp_group=tp_group) + (lp.sum() + ent.sum()).backward() + return lp, ent, h.grad, w.grad + + +def _run_fused_step( + hidden: torch.Tensor, + weight_shard: torch.Tensor, + labels: torch.Tensor, + tp_group: dist.ProcessGroup, + entropy_weight: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + h = hidden.detach().clone().requires_grad_(True) + w = weight_shard.detach().clone().requires_grad_(True) + lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) + (lp.sum() + entropy_weight * ent.sum()).backward() + return lp, ent, h.grad, w.grad + + +def _test_tp2_correctness( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, +) -> None: + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size == 2 + device = current_platform.current_device() + tp_group = _get_tp_group() + + hidden, labels, weight_full, weight_shard = _make_tp_inputs( + num_tokens, hidden_size, vocab_size, dtype, device, seed=42 + ) + vocab_per_rank = vocab_size // world_size + + ref_lp, ref_h, ref_dh, ref_dw = _run_full_reference( + hidden, weight_full, labels, entropy_weight=0.5 + ) + fused_lp, fused_h, fused_dh, fused_dw = _run_fused_step( + hidden, weight_shard, labels, tp_group, entropy_weight=0.5 + ) + + if dtype == torch.float32: + rtol, atol = 2e-4, 2e-4 + else: + rtol, atol = 3e-2, 3e-2 + + torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_h.float(), ref_h.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_dh.float(), ref_dh.float(), rtol=rtol, atol=atol) + torch.testing.assert_close( + fused_dw.float(), + ref_dw[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].float(), + rtol=rtol, + atol=atol, + ) + + if rank == 0: + print( + f"[PASS] tp2_correctness: T={num_tokens} H={hidden_size} " + f"V={vocab_size} dtype={dtype}" + ) + + +def _time_step(fn) -> tuple[float, float]: + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end), torch.cuda.max_memory_allocated() / (1024 * 1024) + + +def _test_tp2_performance( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size == 2 + device = current_platform.current_device() + dtype = torch.bfloat16 + tp_group = _get_tp_group() + + hidden, labels, _, weight_shard = _make_tp_inputs( + num_tokens, hidden_size, vocab_size, dtype, device, seed=0 + ) + + for _ in range(2): + _run_fused_step(hidden, weight_shard, labels, tp_group) + gc.collect() + torch.cuda.empty_cache() + for _ in range(2): + _run_tp_materialized_step(hidden, weight_shard, labels, tp_group) + gc.collect() + torch.cuda.empty_cache() + + fused_times = [] + fused_mems = [] + for _ in range(5): + t, m = _time_step( + lambda: _run_fused_step(hidden, weight_shard, labels, tp_group) + ) + fused_times.append(t) + fused_mems.append(m) + + ref_times = [] + ref_mems = [] + for _ in range(5): + t, m = _time_step( + lambda: _run_tp_materialized_step(hidden, weight_shard, labels, tp_group) + ) + ref_times.append(t) + ref_mems.append(m) + + ref_med = sorted(ref_times)[len(ref_times) // 2] + fused_med = sorted(fused_times)[len(fused_times) // 2] + ref_peak = max(ref_mems) + fused_peak = max(fused_mems) + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + + if rank == 0: + print( + f"\n[LCE-TP2-Bench] tokens={num_tokens} hidden={hidden_size} " + f"vocab={vocab_size} dtype={dtype}\n" + f" tp materialized: {ref_med:7.2f} ms / {ref_peak:7.1f} MB peak\n" + f" fused : {fused_med:7.2f} ms / {fused_peak:7.1f} MB peak\n" + f" speedup : {speedup:5.2f}x memory_ratio: {mem_ratio:5.2f}x" + ) + + assert fused_med < ref_med * 1.5, ( + f"Fused TP=2 LCE is more than 1.5x slower than TP materialized reference " + f"(fused={fused_med:.2f}ms ref={ref_med:.2f}ms)." + ) + assert fused_peak < ref_peak * 1.2, ( + f"Fused TP=2 LCE peak memory exceeds TP materialized reference by >20% " + f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--test_type", choices=["correctness", "performance"], required=True) + parser.add_argument("--num_tokens", type=int, required=True) + parser.add_argument("--hidden_size", type=int, required=True) + parser.add_argument("--vocab_size", type=int, required=True) + parser.add_argument("--dtype", choices=["float32", "bfloat16"], default="bfloat16") + args = parser.parse_args() + + dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16}[args.dtype] + _setup_distributed_environment() + try: + if args.test_type == "correctness": + _test_tp2_correctness( + args.num_tokens, args.hidden_size, args.vocab_size, dtype + ) + else: + _test_tp2_performance(args.num_tokens, args.hidden_size, args.vocab_size) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() From fa82bb987406691e3e5dcf9912cac3186ecb73c9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 00:01:14 +0800 Subject: [PATCH 16/31] perf(benchmark): benchmark --- benchmark/bench_linear_cross_entropy.py | 248 ++++++++++++++++++------ 1 file changed, 190 insertions(+), 58 deletions(-) diff --git a/benchmark/bench_linear_cross_entropy.py b/benchmark/bench_linear_cross_entropy.py index 83d9dcb280..a6b1424dd9 100644 --- a/benchmark/bench_linear_cross_entropy.py +++ b/benchmark/bench_linear_cross_entropy.py @@ -2,25 +2,30 @@ """ Standalone benchmark for the fused linear-cross-entropy kernel. -Designed to be run *outside* pytest so that NVIDIA Nsight Systems -(``nsys profile``) can capture a clean, deterministic trace covering both -the materialised reference path and the fused Triton path. - -NVTX ranges are emitted around each phase so the resulting ``.nsys-rep`` -file can be filtered down to just the linear-CE kernels in the Nsight UI. +Designed to be run outside pytest to measure forward+backward latency and +peak memory for the materialised reference path and the fused Triton path. Usage:: - # Plain run (sanity) - python -m benchmark.bench_linear_cross_entropy --tokens 4096 --vocab 152064 - - # Profile with Nsight Systems - nsys profile -t nvtx,cuda,cudnn,cublas \\ - -o lce_profile --capture-range cudaProfilerApi --capture-range-end stop \\ - python -m benchmark.bench_linear_cross_entropy \\ - --tokens 4096 --vocab 152064 --use-cuda-profiler-api - -See ``docs/perf/nsight_linear_cross_entropy.md`` for a full Nsight workflow. + # Qwen3 single-GPU full-vocab benchmark + uv run python -m benchmark.bench_linear_cross_entropy \\ + --mode both --tokens 2048 --hidden 4096 --vocab 152064 \\ + --dtype bfloat16 --warmup 5 --iters 15 --check-correctness + + # Qwen3 TP=2 benchmark. The reference path materialises only local + # [tokens, vocab/tp] logits and uses vocab-parallel reductions. + uv run torchrun --nproc_per_node=2 --nnodes=1 \\ + --master-addr=localhost --master_port=29501 \\ + -m benchmark.bench_linear_cross_entropy \\ + --mode both --tp-size 2 --tokens 2048 --hidden 4096 --vocab 152064 \\ + --dtype bfloat16 --warmup 5 --iters 15 --check-correctness + + # Qwen3 TP=4 benchmark + uv run torchrun --nproc_per_node=4 --nnodes=1 \\ + --master-addr=localhost --master_port=29501 \\ + -m benchmark.bench_linear_cross_entropy \\ + --mode both --tp-size 4 --tokens 2048 --hidden 4096 --vocab 152064 \\ + --dtype bfloat16 --warmup 5 --iters 15 --check-correctness """ from __future__ import annotations @@ -28,48 +33,130 @@ import argparse import gc import math +import os import sys import torch +import torch.distributed as dist + +from areal.utils.functional import gather_logprobs_entropy + + +def _setup_distributed(tp_size: int): + if tp_size == 1: + return None + if not dist.is_available(): + raise RuntimeError("torch.distributed is required when --tp-size > 1") + if not dist.is_initialized(): + required = ("RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_PORT") + missing = [k for k in required if k not in os.environ] + if missing: + raise RuntimeError( + "--tp-size > 1 must be launched with torchrun; missing env vars: " + + ", ".join(missing) + ) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + if world_size != tp_size: + raise RuntimeError( + f"--tp-size={tp_size} must match torchrun world_size={world_size}" + ) + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", dist.get_rank()))) + return dist.group.WORLD + + +def _rank(tp_group): + return dist.get_rank(tp_group) if tp_group is not None else 0 + +def _world_size(tp_group): + return dist.get_world_size(tp_group) if tp_group is not None else 1 + + +def _make_inputs(num_tokens, hidden_size, vocab_size, dtype, tp_group=None, seed=0): + world_size = _world_size(tp_group) + rank = _rank(tp_group) + if vocab_size % world_size != 0: + raise ValueError( + f"vocab_size={vocab_size} must be divisible by tp_size={world_size}" + ) + local_vocab_size = vocab_size // world_size -def _make_inputs(num_tokens, hidden_size, vocab_size, dtype, seed=0): g = torch.Generator(device="cuda").manual_seed(seed) hidden = ( torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda", generator=g) * 0.02 ) weight = ( - torch.randn(vocab_size, hidden_size, dtype=dtype, device="cuda", generator=g) + torch.randn( + local_vocab_size, + hidden_size, + dtype=dtype, + device="cuda", + generator=g, + ) * 0.02 ) + if tp_group is not None: + weight = weight + (rank * 0.001) labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda", generator=g) return hidden.contiguous(), weight.contiguous(), labels.contiguous() -def _ref_step(hidden, weight, labels, temperature=1.0): +def _ref_step(hidden, weight, labels, temperature=1.0, tp_group=None): h = hidden.detach().clone().requires_grad_(True) w = weight.detach().clone().requires_grad_(True) - logits = (h.float() @ w.float().t()) / temperature - log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) - lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) - probs = log_softmax.exp() - ent = -(probs * log_softmax).sum(dim=-1) + logits = h.float() @ w.float().t() + if tp_group is not None: + lp, ent = gather_logprobs_entropy( + logits, labels, temperature=temperature, tp_group=tp_group + ) + else: + log_softmax = torch.nn.functional.log_softmax(logits / temperature, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) (lp.sum() + ent.sum()).backward() - return h.grad, w.grad + if tp_group is not None: + dist.all_reduce(h.grad, op=dist.ReduceOp.SUM, group=tp_group) + return lp.detach(), ent.detach(), h.grad.detach(), w.grad.detach() -def _fused_step(hidden, weight, labels, temperature=1.0): +def _fused_step(hidden, weight, labels, temperature=1.0, tp_group=None): from areal.utils.kernel import linear_cross_entropy h = hidden.detach().clone().requires_grad_(True) w = weight.detach().clone().requires_grad_(True) - lp, ent = linear_cross_entropy(h, w, labels, temperature, "none", None) + lp, ent = linear_cross_entropy(h, w, labels, temperature, "none", tp_group) (lp.sum() + ent.sum()).backward() - return h.grad, w.grad + return lp.detach(), ent.detach(), h.grad.detach(), w.grad.detach() + + +def _check_correctness(hidden, weight, labels, dtype, tp_group=None): + ref_lp, ref_ent, ref_dh, ref_dw = _ref_step( + hidden, weight, labels, tp_group=tp_group + ) + fused_lp, fused_ent, fused_dh, fused_dw = _fused_step( + hidden, weight, labels, tp_group=tp_group + ) + if dtype == torch.float32: + rtol, atol = 1e-4, 1e-4 + elif dtype == torch.bfloat16: + rtol, atol = 3e-2, 3e-2 + else: + rtol, atol = 2e-2, 2e-2 + + torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) + torch.testing.assert_close( + fused_ent.float(), ref_ent.float(), rtol=rtol, atol=atol + ) + torch.testing.assert_close(fused_dh.float(), ref_dh.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_dw.float(), ref_dw.float(), rtol=rtol, atol=atol) -def _measure(label, fn, hidden, weight, labels, args, warmup, iters): + +def _measure(label, fn, hidden, weight, labels, warmup, iters, tp_group=None): nvtx = torch.cuda.nvtx times = [] mems = [] @@ -77,7 +164,7 @@ def _measure(label, fn, hidden, weight, labels, args, warmup, iters): # Warmup nvtx.range_push(f"{label}/warmup") for _ in range(warmup): - fn(hidden, weight, labels) + fn(hidden, weight, labels, tp_group=tp_group) gc.collect() torch.cuda.empty_cache() nvtx.range_pop() @@ -91,7 +178,7 @@ def _measure(label, fn, hidden, weight, labels, args, warmup, iters): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() - fn(hidden, weight, labels) + fn(hidden, weight, labels, tp_group=tp_group) end.record() torch.cuda.synchronize() nvtx.range_pop() @@ -102,11 +189,20 @@ def _measure(label, fn, hidden, weight, labels, args, warmup, iters): return times, mems +def _distributed_max(value, tp_group): + if tp_group is None: + return value + tensor = torch.tensor(value, dtype=torch.float64, device="cuda") + dist.all_reduce(tensor, op=dist.ReduceOp.MAX, group=tp_group) + return float(tensor.item()) + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--tokens", type=int, default=4096) parser.add_argument("--hidden", type=int, default=4096) parser.add_argument("--vocab", type=int, default=152064) + parser.add_argument("--tp-size", type=int, default=1) parser.add_argument( "--dtype", choices=["bfloat16", "float16", "float32"], @@ -114,14 +210,11 @@ def main(): ) parser.add_argument("--warmup", type=int, default=3) parser.add_argument("--iters", type=int, default=10) + parser.add_argument("--check-correctness", action="store_true") parser.add_argument( "--use-cuda-profiler-api", action="store_true", - help=( - "Wrap the measurement region with cudaProfilerStart/Stop so that " - "`nsys profile --capture-range cudaProfilerApi` only records the " - "interesting region." - ), + help="Wrap the measurement region with cudaProfilerStart/Stop.", ) parser.add_argument("--mode", choices=["both", "ref", "fused"], default="both") args = parser.parse_args() @@ -130,45 +223,84 @@ def main(): print("CUDA is not available; aborting.", file=sys.stderr) sys.exit(1) - dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[ - args.dtype - ] - hidden, weight, labels = _make_inputs(args.tokens, args.hidden, args.vocab, dtype) - print( - f"[bench] tokens={args.tokens} hidden={args.hidden} vocab={args.vocab} " - f"dtype={args.dtype} warmup={args.warmup} iters={args.iters}" + tp_group = _setup_distributed(args.tp_size) + dtype = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + }[args.dtype] + hidden, weight, labels = _make_inputs( + args.tokens, args.hidden, args.vocab, dtype, tp_group=tp_group ) + if _rank(tp_group) == 0: + print( + f"[bench] tokens={args.tokens} hidden={args.hidden} vocab={args.vocab} " + f"tp={args.tp_size} dtype={args.dtype} warmup={args.warmup} " + f"iters={args.iters}" + ) + + if args.check_correctness: + _check_correctness(hidden, weight, labels, dtype, tp_group=tp_group) + if _rank(tp_group) == 0: + print("[bench] correctness check passed") if args.use_cuda_profiler_api: torch.cuda.cudart().cudaProfilerStart() results = {} if args.mode in ("both", "ref"): - t, m = _measure("reference", _ref_step, hidden, weight, labels, args, args.warmup, args.iters) + t, m = _measure( + "reference", + _ref_step, + hidden, + weight, + labels, + args.warmup, + args.iters, + tp_group=tp_group, + ) results["reference"] = (t, m) if args.mode in ("both", "fused"): - t, m = _measure("fused", _fused_step, hidden, weight, labels, args, args.warmup, args.iters) + t, m = _measure( + "fused", + _fused_step, + hidden, + weight, + labels, + args.warmup, + args.iters, + tp_group=tp_group, + ) results["fused"] = (t, m) if args.use_cuda_profiler_api: torch.cuda.cudart().cudaProfilerStop() + summaries = {} for name, (t, m) in results.items(): - median = sorted(t)[len(t) // 2] - peak = max(m) - print(f"[bench] {name:9s} median={median:7.2f}ms peak_mem={peak:8.1f}MB") - - if "reference" in results and "fused" in results: - ref_med = sorted(results["reference"][0])[len(results["reference"][0]) // 2] - fused_med = sorted(results["fused"][0])[len(results["fused"][0]) // 2] - ref_peak = max(results["reference"][1]) - fused_peak = max(results["fused"][1]) - speedup = ref_med / fused_med if fused_med > 0 else math.inf - mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf - print( - f"[bench] speedup={speedup:.2f}x fused_peak/ref_peak={mem_ratio:.2f}x" + local_median = sorted(t)[len(t) // 2] + local_peak = max(m) + summaries[name] = ( + _distributed_max(local_median, tp_group), + _distributed_max(local_peak, tp_group), ) + if _rank(tp_group) == 0: + for name, (median, peak) in summaries.items(): + print(f"[bench] {name:9s} median={median:7.2f}ms peak_mem={peak:8.1f}MB") + + if "reference" in summaries and "fused" in summaries: + ref_med, ref_peak = summaries["reference"] + fused_med, fused_peak = summaries["fused"] + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + print( + f"[bench] speedup={speedup:.2f}x fused_peak/ref_peak={mem_ratio:.2f}x" + ) + + if dist.is_initialized(): + dist.destroy_process_group() + if __name__ == "__main__": main() From 4e9f8c719152db0cddd014a509ee2009cf776e20 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 00:59:49 +0800 Subject: [PATCH 17/31] refactor(fsdp): remove useless --- .../fsdp_utils/kernels/fused_experts.py | 430 ------------------ 1 file changed, 430 deletions(-) delete mode 100644 areal/engine/fsdp_utils/kernels/fused_experts.py diff --git a/areal/engine/fsdp_utils/kernels/fused_experts.py b/areal/engine/fsdp_utils/kernels/fused_experts.py deleted file mode 100644 index 4e19d7604c..0000000000 --- a/areal/engine/fsdp_utils/kernels/fused_experts.py +++ /dev/null @@ -1,430 +0,0 @@ -"""Fused MoE autograd functions adapted for AReaL FSDP backend. - -Forward reuses SGLang's Triton kernels. Backward uses a Triton kernel written in -``fused_moe_triton_backward_kernels.py`` that computes ``grad_input``, ``grad_weight`` -and (optionally) ``grad_topk_weights`` with ``tl.atomic_add``. - -Debug logging can be enabled by setting the environment variable -``AREAL_FUSED_MOE_DEBUG=1``. When enabled, each intermediate tensor of the pipeline -is printed (shape + mean/std + first few values) so that the fused path can be -compared against a reference implementation by ``diff``-ing two runs. -""" - -from __future__ import annotations - -import os - -import torch -import triton.language as tl -from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( - invoke_fused_moe_kernel, - moe_align_block_size, - moe_sum_reduce, - silu_and_mul, -) - -from .fused_moe_triton_backward_kernels import invoke_fused_moe_backward_kernel - -_DEBUG = os.environ.get("AREAL_FUSED_MOE_DEBUG", "0") == "1" - - -def _dbg(tag: str, t: torch.Tensor | None) -> None: - """Emit a compact summary of a tensor when debug mode is on. - - Only rank-0 logging is fine because callers run the same op on every rank and - we just want to sanity-check numerical content during a single-process test. - """ - if not _DEBUG or t is None: - return - try: - flat = t.detach().float().reshape(-1) - head = flat[:4].tolist() - print( - f"[fused_moe] {tag}: shape={tuple(t.shape)} dtype={t.dtype} " - f"mean={flat.mean().item():.6e} std={flat.std().item():.6e} head={head}" - ) - except Exception as e: # pragma: no cover - print(f"[fused_moe] {tag}: ") - - -class GateUpProjFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - hidden_states: torch.Tensor, - w1: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ): - num_tokens, _ = hidden_states.shape - E, N, _ = w1.shape - # Match slime / vLLM convention: chunked launch to avoid the bug - # https://github.com/vllm-project/vllm/issues/5938 - CHUNK_SIZE = 64 * 1024 - - # Default deterministic config. Tuned for H800 / A100 bf16 MoE. - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - - topk = topk_ids.shape[1] - - intermediate_cache1 = torch.empty( - (num_tokens * topk, N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - _dbg("gate_up.fwd.hidden_states", hidden_states) - _dbg("gate_up.fwd.w1", w1) - _dbg("gate_up.fwd.topk_weights", topk_weights) - _dbg("gate_up.fwd.topk_ids", topk_ids) - - for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = ( - chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens), - ) - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - cur_intermediate_cache1 = intermediate_cache1[ - begin_chunk_idx * topk : end_chunk_idx * topk - ] - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, config["BLOCK_SIZE_M"], E - ) - - invoke_fused_moe_kernel( - curr_hidden_states, - w1, - None, - cur_intermediate_cache1, - None, - None, - None, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=tl.bfloat16, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None, - c_sorted=False, - filter_expert=True, - ) - - _dbg("gate_up.fwd.intermediate_cache1", intermediate_cache1) - - ctx.save_for_backward(hidden_states, w1, topk_weights, topk_ids) - ctx.config = config - ctx.num_tokens = num_tokens - ctx.topk = topk - - return intermediate_cache1 - - @staticmethod - def backward(ctx, grad_output): - """Backward for GateUpProj using Triton kernels. - - ``grad_output`` has shape ``(num_tokens * topk, N)``. We return - ``(grad_hidden_states, grad_w1, None, None)`` because ``topk_weights`` and - ``topk_ids`` are not multiplied in the forward kernel for this stage. - """ - hidden_states, w1, topk_weights, topk_ids = ctx.saved_tensors - config = ctx.config - num_tokens = ctx.num_tokens - topk = ctx.topk - - E, N, D_in = w1.shape - CHUNK_SIZE = 64 * 1024 - - grad_hidden_states = torch.zeros_like(hidden_states) - grad_w1 = torch.zeros_like(w1) - grad_topk_weights = torch.zeros_like(topk_weights) - - _dbg("gate_up.bwd.grad_output", grad_output) - - for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = ( - chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens), - ) - - curr_num_tokens = end_chunk_idx - begin_chunk_idx - if curr_num_tokens == 0: - continue - - curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - curr_grad_output = grad_output[begin_chunk_idx * topk : end_chunk_idx * topk] - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, config["BLOCK_SIZE_M"], E - ) - - curr_grad_hidden_states = torch.zeros_like(curr_hidden_states) - curr_grad_w1 = torch.zeros_like(w1) - - invoke_fused_moe_backward_kernel( - grad_output=curr_grad_output, - input=curr_hidden_states, - weight=w1, - grad_input=curr_grad_hidden_states, - grad_weight=curr_grad_w1, - grad_topk_weights=None, - topk_weights=curr_topk_weights, - topk_ids=curr_topk_ids, - sorted_token_ids=sorted_token_ids, - expert_ids=expert_ids, - num_tokens_post_padded=num_tokens_post_padded, - mul_routed_weight=False, - top_k=topk, - config=config, - compute_type=tl.bfloat16, - ) - - grad_hidden_states[begin_chunk_idx:end_chunk_idx] += curr_grad_hidden_states - grad_w1 += curr_grad_w1 - - _dbg("gate_up.bwd.grad_hidden_states", grad_hidden_states) - _dbg("gate_up.bwd.grad_w1", grad_w1) - - return grad_hidden_states, grad_w1, grad_topk_weights, None - - -class SiluAndMulFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, intermediate_cache1: torch.Tensor): - num_tokens, N = intermediate_cache1.shape - intermediate_cache2 = torch.empty( - (num_tokens, N // 2), - device=intermediate_cache1.device, - dtype=intermediate_cache1.dtype, - ) - silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) - _dbg("silu.fwd.intermediate_cache2", intermediate_cache2) - - ctx.save_for_backward(intermediate_cache1) - return intermediate_cache2 - - @staticmethod - def backward(ctx, grad_output): - (intermediate_cache1,) = ctx.saved_tensors - N = intermediate_cache1.shape[-1] - x1, x2 = intermediate_cache1.view(-1, N).chunk(2, dim=-1) - silu_x1 = torch.nn.functional.silu(x1) - - sig = torch.sigmoid(x1) - dsilu_dx1 = sig + x1 * sig * (1 - sig) - grad_x1 = grad_output * x2 * dsilu_dx1 - grad_x2 = grad_output * silu_x1 - grad_input = torch.cat([grad_x1, grad_x2], dim=-1) - _dbg("silu.bwd.grad_input", grad_input) - - return grad_input.view_as(intermediate_cache1) - - -class DownProjFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - intermediate_cache2: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - ): - num_tokens, _ = intermediate_cache2.shape - topk = topk_ids.shape[1] - num_tokens //= topk - E, _, _ = w2.shape - CHUNK_SIZE = 64 * 1024 - - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - - intermediate_cache3 = torch.empty( - (num_tokens, topk, w2.shape[1]), - device=intermediate_cache2.device, - dtype=intermediate_cache2.dtype, - ) - - _dbg("down.fwd.intermediate_cache2", intermediate_cache2) - _dbg("down.fwd.w2", w2) - - for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = ( - chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens), - ) - cur_intermediate_cache2 = intermediate_cache2[ - begin_chunk_idx * topk : end_chunk_idx * topk - ] - cur_intermediate_cache3 = intermediate_cache3[begin_chunk_idx:end_chunk_idx] - - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, config["BLOCK_SIZE_M"], E - ) - invoke_fused_moe_kernel( - cur_intermediate_cache2, - w2, - None, - cur_intermediate_cache3, - None, - None, - None, - curr_topk_weights, - curr_topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=tl.bfloat16, - use_fp8_w8a8=False, - use_int8_w8a8=False, - use_int8_w8a16=False, - use_int4_w4a16=False, - per_channel_quant=False, - block_shape=None, - a_use_tma=False, - b_use_tma=False, - ) - - _dbg("down.fwd.intermediate_cache3", intermediate_cache3) - - ctx.save_for_backward(intermediate_cache2, w2, topk_weights, topk_ids) - ctx.config = config - ctx.num_tokens = num_tokens - ctx.topk = topk - - return intermediate_cache3 - - @staticmethod - def backward(ctx, grad_output): - """Backward for DownProj. - - ``grad_output`` has shape ``(num_tokens, topk, hidden_size)``. - Returns ``(grad_intermediate_cache2, grad_w2, grad_topk_weights, None)``. - """ - intermediate_cache2, w2, topk_weights, topk_ids = ctx.saved_tensors - config = ctx.config - num_tokens = ctx.num_tokens - topk = ctx.topk - - E, hidden_size, intermediate_size = w2.shape - CHUNK_SIZE = 64 * 1024 - - grad_intermediate_cache2 = torch.zeros_like(intermediate_cache2) - grad_w2 = torch.zeros_like(w2) - grad_topk_weights = torch.zeros_like(topk_weights) - - _dbg("down.bwd.grad_output", grad_output) - - for chunk in range((num_tokens // CHUNK_SIZE) + 1): - begin_chunk_idx, end_chunk_idx = ( - chunk * CHUNK_SIZE, - min((chunk + 1) * CHUNK_SIZE, num_tokens), - ) - - curr_num_tokens = end_chunk_idx - begin_chunk_idx - if curr_num_tokens == 0: - continue - - curr_intermediate_cache2 = intermediate_cache2[ - begin_chunk_idx * topk : end_chunk_idx * topk - ] - curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] - curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] - curr_grad_output = grad_output[begin_chunk_idx:end_chunk_idx] - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - curr_topk_ids, config["BLOCK_SIZE_M"], E - ) - - curr_grad_intermediate_cache2 = torch.zeros_like(curr_intermediate_cache2) - curr_grad_w2 = torch.zeros_like(w2) - curr_grad_topk_weights = torch.zeros_like(curr_topk_weights) - - # Note: Use top_k=1 to match forward pass indexing convention of - # DownProj (each routed copy is its own "token"). - invoke_fused_moe_backward_kernel( - grad_output=curr_grad_output, - input=curr_intermediate_cache2, - weight=w2, - grad_input=curr_grad_intermediate_cache2, - grad_weight=curr_grad_w2, - grad_topk_weights=curr_grad_topk_weights, - topk_weights=curr_topk_weights, - topk_ids=curr_topk_ids, - sorted_token_ids=sorted_token_ids, - expert_ids=expert_ids, - num_tokens_post_padded=num_tokens_post_padded, - mul_routed_weight=True, - top_k=1, - config=config, - compute_type=tl.bfloat16, - ) - - grad_intermediate_cache2[ - begin_chunk_idx * topk : end_chunk_idx * topk - ] = curr_grad_intermediate_cache2 - grad_w2 += curr_grad_w2 - grad_topk_weights[begin_chunk_idx:end_chunk_idx] = curr_grad_topk_weights - - _dbg("down.bwd.grad_intermediate_cache2", grad_intermediate_cache2) - _dbg("down.bwd.grad_w2", grad_w2) - _dbg("down.bwd.grad_topk_weights", grad_topk_weights) - - return grad_intermediate_cache2, grad_w2, grad_topk_weights, None - - -class MoeSumReduceFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - intermediate_cache3: torch.Tensor, - hidden_states_shape, - ): - out_hidden_states = torch.empty( - hidden_states_shape, - device=intermediate_cache3.device, - dtype=intermediate_cache3.dtype, - ) - moe_sum_reduce( - intermediate_cache3, - out_hidden_states, - 1.0, - ) - _dbg("sum_reduce.fwd.out_hidden_states", out_hidden_states) - ctx.save_for_backward(intermediate_cache3) - return out_hidden_states - - @staticmethod - def backward(ctx, grad_output): - (intermediate_cache3,) = ctx.saved_tensors - grad = grad_output.unsqueeze(1).expand_as(intermediate_cache3) - _dbg("sum_reduce.bwd.grad_input", grad) - return grad, None From 64b91a4f41789f21c95cc1cdcedcadd9a0c4d08f Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 01:03:07 +0800 Subject: [PATCH 18/31] refactor: remove test profile --- areal/engine/megatron_engine.py | 96 ++------------------------------- 1 file changed, 5 insertions(+), 91 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 5628abb5d4..eb3a450f78 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -59,6 +59,11 @@ from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper +from areal.engine.megatron_utils.fused_lce_capture import ( + FUSED_LCE_HIDDEN_KEY, + FUSED_LCE_WEIGHT_KEY, + capture_lm_head_hidden, +) from areal.engine.megatron_utils.megatron import ( all_gather_param, convert_to_hf, @@ -106,11 +111,6 @@ split_padded_tensor_dict_into_mb_list, unpad_logits, ) -from areal.engine.megatron_utils.fused_lce_capture import ( - FUSED_LCE_HIDDEN_KEY, - FUSED_LCE_WEIGHT_KEY, - capture_lm_head_hidden, -) from areal.utils.functional import ( gather_logprobs, gather_logprobs_entropy, @@ -148,33 +148,6 @@ def parameters(self, *args, **kwargs) -> Iterator[nn.Parameter]: yield parameter -_LCE_KERNEL_PREFIXES = ( - "efficient_entropy_kernel", - "lce_backward", - "triton_poi", -) - - -def _print_lce_summary(prof: torch.profiler.profile, rank: int) -> None: - logger = logging.getLogger("LCEProfiler") - events = prof.key_averages() - lce_rows: list[tuple[str, float, float, float]] = [] - for evt in events: - key = evt.key - if any(p in key for p in _LCE_KERNEL_PREFIXES): - cuda_ms = evt.cuda_time_total / 1000.0 - cpu_ms = evt.cpu_time_total / 1000.0 - calls = evt.count - lce_rows.append((key, cuda_ms, cpu_ms, calls)) - if not lce_rows: - logger.info(f"[Rank {rank}] No LCE Triton kernels found in profiler trace.") - return - header = f"[Rank {rank}] LCE Kernel Profiling Summary (CUDA ms / CPU ms / calls):" - logger.info(header) - for key, cuda_ms, cpu_ms, calls in lce_rows: - logger.info(f" {key}: {cuda_ms:.3f} / {cpu_ms:.3f} / {calls}") - - class MegatronEngine(TrainEngine): def __init__(self, config: TrainEngineConfig): self.config = config @@ -924,8 +897,6 @@ def train_batch( self._ensure_ready() self.optimizer_zero_grad() - self._maybe_init_lce_profiler() - input_batched, _ = self._normalize_batch_input(input_) # Step 1: Prepare micro-batches @@ -962,8 +933,6 @@ def process_output( # Step 4: Optimizer step result = self.optimizer_step() - self._maybe_step_lce_profiler() - return result @torch.no_grad() @@ -1367,61 +1336,6 @@ def _ensure_ready(self) -> None: if self.model is None: raise RuntimeError("Model is not initialized.") - _LCE_PROFILER_KEY = "_lce_profiler" - _LCE_PROFILER_ENV = "ARENAL_LCE_PROFILER_DIR" - - def _maybe_init_lce_profiler(self) -> None: - if hasattr(self, self._LCE_PROFILER_KEY): - return - profiler_dir = os.environ.get(self._LCE_PROFILER_ENV, "") - if not profiler_dir: - setattr(self, self._LCE_PROFILER_KEY, None) - return - - import torch.profiler - - rank = dist.get_rank() if dist.is_initialized() else 0 - output_dir = os.path.join(profiler_dir, f"rank_{rank}") - os.makedirs(output_dir, exist_ok=True) - - def _lce_trace_handler(prof: torch.profiler.profile) -> None: - torch.profiler.tensorboard_trace_handler(output_dir)(prof) - _print_lce_summary(prof, rank) - - setattr( - self, - self._LCE_PROFILER_KEY, - torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - record_shapes=True, - profile_memory=True, - with_stack=True, - schedule=torch.profiler.schedule( - wait=2, warmup=1, active=3, repeat=1 - ), - on_trace_ready=_lce_trace_handler, - ), - ) - getattr(self, self._LCE_PROFILER_KEY).start() - logger = logging.getLogger("LCEProfiler") - logger.info( - f"[Rank {rank}] torch.profiler started, traces will be saved to {output_dir} " - f"(schedule: wait=2, warmup=1, active=3)" - ) - - def _maybe_step_lce_profiler(self) -> None: - profiler = getattr(self, self._LCE_PROFILER_KEY, None) - if profiler is None: - return - profiler.step() - if profiler.profiler is None: - logger = logging.getLogger("LCEProfiler") - rank = dist.get_rank() if dist.is_initialized() else 0 - logger.info(f"[Rank {rank}] torch.profiler finished and stopped.") - def _update_bucket_weights_from_distributed( self, meta: WeightUpdateMeta, From 92c4298b35a330154802067d77019f4fcdf4de8b Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 02:02:48 +0800 Subject: [PATCH 19/31] refactor(kernel): fix --- areal/utils/kernel/kernels.py | 829 ++-------------------------------- 1 file changed, 36 insertions(+), 793 deletions(-) diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index de7ccb7f16..33a638578a 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -36,9 +36,6 @@ materialized, trading kernel-launch overhead for large memory savings. """ -import typing -from dataclasses import dataclass - import torch import torch.distributed as dist @@ -106,7 +103,7 @@ def inner(func): elif SUPPORT_CUDA_TMA: # TMA descriptors require a global memory allocation - def alloc_fn(size: int, alignment: int, stream: typing.Optional[int]): + def alloc_fn(size: int, alignment: int, stream: int | None): return torch.empty(size, device=get_device_name(), dtype=torch.int8) # https://github.com/triton-lang/triton/commit/43625fc968b693ab51884ca95adbcf3e43483fd0 @@ -130,7 +127,6 @@ def alloc_fn(size: int, alignment: int, stream: typing.Optional[int]): triton.set_allocator(alloc_fn) -@dataclass class EntropyReductionEnum: """ Enum for the reduction method of cross entropy. @@ -173,42 +169,7 @@ def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: return _enum -@dataclass -class BackwardEnum: - """ - Enum for the backward method. - """ - - _Total_Fuse_MN = ( - 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight - ) - _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight - _Split_Dlogits_N = 2 # split d_logits along its N dimension, aka. vocab_size - _Split_Dlogits_M = 3 # split d_logits along its M dimension, aka. num_tokens - - -@dataclass -class Config: - """Configuration for efficient entropy kernel operations. - - Args: - _backward (BackwardEnum): Backward computation method. Defaults to BackwardEnum._Split_Dlogits_N. - _use_triton (bool): Whether to use Triton kernels for computation. Defaults to True. - """ - - _backward: BackwardEnum = BackwardEnum._Split_Dlogits_N - _use_triton: bool = True - - -_config = Config() - - -def set_backward_method(backward_method: BackwardEnum): - """ - Set the backward method. - """ - global _config - _config._backward = backward_method +_USE_TRITON = True @triton.autotune( @@ -598,9 +559,9 @@ def efficient_entropy_forward( hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, - reduction: typing.Optional[int] = 2, - temperature: typing.Optional[float] = 1.0, - dist_process_group: typing.Optional[dist.ProcessGroup] = None, + reduction: int | None = 2, + temperature: float | None = 1.0, + dist_process_group: dist.ProcessGroup | None = None, ) -> list[torch.Tensor]: """ forward host function @@ -664,7 +625,7 @@ def efficient_entropy_forward( assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda - if _config._use_triton: + if _USE_TRITON: # 1D kernel launch, then split the tile def mainloop_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) @@ -787,624 +748,6 @@ def epilogue_grid(meta): return (logprobs, entropy, maximum, accumulate, entropy_b) -# NOTE: merge d_weight & d_hidden here, split along M & N -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ) - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_mainloop_MN( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_hidden_ptr, - stride_d_hidden_m: tl.int64, - stride_d_hidden_k: tl.int64, - d_weight_ptr, - stride_d_weight_n: tl.int64, - stride_d_weight_k: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - USE_TMA: tl.constexpr, -): - """ - backward mainloop, where d_logits & d_hidden & d_weight are fused - """ - # block swizzling - # pid = tl.program_id(axis=0) - # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - # pid_m = pid % num_pid_m - # pid_n = pid // num_pid_m - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - start_offs_am = pid_m * BLOCK_SIZE_M - offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) - start_offs_bn = pid_n * BLOCK_SIZE_N - offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - if USE_TMA: - # using TMA and device-side descriptor creation - hidden_desc = tl.make_tensor_descriptor( - hidden_ptr, - shape=[num_tokens, hidden_size], - strides=[stride_hidden_m, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], - ) - - weight_desc = tl.make_tensor_descriptor( - weight_ptr, - shape=[vocab_size, hidden_size], - strides=[stride_weight_n, 1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], - ) - - maximum_ptrs = maximum_ptr + offs_am * stride_maximum - maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) - accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero - accu_rcp = tl.fdiv(1.0, accu) - - d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy - d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: # none - d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs - d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b - entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) - - if not USE_TMA: - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - labels_ptrs = labels_ptr + offs_am * stride_labels - labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) - - d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k - # d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n - d_weight_ptrs = d_weight_ptr + offs_bn[:, None] * stride_d_weight_n + offs_k[None, :] * stride_d_weight_k - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - if USE_TMA: - start_offs_k = k * BLOCK_SIZE_K - _hidden = hidden_desc.load([start_offs_am, start_offs_k]) - _weight = weight_desc.load([start_offs_bn, start_offs_k]) - else: - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - other=0.0, - ) - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - - logits = tl.dot(_hidden, _weight.T, logits) - - if not USE_TMA: - hidden_ptrs -= hidden_size * stride_hidden_k - weight_ptrs -= hidden_size * stride_weight_k - - # scale logits by temperature - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - # scale d_logits by temperature - d_logits *= rcp_temperature - - # loop for d_weight & d_hidden - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - start_offs_k = k * BLOCK_SIZE_K - if USE_TMA: - _hidden = hidden_desc.load([start_offs_am, start_offs_k]) - else: - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - # _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) - # tl.atomic_add(d_weight_ptrs, - # _d_weight, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) - _d_weight = tl.dot(d_logits.trans(), _hidden.to(tl.float32)) - tl.atomic_add( - d_weight_ptrs, - _d_weight, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - ) - - if USE_TMA: - _weight = weight_desc.load([start_offs_bn, start_offs_k]) - else: - # _weight = tl.load( - # weight_ptrs, - # mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), - # other=0.0 - # ) - # _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - other=0.0, - ) - _d_hidden = tl.dot(d_logits, _weight.to(tl.float32)) - tl.atomic_add( - d_hidden_ptrs, - _d_hidden, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - ) - - if not USE_TMA: - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k - d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_d_hidden( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_hidden_ptr, - stride_d_hidden_m: tl.int64, - stride_d_hidden_k: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - """ - backward d_hidden - """ - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - pid_m = pid % num_pid_m - pid_k = pid // num_pid_m - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_k = tl.arange(0, BLOCK_SIZE_K) - result_offs_k = pid_k * BLOCK_SIZE_K + offs_k - - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) - accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) - if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) - elif reduction == 1: - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) - - # iterate over vocab_size - d_hidden = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) - for n in range(0, tl.cdiv(vocab_size, BLOCK_SIZE_N)): - offs_n = n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - # iterate over hidden_size to get logits - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), - other=0.0, - ) - - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - - # scale logits by temperature - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - # scale d_logits - d_logits *= rcp_temperature - - # calculate d_hidden - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + result_offs_k[None, :] * stride_weight_k) - _weight = tl.load( - weight_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_n[:, None] < vocab_size), other=0.0 - ) - d_hidden = tl.dot(d_logits.to(weight_ptr.dtype.element_ty), _weight, d_hidden) - - # write back - tl.store( - d_hidden_ptr + offs_m[:, None] * stride_d_hidden_m + result_offs_k[None, :] * stride_d_hidden_k, - d_hidden, - mask=(offs_m[:, None] < num_tokens) & (result_offs_k[None, :] < hidden_size), - ) - - -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_d_weight( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_weight_ptr, - stride_d_weight_n: tl.int64, - stride_d_weight_k: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - pid_n = pid % num_pid_n - pid_k = pid // num_pid_n - - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - result_offs_k = pid_k * BLOCK_SIZE_K + offs_k - - d_weight = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) - for m in range(0, tl.cdiv(num_tokens, BLOCK_SIZE_M)): - offs_m = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - - maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_m * stride_accu, mask=offs_m < num_tokens, other=1e-6) - accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_m * stride_d_entropy, mask=offs_m < num_tokens, other=0.0) - if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_m * stride_d_logprobs, mask=offs_m < num_tokens, other=0.0) - elif reduction == 1: - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=0) - - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_n[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_m[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_n[:, None] < vocab_size), - other=0.0, - ) - - logits = tl.dot(_hidden, _weight.trans(), logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_n + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - d_logits *= rcp_temperature - - hidden_ptrs = hidden_ptr + (offs_m[:, None] * stride_hidden_m + result_offs_k[None, :] * stride_hidden_k) - _hidden = tl.load( - hidden_ptrs, mask=(result_offs_k[None, :] < hidden_size) & (offs_m[:, None] < num_tokens), other=0.0 - ) - d_weight = tl.dot(d_logits.to(d_weight_ptr.dtype.element_ty).trans(), _hidden, d_weight) - - # write back - tl.store( - d_weight_ptr + offs_n[:, None] * stride_d_weight_n + result_offs_k[None, :] * stride_d_weight_k, - d_weight, - mask=(offs_n[:, None] < vocab_size) & (result_offs_k[None, :] < hidden_size), - ) - - -# NOTE: split tile from d_logits' perspective -@triton.autotune( - configs=[ - triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, - num_warps=8, - ), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_d_logits( - num_tokens: int, - hidden_size: int, - vocab_size: int, - rank: int, - hidden_ptr, - stride_hidden_m: tl.int64, - stride_hidden_k: tl.int64, - weight_ptr, - stride_weight_n: tl.int64, - stride_weight_k: tl.int64, - labels_ptr, - stride_labels: tl.int64, - maximum_ptr, - stride_maximum: tl.int64, - accu_ptr, - stride_accu: tl.int64, - d_entropy_ptr, - stride_d_entropy: tl.int64, - d_logprobs_ptr, - stride_d_logprobs: tl.int64, - reduction: int, - entropy_b_ptr, - stride_entropy_b: tl.int64, - d_logits_ptr, - stride_d_logits_m: tl.int64, - stride_d_logits_n: tl.int64, - rcp_temperature: tl.float32, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - USE_TMA: tl.constexpr, -): - """ - backward d_logits - """ - # block swizzling - # pid = tl.program_id(axis=0) - # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - # pid_m = pid % num_pid_m - # pid_n = pid // num_pid_m - - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - start_offs_am = pid_m * BLOCK_SIZE_M - offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) - start_offs_bn = pid_n * BLOCK_SIZE_N - offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) - offs_k = tl.arange(0, BLOCK_SIZE_K) - - maximum_ptrs = maximum_ptr + offs_am * stride_maximum - maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) - accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero - accu_rcp = tl.fdiv(1.0, accu) - - d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy - d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: # none - d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs - d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_logprobs = -1 * d_logprobs - - entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b - entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) - - labels_ptrs = labels_ptr + offs_am * stride_labels - labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - if USE_TMA: - # using TMA and device-side descriptor creation - hidden_desc = tl.make_tensor_descriptor( - hidden_ptr, - shape=[num_tokens, hidden_size], - strides=[stride_hidden_m, 1], - block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], - ) - weight_desc = tl.make_tensor_descriptor( - weight_ptr, - shape=[vocab_size, hidden_size], - strides=[stride_weight_n, 1], - block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], - ) - else: - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) - - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - if USE_TMA: - start_offs_k = k * BLOCK_SIZE_K - _hidden = hidden_desc.load([start_offs_am, start_offs_k]) - _weight = weight_desc.load([start_offs_bn, start_offs_k]) - else: - _hidden = tl.load( - hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0, - ) - _weight = tl.load( - weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_size), - other=0.0, - ) - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - logits = tl.dot(_hidden, _weight.T, logits) - - if not USE_TMA: - hidden_ptrs -= hidden_size * stride_hidden_k - weight_ptrs -= hidden_size * stride_weight_k - - # scale logits by temperature - logits *= rcp_temperature - - exp_logits = tl.exp(logits - maximum[:, None]) - - mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] - d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) - - # scale d_logits by temperature - d_logits *= rcp_temperature - - # store d_logits - d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n - tl.store( - d_logits_ptrs, - d_logits, # will be implicitly converted to d_logits_ptrs.dtype.element_ty - mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size), - ) - - @triton.autotune( configs=[ triton.Config( @@ -1552,10 +895,10 @@ def efficient_entropy_backward( maximum: torch.Tensor, acc: torch.Tensor, entropy_b: torch.Tensor, - reduction: typing.Optional[int] = 2, + reduction: int | None = 2, should_return_fp32_grad: bool = False, - temperature: typing.Optional[float] = 1.0, - dist_process_group: typing.Optional[dist.ProcessGroup] = None, + temperature: float | None = 1.0, + dist_process_group: dist.ProcessGroup | None = None, ) -> list[torch.Tensor]: """ backward host function @@ -1586,13 +929,9 @@ def efficient_entropy_backward( assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device assert dentropy.shape == (num_tokens,) - d_hidden, d_weight = None, None - if _config._backward == BackwardEnum._Total_Fuse_MN or should_return_fp32_grad: - d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) - d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) - else: - d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) - d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + grad_dtype = torch.float32 if should_return_fp32_grad else hidden.dtype + d_hidden = torch.empty_like(hidden, dtype=grad_dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=grad_dtype, device=weight.device) assert d_hidden.is_contiguous() and d_weight.is_contiguous() assert maximum.is_contiguous() and acc.is_contiguous() @@ -1600,22 +939,26 @@ def efficient_entropy_backward( assert maximum.shape == labels.shape == acc.shape assert maximum.is_cuda and acc.is_cuda - vocab_per_split = 1024 + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + vocab_per_split = 9504 assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - assert entropy_b.is_contiguous() and entropy_b.is_cuda - assert entropy_b.shape == (num_tokens,) + _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + assert _d_logits.is_contiguous() - if _config._backward == BackwardEnum._Total_Fuse_MN: - # --- Triton doesn't materialize d_logits at all. Split tiles at the perspective of d_logits. - def mainloop_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) - efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( + for split_idx in range(num_splits): + efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( + split_idx, num_tokens, hidden_size, vocab_size, + vocab_per_split, _rank, hidden, hidden.stride(0), @@ -1636,122 +979,22 @@ def mainloop_grid(meta): REDUCTION, entropy_b, entropy_b.stride(0), - d_hidden, - d_hidden.stride(0), - d_hidden.stride(1), - d_weight, - d_weight.stride(0), - d_weight.stride(1), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), 1.0 / temperature, USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, ) - elif _config._backward == BackwardEnum._Total_Separate: - _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype).contiguous() - assert _d_logits.is_contiguous() - - if _config._use_triton: - - def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) - - efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( - num_tokens, - hidden_size, - vocab_size, - _rank, - hidden, - hidden.stride(0), - hidden.stride(1), - weight, - weight.stride(0), - weight.stride(1), - labels, - labels.stride(0), - maximum, - maximum.stride(0), - acc, - acc.stride(0), - dentropy, - dentropy.stride(0), - dlogprobs, - dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, - entropy_b, - entropy_b.stride(0), - _d_logits, - _d_logits.stride(0), - _d_logits.stride(1), - 1.0 / temperature, - USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, - ) + split_start = split_idx * vocab_per_split + split_end = min(split_start + vocab_per_split, vocab_size) + current_d_logits = _d_logits[:, : split_end - split_start] + current_weight = weight[split_start:split_end, :] + current_d_weight = d_weight[split_start:split_end, :] - torch.matmul(_d_logits, weight, out=d_hidden) - torch.matmul(_d_logits.T, hidden, out=d_weight) + if split_idx == 0: + torch.matmul(current_d_logits, current_weight, out=d_hidden) else: - raise AssertionError("Triton is required for efficient entropy kernel") - - elif _config._backward == BackwardEnum._Split_Dlogits_N: - vocab_per_split = 9504 - num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - - _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() - assert _d_logits.is_contiguous() - - def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) - - for split_idx in range(num_splits): - efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( - split_idx, - num_tokens, - hidden_size, - vocab_size, - vocab_per_split, - _rank, - hidden, - hidden.stride(0), - hidden.stride(1), - weight, - weight.stride(0), - weight.stride(1), - labels, - labels.stride(0), - maximum, - maximum.stride(0), - acc, - acc.stride(0), - dentropy, - dentropy.stride(0), - dlogprobs, - dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, - entropy_b, - entropy_b.stride(0), - _d_logits, - _d_logits.stride(0), - _d_logits.stride(1), - 1.0 / temperature, - USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, - ) - - if split_idx == (num_splits - 1): - vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split - _d_logits = _d_logits[:, :vocab_right_bound].contiguous() - - if split_idx == 0: - torch.matmul( - _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], out=d_hidden - ) - else: - d_hidden += torch.matmul( - _d_logits, weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] - ) - torch.matmul( - _d_logits.T, hidden, out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :] - ) - - elif _config._backward == BackwardEnum._Split_Dlogits_M: - raise NotImplementedError("BackwardEnum._Split_Dlogits_M is not implemented yet") - + d_hidden += torch.matmul(current_d_logits, current_weight) + torch.matmul(current_d_logits.T, hidden, out=current_d_weight) return d_hidden, d_weight From c640e03998c911eb597229c38a050850716585ee Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 11:12:06 +0800 Subject: [PATCH 20/31] fix(kernel): fix code --- areal/utils/kernel/kernels.py | 1 - 1 file changed, 1 deletion(-) diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index 33a638578a..df327c66b0 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -943,7 +943,6 @@ def efficient_entropy_backward( assert entropy_b.shape == (num_tokens,) vocab_per_split = 9504 - assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() From 8b1a2436795f76c245f8a72ac3bbca908ae83ec4 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 11:47:25 +0800 Subject: [PATCH 21/31] refactor(kernel): fix --- areal/utils/kernel/kernels.py | 115 ++++----------------- areal/utils/kernel/linear_cross_entropy.py | 21 ++-- 2 files changed, 29 insertions(+), 107 deletions(-) diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index df327c66b0..c610d49a4d 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -127,46 +127,16 @@ def alloc_fn(size: int, alignment: int, stream: int | None): triton.set_allocator(alloc_fn) -class EntropyReductionEnum: - """ - Enum for the reduction method of cross entropy. - """ - - _None = 0 - _Sum = 1 - _Mean = 2 +_REDUCTION_NONE = 0 def get_entropy_reduction_enum_number(reduction: str) -> int: """ - Get the enum number for the reduction method of cross entropy. + Validate the only supported reduction mode and return its kernel code. """ - _enum = EntropyReductionEnum._None if reduction == "none": - _enum = EntropyReductionEnum._None - elif reduction == "sum": - _enum = EntropyReductionEnum._Sum - elif reduction == "mean": - _enum = EntropyReductionEnum._Mean - else: - raise ValueError(f"Invalid reduction: {reduction}") - return _enum - - -def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: - """ - Get the enum for the reduction method of cross entropy. - """ - _enum = EntropyReductionEnum._None - if ce_reduction == 0: - _enum = EntropyReductionEnum._None - elif ce_reduction == 1: - _enum = EntropyReductionEnum._Sum - elif ce_reduction == 2: - _enum = EntropyReductionEnum._Mean - else: - raise ValueError(f"Invalid ce_reduction: {ce_reduction}") - return _enum + return _REDUCTION_NONE + raise ValueError(f"Only reduction='none' is supported, got {reduction!r}") _USE_TRITON = True @@ -201,7 +171,6 @@ def efficient_entropy_kernel_general_mainloop( stride_entropy_b_n: tl.int64, global_logprobs_ptr, stride_global_logprobs: tl.int64, - global_logprobs_scalar_ptr, rcp_temperature: tl.float32, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, @@ -219,9 +188,6 @@ def efficient_entropy_kernel_general_mainloop( pid_m = pid % num_pid_m pid_n = pid // num_pid_m - if pid_m == 0 and pid_n == 0: - tl.store(global_logprobs_scalar_ptr, 0.0) - # create pointers for the first blocks of hidden start_offs_am = pid_m * BLOCK_SIZE_M offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) @@ -362,8 +328,6 @@ def efficient_entropy_triton_kernel_epilogue( stride_global_entropy: tl.int64, global_logprobs_ptr, stride_global_logprobs: tl.int64, - global_logprobs_scalar_ptr, - reduction: int, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, ): @@ -420,14 +384,7 @@ def efficient_entropy_triton_kernel_epilogue( global_logprobs = global_max + tl.log(global_accu) - global_logprobs global_logprobs = -1 * global_logprobs - if reduction == 0: - tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) - elif reduction == 1: - global_logprobs_scalar = tl.sum(global_logprobs, axis=0) - tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) - elif reduction == 2: - global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) - tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) @triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) @@ -520,8 +477,7 @@ def efficient_entropy_triton_epilogue_tp_update( stride_entropy_b: tl.int64, entropy_ptr, stride_entropy: tl.int64, - logprobs_scalar_ptr, - reduction: int, + logprobs_out_ptr, BLOCK_SIZE_M: tl.constexpr, ): pid_m = tl.program_id(axis=0) @@ -542,14 +498,7 @@ def efficient_entropy_triton_epilogue_tp_update( logprobs = maximum + tl.log(accumulate) - logprobs logprobs = -1 * logprobs - if reduction == 0: - tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) - elif reduction == 1: - logprobs_scalar = tl.sum(logprobs, axis=0) - tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) - elif reduction == 2: - logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) - tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + tl.store(logprobs_out_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) _dedicated_stream, _dedicated_events = None, None @@ -559,7 +508,7 @@ def efficient_entropy_forward( hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, - reduction: int | None = 2, + reduction: int | None = _REDUCTION_NONE, temperature: float | None = 1.0, dist_process_group: dist.ProcessGroup | None = None, ) -> list[torch.Tensor]: @@ -587,17 +536,12 @@ def efficient_entropy_forward( vocab_size, hidden_size = weight.shape assert hidden_size % 128 == 0 - REDUCTION = get_entropy_reduction_enum(reduction) - - if REDUCTION == EntropyReductionEnum._None: - if dist_process_group is None: - logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - else: - logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) - elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): - logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) - else: + if reduction != _REDUCTION_NONE: raise ValueError(f"Invalid reduction: {reduction}") + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) assert logprobs.is_contiguous() and entropy.is_contiguous() @@ -617,10 +561,7 @@ def efficient_entropy_forward( _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - if REDUCTION == EntropyReductionEnum._None: - _logprobs = logprobs - else: - _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + _logprobs = logprobs assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda @@ -654,7 +595,6 @@ def mainloop_grid(meta): _entropy_b.stride(1), _logprobs, _logprobs.stride(0), - logprobs, 1.0 / temperature, USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, ) @@ -688,8 +628,6 @@ def epilogue_grid(meta): entropy.stride(0), _logprobs, _logprobs.stride(0), - logprobs, - REDUCTION, ) else: # tensor-parallel @@ -742,7 +680,6 @@ def epilogue_grid(meta): entropy, entropy.stride(0), logprobs, - REDUCTION, ) return (logprobs, entropy, maximum, accumulate, entropy_b) @@ -782,7 +719,6 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( stride_d_entropy: tl.int64, d_logprobs_ptr, stride_d_logprobs: tl.int64, - reduction: int, entropy_b_ptr, stride_entropy_b: tl.int64, d_logits_ptr, @@ -815,14 +751,7 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) accu_rcp = tl.fdiv(1.0, accu) d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: - d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: - d_logprobs = tl.load(d_logprobs_ptr) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) - d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) d_logprobs = -1 * d_logprobs entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) @@ -895,7 +824,7 @@ def efficient_entropy_backward( maximum: torch.Tensor, acc: torch.Tensor, entropy_b: torch.Tensor, - reduction: int | None = 2, + reduction: int | None = _REDUCTION_NONE, should_return_fp32_grad: bool = False, temperature: float | None = 1.0, dist_process_group: dist.ProcessGroup | None = None, @@ -917,12 +846,9 @@ def efficient_entropy_backward( vocab_size, hidden_size = weight.shape assert hidden_size % 128 == 0 - REDUCTION = get_entropy_reduction_enum(reduction) - - if REDUCTION == EntropyReductionEnum._None: - assert dlogprobs.shape == (num_tokens,) - else: - assert dlogprobs.dim() == 0 + if reduction != _REDUCTION_NONE: + raise ValueError(f"Invalid reduction: {reduction}") + assert dlogprobs.shape == (num_tokens,) assert dlogprobs.is_contiguous() and dentropy.is_contiguous() assert dlogprobs.is_cuda and dentropy.is_cuda @@ -974,8 +900,7 @@ def d_logits_grid(meta): dentropy, dentropy.stride(0), dlogprobs, - dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, + dlogprobs.stride(0), entropy_b, entropy_b.stride(0), _d_logits, diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py index 8b0c150332..e000d2d9c5 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -13,7 +13,6 @@ from __future__ import annotations import os -import typing import torch import torch.distributed as dist @@ -76,16 +75,15 @@ class LinearCrossEntropy(torch.autograd.Function): labels: integer label ids; either ``(num_tokens,)`` or ``(batch_size, seq_len)``. temperature: softmax temperature; defaults to ``1.0``. - reduction: ``"none"`` returns per-token negative log-likelihood; - ``"sum"`` and ``"mean"`` return scalars. + reduction: only ``"none"`` is supported and returns per-token + negative log-likelihood. dist_process_group: optional tensor-parallel group for vocab-sharded ``weight``. ``labels`` must contain *global* vocab ids on every rank; the kernel handles the per-rank slice internally. Returns: - ``(logprobs, entropy)`` where ``entropy`` has shape ``(num_tokens,)`` - and ``logprobs`` has shape ``(num_tokens,)`` for ``reduction="none"`` - or ``()`` otherwise. + ``(logprobs, entropy)`` where both tensors have shape + ``(num_tokens,)``. """ @staticmethod @@ -94,9 +92,9 @@ def forward( hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, - temperature: typing.Optional[float] = 1.0, - reduction: typing.Optional[str] = "none", - dist_process_group: typing.Optional[dist.ProcessGroup] = None, + temperature: float | None = 1.0, + reduction: str | None = "none", + dist_process_group: dist.ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if not isinstance(temperature, float): temperature = float(temperature) @@ -265,12 +263,11 @@ def linear_cross_entropy( labels: torch.Tensor, temperature: float = 1.0, reduction: str = "none", - dist_process_group: typing.Optional[dist.ProcessGroup] = None, + dist_process_group: dist.ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Functional wrapper around :class:`LinearCrossEntropy`. - Returns ``(logprobs, entropy)`` with shapes following ``reduction`` - semantics. See the class docstring for full argument descriptions. + Returns per-token ``(logprobs, entropy)``. """ return LinearCrossEntropy.apply( hidden, weight, labels, temperature, reduction, dist_process_group From a6952f3bebd936e6fb420fdd0b529fa1c54eaac9 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 11:59:11 +0800 Subject: [PATCH 22/31] refactor(kernel): remove useless --- areal/utils/kernel/linear_cross_entropy.py | 285 ++++++++------------- 1 file changed, 106 insertions(+), 179 deletions(-) diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py index e000d2d9c5..b839affbd9 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -12,54 +12,9 @@ from __future__ import annotations -import os - import torch import torch.distributed as dist -from areal.utils import logging - -logger = logging.getLogger("LinearCrossEntropy") - - -def _debug_enabled() -> bool: - """Whether to emit per-tensor debug summaries. - - Toggled via ``AREAL_LCE_DEBUG=1`` to avoid GPU-CPU sync in hot paths. - """ - return os.environ.get("AREAL_LCE_DEBUG", "0") == "1" - - -def _summarize(name: str, tensor: torch.Tensor) -> None: - """Emit a tiny statistical summary for diff-driven debugging. - - Triggers a CPU-GPU sync via ``.item()``; only call when - ``_debug_enabled()`` is true. - """ - if not tensor.is_floating_point(): - logger.debug( - "[diff] %s shape=%s dtype=%s device=%s", - name, - tuple(tensor.shape), - tensor.dtype, - tensor.device, - ) - return - flat = tensor.detach().float().reshape(-1) - if flat.numel() == 0: - logger.debug("[diff] %s is empty", name) - return - logger.debug( - "[diff] %s shape=%s dtype=%s mean=%.6e std=%.6e min=%.6e max=%.6e", - name, - tuple(tensor.shape), - tensor.dtype, - flat.mean().item(), - flat.std(unbiased=False).item() if flat.numel() > 1 else 0.0, - flat.min().item(), - flat.max().item(), - ) - class LinearCrossEntropy(torch.autograd.Function): """Fused linear + cross-entropy / token-entropy autograd Function. @@ -101,70 +56,51 @@ def forward( if not isinstance(reduction, str): raise TypeError(f"reduction must be str, got {type(reduction)}") - with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): - # Local import keeps Triton dependency lazy: tests can still - # import this module on machines without Triton. - from areal.utils.kernel import kernels + # Local import keeps Triton dependency lazy: tests can still + # import this module on machines without Triton. + from areal.utils.kernel import kernels - REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) - original_hidden_shape = hidden.shape - if hidden.dim() != 2: - hidden = hidden.reshape(-1, hidden.shape[-1]) - if labels.dim() != 1: - labels = labels.reshape(-1) + original_hidden_shape = hidden.shape + if hidden.dim() != 2: + hidden = hidden.reshape(-1, hidden.shape[-1]) + if labels.dim() != 1: + labels = labels.reshape(-1) - # Triton kernels demand contiguous CUDA tensors; bail out loudly - # on misuse rather than silently materialising copies on a hot - # path. - assert hidden.is_cuda and weight.is_cuda and labels.is_cuda, ( - "LinearCrossEntropy requires CUDA inputs" - ) - assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous(), ( - "LinearCrossEntropy requires contiguous tensors" - ) - - if _debug_enabled(): - _summarize("forward.hidden", hidden) - _summarize("forward.weight", weight) - _summarize("forward.labels", labels) - logger.debug( - "[diff] forward.meta temperature=%.6f reduction=%s tp_world=%d", - temperature, - reduction, - 1 if dist_process_group is None else dist.get_world_size(dist_process_group), - ) - - ( - logprobs, - entropy, - _maximum, - _accumulate, - _entropy_b, - ) = kernels.efficient_entropy_forward( - hidden, - weight, - labels, - REDUCTION, - temperature, - dist_process_group, - ) + # Triton kernels demand contiguous CUDA tensors; bail out loudly + # on misuse rather than silently materialising copies on a hot + # path. + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda, ( + "LinearCrossEntropy requires CUDA inputs" + ) + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous(), ( + "LinearCrossEntropy requires contiguous tensors" + ) - if _debug_enabled(): - _summarize("forward.logprobs", logprobs) - _summarize("forward.entropy", entropy) - _summarize("forward._maximum", _maximum) - _summarize("forward._accumulate", _accumulate) - _summarize("forward._entropy_b", _entropy_b) + ( + logprobs, + entropy, + _maximum, + _accumulate, + _entropy_b, + ) = kernels.efficient_entropy_forward( + hidden, + weight, + labels, + REDUCTION, + temperature, + dist_process_group, + ) - ctx.save_for_backward( - hidden, weight, labels, _maximum, _accumulate, _entropy_b - ) - ctx.original_hidden_shape = original_hidden_shape - ctx.REDUCTION = REDUCTION - ctx.dist_process_group = dist_process_group - ctx.should_return_fp32_grad = False - ctx.temperature = temperature + ctx.save_for_backward( + hidden, weight, labels, _maximum, _accumulate, _entropy_b + ) + ctx.original_hidden_shape = original_hidden_shape + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + ctx.should_return_fp32_grad = False + ctx.temperature = temperature return logprobs, entropy @@ -174,84 +110,75 @@ def backward( dlogprobs: torch.Tensor, dentropy: torch.Tensor, ) -> tuple: - with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): - from areal.utils.kernel import kernels - - ( - hidden, - weight, - labels, - _maximum, - _accumulate, - _entropy_b, - ) = ctx.saved_tensors - - if _debug_enabled(): - _summarize("backward.dlogprobs", dlogprobs) - _summarize("backward.dentropy", dentropy) - - # PyTorch autograd may produce non-contiguous gradient tensors - # (e.g. expanded views from broadcast). Triton kernels require - # contiguous inputs, so ensure contiguity before dispatching. - dlogprobs = dlogprobs.contiguous() - dentropy = dentropy.contiguous() + from areal.utils.kernel import kernels + + ( + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + ) = ctx.saved_tensors + + # PyTorch autograd may produce non-contiguous gradient tensors + # (e.g. expanded views from broadcast). Triton kernels require + # contiguous inputs, so ensure contiguity before dispatching. + dlogprobs = dlogprobs.contiguous() + dentropy = dentropy.contiguous() + + d_hidden, d_weight = kernels.efficient_entropy_backward( + dlogprobs, + dentropy, + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + ctx.REDUCTION, + ctx.should_return_fp32_grad, + ctx.temperature, + ctx.dist_process_group, + ) - d_hidden, d_weight = kernels.efficient_entropy_backward( - dlogprobs, - dentropy, - hidden, - weight, - labels, - _maximum, - _accumulate, - _entropy_b, - ctx.REDUCTION, - ctx.should_return_fp32_grad, - ctx.temperature, - ctx.dist_process_group, + # TP all-reduce on d_hidden. + # + # Why this is required: + # ``efficient_entropy_backward`` computes a *local* contribution + # ``d_hidden_local = d_logits_local @ weight_local`` where each TP + # rank holds only a vocab-shard of ``weight``. The mathematically + # correct gradient is the sum across the TP group: + # d_hidden = sum_over_tp_ranks(d_logits_local @ weight_local). + # In Megatron's normal forward, the surrounding + # ``ColumnParallelLinear`` (output_layer) inserts this all-reduce + # via ``linear_with_grad_accumulation_and_async_allreduce``. The + # fused-LCE fast path monkey-patches ``output_layer.forward`` to + # return ``(hidden, None)`` (an autograd identity), which bypasses + # mcore's machinery — so the all-reduce vanishes unless we + # reproduce it here. + # + # Without this reduction, TP > 1 silently produces gradients that + # equal each rank's local partial, leading to incorrect training + # that is *not* caught by any forward-only invariant since the + # forward kernel already all-reduces (max / logsumexp / entropy + # auxiliaries) inside ``efficient_entropy_forward``. + # + # ``d_weight`` does NOT need an all-reduce: each rank legitimately + # owns its vocab slice's weights, so the gradient on the local + # weight shard is correctly local-only — exactly mirroring how + # mcore handles ColumnParallel weight grads. + if ( + ctx.dist_process_group is not None + and dist.get_world_size(ctx.dist_process_group) > 1 + ): + dist.all_reduce( + d_hidden, + op=dist.ReduceOp.SUM, + group=ctx.dist_process_group, ) - # TP all-reduce on d_hidden. - # - # Why this is required: - # ``efficient_entropy_backward`` computes a *local* contribution - # ``d_hidden_local = d_logits_local @ weight_local`` where each TP - # rank holds only a vocab-shard of ``weight``. The mathematically - # correct gradient is the sum across the TP group: - # d_hidden = sum_over_tp_ranks(d_logits_local @ weight_local). - # In Megatron's normal forward, the surrounding - # ``ColumnParallelLinear`` (output_layer) inserts this all-reduce - # via ``linear_with_grad_accumulation_and_async_allreduce``. The - # fused-LCE fast path monkey-patches ``output_layer.forward`` to - # return ``(hidden, None)`` (an autograd identity), which bypasses - # mcore's machinery — so the all-reduce vanishes unless we - # reproduce it here. - # - # Without this reduction, TP > 1 silently produces gradients that - # equal each rank's local partial, leading to incorrect training - # that is *not* caught by any forward-only invariant since the - # forward kernel already all-reduces (max / logsumexp / entropy - # auxiliaries) inside ``efficient_entropy_forward``. - # - # ``d_weight`` does NOT need an all-reduce: each rank legitimately - # owns its vocab slice's weights, so the gradient on the local - # weight shard is correctly local-only — exactly mirroring how - # mcore handles ColumnParallel weight grads. - if ( - ctx.dist_process_group is not None - and dist.get_world_size(ctx.dist_process_group) > 1 - ): - dist.all_reduce( - d_hidden, - op=dist.ReduceOp.SUM, - group=ctx.dist_process_group, - ) - - d_hidden = d_hidden.view(ctx.original_hidden_shape) - - if _debug_enabled(): - _summarize("backward.d_hidden", d_hidden) - _summarize("backward.d_weight", d_weight) + d_hidden = d_hidden.view(ctx.original_hidden_shape) # Order matches forward: hidden, weight, labels, temperature, reduction, group return d_hidden, d_weight, None, None, None, None From 63463966b502830cab151dd3d53e57ec71d7a9d0 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 12:17:11 +0800 Subject: [PATCH 23/31] refactor(kernel): comment --- areal/engine/megatron_engine.py | 45 +---- .../megatron_utils/fused_lce_capture.py | 167 ++---------------- .../utils/functional/linear_cross_entropy.py | 79 +++------ areal/utils/kernel/kernels.py | 12 -- areal/utils/kernel/linear_cross_entropy.py | 80 ++------- 5 files changed, 53 insertions(+), 330 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index eb3a450f78..58c0931a73 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -720,11 +720,6 @@ def forward_backward_batch( ) -> None: self._ensure_ready() - # Resolve once per call: whether the fused linear-cross-entropy path - # should engage. We engage it only on the pipeline-last stage, in - # non-critic mode, and outside the tree-training branch (those branches - # bring their own gather kernels and additional invariants we do not - # currently extend). use_fused_lce = ( getattr(self.config, "use_fused_linear_ce", False) and not self.config.is_critic @@ -761,10 +756,6 @@ def forward_step(batch_iter, model): cp_size = mpu.get_context_parallel_world_size() cp_local = cp_size > 1 - # Engage fused linear-cross-entropy capture only on the pipeline - # last stage; the LM head only exists there. CP-local logit-gather - # path keeps the standard materialised-logits route because the - # split-and-gather machinery operates on the [seq, vocab] tensor. model_vp_stage_for_capture = getattr(model, "vp_stage", 0) should_capture = ( use_fused_lce @@ -783,9 +774,6 @@ def forward_step(batch_iter, model): gather_cp_output=not cp_local, ) - # Stash captured hidden + LM-head weight on the orig_mb dict so - # the downstream loss/forward callbacks can pick them up via the - # standard `inputs` argument (which is the *same* dict). if ( capture is not None and capture.hidden is not None @@ -833,33 +821,7 @@ def _process_output(input_, output_): cu_seqlens=cu_seqlens, old_cu_seqlens=mb_input.old_cu_seqlens, ) - # When fused-LCE capture is active, the model's - # ``output`` is actually the pre-projection hidden - # state (the LM-head was monkey-patched to a no-op), - # so the unpadded tensor is the hidden we want to - # feed into the fused kernel. - # - # Megatron's outer ``Float16Module`` wraps the inner - # ``GPTModel`` and, on the last pipeline stage, - # *upcasts the wrapped module's outputs to fp32* - # (see ``Float16Module.forward(..., fp32_output=True)`` - # and ``float16_to_fp32`` in - # ``megatron.core.transformer.module``). The captured - # hidden was already cast to ``weight.dtype`` (bf16/fp16) - # inside ``capture_lm_head_hidden._patched_forward`` to - # mirror ``ColumnParallelLinear``'s implicit downcast, - # but ``Float16Module``'s post-hoc upcast then re-promotes - # the tensor returned to mcore back to fp32. The Triton - # GEMM in ``efficient_entropy_forward`` requires both - # operands to share the same dtype; without re-aligning - # here, the kernel raises - # "Both operands must be same dtype. Got fp32 and - # bf16; falling back to reference path." - # and silently disables the fused fast path. - # - # ``Tensor.to(dtype)`` is autograd-aware; backward will - # auto-upcast gradients to the upstream fp32 dtype, which - # is exactly what mcore would have produced anyway. + # Re-align Float16Module's fp32 hidden to lm-head weight dtype. if mb_input.orig_mb.get("_fused_lce_active", False): fused_weight = mb_input.orig_mb.get( FUSED_LCE_WEIGHT_KEY @@ -1919,11 +1881,6 @@ def _compute_logprobs_and_loss( if mpu.get_tensor_model_parallel_world_size() > 1 else None, ) - # vocab_min/max_logits are diagnostics consumed by the - # clip-ratio statistics inside the PPO loss; the fused - # kernel never materialises the [seq, vocab] tensor, so - # we substitute finite proxies derived from the chosen - # logprobs (cheap and never stalls training). proxy = logprobs.detach().float() vocab_min_logits = proxy vocab_max_logits = proxy diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py index a1f3899800..55c124f252 100644 --- a/areal/engine/megatron_utils/fused_lce_capture.py +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -2,54 +2,20 @@ """ LM-head hidden-state capture for the fused linear-cross-entropy fast path. -The fused :func:`areal.utils.kernel.linear_cross_entropy` kernel needs the -pre-projection hidden state (``[seq, hidden]``) and the LM-head weight -(``[vocab, hidden]``, possibly vocab-sharded along the TP group) instead of -the materialised ``[seq, vocab]`` logits tensor. The Megatron-Core -:class:`GPTModel` does not expose either of these to AReaL's -``_compute_logprobs_and_loss`` call site by default, so we install a -temporary monkey-patch on ``output_layer.forward`` for the duration of one -microbatch forward pass: +The fused LCE kernel needs ``(hidden, weight)`` instead of materialised +``[seq, vocab]`` logits. This module temporarily monkey-patches +``output_layer.forward`` to capture those tensors for one microbatch. -1. Stashes the input tensor (``hidden``) and the actual weight (either the - ``output_layer``'s own weight, or the embedding-tied weight passed in via - ``weight=``). When sequence-parallel is active (TP > 1 in AReaL), the - incoming ``input_`` is scattered along seq to ``[seq/tp_size, hidden]``, - so we first call ``gather_from_sequence_parallel_region`` to restore the - full ``[seq, hidden]`` tensor — exactly mirroring the first step of - mcore's ``ColumnParallelLinear.forward``. -2. Returns ``(hidden, None)`` instead of ``(logits, bias)``. Because - :func:`areal.utils.data.unpad_logits` and - :func:`postprocess_packed_seqs_context_parallel` are shape-agnostic on - the leading sequence dim and propagate ``shape[1:]`` verbatim, the - returned hidden tensor flows through the rest of the engine pipeline - without modification — the engine's downstream code on the fused path - never inspects the trailing dim except to take a min/max for diagnostic - purposes, which we override with proxies in - ``MegatronEngine._compute_logprobs_and_loss``. - -The patch is installed only when ``enabled=True`` and uninstalled on -context exit (including on exception), so error-path leaks of the patched -method are impossible. - -Compatibility notes: - -* The patch is incompatible with Megatron-Core's MuP logit scaling - (``config.use_mup``), MTP (``config.mtp_num_layers > 0``) and inference - paths that materialise ``last_token_logits``. We assert against these - configurations at install time and refuse to engage; the engine then - falls back to the materialised path automatically. -* The patch is also incompatible with the critic value head, since that - head is a 1-output-dim ``ColumnParallelLinear`` and the fused kernel - requires the LM-head weight; the engine guards on ``is_critic`` before - calling this helper. +Compatibility: incompatible with MuP (``use_mup``), MTP +(``mtp_num_layers > 0``), and critic heads. The engine falls back to +the materialised path automatically when any of these conditions hold. """ from __future__ import annotations +from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Iterator, Optional import torch from megatron.core import parallel_state as mpu @@ -61,32 +27,18 @@ logger = logging.getLogger("FusedLCECapture") -# Keys used to pass captured tensors from forward_step → process_output. -# Centralised here to keep the engine and helper in lockstep. FUSED_LCE_HIDDEN_KEY = "_fused_lce_hidden" FUSED_LCE_WEIGHT_KEY = "_fused_lce_weight" @dataclass class _CaptureSlot: - """Mutable single-shot stash populated by the patched ``forward``.""" + hidden: torch.Tensor | None = None + weight: torch.Tensor | None = None - hidden: Optional[torch.Tensor] = None - weight: Optional[torch.Tensor] = None - -def _unwrap_to_post_process_module(model: torch.nn.Module) -> Optional[torch.nn.Module]: - """Strip DDP/Float16Module wrappers and return the inner module that - owns ``output_layer`` (i.e. an mcore ``GPTModel`` on the last PP stage), - or ``None`` if no such module is reachable on this rank. - - Returning ``None`` (instead of raising) lets the caller skip the patch - transparently on intermediate pipeline stages. - """ +def _unwrap_to_post_process_module(model: torch.nn.Module) -> torch.nn.Module | None: inner = model - # Loop bound: at most ~4 wrapper layers in practice (DDP, Float16Module, - # vp wrapper). 8 is a generous upper bound that protects against - # accidental cycles. for _ in range(8): if hasattr(inner, "output_layer") and inner.output_layer is not None: return inner @@ -97,24 +49,18 @@ def _unwrap_to_post_process_module(model: torch.nn.Module) -> Optional[torch.nn. def _is_compatible(post_process_module: torch.nn.Module) -> bool: - """Refuse to engage when the model uses features incompatible with the - fused kernel. Falling back is preferred over silently producing wrong - numbers.""" config = getattr(post_process_module, "config", None) if config is None: - # Conservative default: don't patch unknown modules. return False if getattr(config, "use_mup", False): logger.warning( - "Fused LCE: MuP scaling is enabled (config.use_mup=True); " - "fused path is disabled for this microbatch." + "Fused LCE disabled: MuP scaling is enabled (config.use_mup=True)." ) return False if getattr(config, "mtp_num_layers", 0): logger.warning( - "Fused LCE: MTP is enabled (config.mtp_num_layers>0); " - "fused path is disabled for this microbatch." + "Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)." ) return False @@ -122,18 +68,11 @@ def _is_compatible(post_process_module: torch.nn.Module) -> bool: if output_layer is None: return False - # Sequence parallel + TP gather inside output_layer is what we *want* - # to bypass; AReaL runs with parallel_output=True which keeps logits - # vocab-sharded — exactly what the fused kernel expects via tp_group. parallel_output = getattr(post_process_module, "parallel_output", True) if not parallel_output: - # If gather_output=True, the engine has already requested the - # full-vocab logits to be all-gathered; capturing hidden here would - # mean the downstream kernel needs to gather instead, doubling - # comms. Prefer the existing materialised path in that case. logger.warning( - "Fused LCE: model has parallel_output=False; fused path is " - "disabled to avoid an extra TP gather." + "Fused LCE disabled: model has parallel_output=False; " + "would require an extra TP gather." ) return False @@ -143,24 +82,13 @@ def _is_compatible(post_process_module: torch.nn.Module) -> bool: @contextmanager def capture_lm_head_hidden( model: torch.nn.Module, *, enabled: bool -) -> Iterator[Optional[_CaptureSlot]]: - """Context manager that captures the input + weight handed to the - ``output_layer`` of the wrapped Megatron GPT model. - - Yields: - ``_CaptureSlot`` on the pipeline-last stage when ``enabled`` is - True and the model is compatible; ``None`` otherwise. The caller is - expected to inspect ``slot.hidden`` for ``None`` to decide whether - the fused path is usable for this microbatch. - """ +) -> Iterator[_CaptureSlot | None]: if not enabled: yield None return post_process = _unwrap_to_post_process_module(model) if post_process is None or not _is_compatible(post_process): - # Either an intermediate PP stage or an incompatible config; the - # engine will transparently fall back to the materialised path. yield None return @@ -168,97 +96,32 @@ def capture_lm_head_hidden( slot = _CaptureSlot() original_forward = output_layer.forward - # Detect sequence-parallel mode. In mcore, when TP > 1 AReaL enables - # ``sequence_parallel=True`` (see ``MegatronEngine._make_parallel_strategy``), - # which means the input handed to ``ColumnParallelLinear.forward`` is - # *scattered* along the sequence dimension to shape ``[seq/tp_size, hidden]``. - # The original ``ColumnParallelLinear.forward`` first calls - # ``gather_from_sequence_parallel_region`` to restore the full ``[seq, hidden]`` - # tensor before doing the matmul. Our identity-style patch must replicate - # that gather, otherwise: - # * the captured ``hidden`` is only this rank's sequence shard, leading - # to wrong fused-kernel inputs and wrong logprobs; - # * the tensor returned to mcore (which then flows through - # ``postprocess_packed_seqs_context_parallel`` and ``unpad_logits``) has - # dim-0 = seq/tp_size, which mismatches ``cu_seqlens`` / ``old_cu_seqlens`` - # and crashes with shape errors like "expanded size (X) must match - # existing size (X/tp_size) at non-singleton dimension 0". config = getattr(post_process, "config", None) sequence_parallel = bool(getattr(config, "sequence_parallel", False)) tp_world_size = mpu.get_tensor_model_parallel_world_size() needs_sp_gather = sequence_parallel and tp_world_size > 1 def _patched_forward(input_, weight=None, runtime_gather_output=None): - # Resolve the actual weight: either passed in (weight tying) or the - # output_layer's own parameter. We intentionally store a *reference* - # (not detach) so autograd flows through both the kernel forward - # and backward. actual_weight = weight if weight is not None else output_layer.weight - # When sequence parallel is on, ``input_`` is shape ``[seq/tp_size, hidden]``. - # Gather along the sequence dim to obtain the full ``[seq, hidden]`` tensor - # — this is exactly what the original ``ColumnParallelLinear.forward`` - # does as its first step. ``gather_from_sequence_parallel_region`` is - # an autograd-aware op (its backward is a reduce-scatter along seq), - # so gradients flow correctly back into the SP-scattered upstream. hidden = input_ if needs_sp_gather: hidden = gather_from_sequence_parallel_region(hidden) - # Align ``hidden`` dtype to ``actual_weight`` dtype before handing the - # tensors to the fused Triton kernel. - # - # Why this is required: - # * Megatron-Core feeds ``output_layer`` with the post-final-layernorm - # activation, which is typically fp32 under mixed-precision training, - # while ``output_layer.weight`` is bf16/fp16. The original - # ``ColumnParallelLinear.forward`` silently downcasts ``input_`` to - # the weight dtype inside - # ``linear_with_grad_accumulation_and_async_allreduce``; our - # identity-style monkey-patch bypasses that path and would otherwise - # hand mismatched dtypes to ``efficient_entropy_forward``. - # * Triton's ``tl.dot`` requires both operands to share the same dtype; - # a mismatch triggers warnings such as - # "Both operands must be same dtype. Got fp32 and bf16; falling back - # to reference path." and silently disables the fused fast path. - # * ``Tensor.to(dtype)`` is autograd-aware: backward auto-upcasts - # gradients to the original dtype, so the upstream fp32 activation - # receives a fp32 grad as expected. if hidden.dtype != actual_weight.dtype: hidden = hidden.to(actual_weight.dtype) slot.hidden = hidden slot.weight = actual_weight - # Return ``(hidden, None)``: callers expect ``(logits, bias)`` and - # only ever destructure with ``logits, _ = output_layer(...)``. The - # downstream pipeline (``unpad_logits`` etc.) is shape-agnostic on - # the trailing dim, so passing ``hidden`` through is safe; the - # fused kernel will then consume the stashed tensors and produce - # the real per-token logprobs. - # - # Crucially we return the *gathered* hidden so that the leading - # sequence dim matches what mcore would have produced for the real - # logits tensor (``[seq, vocab/tp_size]``). This keeps every - # downstream shape invariant intact (CP all-gather, batch-padding - # strip, ``unpad_logits`` cu_seqlens slicing). return hidden, None - # ``output_layer.forward = _patched_forward`` replaces the bound method - # at instance level (via ``__dict__`` lookup), shadowing the class - # method without mutating the class. Restoration in ``finally`` is - # therefore exception-safe. output_layer.forward = _patched_forward # type: ignore[assignment] try: yield slot finally: - # Best-effort restoration. ``del`` removes the instance-level - # binding and re-exposes the class-level method, which is what - # callers will execute on subsequent forwards. try: del output_layer.forward except AttributeError: - # If __dict__ assignment is not supported (rare for nn.Module - # subclasses), fall back to direct restoration. output_layer.forward = original_forward # type: ignore[assignment] diff --git a/areal/utils/functional/linear_cross_entropy.py b/areal/utils/functional/linear_cross_entropy.py index bf91c7b7e0..88239043ef 100644 --- a/areal/utils/functional/linear_cross_entropy.py +++ b/areal/utils/functional/linear_cross_entropy.py @@ -1,28 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 """ -High-level fused linear cross-entropy entry points for AReaL. - -These wrappers bridge the :class:`LinearCrossEntropy` Triton kernel into -AReaL's existing :func:`gather_logprobs_entropy` interface so that the -Megatron path can opt in via a single configuration flag without -restructuring the model forward. - -The wrappers: - -* accept already-flat ``hidden`` of shape ``(num_tokens, hidden_size)`` and - ``labels`` of shape ``(num_tokens,)`` (or higher-dimensional tensors with - an explicit last hidden dim) so the call site looks identical to the - existing materialised path; -* support optional tensor-parallel via ``tp_group`` for vocab-sharded - ``weight`` matrices; -* fall back gracefully to the materialised reference path when Triton is - unavailable or inputs are not on CUDA, so unit tests can still run on CPU. +Fused linear cross-entropy entry points for AReaL. + +These wrappers bridge the fused Triton kernel into AReaL's +:func:`gather_logprobs_entropy` interface so the Megatron path can opt in +via a single config flag. They fall back to the materialised reference +path when Triton is unavailable or inputs are not on CUDA. """ from __future__ import annotations import os -from typing import Optional import torch import torch.distributed as dist @@ -33,12 +21,10 @@ def _force_fallback() -> bool: - """Allow ops/CI to disable the fused kernel via env var without code change.""" return os.environ.get("AREAL_DISABLE_FUSED_LCE", "0") == "1" def _kernel_available() -> bool: - """Whether the Triton fused kernel can run on this host.""" if _force_fallback(): return False if not torch.cuda.is_available(): @@ -55,15 +41,8 @@ def _reference_logprobs_entropy( weight: torch.Tensor, labels: torch.Tensor, temperature: float, - tp_group: Optional[dist.ProcessGroup], + tp_group: dist.ProcessGroup | None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Reference (materialised-logits) implementation. - - Used when Triton is unavailable. Mathematically equivalent to the fused - kernel up to floating-point reordering, which is why the test suite - asserts with explicit rtol/atol rather than bitwise equality. - """ - # Shape normalisation matches the fused kernel. flat_hidden = hidden.reshape(-1, hidden.shape[-1]) flat_labels = labels.reshape(-1) @@ -72,8 +51,6 @@ def _reference_logprobs_entropy( logits = logits / temperature if tp_group is not None and dist.get_world_size(tp_group) > 1: - # Vocab-parallel: gather full vocab logits across TP group. - # Used only as a slow correctness fallback. world_size = dist.get_world_size(tp_group) gathered = [torch.empty_like(logits) for _ in range(world_size)] dist.all_gather(gathered, logits, group=tp_group) @@ -93,22 +70,19 @@ def linear_cross_entropy_logprobs_entropy( weight: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0, - tp_group: Optional[dist.ProcessGroup] = None, + tp_group: dist.ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute per-token log-prob and entropy from hidden states + lm-head weight. + """Compute per-token log-prob and entropy via the fused kernel. - This is the fused counterpart to - :func:`areal.utils.functional.vocab_parallel.gather_logprobs_entropy`, - but consumes ``hidden`` (last layer states) and ``weight`` (lm-head - weight) directly instead of a materialised ``[num_tokens, vocab_size]`` - logits tensor. Memory savings scale with ``vocab_size``. + Falls back to the materialised reference path when the fused kernel is + unavailable. Args: hidden: ``(..., hidden_size)`` last-layer hidden states. - weight: ``(vocab_size, hidden_size)`` lm-head weight; may be sharded - on the vocab dimension when ``tp_group`` is set. - labels: ``(...,)`` integer label ids matching the leading dims of - ``hidden``. With TP, labels MUST hold *global* vocab ids. + weight: ``(vocab_size, hidden_size)`` lm-head weight; may be + vocab-sharded when ``tp_group`` is set. + labels: ``(...,)`` integer label ids. With TP, labels must hold + *global* vocab ids. temperature: softmax temperature. tp_group: optional tensor-parallel group when ``weight`` is sharded. @@ -118,30 +92,22 @@ def linear_cross_entropy_logprobs_entropy( leading_shape = labels.shape if _kernel_available(): - # Lazy import: keeps a hard Triton import out of the module path so - # CPU-only environments can still load areal.utils.functional. from areal.utils.kernel.linear_cross_entropy import linear_cross_entropy if hidden.device.type != "cuda": logger.warning( - "Fused LCE requested but hidden is on %s; falling back to reference path.", + "Fused LCE requested but hidden is on %s; falling back to reference.", hidden.device, ) else: try: logprobs, entropy = linear_cross_entropy( - hidden, - weight, - labels, - temperature, - "none", - tp_group, + hidden, weight, labels, temperature, "none", tp_group, ) return logprobs.reshape(leading_shape), entropy.reshape(leading_shape) - except Exception as exc: # pragma: no cover - fall back path + except Exception as exc: logger.warning( - "Fused LCE kernel raised %s; falling back to reference path.", - exc, + "Fused LCE kernel raised %s; falling back to reference.", exc, ) logprobs, entropy = _reference_logprobs_entropy( @@ -155,12 +121,9 @@ def linear_cross_entropy_logprobs( weight: torch.Tensor, labels: torch.Tensor, temperature: float = 1.0, - tp_group: Optional[dist.ProcessGroup] = None, + tp_group: dist.ProcessGroup | None = None, ) -> torch.Tensor: - """Logprobs-only counterpart of :func:`linear_cross_entropy_logprobs_entropy`. - - Returns a tensor shaped like ``labels``. - """ + """Logprobs-only counterpart of :func:`linear_cross_entropy_logprobs_entropy`.""" logprobs, _ = linear_cross_entropy_logprobs_entropy( hidden, weight, labels, temperature, tp_group ) diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index c610d49a4d..10f2dc08cf 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -40,9 +40,6 @@ import torch.distributed as dist -# --- Device helpers ----------------------------------------------------------- -# AReaL relies on torch directly for CUDA device primitives used by the -# Triton kernels below. def _is_cuda_available() -> bool: return torch.cuda.is_available() @@ -131,9 +128,6 @@ def alloc_fn(size: int, alignment: int, stream: int | None): def get_entropy_reduction_enum_number(reduction: str) -> int: - """ - Validate the only supported reduction mode and return its kernel code. - """ if reduction == "none": return _REDUCTION_NONE raise ValueError(f"Only reduction='none' is supported, got {reduction!r}") @@ -512,9 +506,6 @@ def efficient_entropy_forward( temperature: float | None = 1.0, dist_process_group: dist.ProcessGroup | None = None, ) -> list[torch.Tensor]: - """ - forward host function - """ assert hidden.is_cuda and weight.is_cuda and labels.is_cuda assert weight.device == hidden.device and labels.device == hidden.device assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 @@ -829,9 +820,6 @@ def efficient_entropy_backward( temperature: float | None = 1.0, dist_process_group: dist.ProcessGroup | None = None, ) -> list[torch.Tensor]: - """ - backward host function - """ assert hidden.is_cuda and weight.is_cuda and labels.is_cuda assert weight.device == hidden.device and labels.device == hidden.device assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py index b839affbd9..7bf564d2e5 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -1,13 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """ -``LinearCrossEntropy`` autograd Function for AReaL. - -This module exposes a drop-in replacement for the -``logits = hidden @ weight.T`` -> ``log_softmax`` -> per-token -log-probability and entropy pipeline. Internally it dispatches to a Triton -kernel that fuses the matmul with the cross-entropy -reduction so that the ``[num_tokens, vocab_size]`` logits tensor is never -materialized. +Fused linear + cross-entropy autograd Function. + +Dispatches to a Triton kernel that fuses the matmul with cross-entropy +so that the ``[num_tokens, vocab_size]`` logits tensor is never materialised. """ from __future__ import annotations @@ -17,28 +13,19 @@ class LinearCrossEntropy(torch.autograd.Function): - """Fused linear + cross-entropy / token-entropy autograd Function. - - Forward signature: + """Fused linear + cross-entropy autograd Function. Args: - hidden: ``(num_tokens, hidden_size)`` or - ``(batch_size, seq_len, hidden_size)``. Must be contiguous on - CUDA. - weight: ``(vocab_size, hidden_size)`` lm-head weight. Must be - contiguous on CUDA. - labels: integer label ids; either ``(num_tokens,)`` or - ``(batch_size, seq_len)``. + hidden: ``(num_tokens, hidden_size)`` contiguous CUDA tensor. + weight: ``(vocab_size, hidden_size)`` lm-head weight, contiguous CUDA. + labels: ``(num_tokens,)`` integer label ids on CUDA. temperature: softmax temperature; defaults to ``1.0``. - reduction: only ``"none"`` is supported and returns per-token - negative log-likelihood. - dist_process_group: optional tensor-parallel group for vocab-sharded - ``weight``. ``labels`` must contain *global* vocab ids on every - rank; the kernel handles the per-rank slice internally. + reduction: only ``"none"`` is supported. + dist_process_group: optional TP group for vocab-sharded ``weight``. + ``labels`` must contain *global* vocab ids on every rank. Returns: - ``(logprobs, entropy)`` where both tensors have shape - ``(num_tokens,)``. + ``(logprobs, entropy)`` both shaped ``(num_tokens,)``. """ @staticmethod @@ -56,8 +43,6 @@ def forward( if not isinstance(reduction, str): raise TypeError(f"reduction must be str, got {type(reduction)}") - # Local import keeps Triton dependency lazy: tests can still - # import this module on machines without Triton. from areal.utils.kernel import kernels REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) @@ -68,9 +53,6 @@ def forward( if labels.dim() != 1: labels = labels.reshape(-1) - # Triton kernels demand contiguous CUDA tensors; bail out loudly - # on misuse rather than silently materialising copies on a hot - # path. assert hidden.is_cuda and weight.is_cuda and labels.is_cuda, ( "LinearCrossEntropy requires CUDA inputs" ) @@ -121,9 +103,6 @@ def backward( _entropy_b, ) = ctx.saved_tensors - # PyTorch autograd may produce non-contiguous gradient tensors - # (e.g. expanded views from broadcast). Triton kernels require - # contiguous inputs, so ensure contiguity before dispatching. dlogprobs = dlogprobs.contiguous() dentropy = dentropy.contiguous() @@ -142,32 +121,9 @@ def backward( ctx.dist_process_group, ) - # TP all-reduce on d_hidden. - # - # Why this is required: - # ``efficient_entropy_backward`` computes a *local* contribution - # ``d_hidden_local = d_logits_local @ weight_local`` where each TP - # rank holds only a vocab-shard of ``weight``. The mathematically - # correct gradient is the sum across the TP group: - # d_hidden = sum_over_tp_ranks(d_logits_local @ weight_local). - # In Megatron's normal forward, the surrounding - # ``ColumnParallelLinear`` (output_layer) inserts this all-reduce - # via ``linear_with_grad_accumulation_and_async_allreduce``. The - # fused-LCE fast path monkey-patches ``output_layer.forward`` to - # return ``(hidden, None)`` (an autograd identity), which bypasses - # mcore's machinery — so the all-reduce vanishes unless we - # reproduce it here. - # - # Without this reduction, TP > 1 silently produces gradients that - # equal each rank's local partial, leading to incorrect training - # that is *not* caught by any forward-only invariant since the - # forward kernel already all-reduces (max / logsumexp / entropy - # auxiliaries) inside ``efficient_entropy_forward``. - # - # ``d_weight`` does NOT need an all-reduce: each rank legitimately - # owns its vocab slice's weights, so the gradient on the local - # weight shard is correctly local-only — exactly mirroring how - # mcore handles ColumnParallel weight grads. + # TP all-reduce on d_hidden: the fused path bypasses mcore's + # ColumnParallelLinear which normally inserts this reduction. + # d_weight does NOT need all-reduce (each rank owns its vocab shard). if ( ctx.dist_process_group is not None and dist.get_world_size(ctx.dist_process_group) > 1 @@ -180,7 +136,6 @@ def backward( d_hidden = d_hidden.view(ctx.original_hidden_shape) - # Order matches forward: hidden, weight, labels, temperature, reduction, group return d_hidden, d_weight, None, None, None, None @@ -192,10 +147,7 @@ def linear_cross_entropy( reduction: str = "none", dist_process_group: dist.ProcessGroup | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - """Functional wrapper around :class:`LinearCrossEntropy`. - - Returns per-token ``(logprobs, entropy)``. - """ + """Functional wrapper around :class:`LinearCrossEntropy`.""" return LinearCrossEntropy.apply( hidden, weight, labels, temperature, reduction, dist_process_group ) From 0b044e514ab77a07b50b73258fc323e778697ea4 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 12:19:57 +0800 Subject: [PATCH 24/31] docs(kernels): fix --- areal/utils/kernel/kernels.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index 10f2dc08cf..c59e89f3cf 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -1,36 +1,7 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Copyright 2024 Bytedance Ltd. and/or its affiliates -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. """ Implementations of the linear cross entropy with token entropy kernel. +Ref some code from verl. The Triton kernel implementations fuse the matmul with cross-entropy reduction so that the ``[num_tokens, vocab_size]`` logits tensor is never materialized, trading kernel-launch overhead for large memory savings. From 5e35bbe651d1034df5dbd026547765d437094ff2 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 12:21:26 +0800 Subject: [PATCH 25/31] feat(test): fix --- areal/utils/kernel/kernels.py | 1 + 1 file changed, 1 insertion(+) diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index c59e89f3cf..d4965e76c5 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -1,3 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 """ Implementations of the linear cross entropy with token entropy kernel. From 09eff549c53c32063c30e46a8d55267c716fa3c6 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 14:20:15 +0800 Subject: [PATCH 26/31] fix(engine): fix --- .../megatron_utils/fused_lce_capture.py | 50 ++++++++++++++++--- 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py index 55c124f252..b7ee8bf2d0 100644 --- a/areal/engine/megatron_utils/fused_lce_capture.py +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -7,8 +7,9 @@ ``output_layer.forward`` to capture those tensors for one microbatch. Compatibility: incompatible with MuP (``use_mup``), MTP -(``mtp_num_layers > 0``), and critic heads. The engine falls back to -the materialised path automatically when any of these conditions hold. +(``mtp_num_layers > 0``), critic heads, and hidden sizes that do not +satisfy the fused-kernel alignment requirement. The engine falls back +to the materialised path automatically when any of these conditions hold. """ from __future__ import annotations @@ -29,6 +30,8 @@ FUSED_LCE_HIDDEN_KEY = "_fused_lce_hidden" FUSED_LCE_WEIGHT_KEY = "_fused_lce_weight" +_HIDDEN_SIZE_ALIGNMENT = 128 +_WARNED_INCOMPATIBILITIES: set[str] = set() @dataclass @@ -48,18 +51,42 @@ def _unwrap_to_post_process_module(model: torch.nn.Module) -> torch.nn.Module | return None +def _warn_incompatible_once(key: str, message: str, *args: object) -> None: + if key in _WARNED_INCOMPATIBILITIES: + return + _WARNED_INCOMPATIBILITIES.add(key) + logger.warning(message, *args) + + +def _get_lm_head_hidden_size( + config: object, + output_layer: torch.nn.Module, +) -> int | None: + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is not None: + return int(hidden_size) + + weight = getattr(output_layer, "weight", None) + if weight is not None and hasattr(weight, "shape") and len(weight.shape) > 0: + return int(weight.shape[-1]) + + return None + + def _is_compatible(post_process_module: torch.nn.Module) -> bool: config = getattr(post_process_module, "config", None) if config is None: return False if getattr(config, "use_mup", False): - logger.warning( + _warn_incompatible_once( + "use_mup", "Fused LCE disabled: MuP scaling is enabled (config.use_mup=True)." ) return False if getattr(config, "mtp_num_layers", 0): - logger.warning( + _warn_incompatible_once( + "mtp", "Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)." ) return False @@ -68,11 +95,22 @@ def _is_compatible(post_process_module: torch.nn.Module) -> bool: if output_layer is None: return False + hidden_size = _get_lm_head_hidden_size(config, output_layer) + if hidden_size is not None and hidden_size % _HIDDEN_SIZE_ALIGNMENT != 0: + _warn_incompatible_once( + f"hidden_size:{hidden_size}", + "Fused LCE disabled: hidden_size=%s is not divisible by %s.", + hidden_size, + _HIDDEN_SIZE_ALIGNMENT, + ) + return False + parallel_output = getattr(post_process_module, "parallel_output", True) if not parallel_output: - logger.warning( + _warn_incompatible_once( + "parallel_output", "Fused LCE disabled: model has parallel_output=False; " - "would require an extra TP gather." + "would require an extra TP gather.", ) return False From 825d44ca1eeec3cb24e4d0fc638b05cc53bd89ad Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 16:17:22 +0800 Subject: [PATCH 27/31] fix(megatron): fix vocab --- areal/engine/megatron_engine.py | 27 +++++++----- .../megatron_utils/fused_lce_capture.py | 14 +++++++ areal/trainer/ppo/actor.py | 6 ++- areal/trainer/sft/lm_engine.py | 6 ++- .../utils/functional/linear_cross_entropy.py | 42 +++++++++++++++++-- areal/utils/kernel/linear_cross_entropy.py | 30 ++++++++++--- 6 files changed, 104 insertions(+), 21 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 58c0931a73..b1eae9c488 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -1872,18 +1872,23 @@ def _compute_logprobs_and_loss( and fused_hidden is not None and fused_weight is not None ): - logprobs, entropy = linear_cross_entropy_logprobs_entropy( - fused_hidden, - fused_weight, - labels, - temperature=self.config.temperature, - tp_group=mpu.get_tensor_model_parallel_group() - if mpu.get_tensor_model_parallel_world_size() > 1 - else None, + logprobs, entropy, vocab_max_logits = ( + linear_cross_entropy_logprobs_entropy( + fused_hidden, + fused_weight, + labels, + temperature=self.config.temperature, + tp_group=mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None, + return_max_logits=True, + ) ) - proxy = logprobs.detach().float() - vocab_min_logits = proxy - vocab_max_logits = proxy + # Fused kernel does not track per-token vocab min logits; + # skip the min telemetry rather than report a misleading + # proxy. Consumers must guard ``vocab_min_logits`` and + # ``vocab_max_logits`` independently. + vocab_min_logits = None else: logprobs, entropy = gather_logprobs_entropy( output, diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py index b7ee8bf2d0..cb8d3064cc 100644 --- a/areal/engine/megatron_utils/fused_lce_capture.py +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -114,6 +114,20 @@ def _is_compatible(post_process_module: torch.nn.Module) -> bool: ) return False + # The Triton kernel hard-requires hidden_size to be a multiple of 128 + # (BLOCK_HD constant). Surface this constraint at the gating layer so + # incompatible models fall back to the materialised path before the + # autograd graph is built; an assert raised inside ``backward`` would + # otherwise hard-kill the training loop. + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is None or hidden_size % 128 != 0: + logger.warning( + "Fused LCE disabled: hidden_size=%s is not a multiple of 128 " + "(Triton kernel BLOCK_HD constraint).", + hidden_size, + ) + return False + return True diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 07944a31a5..b0cb676ccc 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -546,9 +546,13 @@ def grpo_loss_fn( if "filtered_fraction" in stat: stats_tracker.scalar(rs_filtered_fraction=stat["filtered_fraction"]) - if vocab_min_logits is not None and vocab_max_logits is not None: + if vocab_min_logits is not None: stats_tracker.stat( vocab_min_logits=vocab_min_logits, + denominator="n_tokens", + ) + if vocab_max_logits is not None: + stats_tracker.stat( vocab_max_logits=vocab_max_logits, denominator="n_tokens", ) diff --git a/areal/trainer/sft/lm_engine.py b/areal/trainer/sft/lm_engine.py index 9dd3f8439e..d229ad804f 100644 --- a/areal/trainer/sft/lm_engine.py +++ b/areal/trainer/sft/lm_engine.py @@ -121,9 +121,13 @@ def compute_packed_sft_loss( stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs") stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens") - if vocab_min_logits is not None and vocab_max_logits is not None: + if vocab_min_logits is not None: stats_tracker.stat( vocab_min_logits=vocab_min_logits, + denominator="n_tokens", + ) + if vocab_max_logits is not None: + stats_tracker.stat( vocab_max_logits=vocab_max_logits, denominator="n_tokens", ) diff --git a/areal/utils/functional/linear_cross_entropy.py b/areal/utils/functional/linear_cross_entropy.py index 88239043ef..2dfd318e6a 100644 --- a/areal/utils/functional/linear_cross_entropy.py +++ b/areal/utils/functional/linear_cross_entropy.py @@ -42,7 +42,8 @@ def _reference_logprobs_entropy( labels: torch.Tensor, temperature: float, tp_group: dist.ProcessGroup | None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_max_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: flat_hidden = hidden.reshape(-1, hidden.shape[-1]) flat_labels = labels.reshape(-1) @@ -62,6 +63,14 @@ def _reference_logprobs_entropy( ).squeeze(-1) probs = log_softmax.exp() entropy = -(probs * log_softmax).sum(dim=-1) + if return_max_logits: + # Return max of the post-temperature logits, scaled back by ``temperature`` + # so the value matches ``raw_logits.max(-1).values`` (matches the + # non-fused telemetry path exactly). + max_logits = logits.detach().max(dim=-1).values.float() + if temperature != 1.0: + max_logits = max_logits * temperature + return log_probs_labels, entropy, max_logits return log_probs_labels, entropy @@ -71,7 +80,8 @@ def linear_cross_entropy_logprobs_entropy( labels: torch.Tensor, temperature: float = 1.0, tp_group: dist.ProcessGroup | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_max_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute per-token log-prob and entropy via the fused kernel. Falls back to the materialised reference path when the fused kernel is @@ -85,9 +95,16 @@ def linear_cross_entropy_logprobs_entropy( *global* vocab ids. temperature: softmax temperature. tp_group: optional tensor-parallel group when ``weight`` is sharded. + return_max_logits: when ``True``, additionally returns the per-token + max of the **raw** (pre-temperature) logits, shape ``labels.shape``, + dtype ``float32``. The fused kernel internally tracks + ``max(logits/temperature)``; we multiply it back by ``temperature`` + so the value is numerically identical to + ``raw_logits.max(-1).values`` from the non-fused path. Returns: - ``(logprobs, entropy)`` both shaped like ``labels``. + ``(logprobs, entropy)`` both shaped like ``labels``; or + ``(logprobs, entropy, max_logits)`` when ``return_max_logits=True``. """ leading_shape = labels.shape @@ -101,6 +118,16 @@ def linear_cross_entropy_logprobs_entropy( ) else: try: + if return_max_logits: + logprobs, entropy, max_logits = linear_cross_entropy( + hidden, weight, labels, temperature, "none", tp_group, + return_max_logits=True, + ) + return ( + logprobs.reshape(leading_shape), + entropy.reshape(leading_shape), + max_logits.reshape(leading_shape), + ) logprobs, entropy = linear_cross_entropy( hidden, weight, labels, temperature, "none", tp_group, ) @@ -110,6 +137,15 @@ def linear_cross_entropy_logprobs_entropy( "Fused LCE kernel raised %s; falling back to reference.", exc, ) + if return_max_logits: + logprobs, entropy, max_logits = _reference_logprobs_entropy( + hidden, weight, labels, temperature, tp_group, return_max_logits=True, + ) + return ( + logprobs.reshape(leading_shape), + entropy.reshape(leading_shape), + max_logits.reshape(leading_shape), + ) logprobs, entropy = _reference_logprobs_entropy( hidden, weight, labels, temperature, tp_group ) diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py index 7bf564d2e5..6e7c037ae0 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -23,9 +23,14 @@ class LinearCrossEntropy(torch.autograd.Function): reduction: only ``"none"`` is supported. dist_process_group: optional TP group for vocab-sharded ``weight``. ``labels`` must contain *global* vocab ids on every rank. + return_max_logits: when ``True``, the autograd Function additionally + returns the per-token raw-logit max (kernel-internal + ``max(logits/temperature)`` re-scaled by ``temperature``). The + extra output is detached / non-differentiable. Returns: - ``(logprobs, entropy)`` both shaped ``(num_tokens,)``. + ``(logprobs, entropy)`` both shaped ``(num_tokens,)``; or + ``(logprobs, entropy, max_logits)`` when ``return_max_logits=True``. """ @staticmethod @@ -37,7 +42,8 @@ def forward( temperature: float | None = 1.0, reduction: str | None = "none", dist_process_group: dist.ProcessGroup | None = None, - ) -> tuple[torch.Tensor, torch.Tensor]: + return_max_logits: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if not isinstance(temperature, float): temperature = float(temperature) if not isinstance(reduction, str): @@ -84,6 +90,17 @@ def forward( ctx.should_return_fp32_grad = False ctx.temperature = temperature + if return_max_logits: + # ``_maximum`` is the per-token max of ``logits / temperature`` + # (post-temperature, online-softmax accumulator). Multiply back + # by ``temperature`` to recover the raw-logit max so the value + # matches ``raw_logits.max(-1).values`` from the non-fused path. + if temperature != 1.0: + max_logits = _maximum.detach() * temperature + else: + max_logits = _maximum.detach().clone() + return logprobs, entropy, max_logits + return logprobs, entropy @staticmethod @@ -91,6 +108,7 @@ def backward( ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor, + dmax_logits: torch.Tensor | None = None, ) -> tuple: from areal.utils.kernel import kernels @@ -136,7 +154,7 @@ def backward( d_hidden = d_hidden.view(ctx.original_hidden_shape) - return d_hidden, d_weight, None, None, None, None + return d_hidden, d_weight, None, None, None, None, None def linear_cross_entropy( @@ -146,8 +164,10 @@ def linear_cross_entropy( temperature: float = 1.0, reduction: str = "none", dist_process_group: dist.ProcessGroup | None = None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_max_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Functional wrapper around :class:`LinearCrossEntropy`.""" return LinearCrossEntropy.apply( - hidden, weight, labels, temperature, reduction, dist_process_group + hidden, weight, labels, temperature, reduction, dist_process_group, + return_max_logits, ) From 8aaff93d2746bb4e46504ba06d6c6117d7aaa6d3 Mon Sep 17 00:00:00 2001 From: bingyechen Date: Sun, 10 May 2026 17:21:59 +0800 Subject: [PATCH 28/31] feat(engine): fix conflict --- areal/engine/megatron_engine.py | 82 ++++++++++++++++++++++++++++----- 1 file changed, 71 insertions(+), 11 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index b1eae9c488..5e44e524fb 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -55,7 +55,7 @@ init_custom_process_group, warmup_process_groups, ) -from areal.engine.core.model import disable_dropout_in_model +from areal.engine.core.model import disable_dropout_in_model, is_valid_vision_model from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper @@ -73,6 +73,7 @@ from areal.engine.megatron_utils.megatron_lora import get_vllm_lora_target_modules from areal.engine.megatron_utils.packed_context_parallel import ( packed_context_parallel_forward, + reassemble_cp_packed_logprobs, split_packed_seqs_for_context_parallel, ) from areal.engine.megatron_utils.pipeline_parallel import ( @@ -117,7 +118,7 @@ linear_cross_entropy_logprobs, linear_cross_entropy_logprobs_entropy, ) -from areal.utils.hf_utils import load_hf_tokenizer +from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer from areal.utils.lock import DistributedLock from areal.utils.network import find_free_ports, format_host_for_url, gethostip from areal.utils.offload import is_tms_enabled, torch_memory_saver @@ -191,6 +192,8 @@ def __init__(self, config: TrainEngineConfig): self.quantization_config: dict[str, int | str | list[str]] | None = None self.bridge_cls: str = getattr(self.mcore_config, "bridge_type", "mbridge") self.bridge_lora: MegatronBridgeLoRA | None = None + self.is_vision_model: bool = False + self.processor = None def create_process_group(self, parallel_strategy: ParallelStrategy | None = None): if parallel_strategy is None: @@ -326,6 +329,22 @@ def initialize(self, addr: str | None, ft_spec: FinetuneSpec, *args, **kwargs): self.parallel_strategy, self.hf_config, self.tf_config ) + self.is_vision_model = is_valid_vision_model(self.hf_config.model_type) + if self.is_vision_model: + if self.parallel_strategy.context_parallel_size > 1: + raise NotImplementedError( + "Context parallel (CP > 1) is not supported with VLM models. " + f"Got context_parallel_size={self.parallel_strategy.context_parallel_size} " + f"for model_type={self.hf_config.model_type}." + ) + self.processor, self.tokenizer = load_hf_processor_and_tokenizer( + self.config.path + ) + self.logger.info( + f"VLM model detected (type={self.hf_config.model_type}). " + f"Loaded processor and tokenizer." + ) + self.quantization_config = getattr( self.hf_config, "quantization_config", None ) @@ -772,6 +791,7 @@ def forward_step(batch_iter, model): model, mb_input.padded_mb, gather_cp_output=not cp_local, + is_vision_model=self.is_vision_model, ) if ( @@ -805,14 +825,11 @@ def _process_output(input_, output_): cp_labels = split_packed_seqs_for_context_parallel( rolled_ids, padded_cu_seqlens ) - cp_loss_mask = split_packed_seqs_for_context_parallel( - mb_input.padded_mb["loss_mask"], padded_cu_seqlens - ) - cp_cu_seqlens = padded_cu_seqlens // cp_size cp_inputs = dict(mb_input.orig_mb) cp_inputs["_cp_local_labels"] = cp_labels - cp_inputs["loss_mask"] = cp_loss_mask - cp_inputs["cu_seqlens"] = cp_cu_seqlens + cp_inputs["_cp_padded_cu_seqlens"] = padded_cu_seqlens + cp_inputs["_cp_padding_length"] = mb_input.padding_length + cp_inputs["_cp_old_cu_seqlens"] = mb_input.old_cu_seqlens return output, functools.partial(_process_output, cp_inputs) else: output = unpad_logits( @@ -893,9 +910,9 @@ def process_output( ) # Step 4: Optimizer step - result = self.optimizer_step() - - return result + stats = self.optimizer_step() + stats["num_micro_batches"] = len(mb_list.mbs) + return stats @torch.no_grad() def eval_batch( @@ -1860,6 +1877,7 @@ def _compute_logprobs_and_loss( ) else: cp_local_labels = inputs.get("_cp_local_labels") + cp_padded_cu_seqlens = inputs.get("_cp_padded_cu_seqlens") if cp_local_labels is not None: labels = cp_local_labels else: @@ -1900,6 +1918,48 @@ def _compute_logprobs_and_loss( ) vocab_min_logits = output.detach().min(-1).values.float() vocab_max_logits = output.detach().max(-1).values.float() + if cp_padded_cu_seqlens is not None: + logprobs = reassemble_cp_packed_logprobs( + logprobs, cp_padded_cu_seqlens + ) + entropy = reassemble_cp_packed_logprobs( + entropy, cp_padded_cu_seqlens + ) + vocab_min_logits = reassemble_cp_packed_logprobs( + vocab_min_logits, cp_padded_cu_seqlens + ) + vocab_max_logits = reassemble_cp_packed_logprobs( + vocab_max_logits, cp_padded_cu_seqlens + ) + cp_padding_length = inputs.get("_cp_padding_length", 0) + cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") + logprobs = unpad_logits( + logprobs, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + entropy = unpad_logits( + entropy, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + vocab_min_logits = unpad_logits( + vocab_min_logits, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + vocab_max_logits = unpad_logits( + vocab_max_logits, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + inputs = { + k: v for k, v in inputs.items() if not k.startswith("_cp_") + } loss = loss_fn( logprobs, entropy, From c5525d6749d272aab816402d0e99bb201461d4c9 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 10 May 2026 11:14:12 +0000 Subject: [PATCH 29/31] feat: precommit fix --- areal/engine/megatron_engine.py | 14 +- .../megatron_utils/fused_lce_capture.py | 5 +- .../utils/functional/linear_cross_entropy.py | 32 ++- areal/utils/kernel/kernels.py | 272 ++++++++++++++---- areal/utils/kernel/linear_cross_entropy.py | 26 +- areal/utils/stats_logger.py | 3 +- benchmark/bench_linear_cross_entropy.py | 4 +- docs/en/cli_reference.md | 5 + docs/zh/cli_reference.md | 5 + tests/test_linear_cross_entropy.py | 32 +-- tests/torchrun/run_lce_tp2.py | 4 +- 11 files changed, 287 insertions(+), 115 deletions(-) diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index eeb1fc2407..ccaa32444b 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -860,9 +860,7 @@ def forward_step(batch_iter, model): and not cp_local ) - with capture_lm_head_hidden( - model, enabled=should_capture - ) as capture: + with capture_lm_head_hidden(model, enabled=should_capture) as capture: output = packed_context_parallel_forward( model, mb_input.padded_mb, @@ -916,9 +914,7 @@ def _process_output(input_, output_): ) # Re-align Float16Module's fp32 hidden to lm-head weight dtype. if mb_input.orig_mb.get("_fused_lce_active", False): - fused_weight = mb_input.orig_mb.get( - FUSED_LCE_WEIGHT_KEY - ) + fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY) if ( fused_weight is not None and output.dtype != fused_weight.dtype @@ -2234,11 +2230,7 @@ def _compute_forward_result( fused_active = inputs.get("_fused_lce_active", False) fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) - if ( - fused_active - and fused_hidden is not None - and fused_weight is not None - ): + if fused_active and fused_hidden is not None and fused_weight is not None: logprobs = linear_cross_entropy_logprobs( fused_hidden, fused_weight, diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py index cb8d3064cc..fd58d83c7f 100644 --- a/areal/engine/megatron_utils/fused_lce_capture.py +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -81,13 +81,12 @@ def _is_compatible(post_process_module: torch.nn.Module) -> bool: if getattr(config, "use_mup", False): _warn_incompatible_once( "use_mup", - "Fused LCE disabled: MuP scaling is enabled (config.use_mup=True)." + "Fused LCE disabled: MuP scaling is enabled (config.use_mup=True).", ) return False if getattr(config, "mtp_num_layers", 0): _warn_incompatible_once( - "mtp", - "Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)." + "mtp", "Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)." ) return False diff --git a/areal/utils/functional/linear_cross_entropy.py b/areal/utils/functional/linear_cross_entropy.py index 2dfd318e6a..407abbd0f8 100644 --- a/areal/utils/functional/linear_cross_entropy.py +++ b/areal/utils/functional/linear_cross_entropy.py @@ -43,7 +43,9 @@ def _reference_logprobs_entropy( temperature: float, tp_group: dist.ProcessGroup | None, return_max_logits: bool = False, -) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> ( + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] +): flat_hidden = hidden.reshape(-1, hidden.shape[-1]) flat_labels = labels.reshape(-1) @@ -81,7 +83,9 @@ def linear_cross_entropy_logprobs_entropy( temperature: float = 1.0, tp_group: dist.ProcessGroup | None = None, return_max_logits: bool = False, -) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> ( + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] +): """Compute per-token log-prob and entropy via the fused kernel. Falls back to the materialised reference path when the fused kernel is @@ -120,7 +124,12 @@ def linear_cross_entropy_logprobs_entropy( try: if return_max_logits: logprobs, entropy, max_logits = linear_cross_entropy( - hidden, weight, labels, temperature, "none", tp_group, + hidden, + weight, + labels, + temperature, + "none", + tp_group, return_max_logits=True, ) return ( @@ -129,17 +138,28 @@ def linear_cross_entropy_logprobs_entropy( max_logits.reshape(leading_shape), ) logprobs, entropy = linear_cross_entropy( - hidden, weight, labels, temperature, "none", tp_group, + hidden, + weight, + labels, + temperature, + "none", + tp_group, ) return logprobs.reshape(leading_shape), entropy.reshape(leading_shape) except Exception as exc: logger.warning( - "Fused LCE kernel raised %s; falling back to reference.", exc, + "Fused LCE kernel raised %s; falling back to reference.", + exc, ) if return_max_logits: logprobs, entropy, max_logits = _reference_logprobs_entropy( - hidden, weight, labels, temperature, tp_group, return_max_logits=True, + hidden, + weight, + labels, + temperature, + tp_group, + return_max_logits=True, ) return ( logprobs.reshape(leading_shape), diff --git a/areal/utils/kernel/kernels.py b/areal/utils/kernel/kernels.py index d4965e76c5..00575ef9ff 100644 --- a/areal/utils/kernel/kernels.py +++ b/areal/utils/kernel/kernels.py @@ -85,7 +85,9 @@ def alloc_fn(size: int, alignment: int, stream: int | None): import triton.runtime._allocation as _triton_allocation - if isinstance(getattr(_triton_allocation, "_allocator", None), contextvars.ContextVar): + if isinstance( + getattr(_triton_allocation, "_allocator", None), contextvars.ContextVar + ): _triton_allocation._allocator = contextvars.ContextVar( _triton_allocation._allocator.name, default=alloc_fn, @@ -109,7 +111,13 @@ def get_entropy_reduction_enum_number(reduction: str) -> int: @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, num_stages=3, num_warps=8)], + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=8, + ) + ], key=["num_tokens", "hidden_size", "vocab_size"], ) @triton.jit @@ -176,7 +184,9 @@ def efficient_entropy_kernel_general_mainloop( ) else: - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) # load labels for this block labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) @@ -195,7 +205,9 @@ def efficient_entropy_kernel_general_mainloop( logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) if not USE_TMA: # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) # iterate over K dimension for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): @@ -208,14 +220,18 @@ def efficient_entropy_kernel_general_mainloop( # load the next block of hidden and weight _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) _weight = tl.load( weight_ptrs, mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[:, None] < (min((pid_n + 1) * vocab_per_split, vocab_size))), + & ( + offs_bn[:, None] + < (min((pid_n + 1) * vocab_per_split, vocab_size)) + ), other=0.0, ) @@ -253,13 +269,27 @@ def efficient_entropy_kernel_general_mainloop( offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_max_n = pid_n maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m - tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + tl.store( + maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits) + ) # store entropy accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m - tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) - entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m - tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + tl.store( + accu_ptrs, + _accu, + mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits), + ) + entropy_b_ptrs = ( + entropy_b_ptr + + offs_max_n * stride_entropy_b_n + + offs_max_m * stride_entropy_b_m + ) + tl.store( + entropy_b_ptrs, + _entropy_b, + mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits), + ) # store logprobs vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size @@ -270,7 +300,10 @@ def efficient_entropy_kernel_general_mainloop( tl.store(global_logprobs_ptrs, _logprobs, mask=mask) -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) @triton.jit def efficient_entropy_triton_kernel_epilogue( max_ptr, @@ -308,16 +341,34 @@ def efficient_entropy_triton_kernel_epilogue( global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + max_ptrs = ( + max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + ) - _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + _max = tl.load( + max_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) - accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n - _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + accu_ptrs = ( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + ) + _accu = tl.load( + accu_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) - entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + entropy_b_ptrs = ( + entropy_b_ptr + + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n + ) _entropy_b = tl.load( - entropy_b_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0 + entropy_b_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, ) # local reduction @@ -328,7 +379,9 @@ def efficient_entropy_triton_kernel_epilogue( _scale = tl.exp(_max - global_max[:, None]) _coeff = tl.exp(_max_old - global_max) global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) - global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum( + _scale * _entropy_b, axis=1 + ) # store maximum_ptrs = global_max_ptr + offs_m * stride_global_max @@ -336,7 +389,11 @@ def efficient_entropy_triton_kernel_epilogue( # store entropy_b global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b - tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + tl.store( + global_entropy_b_ptr + offs_m * stride_global_entropy_b, + global_entropy_b, + mask=offs_m < num_tokens, + ) # store entropy global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu @@ -353,7 +410,10 @@ def efficient_entropy_triton_kernel_epilogue( tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) -@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) @triton.jit def efficient_entropy_triton_kernel_epilogue_tp( num_tokens, @@ -390,17 +450,23 @@ def efficient_entropy_triton_kernel_epilogue_tp( offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) _reduced_max = tl.load( - reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + offs_n[None, :] * stride_reduced_max_n, + reduced_max_ptr + + offs_m[:, None] * stride_reduced_max_m + + offs_n[None, :] * stride_reduced_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) _original_max = tl.load( - original_max_ptr + offs_m[:, None] * stride_original_max_m + offs_n[None, :] * stride_original_max_n, + original_max_ptr + + offs_m[:, None] * stride_original_max_m + + offs_n[None, :] * stride_original_max_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) _accu = tl.load( - accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + accu_ptr + + offs_m[:, None] * stride_accu_m + + offs_n[None, :] * stride_accu_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) @@ -417,16 +483,32 @@ def efficient_entropy_triton_kernel_epilogue_tp( # update entropy_b _entropy_b = tl.load( - entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n, + entropy_b_ptr + + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0, ) - global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum( + _scale * _entropy_b, axis=1 + ) # store - tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) - tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) - tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + tl.store( + global_max_ptr + offs_m * stride_global_max, + global_max, + mask=offs_m < num_tokens, + ) + tl.store( + global_accu_ptr + offs_m * stride_global_accu, + global_accu, + mask=offs_m < num_tokens, + ) + tl.store( + global_entropy_b_ptr + offs_m * stride_global_entropy_b, + global_entropy_b, + mask=offs_m < num_tokens, + ) @triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) @@ -451,20 +533,30 @@ def efficient_entropy_triton_epilogue_tp_update( offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) - accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + accumulate = tl.load( + accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens + ) - entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.load( + entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens + ) entropy_b = tl.fdiv(entropy_b, accumulate) - tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + tl.store( + entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens + ) entropy = tl.log(accumulate) + maximum - entropy_b tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) - logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = tl.load( + logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens + ) logprobs = maximum + tl.log(accumulate) - logprobs logprobs = -1 * logprobs - tl.store(logprobs_out_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + tl.store( + logprobs_out_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens + ) _dedicated_stream, _dedicated_events = None, None @@ -486,9 +578,13 @@ def efficient_entropy_forward( assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) - _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + _world_size = ( + 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + ) - if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + if dist_process_group is not None and not hasattr( + efficient_entropy_forward, "_initialized" + ): global _dedicated_stream, _dedicated_events _dedicated_stream = get_torch_device().Stream(hidden.device) _dedicated_events = [get_torch_device().Event() for _ in range(2)] @@ -510,19 +606,31 @@ def efficient_entropy_forward( assert logprobs.is_contiguous() and entropy.is_contiguous() maximum = torch.empty_like(entropy) - accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b = torch.empty( + (num_tokens * 2,), device=hidden.device, dtype=torch.float32 + ) accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) accumulate = accumulate_and_entropy_b_view[0, :] entropy_b = accumulate_and_entropy_b_view[1, :] - assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + assert ( + maximum.is_contiguous() + and accumulate.is_contiguous() + and entropy_b.is_contiguous() + ) vocab_per_split = 1024 assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _max = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + _accu = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + _entropy_b = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) _logprobs = logprobs @@ -559,7 +667,9 @@ def mainloop_grid(meta): _logprobs, _logprobs.stride(0), 1.0 / temperature, - USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, + USE_TMA=SUPPORT_CUDA_TMA + and hidden.stride(1) == 1 + and weight.stride(1) == 1, ) else: raise AssertionError("Triton is required for efficient entropy kernel") @@ -627,7 +737,9 @@ def epilogue_grid(meta): ) get_torch_device().current_stream().wait_event(_dedicated_events[1]) - dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + dist.all_reduce( + accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group + ) # update logprobs & entropy efficient_entropy_triton_epilogue_tp_update[epilogue_grid]( @@ -651,7 +763,12 @@ def epilogue_grid(meta): @triton.autotune( configs=[ triton.Config( - {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, num_stages=3, num_warps=8, ), @@ -710,14 +827,28 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) - maximum = tl.load(maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0) - accu = tl.load(accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6) + maximum = tl.load( + maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0 + ) + accu = tl.load( + accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6 + ) accu_rcp = tl.fdiv(1.0, accu) - d_entropy = tl.load(d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0) - d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens, other=0.0) + d_entropy = tl.load( + d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0 + ) + d_logprobs = tl.load( + d_logprobs_ptr + offs_am * stride_d_logprobs, + mask=offs_am < num_tokens, + other=0.0, + ) d_logprobs = -1 * d_logprobs - entropy_b = tl.load(entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0) - labels = tl.load(labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0) + entropy_b = tl.load( + entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0 + ) + labels = tl.load( + labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0 + ) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) @@ -736,8 +867,12 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], ) else: - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k) + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): @@ -748,12 +883,14 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( else: _hidden = tl.load( hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), other=0.0, ) _weight = tl.load( weight_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[:, None] < vocab_right_bound), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_right_bound), other=0.0, ) hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k @@ -765,7 +902,11 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) - d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) d_logits *= rcp_temperature @@ -774,7 +915,11 @@ def efficient_entropy_backward_kernel_general_d_logits_split_N( mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) tl.store( - d_logits_ptr + offs_am[:, None] * stride_d_logits_m + result_offs_n[None, :] * stride_d_logits_n, d_logits, mask + d_logits_ptr + + offs_am[:, None] * stride_d_logits_m + + result_offs_n[None, :] * stride_d_logits_n, + d_logits, + mask, ) @@ -799,7 +944,9 @@ def efficient_entropy_backward( assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) - _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + _world_size = ( + 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + ) num_tokens, hidden_size = hidden.shape num_tokens = labels.shape[0] @@ -831,11 +978,16 @@ def efficient_entropy_backward( vocab_per_split = 9504 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - _d_logits = torch.empty((num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype).contiguous() + _d_logits = torch.empty( + (num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype + ).contiguous() assert _d_logits.is_contiguous() def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]),) + return ( + triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]), + ) for split_idx in range(num_splits): efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( @@ -867,7 +1019,9 @@ def d_logits_grid(meta): _d_logits.stride(0), _d_logits.stride(1), 1.0 / temperature, - USE_TMA=SUPPORT_CUDA_TMA and hidden.stride(1) == 1 and weight.stride(1) == 1, + USE_TMA=SUPPORT_CUDA_TMA + and hidden.stride(1) == 1 + and weight.stride(1) == 1, ) split_start = split_idx * vocab_per_split diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/utils/kernel/linear_cross_entropy.py index 6e7c037ae0..3411a194cb 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/utils/kernel/linear_cross_entropy.py @@ -43,7 +43,10 @@ def forward( reduction: str | None = "none", dist_process_group: dist.ProcessGroup | None = None, return_max_logits: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ): if not isinstance(temperature, float): temperature = float(temperature) if not isinstance(reduction, str): @@ -62,9 +65,9 @@ def forward( assert hidden.is_cuda and weight.is_cuda and labels.is_cuda, ( "LinearCrossEntropy requires CUDA inputs" ) - assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous(), ( - "LinearCrossEntropy requires contiguous tensors" - ) + assert ( + hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + ), "LinearCrossEntropy requires contiguous tensors" ( logprobs, @@ -81,9 +84,7 @@ def forward( dist_process_group, ) - ctx.save_for_backward( - hidden, weight, labels, _maximum, _accumulate, _entropy_b - ) + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) ctx.original_hidden_shape = original_hidden_shape ctx.REDUCTION = REDUCTION ctx.dist_process_group = dist_process_group @@ -165,9 +166,16 @@ def linear_cross_entropy( reduction: str = "none", dist_process_group: dist.ProcessGroup | None = None, return_max_logits: bool = False, -) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> ( + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] +): """Functional wrapper around :class:`LinearCrossEntropy`.""" return LinearCrossEntropy.apply( - hidden, weight, labels, temperature, reduction, dist_process_group, + hidden, + weight, + labels, + temperature, + reduction, + dist_process_group, return_max_logits, ) diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 220e6d3920..33dd809053 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -8,9 +8,10 @@ import swanlab import torch.distributed as dist import trackio -import wandb from tensorboardX import SummaryWriter +import wandb + from areal.api import FinetuneSpec from areal.api.cli_args import BaseExperimentConfig, StatsLoggerConfig from areal.utils import logging diff --git a/benchmark/bench_linear_cross_entropy.py b/benchmark/bench_linear_cross_entropy.py index a6b1424dd9..080d6b05dd 100644 --- a/benchmark/bench_linear_cross_entropy.py +++ b/benchmark/bench_linear_cross_entropy.py @@ -149,9 +149,7 @@ def _check_correctness(hidden, weight, labels, dtype, tp_group=None): rtol, atol = 2e-2, 2e-2 torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) - torch.testing.assert_close( - fused_ent.float(), ref_ent.float(), rtol=rtol, atol=atol - ) + torch.testing.assert_close(fused_ent.float(), ref_ent.float(), rtol=rtol, atol=atol) torch.testing.assert_close(fused_dh.float(), ref_dh.float(), rtol=rtol, atol=atol) torch.testing.assert_close(fused_dw.float(), ref_dw.float(), rtol=rtol, atol=atol) diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 0b217a2673..8c6260daff 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -370,6 +370,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -442,6 +443,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -488,6 +490,7 @@ Core configuration for model training, including optimization and backend settin | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -993,6 +996,7 @@ fields. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -1257,6 +1261,7 @@ Configuration class: TeacherConfig | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index e9e6f11180..31bf3cac24 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -368,6 +368,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -440,6 +441,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -486,6 +488,7 @@ Core configuration for model training, including optimization and backend settin | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -991,6 +994,7 @@ fields. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -1255,6 +1259,7 @@ Configuration class: TeacherConfig | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py index fdc2abd8b1..dec29bd46c 100644 --- a/tests/test_linear_cross_entropy.py +++ b/tests/test_linear_cross_entropy.py @@ -70,20 +70,14 @@ def _make_inputs( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: g = torch.Generator(device=device).manual_seed(seed) hidden = ( - torch.randn( - num_tokens, hidden_size, dtype=dtype, device=device, generator=g - ) + torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) * 0.02 ) weight = ( - torch.randn( - vocab_size, hidden_size, dtype=dtype, device=device, generator=g - ) + torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) * 0.02 ) - labels = torch.randint( - 0, vocab_size, (num_tokens,), device=device, generator=g - ) + labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) return hidden.contiguous(), weight.contiguous(), labels.contiguous() @@ -210,9 +204,7 @@ def test_linear_cross_entropy_backward_matches_reference( # kernel's d_weight accumulates ``num_tokens`` partial products, so we # use a slightly looser absolute tolerance for d_weight at the largest # shape; rtol stays tight to catch directional errors. - torch.testing.assert_close( - hidden_a.grad, hidden_b.grad, rtol=1e-4, atol=1e-4 - ) + torch.testing.assert_close(hidden_a.grad, hidden_b.grad, rtol=1e-4, atol=1e-4) weight_atol = 1e-4 if num_tokens <= 512 else 5e-4 torch.testing.assert_close( weight_a.grad, weight_b.grad, rtol=1e-4, atol=weight_atol @@ -275,9 +267,7 @@ def _run_lce_tp2_with_torchrun( ) except subprocess.CalledProcessError as e: pytest.fail( - "TP=2 LCE torchrun test failed:\n" - f"STDOUT:\n{e.stdout}\n" - f"STDERR:\n{e.stderr}" + f"TP=2 LCE torchrun test failed:\nSTDOUT:\n{e.stdout}\nSTDERR:\n{e.stderr}" ) @@ -322,9 +312,7 @@ def test_linear_cross_entropy_tp2_performance_benchmark( vocab_size: int, ) -> None: """TP=2 fused vs TP-materialised forward+backward time and peak memory.""" - _run_lce_tp2_with_torchrun( - "performance", num_tokens, hidden_size, vocab_size - ) + _run_lce_tp2_with_torchrun("performance", num_tokens, hidden_size, vocab_size) # --------------------------------------------------------------------------- @@ -342,7 +330,9 @@ def _peak_memory_mb(fn, *args, **kwargs) -> tuple[float, float]: start.record() out = fn(*args, **kwargs) if isinstance(out, tuple): - loss = sum(t.float().sum() for t in out if t.requires_grad or t.grad_fn is not None) + loss = sum( + t.float().sum() for t in out if t.requires_grad or t.grad_fn is not None + ) else: loss = out.float().sum() loss.backward() @@ -395,9 +385,7 @@ def test_linear_cross_entropy_performance_benchmark( captured numbers are also printed for human review. """ dtype = torch.bfloat16 - hidden, weight, labels = _make_inputs( - num_tokens, hidden_size, vocab_size, dtype - ) + hidden, weight, labels = _make_inputs(num_tokens, hidden_size, vocab_size, dtype) # warm-up for _ in range(2): diff --git a/tests/torchrun/run_lce_tp2.py b/tests/torchrun/run_lce_tp2.py index e19fb84821..1e783f0821 100644 --- a/tests/torchrun/run_lce_tp2.py +++ b/tests/torchrun/run_lce_tp2.py @@ -240,7 +240,9 @@ def _test_tp2_performance( def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--test_type", choices=["correctness", "performance"], required=True) + parser.add_argument( + "--test_type", choices=["correctness", "performance"], required=True + ) parser.add_argument("--num_tokens", type=int, required=True) parser.add_argument("--hidden_size", type=int, required=True) parser.add_argument("--vocab_size", type=int, required=True) From 9a6b467c9f79f35d3dd65d10f017682a12018bf8 Mon Sep 17 00:00:00 2001 From: TaoZex <45089228+TaoZex@users.noreply.github.com> Date: Fri, 15 May 2026 00:56:29 +0800 Subject: [PATCH 30/31] feat: fix by comment --- areal/api/cli_args.py | 17 +- areal/engine/megatron_engine.py | 188 ++++++++++-------- areal/models/kernel/__init__.py | 34 ++++ .../kernel/functional.py} | 2 +- areal/{utils => models}/kernel/kernels.py | 0 .../kernel/linear_cross_entropy.py | 4 +- areal/utils/functional/__init__.py | 7 - areal/utils/kernel/__init__.py | 22 -- .../bench_linear_cross_entropy.py | 8 +- examples/math/gsm8k_grpo_megatron.yaml | 2 +- tests/test_linear_cross_entropy.py | 10 +- tests/torchrun/run_lce_tp2.py | 2 +- 12 files changed, 158 insertions(+), 138 deletions(-) create mode 100644 areal/models/kernel/__init__.py rename areal/{utils/functional/linear_cross_entropy.py => models/kernel/functional.py} (98%) rename areal/{utils => models}/kernel/kernels.py (100%) rename areal/{utils => models}/kernel/linear_cross_entropy.py (98%) delete mode 100644 areal/utils/kernel/__init__.py rename benchmark/{ => kernels}/bench_linear_cross_entropy.py (97%) diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 07971fec74..3681bd2458 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -929,6 +929,15 @@ class MegatronEngineConfig: }, ) + use_fused_linear_ce: bool = field( + default=False, + metadata={ + "help": "Fuse the linear projection with cross-entropy so that the " + "[num_tokens, vocab_size] logits tensor is never materialised. " + "Only effective for the Megatron actor backend with parallel_output=True." + }, + ) + class SchedulingStrategyType(str, Enum): separation = "separation" @@ -1143,14 +1152,6 @@ class TrainEngineConfig: default=False, metadata={"help": "Enable tree training with flex attention module."}, ) - use_fused_linear_ce: bool = field( - default=False, - metadata={ - "help": "Fuse the linear projection with cross-entropy so that the " - "[num_tokens, vocab_size] logits tensor is never materialised. " - "Only effective for the Megatron actor backend with parallel_output=True." - }, - ) # Scheduling scheduling_spec: tuple[SchedulingSpec, ...] = field( diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index ccaa32444b..4f11162964 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -88,6 +88,10 @@ ) from areal.infra.dist_rollout import DistRolloutCoordinator from areal.infra.platforms import current_platform +from areal.models.kernel import ( + linear_cross_entropy_logprobs, + linear_cross_entropy_logprobs_entropy, +) from areal.models.mcore.hf_load import load_weights_from_hf_with_mbridge_fast from areal.models.mcore.hf_save import ( save_critic_value_head, @@ -125,8 +129,6 @@ from areal.utils.functional import ( gather_logprobs, gather_logprobs_entropy, - linear_cross_entropy_logprobs, - linear_cross_entropy_logprobs_entropy, ) from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer from areal.utils.lock import DistributedLock @@ -816,7 +818,7 @@ def forward_backward_batch( self._ensure_ready() use_fused_lce = ( - getattr(self.config, "use_fused_linear_ce", False) + getattr(self.config.megatron, "use_fused_linear_ce", False) and not self.config.is_critic and not self.enable_tree_training ) @@ -2107,90 +2109,9 @@ def _compute_logprobs_and_loss( else None, ) else: - cp_local_labels = inputs.get("_cp_local_labels") - cp_padded_cu_seqlens = inputs.get("_cp_padded_cu_seqlens") - if cp_local_labels is not None: - labels = cp_local_labels - else: - labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) - fused_active = inputs.get("_fused_lce_active", False) - fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) - fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) - if ( - fused_active - and fused_hidden is not None - and fused_weight is not None - ): - logprobs, entropy, vocab_max_logits = ( - linear_cross_entropy_logprobs_entropy( - fused_hidden, - fused_weight, - labels, - temperature=self.config.temperature, - tp_group=mpu.get_tensor_model_parallel_group() - if mpu.get_tensor_model_parallel_world_size() > 1 - else None, - return_max_logits=True, - ) - ) - # Fused kernel does not track per-token vocab min logits; - # skip the min telemetry rather than report a misleading - # proxy. Consumers must guard ``vocab_min_logits`` and - # ``vocab_max_logits`` independently. - vocab_min_logits = None - else: - logprobs, entropy = gather_logprobs_entropy( - output, - labels, - temperature=self.config.temperature, - tp_group=mpu.get_tensor_model_parallel_group() - if mpu.get_tensor_model_parallel_world_size() > 1 - else None, - ) - vocab_min_logits = output.detach().min(-1).values.float() - vocab_max_logits = output.detach().max(-1).values.float() - if cp_padded_cu_seqlens is not None: - logprobs = reassemble_cp_packed_logprobs( - logprobs, cp_padded_cu_seqlens - ) - entropy = reassemble_cp_packed_logprobs( - entropy, cp_padded_cu_seqlens - ) - vocab_min_logits = reassemble_cp_packed_logprobs( - vocab_min_logits, cp_padded_cu_seqlens - ) - vocab_max_logits = reassemble_cp_packed_logprobs( - vocab_max_logits, cp_padded_cu_seqlens - ) - cp_padding_length = inputs.get("_cp_padding_length", 0) - cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") - logprobs = unpad_logits( - logprobs, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - entropy = unpad_logits( - entropy, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - vocab_min_logits = unpad_logits( - vocab_min_logits, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - vocab_max_logits = unpad_logits( - vocab_max_logits, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - inputs = { - k: v for k, v in inputs.items() if not k.startswith("_cp_") - } + logprobs, entropy, vocab_min_logits, vocab_max_logits, inputs = ( + self._compute_packed_logprobs_entropy(output, inputs) + ) loss = loss_fn( logprobs, entropy, @@ -2205,6 +2126,99 @@ def _compute_logprobs_and_loss( loss_scale = local_weight / total_loss_weight * loss_multiplier return loss * loss_scale + def _compute_packed_logprobs_entropy( + self, + output: torch.Tensor, + inputs: dict[str, Any], + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + dict[str, Any], + ]: + """Compute per-token logprobs/entropy for the non-tree packed path. + + Returns ``(logprobs, entropy, vocab_min_logits, vocab_max_logits, inputs)``. + ``inputs`` is returned because the materialised CP branch strips the + ``_cp_*`` keys before the loss is invoked. + """ + cp_local_labels = inputs.get("_cp_local_labels") + cp_padded_cu_seqlens = inputs.get("_cp_padded_cu_seqlens") + if cp_local_labels is not None: + labels = cp_local_labels + else: + labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) + + tp_group = ( + mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None + ) + + # Fused LCE fast path: logits are never materialised, so we skip the + # min telemetry rather than report a misleading proxy. + fused_active = inputs.get("_fused_lce_active", False) + fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) + fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) + if fused_active and fused_hidden is not None and fused_weight is not None: + logprobs, entropy, vocab_max_logits = linear_cross_entropy_logprobs_entropy( + fused_hidden, + fused_weight, + labels, + temperature=self.config.temperature, + tp_group=tp_group, + return_max_logits=True, + ) + return logprobs, entropy, None, vocab_max_logits, inputs + + # Materialised path. + logprobs, entropy = gather_logprobs_entropy( + output, + labels, + temperature=self.config.temperature, + tp_group=tp_group, + ) + vocab_min_logits = output.detach().min(-1).values.float() + vocab_max_logits = output.detach().max(-1).values.float() + if cp_padded_cu_seqlens is not None: + logprobs = reassemble_cp_packed_logprobs(logprobs, cp_padded_cu_seqlens) + entropy = reassemble_cp_packed_logprobs(entropy, cp_padded_cu_seqlens) + vocab_min_logits = reassemble_cp_packed_logprobs( + vocab_min_logits, cp_padded_cu_seqlens + ) + vocab_max_logits = reassemble_cp_packed_logprobs( + vocab_max_logits, cp_padded_cu_seqlens + ) + cp_padding_length = inputs.get("_cp_padding_length", 0) + cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") + logprobs = unpad_logits( + logprobs, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + entropy = unpad_logits( + entropy, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + vocab_min_logits = unpad_logits( + vocab_min_logits, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + vocab_max_logits = unpad_logits( + vocab_max_logits, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + inputs = {k: v for k, v in inputs.items() if not k.startswith("_cp_")} + return logprobs, entropy, vocab_min_logits, vocab_max_logits, inputs + def _compute_forward_result( self, output: torch.Tensor, diff --git a/areal/models/kernel/__init__.py b/areal/models/kernel/__init__.py new file mode 100644 index 0000000000..0d61c6f2cf --- /dev/null +++ b/areal/models/kernel/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Triton-based fused linear-cross-entropy kernels for AReaL. + +The kernel implementations under :mod:`areal.models.kernel.kernels` fuse +the matmul with cross-entropy reduction, preserving numerical semantics +while avoiding materialization of the ``[num_tokens, vocab_size]`` logits +tensor. The :class:`LinearCrossEntropy` autograd function exposed below +provides a memory-efficient drop-in replacement for the materialized +``logits = hidden @ weight.T`` followed by softmax / log-softmax / +entropy computation. + +The :mod:`areal.models.kernel.functional` submodule additionally provides +high-level wrappers (``linear_cross_entropy_logprobs`` / +``linear_cross_entropy_logprobs_entropy``) that fall back to a +materialized reference implementation when the fused kernel is +unavailable. +""" + +from areal.models.kernel.functional import ( + linear_cross_entropy_logprobs, + linear_cross_entropy_logprobs_entropy, +) +from areal.models.kernel.linear_cross_entropy import ( + LinearCrossEntropy, + linear_cross_entropy, +) + +__all__ = [ + "LinearCrossEntropy", + "linear_cross_entropy", + "linear_cross_entropy_logprobs", + "linear_cross_entropy_logprobs_entropy", +] diff --git a/areal/utils/functional/linear_cross_entropy.py b/areal/models/kernel/functional.py similarity index 98% rename from areal/utils/functional/linear_cross_entropy.py rename to areal/models/kernel/functional.py index 407abbd0f8..f08b5108e0 100644 --- a/areal/utils/functional/linear_cross_entropy.py +++ b/areal/models/kernel/functional.py @@ -113,7 +113,7 @@ def linear_cross_entropy_logprobs_entropy( leading_shape = labels.shape if _kernel_available(): - from areal.utils.kernel.linear_cross_entropy import linear_cross_entropy + from areal.models.kernel.linear_cross_entropy import linear_cross_entropy if hidden.device.type != "cuda": logger.warning( diff --git a/areal/utils/kernel/kernels.py b/areal/models/kernel/kernels.py similarity index 100% rename from areal/utils/kernel/kernels.py rename to areal/models/kernel/kernels.py diff --git a/areal/utils/kernel/linear_cross_entropy.py b/areal/models/kernel/linear_cross_entropy.py similarity index 98% rename from areal/utils/kernel/linear_cross_entropy.py rename to areal/models/kernel/linear_cross_entropy.py index 3411a194cb..32bbca3c86 100644 --- a/areal/utils/kernel/linear_cross_entropy.py +++ b/areal/models/kernel/linear_cross_entropy.py @@ -52,7 +52,7 @@ def forward( if not isinstance(reduction, str): raise TypeError(f"reduction must be str, got {type(reduction)}") - from areal.utils.kernel import kernels + from areal.models.kernel import kernels REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) @@ -111,7 +111,7 @@ def backward( dentropy: torch.Tensor, dmax_logits: torch.Tensor | None = None, ) -> tuple: - from areal.utils.kernel import kernels + from areal.models.kernel import kernels ( hidden, diff --git a/areal/utils/functional/__init__.py b/areal/utils/functional/__init__.py index d1182361b8..c91c3ff2b6 100644 --- a/areal/utils/functional/__init__.py +++ b/areal/utils/functional/__init__.py @@ -11,10 +11,6 @@ reward_overlong_penalty, sapo_loss_fn, ) -from areal.utils.functional.linear_cross_entropy import ( - linear_cross_entropy_logprobs, - linear_cross_entropy_logprobs_entropy, -) from areal.utils.functional.vocab_parallel import ( gather_logprobs, gather_logprobs_entropy, @@ -34,7 +30,4 @@ # vocab_parallel.py "gather_logprobs", "gather_logprobs_entropy", - # linear_cross_entropy.py (fused linear + CE/entropy via Triton) - "linear_cross_entropy_logprobs", - "linear_cross_entropy_logprobs_entropy", ] diff --git a/areal/utils/kernel/__init__.py b/areal/utils/kernel/__init__.py deleted file mode 100644 index 9ecc81d669..0000000000 --- a/areal/utils/kernel/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -Triton-based fused linear-cross-entropy kernels for AReaL. - -The kernel implementations under :mod:`areal.utils.kernel.kernels` fuse -the matmul with cross-entropy reduction, preserving numerical semantics -while avoiding materialization of the ``[num_tokens, vocab_size]`` logits -tensor. The :class:`LinearCrossEntropy` autograd function exposed below -provides a memory-efficient drop-in replacement for the materialized -``logits = hidden @ weight.T`` followed by softmax / log-softmax / -entropy computation. -""" - -from areal.utils.kernel.linear_cross_entropy import ( - LinearCrossEntropy, - linear_cross_entropy, -) - -__all__ = [ - "LinearCrossEntropy", - "linear_cross_entropy", -] diff --git a/benchmark/bench_linear_cross_entropy.py b/benchmark/kernels/bench_linear_cross_entropy.py similarity index 97% rename from benchmark/bench_linear_cross_entropy.py rename to benchmark/kernels/bench_linear_cross_entropy.py index 080d6b05dd..5b470f1929 100644 --- a/benchmark/bench_linear_cross_entropy.py +++ b/benchmark/kernels/bench_linear_cross_entropy.py @@ -8,7 +8,7 @@ Usage:: # Qwen3 single-GPU full-vocab benchmark - uv run python -m benchmark.bench_linear_cross_entropy \\ + uv run python -m benchmark.kernels.bench_linear_cross_entropy \\ --mode both --tokens 2048 --hidden 4096 --vocab 152064 \\ --dtype bfloat16 --warmup 5 --iters 15 --check-correctness @@ -16,14 +16,14 @@ # [tokens, vocab/tp] logits and uses vocab-parallel reductions. uv run torchrun --nproc_per_node=2 --nnodes=1 \\ --master-addr=localhost --master_port=29501 \\ - -m benchmark.bench_linear_cross_entropy \\ + -m benchmark.kernels.bench_linear_cross_entropy \\ --mode both --tp-size 2 --tokens 2048 --hidden 4096 --vocab 152064 \\ --dtype bfloat16 --warmup 5 --iters 15 --check-correctness # Qwen3 TP=4 benchmark uv run torchrun --nproc_per_node=4 --nnodes=1 \\ --master-addr=localhost --master_port=29501 \\ - -m benchmark.bench_linear_cross_entropy \\ + -m benchmark.kernels.bench_linear_cross_entropy \\ --mode both --tp-size 4 --tokens 2048 --hidden 4096 --vocab 152064 \\ --dtype bfloat16 --warmup 5 --iters 15 --check-correctness """ @@ -124,7 +124,7 @@ def _ref_step(hidden, weight, labels, temperature=1.0, tp_group=None): def _fused_step(hidden, weight, labels, temperature=1.0, tp_group=None): - from areal.utils.kernel import linear_cross_entropy + from areal.models.kernel import linear_cross_entropy h = hidden.detach().clone().requires_grad_(True) w = weight.detach().clone().requires_grad_(True) diff --git a/examples/math/gsm8k_grpo_megatron.yaml b/examples/math/gsm8k_grpo_megatron.yaml index 0cb03c0830..2482b297bc 100644 --- a/examples/math/gsm8k_grpo_megatron.yaml +++ b/examples/math/gsm8k_grpo_megatron.yaml @@ -43,7 +43,7 @@ actor: backend: "megatron:d4p1t1" experiment_name: ${experiment_name} trial_name: ${trial_name} - path: /workspace/models/Qwen3-0.6B + path: Qwen/Qwen2.5-1.5B-Instruct init_from_scratch: false disable_dropout: true gradient_checkpointing: false diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py index dec29bd46c..90d82f29d0 100644 --- a/tests/test_linear_cross_entropy.py +++ b/tests/test_linear_cross_entropy.py @@ -3,7 +3,7 @@ Correctness + performance tests for the fused linear-cross-entropy kernel. The test suite verifies that -:func:`areal.utils.functional.linear_cross_entropy_logprobs_entropy` produces +:func:`areal.models.kernel.linear_cross_entropy_logprobs_entropy` produces results numerically equivalent to the materialised ``logits @ weight`` + ``log_softmax`` reference, and that it provides a measurable wall-clock / memory benefit over the reference path on representative LLM shapes. @@ -101,7 +101,7 @@ def test_linear_cross_entropy_correctness( dtype: torch.dtype, ) -> None: """Fused forward output must match the materialised reference.""" - from areal.utils.functional import linear_cross_entropy_logprobs_entropy + from areal.models.kernel import linear_cross_entropy_logprobs_entropy hidden, weight, labels = _make_inputs(num_tokens, hidden_size, vocab_size, dtype) @@ -133,7 +133,7 @@ def test_linear_cross_entropy_correctness( @pytest.mark.parametrize("temperature", [0.7, 1.0, 1.5]) def test_linear_cross_entropy_temperature(temperature: float) -> None: """Temperature scaling matches the reference for non-trivial values.""" - from areal.utils.functional import linear_cross_entropy_logprobs_entropy + from areal.models.kernel import linear_cross_entropy_logprobs_entropy hidden, weight, labels = _make_inputs( num_tokens=128, hidden_size=512, vocab_size=4096, dtype=torch.float32 @@ -174,7 +174,7 @@ def test_linear_cross_entropy_backward_matches_reference( drift in the fused d_hidden / d_weight kernels is caught at scale rather than only on toy inputs. """ - from areal.utils.kernel import linear_cross_entropy + from areal.models.kernel import linear_cross_entropy hidden_a, weight_a, labels = _make_inputs( num_tokens, hidden_size, vocab_size, torch.float32 @@ -355,7 +355,7 @@ def _run_reference_forward_backward(hidden, weight, labels, temperature): def _run_fused_forward_backward(hidden, weight, labels, temperature): - from areal.utils.kernel import linear_cross_entropy + from areal.models.kernel import linear_cross_entropy h = hidden.detach().clone().requires_grad_(True) w = weight.detach().clone().requires_grad_(True) diff --git a/tests/torchrun/run_lce_tp2.py b/tests/torchrun/run_lce_tp2.py index 1e783f0821..e489b57c29 100644 --- a/tests/torchrun/run_lce_tp2.py +++ b/tests/torchrun/run_lce_tp2.py @@ -7,8 +7,8 @@ import torch.distributed as dist from areal.infra.platforms import current_platform +from areal.models.kernel import linear_cross_entropy from areal.utils.functional import gather_logprobs_entropy -from areal.utils.kernel import linear_cross_entropy def _setup_distributed_environment() -> None: From 7fcc99b8c1fbacf0dc62ce6b5a0c66637dc966bc Mon Sep 17 00:00:00 2001 From: TaoZex <2633363995@qq.com> Date: Fri, 15 May 2026 01:10:56 +0800 Subject: [PATCH 31/31] feat: fix --- areal/utils/stats_logger.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/areal/utils/stats_logger.py b/areal/utils/stats_logger.py index 33dd809053..220e6d3920 100644 --- a/areal/utils/stats_logger.py +++ b/areal/utils/stats_logger.py @@ -8,9 +8,8 @@ import swanlab import torch.distributed as dist import trackio -from tensorboardX import SummaryWriter - import wandb +from tensorboardX import SummaryWriter from areal.api import FinetuneSpec from areal.api.cli_args import BaseExperimentConfig, StatsLoggerConfig