diff --git a/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py new file mode 100644 index 0000000000..bd630b2b82 --- /dev/null +++ b/benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py @@ -0,0 +1,185 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm + +from benchmarks.utils import benchmark_cuda_function_in_microseconds +from torchao.prototype import mxfp8_cuda +from torchao.prototype.moe_training.scaled_grouped_mm import ( + _to_mxfp8_dim1_3d, +) +from torchao.prototype.mx_formats.mx_tensor import to_mx + +device = torch.device("cuda") + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + input_shape: tuple[int] + + +@dataclass(frozen=True) +class ExperimentResult: + # time + to_mx_us: float + cuda_2d_us: float + cuda_3d_us: float + # mem bw + to_mx_gbps: float + cuda_2d_gbps: float + cuda_3d_gbps: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + # Llama4 shapes. Input activations are scaled along K dim. + input_shapes = [ + (1, 8192, 5120), + (2, 8192, 5120), + (4, 8192, 5120), + (8, 8192, 5120), + (16, 8192, 5120), + (64, 8192, 5120), + ] + configs = [] + for shape in input_shapes: + configs.append( + ExperimentConfig( + input_shape=shape, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + block_size = 32 + input_shape = config.input_shape + input_tensor = torch.randn( + *input_shape, + dtype=torch.bfloat16, + device=device, + ) + + def using_to_mx(x: torch.Tensor) -> torch.Tensor: + # Reference implementation + s_d1_ref, y_d1_ref = to_mx( + # Transpose (E,N,K) to (E,K,N) so N is final dim, + # since to_mx scales along that dim + x.transpose(-2, -1).contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Transpose tensors and scales back so we have effectively + # quantized input shape (E, N, K) along N + y_d1_ref = y_d1_ref.transpose(-2, -1) + s_d1_ref = s_d1_ref.transpose(-2, -1) + return y_d1_ref, s_d1_ref + + # bench to_mx + using_to_mx_c = torch.compile(using_to_mx) + scales_to_mx, data_to_mx = using_to_mx_c(input_tensor) + to_mx_time_us = benchmark_cuda_function_in_microseconds( + using_to_mx_c, + input_tensor, + ) + + # bench 2d dim1 kernel then transforming to col major + using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d) + scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor) + time_cuda_2d_us = benchmark_cuda_function_in_microseconds( + using_cuda_2d_c, + input_tensor, + ) + + # bench 3d cuda kernel + data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor) + time_cuda_3d_us = benchmark_cuda_function_in_microseconds( + mxfp8_cuda.quantize_3d, + input_tensor, + ) + + # mem bw calculations + bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8 + bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8 + bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8 + + read_bytes = input_tensor.numel() * bytes_per_input_el + write_bytes = ( + data_cuda_3d.numel() * bytes_per_output_el + + scales_cuda_3d.numel() * bytes_per_scale_el + ) + + to_mx_gbps = ((read_bytes + write_bytes) / 1e9) / (to_mx_time_us / 1e6) + cuda_2d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_2d_us / 1e6) + cuda_3d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_3d_us / 1e6) + + return ExperimentResult( + # time + to_mx_us=to_mx_time_us, + cuda_2d_us=time_cuda_2d_us, + cuda_3d_us=time_cuda_3d_us, + # mem bw + to_mx_gbps=to_mx_gbps, + cuda_2d_gbps=cuda_2d_gbps, + cuda_3d_gbps=cuda_3d_gbps, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "input_shape", + "to_mx_us", + "cuda_2d_us", + "cuda_3d_us", + "to_mx_gbps", + "cuda_2d_gbps", + "cuda_3d_gbps", + ] + rows = [] + for experiment in experiments: + rows.append( + [ + str(experiment.config.input_shape), + experiment.result.to_mx_us, + experiment.result.cuda_2d_us, + experiment.result.cuda_3d_us, + round(experiment.result.to_mx_gbps, 3), + round(experiment.result.cuda_2d_gbps, 3), + round(experiment.result.cuda_3d_gbps, 3), + ] + ) + print(tabulate(rows, headers=headers)) + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 9c044b9fef..1da4899667 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -12,7 +12,6 @@ if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9): pytest.skip("Unsupported PyTorch version", allow_module_level=True) - from torchao.prototype.moe_training.kernels.float8_rowwise import ( triton_fp8_rowwise_3d_transpose_rhs, triton_fp8_rowwise_3d_transpose_rhs_fused_reduction, @@ -38,8 +37,11 @@ torch_to_float8_per_group_colwise, torch_to_float8_per_group_rowwise, ) -from torchao.prototype.mx_formats.mx_tensor import to_mx +from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx from torchao.testing.utils import skip_if_rocm +from torchao.utils import ( + is_sm_at_least_100, +) @skip_if_rocm("ROCm enablement in progress") @@ -313,3 +315,53 @@ def test_triton_mx_block_rearrange_2d_K_groups( output_group_offsets, ) assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal" + + +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="MXFP8 requires CUDA capability 10.0 or greater", +) +@pytest.mark.parametrize("E", (1, 2, 4, 8)) +@pytest.mark.parametrize("N", (32, 64, 8192)) +@pytest.mark.parametrize("K", (32, 64, 8192)) +@pytest.mark.parametrize("input_dtype", (torch.bfloat16,)) +@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,)) +def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode): + from torchao.prototype import mxfp8_cuda + + scaling_mode_str = ( + "floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil" + ) + block_size = 32 + + # Use disinct incrementing values from 0 to E*M*K-1 to make debugging easier. + x = ( + torch.arange(0, E * N * K, dtype=input_dtype, device="cuda") + .reshape(E, N, K) + .contiguous() + ) + + # Reference implementation + s_d1_ref, y_d1_ref = to_mx( + # Transpose so N is final dim, since to_mx scales along that dim + x.transpose(-2, -1).contiguous(), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + + # Transpose tensors and scales back so we have effectively + # quantized input shape (E, N, K) along N + y_d1_ref = y_d1_ref.transpose(-2, -1) + s_d1_ref = s_d1_ref.transpose(-2, -1) + + # CUDA implementation (should work with any stride pattern) + y_d1, s_d1 = mxfp8_cuda.quantize_3d( + x, scale_dim_n=block_size, scaling_mode=scaling_mode_str + ) + + # Check scales + torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0) + + # Check quantized values + torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0) + assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match" diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu b/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu index ffb91d38c6..7546dc7b7b 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu @@ -109,4 +109,72 @@ void mxfp8_quantize_cuda(const torch::Tensor &input, stream); } +void mxfp8_quantize_3d_cuda(const torch::Tensor &input, + torch::Tensor &output_colwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_n, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Get tensor properties for 3D tensor (E, N, K) + const int64_t E = input.size(0); + const int64_t N = input.size(1); + const int64_t K = input.size(2); + + // Get data pointers + const void *input_ptr = input.data_ptr(); + void *output_colwise_ptr = output_colwise.data_ptr(); + e8m0_t *scales_colwise_ptr = + reinterpret_cast(scales_colwise.data_ptr()); + + // Get CUDA stream + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Get strides of scales tensor + int64_t scales_colwise_stride_dim0 = scales_colwise.stride(0); + int64_t scales_colwise_stride_dim1 = scales_colwise.stride(1); + int64_t scales_colwise_stride_dim2 = scales_colwise.stride(2); + + // Get input tensor strides for generic layout support + int64_t input_stride_dim0 = input.stride(0); // E dimension stride + int64_t input_stride_dim1 = input.stride(1); // N dimension stride + int64_t input_stride_dim2 = input.stride(2); // K dimension stride + + // Get output tensor strides (shoudl be col major) + int64_t output_stride_dim0 = output_colwise.stride(0); // E dimension stride + int64_t output_stride_dim1 = output_colwise.stride(1); // N dimension stride + int64_t output_stride_dim2 = output_colwise.stride(2); // K dimension stride + + +#if defined(DEBUG) + printf("mxfp8_quantize_3d_cuda:\n"); + printf("Quantizing 3D input tensor of size %ld x %ld x %ld\n", E, N, K); + printf("scaling_mode: %s\n", scaling_mode.c_str()); + printf("Scale dim n: %ld\n", scale_dim_n); + printf("Output scale shape: %ld x %ld x %ld\n", + scales_colwise.sizes()[0], scales_colwise.sizes()[1], scales_colwise.sizes()[2]); + printf("scales_colwise_stride_dim0 = %ld\n", scales_colwise_stride_dim0); + printf("scales_colwise_stride_dim1 = %ld\n", scales_colwise_stride_dim1); + printf("input_stride_dim0 = %ld\n", input_stride_dim0); + printf("input_stride_dim1 = %ld\n", input_stride_dim1); + printf("input_stride_dim2 = %ld\n", input_stride_dim2); + printf("output_stride_dim0 = %ld\n", output_stride_dim0); + printf("output_stride_dim1 = %ld\n", output_stride_dim1); + printf("output_stride_dim2 = %ld\n", output_stride_dim2); +#endif + + // Call the 3D quantization kernel + MXFP8Quantizer::quantize_3d(input_ptr, + output_colwise_ptr, + scales_colwise_ptr, + E, N, K, + input_stride_dim0, input_stride_dim1, input_stride_dim2, + output_stride_dim0, output_stride_dim1, output_stride_dim2, + scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2, + get_input_dtype(input), get_output_dtype(fp8_format), + scale_dim_n, + get_scaling_mode(scaling_mode), + stream); +} + } // namespace mxfp8 diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp index 1f76788133..6119a4ce61 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp @@ -18,6 +18,13 @@ void mxfp8_quantize_cuda(const torch::Tensor &input, const std::string &fp8_format, const std::string &scaling_mode); +void mxfp8_quantize_3d_cuda(const torch::Tensor &input, + torch::Tensor &output_colwise, + torch::Tensor &scales_colwise, + int64_t scale_dim_n, + const std::string &fp8_format, + const std::string &scaling_mode); + // Helper for tensor validation void check_cuda_tensor(const torch::Tensor &t, const char *name) { TORCH_CHECK(t.is_cuda(), name, " must be a CUDA tensor"); @@ -115,6 +122,60 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise, scales_colwise); } +// 3D tensor quantization function +std::tuple +mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n, + const std::string &fp8_format, + const std::string &scaling_mode) { + + // Validate inputs + TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor"); + // Note: We don't check contiguous for 3D as it may have column major strides + TORCH_CHECK(input.dim() == 3, "input must be 3D"); + TORCH_CHECK(input.scalar_type() == torch::kFloat32 || + input.scalar_type() == torch::kFloat16 || + input.scalar_type() == torch::kBFloat16, + "Input must be float32, float16, or bfloat16"); + TORCH_CHECK(scale_dim_n == 32, "scale_dim_n must be 32 for now"); + + validate_fp8_format(fp8_format); + + const int64_t E = input.size(0); + const int64_t N = input.size(1); + const int64_t K = input.size(2); + + // Check dimensions are valid for 3D kernel + TORCH_CHECK((N >= 32) && (N % 32 == 0), "N must be a multiple of 32"); + TORCH_CHECK((K >= 32) && (K % 32 == 0), "K must be a multiple of 32"); + + // The kernel should work with any stride pattern - no layout requirements + + c10::cuda::CUDAGuard device_guard(input.device()); + + // Create tensor options + const auto options_fp8 = torch::TensorOptions() + .dtype(torch::kFloat8_e4m3fn) + .device(input.device()); + + const auto options_scale = torch::TensorOptions() + .dtype(torch::kFloat8_e8m0fnu) + .device(input.device()); + + // Create output tensor with column major layout (required for downstream ops) + torch::Tensor output_colwise = torch::empty_strided( + {E, N, K}, {N * K, 1, N}, options_fp8); + + // Create scales tensor with shape (E, num_n_blocks, K) + const int64_t num_n_blocks = (N + scale_dim_n - 1) / scale_dim_n; + torch::Tensor scales_colwise = torch::empty({E, num_n_blocks, K}, options_scale); + + // Call CUDA kernel + mxfp8_quantize_3d_cuda(input, output_colwise, scales_colwise, + scale_dim_n, fp8_format, scaling_mode); + + return std::make_tuple(output_colwise, scales_colwise); +} + } // namespace mxfp8 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -125,4 +186,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("scale_dim_x") = 32, py::arg("scale_dim_y") = 32, py::arg("fp8_format") = "e4m3", py::arg("scaling_mode") = "floor"); + + m.def("quantize_3d", &mxfp8::mxfp8_quantize_3d, "MXFP8 3D quantization", + py::arg("input"), py::arg("scale_dim_n") = 32, + py::arg("fp8_format") = "e4m3", + py::arg("scaling_mode") = "floor"); } diff --git a/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh b/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh index 188ccd5203..50e7e88afa 100644 --- a/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh +++ b/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh @@ -22,6 +22,7 @@ #include #include + #define MIN_CUDA_SM 1000 // SM90 = 900, SM100 = 1000 // Check if we're compiling for supported architecture @@ -697,7 +698,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) for (int y = 0; y < MXFP8_SHMEM_DIM_Y; y++) { for (int x = 0; x < MXFP8_SHMEM_DIM_X; x++) { printf("in_sh[%d][%d][%d] = %f\n", b, y, x, - (float)in_sh[b][y][x]); + DataTypeTraits::to_float(in_sh[b][y][x])); } } } @@ -900,10 +901,244 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) // #endif } +// 3D MXFP8 quantization kernel using 2D TMA +template +__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + mxfp8_quantize_kernel_3d( + const CUtensorMap* tensor_maps_input, + const CUtensorMap* tensor_maps_output, + e8m0_t *const scales_colwise, + const size_t E, const size_t N, const size_t K, + const size_t scales_colwise_stride_dim0, + const size_t scales_colwise_stride_dim1, + const size_t scales_colwise_stride_dim2) { + + static_assert(DataTypeTraits::is_supported, + "Input data type is not supported by this kernel."); + + // Only support colwise scaling for 3D case + constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + static_assert(USE_COLWISE_SCALING, "3D kernel only supports colwise scaling"); + + constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = + MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 + constexpr size_t SCALES_COLWISE_PER_CHUNK_X = + MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = + SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 + constexpr size_t SCALES_COLWISE_PER_BLOCK_X = + SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 + + const int block_offset_Y = + blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; + const int block_offset_X = + blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; + const int scales_colwise_block_offset_Y = + blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; + const int scales_colwise_block_offset_X = + blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; + + const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; + const int expert_idx = blockIdx.z; + + // The destination shared memory buffer of a bulk tensor operation should be + // 128 e8m0_t aligned + __shared__ alignas(128) + IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; + __shared__ alignas(128) OType + out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_X][MXFP8_SHMEM_DIM_Y]; + + constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in +// the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + + initialize_barriers( + mbar, is_master_thread); + + int parity = 0; + +// Process chunks +#pragma unroll + // Calculate chunk offsets + for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { + const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; + const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; + + const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; + const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + + const int scales_colwise_chunk_offset_Y = + scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; + const int scales_colwise_chunk_offset_X = + scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + +// Prefetch initial data +#pragma unroll + // Kick off TMA async copy from global to shared memory + for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; + ++prefetch_buff) { + const int chunk_stage_offset_Y = + chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; + const int chunk_stage_offset_X = chunk_offset_X; + copy_2d_to_shared(&in_sh[prefetch_buff], + &tensor_maps_input[expert_idx], + chunk_stage_offset_X, + chunk_stage_offset_Y, + shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + +// Process iterations +#pragma unroll + // Iterate through the chunk along the Y dim + for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { + const int buff = iter % MXFP8_BUFFERS_NUM; + const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + + // Prefetch next iteration data + if (next_iter < MXFP8_ITERATIONS) { + const int next_buff = next_iter % MXFP8_BUFFERS_NUM; + const int chunk_it_offset_y = + chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; + const int chunk_it_offset_x = chunk_offset_X; + copy_2d_to_shared(&in_sh[next_buff], + &tensor_maps_input[expert_idx], + chunk_it_offset_x, + chunk_it_offset_y, + shmem_buff_size, + &mbar[next_iter], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#if defined(DEBUG_SMEM) + // Debugging smem data + if (threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0) { + printf("Shared memory values for expert %d:\n", expert_idx); + for (int b = 0; b < MXFP8_BUFFERS_NUM; b++) { + for (int y = 0; y < MXFP8_SHMEM_DIM_Y; y++) { + for (int x = 0; x < MXFP8_SHMEM_DIM_X; x++) { + printf("in_sh[%d][%d][%d] = %f\n", b, y, x, + DataTypeTraits::to_float(in_sh[b][y][x])); + } + } + } + } +#endif + + // ======== 3d tensor column-wise scaling + + // Create bounds checker for this chunk + BoundsChecker bounds(N, K, chunk_offset_X, chunk_offset_Y); + + const size_t col = chunk_offset_X + tid_colwise_X; + const bool col_out_of_bounds = (col >= K); + + float in_compute[SCALE_DIM_Y]; + float amax = 0; + + // Calculate amax and prepare input values +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const bool out_of_bounds = + bounds.is_colwise_out_of_bounds(i, col, row_base); + + // Load and convert to float + float elt = + DataTypeTraits::to_float(in_sh[buff][i][tid_colwise_X]); + in_compute[i] = elt; + + // Update thread local amax + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } + + // Apply quantization to the local block. + e8m0_t e8m0_biased_scale; + OType quantized_values[SCALE_DIM_Y]; + quantize_block( + amax, e8m0_biased_scale, in_compute, quantized_values); + + // Write scaling factor to global memory + const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; + const int global_scales_offset_X = + scales_colwise_chunk_offset_X + tid_colwise_X; + + // Calculate scale offset using expert base offset plus local scale offset. + const int expert_scale_base_offset = expert_idx * scales_colwise_stride_dim0; + const int scale_idx = expert_scale_base_offset + + global_scales_offset_Y * scales_colwise_stride_dim1 + + global_scales_offset_X * scales_colwise_stride_dim2; + + // Bounds check for scale writing + const bool row_out_of_bounds = (row_base >= N); + if (!row_out_of_bounds && !col_out_of_bounds) { + scales_colwise[scale_idx] = e8m0_biased_scale; + } + + // Store quantized values to shared memory +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + out_colwise_sh[buff][tid_colwise_X][i] = quantized_values[i]; + } + +#if defined(DEBUG) + if (tid_colwise_X == 0) { + printf("Colwise: amax=%f, e8m0_scale=%u\n", amax, e8m0_biased_scale); + } +#endif + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + // Swap logical destination offsets for TMA to write into column major layout. + const int chunk_it_offset_y = chunk_offset_X; + const int chunk_it_offset_x = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + // TMA descriptor for this expert in the output tensor + reinterpret_cast(&tensor_maps_output[expert_idx]), + chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(&out_colwise_sh[buff])); + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + } + + destroy_barriers(mbar, is_master_thread); + // #endif +} + // Simple wrapper class for MXFP8 quantization class MXFP8Quantizer { public: - // Quantize a tensor using MXFP8 + // Quantize a 2D tensor using MXFP8 // input: pointer to input data // output_rowwise: pointer to row-wise quantized output (can be nullptr) // output_colwise: pointer to column-wise quantized output (can be nullptr) @@ -1044,6 +1279,146 @@ public: #undef LAUNCH_KERNEL +#endif + } + + // Quantize a 3D tensor using MXFP8 with colwise scaling + // input: pointer to input data with shape (E, N, K) and strides (N*K, 1, N) (column major) + // output_colwise: pointer to column-wise quantized output with same layout + // scales_colwise: pointer to column-wise scaling factors with shape (E, num_n_blocks, K) + // E, N, K: tensor dimensions + // scales_colwise_stride_dim0: stride for E dimension in scales + // scales_colwise_stride_dim1: stride for num_n_blocks dimension in scales + // input_dtype: data type of input + // output_dtype: FP8 output type (fp8e4m3 or fp8e5m2) + // scale_dim_n: block size for column-wise scaling along N dimension (typically 32) + static void + quantize_3d(const void *input, void *output_colwise, e8m0_t *scales_colwise, + size_t E, size_t N, size_t K, + size_t input_stride_dim0, size_t input_stride_dim1, size_t input_stride_dim2, + size_t output_stride_dim0, size_t output_stride_dim1, size_t output_stride_dim2, + size_t scales_colwise_stride_dim0, size_t scales_colwise_stride_dim1, size_t scales_colwise_stride_dim2, + DType input_dtype, DType output_dtype, + size_t scale_dim_n = 32, + ScaleCalculationMode scaling_mode = ScaleCalculationMode::FLOOR, + cudaStream_t stream = 0) { + + // Check parameters + assert(scale_dim_n == 32); // Only support 32 for now + assert(output_colwise != nullptr); + assert(scales_colwise != nullptr); + + // Calculate grid dimensions for 3D tensor: Z handles E dimension, X,Y handle (N,K) + const size_t chunks_Y = DIVUP(N, MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(K, MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + const dim3 block(MXFP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y, E); // 3D grid: Z dimension handles experts + + // Create TMA descriptors for each expert + // Allocate GPU-accessible memory for TMA descriptors + CUtensorMap* tensor_maps_input = nullptr; + CUtensorMap* tensor_maps_output = nullptr; + + // Use cudaMallocManaged for GPU-accessible TMA descriptors + cudaError_t err1 = cudaMallocManaged(&tensor_maps_input, E * sizeof(CUtensorMap)); + cudaError_t err2 = cudaMallocManaged(&tensor_maps_output, E * sizeof(CUtensorMap)); + + if (err1 != cudaSuccess || err2 != cudaSuccess) { + printf("Failed to allocate managed memory for TMA descriptors\n"); + return; + } + + int32_t input_bits_per_elem = get_dtype_bits(input_dtype); + int32_t output_bits_per_elem = get_dtype_bits(output_dtype); + + for (int expert_idx = 0; expert_idx < E; ++expert_idx) { + // Calculate expert base addresses using actual tensor strides + const char* input_base = static_cast(input); + char* output_base = static_cast(output_colwise); + + // Use input_stride_dim0 to get correct byte offset for each expert + void* input_expert_base_addr = const_cast(input_base) + + expert_idx * input_stride_dim0 * (input_bits_per_elem / 8); + void* output_expert_base_addr = output_base + + expert_idx * output_stride_dim0 * (output_bits_per_elem / 8); + + // Input tensor map for reading from a specific expert, from input shape (E,N,K). + // For input stride pattern (input_stride_dim0, input_stride_dim1, input_stride_dim2) + // within each expert (N,K) slice, the stride for rows is input_stride_dim1 (elements) + create_2D_tensor_map( + tensor_maps_input[expert_idx], + input_expert_base_addr, + input_dtype, + N, K, + MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, + input_stride_dim1, // stride between rows within expert (N,K) slice + input_bits_per_elem); // bits per elem in input + + // Output tensor map: column major layout with dimensions swapped for TMA + // For output stride pattern (output_stride_dim0, output_stride_dim1, output_stride_dim2) + // within each expert (K,N) slice (swapped), use output_stride_dim2 for TMA stride + create_2D_tensor_map( + tensor_maps_output[expert_idx], + output_expert_base_addr, + output_dtype, + K, N, // Swap for column major layout + MXFP8_SHMEM_DIM_X, MXFP8_SHMEM_DIM_Y, + output_stride_dim2, // stride for swapped dimensions in column major + output_bits_per_elem); // bits per elem in output fp8e4m3 + } + +// Launch 3D kernel based on input/output types and scaling dimensions +// Only compile kernel launches for SM90+ +#if defined(__CUDACC__) && \ + (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= MIN_CUDA_SM) + +// Use TMA and mbarrier instructions for 3D +#define LAUNCH_KERNEL_3D(IType, OType, SCALE_Y, SCALE_X, ScalingMode) \ + mxfp8_quantize_kernel_3d \ + <<>>( \ + tensor_maps_input, tensor_maps_output, \ + scales_colwise, \ + E, N, K, \ + scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2); + + // Validate output dtype + if (output_dtype != DType::kFloat8E4M3) { + printf("unsupported output dtype, must be fp8e4m3\n"); + exit(1); + } + + if (scaling_mode == ScaleCalculationMode::FLOOR) { + if (input_dtype == DType::kFloat32) { + LAUNCH_KERNEL_3D(float, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } else if (input_dtype == DType::kBFloat16) { + LAUNCH_KERNEL_3D(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::FLOOR); + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else if (scaling_mode == ScaleCalculationMode::RCEIL) { + if (input_dtype == DType::kFloat32) { + LAUNCH_KERNEL_3D(float, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } else if (input_dtype == DType::kBFloat16) { + LAUNCH_KERNEL_3D(bfloat16, fp8e4m3, 32, 1, ScaleCalculationMode::RCEIL); + } else { + printf("unsupported input dtype, must be float32 or bfloat16\n"); + exit(1); + } + } else { + printf("unsupported scaling mode\n"); + exit(1); + } + +#undef LAUNCH_KERNEL_3D + + // Clean up managed memory for TMA descriptors + cudaFree(tensor_maps_input); + cudaFree(tensor_maps_output); + #endif } }; diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index ae891c0dc9..24c1e6b60d 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -447,6 +447,7 @@ def backward(ctx, grad_out: torch.Tensor): def _to_mxfp8_dim1_3d( B: torch.Tensor, block_size: int = 32, + scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, ) -> tuple[torch.Tensor, torch.Tensor]: """ Convert a 3D tensor to MXFP8 format with (block_size, 1) scaling granularity. @@ -460,7 +461,7 @@ def _to_mxfp8_dim1_3d( hp_dtype=B_reshaped.dtype, gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, - scale_calculation_mode=ScaleCalculationMode.FLOOR, + scale_calculation_mode=scaling_mode, ) B_data = B_t_mx.qdata.t() # (K, E*N) -> (E*N, K) B_data = B_data.reshape(E, N, K) # (E*N, K) -> (E, N, K)