diff --git a/fms_mo/run_quant.py b/fms_mo/run_quant.py index dac90648..beb8db27 100644 --- a/fms_mo/run_quant.py +++ b/fms_mo/run_quant.py @@ -27,10 +27,15 @@ # Standard import logging +import os +import sys import time +import traceback # Third Party from datasets import load_from_disk +from huggingface_hub.errors import HFValidationError +from torch.cuda import OutOfMemoryError from transformers import AutoTokenizer import transformers @@ -44,9 +49,14 @@ ModelArguments, OptArguments, ) +from fms_mo.utils.config_utils import get_json_config +from fms_mo.utils.error_logging import ( + INTERNAL_ERROR_EXIT_CODE, + USER_ERROR_EXIT_CODE, + write_termination_log, +) from fms_mo.utils.import_utils import available_packages - -logger = logging.Logger("fms_mo.main") +from fms_mo.utils.logging_utils import set_log_level def quantize( @@ -70,6 +80,8 @@ def quantize( fp8_args (fms_mo.training_args.FP8Arguments): Parameters to use for FP8 quantization """ + logger = set_log_level(opt_args.log_level, "fms_mo.quantize") + logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n") if opt_args.quant_method == "gptq": @@ -119,6 +131,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args): # Local from fms_mo.utils.custom_gptq_models import custom_gptq_classes + logger = set_log_level(opt_args.log_level, "fms_mo.run_gptq") + quantize_config = BaseQuantizeConfig( bits=gptq_args.bits, group_size=gptq_args.group_size, @@ -178,6 +192,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args): from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot + logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8") + model = SparseAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=model_args.torch_dtype ) @@ -204,9 +220,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args): tokenizer.save_pretrained(opt_args.output_dir) -def main(): - """Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques""" - +def get_parser(): + """Get the command-line argument parser.""" parser = transformers.HfArgumentParser( dataclass_types=( ModelArguments, @@ -217,20 +232,53 @@ def main(): FP8Arguments, ) ) + return parser - ( - model_args, - data_args, - opt_args, - fms_mo_args, - gptq_args, - fp8_args, - _, - ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) - logger.debug( - "Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, " - "gptq_args %s, fp8_args %s", +def parse_arguments(parser, json_config=None): + """Parses arguments provided either via command-line or JSON config. + + Args: + parser: argparse.ArgumentParser + Command-line argument parser. + json_config: dict[str, Any] + Dict of arguments to use with tuning. + + Returns: + ModelArguments + Arguments pertaining to which model we are going to quantize. + DataArguments + Arguments pertaining to what data we are going to use for optimization and evaluation. + OptArguments + Arguments generic to optimization. + FMSMOArguments + Configuration for PTQ quantization. + GPTQArguments + Configuration for GPTQ quantization. + FP8Arguments + Configuration for FP8 quantization. + """ + if json_config: + ( + model_args, + data_args, + opt_args, + fms_mo_args, + gptq_args, + fp8_args, + ) = parser.parse_dict(json_config, allow_extra_keys=True) + else: + ( + model_args, + data_args, + opt_args, + fms_mo_args, + gptq_args, + fp8_args, + _, + ) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + return ( model_args, data_args, opt_args, @@ -239,14 +287,72 @@ def main(): fp8_args, ) - quantize( - model_args=model_args, - data_args=data_args, - opt_args=opt_args, - fms_mo_args=fms_mo_args, - gptq_args=gptq_args, - fp8_args=fp8_args, - ) + +def main(): + """Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques""" + + parser = get_parser() + logger = logging.getLogger() + job_config = get_json_config() + # accept arguments via command-line or JSON + try: + ( + model_args, + data_args, + opt_args, + fms_mo_args, + gptq_args, + fp8_args, + ) = parse_arguments(parser, job_config) + + logger = set_log_level(opt_args.log_level, __name__) + + logger.debug(f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, \ + opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, \ + fp8_args {fp8_args}") + except Exception as e: # pylint: disable=broad-except + logger.error(traceback.format_exc()) + write_termination_log( + f"Exception raised during optimization. This may be a problem with your input: {e}" + ) + sys.exit(USER_ERROR_EXIT_CODE) + + if opt_args.output_dir: + os.makedirs(opt_args.output_dir, exist_ok=True) + logger.info("Using the output directory at %s", opt_args.output_dir) + try: + quantize( + model_args=model_args, + data_args=data_args, + opt_args=opt_args, + fms_mo_args=fms_mo_args, + gptq_args=gptq_args, + fp8_args=fp8_args, + ) + except (MemoryError, OutOfMemoryError) as e: + logger.error(traceback.format_exc()) + write_termination_log(f"OOM error during optimization. {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) + except FileNotFoundError as e: + logger.error(traceback.format_exc()) + write_termination_log(f"Unable to load file: {e}") + sys.exit(USER_ERROR_EXIT_CODE) + except HFValidationError as e: + logger.error(traceback.format_exc()) + write_termination_log( + f"There may be a problem with loading the model. Exception: {e}" + ) + sys.exit(USER_ERROR_EXIT_CODE) + except (TypeError, ValueError, EnvironmentError) as e: + logger.error(traceback.format_exc()) + write_termination_log( + f"Exception raised during optimization. This may be a problem with your input: {e}" + ) + sys.exit(USER_ERROR_EXIT_CODE) + except Exception as e: # pylint: disable=broad-except + logger.error(traceback.format_exc()) + write_termination_log(f"Unhandled exception during optimization: {e}") + sys.exit(INTERNAL_ERROR_EXIT_CODE) if __name__ == "__main__": diff --git a/fms_mo/utils/config_utils.py b/fms_mo/utils/config_utils.py new file mode 100644 index 00000000..181024cd --- /dev/null +++ b/fms_mo/utils/config_utils.py @@ -0,0 +1,80 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import base64 +import json +import os +import pickle + + +def update_config(config, **kwargs): + """Updates config from key-value pairs provided through kwargs""" + if isinstance(config, (tuple, list)): + for c in config: + update_config(c, **kwargs) + else: + for k, v in kwargs.items(): + if hasattr(config, k): + setattr(config, k, v) + elif "." in k: + # allow --some_config.some_param=True + config_name, param_name = k.split(".") + if type(config).__name__ == config_name: + if hasattr(config, param_name): + setattr(config, param_name, v) + else: + # In case of specialized config we can warm user + print(f"Warning: {config_name} does not accept parameter: {k}") + + +def get_json_config(): + """Parses JSON configuration if provided via environment variables + FMS_MO_CONFIG_JSON_ENV_VAR or FMS_MO_CONFIG_JSON_PATH. + + FMS_MO_CONFIG_JSON_ENV_VAR is the base64 encoded JSON. + FMS_MO_CONFIG_JSON_PATH is the path to the JSON config file. + + Returns: dict or {} + """ + json_env_var = os.getenv("FMS_MO_CONFIG_JSON_ENV_VAR") + json_path = os.getenv("FMS_MO_CONFIG_JSON_PATH") + + # accepts either path to JSON file or encoded string config + # env var takes precedent + job_config_dict = {} + if json_env_var: + job_config_dict = txt_to_obj(json_env_var) + elif json_path: + with open(json_path, "r", encoding="utf-8") as f: + job_config_dict = json.load(f) + + return job_config_dict + + +def txt_to_obj(txt): + """Given encoded byte string, converts to base64 decoded dict. + + Args: + txt: str + Returns: dict[str, Any] + """ + base64_bytes = txt.encode("ascii") + message_bytes = base64.b64decode(base64_bytes) + try: + # If the bytes represent JSON string + return json.loads(message_bytes) + except UnicodeDecodeError: + # Otherwise the bytes are a pickled python dictionary + return pickle.loads(message_bytes) diff --git a/fms_mo/utils/dq_utils.py b/fms_mo/utils/dq_utils.py index d4bef436..7698ed64 100644 --- a/fms_mo/utils/dq_utils.py +++ b/fms_mo/utils/dq_utils.py @@ -38,7 +38,9 @@ def config_quantize_smooth_layers(qcfg): "granite-20b-code", "granite-20b-code", ] - if any(model in qcfg["model"] for model in llama_architecture): + if any(model in qcfg["model"] for model in llama_architecture) or any( + model in qcfg["model_type"] for model in llama_architecture + ): qcfg["qlayer_name_pattern"] = ["model.layers."] qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"] qcfg["qskip_layer_name"] = [] diff --git a/fms_mo/utils/error_logging.py b/fms_mo/utils/error_logging.py new file mode 100644 index 00000000..66171903 --- /dev/null +++ b/fms_mo/utils/error_logging.py @@ -0,0 +1,41 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging +import os + +# The USER_ERROR_EXIT_CODE will be thrown when the process must exit +# as result of a user input error. User-related errors should be +# >= 1 and <=127 due to how some kubernetes operators interpret them. +USER_ERROR_EXIT_CODE = 1 +# The INTERNAL_ERROR_EXIT_CODE will be thrown when training +# abnormally terminates, and it is not clearly fault of the user. +# System-level errors should be >= 128 and <= 254 +INTERNAL_ERROR_EXIT_CODE = 203 + + +def write_termination_log(text, log_file="error.log"): + """Writes text to termination log. + + Args: + text: str + log_file: Optional[str] + """ + log_file = os.environ.get("TERMINATION_LOG_FILE", log_file) + try: + with open(log_file, "a", encoding="utf-8") as handle: + handle.write(text) + except Exception as e: # pylint: disable=broad-except + logging.warning(f"Unable to write termination log due to error {e}") diff --git a/fms_mo/utils/logging_utils.py b/fms_mo/utils/logging_utils.py new file mode 100644 index 00000000..82daf361 --- /dev/null +++ b/fms_mo/utils/logging_utils.py @@ -0,0 +1,49 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard +import logging +import os + + +def set_log_level(log_level=None, logger_name=None): + """Set log level of python native logger and TF logger via argument from CLI or env variable. + + Args: + train_args + Training arguments for training model. + logger_name + Logger name with which the logger is instantiated. + + Returns: + train_args + Updated training arguments for training model. + train_logger + Logger with updated effective log level + """ + + # Clear any existing handlers if necessary + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + # Configure Python native logger + # If CLI arg is passed, assign same log level to python native logger + log_level = log_level or os.environ.get("LOG_LEVEL", "WARNING") + + logging.basicConfig( + format="%(levelname)s:%(filename)s:%(message)s", level=log_level.upper() + ) + + logger = logging.getLogger(logger_name) if logger_name else logging.getLogger() + return logger diff --git a/tests/__init__.py b/tests/__init__.py index e69de29b..094b6434 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/artifacts/configs/dummy_job_config.json b/tests/artifacts/configs/dummy_job_config.json new file mode 100644 index 00000000..5f84d231 --- /dev/null +++ b/tests/artifacts/configs/dummy_job_config.json @@ -0,0 +1,12 @@ +{ + "accelerate_launch_args": { + "main_process_port": 1234 + }, + "model_name_or_path": "Maykeye/TinyLLama-v0", + "training_data_path": "data_train", + "quant_method": "gptq", + "bits": 4, + "group_size": 128, + "output_dir": "models/Maykeye/TinyLLama-v0-GPTQ", + "log_level":"DEBUG" + } \ No newline at end of file diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py new file mode 100644 index 00000000..d0ac159a --- /dev/null +++ b/tests/artifacts/testdata/__init__.py @@ -0,0 +1,28 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helpful datasets for configuring individual unit tests.""" + +# Standard +import os + +### Constants used for data +MODEL_NAME = "Maykeye/TinyLLama-v0" +DATA_DIR = os.path.join(os.path.dirname(__file__)) + +WIKITEXT_TOKENIZED_DATA_JSON = os.path.join( + DATA_DIR, "wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048" +) +EMPTY_DATA = os.path.join(DATA_DIR, "empty_data.json") +MALFORMATTED_DATA = os.path.join(DATA_DIR, "malformatted_data.json") diff --git a/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/data-00000-of-00001.arrow b/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/data-00000-of-00001.arrow new file mode 100644 index 00000000..de2c6a62 Binary files /dev/null and b/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/data-00000-of-00001.arrow differ diff --git a/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/dataset_info.json b/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/dataset_info.json new file mode 100644 index 00000000..8f91ccaa --- /dev/null +++ b/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/dataset_info.json @@ -0,0 +1,22 @@ +{ + "citation": "", + "description": "", + "features": { + "input_ids": { + "feature": { + "dtype": "int32", + "_type": "Value" + }, + "_type": "Sequence" + }, + "attention_mask": { + "feature": { + "dtype": "int8", + "_type": "Value" + }, + "_type": "Sequence" + } + }, + "homepage": "", + "license": "" +} \ No newline at end of file diff --git a/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/state.json b/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/state.json new file mode 100644 index 00000000..caaeddab --- /dev/null +++ b/tests/artifacts/testdata/wiki_maykeye_tinyllama_v0_numsamp2_seqlen2048/state.json @@ -0,0 +1,13 @@ +{ + "_data_files": [ + { + "filename": "data-00000-of-00001.arrow" + } + ], + "_fingerprint": "27301518aea394cd", + "_format_columns": null, + "_format_kwargs": {}, + "_format_type": null, + "_output_all_columns": false, + "_split": null +} \ No newline at end of file diff --git a/tests/models/__init__.py b/tests/models/__init__.py index e69de29b..094b6434 100644 --- a/tests/models/__init__.py +++ b/tests/models/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/test_run_quant.py b/tests/test_run_quant.py new file mode 100644 index 00000000..74cfd09a --- /dev/null +++ b/tests/test_run_quant.py @@ -0,0 +1,141 @@ +# Copyright The FMS Model Optimizer Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for run_quant.py""" + +# Standard +import copy +import json +import os + +# Third Party +import pytest +import torch + +# Local +from fms_mo.run_quant import get_parser, parse_arguments, quantize +from fms_mo.training_args import ( + DataArguments, + FMSMOArguments, + FP8Arguments, + GPTQArguments, + ModelArguments, + OptArguments, +) +from tests.artifacts.testdata import MODEL_NAME, WIKITEXT_TOKENIZED_DATA_JSON + +MODEL_ARGS = ModelArguments(model_name_or_path=MODEL_NAME, torch_dtype="float16") +DATA_ARGS = DataArguments( + training_data_path=WIKITEXT_TOKENIZED_DATA_JSON, +) +OPT_ARGS = OptArguments(quant_method="dq", output_dir="tmp") +GPTQ_ARGS = GPTQArguments( + bits=4, + group_size=64, +) +FP8_ARGS = FP8Arguments() +DQ_ARGS = FMSMOArguments( + nbits_w=8, + nbits_a=8, + nbits_kvcache=32, + qa_mode="fp8_e4m3_scale", + qw_mode="fp8_e4m3_scale", + qmodel_calibration_new=0, +) + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Only runs if GPUs are available" +) +def test_run_train_requires_output_dir(): + """Check fails when output dir not provided.""" + updated_output_dir_opt_args = copy.deepcopy(OPT_ARGS) + updated_output_dir_opt_args.output_dir = None + with pytest.raises(TypeError): + quantize( + model_args=MODEL_ARGS, + data_args=DATA_ARGS, + opt_args=updated_output_dir_opt_args, + fms_mo_args=DQ_ARGS, + ) + + +def test_run_train_fails_training_data_path_not_exist(): + """Check fails when data path not found.""" + updated_data_path_args = copy.deepcopy(DATA_ARGS) + updated_data_path_args.training_data_path = "fake/path" + with pytest.raises(FileNotFoundError): + quantize( + model_args=MODEL_ARGS, + data_args=updated_data_path_args, + opt_args=OPT_ARGS, + fms_mo_args=DQ_ARGS, + ) + + +HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( + os.path.dirname(__file__), "artifacts", "configs", "dummy_job_config.json" +) + + +@pytest.fixture(name="job_config", scope="session") +def fixture_job_config(): + """Fixture to get happy path dummy config as a dict, note that job_config dict gets + modified during process training args""" + with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: + dummy_job_config_dict = json.load(f) + return dummy_job_config_dict + + +############################# Arg Parsing Tests ############################# + + +def test_parse_arguments(job_config): + """Test that arg parser can parse json job config correctly""" + parser = get_parser() + job_config_copy = copy.deepcopy(job_config) + ( + model_args, + data_args, + opt_args, + _, + _, + _, + ) = parse_arguments(parser, job_config_copy) + assert str(model_args.torch_dtype) == "torch.bfloat16" + assert data_args.training_data_path == "data_train" + assert opt_args.output_dir == "models/Maykeye/TinyLLama-v0-GPTQ" + assert opt_args.quant_method == "gptq" + + +def test_parse_arguments_defaults(job_config): + """Test that defaults set in fms_mo/training_args.py are retained""" + parser = get_parser() + job_config_defaults = copy.deepcopy(job_config) + assert "torch_dtype" not in job_config_defaults + assert "max_seq_length" not in job_config_defaults + assert "model_revision" not in job_config_defaults + assert "nbits_kvcache" not in job_config_defaults + ( + model_args, + data_args, + _, + fms_mo_args, + _, + _, + ) = parse_arguments(parser, job_config_defaults) + assert str(model_args.torch_dtype) == "torch.bfloat16" + assert model_args.model_revision == "main" + assert data_args.max_seq_length == 2048 + assert fms_mo_args.nbits_kvcache == 32