Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 131 additions & 25 deletions fms_mo/run_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@

# Standard
import logging
import os
import sys
import time
import traceback

# Third Party
from datasets import load_from_disk
from huggingface_hub.errors import HFValidationError
from torch.cuda import OutOfMemoryError
from transformers import AutoTokenizer
import transformers

Expand All @@ -44,9 +49,14 @@
ModelArguments,
OptArguments,
)
from fms_mo.utils.config_utils import get_json_config
from fms_mo.utils.error_logging import (
INTERNAL_ERROR_EXIT_CODE,
USER_ERROR_EXIT_CODE,
write_termination_log,
)
from fms_mo.utils.import_utils import available_packages

logger = logging.Logger("fms_mo.main")
from fms_mo.utils.logging_utils import set_log_level


def quantize(
Expand All @@ -71,6 +81,8 @@ def quantize(
output_dir (str) Output directory to write to
"""

logger = set_log_level(opt_args.log_level, "fms_mo.quantize")

logger.info(f"{fms_mo_args}\n{opt_args.quant_method}\n")

if opt_args.quant_method == "gptq":
Expand Down Expand Up @@ -120,6 +132,8 @@ def run_gptq(model_args, data_args, opt_args, gptq_args):
# Local
from fms_mo.utils.custom_gptq_models import custom_gptq_classes

logger = set_log_level(opt_args.log_level, "fms_mo.run_gptq")

quantize_config = BaseQuantizeConfig(
bits=gptq_args.bits,
group_size=gptq_args.group_size,
Expand Down Expand Up @@ -179,6 +193,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

logger = set_log_level(opt_args.log_level, "fms_mo.run_fp8")

model = SparseAutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, torch_dtype=model_args.torch_dtype
)
Expand All @@ -205,9 +221,8 @@ def run_fp8(model_args, data_args, opt_args, fp8_args):
tokenizer.save_pretrained(opt_args.output_dir)


def main():
"""Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""

def get_parser():
"""Get the command-line argument parser."""
parser = transformers.HfArgumentParser(
dataclass_types=(
ModelArguments,
Expand All @@ -218,20 +233,53 @@ def main():
FP8Arguments,
)
)
return parser

(
model_args,
data_args,
opt_args,
fms_mo_args,
gptq_args,
fp8_args,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

logger.debug(
"Input args parsed: \nmodel_args %s, data_args %s, opt_args %s, fms_mo_args %s, "
"gptq_args %s, fp8_args %s",
def parse_arguments(parser, json_config=None):
"""Parses arguments provided either via command-line or JSON config.

Args:
parser: argparse.ArgumentParser
Command-line argument parser.
json_config: dict[str, Any]
Dict of arguments to use with tuning.

Returns:
ModelArguments
Arguments pertaining to which model we are going to quantize.
DataArguments
Arguments pertaining to what data we are going to use for optimization and evaluation.
OptArguments
Arguments generic to optimization.
FMSMOArguments
Configuration for PTQ quantization.
GPTQArguments
Configuration for GPTQ quantization.
FP8Arguments
Configuration for FP8 quantization.
"""
if json_config:
(
model_args,
data_args,
opt_args,
fms_mo_args,
gptq_args,
fp8_args,
) = parser.parse_dict(json_config, allow_extra_keys=True)
else:
(
model_args,
data_args,
opt_args,
fms_mo_args,
gptq_args,
fp8_args,
_,
) = parser.parse_args_into_dataclasses(return_remaining_strings=True)

return (
model_args,
data_args,
opt_args,
Expand All @@ -240,14 +288,72 @@ def main():
fp8_args,
)

quantize(
model_args=model_args,
data_args=data_args,
opt_args=opt_args,
fms_mo_args=fms_mo_args,
gptq_args=gptq_args,
fp8_args=fp8_args,
)

def main():
"""Main entry point for quantize API for GPTQ, FP8 and DQ quantization techniques"""

parser = get_parser()
logger = logging.getLogger()
job_config = get_json_config()
# accept arguments via command-line or JSON
try:
(
model_args,
data_args,
opt_args,
fms_mo_args,
gptq_args,
fp8_args,
) = parse_arguments(parser, job_config)

logger = set_log_level(opt_args.log_level, __name__)

logger.debug(f"Input args parsed: \nmodel_args {model_args}, data_args {data_args}, \
opt_args {opt_args}, fms_mo_args {fms_mo_args}, gptq_args {gptq_args}, \
fp8_args {fp8_args}")
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
write_termination_log(
f"Exception raised during optimization. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)

if opt_args.output_dir:
os.makedirs(opt_args.output_dir, exist_ok=True)
logger.info("Using the output directory at %s", opt_args.output_dir)
try:
quantize(
model_args=model_args,
data_args=data_args,
opt_args=opt_args,
fms_mo_args=fms_mo_args,
gptq_args=gptq_args,
fp8_args=fp8_args,
)
except (MemoryError, OutOfMemoryError) as e:
logger.error(traceback.format_exc())
write_termination_log(f"OOM error during optimization. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)
except FileNotFoundError as e:
logger.error(traceback.format_exc())
write_termination_log(f"Unable to load file: {e}")
sys.exit(USER_ERROR_EXIT_CODE)
except HFValidationError as e:
logger.error(traceback.format_exc())
write_termination_log(
f"There may be a problem with loading the model. Exception: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except (TypeError, ValueError, EnvironmentError) as e:
logger.error(traceback.format_exc())
write_termination_log(
f"Exception raised during optimization. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logger.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during optimization: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)


if __name__ == "__main__":
Expand Down
80 changes: 80 additions & 0 deletions fms_mo/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import base64
import json
import os
import pickle


def update_config(config, **kwargs):
"""Updates config from key-value pairs provided through kwargs"""
if isinstance(config, (tuple, list)):
for c in config:
update_config(c, **kwargs)
else:
for k, v in kwargs.items():
if hasattr(config, k):
setattr(config, k, v)
elif "." in k:
# allow --some_config.some_param=True
config_name, param_name = k.split(".")
if type(config).__name__ == config_name:
if hasattr(config, param_name):
setattr(config, param_name, v)
else:
# In case of specialized config we can warm user
print(f"Warning: {config_name} does not accept parameter: {k}")


def get_json_config():
"""Parses JSON configuration if provided via environment variables
FMS_MO_CONFIG_JSON_ENV_VAR or FMS_MO_CONFIG_JSON_PATH.

FMS_MO_CONFIG_JSON_ENV_VAR is the base64 encoded JSON.
FMS_MO_CONFIG_JSON_PATH is the path to the JSON config file.

Returns: dict or {}
"""
json_env_var = os.getenv("FMS_MO_CONFIG_JSON_ENV_VAR")
json_path = os.getenv("FMS_MO_CONFIG_JSON_PATH")

# accepts either path to JSON file or encoded string config
# env var takes precedent
job_config_dict = {}
if json_env_var:
job_config_dict = txt_to_obj(json_env_var)
elif json_path:
with open(json_path, "r", encoding="utf-8") as f:
job_config_dict = json.load(f)

return job_config_dict


def txt_to_obj(txt):
"""Given encoded byte string, converts to base64 decoded dict.

Args:
txt: str
Returns: dict[str, Any]
"""
base64_bytes = txt.encode("ascii")
message_bytes = base64.b64decode(base64_bytes)
try:
# If the bytes represent JSON string
return json.loads(message_bytes)
except UnicodeDecodeError:
# Otherwise the bytes are a pickled python dictionary
return pickle.loads(message_bytes)
4 changes: 3 additions & 1 deletion fms_mo/utils/dq_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def config_quantize_smooth_layers(qcfg):
"granite-20b-code",
"granite-20b-code",
]
if any(model in qcfg["model"] for model in llama_architecture):
if any(model in qcfg["model"] for model in llama_architecture) or any(
model in qcfg["model_type"] for model in llama_architecture
):
qcfg["qlayer_name_pattern"] = ["model.layers."]
qcfg["scale_layers"] = ["k_proj", "v_proj", "gate_proj", "up_proj"]
qcfg["qskip_layer_name"] = []
Expand Down
41 changes: 41 additions & 0 deletions fms_mo/utils/error_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright The FMS Model Optimizer Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import logging
import os

# The USER_ERROR_EXIT_CODE will be thrown when the process must exit
# as result of a user input error. User-related errors should be
# >= 1 and <=127 due to how some kubernetes operators interpret them.
USER_ERROR_EXIT_CODE = 1
# The INTERNAL_ERROR_EXIT_CODE will be thrown when training
# abnormally terminates, and it is not clearly fault of the user.
# System-level errors should be >= 128 and <= 254
INTERNAL_ERROR_EXIT_CODE = 203


def write_termination_log(text, log_file="error.log"):
"""Writes text to termination log.

Args:
text: str
log_file: Optional[str]
"""
log_file = os.environ.get("TERMINATION_LOG_FILE", log_file)
try:
with open(log_file, "a", encoding="utf-8") as handle:
handle.write(text)
except Exception as e: # pylint: disable=broad-except
logging.warning(f"Unable to write termination log due to error {e}")
Loading
Loading