Skip to content

Commit 01af659

Browse files
QuantizationMetadata class
Signed-off-by: Brian Dellabetta <[email protected]>
1 parent 7d8c5a4 commit 01af659

File tree

4 files changed

+65
-59
lines changed

4 files changed

+65
-59
lines changed

src/compressed_tensors/quantization/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@
1717

1818
from .quant_args import *
1919
from .quant_config import *
20-
from .quant_names import *
20+
from .quant_metadata import *
2121
from .quant_scheme import *
2222
from .lifecycle import *

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020

2121
import torch
2222
from compressed_tensors.quantization import (
23-
ALL_QPARAM_NAMES,
2423
FP8_E4M3_DATA,
2524
ActivationOrdering,
2625
KVCacheScaleType,
2726
QuantizationArgs,
27+
QuantizationMetadata,
2828
QuantizationScheme,
2929
QuantizationStatus,
3030
QuantizationStrategy,
@@ -76,7 +76,7 @@ def initialize_module_for_quantization(
7676
# no scheme passed and layer not targeted for quantization - skip
7777
return
7878

79-
_clear_all_qparams(module)
79+
QuantizationMetadata.clear_all_qparams(module)
8080

8181
if is_attention_module(module):
8282
# quantized actions based on calltime status
@@ -133,19 +133,6 @@ def is_attention_module(module: Module):
133133
)
134134

135135

136-
def _clear_all_qparams(
137-
module: Module,
138-
):
139-
"""
140-
Clear all previously registered quantization parameters from module
141-
142-
:param module: module to clear qparams from
143-
"""
144-
for key in ALL_QPARAM_NAMES:
145-
if hasattr(module, key):
146-
delete_offload_parameter(module, key)
147-
148-
149136
def _initialize_scale_zero_point(
150137
module: Module,
151138
base_name: str,
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from enum import Enum
16+
17+
from compressed_tensors.utils import delete_offload_parameter
18+
from torch.nn import Module
19+
20+
21+
__all__ = ["QuantizationMetadata", "KVCacheScaleType"]
22+
23+
24+
class KVCacheScaleType(Enum):
25+
KEY = "k_scale"
26+
VALUE = "v_scale"
27+
28+
29+
class QuantizationMetadata:
30+
"""
31+
Container class for metadata related to quantization
32+
"""
33+
34+
@staticmethod
35+
def all_qparam_names():
36+
"""
37+
All quantization parameter names that might be registered
38+
onto a module during lifecycle (excluding serialized parameters)
39+
"""
40+
return [KVCacheScaleType.KEY.value, KVCacheScaleType.VALUE.value] + [
41+
f"{base_name}_{suffix}"
42+
for base_name in ("input", "weight", "output")
43+
for suffix in (
44+
"global_scale",
45+
"scale",
46+
"zero_point",
47+
"g_idx",
48+
)
49+
]
50+
51+
@classmethod
52+
def clear_all_qparams(cls, module: Module):
53+
"""
54+
Remove all parameters related to quantization that might have
55+
been registered onto a module previously in lifecycle (excluding
56+
serialized parameters)
57+
58+
:param module: Module to clear
59+
"""
60+
for key in cls.all_qparam_names():
61+
if hasattr(module, key):
62+
delete_offload_parameter(module, key)

src/compressed_tensors/quantization/quant_names.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)