Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add fake quant for model and observer activation #318

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/emu/conf/compress/compress_emu3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ defaults:

data:
data_path: null
max_calib_data: null
max_seq_len: null
num_calibration_steps: null
max_seq_length: null
tokenzier_args:
tokenizer_path: BAAI/Emu3-Gen/
special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt
Expand Down
2 changes: 1 addition & 1 deletion examples/emu/conf/compress/compress_emu3_w4a16.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:

data:
data_path:
num_calibration_samples: 16
num_calibration_steps: 16
max_seq_length: 9216
tokenzier_args:
tokenizer_path: BAAI/Emu3-Gen/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ defaults:

data:
data_path:
num_calibration_samples: 16
num_calibration_steps: 16
max_seq_length: 8192
tokenzier_args: null

Expand Down
2 changes: 1 addition & 1 deletion examples/llava_onevision/conf/config_compress.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults:
- _self_
- compress: compress_llava_ov_w4a16
- compress: compress_llava_ov

experiment:
exp_name: llava_ov
Expand Down
115 changes: 76 additions & 39 deletions flagscale/compress/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import torch
from torch.nn import Module
from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.utils.fsdp.context import fix_fsdp_module_name
Expand All @@ -22,52 +23,94 @@
disable_quantization,
enable_quantization,
)
from llmcompressor.modifiers.quantization.calibration import initialize_observer, update_weight_zp_scale, freeze_module_quantization
from llmcompressor.modifiers.quantization.calibration import (
apply_calibration_status,
calibrate_input_hook,
calibrate_kv_cache_input_hook,
calibrate_kv_cache_output_hook,
calibrate_output_hook,
freeze_module_quantization,
initialize_observer,
set_unset_kv_cache,
update_weight_zp_scale,
)
from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained
from flagscale.compress.algo import SmoothQuantWrapper, RTNWrapper
from flagscale.compress.blockwise_compressor import BlockCompressor

from flagscale.runner.runner_utils import logger

__all__ = ["LLMCompressorAdapter"]

QUANT_MAPPING_NAMES = {
"gptq": GPTQWrapper
"gptq": GPTQWrapper,
}

BLOCKWISE_WRAPPER_NAMES = {
"smoothquant": SmoothQuantWrapper,
}

class LLMCompressorAdapter:
def __init__(self, model, scheme, targets, algo=None, ignore=None, dataset=None, num_calibration_steps=384):
def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dataset=None, num_calibration_steps=384):
self.model = model
modify_save_pretrained(self.model)
if algo is not None:
assert len(algo) == 1
for k, v in algo.items():
self.algo = k
self.algo_args = v
else:
self.algo = algo
self.algo_args = {}
self.scheme = scheme
self.ignore = ignore
self.targets = targets
self.num_calibration_steps = num_calibration_steps
self.dataset = dataset
self.config_groups = None
self.wrapper_cls = None
self.compress_granularity = None
self.layer_compressors_ = []
self.num_calibration_steps = num_calibration_steps
self.dataset = dataset

if (self.algo is None and is_preset_scheme(self.scheme)) or self.algo in list(QUANT_MAPPING_NAMES.keys()):
self.wrapper_cls = QUANT_MAPPING_NAMES[self.algo] if self.algo is not None else None
quant_config = self.init_quant_config()
self.require_calib = True

support_algos = list(QUANT_MAPPING_NAMES.keys()) + list(BLOCKWISE_WRAPPER_NAMES.keys())
if (self.algo is None and is_preset_scheme(self.scheme)) or self.algo in support_algos:
if self.algo is not None:
if self.algo in QUANT_MAPPING_NAMES:
self.wrapper_cls = QUANT_MAPPING_NAMES[self.algo]
self.compress_granularity = LayerCompressor
elif self.algo in BLOCKWISE_WRAPPER_NAMES:
self.wrapper_cls = BLOCKWISE_WRAPPER_NAMES[self.algo]
self.compress_granularity = BlockCompressor
else:
raise f"algorithm: {self.algo} not implemented"
else:
self.wrapper_cls = RTNWrapper
self.compress_granularity = LayerCompressor
quant_config = self.init_quant_config()

if quant_config is not None:
### find ignore and target to quant, initialize module for quant
### overwrite forward if quantization_enabled is Tue
apply_quantization_config(self.model, quant_config)
if self.wrapper_cls is None:
self.preprocess_weight()
self.require_calib = quant_config.requires_calibration_data()

self.init_compressor()
if self.require_calib:
if model.training == False: ### Post Training
assert self.dataset is not None, f"The algorithm {self.algo} you selected requires a calibration process, please provide the calibration data"
self.run_blockwise_calib_forward()
self.model.apply(freeze_module_quantization)
else: ### Training Aware
pass
else:
self.init_compressor()
if self.dataset is not None:
self.run_blockwise_calib_forward()
self.model.apply(freeze_module_quantization)
self.layer_compressors_[0].clear_early_stop()
for idx, layer_compressor in enumerate(self.layer_compressors_):
layer_compressor.pre_compress()
layer_compressor.compress()
layer_compressor.post_compress()
layer_compressor.revert_layer_wrappers()



def init_quant_config(self):
if self.scheme is not None:
# takes precedence over config_groups
Expand All @@ -87,19 +130,21 @@ def init_quant_config(self):
group_name = f"group_{idx}"
self.config_groups[group_name] = scheme

if self.config_groups is None or len(self.config_groups) == 0:
default_quant_scheme = QuantizationScheme(targets=self.targets)
self.config_groups = {"group_0": default_quant_scheme}
logger.info(
f"No config groups were provided, using default {self.config_groups}"
)
if self.config_groups is None or len(self.config_groups) == 0:
default_quant_scheme = QuantizationScheme(targets=self.targets)
self.config_groups = {"group_0": default_quant_scheme}
logger.info(
f"No config groups were provided, using default {self.config_groups}"
)

return QuantizationConfig(
config_groups=self.config_groups,
kv_cache_scheme=None, ### TODO(lvmengsi): not support kv cache quant for now
quantization_status=QuantizationStatus.INITIALIZED,
ignore=self.ignore,
)
return QuantizationConfig(
config_groups=self.config_groups,
kv_cache_scheme=None, ### TODO(lvmengsi): not support kv cache quant for now
quantization_status=QuantizationStatus.INITIALIZED,
ignore=self.ignore,
)
else:
return None

def init_compressor(self):
for name, layer in self.model.named_modules():
Expand All @@ -114,18 +159,10 @@ def init_compressor(self):
if matches := find_name_or_class_matches(name, layer, self.ignore):
continue
logger.info(f"prepare compressor for layer {name}")
compressor = LayerCompressor(self.wrapper_cls, self.model, layer, idx, name, self.algo_args)
compressor = self.compress_granularity(self.wrapper_cls, self.model, layer, idx, name, self.algo_args)
self.layer_compressors_.append(compressor)
self.layer_compressors_[0].set_early_stop()

def preprocess_weight(self):
for idx, (name, layer) in enumerate(self.model.named_modules()):
layer.apply(lambda module: initialize_observer(layer, base_name="weight"))
self.model.apply(update_weight_zp_scale)

def add_hook(self):
pass

@torch.no_grad()
def run_blockwise_calib_forward(self):
logger.info(f"start calibration")
Expand All @@ -147,4 +184,4 @@ def run_blockwise_calib_forward(self):
error = get_output_error(unquantized_outputs, quantized_outputs)
logger.info(f"Mean output error from quantization: {error:.3f}")
intermediates = quantized_outputs
self.model.apply(enable_quantization)
self.model.apply(enable_quantization)
2 changes: 2 additions & 0 deletions flagscale/compress/algo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .rtn import *
from .smooth_quant import *
13 changes: 0 additions & 13 deletions flagscale/compress/algo/algo_base.py

This file was deleted.

72 changes: 72 additions & 0 deletions flagscale/compress/algo/rtn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import torch
from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper
from compressed_tensors.utils import update_parameter_data
from compressed_tensors.quantization.lifecycle import fake_quantize
from flagscale.compress.observers.base import Observer

__all__ = ["RTNWrapper"]
class RTNWrapper(ModuleCompressionWrapper):
def __init__(self, name, layer, enable_fake_quant=False):
super(RTNWrapper, self).__init__(name, layer)
quantization_scheme = getattr(self.layer, "quantization_scheme", None)
self._enable_fake_quant = enable_fake_quant
self.weight_observer = None
self.input_observer = None
self.output_observer = None
self.weights_observer_args = getattr(quantization_scheme, "weights", None)
self.input_observer_args = getattr(quantization_scheme, "input_activations", None)
self.output_observer_args = getattr(quantization_scheme, "output_activations", None)
if self._enable_fake_quant:
if self.input_observer_args and self.input_observer_args.dynamic:
self.input_observer_args.observer = "minmax"
self.input_observer = Observer.load_from_registry(self.input_observer_args.get_observer(), quantization_args=self.input_observer_args)
if self.weights_observer_args:
# origin_weight = self.layer.weight.clone()
W = fake_quantize(self.layer.weight, self.layer.weight_scale, self.layer.weight_zero_point, self.weights_observer_args)
update_parameter_data(self.layer, W, f"weight")
del W
else:
if self.weights_observer_args and not self.weights_observer_args.dynamic:
self.weight_observer = Observer.load_from_registry(self.weights_observer_args.get_observer(), quantization_args=self.weights_observer_args)
if self.input_observer_args and not self.input_observer_args.dynamic:
self.input_observer = Observer.load_from_registry(self.input_observer_args.get_observer(), quantization_args=self.input_observer_args)
if self.output_observer_args and not self.output_observer_args.dynamic:
self.output_observer = Observer.load_from_registry(self.output_observer_args.get_observer(), quantization_args=self.output_observer_args)

def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
if self.input_observer:
updated_scale, updated_zero_point = self.input_observer(inp)
update_parameter_data(self.layer, updated_scale, f"input_scale")
update_parameter_data(self.layer, updated_zero_point, f"input_zero_point")

def compress(self, g_idx=None):
if self.weight_observer:
updated_scale, updated_zero_point = self.weight_observer(self.layer.weight, g_idx=g_idx)
update_parameter_data(self.layer, updated_scale, f"weight_scale")
update_parameter_data(self.layer, updated_zero_point, f"weight_zero_point")

def enable_fake_quant(self):
self._enable_fake_quant = True

def forward(self, inp, **kwargs):
"""
Run a forward pass of the wrapped layer
"""
if self._enable_fake_quant:
if self.input_observer_args:
print("self.input_observer_args: ", self.input_observer_args)
if self.input_observer_args.dynamic:
scale, zp = self.input_observer(inp)
tmp_inp = fake_quantize(inp, scale, zp, self.input_observer_args)
error = torch.nn.functional.mse_loss(inp, tmp_inp)
# print("input dynamic error: ", error, inp, tmp_inp, scale, zp)
inp = tmp_inp
del tmp_inp, error
else:
inp = fake_quantize(inp, self.layer.input_scale, self.layer.input_zero_point, self.input_observer_args)
out = self.layer(inp, **kwargs)
# if self._enable_fake_quant and self.output_observer:
# out = fake_quantize(out, self.layer.output_scale, self.layer.output_zero_point, self.output_observer.quantization_args)
torch.cuda.empty_cache()
return out
21 changes: 21 additions & 0 deletions flagscale/compress/algo/smooth_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from flagscale.compress.blockwise_wrapper import BlockCompressionWrapper
from llmcompressor.modifiers.smoothquant import *
from llmcompressor.modifiers.smoothquant.utils import DEFAULT_SMOOTHQUANT_MAPPINGS

__all__ = ["SmoothQuantWrapper"]
class SmoothQuantWrapper(BlockCompressionWrapper):
def __init__(self, name, layer):
super().__init__(name=name, layer=layer)
self.sq = SmoothQuantModifier()
self.sq.ignore = [] if not self.sq.ignore else self.sq.ignore
self.sq.mappings = self.sq._infer_mappings_from_model(self.layer)
self.sq.resolved_mappings_ = self.sq._resolve_mappings(self.layer)
self.sq.scales_ = {}

def add_batch(self, inp: torch.Tensor, out: torch.Tensor):
self.sq._setup_scale_hooks()

def compress(self, smoothing_strength):
self.sq.smoothing_strength = smoothing_strength
self.sq._apply_smoothing(self.layer)
42 changes: 42 additions & 0 deletions flagscale/compress/blockwise_compressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import operator
import torch
from torch.nn import Module
from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor
from llmcompressor.utils.fsdp.context import (
summon_full_params_context,
)
from llmcompressor.utils.pytorch.module import set_layer, get_layer

def replace_block(target: str, model: Module, target_module: Module):
parent_target = ".".join(target.split(".")[:-1])
parent_layer = get_layer(parent_target, model)[1]
setattr(parent_layer, target.split(".")[-1], target_module)

class BlockCompressor(LayerCompressor):
def pre_compress(self):
full_name = self.name
with summon_full_params_context(self.layer):
wrapper = self.module_compressor_class(full_name, self.layer)
replace_block(full_name, self.model, wrapper)
self.modules[full_name] = wrapper

self.layer = operator.attrgetter(self.name)(self.model)

def add_batch(name):
def tmp(_, inp, out):
self.modules[name].add_batch(inp[0].data, out[0].data)

return tmp

for name in self.modules:
self.handles.append(self.modules[name].register_forward_hook(add_batch(name)))

def revert_layer_wrappers(self):
"""
Reverts wrapped root modules back to their original structure
"""
for name, module_wrapper in self.modules.items():
full_name = self.name
replace_block(full_name, self.model, module_wrapper.layer)
torch.cuda.empty_cache()
self.modules = None
Loading
Loading