diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 44f1c89673..1f09e619d6 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -9,44 +9,57 @@ import pathlib from transformer_engine.pytorch.module import GroupedLinear -from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager from contextlib import nullcontext """ # Profile BF16 recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \ + --output=./benchmarks/linear/b200_numgemm_8_bf16 \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16 # Profile FP8 sub-channel recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \ + --output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel # Profile MXFP8 recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \ + --output=./benchmarks/linear/b200_numgemm_8_mxfp8 \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8 +# Profile NVFP4 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_numgemm_8_nvfp4 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 + """ RECIPES = { "bf16": None, "fp8_sub_channel": Float8BlockScaling(), "mxfp8": MXFP8BlockScaling(), + "nvfp4": NVFP4BlockScaling(), } mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( FP8GlobalStateManager.is_fp8_block_scaling_available() ) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): @@ -145,7 +158,7 @@ def benchmark_linear( "recipe": recipe, }, num_threads=1, - ).blocked_autorange(min_run_time=5) + ).blocked_autorange(min_run_time=10) print(f"{recipe_name}: {timing} \n") timing_ms = timing.median * 1000 / num_microbatches @@ -228,30 +241,44 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): use_bias = False # Set the MKN values to benchmark + # Deepseek V3 EP64, SEQ_LEN=8192, topK8 + # 256 expert => 4 local experts + # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384 + # M = AvgM * localExperts = 65536 + # K = 7168 + # N = 2048 + + # Deepseek V3 EP32, SEQ_LEN=8192, topK8 + # 256 expert => 8 local experts + # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192 + # M = AvgM * localExperts = 65536 + # K = 7168 + # N = 2048 + + # 4 or 8local experts per rank + num_gemms_list = [4, 8] + + # MKN for group linear mkns = [] - for m in [8192]: - # for m in [4096, 8192, 16384]: - # for n in [1024, 2048, 4096, 8192, 16384]: - for n in [8192]: - for k in [4096]: + for m in [65536]: + for k in [7168]: + for n in [2048]: mkns.append((m, k, n)) # default recipes to run if not specified recipe_list = ["bf16"] if args.recipe == "all": - recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"] + recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"] else: recipe_list = [args.recipe] - num_gemms_list = [8] - if args.profile: - mkns = [(4096 * 8, 4096, 4096)] + mkns = [(8192 * 8, 7168, 2048)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" - " fp8_sub_channel, mxfp8, or bf16" + " fp8_sub_channel, mxfp8, nvfp4, or bf16" ) recipe_list = [args.recipe] num_gemms_list = [8] @@ -268,13 +295,17 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): "bf16", "fp8_sub_channel", "mxfp8", - ], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8" + "nvfp4", + ], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4" if recipe_name == "mxfp8" and not mxfp8_available: print(f"MXFP8 is not available, skipping {recipe_name}") continue if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: print(f"FP8 block scaling is not available, skipping {recipe_name}") continue + if recipe_name == "nvfp4" and not nvfp4_available: + print(f"NVFP4 is not available, skipping {recipe_name}") + continue df = run_benchmark_linear( mkns, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a0e285b913..8fff387aff 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -55,6 +55,7 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() +nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -116,6 +117,43 @@ ) +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) @@ -124,6 +162,8 @@ if fp8_available: fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +if nvfp4_available: + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper @@ -601,6 +641,11 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) config = model_configs[model] @@ -711,6 +756,11 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) config = model_configs[model] @@ -1304,6 +1354,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): te_linear_ref = Linear( config.hidden_size, @@ -1747,8 +1803,8 @@ def _test_grouped_linear_accuracy( split_size = 1 if fp8: split_size = 16 - if recipe.mxfp8(): - split_size = 128 + if recipe.mxfp8() or recipe.nvfp4(): + split_size = 32 m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero @@ -1820,6 +1876,12 @@ def test_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, @@ -1956,6 +2018,12 @@ def test_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, @@ -2043,7 +2111,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): align_size = 16 - if recipe.mxfp8(): + if recipe.mxfp8() or recipe.nvfp4(): align_size = 32 padded_tokens_per_expert = [ (num_tokens + align_size - 1) // align_size * align_size @@ -2158,6 +2226,12 @@ def test_padding_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, @@ -2229,6 +2303,12 @@ def test_padding_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, @@ -2438,6 +2518,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + config = model_configs[model] outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index bddd9bf194..f4b19e4949 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -154,7 +154,8 @@ struct Tensor { return acc; } - bool has_data() const noexcept { return data.dptr != nullptr; } + // Check for size (not just pointer) for 0-dim or no token cases. + bool has_data() const noexcept { return data.dptr != nullptr || data.shape.size() != 0; } // Check for size (not just pointer) for 0-dim or no token cases. bool has_columnwise_data() const noexcept { diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 36e06173d0..06735e3104 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -332,11 +332,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING || - input->scaling_mode == NVTE_BLOCK_SCALING_1D || - input->scaling_mode == NVTE_BLOCK_SCALING_2D || - input->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + NVTE_CHECK( + input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); @@ -583,16 +581,19 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, NVTE_CHECK_CUDA(cudaGetLastError()); } -// TODO(nvfp4): Add NVFP4 support. void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; + bool all_nvfp4 = true; for (size_t i = 0; i < num_tensors; i++) { - if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); - } + auto scaling_mode = input[i]->scaling_mode; + auto is_fp8 = is_fp8_dtype(input[i]->dtype()); + auto is_fp4 = is_fp4_dtype(input[i]->dtype()); + NVTE_CHECK( + (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), + "Not implemented scaling mode " + to_string(scaling_mode) + "."); // We don't allow empty tensors. They should be filtered out before calling this function. if (input[i]->data.numel() == 0) { NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); @@ -601,13 +602,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); all_has_data &= input[i]->has_data(); all_has_columnwise_data &= input[i]->has_columnwise_data(); + all_nvfp4 &= is_nvfp4_scaling(scaling_mode); } NVTE_CHECK(all_has_data || all_has_columnwise_data, "All tensors should have data or columnwise data."); + const bool rowwise_swizzle = all_has_data || all_nvfp4; + const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; + constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_K = 4; - if (all_has_data) { + if (rowwise_swizzle) { MultiSwizzleArgs kernel_args; kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; @@ -623,29 +628,60 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.num_tensors = 0; vec_load_size = 4; } - const int m = input[i]->scale_inv.shape[0]; - const int k = input[i]->scale_inv.shape[1]; + + int m, k; + + if (all_has_data) { + m = input[i]->scale_inv.shape[0]; + k = input[i]->scale_inv.shape[1]; + } else { + NVTE_CHECK(all_nvfp4, "When doing rowwise swizzle with rowwise data, it has to be NVFP4"); + m = input[i]->columnwise_scale_inv.shape[0]; + k = input[i]->columnwise_scale_inv.shape[1]; + } NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - NVTE_CHECK( - m * k == std::accumulate(output[i]->scale_inv.shape.begin(), - output[i]->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); + + if (output[i]->has_data()) { + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output[i]->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } int num_tiles_k = k / SF_TILE_DIM_K; int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; // We use the minimum vec_load_size across all tensors. - vec_load_size = std::min(vec_load_size, vec_load_size_i); + // TODO(zhongbo): fix vec_load_size for NVFP4 + // Current unit test won't capture this issue, but in E2E + // using vec_load_size = 1 other than 1 will lead to mis-aligned + // address error in MOE training + vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i); const int pos = kernel_args.num_tensors; - kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); - kernel_args.output_list[pos] = output[i]->scale_inv.dptr; kernel_args.m_list[pos] = m; kernel_args.k_list[pos] = k; - kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); - kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + if (!all_nvfp4 || all_has_data) { + int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size; + } else { + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE; + } kernel_args.num_tensors++; } // Launch the remaining tensors @@ -655,7 +691,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args, vec_load_size, true, stream); } - if (all_has_columnwise_data) { + if (columnwise_swizzle) { + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); + MultiSwizzleArgs kernel_args; kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 49ae963d74..e054424dd4 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector meta_shape{1}; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, columnwise_scale_inv_shape); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b6e9ef828c..0f27e19719 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -491,6 +491,207 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } +// allocate fp4 data, fp8 scalings, and amax values +// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] +// amax values need to be zeroed out, use cudaMemsetAsync for the last few bytes for amax +std::tuple, std::vector> bulk_allocate_nvfp4_tensors( + std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { + init_extension(); + std::tuple, std::vector> retval; + auto &tensor_py_list = std::get<0>(retval); + auto &tensor_cpp_list = std::get<1>(retval); + + // Number of tensors + const size_t num_tensors = shape_list.size(); + if (num_tensors == 0) { + return retval; + } + + // Quantization parameters + const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); + const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + constexpr size_t scale_elem_size = 1; + + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + size_t offset, at::ScalarType dtype) -> at::Tensor { + std::vector shape_int64(shape.begin(), shape.end()); + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { + return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); + } + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); + }; + + // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const std::vector &shape) { + std::vector fp4_shape(shape.begin(), shape.end()); + if (!fp4_shape.empty()) { + fp4_shape.back() /= 2; + } + return fp4_shape; + }; + + // Allocate row-wise data + std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; + std::vector> rowwise_data_shapes, rowwise_scale_shapes; + if (rowwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_shapes.emplace_back(shape_list[i]); + rowwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets, amax_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + // Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes). + // Integer arithmetic: ceil(product / 2) == (product + 1) / 2. + buffer_size += (product(rowwise_data_shapes[i]) + 1) / 2; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + amax_offsets.push_back(buffer_size); + // amax is scalar in fp32, 4 bytes each + buffer_size += 4; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), + data_offsets[i], torch::kUInt8)); + rowwise_scale_list.emplace_back( + make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + amax_rowwise_list.emplace_back( + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + } + } + + // Allocate column-wise data + std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; + std::vector> columnwise_data_shapes, columnwise_scale_shapes; + if (columnwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + // push the transposed shape into NVFP4 columnwise shape + // NVFP4 on SM100 is TN only + columnwise_data_shapes.emplace_back(); + auto &shape = columnwise_data_shapes.back(); + shape.push_back(shape_list[i].back()); + for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { + shape.push_back(shape_list[i][j]); + } + columnwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets, amax_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + // Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes). + // Integer arithmetic: ceil(product / 2) == (product + 1) / 2. + buffer_size += (product(columnwise_data_shapes[i]) + 1) / 2; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + amax_offsets.push_back(buffer_size); + // amax is scalar in fp32, 4 bytes each + buffer_size += 4; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_list.emplace_back(make_torch_view( + buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); + columnwise_scale_list.emplace_back( + make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + amax_columnwise_list.emplace_back( + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + } + } + + // Construct nvfp4 tensors + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); + for (size_t i = 0; i < num_tensors; ++i) { + // Create tensor objects with proper reference counting + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); + py::object columnwise_data = + (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); + py::object columnwise_scale = + (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none(); + py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); + + // Construct Python tensor + tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, amax_rowwise, amax_columnwise, + fp4_dtype, quantizer_py_list[i])); + + // Construct C++ tensor + // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, + // then set the amax and amax_columnwise values. + { + auto tensor_wrapper = makeTransformerEngineTensor( + rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp4_dtype, + /*amax_ptr=*/nullptr, + /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode); + + // Set the amax rowwise and amax columnwise if available + if (rowwise_usage) { + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (columnwise_usage) { + tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, + std::vector{1}); + } + tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); + } + } + + return retval; +} + } // namespace std::vector split_quantize(const at::Tensor &tensor, @@ -549,7 +750,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool use_fused_bulk_alloc = true; for (size_t i = 0; i < quantizer_list.size(); i++) { if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) && - !detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) { + !detail::IsMXFP8Quantizers(quantizer_list[i].ptr()) && + !detail::IsNVFP4Quantizers(quantizer_list[i].ptr())) { use_fused_bulk_alloc = false; break; } @@ -570,6 +772,7 @@ std::vector split_quantize(const at::Tensor &tensor, // TODO(zhongbo): make a better api to make this part less hacky bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr()); bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr()); + bool is_nvfp4 = detail::IsNVFP4Quantizers(quantizer_list[0].ptr()); if (is_fp8_blockwise) { // FP8 block-scaling: construct output tensors with bulk allocations std::vector blockwise_quantizers; @@ -586,6 +789,14 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); + } else if (is_nvfp4) { + // NVFP4: construct output tensors with bulk allocations + std::vector nvfp4_quantizers; + for (auto &quantizer : quantizer_cpp_list) { + nvfp4_quantizers.push_back(static_cast(quantizer.get())); + } + std::tie(output_py_list, output_cpp_list) = + bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); } else { NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); } diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 3635d4a9c0..f5a3563a1c 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); + auto* amax_ptr = amax.data_ptr(); TensorWrapper fake_te_output( - nullptr, te_input.shape(), + amax_ptr, te_input.shape(), DType::kFloat8E4M3, // It doesn't matter because we only compute amax. - amax.data_ptr()); + amax_ptr, nullptr, amax_ptr); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 42ae658f2a..d7e8912ac7 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1200,6 +1200,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({1}, bit32_tensor_opts); } if (columnwise_usage) { @@ -1213,6 +1215,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_data_tensor = at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed amax_columnwise = at::empty({1}, bit32_tensor_opts); } @@ -1352,6 +1356,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_rowwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({1}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } @@ -1392,7 +1398,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_columnwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - amax_columnwise = at::zeros({1}, opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed + amax_columnwise = at::empty({1}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } } else { // columnwise_usage == false diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 3bb6be715d..e6907e1f1c 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -99,10 +99,14 @@ std::optional multi_tensor_swizzle_scaling_factors( if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING && + tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } + const auto scaling_mode = tensors.front().scaling_mode(); + const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; + std::vector wrappers; std::vector input_tensors, output_tensors; @@ -130,39 +134,44 @@ std::optional multi_tensor_swizzle_scaling_factors( // Allocate full buffer auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + for (size_t i = 0; i < tensors.size(); ++i) { auto& tensor = tensors[i]; void* scale_inv_dptr = scale_inv_dptrs[i]; void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); - auto input_shape = nvte_shape_to_vector(tensor.shape()); - + // auto input_shape = nvte_shape_to_vector(tensor.shape()); + NVTEShape nvte_input_shape; + if (rowwise) { + nvte_input_shape = tensor.shape(); + } else { + nvte_input_shape = tensor.get_columnwise_data().shape; + } + auto input_shape = nvte_shape_to_vector(nvte_input_shape); // Reconstruct input only to avoid swizzling both directions if not needed. // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper input_cu(scaling_mode); + transformer_engine::TensorWrapper output_cu(scaling_mode); if (rowwise) { - input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); - output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); // Set the swizzled scaling factor to the original tensor. - tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); } else { - input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); - output_cu.set_columnwise_data(tensor.columnwise_dptr(), - transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_columnwise_scale_inv( - swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); // Set the swizzled scaling factor to the original tensor. - tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); } input_tensors.emplace_back(input_cu.data()); diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 0d2e3e6d7c..28c0e44302 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -111,7 +111,14 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + self.align_size = ( + 32 + if ( + FP8GlobalStateManager.get_fp8_recipe().mxfp8() + or FP8GlobalStateManager.get_fp8_recipe().nvfp4() + ) + else 16 + ) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 3b0f8928fa..27f5b15a27 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -109,7 +109,14 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + self.align_size = ( + 32 + if ( + FP8GlobalStateManager.get_fp8_recipe().mxfp8() + or FP8GlobalStateManager.get_fp8_recipe().nvfp4() + ) + else 16 + ) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index ca2154f554..133ed7af6f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -103,7 +103,7 @@ def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor: def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int: """Sign mask for random Hadamard transform.""" if with_random_sign_mask: - return get_sign_from_vector(get_wgrad_sign_vector()) + return get_sign_from_vector(get_wgrad_sign_vector()).item() return 0