9
9
10
10
from torchao .prototype .blockwise_fp8 .kernels import (
11
11
blockwise_fp8_gemm ,
12
- fp8_blockwise_act_quant ,
12
+ torch_blockwise_scale_act_quant ,
13
+ triton_quantize_fp8_block ,
13
14
)
14
15
15
16
16
- class BlockwiseQuantLinear (nn .Module ):
17
+ class BlockwiseQuantLinear (nn .Linear ):
17
18
"""
18
19
Custom linear layer with support for quantized weights and optional bias.
19
20
@@ -24,54 +25,81 @@ class BlockwiseQuantLinear(nn.Module):
24
25
block_size (int): Block size for quantization. Defaults to 128.
25
26
dtype (torch.dtype): Data type for the weights. Defaults to torch.float8_e4m3fn.
26
27
"""
27
-
28
- dtype = torch .bfloat16
28
+ supported_dtypes = [
29
+ torch .bfloat16 ,
30
+ ]
29
31
30
32
def __init__ (
31
33
self ,
32
- in_features : int ,
33
- out_features : int ,
34
- bias : bool = False ,
34
+ * args ,
35
35
block_size : int = 128 ,
36
- dtype : torch .dtype = torch .float8_e4m3fn ,
36
+ dtype = torch .bfloat16 ,
37
+ ** kwargs ,
37
38
):
38
- super ().__init__ ()
39
- supported_dtypes = [
40
- torch .float8_e4m3fn ,
41
- torch .float8_e5m2 ,
42
- ]
43
- assert dtype in supported_dtypes , (
44
- f"Unsupported dtype: { dtype } . Supported dtypes: { supported_dtypes } "
45
- )
46
- scale_in_features = (in_features + block_size - 1 ) // block_size
47
- scale_out_features = (out_features + block_size - 1 ) // block_size
48
- self .weight = nn .Parameter (torch .empty (out_features , in_features , dtype = dtype ))
49
- self .weight .scale = self .scale = nn .Parameter (
50
- torch .empty (scale_out_features , scale_in_features , dtype = torch .float32 )
39
+ super ().__init__ (* args , ** kwargs )
40
+
41
+ assert dtype in self .supported_dtypes , (
42
+ f"Unsupported dtype: { dtype } . Supported dtypes: { self .supported_dtypes } "
51
43
)
52
44
self .block_size = block_size
53
- self .dtype
54
-
55
- if bias :
56
- self .bias = nn .Parameter (torch .empty (out_features ))
57
- else :
58
- self .register_parameter ("bias" , None )
45
+ self .dtype = dtype
59
46
60
47
def forward (self , x : torch .Tensor ) -> torch .Tensor :
61
48
"""
62
49
Forward pass for the custom linear layer.
63
50
64
51
Args:
65
- x (torch.Tensor): Input tensor.
52
+ x (torch.Tensor): input tensor.
66
53
67
54
Returns:
68
55
torch.Tensor: Transformed tensor after linear computation.
69
56
"""
70
- x , scale = fp8_blockwise_act_quant (x , self .block_size , self .dtype )
57
+ return fp8_blockwise_mm .apply (x , self .weight , self .block_size )
58
+
59
+
60
+ class fp8_blockwise_mm (torch .autograd .Function ):
61
+ @staticmethod
62
+ def forward (ctx , x , weight , block_size ):
63
+ # torch.compile currently has the fastest activation quantization (1 x block_size)
64
+ x_fp8 , x_scale = torch_blockwise_scale_act_quant (x , tile_size = block_size )
65
+
66
+ # fbgemm currently has the fastest weight quantization (block_size x block_size)
67
+ weight_fp8 , weight_scale = triton_quantize_fp8_block (weight , block_m = block_size , block_k = block_size )
68
+
71
69
y = blockwise_fp8_gemm (
72
- x , scale , self .weight , self .weight .scale , self .block_size
70
+ x_fp8 , x_scale ,
71
+ weight_fp8 , weight_scale ,
72
+ block_size ,
73
73
)
74
-
75
- if self .bias is not None :
76
- y += self .bias
74
+ ctx .save_for_backward (x_fp8 , x_scale , weight_fp8 , weight_scale )
75
+ ctx .block_size = block_size
77
76
return y
77
+
78
+ @staticmethod
79
+ def backward (ctx , grad_output ):
80
+ x_fp8 , x_scale , weight_fp8 , weight_scale = ctx .saved_tensors
81
+ block_size = ctx .block_size
82
+
83
+ grad_output_fp8 , grad_output_scale = torch_blockwise_scale_act_quant (
84
+ grad_output , block_size ,
85
+ )
86
+
87
+ grad_output_t_fp8 , grad_output_t_scale = torch_blockwise_scale_act_quant (
88
+ grad_output .t (), block_size ,
89
+ )
90
+
91
+ # grad_x = grad_output @ weight.T
92
+ grad_x = blockwise_fp8_gemm (
93
+ grad_output_fp8 , grad_output_scale ,
94
+ weight_fp8 .t (), weight_scale .t (),
95
+ block_size ,
96
+ )
97
+
98
+ # grad_weight = grad_output.T @ x
99
+ grad_weight = blockwise_fp8_gemm (
100
+ grad_output_t_fp8 , grad_output_t_scale ,
101
+ x_fp8 , x_scale ,
102
+ block_size ,
103
+ )
104
+
105
+ return grad_x , grad_weight , None , None
0 commit comments