Skip to content

Commit faaf1ac

Browse files
make fp8 blockwise linear differentiable; use new kernels
1 parent ee6ce03 commit faaf1ac

File tree

4 files changed

+132
-33
lines changed

4 files changed

+132
-33
lines changed
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
import torch
9+
10+
from torchao.prototype.blockwise_fp8.blockwise_linear import BlockwiseQuantLinear
11+
from torchao.float8.float8_utils import compute_error
12+
13+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
14+
15+
16+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
17+
@pytest.mark.parametrize("in_features", [1024])
18+
@pytest.mark.parametrize("out_features", [1024])
19+
@pytest.mark.parametrize("batch_size", [1])
20+
@pytest.mark.parametrize("block_size", [128])
21+
def test_blockwise_quant_linear_fwd_bwd(
22+
in_features, out_features, batch_size, block_size,
23+
):
24+
if in_features % block_size != 0 or out_features % block_size != 0:
25+
pytest.skip(f"Dimensions must be divisible by block_size={block_size}")
26+
27+
torch.random.manual_seed(0)
28+
layer_test = BlockwiseQuantLinear(
29+
in_features=in_features,
30+
out_features=out_features,
31+
block_size=block_size,
32+
).cuda()
33+
34+
torch.random.manual_seed(0)
35+
layer_ref = torch.nn.Linear(
36+
in_features=in_features,
37+
out_features=out_features,
38+
).cuda()
39+
40+
41+
# Create input tensor
42+
x_test = torch.randn(batch_size, in_features).cuda()
43+
x_ref = x_test.clone().detach().requires_grad_(True)
44+
45+
# Forward pass
46+
y_test = layer_test(x_test)
47+
y_ref = layer_ref(x_ref)
48+
49+
# Compare outputs
50+
sqnr = compute_error(y_test, y_ref)
51+
breakpoint()
52+
assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"
53+
54+
# # Backward pass
55+
# y_test.sum().backward()
56+
# y_ref.sum().backward()
57+
58+
# # Compare input grads
59+
# sqnr = compute_error(x_test.grad, x_ref.grad)
60+
# assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"
61+
62+
# # Compare weight grads
63+
# sqnr = compute_error(layer_test.weight, layer_ref.weight)
64+
# assert sqnr >= 25.0, f"SQNR: {sqnr} must be >= 25.0"

torchao/prototype/blockwise_fp8/blockwise_linear.py

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
from torchao.prototype.blockwise_fp8.kernels import (
1111
blockwise_fp8_gemm,
12-
fp8_blockwise_act_quant,
12+
torch_blockwise_scale_act_quant,
13+
triton_quantize_fp8_block,
1314
)
1415

1516

16-
class BlockwiseQuantLinear(nn.Module):
17+
class BlockwiseQuantLinear(nn.Linear):
1718
"""
1819
Custom linear layer with support for quantized weights and optional bias.
1920
@@ -24,54 +25,81 @@ class BlockwiseQuantLinear(nn.Module):
2425
block_size (int): Block size for quantization. Defaults to 128.
2526
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
2627
"""
27-
28-
dtype = torch.bfloat16
28+
supported_dtypes = [
29+
torch.bfloat16,
30+
]
2931

3032
def __init__(
3133
self,
32-
in_features: int,
33-
out_features: int,
34-
bias: bool = False,
34+
*args,
3535
block_size: int = 128,
36-
dtype: torch.dtype = torch.float8_e4m3fn,
36+
dtype = torch.bfloat16,
37+
**kwargs,
3738
):
38-
super().__init__()
39-
supported_dtypes = [
40-
torch.float8_e4m3fn,
41-
torch.float8_e5m2,
42-
]
43-
assert dtype in supported_dtypes, (
44-
f"Unsupported dtype: {dtype}. Supported dtypes: {supported_dtypes}"
45-
)
46-
scale_in_features = (in_features + block_size - 1) // block_size
47-
scale_out_features = (out_features + block_size - 1) // block_size
48-
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype))
49-
self.weight.scale = self.scale = nn.Parameter(
50-
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
39+
super().__init__(*args, **kwargs)
40+
41+
assert dtype in self.supported_dtypes, (
42+
f"Unsupported dtype: {dtype}. Supported dtypes: {self.supported_dtypes}"
5143
)
5244
self.block_size = block_size
53-
self.dtype
54-
55-
if bias:
56-
self.bias = nn.Parameter(torch.empty(out_features))
57-
else:
58-
self.register_parameter("bias", None)
45+
self.dtype = dtype
5946

6047
def forward(self, x: torch.Tensor) -> torch.Tensor:
6148
"""
6249
Forward pass for the custom linear layer.
6350
6451
Args:
65-
x (torch.Tensor): Input tensor.
52+
x (torch.Tensor): input tensor.
6653
6754
Returns:
6855
torch.Tensor: Transformed tensor after linear computation.
6956
"""
70-
x, scale = fp8_blockwise_act_quant(x, self.block_size, self.dtype)
57+
return fp8_blockwise_mm.apply(x, self.weight, self.block_size)
58+
59+
60+
class fp8_blockwise_mm(torch.autograd.Function):
61+
@staticmethod
62+
def forward(ctx, x, weight, block_size):
63+
# torch.compile currently has the fastest activation quantization (1 x block_size)
64+
x_fp8, x_scale = torch_blockwise_scale_act_quant(x, tile_size=block_size)
65+
66+
# fbgemm currently has the fastest weight quantization (block_size x block_size)
67+
weight_fp8, weight_scale = triton_quantize_fp8_block(weight, block_m=block_size, block_k=block_size)
68+
7169
y = blockwise_fp8_gemm(
72-
x, scale, self.weight, self.weight.scale, self.block_size
70+
x_fp8, x_scale,
71+
weight_fp8, weight_scale,
72+
block_size,
7373
)
74-
75-
if self.bias is not None:
76-
y += self.bias
74+
ctx.save_for_backward(x_fp8, x_scale, weight_fp8, weight_scale)
75+
ctx.block_size = block_size
7776
return y
77+
78+
@staticmethod
79+
def backward(ctx, grad_output):
80+
x_fp8, x_scale, weight_fp8, weight_scale = ctx.saved_tensors
81+
block_size = ctx.block_size
82+
83+
grad_output_fp8, grad_output_scale = torch_blockwise_scale_act_quant(
84+
grad_output, block_size,
85+
)
86+
87+
grad_output_t_fp8, grad_output_t_scale = torch_blockwise_scale_act_quant(
88+
grad_output.t(), block_size,
89+
)
90+
91+
# grad_x = grad_output @ weight.T
92+
grad_x = blockwise_fp8_gemm(
93+
grad_output_fp8, grad_output_scale,
94+
weight_fp8.t(), weight_scale.t(),
95+
block_size,
96+
)
97+
98+
# grad_weight = grad_output.T @ x
99+
grad_weight = blockwise_fp8_gemm(
100+
grad_output_t_fp8, grad_output_t_scale,
101+
x_fp8, x_scale,
102+
block_size,
103+
)
104+
105+
return grad_x, grad_weight, None, None

torchao/prototype/blockwise_fp8/kernels.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import sys
78
import math
89
from typing import Optional, Tuple
910

@@ -12,6 +13,12 @@
1213
import triton.language as tl
1314
from triton import Config
1415

16+
# try:
17+
# from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import triton_quantize_fp8_block
18+
# except ImportError:
19+
# print("Please install fbgemm-gpu to use this feature")
20+
# sys.exit(1)
21+
1522
# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
1623

1724
fp8_gemm_configs = [

0 commit comments

Comments
 (0)