Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py

from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype import mxfp8_cuda
from torchao.prototype.moe_training.scaled_grouped_mm import (
_to_mxfp8_dim1_3d,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
input_shape: tuple[int]


@dataclass(frozen=True)
class ExperimentResult:
# time
to_mx_us: float
cuda_2d_us: float
cuda_3d_us: float
# mem bw
to_mx_gbps: float
cuda_2d_gbps: float
cuda_3d_gbps: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
# Llama4 shapes. Input activations are scaled along K dim.
input_shapes = [
(1, 8192, 5120),
(2, 8192, 5120),
(4, 8192, 5120),
(8, 8192, 5120),
(16, 8192, 5120),
(64, 8192, 5120),
]
configs = []
for shape in input_shapes:
configs.append(
ExperimentConfig(
input_shape=shape,
)
)
return configs


def run_experiment(config: ExperimentConfig) -> ExperimentResult:
block_size = 32
input_shape = config.input_shape
input_tensor = torch.randn(
*input_shape,
dtype=torch.bfloat16,
device=device,
)

def using_to_mx(x: torch.Tensor) -> torch.Tensor:
# Reference implementation
s_d1_ref, y_d1_ref = to_mx(
# Transpose (E,N,K) to (E,K,N) so N is final dim,
# since to_mx scales along that dim
x.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Transpose tensors and scales back so we have effectively
# quantized input shape (E, N, K) along N
y_d1_ref = y_d1_ref.transpose(-2, -1)
s_d1_ref = s_d1_ref.transpose(-2, -1)
return y_d1_ref, s_d1_ref

# bench to_mx
using_to_mx_c = torch.compile(using_to_mx)
scales_to_mx, data_to_mx = using_to_mx_c(input_tensor)
to_mx_time_us = benchmark_cuda_function_in_microseconds(
using_to_mx_c,
input_tensor,
)

# bench 2d dim1 kernel then transforming to col major
using_cuda_2d_c = torch.compile(_to_mxfp8_dim1_3d)
scales_cuda_2d, data_cuda_2d = using_cuda_2d_c(input_tensor)
time_cuda_2d_us = benchmark_cuda_function_in_microseconds(
using_cuda_2d_c,
input_tensor,
)

# bench 3d cuda kernel
data_cuda_3d, scales_cuda_3d = mxfp8_cuda.quantize_3d(input_tensor)
time_cuda_3d_us = benchmark_cuda_function_in_microseconds(
mxfp8_cuda.quantize_3d,
input_tensor,
)

# mem bw calculations
bytes_per_input_el = torch.finfo(torch.bfloat16).bits / 8
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
bytes_per_scale_el = torch.finfo(torch.float8_e8m0fnu).bits / 8

read_bytes = input_tensor.numel() * bytes_per_input_el
write_bytes = (
data_cuda_3d.numel() * bytes_per_output_el
+ scales_cuda_3d.numel() * bytes_per_scale_el
)

to_mx_gbps = ((read_bytes + write_bytes) / 1e9) / (to_mx_time_us / 1e6)
cuda_2d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_2d_us / 1e6)
cuda_3d_gbps = ((read_bytes + write_bytes) / 1e9) / (time_cuda_3d_us / 1e6)

return ExperimentResult(
# time
to_mx_us=to_mx_time_us,
cuda_2d_us=time_cuda_2d_us,
cuda_3d_us=time_cuda_3d_us,
# mem bw
to_mx_gbps=to_mx_gbps,
cuda_2d_gbps=cuda_2d_gbps,
cuda_3d_gbps=cuda_3d_gbps,
)


def print_results(experiments: List[Experiment]):
headers = [
"input_shape",
"to_mx_us",
"cuda_2d_us",
"cuda_3d_us",
"to_mx_gbps",
"cuda_2d_gbps",
"cuda_3d_gbps",
]
rows = []
for experiment in experiments:
rows.append(
[
str(experiment.config.input_shape),
experiment.result.to_mx_us,
experiment.result.cuda_2d_us,
experiment.result.cuda_3d_us,
round(experiment.result.to_mx_gbps, 3),
round(experiment.result.cuda_2d_gbps, 3),
round(experiment.result.cuda_3d_gbps, 3),
]
)
print(tabulate(rows, headers=headers))


def main():
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
main()
56 changes: 54 additions & 2 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
if not (torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9):
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


from torchao.prototype.moe_training.kernels.float8_rowwise import (
triton_fp8_rowwise_3d_transpose_rhs,
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
Expand All @@ -38,8 +37,11 @@
torch_to_float8_per_group_colwise,
torch_to_float8_per_group_rowwise,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.mx_tensor import ScaleCalculationMode, to_mx
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
is_sm_at_least_100,
)


@skip_if_rocm("ROCm enablement in progress")
Expand Down Expand Up @@ -313,3 +315,53 @@ def test_mxfp8_per_group_blocked_scales_2d2d(
output_group_offsets,
)
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"


@pytest.mark.skipif(
not is_sm_at_least_100(),
reason="MXFP8 requires CUDA capability 10.0 or greater",
)
@pytest.mark.parametrize("E", (1, 2, 4, 8))
@pytest.mark.parametrize("N", (32, 64, 8192))
@pytest.mark.parametrize("K", (32, 64, 8192))
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,))
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
from torchao.prototype import mxfp8_cuda

scaling_mode_str = (
"floor" if scaling_mode == ScaleCalculationMode.FLOOR else "rceil"
)
block_size = 32

# Use disinct incrementing values from 0 to E*M*K-1 to make debugging easier.
x = (
torch.arange(0, E * N * K, dtype=input_dtype, device="cuda")
.reshape(E, N, K)
.contiguous()
)

# Reference implementation
s_d1_ref, y_d1_ref = to_mx(
# Transpose so N is final dim, since to_mx scales along that dim
x.transpose(-2, -1).contiguous(),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Transpose tensors and scales back so we have effectively
# quantized input shape (E, N, K) along N
y_d1_ref = y_d1_ref.transpose(-2, -1)
s_d1_ref = s_d1_ref.transpose(-2, -1)

# CUDA implementation (should work with any stride pattern)
y_d1, s_d1 = mxfp8_cuda.quantize_3d(
x, scale_dim_n=block_size, scaling_mode=scaling_mode_str
)

# Check scales
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)

# Check quantized values
torch.testing.assert_close(y_d1, y_d1_ref, rtol=0, atol=0)
assert y_d1.stride() == y_d1_ref.stride(), "quantized tensor strides do not match"
68 changes: 68 additions & 0 deletions torchao/csrc/cuda/mx_kernels/mxfp8_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,72 @@ void mxfp8_quantize_cuda(const torch::Tensor &input,
stream);
}

void mxfp8_quantize_3d_cuda(const torch::Tensor &input,
torch::Tensor &output_colwise,
torch::Tensor &scales_colwise,
int64_t scale_dim_n,
const std::string &fp8_format,
const std::string &scaling_mode) {

// Get tensor properties for 3D tensor (E, N, K)
const int64_t E = input.size(0);
const int64_t N = input.size(1);
const int64_t K = input.size(2);

// Get data pointers
const void *input_ptr = input.data_ptr();
void *output_colwise_ptr = output_colwise.data_ptr();
e8m0_t *scales_colwise_ptr =
reinterpret_cast<e8m0_t *>(scales_colwise.data_ptr());

// Get CUDA stream
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Get strides of scales tensor
int64_t scales_colwise_stride_dim0 = scales_colwise.stride(0);
int64_t scales_colwise_stride_dim1 = scales_colwise.stride(1);
int64_t scales_colwise_stride_dim2 = scales_colwise.stride(2);

// Get input tensor strides for generic layout support
int64_t input_stride_dim0 = input.stride(0); // E dimension stride
int64_t input_stride_dim1 = input.stride(1); // N dimension stride
int64_t input_stride_dim2 = input.stride(2); // K dimension stride

// Get output tensor strides (shoudl be col major)
int64_t output_stride_dim0 = output_colwise.stride(0); // E dimension stride
int64_t output_stride_dim1 = output_colwise.stride(1); // N dimension stride
int64_t output_stride_dim2 = output_colwise.stride(2); // K dimension stride


#if defined(DEBUG)
printf("mxfp8_quantize_3d_cuda:\n");
printf("Quantizing 3D input tensor of size %ld x %ld x %ld\n", E, N, K);
printf("scaling_mode: %s\n", scaling_mode.c_str());
printf("Scale dim n: %ld\n", scale_dim_n);
printf("Output scale shape: %ld x %ld x %ld\n",
scales_colwise.sizes()[0], scales_colwise.sizes()[1], scales_colwise.sizes()[2]);
printf("scales_colwise_stride_dim0 = %ld\n", scales_colwise_stride_dim0);
printf("scales_colwise_stride_dim1 = %ld\n", scales_colwise_stride_dim1);
printf("input_stride_dim0 = %ld\n", input_stride_dim0);
printf("input_stride_dim1 = %ld\n", input_stride_dim1);
printf("input_stride_dim2 = %ld\n", input_stride_dim2);
printf("output_stride_dim0 = %ld\n", output_stride_dim0);
printf("output_stride_dim1 = %ld\n", output_stride_dim1);
printf("output_stride_dim2 = %ld\n", output_stride_dim2);
#endif

// Call the 3D quantization kernel
MXFP8Quantizer::quantize_3d(input_ptr,
output_colwise_ptr,
scales_colwise_ptr,
E, N, K,
input_stride_dim0, input_stride_dim1, input_stride_dim2,
output_stride_dim0, output_stride_dim1, output_stride_dim2,
scales_colwise_stride_dim0, scales_colwise_stride_dim1, scales_colwise_stride_dim2,
get_input_dtype(input), get_output_dtype(fp8_format),
scale_dim_n,
get_scaling_mode(scaling_mode),
stream);
}

} // namespace mxfp8
Loading
Loading