feat: Support Linear Cross Entropy fuse kernel#1322
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fused linear cross-entropy (LCE) optimization using Triton kernels to avoid materializing large logit tensors, which significantly reduces memory overhead and improves performance for the Megatron backend. The implementation includes a context manager for capturing hidden states, a custom autograd function, and comprehensive benchmarking and testing suites. Review feedback identifies opportunities to improve telemetry accuracy by using actual logit values instead of logprobs, suggests moving kernel alignment assertions into a compatibility check for graceful fallbacks, and recommends using in-place operations in the backward pass for better efficiency.
|
| Case | rtol | atol |
|---|---|---|
Forward float32 |
1e-5 |
1e-5 |
Forward bfloat16 |
2e-2 |
2e-2 |
Forward float16 |
1e-2 |
1e-2 |
Temperature float32 |
1e-5 |
1e-5 |
Backward hidden.grad |
1e-4 |
1e-4 |
Backward weight.grad small/medium |
1e-4 |
1e-4 |
Backward weight.grad large |
1e-4 |
5e-4 |
These tolerances are strict enough to catch real numerical regressions while allowing expected low-precision accumulation drift.
Benchmark PurposeThe benchmark It measures:
This provides a focused way to evaluate the kernel-level benefit independently from full end-to-end training noise. Future FSDP AdaptationThe current pull request implements the core fused LCE capability and integrates it into the Megatron engine first. After this PR is merged, the same design can be adapted to the FSDP engine. |
garrett4wade
left a comment
There was a problem hiding this comment.
LGTM except that several coding style issues. We can make the code look much more better.
Besides, please fix the pre-commit error with pre-commit run --all-files
| fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY) | ||
| if ( | ||
| fused_weight is not None | ||
| and output.dtype != fused_weight.dtype | ||
| ): | ||
| output = output.to(fused_weight.dtype) | ||
| mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output |
There was a problem hiding this comment.
Since we usually require fp32 logits, will this downcast operation cause a precision issue?
There was a problem hiding this comment.
The fused LCE kernel internally accumulates the matrix multiplication in fp32. Therefore, even with bf16 input hidden states, the precision of the logits and log-softmax computations within the kernel remains fully preserved in fp32.
In practice, the non-fused computation path follows:
bf16 hidden → bf16 matmul → bf16 logits → fp32 logits (upcast by Float16Module) → fp32 log-softmax.
In contrast, the fused path maintains fp32 accumulation throughout the entire computation, ensuring its numerical precision is at least on par with, if not better than, the non-fused baseline.
|
@garrett4wade Thank you for the feedback. I've updated the code based on your suggestions and resolved the pre-commit error. Looking forward to your review! Note: The pre-commit check in CI is still failing due to a bad commit message. This is expected and can be safely ignored. |







Description
Adds a fused Linear Cross Entropy (LCE) path for Megatron training to avoid materialising full
[tokens, vocab]logits.Key changes:
logprobsand entropy.d_hiddenall-reduce in backward.Related Issue
Fixes #TBD
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
Key files:
areal/utils/kernel/kernels.py: implements the Triton fused LCE kernels, including forward logprob/entropy computation and split-N backward.areal/utils/kernel/linear_cross_entropy.py: exposes the fused LCE autograd function and handles TPd_hiddenall-reduce in backward.areal/utils/functional/linear_cross_entropy.py: provides AReaL-facing wrappers with fallback to the materialised reference path.areal/engine/megatron_utils/fused_lce_capture.py: captures LM-head hidden states and weights without materialising logits.areal/engine/megatron_engine.py: wires fused LCE into the Megatron training/logprob path behindactor.use_fused_linear_ce.tests/test_linear_cross_entropy.pyandtests/torchrun/run_lce_tp2.py: cover single-GPU and TP=2 correctness/performance checks.benchmark/bench_linear_cross_entropy.py: provides standalone fused vs materialised latency/memory benchmarking, including TP mode.Need help? Check the Contributing Guide or ask in
https://github.com/inclusionAI/AReaL/discussions!