Skip to content

Commit 864df00

Browse files
committed
add fp4
Signed-off-by: Kyle Sayers <[email protected]>
1 parent dc43b64 commit 864df00

File tree

3 files changed

+63
-27
lines changed

3 files changed

+63
-27
lines changed

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def initialize_qparams(
218218

219219
group_size = quantization_args.group_size
220220
num_groups = strategy_cdiv(observed_shape[-1], group_size, strategy)
221-
expected_shape = (*observed_shape[:-1], num_groups)
221+
expected_shape = (observed_shape[-2], num_groups)
222222

223223
# initialize activation ordering if applicable
224224
if actorder == ActivationOrdering.GROUP:

tests/mock_observer.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Tuple
15+
from typing import Optional, Tuple
1616
from weakref import ref
1717

1818
import torch
@@ -42,7 +42,7 @@ def get_min_max(self, observed: torch.Tensor):
4242
return min_vals, max_vals
4343

4444
def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
45-
observed = flatten_for_quantization(observed, self.base_name, self.args)
45+
observed = flatten_for_calibration(observed, self.base_name, self.args)
4646

4747
self.min_vals, self.max_vals = self.get_min_max(observed)
4848

@@ -57,26 +57,31 @@ def forward(self, observed: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
5757

5858
def get_global_scale(self, observed: torch.Tensor):
5959
observed = observed.reshape((1, 1, -1)) # per tensor reshape
60-
min_vals, max_vals = self.get_min_max(observed)
61-
global_scale = generate_gparam(min_vals, max_vals)
60+
self.min_vals, self.max_vals = self.get_min_max(observed)
61+
global_scale = generate_gparam(self.min_vals, self.max_vals)
6262

6363
return global_scale
6464

6565

66-
def flatten_for_quantization(
67-
value: torch.Tensor, base_name: str, args: QuantizationArgs
66+
def flatten_for_calibration(
67+
value: torch.Tensor,
68+
base_name: str,
69+
args: QuantizationArgs,
70+
g_idx: Optional[torch.Tensor] = None,
6871
) -> torch.Tensor:
6972
if base_name == "weight":
70-
return flatten_weight_for_quantization(value, args)
73+
return _flatten_weight(value, args, g_idx)
7174
elif base_name in ("input", "output"):
72-
return flatten_activation_for_quantization(value, args)
75+
return _flatten_activation(value, args)
7376
elif base_name in ("q", "k", "v"):
74-
return flatten_attention_for_quantization(value, args)
77+
return _flatten_attention(value, args)
7578
else:
7679
raise ValueError(f"Unknown quantization base name: {base_name}")
7780

7881

79-
def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs):
82+
def _flatten_weight(
83+
value: torch.Tensor, args: QuantizationArgs, g_idx: Optional[torch.Tensor] = None
84+
):
8085
# value.shape = (num_rows, num_cols)
8186

8287
if args.strategy == QuantizationStrategy.TENSOR:
@@ -91,34 +96,32 @@ def flatten_weight_for_quantization(value: torch.Tensor, args: QuantizationArgs)
9196
return value.unsqueeze(-2).unsqueeze(0)
9297

9398
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
99+
if g_idx is not None:
100+
value = value.index_select(dim=1, index=torch.argsort(g_idx))
101+
94102
# (1, num_rows, num_groups, group_size)
95103
return value.unflatten(-1, (-1, args.group_size)).unsqueeze(0)
96104

97105
if args.strategy == QuantizationStrategy.BLOCK:
98106
# (1, num_block_rows, num_block_cols, block_width * block_height)
99107
block_height, block_width = args.block_structure
100-
num_rows, num_cols = value.shape
101-
num_block_rows = strategy_cdiv(num_rows, block_height, args.strategy)
102-
num_block_cols = strategy_cdiv(num_cols, block_width, args.strategy)
108+
rows, cols = value.shape
109+
block_rows = strategy_cdiv(rows, block_height, args.strategy, strict=True)
110+
block_cols = strategy_cdiv(cols, block_width, args.strategy, strict=True)
103111
return (
104-
value.reshape(
105-
num_block_rows,
106-
block_height,
107-
num_block_cols,
108-
block_width,
109-
)
112+
value.reshape(block_rows, block_height, block_cols, block_width)
110113
.transpose(1, 2)
111114
.flatten(-2, -1)
112115
.unsqueeze(0)
113116
)
114117

115118
if args.strategy == QuantizationStrategy.ATTN_HEAD:
116-
raise ValueError("attention head quantization cannot be applied to weights")
119+
raise ValueError("Attention head quantization cannot be applied to weights")
117120

118121
assert False, f"Unknown strategy {args.strategy}"
119122

120123

121-
def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationArgs):
124+
def _flatten_activation(value: torch.Tensor, args: QuantizationArgs):
122125
# value.shape = (batch_size, seq_len, hidden_dim)
123126

124127
if args.strategy == QuantizationStrategy.TENSOR:
@@ -128,7 +131,7 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
128131
if args.strategy == QuantizationStrategy.TOKEN:
129132
# (batch_size, seq_len, hidden_dim)
130133
# warning: token quantization uses `compute_dynamic_scales_and_zp`
131-
return value.flatten(2, -1)
134+
return value
132135

133136
if args.strategy == QuantizationStrategy.CHANNEL:
134137
raise ValueError("Channel quantization cannot be applied to activations")
@@ -142,12 +145,12 @@ def flatten_activation_for_quantization(value: torch.Tensor, args: QuantizationA
142145
raise ValueError("Block quantization cannot be applied to activations")
143146

144147
if args.strategy == QuantizationStrategy.ATTN_HEAD:
145-
raise ValueError("attention head quantization cannot be applied to linear acts")
148+
raise ValueError("Attention head quantization cannot be applied to activations")
146149

147150
assert False, f"Unknown strategy {args.strategy}"
148151

149152

150-
def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationArgs):
153+
def _flatten_attention(value: torch.Tensor, args: QuantizationArgs):
151154
# value.shape = (batch_size, num_heads, seq_len, head_dim)
152155

153156
if args.strategy == QuantizationStrategy.TENSOR:
@@ -161,7 +164,8 @@ def flatten_attention_for_quantization(value: torch.Tensor, args: QuantizationAr
161164
raise ValueError("Channel quantization cannot be applied to attention")
162165

163166
if args.strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
164-
raise ValueError("Group quantization cannot be applied to attention")
167+
# batch_size * num_heads * seq_len, num_groups, group_size)
168+
return value.flatten(0, 2).unflatten(-1, (-1, args.group_size))
165169

166170
if args.strategy == QuantizationStrategy.BLOCK:
167171
raise ValueError("Block quantization cannot be applied to attention")

tests/test_quantization/lifecycle/test_static_lifecycle.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,35 @@ class MockAttention(torch.nn.Module):
309309
# static token is not supported
310310
# channel is not supported
311311
# group is not supported
312-
# tensor group is not supported
312+
(
313+
QuantizationArgs(
314+
num_bits=4,
315+
type="float", # must be fp4
316+
symmetric=True,
317+
strategy="tensor_group",
318+
dynamic="local",
319+
group_size=2,
320+
),
321+
torch.tensor([0.0]),
322+
torch.tensor([23.0]),
323+
torch.tensor(
324+
[
325+
[
326+
[
327+
[0.0000, 1.0234, 2.0469, 3.0781],
328+
[3.2812, 4.9375, 4.9375, 7.3750],
329+
[9.0000, 9.0000, 10.6875, 10.6875],
330+
],
331+
[
332+
[13.1250, 13.1250, 14.7500, 14.7500],
333+
[16.3750, 16.3750, 19.7500, 19.7500],
334+
[21.3750, 21.3750, 23.0000, 23.0000],
335+
],
336+
]
337+
]
338+
),
339+
0.55,
340+
),
313341
# block is not supported
314342
(
315343
QuantizationArgs(
@@ -369,6 +397,10 @@ def test_static_attention_quantization(
369397
attention.k_observer = MockMinMaxObserver("k", args, attention)
370398

371399
# calibrate quantization parameters
400+
if hasattr(attention, "k_global_scale"):
401+
global_scale = attention.k_observer.get_global_scale(input)
402+
attention.k_global_scale.data = global_scale
403+
372404
if scheme.input_activations.dynamic is False:
373405
scale, zero_point = attention.k_observer(input)
374406
attention.k_scale.data = scale

0 commit comments

Comments
 (0)