diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 479d378ba6..b2f14b1892 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -32,12 +32,6 @@ add_executable(test_operator test_swap_first_dims.cu ../test_common.cu) -# Add profiling and debug flags for CUDA compilation -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage -# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping - # Find required packages find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index e905a00640..be716639fe 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -663,12 +663,7 @@ std::vector> tensor_dims = { // Only GeLU activation tests are supported std::vector Activation_types = { - ActivationType::Identity, - ActivationType::GeLU, - ActivationType::SiLU, - ActivationType::ReLU, - ActivationType::QGeLU, - ActivationType::SReLU, + ActivationType::Identity }; } // namespace diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e6be47686a..34d7a09be9 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -119,7 +119,7 @@ list(APPEND transformer_engine_SOURCES normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu - util/cast.cu + cast/cast.cu util/padding.cu util/cuda_driver.cpp util/cuda_nvml.cpp @@ -267,7 +267,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) set_source_files_properties(activation/gelu.cu activation/relu.cu activation/swiglu.cu - util/cast.cu + cast/cast.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") endif() diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 1d9a3fb43c..d87197edae 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -14,11 +14,9 @@ #include #include +#include "../cast/dispatch/gated.cuh" +#include "../cast/dispatch/quantize.cuh" #include "../common.h" -#include "../util/cast_gated_kernels.cuh" -#include "../util/cast_kernels.cuh" -#include "../util/math.h" -#include "../util/vectorized_pointwise.h" namespace transformer_engine { @@ -32,8 +30,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } template @@ -46,16 +44,14 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DGATED = false; - constexpr NVTETensor grad = nullptr; - quantize_gated_helper(grad, input, output, p, stream); + dispatch::quantize_gated_helper(input, output, p, stream); } template (grad, input, output, p, stream); + dispatch::quantize_dgated_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/cast/cast.cu similarity index 86% rename from transformer_engine/common/util/cast.cu rename to transformer_engine/common/cast/cast.cu index 107965d342..260a9407e9 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -10,20 +10,13 @@ #include #include -#include -#include -#include -#include - #include "../common.h" #include "../transpose/cast_transpose.h" #include "../util/multi_stream.h" -#include "../util/vectorized_pointwise.h" #include "../utils.cuh" -#include "cast_kernels.cuh" -#include "dequantize_kernels.cuh" -#include "math.h" -#include "ptx.cuh" +#include "dispatch/dequantize.cuh" +#include "dispatch/gated.cuh" +#include "dispatch/quantize.cuh" #include "transformer_engine/activation.h" #include "transformer_engine/transpose.h" @@ -38,8 +31,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper(input, grad, output, dbias, - workspace, nullptr, stream); + dispatch::quantize_helper(input, grad, output, dbias, + workspace, nullptr, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -66,7 +59,7 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output, constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper( + dispatch::quantize_helper( input, grad, output, dbias, workspace, quant_config, stream); } @@ -80,7 +73,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr bool IS_ACT = false; constexpr const NVTETensor activation_input = nullptr; - detail::quantize_helper( + dispatch::quantize_helper( activation_input, input, output, dbias, workspace, nullptr, stream); } @@ -94,7 +87,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; - detail::quantize_helper>( + dispatch::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } @@ -108,7 +101,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; - detail::quantize_helper>( + dispatch::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } @@ -122,7 +115,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; - detail::quantize_helper>( + dispatch::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } @@ -136,7 +129,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; - detail::quantize_helper>( + dispatch::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } @@ -150,14 +143,15 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_DACT = true; constexpr bool IS_ACT = false; - detail::quantize_helper>( + dispatch::quantize_helper>( activation_input, input, output, dbias, workspace, nullptr, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dequantize); using namespace transformer_engine; - detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); + dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), + stream); } void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, @@ -184,7 +178,7 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, } for (int i = 0; i < num_tensors; i++) { - detail::quantize_helper( + dispatch::quantize_helper( inputs[i], grad, outputs[i], dbias, workspace, nullptr, detail::get_compute_stream(i % num_streams)); } diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh new file mode 100644 index 0000000000..b750142f5b --- /dev/null +++ b/transformer_engine/common/cast/core/common.cuh @@ -0,0 +1,97 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file common.cuh + * \brief Common functions in quantize. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace common { +inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % elems_per_block == 0); + return isFullTile; +} + +inline bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr size_t TMA_bytes = 16; + const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + return cols % alignment_requirement == 0; +} + +namespace kernel { + +constexpr size_t THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, + const size_t rows, const size_t cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} +} // namespace kernel + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace common +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh new file mode 100644 index 0000000000..d9a1038692 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -0,0 +1,53 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize.cuh + * \brief Dequantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../fp8/dequantize_fp8.cuh" +#include "../mxfp8/dequantize_mxfp8.cuh" +#include "../nvfp4/dequantize_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { + +inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh new file mode 100644 index 0000000000..7ee9a09586 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -0,0 +1,167 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated.cuh + * \brief Gated dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ + +#include + +#include "../../common.h" +#include "../../utils.cuh" +#include "../fp8/gated_fp8.cuh" +#include "../mxfp8/gated_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_gated_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, + cudaStream_t stream) { + using namespace dispatch; + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + const auto scaling_mode = output->scaling_mode; + if ((scaling_mode != NVTE_DELAYED_TENSOR_SCALING) && !is_supported_by_CC_100()) { + NVTE_ERROR("Not supported by the Arch < 10.0"); + } + + constexpr bool allow_empty = false; + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", allow_empty); + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + bool is_fp8_rowwise_output = true; + bool is_fp8_colwise_output = true; + if (output->has_data()) { + is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == cols, "Wrong dimension of the output."); + } + if (output->has_columnwise_data()) { + is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == cols, "Wrong dimension of the output."); + } + + const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0) && + is_supported_by_CC_100(); + + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (use_tma_kernels) { + Tensor dummy_tensor; // grad + fp8::cast_gated_dgated_tma(dummy_tensor, input, output, p, + stream); + } else { + fp8::cast_gated(input, output, p, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (use_tma_kernels) { + Tensor dummy_tensor; // grad + mxfp8::quantize_gated_dgated(dummy_tensor, input, output, p, + stream); + } else { + NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + "by 32, got input of shape ", input.data.shape); + } + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(scaling_mode) + "."); + } +} + +template +void quantize_dgated_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, + NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { + using namespace dispatch; + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + const auto scaling_mode = output->scaling_mode; + if ((scaling_mode != NVTE_DELAYED_TENSOR_SCALING) && !is_supported_by_CC_100()) { + NVTE_ERROR("Not supported by the Arch < 10.0"); + } + + constexpr bool allow_empty = false; + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", allow_empty); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = 2 * cols; + + CheckInputTensor(grad, "grad"); + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); + NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + bool is_fp8_rowwise_output = true; + bool is_fp8_colwise_output = true; + if (output->has_data()) { + is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + if (output->has_columnwise_data()) { + is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); + NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); + } + + const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && (cols % 32 == 0) && + is_supported_by_CC_100(); + + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (use_tma_kernels) { + fp8::cast_gated_dgated_tma(grad, gated_input, output, p, + stream); + } else { + fp8::cast_dgated(grad, gated_input, output, p, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (use_tma_kernels) { + mxfp8::quantize_gated_dgated(grad, gated_input, output, p, + stream); + } else { + NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", + "by 32, got input of shape ", gated_input.data.shape); + } + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(scaling_mode) + "."); + } +} +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh new file mode 100644 index 0000000000..c36a50e0c8 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -0,0 +1,186 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize.cuh + * \brief Quantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/vectorized_pointwise.h" +#include "../core/common.cuh" +#include "../fp8/quantize_fp8.cuh" +#include "../mxfp8/quantize_mxfp8.cuh" +#include "../nvfp4/quantize_nvfp4.cuh" +#include "../nvfp4/quantize_transpose_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = convertNVTETensorCheck(grad); + activation_input_tensor = convertNVTETensor(input); + } else { + // forward = input is activation input + input_tensor = convertNVTETensorCheck(input); + activation_input_tensor = nullptr; + } + auto output_tensor = convertNVTETensorCheck(output); + auto dbias_tensor = convertNVTETensor(dbias); + auto workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, + workspace_tensor, stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not supported by NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && + output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh new file mode 100644 index 0000000000..599e87d05a --- /dev/null +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -0,0 +1,58 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_fp8.cuh + * \brief CUDA kernels to dequantize from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + DequantizeParam p; p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh new file mode 100644 index 0000000000..90540da77f --- /dev/null +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -0,0 +1,424 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated_fp8.cuh + * \brief CUDA kernels to cast to FP8 with gated activations. + */ + +#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_FP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +namespace kernel { +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const size_t buff = it % BUFFERS_NUM; + const size_t next_it = it + 1; + if (next_it < ITERATIONS) { + const size_t next_buff = next_it % BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + } + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); + act_x = x * s; + if (act_elt <= p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; + } else { + dact_x = 0.0f; + } + } else { + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } + } + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, p) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_DGATED) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace kernel + +template +void cast_gated_dgated_tma(const Tensor &grad, const Tensor &gated_input, Tensor *output, + ParamOP &p, cudaStream_t stream) { + using namespace kernel; + checkCuDriverContext(stream); + + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + } + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, typeToNumBits(gated_input.dtype())); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, + typeToNumBits(output->dtype())); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; + + auto kernel = cast_fp8_gated_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { + CheckInputTensor(input, "gated_act_input"); + CheckOutputTensor(*output, "gated_act_output"); + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, + "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), + output->flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, + cudaStream_t stream) { + CheckInputTensor(grad, "dgated_act_grad"); + CheckInputTensor(input, "dgated_act_input"); + CheckOutputTensor(*output, "dgated_act_output"); + NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), + "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), + ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, + "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, + "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(input.data.shape == output->data.shape, + "Input and output shapes must match. Input shape: ", input.data.shape, + ", output shape: ", output->data.shape, "."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh new file mode 100644 index 0000000000..efc5015b75 --- /dev/null +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -0,0 +1,580 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_fp8.cuh + * \brief CUDA kernels to quantize to FP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +namespace quantize_2D_kernel { + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + const size_t dbias_offset_Y = blockIdx.y + tid_Y; + const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const size_t dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const size_t chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const size_t buff = iter % FP8_BUFFERS_NUM; + const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const size_t dbias_offset_X = my_column; + const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_2D_kernel + +namespace quantize_1D_kernel { +using namespace quantize_2D_kernel; + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const size_t buff = iter % SHMEM_BUFFERS; + const size_t it_offset = iter * SHMEM_DIM; + + const size_t next_iter = iter + 1; + const size_t next_buff = next_iter % SHMEM_BUFFERS; + const size_t next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_1D_kernel + +template +void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace quantize_1D_kernel; + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + using namespace quantize_2D_kernel; + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + NVTE_CHECK_CUDA(cudaGetLastError()); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { +using Empty = transformer_engine::Empty; +__device__ inline float identity(float value, const Empty &) { return value; } +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_1D_kernel; + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + if (!IS_DBIAS && !IS_DACT) { + if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 + quantize_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (common::dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 (+dAct) + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } + } else { + if (IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + + " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } +} + +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh similarity index 69% rename from transformer_engine/common/util/dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 9f70ce4cd4..fb43fce96b 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -4,36 +4,27 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file dequantize_kernels.cuh - * \brief CUDA kernels to cast from MXFP8. +/*! \file dequantize_mxfp8.cuh + * \brief CUDA kernels to dequantize from MXFP8. */ -#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ #include #include #include -#include - -#include -#include -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transformer_engine.h" -#include "transformer_engine/transpose.h" +#include -namespace transformer_engine { +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" -namespace dequantization { +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace dequantize_kernel { constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -228,29 +219,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace dequantize_kernel -void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), nullptr, - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) -} - -void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace dequantize_kernel; bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); @@ -334,113 +306,8 @@ void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } - -#if CUDA_VERSION >= 12080 -template -__global__ void __launch_bounds__(512) - dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride) { - const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t x = thread_idx % M; - const size_t y = thread_idx / M; - - union fp4vec { - uint64_t vec; - fp4e2m1x4 small_vec[4]; - }; - using OVec = Vec; - const uint64_t *const input_vectorized = reinterpret_cast(input); - OVec *output_vec = reinterpret_cast(output); - - const size_t my_index = x + y * M; - const size_t my_scale_index = x + y * scale_stride; - const size_t my_output_index = (x + y * M) * 4; - fp4vec value; - value.vec = input_vectorized[my_index]; - fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); - float final_scale = static_cast(scale) * amax * factor_inv; -#pragma unroll - for (int i = 0; i < 4; i++) { - float4 current = static_cast(value.small_vec[i]); - OVec out; - out.data.elt[0] = static_cast(current.x * final_scale); - out.data.elt[1] = static_cast(current.y * final_scale); - out.data.elt[2] = static_cast(current.z * final_scale); - out.data.elt[3] = static_cast(current.w * final_scale); - output_vec[my_output_index + i] = out; - } -} -#endif // CUDA_VERSION - -void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { -#if CUDA_VERSION >= 12080 - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output"); - NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); - NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - constexpr int FP4_BLOCK_SIZE = 16; - const size_t N = input.flat_first_dim(); - const size_t M = input.flat_last_dim(); - - NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", - FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); - - const size_t Mread = M / FP4_BLOCK_SIZE; - const size_t total = N * Mread; - const size_t threads = 512; - const size_t blocks = DIVUP(total, threads); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, - input.scale_inv.shape.back());); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif // CUDA_VERSION >= 12080 -} - -} // namespace dequantization - -namespace detail { - -void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - switch (input.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - dequantization::fp8_dequantize(input, output, stream); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - break; - } - case NVTE_NVFP4_1D_SCALING: { - dequantization::fp4_dequantize(input, output, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } -} - -} // namespace detail - +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh similarity index 55% rename from transformer_engine/common/util/cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 93086bd827..52bfff71f4 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -4,281 +4,30 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file cast_gated_kernels.cuh - * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. +/*! \file gated_mxfp8.cuh + * \brief CUDA kernels to cast to MXFP8 with gated activations. */ -#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ #include #include #include -#include -#include +#include -#include - -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" namespace transformer_engine { - -namespace gated_kernels { - -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 512; -constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 -constexpr size_t BUFFERS_NUM = 2; -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); +namespace dispatch { +namespace mxfp8 { +namespace gated_kernel { __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act, - const __grid_constant__ CUtensorMap tensor_map_output_gate, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols, - const ParamOP p) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - - constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t in_transaction_size = buff_elems * sizeof(IType); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); - const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - // Prefetch data of the first stage - - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, - TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, - chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } else { - copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } - -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const size_t buff = it % BUFFERS_NUM; - const size_t next_it = it + 1; - if (next_it < ITERATIONS) { - const size_t next_buff = next_it % BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3( - &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, - &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, - in_transaction_size, &mbar[next_it], is_master_thread); - } else { - copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, - chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, - &mbar[next_it], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_sh_curr = out_act_sh + buff * buff_elems; - OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; -#pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; - } - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - - const float x = act_elt; - float act_x; - float dact_x; - if constexpr (std::is_same::value) { - const float x = min(act_elt, p.limit); - const float s = sigmoidf(p.alpha * x); - act_x = x * s; - if (act_elt <= p.limit) { - dact_x = s + s * (1 - s) * p.alpha * x; - } else { - dact_x = 0.0f; - } - } else { - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, p); - dact_x = DActOP(x, p); - } - } - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; - - out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, p) * gate_elt; - out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // dGeLU - ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, - chunk_it_offset_y, - reinterpret_cast(out_act_sh_curr)); - - if constexpr (IS_DGATED) { - // dGate - ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_sh_curr)); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -namespace mxfp8_kernel { - constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; @@ -306,16 +55,17 @@ template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, const ParamOP p) { + quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -920,94 +670,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -} // namespace mxfp8_kernel +} // namespace gated_kernel template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { - checkCuDriverContext(stream); - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act{}; - alignas(64) CUtensorMap tensor_map_output_gate{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { +void quantize_gated_dgated(const Tensor &grad, const Tensor &gated_input, Tensor *output, + ParamOP &p, cudaStream_t stream) { + using namespace gated_kernel; checkCuDriverContext(stream); const bool USE_ROWWISE_SCALING = output->has_data(); @@ -1033,15 +702,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; - - const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; - constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; @@ -1118,230 +781,60 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; switch (scaling_type) { - case ScalingType::ROWWISE: + case ScalingType::ROWWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::COLWISE: + } + case ScalingType::COLWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::BIDIMENSIONAL: + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, - "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), p, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, - cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, - "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match. Input shape: ", input.data.shape, - ", output shape: ", output->data.shape, "."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), p, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { - constexpr bool allow_empty = false; - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", allow_empty); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - if constexpr (IS_DGATED) { - CheckInputTensor(grad, "grad"); - NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); - NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); - } + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - bool is_fp8_rowwise_output = true; - bool is_fp8_colwise_output = true; - if (output->has_data()) { - is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - if (output->has_columnwise_data()) { - is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - - const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; - - if (is_delayed_tensor_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, p, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, p, stream); - } else { - cast_gated(gated_input, output, p, stream); - } - } - } else if (is_mxfp8_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, p, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } - } else { - NVTE_ERROR("Not supported scaling mode"); - } -} -} // namespace gated_kernels - -namespace detail { - -template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - ParamOP p, cudaStream_t stream) { - using namespace gated_kernels; - Tensor grad_empty_tensor; - const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); - Tensor *output_tensor = convertNVTETensorCheck(output); - - if (is_supported_by_CC_100()) { - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, p, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, - stream); - } else { - cast_gated(gated_input_tensor, output_tensor, p, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } -} // namespace detail +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh new file mode 100644 index 0000000000..5505de6050 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -0,0 +1,722 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_mxfp8.cuh + * \brief CUDA kernels to quantize to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace quantize_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_kernel; + checkCuDriverContext(stream); + + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::COLWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + } + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh new file mode 100644 index 0000000000..660b014bf0 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -0,0 +1,308 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file core_nvfp4.cuh + * \brief Core functions used in NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ + +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "curanddx.hpp" +#if CUDA_VERSION >= 12080 +#include +#endif // CUDA_VERSION >= 12080 + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +using nvfp4_scale_t = fp8e4m3; + +namespace quantization_and_transposition_SF { +// Used in transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if CUDA_VERSION >= 12080 + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +#else + NVTE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantization_and_transposition_SF + +namespace quantization_SF { +// Used in non-transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#if CUDA_VERSION >= 12080 + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +#else + NVTE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantization_SF + +namespace core { + +#if CUDA_VERSION >= 12080 +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +using namespace ptx; + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} + +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#endif // CUDA_VERSION >= 12080 + +} // namespace core +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh new file mode 100644 index 0000000000..ff5a4a5d9b --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -0,0 +1,111 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_nvfp4.cuh + * \brief CUDA kernels to dequantize from NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if CUDA_VERSION >= 12080 +#include +#endif // CUDA_VERSION >= 12080 + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace dequantize_kernel { +#if CUDA_VERSION >= 12080 +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // CUDA_VERSION +} // namespace dequantize_kernel + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if CUDA_VERSION >= 12080 + using namespace dequantize_kernel; + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // CUDA_VERSION >= 12080 +} +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh new file mode 100644 index 0000000000..69af2841e1 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -0,0 +1,688 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_kernel { + +using namespace ptx; +using namespace quantization_SF; +using namespace core; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { +#if CUDA_VERSION >= 12080 + using namespace quantize_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh similarity index 82% rename from transformer_engine/common/util/nvfp4_transpose.cuh rename to transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index 712b557c5d..c260f9ec2f 100644 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -4,40 +4,35 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file nvfp4_transpose.cuh +/*! \file quantize_transpose_nvfp4.cuh * \brief CUDA kernels to cast to NVFP4 and transpose. */ -#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ -#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ #include #include #include +#include -#if CUDA_VERSION > 12080 -#include -#endif // CUDA_VERSION > 12080 - -#include - -#include "../common.h" -#include "../utils.cuh" -#include "curanddx.hpp" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { -#if CUDA_VERSION > 12080 -namespace nvfp4_transpose { - -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + - curanddx::SM<800>() + curanddx::Thread()); +namespace quantize_transpose_kernel { +using namespace quantization_and_transposition_SF; +using namespace core; using namespace ptx; -using nvfp4_scale_t = fp8e4m3; + +#if CUDA_VERSION >= 12080 constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) @@ -49,8 +44,9 @@ constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; -constexpr size_t RNG_GENS_PER_THREAD = - SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers + +// Each call generates 4x uint32_t random numbers +constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; constexpr size_t TILE_DIM_Y = 32; constexpr size_t TILE_DIM_X = 128; @@ -110,244 +106,18 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 -// Compute per-block E4M3 encoding/decoding scaling factor -__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { - // constexpr float rcp_6f = 1.0f / 6.0f; - // const float S_dec_b = block_amax * rcp_6f; - // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - // return S_dec_b_fp8; - // NOTE: Divide by 6.0f is not elegant and not efficient. - // However, this is part of the emulation code to ensure exact match. - using namespace detail; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; - const float S_dec_b = block_amax / fp4_max * S_enc; - return static_cast(fminf(S_dec_b, TypeExtrema::max)); -} - -// Compute the global encode scale factor for a given global amax -__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { - using namespace detail; - constexpr float fp8_max = TypeExtrema::max; // 448.0f; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; - float global_encode_scale = fp8_max * fp4_max / global_amax; - // If scale is infinity, return max value of float32 - global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); - // If global amax is 0 or infinity, return 1 - if (global_amax == 0.0f || global_encode_scale == 0.0f) { - return 1.0f; - } - return global_encode_scale; -} - -__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { - if (rnd_idx == 4) { - rnd_idx = 0; - curanddx::uniform_bits dist; - random_uint4 = dist.generate4(rng); - } - // Treat uint4 as an array of 4x uint32_t elements for indexing - const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); - const uint32_t rbits = rbits_arr[rnd_idx++]; - return rbits; -} - -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( - const uint64_t in_4x, const float2 scale, const uint32_t rbits) { - uint16_t out_4x = 0; -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL - return *reinterpret_cast(&out_4x); -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, - const float2 scale, - const uint32_t rbits) { - // NOTE: rbits unused for rn. - uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale))); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL - return reinterpret_cast(&out_4x)[0]; -} - -template -__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, - const float2 scale, - const uint32_t rbits) { - if constexpr (USE_STOCHASTIC_ROUNDING) { - return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); - } else { - return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); - } -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( - const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { - uint16_t out_4x = 0; -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale)), "r"(rbits)); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL - return *reinterpret_cast(&out_4x); -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, - const float2 in23, - const float2 scale, - const uint32_t rbits) { - // NOTE: rbits unused for rn. - uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale))); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL - return reinterpret_cast(&out_4x)[0]; -} - -template -__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, - const float2 scale, - const uint32_t rbits) { - if constexpr (USE_STOCHASTIC_ROUNDING) { - return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); - } else { - return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); - } -} - -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - template __global__ void __launch_bounds__(THREADS_NUM) - nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, - const float *noop, const float *const amax_rowwise_ptr, - const float *const amax_colwise_ptr, const size_t rows, - const size_t cols, const size_t scale_stride, - const size_t scale_stride_t, const size_t *rng_state) { + quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); @@ -852,14 +622,15 @@ __global__ void __launch_bounds__(THREADS_NUM) template __global__ void __launch_bounds__(THREADS_NUM) - nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, - const float *noop, const float *const amax_rowwise_ptr, - const float *const amax_colwise_ptr, const size_t rows, - const size_t cols, const size_t scale_stride, - const size_t scale_stride_t, const size_t *rng_state) { + quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); @@ -1379,19 +1150,20 @@ __global__ void __launch_bounds__(THREADS_NUM) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -} // namespace nvfp4_transpose -#endif // CUDA_VERSION > 12080 +#endif // CUDA_VERSION >= 12080 +} // namespace quantize_transpose_kernel // Compile-time flag to choose kernel variant #ifndef USE_2D_NVFP4_KERNEL #define USE_2D_NVFP4_KERNEL 0 #endif -template -void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, - const QuantizationConfig *quant_config, cudaStream_t stream) { -#if CUDA_VERSION > 12080 +template +void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if CUDA_VERSION >= 12080 + using namespace quantize_transpose_kernel; + using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1399,8 +1171,9 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o // TODO(Frank): Is there a better way to do this? bool return_transpose = output->has_columnwise_data(); - using namespace nvfp4_transpose; - using namespace ptx; + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; checkCuDriverContext(stream); CheckNoopTensor(*noop, "cast_noop"); @@ -1493,12 +1266,12 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = nvfp4_transpose_kernel; + auto kernel = quantize_transpose_nvfp4_kernel; if constexpr (use_2d_quantization) { - kernel = nvfp4_transpose_kernel_2D; + kernel = quantize_transpose_nvfp4_2D_kernel; } cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); @@ -1509,8 +1282,11 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // CUDA_VERSION > 12080 +#endif // CUDA_VERSION >= 12080 } + +} // namespace nvfp4 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh deleted file mode 100644 index b0498602b5..0000000000 --- a/transformer_engine/common/util/cast_kernels.cuh +++ /dev/null @@ -1,2188 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file cast_kernels.cuh - * \brief CUDA kernels to cast to/from FP8/MXFP8. - */ - -#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ - -#include -#include -#include -#include - -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "nvfp4_transpose.cuh" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { - -namespace mxfp8_kernel { - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 32; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; - static_assert(BUFF_DIM_Y == 32); - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - - OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } - - float block_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - - parity ^= 1; - - if constexpr (IS_DBIAS) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); - - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; - } - } - const int dbias_stride = cols; - const int dbias_offset_Y = blockIdx.y; - const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace mxfp8_kernel - -namespace nvfp4_kernel { - -using namespace ptx; - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 16; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t BUFF_DIM_Y = 32; - -constexpr size_t PACK_SIZE = 8; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 - -// Compute per-block E4M3 encoding/decoding scaling factor -__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { - constexpr float rcp_6f = 1.0f / 6.0f; - // const float S_dec_b = block_amax * rcp_6f; - // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - // return S_dec_b_fp8; - return static_cast(block_amax * rcp_6f * S_enc); -} - -#define DIRECT_SCALING_FACTORS_STORE 1 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, - const float *noop, float *const amax_ptr, - const float *const nvfp4_second_stage_scale_ptr, const size_t rows, - const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool ROWWISE_SCALING = true; - constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = - (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); - - using IType2 = typename ptx::FPx2; - - if constexpr (!COMPUTE_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; - constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; - - static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && - "Number of buffer rows must be greater or equal to the size of the columwise " - "scaling block\0"); - static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); - static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && - "Number of buffer rows must be greater or equal to the number of rowwise " - "processing threads in Y dimension\0"); - - constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size - constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - - constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; - // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of - // // threads to process one row in a single iteration - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * CHUNK_DIM_X; - const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const int tid_Y_colwise = 0; - const int tid_X_colwise = threadIdx.x; - - const int thread_offset_Y_rowwise = tid_Y_rowwise; - const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const int thread_offset_Y_colwise = tid_Y_colwise; - const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements - - const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const int col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_nvfp4 = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_mxfp8 = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t buff_size_nvfp4_scales = - CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); - - constexpr size_t in_mem = buff_size_aligned_in; - - constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); - constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); - constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); - constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); - fp8e4m3 *out_rowwise_scales_sh = - reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); - e8m0_t *out_colwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = - (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); - - float thread_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const int buff = stage % BUFFS_NUM; - const int next_stage = stage + 1; - const int stage_offset_Y = stage * BUFF_DIM_Y; - - const int buff_offset_in = buff * BUFF_IN_DIM; - const int buff_offset_out = buff * BUFF_OUT_DIM; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const int next_buff = next_stage % BUFFS_NUM; - const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const int global_offset_Y = block_offset_Y + next_stage_offset_Y; - const int global_offset_X = block_offset_X; - const int next_buff_offset = next_buff * BUFF_IN_DIM; - - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], 0); - - float block_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; - - block_amax = 0.0f; - float in_compute_colwise[SCALE_DIM_Y]; - IType in_colwise_IType[SCALE_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType block_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); - } - block_amax = static_cast(block_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); - - const int global_scales_offset_Y = scales_offset_Y_colwise + stage; - const int global_scales_offset_X = scales_offset_X_colwise; - const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - if (colwise_scale_is_within_bounds) { - scales_colwise_e8m0[scale_idx] = biased_exponent; - } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; -#pragma unroll - for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { - const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; - - const int shmem_offset_base_rowwise_in = - buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; - const int shmem_offset_base_rowwise_out = - buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; - - const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; - - block_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find NVFP4-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = - (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E4M3 scaling factor - const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); - -#if DIRECT_SCALING_FACTORS_STORE - // Check boundaries - if (rowwise_scale_is_within_bounds) { - const int scales_offset_Y = - scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = scales_offset_X_rowwise; - const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; - scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; - } -#else - const int shmem_scales_offset_Y = - stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; - const int shmem_scales_offset_X = tid_X_rowwise; - const int scale_idx = - shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; - out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; -#endif - // Compute "correct" per-block encoding scaling factor - const float block_scale_inverse = - __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 - -// 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; // Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - in01 = in_IType[w].data.elt[2 * e]; - in23 = in_IType[w].data.elt[2 * e + 1]; - } else if constexpr (IS_CACHED_ACT_OP) { - in01.x = in_cached[w].data.elt[4 * e]; - in01.y = in_cached[w].data.elt[4 * e + 1]; - in23.x = in_cached[w].data.elt[4 * e + 2]; - in23.y = in_cached[w].data.elt[4 * e + 3]; - } else { - const int j = w * PACK_SIZE + 4 * e; - in01.x = in_compute_rowwise[j]; - in01.y = in_compute_rowwise[j + 1]; - in23.x = in_compute_rowwise[j + 2]; - in23.y = in_compute_rowwise[j + 3]; - } - fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); - ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); - } - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } - } - } - - __builtin_assume(thread_amax >= 0); - __builtin_assume(block_amax >= 0); - thread_amax = fmaxf(thread_amax, block_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; - const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - -#if !DIRECT_SCALING_FACTORS_STORE - // Vectorized store of scaling factors. - // Each thread stores multiple scaling factors in one store instruction. - if constexpr (ROWWISE_SCALING) { - // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; - const int scale_idx_global = - scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; - const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; - - if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && - (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { - using ScalesVec_t = Vec; - const ScalesVec_t &scales = - *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); - scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); - } - } -#endif - - float chunk_amax = 0.0f; - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - chunk_amax = reduce_max(thread_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, chunk_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace nvfp4_kernel - -constexpr size_t FP8_CHUNK_DIM_Y = 128; -constexpr size_t FP8_CHUNK_DIM_X = 128; -constexpr size_t FP8_THREADS_PER_CHUNK = 128; -constexpr size_t FP8_BUFFERS_NUM = 2; -constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); - -constexpr size_t FP8_BUFFER_DIM_Y = 16; -constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 -constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 - -constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); - -template -__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output, - float *const dbias_workspace, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, - const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - const size_t dbias_offset_Y = blockIdx.y + tid_Y; - const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; - const bool col_out_of_bounds = my_column >= cols; - const size_t dbias_stride = cols; - - float partial_dbias = 0.f; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - - constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const size_t chunk_stage_offset_X = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); - } - } - -#pragma unroll - for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const size_t buff = iter % FP8_BUFFERS_NUM; - const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; - if (next_iter < FP8_ITERATIONS) { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); - } - } - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = row >= rows; - const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if constexpr (IS_DACT) { - if (!out_of_bounds) { - partial_dbias += elt; - } - } else { - // If no activation, elt is 0 so we can safely do this - partial_dbias += elt; - } - } - __builtin_assume(amax >= 0); - if (IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); - } - out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - - if constexpr (IS_DBIAS) { - const size_t dbias_offset_X = my_column; - const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t CHUNKS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; -constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; -constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; -constexpr size_t CHUNKS_PER_ITERATION = 32; -constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; -constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; -constexpr size_t SHMEM_BUFFERS = 2; -static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); - -template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; - const IType *input = input_ptr + block_offset; - OType *output = output_ptr + block_offset; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - - constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; ++iter) { - const size_t buff = iter % SHMEM_BUFFERS; - const size_t it_offset = iter * SHMEM_DIM; - - const size_t next_iter = iter + 1; - const size_t next_buff = next_iter % SHMEM_BUFFERS; - const size_t next_iter_offset = next_iter * SHMEM_DIM; - - if (next_iter < ITERATIONS) { - copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, - &(mbar[next_iter]), is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; - float elt = static_cast(in_sh[buff][shmem_offset]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(elt)); - out_sh[buff][shmem_offset] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - ptx::cp_async_bulk_tensor_1d_shared_to_global( - reinterpret_cast(output + it_offset), - reinterpret_cast(&out_sh[buff]), transaction_size_OUT); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read<1>(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; -template -__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, - const size_t rows, const size_t cols) { - using ComputeVec = Vec; - using OutputVec = Vec; - - const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - if (thread_id * nvec >= cols) { - return; - } - - const float *const thread_in_base = dbias_partial + thread_id * nvec; - OType *const thread_out_base = dbias_output + thread_id * nvec; - - ComputeVec ldg_vec; - ComputeVec acc_vec; - acc_vec.clear(); - for (int i = 0; i < rows; ++i) { - ldg_vec.load_from(thread_in_base + i * cols); -#pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } - - OutputVec stg_vec; -#pragma unroll - for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); - } - stg_vec.store_to(thread_out_base); -} - -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream) { - constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 - constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); - - NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); - const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); - - reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { - const size_t N = product(input.data.shape); - - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - const size_t chunks = DIVUP(N, CHUNK_SIZE); - const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - const float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(THREADS_PER_BLOCK); - const dim3 grid(blocks); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - const IType *input_ptr = reinterpret_cast(input.data.dptr); - OType *output_ptr = reinterpret_cast(output->data.dptr); - - cast_fp8_1D_kernel<<>>( - input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) - ); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, - Tensor *workspace, cudaStream_t stream) { - checkCuDriverContext(stream); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); - const size_t blocks_Y = chunks_Y; - const size_t blocks_X = chunks_X; - - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(FP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - } - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); - - cast_fp8_2D_kernel - <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, - workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, - const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - using namespace mxfp8_kernel; - checkCuDriverContext(stream); - - bool use_rowwise_scaling = output->has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - - if (use_rowwise_scaling) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - - constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; - - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -// This kernel supports only two scaling cases: -// 1. r16c0 - Rowwise NVFP4 -// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 -template -void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { - using namespace nvfp4_kernel; - using namespace ptx; - checkCuDriverContext(stream); - - NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - bool use_colwise_scaling = output->has_columnwise_data(); - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - constexpr size_t CHUNK_DIM_Y = 128; - constexpr size_t CHUNK_DIM_X = 128; - constexpr size_t THREADS_PER_CHUNK = 128; - - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = output->scale_inv.shape[1]; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); - e8m0_t *const scales_colwise_e8m0_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - - const ScalingType scaling_type = - use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const float *const nvfp4_second_stage_scale_ptr = - reinterpret_cast(output->scale.dptr); - - // Output data type is only required for the column-wise MXFP8 scaling. - // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work - const DType output_data_type = - use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, sizeof(IType) * 8); - - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); - } - - constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_nvfp4 = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_mxfp8 = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_nvfp4_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); - - constexpr size_t in_mem = buff_size_aligned_in; - - const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; - const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; - - const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; - const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; - - const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + - out_rowwise_scales_mem + out_colwise_scales_mem + - TMA_SHMEM_ALIGNMENT; - - const size_t dshmem_size = in_mem + out_mem; - - switch (scaling_type) { - case ScalingType::ROWWISE: - cudaFuncSetAttribute( - cast_nvfp4_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - - cast_nvfp4_kernel - <<>>( - tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, - scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, - nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - break; - case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( - cast_nvfp4_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - - cast_nvfp4_kernel - <<>>( - tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, - scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, - nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - break; - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace detail { - -using Empty = transformer_engine::Empty; - -__device__ inline float identity(float value, const Empty &) { return value; } - -struct DequantizeParam { - const float *scale_inv; -}; - -__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); -} - -} // namespace detail - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(noop->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream) { - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input->data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input->data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace { - -static bool is_full_tile_1D_tensor(const Tensor *const t) { - const size_t N = product(t->data.shape); - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - return isFullTile; -} - -bool dimensions_supported_by_TMA(const Tensor *const t) { - const size_t cols = t->flat_last_dim(); - constexpr size_t TMA_bytes = 16; - const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); - return cols % alignment_requirement == 0; -} - -} // namespace - -// Supported by the Arch >= 10.0 -template -void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 - cast_fp8_1D(input, output, stream); - } else { - // Unaligned - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 (+dAct) - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } else { - // Unaligned - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - } else { - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -// Supported by the Arch < 10.0 -template -void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { - // zhongboz: should we just ignore IS_ACT here? - NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); - } - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -template -void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, - Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { - // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } -} - -namespace detail { - -template -void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, - NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - const Tensor *input_tensor; - const Tensor *activation_input_tensor; - if constexpr (IS_DBIAS || IS_DACT) { - // backward - input is incoming gradient - input_tensor = convertNVTETensorCheck(grad); - activation_input_tensor = convertNVTETensor(input); - } else { - // forward = input is activation input - input_tensor = convertNVTETensorCheck(input); - activation_input_tensor = nullptr; - } - auto output_tensor = convertNVTETensorCheck(output); - auto dbias_tensor = convertNVTETensor(dbias); - auto workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && - output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4_quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4_quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for " - "2D quantization"); - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } -} - -} // namespace detail -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_