Skip to content

make fp8 blockwise linear differentiable; use new kernels #2602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from triton.testing import do_bench

from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
from torchao.prototype.blockwise_fp8.kernels import (
blockwise_fp8_gemm,
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
Expand Down
169 changes: 169 additions & 0 deletions benchmarks/float8/bench_fp8_blockwise_quant_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
import argparse
import itertools
from dataclasses import dataclass
from typing import List

import torch
from tabulate import tabulate
from tqdm import tqdm
from utils import benchmark_microseconds

from torchao.prototype.blockwise_fp8.kernels import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
torch_blockwise_scale_act_quant,
torch_blockwise_scale_weight_quant,
triton_quantize_fp8_block,
)

device = torch.device("cuda")

# Needed since changing args to function causes recompiles
torch._dynamo.config.cache_size_limit = 1000


@dataclass(frozen=True)
class ExperimentConfig:
A_shape: tuple[int]
block_m: int
block_k: int


@dataclass(frozen=True)
class ExperimentResult:
torch_us: float
fbgemm_us: float
deepgemm_us: float


@dataclass(frozen=True)
class Experiment:
config: ExperimentConfig
result: ExperimentResult


def get_configs() -> List[ExperimentConfig]:
A_shapes = [
(1024, 1024),
(2048, 2048),
(4096, 4096),
(8192, 8192),
(16384, 16384),
(32768, 32768),
]
block_m_opts = [1, 128]
block_k_opts = [
128,
]
configs = []
for A_shape, block_m, block_k in itertools.product(
A_shapes,
block_m_opts,
block_k_opts,
):
configs.append(
ExperimentConfig(
A_shape=A_shape,
block_m=block_m,
block_k=block_k,
)
)
return configs


def run_experiment(
config: ExperimentConfig, args: argparse.Namespace
) -> ExperimentResult:
A = torch.randn(
*config.A_shape,
dtype=torch.bfloat16,
device=device,
)

# Torch and DeepGEMM implementations are specific to activation quantization (1 x block_size)
# and weight quantization (block_size x block_size)
if config.block_m == 1:
torch_func = torch.compile(torch_blockwise_scale_act_quant)
deepgemm_func = fp8_blockwise_act_quant
else:
torch_func = torch.compile(torch_blockwise_scale_weight_quant)
deepgemm_func = fp8_blockwise_weight_quant

# Validate output shapes and strides
torch_out, torch_scale = torch_func(A, tile_size=config.block_k)
deepgemm_out, deepgemm_scale = deepgemm_func(A, block_size=config.block_k)
fbgemm_out, fbgemm_scale = triton_quantize_fp8_block(
A, block_m=config.block_m, block_k=config.block_k, k_major=True
)
assert torch_out.shape == deepgemm_out.shape == fbgemm_out.shape
assert torch_out.stride() == deepgemm_out.stride() == fbgemm_out.stride()
assert torch_scale.shape == deepgemm_scale.shape == fbgemm_scale.shape
assert torch_scale.stride() == deepgemm_scale.stride() == fbgemm_scale.stride()

# Do benchmarking
torch_us = benchmark_microseconds(torch_func, A, tile_size=config.block_k)
deepgemm_us = benchmark_microseconds(
fp8_blockwise_act_quant, A, block_size=config.block_k
)
fbgemm_us = benchmark_microseconds(
triton_quantize_fp8_block,
A,
block_m=config.block_m,
block_k=config.block_k,
k_major=True,
)

return ExperimentResult(
torch_us=round(torch_us, 3),
fbgemm_us=round(fbgemm_us, 3),
deepgemm_us=round(deepgemm_us, 3),
)


def print_results(experiments: List[Experiment]):
headers = [
"A_shape",
"block_shape",
"torch_us",
"fbgemm_us",
"deepgemm_us",
]
rows = []
for experiment in experiments:
A_shape = f"({experiment.config.A_shape[0]}, {experiment.config.A_shape[1]})"
block_shape = f"({experiment.config.block_m},{experiment.config.block_k})"
rows.append(
[
A_shape,
block_shape,
experiment.result.torch_us,
experiment.result.fbgemm_us,
experiment.result.deepgemm_us,
]
)
print(tabulate(rows, headers=headers))


def main(args: argparse.Namespace):
torch.random.manual_seed(123)
configs = get_configs()
results = []
for config in tqdm(configs):
result = run_experiment(config, args)
results.append(Experiment(config=config, result=result))

# Use Tabulate to print results
print_results(results)


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("--compile", action="store_true")
args = arg_parser.parse_args()
main(args)
10 changes: 10 additions & 0 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch.utils.benchmark as benchmark
from torch.profiler import ProfilerActivity, profile
from triton.testing import do_bench


def profiler_output_to_filtered_time_by_kernel_name(
Expand Down Expand Up @@ -428,3 +429,12 @@ def do_benchmarks(
tops_sec = float(tops) / time_sec
pct_top_peak = tops_sec / peak_tops
return time_sec, tops_sec, pct_top_peak


def benchmark_microseconds(f, *args, warmup=25, rep=100, **kwargs):
return (
do_bench(
lambda: f(*args, **kwargs), warmup=warmup, rep=rep, return_mode="median"
)
* 1e3
)
65 changes: 65 additions & 0 deletions test/prototype/blockwise_fp8/test_blockwise_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8.blockwise_linear import Float8BlockwiseLinear

triton = pytest.importorskip("triton", reason="Triton required to run this test")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("in_features", [1024])
@pytest.mark.parametrize("out_features", [1024])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("block_size", [128])
def test_blockwise_quant_linear_fwd_bwd(
in_features,
out_features,
batch_size,
block_size,
):
if in_features % block_size != 0 or out_features % block_size != 0:
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")

torch.random.manual_seed(0)
layer_test = Float8BlockwiseLinear(
in_features=in_features,
out_features=out_features,
block_size=block_size,
).cuda()

torch.random.manual_seed(0)
layer_ref = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
).cuda()

# Create input tensor
x_test = torch.randn(batch_size, in_features).cuda()
x_ref = x_test.clone().detach().requires_grad_(True)

# Forward pass
y_test = layer_test(x_test)
y_ref = layer_ref(x_ref)

# Compare outputs
sqnr = compute_error(y_ref, y_test)
assert sqnr >= 25.0, f"SQNR: {sqnr.item()} must be >= 25.0"

# Backward pass
y_test.sum().backward()
y_ref.sum().backward()

# Compare input grads
sqnr = compute_error(x_ref.grad, x_test.grad)
assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"

# Compare weight grads
sqnr = compute_error(layer_ref.weight, layer_test.weight)
assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"
Loading
Loading