diff --git a/examples/emu/conf/compress/compress_emu3.yaml b/examples/emu/conf/compress/compress_emu3.yaml index 13e18481a..2e5744fda 100644 --- a/examples/emu/conf/compress/compress_emu3.yaml +++ b/examples/emu/conf/compress/compress_emu3.yaml @@ -4,8 +4,8 @@ defaults: data: data_path: null - max_calib_data: null - max_seq_len: null + num_calibration_steps: null + max_seq_length: null tokenzier_args: tokenizer_path: BAAI/Emu3-Gen/ special_tokens_file: BAAI/Emu3-Gen/emu3_vision_tokens.txt diff --git a/examples/emu/conf/compress/compress_emu3_w4a16.yaml b/examples/emu/conf/compress/compress_emu3_w4a16.yaml index 534727f67..c3174fef3 100644 --- a/examples/emu/conf/compress/compress_emu3_w4a16.yaml +++ b/examples/emu/conf/compress/compress_emu3_w4a16.yaml @@ -4,7 +4,7 @@ defaults: data: data_path: - num_calibration_samples: 16 + num_calibration_steps: 16 max_seq_length: 9216 tokenzier_args: tokenizer_path: BAAI/Emu3-Gen/ diff --git a/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml b/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml index afdb92144..4dbc4198a 100644 --- a/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml +++ b/examples/llava_onevision/conf/compress/compress_llava_ov_w4a16.yaml @@ -4,7 +4,7 @@ defaults: data: data_path: - num_calibration_samples: 16 + num_calibration_steps: 16 max_seq_length: 8192 tokenzier_args: null diff --git a/examples/llava_onevision/conf/config_compress.yaml b/examples/llava_onevision/conf/config_compress.yaml index 0cd5a76b4..a5d2c3be3 100644 --- a/examples/llava_onevision/conf/config_compress.yaml +++ b/examples/llava_onevision/conf/config_compress.yaml @@ -1,6 +1,6 @@ defaults: - _self_ - - compress: compress_llava_ov_w4a16 + - compress: compress_llava_ov experiment: exp_name: llava_ov diff --git a/flagscale/compress/adapter.py b/flagscale/compress/adapter.py index 7414159b4..e055aec99 100644 --- a/flagscale/compress/adapter.py +++ b/flagscale/compress/adapter.py @@ -2,6 +2,7 @@ import os import torch +from torch.nn import Module from llmcompressor.modifiers.quantization.gptq.utils.gptq_wrapper import GPTQWrapper from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor from llmcompressor.utils.fsdp.context import fix_fsdp_module_name @@ -22,21 +23,36 @@ disable_quantization, enable_quantization, ) -from llmcompressor.modifiers.quantization.calibration import initialize_observer, update_weight_zp_scale, freeze_module_quantization +from llmcompressor.modifiers.quantization.calibration import ( + apply_calibration_status, + calibrate_input_hook, + calibrate_kv_cache_input_hook, + calibrate_kv_cache_output_hook, + calibrate_output_hook, + freeze_module_quantization, + initialize_observer, + set_unset_kv_cache, + update_weight_zp_scale, +) from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained +from flagscale.compress.algo import SmoothQuantWrapper, RTNWrapper +from flagscale.compress.blockwise_compressor import BlockCompressor from flagscale.runner.runner_utils import logger __all__ = ["LLMCompressorAdapter"] QUANT_MAPPING_NAMES = { - "gptq": GPTQWrapper + "gptq": GPTQWrapper, } +BLOCKWISE_WRAPPER_NAMES = { + "smoothquant": SmoothQuantWrapper, +} + class LLMCompressorAdapter: - def __init__(self, model, scheme, targets, algo=None, ignore=None, dataset=None, num_calibration_steps=384): + def __init__(self, model, scheme=None, targets=None, algo=None, ignore=None, dataset=None, num_calibration_steps=384): self.model = model - modify_save_pretrained(self.model) if algo is not None: assert len(algo) == 1 for k, v in algo.items(): @@ -44,30 +60,57 @@ def __init__(self, model, scheme, targets, algo=None, ignore=None, dataset=None, self.algo_args = v else: self.algo = algo + self.algo_args = {} self.scheme = scheme self.ignore = ignore self.targets = targets + self.num_calibration_steps = num_calibration_steps + self.dataset = dataset + self.config_groups = None self.wrapper_cls = None + self.compress_granularity = None self.layer_compressors_ = [] - self.num_calibration_steps = num_calibration_steps - self.dataset = dataset - - if (self.algo is None and is_preset_scheme(self.scheme)) or self.algo in list(QUANT_MAPPING_NAMES.keys()): - self.wrapper_cls = QUANT_MAPPING_NAMES[self.algo] if self.algo is not None else None - quant_config = self.init_quant_config() + self.require_calib = True + + support_algos = list(QUANT_MAPPING_NAMES.keys()) + list(BLOCKWISE_WRAPPER_NAMES.keys()) + if (self.algo is None and is_preset_scheme(self.scheme)) or self.algo in support_algos: + if self.algo is not None: + if self.algo in QUANT_MAPPING_NAMES: + self.wrapper_cls = QUANT_MAPPING_NAMES[self.algo] + self.compress_granularity = LayerCompressor + elif self.algo in BLOCKWISE_WRAPPER_NAMES: + self.wrapper_cls = BLOCKWISE_WRAPPER_NAMES[self.algo] + self.compress_granularity = BlockCompressor + else: + raise f"algorithm: {self.algo} not implemented" + else: + self.wrapper_cls = RTNWrapper + self.compress_granularity = LayerCompressor + quant_config = self.init_quant_config() + if quant_config is not None: ### find ignore and target to quant, initialize module for quant ### overwrite forward if quantization_enabled is Tue apply_quantization_config(self.model, quant_config) - if self.wrapper_cls is None: - self.preprocess_weight() + self.require_calib = quant_config.requires_calibration_data() + + self.init_compressor() + if self.require_calib: + if model.training == False: ### Post Training + assert self.dataset is not None, f"The algorithm {self.algo} you selected requires a calibration process, please provide the calibration data" + self.run_blockwise_calib_forward() + self.model.apply(freeze_module_quantization) + else: ### Training Aware + pass else: - self.init_compressor() - if self.dataset is not None: - self.run_blockwise_calib_forward() - self.model.apply(freeze_module_quantization) + self.layer_compressors_[0].clear_early_stop() + for idx, layer_compressor in enumerate(self.layer_compressors_): + layer_compressor.pre_compress() + layer_compressor.compress() + layer_compressor.post_compress() + layer_compressor.revert_layer_wrappers() + - def init_quant_config(self): if self.scheme is not None: # takes precedence over config_groups @@ -87,19 +130,21 @@ def init_quant_config(self): group_name = f"group_{idx}" self.config_groups[group_name] = scheme - if self.config_groups is None or len(self.config_groups) == 0: - default_quant_scheme = QuantizationScheme(targets=self.targets) - self.config_groups = {"group_0": default_quant_scheme} - logger.info( - f"No config groups were provided, using default {self.config_groups}" - ) + if self.config_groups is None or len(self.config_groups) == 0: + default_quant_scheme = QuantizationScheme(targets=self.targets) + self.config_groups = {"group_0": default_quant_scheme} + logger.info( + f"No config groups were provided, using default {self.config_groups}" + ) - return QuantizationConfig( - config_groups=self.config_groups, - kv_cache_scheme=None, ### TODO(lvmengsi): not support kv cache quant for now - quantization_status=QuantizationStatus.INITIALIZED, - ignore=self.ignore, - ) + return QuantizationConfig( + config_groups=self.config_groups, + kv_cache_scheme=None, ### TODO(lvmengsi): not support kv cache quant for now + quantization_status=QuantizationStatus.INITIALIZED, + ignore=self.ignore, + ) + else: + return None def init_compressor(self): for name, layer in self.model.named_modules(): @@ -114,18 +159,10 @@ def init_compressor(self): if matches := find_name_or_class_matches(name, layer, self.ignore): continue logger.info(f"prepare compressor for layer {name}") - compressor = LayerCompressor(self.wrapper_cls, self.model, layer, idx, name, self.algo_args) + compressor = self.compress_granularity(self.wrapper_cls, self.model, layer, idx, name, self.algo_args) self.layer_compressors_.append(compressor) self.layer_compressors_[0].set_early_stop() - def preprocess_weight(self): - for idx, (name, layer) in enumerate(self.model.named_modules()): - layer.apply(lambda module: initialize_observer(layer, base_name="weight")) - self.model.apply(update_weight_zp_scale) - - def add_hook(self): - pass - @torch.no_grad() def run_blockwise_calib_forward(self): logger.info(f"start calibration") @@ -147,4 +184,4 @@ def run_blockwise_calib_forward(self): error = get_output_error(unquantized_outputs, quantized_outputs) logger.info(f"Mean output error from quantization: {error:.3f}") intermediates = quantized_outputs - self.model.apply(enable_quantization) \ No newline at end of file + self.model.apply(enable_quantization) diff --git a/flagscale/compress/algo/__init__.py b/flagscale/compress/algo/__init__.py index e69de29bb..d51a93938 100644 --- a/flagscale/compress/algo/__init__.py +++ b/flagscale/compress/algo/__init__.py @@ -0,0 +1,2 @@ +from .rtn import * +from .smooth_quant import * diff --git a/flagscale/compress/algo/algo_base.py b/flagscale/compress/algo/algo_base.py deleted file mode 100644 index f4bf4607a..000000000 --- a/flagscale/compress/algo/algo_base.py +++ /dev/null @@ -1,13 +0,0 @@ -import os - -class BaseALGO: - def __init__(self, name): - self.name = name - self._observer = False - self._compress = False - - def preprocess_weight(self): - raise NotImplementedError - - def add_batch(self): - raise NotImplementedError diff --git a/flagscale/compress/algo/rtn.py b/flagscale/compress/algo/rtn.py new file mode 100644 index 000000000..0521e21b8 --- /dev/null +++ b/flagscale/compress/algo/rtn.py @@ -0,0 +1,72 @@ +import os +import torch +from llmcompressor.modifiers.utils.compression_wrapper import ModuleCompressionWrapper +from compressed_tensors.utils import update_parameter_data +from compressed_tensors.quantization.lifecycle import fake_quantize +from flagscale.compress.observers.base import Observer + +__all__ = ["RTNWrapper"] +class RTNWrapper(ModuleCompressionWrapper): + def __init__(self, name, layer, enable_fake_quant=False): + super(RTNWrapper, self).__init__(name, layer) + quantization_scheme = getattr(self.layer, "quantization_scheme", None) + self._enable_fake_quant = enable_fake_quant + self.weight_observer = None + self.input_observer = None + self.output_observer = None + self.weights_observer_args = getattr(quantization_scheme, "weights", None) + self.input_observer_args = getattr(quantization_scheme, "input_activations", None) + self.output_observer_args = getattr(quantization_scheme, "output_activations", None) + if self._enable_fake_quant: + if self.input_observer_args and self.input_observer_args.dynamic: + self.input_observer_args.observer = "minmax" + self.input_observer = Observer.load_from_registry(self.input_observer_args.get_observer(), quantization_args=self.input_observer_args) + if self.weights_observer_args: + # origin_weight = self.layer.weight.clone() + W = fake_quantize(self.layer.weight, self.layer.weight_scale, self.layer.weight_zero_point, self.weights_observer_args) + update_parameter_data(self.layer, W, f"weight") + del W + else: + if self.weights_observer_args and not self.weights_observer_args.dynamic: + self.weight_observer = Observer.load_from_registry(self.weights_observer_args.get_observer(), quantization_args=self.weights_observer_args) + if self.input_observer_args and not self.input_observer_args.dynamic: + self.input_observer = Observer.load_from_registry(self.input_observer_args.get_observer(), quantization_args=self.input_observer_args) + if self.output_observer_args and not self.output_observer_args.dynamic: + self.output_observer = Observer.load_from_registry(self.output_observer_args.get_observer(), quantization_args=self.output_observer_args) + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + if self.input_observer: + updated_scale, updated_zero_point = self.input_observer(inp) + update_parameter_data(self.layer, updated_scale, f"input_scale") + update_parameter_data(self.layer, updated_zero_point, f"input_zero_point") + + def compress(self, g_idx=None): + if self.weight_observer: + updated_scale, updated_zero_point = self.weight_observer(self.layer.weight, g_idx=g_idx) + update_parameter_data(self.layer, updated_scale, f"weight_scale") + update_parameter_data(self.layer, updated_zero_point, f"weight_zero_point") + + def enable_fake_quant(self): + self._enable_fake_quant = True + + def forward(self, inp, **kwargs): + """ + Run a forward pass of the wrapped layer + """ + if self._enable_fake_quant: + if self.input_observer_args: + print("self.input_observer_args: ", self.input_observer_args) + if self.input_observer_args.dynamic: + scale, zp = self.input_observer(inp) + tmp_inp = fake_quantize(inp, scale, zp, self.input_observer_args) + error = torch.nn.functional.mse_loss(inp, tmp_inp) + # print("input dynamic error: ", error, inp, tmp_inp, scale, zp) + inp = tmp_inp + del tmp_inp, error + else: + inp = fake_quantize(inp, self.layer.input_scale, self.layer.input_zero_point, self.input_observer_args) + out = self.layer(inp, **kwargs) + # if self._enable_fake_quant and self.output_observer: + # out = fake_quantize(out, self.layer.output_scale, self.layer.output_zero_point, self.output_observer.quantization_args) + torch.cuda.empty_cache() + return out diff --git a/flagscale/compress/algo/smooth_quant.py b/flagscale/compress/algo/smooth_quant.py new file mode 100644 index 000000000..23de7b9e2 --- /dev/null +++ b/flagscale/compress/algo/smooth_quant.py @@ -0,0 +1,21 @@ +import torch +from flagscale.compress.blockwise_wrapper import BlockCompressionWrapper +from llmcompressor.modifiers.smoothquant import * +from llmcompressor.modifiers.smoothquant.utils import DEFAULT_SMOOTHQUANT_MAPPINGS + +__all__ = ["SmoothQuantWrapper"] +class SmoothQuantWrapper(BlockCompressionWrapper): + def __init__(self, name, layer): + super().__init__(name=name, layer=layer) + self.sq = SmoothQuantModifier() + self.sq.ignore = [] if not self.sq.ignore else self.sq.ignore + self.sq.mappings = self.sq._infer_mappings_from_model(self.layer) + self.sq.resolved_mappings_ = self.sq._resolve_mappings(self.layer) + self.sq.scales_ = {} + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor): + self.sq._setup_scale_hooks() + + def compress(self, smoothing_strength): + self.sq.smoothing_strength = smoothing_strength + self.sq._apply_smoothing(self.layer) diff --git a/flagscale/compress/blockwise_compressor.py b/flagscale/compress/blockwise_compressor.py new file mode 100644 index 000000000..d3b1accbb --- /dev/null +++ b/flagscale/compress/blockwise_compressor.py @@ -0,0 +1,42 @@ +import operator +import torch +from torch.nn import Module +from llmcompressor.modifiers.utils.layer_compressor import LayerCompressor +from llmcompressor.utils.fsdp.context import ( + summon_full_params_context, +) +from llmcompressor.utils.pytorch.module import set_layer, get_layer + +def replace_block(target: str, model: Module, target_module: Module): + parent_target = ".".join(target.split(".")[:-1]) + parent_layer = get_layer(parent_target, model)[1] + setattr(parent_layer, target.split(".")[-1], target_module) + +class BlockCompressor(LayerCompressor): + def pre_compress(self): + full_name = self.name + with summon_full_params_context(self.layer): + wrapper = self.module_compressor_class(full_name, self.layer) + replace_block(full_name, self.model, wrapper) + self.modules[full_name] = wrapper + + self.layer = operator.attrgetter(self.name)(self.model) + + def add_batch(name): + def tmp(_, inp, out): + self.modules[name].add_batch(inp[0].data, out[0].data) + + return tmp + + for name in self.modules: + self.handles.append(self.modules[name].register_forward_hook(add_batch(name))) + + def revert_layer_wrappers(self): + """ + Reverts wrapped root modules back to their original structure + """ + for name, module_wrapper in self.modules.items(): + full_name = self.name + replace_block(full_name, self.model, module_wrapper.layer) + torch.cuda.empty_cache() + self.modules = None diff --git a/flagscale/compress/blockwise_wrapper.py b/flagscale/compress/blockwise_wrapper.py new file mode 100644 index 000000000..b359b4f62 --- /dev/null +++ b/flagscale/compress/blockwise_wrapper.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod +from typing import Optional, Set + +import torch +import torch.nn as nn +from torch.nn import Module +try: + import transformers +except ImportError as err: + transformers = None + transformers_err = err + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +__all__ = ["BlockCompressionWrapper"] + +class BlockCompressionWrapper(Module, ABC): + def __init__(self, name, layer): + super(BlockCompressionWrapper, self).__init__() + if transformers is None: + raise transformers_err + + self.name = name + self.layer = layer + + self.dev = next(self.layer.parameters()).device + if hasattr(self.layer, "_hf_hook") and self.layer._hf_hook.offload: + self.dev = self.layer._hf_hook.execution_device + + # This need to be a buffer so its preserved between forward passes + self.register_buffer( + "nsamples", torch.zeros(1, dtype=torch.int32, device=self.dev) + ) + + def forward(self, *args, **kwargs): + """ + Run a forward pass of the wrapped layer + """ + return self.layer(*args, **kwargs) + + def free(self): + """ + Free buffers used for compression + """ + delattr(self, "nsamples") + + @abstractmethod + def add_batch(self, *args, **kwargs): + """ + Add a batch of layer input and output data to the layer statistics calculation + """ + raise NotImplementedError("Child class must implement `add_batch`") + + @abstractmethod + def compress(self, *args, **kwargs): + """ + Run pruning on the layer up to the target sparsity + """ + raise NotImplementedError("Child class must implement `compress`") + + def state_dict(self, destination=None, prefix="", keep_vars=False, **kwargs): + """ + Pass request to wrapped layer, so compression wrapper does not appear in + the state_dict + """ + return self.layer.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars, **kwargs + ) + + def load_state_dict(self, state_dict, strict=True): + """ + Pass request to wrapped layer, so compression wrapper does not appear in + the state_dict + """ + return self.layer.load_state_dict(state_dict, strict=strict) + + def named_modules( + self, + memo: Optional[Set["Module"]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + """ + Pass request to wrapped layer, so compression wrapper does not appear in + the module list + """ + return self.layer.named_modules( + memo=memo, prefix=prefix, remove_duplicate=remove_duplicate + ) \ No newline at end of file diff --git a/flagscale/compress/compressor.py b/flagscale/compress/compressor.py index 9ef35bd0b..1a2182dda 100644 --- a/flagscale/compress/compressor.py +++ b/flagscale/compress/compressor.py @@ -7,9 +7,12 @@ import torch from transformers import * +from llmcompressor.transformers.sparsification.compressed_tensors_utils import modify_save_pretrained +from llmcompressor.utils.pytorch.module import set_layer from flagscale.compress.combined_algo import prepare_compress_methods from flagscale.compress.adapter import LLMCompressorAdapter +from flagscale.compress.algo import RTNWrapper _g_ignore_fields = ["experiment", "action"] @@ -55,29 +58,50 @@ def copy_rest_file(src_path, dst_path): elif os.path.isdir(full_file_name): shutil.copytree(full_file_name, os.path.join(dst_path, filename)) -def compress(cfg, model=None, dataset=None): - tokenizer = None - model_path = cfg.model.pop("model_path") - if cfg.data.tokenzier_args is not None: - tokenizer = AutoTokenizer.from_pretrained(cfg.data.tokenzier_args.pop("tokenizer_path"), **cfg.data.tokenzier_args) - if model is None: - model_cls = eval(cfg.model.pop("model_cls")) - model = model_cls.from_pretrained(model_path, **cfg.model) - assert isinstance(model, torch.nn.Module), f"model type {type(model)} error, please check it" - compress_args = cfg.compress_args - recipes = prepare_compress_methods(compress_args) - for method, recipe in recipes.items(): - for algo_args in recipe: - algo_args = OmegaConf.to_container(algo_args) - algo_args["dataset"] = dataset - algo_args["num_calibration_steps"] = cfg.data.get("max_seq_length", 384) - adapter = LLMCompressorAdapter(model=model, **algo_args) - ### modify model inplace - model = adapter.model - - # oneshot(model=model, dataset=dataset, recipe=recipe, tokenizer=tokenizer, output_dir=cfg.system.save_dir, max_seq_length=cfg.data.get("max_seq_length", 384), num_calibration_samples=cfg.data.get("num_calibration_samples", 512), splits="calibration") - model.save_pretrained(cfg.system.save_dir, save_compressed=True) - copy_rest_file(model_path, cfg.system.save_dir) +class Compressor: + def __init__(self, cfg, model=None, dataset=None): + self.cfg = cfg + self.model = model + self.dataset = dataset + + def compress(self): + self.tokenizer = None + self.model_path = None + if self.model is None: + assert self.cfg.model is not None + self.model_path = self.cfg.model.pop("model_path") + if self.cfg.data.tokenzier_args is not None: + self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.data.tokenzier_args.pop("tokenizer_path"), **self.cfg.data.tokenzier_args) + if self.model is None: + model_cls = eval(self.cfg.model.pop("model_cls")) + self.model = model_cls.from_pretrained(self.model_path, **self.cfg.model) + assert isinstance(self.model, torch.nn.Module), f"model type {type(self.model)} error, please check it" + compress_args = self.cfg.compress_args + recipes = prepare_compress_methods(compress_args) + for method, recipe in recipes.items(): + for algo_args in recipe: + algo_args = OmegaConf.to_container(algo_args) + algo_args["dataset"] = self.dataset + algo_args["num_calibration_steps"] = self.cfg.data.get("num_calibration_steps", 384) + adapter = LLMCompressorAdapter(model=self.model, **algo_args) + ### modify model inplace + self.model = adapter.model + + # oneshot(model=model, dataset=dataset, recipe=recipe, tokenizer=tokenizer, output_dir=cfg.system.save_dir, max_seq_length=cfg.data.get("max_seq_length", 384), num_calibration_samples=cfg.data.get("num_calibration_samples", 512), splits="calibration") + def save_pretrained(self, save_compressed=True): + modify_save_pretrained(self.model) + self.model.save_pretrained(self.cfg.system.save_dir, save_compressed=save_compressed) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(self.cfg.system.save_dir) + copy_rest_file(self.model_path, cfg.system.save_dir) + + @torch.no_grad() + def convert(self, model): + for name, mod in model.named_modules(): + if hasattr(mod, "weight_scale"): + wrapper = RTNWrapper(name, mod, enable_fake_quant=True) + set_layer(name, wrapper, model) + return model if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -90,4 +114,4 @@ def compress(cfg, model=None, dataset=None): args = parser.parse_args() cfg = prepare_config(args.config_path) - compress(cfg) + Compressor(cfg) \ No newline at end of file diff --git a/flagscale/compress/compressor_emu3.py b/flagscale/compress/compressor_emu3.py index f8b6c9c8f..11c7076f1 100644 --- a/flagscale/compress/compressor_emu3.py +++ b/flagscale/compress/compressor_emu3.py @@ -9,7 +9,7 @@ from megatron.core.datasets.indexed_dataset import IndexedDataset from torch.utils.data import Dataset -from flagscale.compress.compressor import compress, prepare_config +from flagscale.compress.compressor import Compressor, prepare_config class CusDataset(Dataset): def __init__(self, ds): @@ -59,4 +59,6 @@ def prepare_dataset(cfg): args = parser.parse_args() cfg = prepare_config(args.config_path) dataset = prepare_dataset(cfg) - compress(cfg, dataset=dataset) + cmp = Compressor(cfg, dataset=dataset) + cmp.compress() + diff --git a/flagscale/compress/compressor_llava_ov.py b/flagscale/compress/compressor_llava_ov.py index 5bcb6a84f..497d1860f 100644 --- a/flagscale/compress/compressor_llava_ov.py +++ b/flagscale/compress/compressor_llava_ov.py @@ -13,7 +13,7 @@ import torch from torch.utils.data import Dataset -from flagscale.compress.compressor import compress, prepare_config +from flagscale.compress.compressor import Compressor, prepare_config import transformers from llava.model.builder import load_pretrained_model from llava.train.train import make_supervised_data_module, DataArguments, LLaVATrainer @@ -67,7 +67,7 @@ def prepare_dataset(cfg, model, tokenizer): if cfg.data.data_path is None: return None new_data_args = copy.deepcopy(cfg.data) - new_data_args.pop("num_calibration_samples") + new_data_args.pop("num_calibration_steps") new_data_args.pop("max_seq_length") new_data_args.pop("tokenzier_args") @@ -117,4 +117,5 @@ def prepare_dataset(cfg, model, tokenizer): cfg = prepare_config(args.config_path) model, tokenizer = prepare_model(cfg) dataset = prepare_dataset(cfg, model, tokenizer) - compress(cfg, dataset=dataset, model=model) + com = Compressor(cfg, dataset=dataset, model=model) + com.compress() diff --git a/flagscale/compress/observers/__init__.py b/flagscale/compress/observers/__init__.py new file mode 100644 index 000000000..3b7f7ffd5 --- /dev/null +++ b/flagscale/compress/observers/__init__.py @@ -0,0 +1,4 @@ +from .base import Observer +from .minmax import MinMaxObserver +from .mse import MSEObserver + diff --git a/flagscale/compress/observers/base.py b/flagscale/compress/observers/base.py new file mode 100644 index 000000000..faf18ea5c --- /dev/null +++ b/flagscale/compress/observers/base.py @@ -0,0 +1 @@ +from llmcompressor.observers.base import Observer diff --git a/flagscale/compress/observers/minmax.py b/flagscale/compress/observers/minmax.py new file mode 100644 index 000000000..adb4449f6 --- /dev/null +++ b/flagscale/compress/observers/minmax.py @@ -0,0 +1,49 @@ +from typing import Any, Iterable, Optional, Tuple, Union +import torch +from torch import FloatTensor, IntTensor, Tensor +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.quantization.quant_args import QuantizationArgs +from .base import Observer +from compressed_tensors.registry.registry import _REGISTRY, _ALIAS_REGISTRY + +_REGISTRY[Observer]["moving_minmax"] = _REGISTRY[Observer]["minmax"] +_REGISTRY[Observer].pop("minmax") +_ALIAS_REGISTRY[Observer]["moving_minmax"] = "moving_minmax" +_ALIAS_REGISTRY[Observer].pop("minmax") + +@Observer.register("minmax") +class MinMaxObserver(Observer): + def __init__( + self, quantization_args: QuantizationArgs + ): + super().__init__(quantization_args=quantization_args) + + def calculate_qparams( + self, + observed: Tensor, + reduce_dims: Optional[Tuple[int]] = None, + ) -> Tuple[FloatTensor, IntTensor]: + if not reduce_dims: + min_val, max_val = torch.aminmax(observed) + else: + min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) + max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) + + return calculate_qparams( + min_val, max_val, self.quantization_args + ) + + def get_qparams_along_dim( + self, + observed, + dim: Union[int, Iterable[int]], + ): + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) + return self.calculate_qparams( + observed, reduce_dims=reduce_dims + ) + \ No newline at end of file diff --git a/flagscale/compress/observers/mse.py b/flagscale/compress/observers/mse.py new file mode 100644 index 000000000..81dbecc2a --- /dev/null +++ b/flagscale/compress/observers/mse.py @@ -0,0 +1,103 @@ +from typing import Any, Iterable, Optional, Tuple, Union +import torch +from torch import FloatTensor, IntTensor, Tensor +from compressed_tensors.quantization.utils import calculate_qparams +from compressed_tensors.quantization.quant_args import QuantizationArgs +from compressed_tensors.quantization.lifecycle import fake_quantize +from .base import Observer +from compressed_tensors.registry.registry import _REGISTRY, _ALIAS_REGISTRY + +_REGISTRY[Observer]["moving_mse"] = _REGISTRY[Observer]["mse"] +_REGISTRY[Observer].pop("mse") +_ALIAS_REGISTRY[Observer]["moving_mse"] = "moving_mse" +_ALIAS_REGISTRY[Observer].pop("mse") + +print(_REGISTRY[Observer]) +@Observer.register("mse") +class MSEObserver(Observer): + def __init__( + self, + quantization_args: QuantizationArgs, + grid: float = 100.0, + maxshrink: float = 0.80, + norm: float = 2.4, + ): + super().__init__(quantization_args=quantization_args) + self.grid = grid + self.maxshrink = maxshrink + self.norm = norm + + def calculate_mse_min_max( + self, + observed: Tensor, + reduce_dims: Optional[Tuple[int]] = None, + ): + if not reduce_dims: + absolute_min_val, absolute_max_val = torch.aminmax(observed) + else: + absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True) + absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True) + + best = torch.full_like( + absolute_min_val, torch.finfo(absolute_min_val.dtype).max + ) + min_val = torch.ones_like(absolute_min_val) + max_val = torch.zeros_like(absolute_max_val) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + shrinked_min_val = p * absolute_min_val + shrinked_max_val = p * absolute_max_val + + candidate_scales, candidate_zero_points = calculate_qparams( + shrinked_min_val, shrinked_max_val, self.quantization_args + ) + q = fake_quantize( + observed, + candidate_scales, + candidate_zero_points, + self.quantization_args, + ) + + q -= observed + q.abs_() + q.pow_(self.norm) + if not reduce_dims: + err = torch.sum(q) + else: + err = torch.sum(q, reduce_dims, keepdims=True) + + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + min_val[tmp] = shrinked_min_val[tmp] + max_val[tmp] = shrinked_max_val[tmp] + return min_val, max_val + + def calculate_qparams( + self, + observed: Tensor, + reduce_dims: Optional[Tuple[int]] = None, + ) -> Tuple[FloatTensor, IntTensor]: + min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims) + + return calculate_qparams( + min_val, max_val, self.quantization_args + ) + + def get_qparams_along_dim( + self, + observed, + dim: Union[int, Iterable[int]], + tensor_id: Optional[Any] = None, + ): + if isinstance(dim, int): + dim = [dim] + dim = set(dim) + + reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim) + return self.calculate_qparams( + observed, reduce_dims=reduce_dims + ) + + +