Skip to content

Commit 712a731

Browse files
merge main
Signed-off-by: Brian Dellabetta <[email protected]>
2 parents 7f2c5de + c1e53b1 commit 712a731

File tree

4 files changed

+40
-25
lines changed

4 files changed

+40
-25
lines changed

.github/workflows/test.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,16 @@ jobs:
8989
with:
9090
python-version: ${{ inputs.python }}
9191

92+
- name: install system dependencies
93+
run: |
94+
if command -v g++ >/dev/null 2>&1; then
95+
echo "found g++ compiler"
96+
else
97+
echo "installing g++ etc compilers..."
98+
sudo apt update && sudo apt install -y g++ gcc
99+
fi
100+
shell: bash
101+
92102
- name: checkout code
93103
id: checkout
94104
uses: actions/checkout@v4

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,10 @@ def __init__(
316316

317317
self.quantization_compressor = {}
318318
for format in self.compression_formats:
319-
self.quantization_compressor[
320-
format
321-
] = BaseCompressor.load_from_registry(
322-
format, config=quantization_config
319+
self.quantization_compressor[format] = (
320+
BaseCompressor.load_from_registry(
321+
format, config=quantization_config
322+
)
323323
)
324324

325325
# ----- used by hf quantizer ----- #
@@ -705,9 +705,12 @@ def decompress(self, model_path: str, model: Module):
705705
with override_quantization_status(
706706
self.quantization_config, QuantizationStatus.FROZEN
707707
):
708-
names_to_scheme = apply_quantization_config(
709-
model, self.quantization_config
710-
)
708+
apply_quantization_config(model, self.quantization_config)
709+
names_to_scheme: Set[QuantizationScheme] = {
710+
name: getattr(module, "quantization_scheme")
711+
for name, module in model.named_modules()
712+
if getattr(module, "quantization_scheme", None) is not None
713+
}
711714
# Load activation scales/zp or any other quantization parameters
712715
# Conditionally load the weight quantization parameters if we have a
713716
# dense compressor or if a sparsity compressor has already been applied

src/compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def decompress_weight(
123123
return decompressed_weight
124124

125125

126+
@torch.compile(fullgraph=True, dynamic=True)
126127
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
127128
"""
128129
Packs a tensor with values in the fp4 range into uint8.
@@ -145,12 +146,11 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
145146

146147
# Find closest valid FP4 value index for each element
147148
abs_x = torch.abs(x)
148-
abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
149-
for i, val in enumerate(kE2M1):
150-
abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
149+
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
150+
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]
151151

152152
# Apply sign bit (bit 3) to get final 4-bit representation
153-
indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
153+
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)
154154

155155
# Reshape to prepare for packing pairs of values
156156
indices = indices.reshape(-1)
@@ -174,6 +174,7 @@ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
174174

175175

176176
# reference: : https://github.com/vllm-project/vllm/pull/16362
177+
@torch.compile(fullgraph=True, dynamic=True)
177178
def unpack_fp4_from_uint8(
178179
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
179180
) -> torch.Tensor:

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def load_pretrained_quantization_parameters(
115115

116116
def apply_quantization_config(
117117
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
118-
) -> Dict[str, QuantizationScheme]:
118+
):
119119
"""
120120
Initializes the model for quantization in-place based on the given config.
121121
Optionally coverts quantizable modules to compressed_linear modules
@@ -125,19 +125,17 @@ def apply_quantization_config(
125125
:param run_compressed: Whether the model will be run in compressed mode or
126126
decompressed fully on load
127127
"""
128-
# Workaround for when HF Quantizer passes None, see PR #180
129-
if config is None:
130-
return dict()
131128

132-
# remove reference to the original `config`
133-
# argument. This function can mutate it, and we'd
134-
# like to keep the original `config` as it is.
135129
config = deepcopy(config)
130+
if config is None: # see PR #180
131+
return dict()
132+
133+
# preprocess to support kv cache scheme
134+
config = process_quantization_config(config)
135+
136136
# build mapping of targets to schemes for easier matching
137137
# use ordered dict to preserve target ordering in config
138138
target_to_scheme = OrderedDict()
139-
config = process_quantization_config(config)
140-
names_to_scheme = dict()
141139
for scheme in config.config_groups.values():
142140
for target in scheme.targets:
143141
target_to_scheme[target] = scheme
@@ -150,13 +148,20 @@ def apply_quantization_config(
150148
# quant scheme to the matching layers
151149
matched_targets = match_targets(name, submodule, target_to_scheme)
152150
scheme = _scheme_from_targets(target_to_scheme, matched_targets, name)
151+
152+
# target matched - add layer and scheme to target list
153+
submodule.quantization_scheme = scheme
154+
155+
# replace with run compressed if applicable
156+
# FUTURE: move this to model compressor
153157
if (
154158
run_compressed
155-
and config.format != CompressionFormat.dense.value
156159
and isinstance(submodule, torch.nn.Linear)
160+
and config.format != CompressionFormat.dense.value
157161
):
158162
from compressed_tensors.linear.compressed_linear import CompressedLinear
159163

164+
# TODO: expand to more module types
160165
compressed_linear = CompressedLinear.from_linear(
161166
submodule,
162167
quantization_scheme=scheme,
@@ -167,13 +172,9 @@ def apply_quantization_config(
167172
# target matched - add layer and scheme to target list
168173
submodule.quantization_scheme = scheme
169174

170-
names_to_scheme[name] = submodule.quantization_scheme
171-
172175
# apply current quantization status to each targeted submodule
173176
apply_quantization_status(submodule, config.quantization_status)
174177

175-
return names_to_scheme
176-
177178

178179
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
179180
"""

0 commit comments

Comments
 (0)