diff --git a/.gitignore b/.gitignore index 4447f95c..a42eb8f4 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,6 @@ error.log # Files generated from running examples fms_mo.log -data_train/ -data_test/ +data*_train/ +data*_test/ act_scales/ diff --git a/.spellcheck-en-custom.txt b/.spellcheck-en-custom.txt index eb138cfa..f2d09f64 100644 --- a/.spellcheck-en-custom.txt +++ b/.spellcheck-en-custom.txt @@ -1,4 +1,5 @@ activations +acc ADR Args AutoGPTQ @@ -67,6 +68,7 @@ NLP Nouterloop Nvidia Nvidia's +openai orchestrator param pre @@ -99,6 +101,8 @@ SmoothQuant socio sparsification SQuAD +stderr +Stderr straightforward tokenization tokenized diff --git a/examples/FP8_QUANT/README.md b/examples/FP8_QUANT/README.md index 7ec40c74..976ea3f2 100644 --- a/examples/FP8_QUANT/README.md +++ b/examples/FP8_QUANT/README.md @@ -73,20 +73,18 @@ This end-to-end example utilizes the common set of interfaces provided by `fms_m ## Example Test Results - BF16 (not quantized) LLAMA3-8B model. - ``` bash - | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| - |--------------|------:|------|-----:|----------|---|-----:|---|-----:| - |lambada_openai| 1|none | 5|acc |↑ |0.7120|± |0.0287| - | | |none | 5|perplexity|↓ |3.8683|± |0.3716| - ``` + +| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| +|--------------|------:|------|-----:|----------|---|-----:|---|-----:| +|lambada_openai| 1|none | 5|acc |↑ |0.7120|± |0.0287| +| | |none | 5|perplexity|↓ |3.8683|± |0.3716| - FP8 quantized LLAMA3-8B model. - ``` bash - | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| - |--------------|------:|------|-----:|----------|---|-----:|---|-----:| - |lambada_openai| 1|none | 5|acc |↑ |0.7160|± |0.0286| - | | |none | 5|perplexity|↓ |3.8915|± |0.3727| - ``` + +| Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| +|--------------|------:|------|-----:|----------|---|-----:|---|-----:| +|lambada_openai| 1|none | 5|acc |↑ |0.7160|± |0.0286| +| | |none | 5|perplexity|↓ |3.8915|± |0.3727| ## Code Walk-through diff --git a/fms_mo/dq.py b/fms_mo/dq.py index cbb1ef03..5bf0a288 100644 --- a/fms_mo/dq.py +++ b/fms_mo/dq.py @@ -21,7 +21,6 @@ # Standard from pathlib import Path import logging -import os # Third Party from datasets import load_from_disk @@ -114,7 +113,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): revision="main", use_auth_token=True if model_args.use_auth_token else None, torch_dtype=torch_dtype, - low_cpu_mem_usage=False, + low_cpu_mem_usage=model_args.low_cpu_mem_usage, + device_map="auto" if model_args.low_cpu_mem_usage else None, ) embedding_size = model.get_input_embeddings().weight.shape[0] @@ -125,7 +125,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): logger.info(f"Model is at {model.device} after intialization") logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}") qcfg = qconfig_init(recipe="dq", args=fms_mo_args) - # for models that cannot fit in 1 GPU, keep it in CPU and use block-wise calibration. + # for models that cannot fit in 1 GPU, keep it on CPU and use block-wise calibration. + # or leverage HF's device_map="auto" total_gpu_memory = 1e-5 if torch.cuda.is_available(): total_gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 @@ -143,7 +144,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): qcfg["large_model"] = any( name in model_args.model_name_or_path for name in known_large_models ) or (gpu_mem_util_per > 0.7) - dev = "cpu" if qcfg["large_model"] else "cuda:0" + dev = "cpu" if qcfg["large_model"] else "cuda" + model.to(dev) if hasattr(model.config, "model_type"): qcfg["model_type"] = model.config.model_type @@ -180,23 +182,27 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args): batch_size=1, ) - # For loading or creating smoothquant scale. - act_scale_directory = "./act_scales" - if not os.path.exists(act_scale_directory): - os.makedirs(act_scale_directory) + # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. + scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt") + if qcfg.get("act_scale_path", None): + # user provided a scale file (or a dir) + scale_file_or_dir = Path(qcfg["act_scale_path"]) + if scale_file_or_dir.is_dir(): + scale_file = scale_file_or_dir / f"{qcfg['model'].replace('/', '-')}.pt" + elif scale_file_or_dir.is_file(): + scale_file = scale_file_or_dir - if qcfg["act_scale_path"] is not None: - act_scales = torch.load(qcfg["act_scale_path"], map_location="cpu") + if not scale_file.parent.exists(): + scale_file.parent.mkdir(exist_ok=False) + + if scale_file.exists(): + act_scales = torch.load(scale_file, map_location=getattr(model, "device", dev)) else: logger.info("Generate activation scales") if qcfg["large_model"]: act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg) else: - if gpu_mem_util_per < 0.7: - model.to(dev) - act_scales = get_act_scales(model, dq_dataloader, qcfg) - scale_file = f"{act_scale_directory}/{qcfg['model'].replace('/', '-')}" + ".pt" torch.save(act_scales, scale_file) qmodel_prep( diff --git a/fms_mo/modules/linear.py b/fms_mo/modules/linear.py index 2c2e6527..ddb0069b 100644 --- a/fms_mo/modules/linear.py +++ b/fms_mo/modules/linear.py @@ -93,6 +93,10 @@ def __init__( Defaults to 32. qw_mode (str, optional): Quantization mode for weight. Defaults to None. **kwargs (dict): Additional keyword arguments. + + Note: + scales could be of higher precision than x or W, need to make sure qinput.dtype after + Qa(x/scale) are consistent with x. Same for W """ super().__init__( @@ -275,7 +279,7 @@ def forward(self, x): # pylint: disable=not-callable return F.linear(x, self.W_fp, self.bias) else: - qinput = self.quantize_feature(x / scale) + qinput = self.quantize_feature(x / scale).to(x.dtype) # Default self.update_type == 'hard' pruning. if self.mask is not None: pweight = HardPrune.apply( @@ -283,7 +287,9 @@ def forward(self, x): ) qweight = self.quantize_weight(pweight) else: - qweight = self.quantize_weight(self.weight * scale) + qweight = self.quantize_weight(self.weight * scale).to( + self.weight.dtype + ) qbias = self.bias diff --git a/fms_mo/quant/ptq.py b/fms_mo/quant/ptq.py index db79b364..2e192a25 100644 --- a/fms_mo/quant/ptq.py +++ b/fms_mo/quant/ptq.py @@ -1943,14 +1943,12 @@ def __init__(self, module, qcfg): def forward(self, inp, **kwargs): self.qcfg["cached_block0_input"].append(inp.cpu()) self.qcfg["cache_id"] += 1 - for k, v in kwargs.items(): - if k == "attention_mask": - if v is not None: - self.qcfg["cached_mask"].append(v.cpu()) - if k == "alibi": - self.qcfg["cached_alibi"].append(v.cpu()) - if k == "position_ids": - self.qcfg["position_ids"].append(v.cpu()) + for kw_org, kw_qcfg in self.qcfg["kw_to_cache"].items(): + if kw_qcfg not in self.qcfg: + self.qcfg[kw_qcfg] = [] + v = kwargs.get(kw_org, None) + if v is not None: + self.qcfg[kw_qcfg].append(move_to(v, "cpu")) raise ValueError @@ -1965,14 +1963,15 @@ def __init__(self, module, qcfg): self.module = module def forward(self, **kwargs): - for k, v in kwargs.items(): - if k == "x": - self.qcfg["cached_block0_input"][self.qcfg["cache_id"]] = v.cpu() - self.qcfg["cache_id"] += 1 - if k == "mask": - self.qcfg["cached_mask"] = v.cpu() - if k == "rel_pos_bias": - self.qcfg["cached_pos_bias"] = v.cpu() + self.qcfg["cached_block0_input"][self.qcfg["cache_id"]] = kwargs["x"].cpu() + self.qcfg["cache_id"] += 1 + for kw_org, kw_qcfg in self.qcfg["kw_to_cache"]: + if kw_qcfg not in self.qcfg: + self.qcfg[kw_qcfg] = [] + v = kwargs.get(kw_org, None) + if v is not None: + self.qcfg[kw_qcfg].append(v.cpu()) + raise ValueError @@ -2126,13 +2125,21 @@ def cache_block0_inputs( qcfg["cache_id"] = 0 qcfg["cached_mask"] = [] qcfg["cached_alibi"] = [] - qcfg[ - "position_ids" - ] = [] # latest transformers requires pos_ids to be fed into fwd() # move block0 to GPU and excuting fwd() until finish block0 if "fms" in qcfg["model_type"]: + qcfg["kw_to_cache"] = { + "mask": "cached_mask", + "rel_pos_bias": "cached_pos_bias", + } blocks[0] = RunFMModule(blocks[0], qcfg) else: + # latest transformers requires pos_ids to be fed into fwd() + qcfg["kw_to_cache"] = { + "attention_mask": "cached_mask", + "alibi": "cached_alibi", + "position_ids": "position_ids", + "position_embeddings": "position_embeddings", + } blocks[0] = RunModule(blocks[0], qcfg) if isinstance(dloader, torch.utils.data.DataLoader): @@ -2464,12 +2471,13 @@ def get_module_act_scales(m, block_idx, qcfg, act_scales): alibi=qcfg["cached_alibi"][i].unsqueeze(0).to(dev), )[0].cpu() else: + kwargs = { + kw_org: move_to(qcfg[kw_qcfg][i], dev) if qcfg[kw_qcfg] != [] else None + for kw_org, kw_qcfg in qcfg["kw_to_cache"].items() + } qcfg["cached_input"][i] = m( qcfg["cached_input"][i].to(dev), - attention_mask=None - if qcfg["cached_mask"] == [] - else qcfg["cached_mask"][i].to(dev), - position_ids=qcfg["position_ids"][i].to(dev), + **kwargs, )[0].cpu() for h in hooks: h.remove() @@ -2482,7 +2490,7 @@ def get_act_scales_1gpu(model, dloader, qcfg): """ get activation blocks on 1gpu for very large models that cannot fit in 1gpu """ - dev = "cuda:0" + dev = "cuda" qcfg["batch_size"] = 1 qcfg["loader_len"] = len(dloader) qcfg["dtype"] = next(iter(model.parameters())).dtype diff --git a/fms_mo/training_args.py b/fms_mo/training_args.py index aa9bb854..1da6978f 100644 --- a/fms_mo/training_args.py +++ b/fms_mo/training_args.py @@ -56,6 +56,13 @@ class ModelArguments(TypeChecker): model_name_or_path: str = field(default="facebook/opt-125m") torch_dtype: str = field(default="bfloat16") + low_cpu_mem_usage: bool = field( + default=False, + metadata={ + "help": "When set to True, leverage device_map='auto' and let HF to move modules" + "between cpu and cuda automatically during inference." + }, + ) use_fast_tokenizer: bool = field( default=True, metadata={ diff --git a/fms_mo/utils/dq_utils.py b/fms_mo/utils/dq_utils.py index 7698ed64..ea2546b6 100644 --- a/fms_mo/utils/dq_utils.py +++ b/fms_mo/utils/dq_utils.py @@ -38,50 +38,35 @@ def config_quantize_smooth_layers(qcfg): "granite-20b-code", "granite-20b-code", ] - if any(model in qcfg["model"] for model in llama_architecture) or any( - model in qcfg["model_type"] for model in llama_architecture + if ( + any(model in qcfg["model"] for model in llama_architecture) + or any(model in qcfg["model_type"] for model in llama_architecture) + and qcfg["qskip_large_mag_layers"] ): qcfg["qlayer_name_pattern"] = ["model.layers."] qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"] - qcfg["qskip_layer_name"] = [] - if "2-7b" in qcfg["model"]: - if qcfg["qskip_large_mag_layers"]: - qcfg["qskip_layer_name"] = [ - f"model.layers.{i}.mlp.down_proj" for i in [1, 30] - ] - if "2-13b" in qcfg["model"]: - if qcfg["qskip_large_mag_layers"]: - qcfg["qskip_layer_name"] = [ - f"model.layers.{i}.mlp.down_proj" for i in [3, 37] - ] - if "2-70b" in qcfg["model"]: - if qcfg["qskip_large_mag_layers"]: - qcfg["qskip_layer_name"] = [ - f"model.layers.{i}.mlp.down_proj" for i in [2, 8, 79] - ] - if "3-8B" in qcfg["model"]: - if qcfg["qskip_large_mag_layers"]: - qcfg["qskip_layer_name"] = [ - f"model.layers.{i}.mlp.down_proj" for i in [1, 31] - ] - if "3-70B" in qcfg["model"]: - if qcfg["qskip_large_mag_layers"]: - qcfg["qskip_layer_name"] = [ - f"model.layers.{i}.mlp.down_proj" for i in [3, 78, 79] - ] - if "405B-Instruct" in qcfg["model"]: # llama3.1 - if qcfg["qskip_large_mag_layers"]: - qcfg["qskip_layer_name"] = [ - f"model.layers.{i}.mlp.down_proj" for i in [5, 124, 125] + large_mag_layers = { + "2-7b": [1, 30], + "2-70b": [2, 8, 79], + "3-8B": [1, 31], + "3-70B": [3, 78, 79], + "405B-Instruct": [5, 124, 125], + } + for llama_family, layers in large_mag_layers.items(): + if llama_family in qcfg["model"]: + qcfg["qskip_layer_name"] += [ + f"model.layers.{i}.mlp.down_proj" for i in layers ] + break + elif "mixtral" in qcfg["model"]: qcfg["qlayer_name_pattern"] = ( ["model.layers"] if qcfg["nbits_bmm1"] == 32 else [] ) qcfg["scale_layers"] = ["q_proj", "k_proj", "v_proj", "w1", "w3"] - qcfg["qskip_layer_name"] = [] - for i in range(32): - qcfg["qskip_layer_name"].append(f"model.layers.{i}.block_sparse_moe.gate") + qcfg["qskip_layer_name"] += [ + f"model.layers.{i}.block_sparse_moe.gate" for i in range(32) + ] if qcfg["qskip_large_mag_layers"]: qcfg["qskip_layer_name"] += [ f"model.layers.{i}.block_sparse_moe.experts.{j}.w2"