Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 47 additions & 16 deletions benchmarks/linear/benchmark_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,57 @@
import pathlib

from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling
from transformer_engine.common.recipe import (
Float8BlockScaling,
MXFP8BlockScaling,
NVFP4BlockScaling,
)
from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager
from contextlib import nullcontext

"""
# Profile BF16 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \
--output=./benchmarks/linear/b200_numgemm_8_bf16 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16

# Profile FP8 sub-channel recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \
--output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel

# Profile MXFP8 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \
--output=./benchmarks/linear/b200_numgemm_8_mxfp8 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8

# Profile NVFP4 recipe with Nsight Systems
nsys profile \
--output=./benchmarks/linear/b200_numgemm_8_nvfp4 \
--force-overwrite true \
--trace=cuda,nvtx,cudnn,cublas \
python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4

"""

RECIPES = {
"bf16": None,
"fp8_sub_channel": Float8BlockScaling(),
"mxfp8": MXFP8BlockScaling(),
"nvfp4": NVFP4BlockScaling(),
}

mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available()


def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None):
Expand Down Expand Up @@ -145,7 +158,7 @@ def benchmark_linear(
"recipe": recipe,
},
num_threads=1,
).blocked_autorange(min_run_time=5)
).blocked_autorange(min_run_time=10)
print(f"{recipe_name}: {timing} \n")
timing_ms = timing.median * 1000 / num_microbatches

Expand Down Expand Up @@ -228,30 +241,44 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):

use_bias = False
# Set the MKN values to benchmark
# Deepseek V3 EP64, SEQ_LEN=8192, topK8
# 256 expert => 4 local experts
# Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384
# M = AvgM * localExperts = 65536
# K = 7168
# N = 2048

# Deepseek V3 EP32, SEQ_LEN=8192, topK8
# 256 expert => 8 local experts
# Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192
# M = AvgM * localExperts = 65536
# K = 7168
# N = 2048

# 4 or 8local experts per rank
num_gemms_list = [4, 8]

# MKN for group linear
mkns = []
for m in [8192]:
# for m in [4096, 8192, 16384]:
# for n in [1024, 2048, 4096, 8192, 16384]:
for n in [8192]:
for k in [4096]:
for m in [65536]:
for k in [7168]:
for n in [2048]:
mkns.append((m, k, n))

# default recipes to run if not specified
recipe_list = ["bf16"]

if args.recipe == "all":
recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"]
recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"]
else:
recipe_list = [args.recipe]

num_gemms_list = [8]

if args.profile:
mkns = [(4096 * 8, 4096, 4096)]
mkns = [(8192 * 8, 7168, 2048)]
# in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
" fp8_sub_channel, mxfp8, or bf16"
" fp8_sub_channel, mxfp8, nvfp4, or bf16"
)
recipe_list = [args.recipe]
num_gemms_list = [8]
Expand All @@ -268,13 +295,17 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
"bf16",
"fp8_sub_channel",
"mxfp8",
], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8"
"nvfp4",
], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4"
if recipe_name == "mxfp8" and not mxfp8_available:
print(f"MXFP8 is not available, skipping {recipe_name}")
continue
if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available:
print(f"FP8 block scaling is not available, skipping {recipe_name}")
continue
if recipe_name == "nvfp4" and not nvfp4_available:
print(f"NVFP4 is not available, skipping {recipe_name}")
continue

df = run_benchmark_linear(
mkns,
Expand Down
92 changes: 89 additions & 3 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available()

sm_80plus = get_device_compute_capability() >= (8, 0)

Expand Down Expand Up @@ -116,6 +117,43 @@
)


def nvfp4_rht_and_2d_quantization():
nvfp4_recipe = recipe.NVFP4BlockScaling()
nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(
random_hadamard_transform=False, fp4_2d_quantization=True
)
nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(
random_hadamard_transform=True, fp4_2d_quantization=False
)
return nvfp4_recipe


def check_rht_usage(recipe: recipe.Recipe) -> bool:
# if using RHT, we can only support bf16
# check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad
if recipe.nvfp4():
if (
recipe.fp4_quant_fwd_inp.random_hadamard_transform
or recipe.fp4_quant_fwd_weight.random_hadamard_transform
or recipe.fp4_quant_bwd_grad.random_hadamard_transform
):
return True
return False


def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool:
supported_input_dtypes = []
if recipe.nvfp4():
supported_input_dtypes.append(torch.bfloat16)
# if not using RHT, we can add fp32 as well
if not check_rht_usage(recipe):
supported_input_dtypes.append(torch.float32)
return supported_input_dtypes


fp8_recipes = []
if mxfp8_available:
fp8_recipes.append(recipe.MXFP8BlockScaling())
Expand All @@ -124,6 +162,8 @@
if fp8_available:
fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(recipe.DelayedScaling())
if nvfp4_available:
fp8_recipes.append(nvfp4_rht_and_2d_quantization())

use_cutlass_grouped_gemm = [False]
# Only enable cutlass grouped gemm on Hopper
Expand Down Expand Up @@ -601,6 +641,11 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

config = model_configs[model]

Expand Down Expand Up @@ -711,6 +756,11 @@ def test_gpt_full_activation_recompute(
):
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

config = model_configs[model]

Expand Down Expand Up @@ -1304,6 +1354,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe):
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
te_linear_ref = Linear(
config.hidden_size,
Expand Down Expand Up @@ -1747,8 +1803,8 @@ def _test_grouped_linear_accuracy(
split_size = 1
if fp8:
split_size = 16
if recipe.mxfp8():
split_size = 128
if recipe.mxfp8() or recipe.nvfp4():
split_size = 32
m = config.max_seqlen_q // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
dist.append(dist[-1]) # Manually add a zero
Expand Down Expand Up @@ -1820,6 +1876,12 @@ def test_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
Expand Down Expand Up @@ -1956,6 +2018,12 @@ def test_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = GroupedLinear(
num_gemms,
Expand Down Expand Up @@ -2043,7 +2111,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r

def _pad_tensor_for_fp8(hidden_states, tokens_per_expert):
align_size = 16
if recipe.mxfp8():
if recipe.mxfp8() or recipe.nvfp4():
align_size = 32
padded_tokens_per_expert = [
(num_tokens + align_size - 1) // align_size * align_size
Expand Down Expand Up @@ -2158,6 +2226,12 @@ def test_padding_grouped_linear_accuracy(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
Expand Down Expand Up @@ -2229,6 +2303,12 @@ def test_padding_grouped_linear_accuracy_save_original_input(
if config.max_seqlen_q % 16 != 0 and fp8:
pytest.skip("FP8 requires sequence length to be divisible by 16.")

if fp8 and recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe):
grouped_linear = TorchGroupedLinearWithPadding(
num_gemms,
Expand Down Expand Up @@ -2438,6 +2518,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")

if recipe.nvfp4():
if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype):
pytest.skip(
f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}"
)

config = model_configs[model]

outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe)
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ struct Tensor {
return acc;
}

bool has_data() const noexcept { return data.dptr != nullptr; }
// Check for size (not just pointer) for 0-dim or no token cases.
bool has_data() const noexcept { return data.dptr != nullptr || data.shape.size() != 0; }
Comment on lines +157 to +158
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mathematically, a 0-D tensor is a scalar with 1 entry.

Suggested change
// Check for size (not just pointer) for 0-dim or no token cases.
bool has_data() const noexcept { return data.dptr != nullptr || data.shape.size() != 0; }
bool has_data() const noexcept { return data.dptr != nullptr; }


// Check for size (not just pointer) for 0-dim or no token cases.
bool has_columnwise_data() const noexcept {
Expand Down
Loading
Loading