diff --git a/QEfficient/cloud/export.py b/QEfficient/cloud/export.py index d2e3f66fc..a5e0b6e19 100644 --- a/QEfficient/cloud/export.py +++ b/QEfficient/cloud/export.py @@ -11,18 +11,20 @@ from QEfficient.base.common import QEFFCommonLoader from QEfficient.utils import check_and_assign_cache_dir +from QEfficient.utils.custom_yaml import generate_custom_io from QEfficient.utils.logging_utils import logger # Specifically for Docker images. ROOT_DIR = os.path.dirname(os.path.abspath("")) -def get_onnx_model_path( +def get_onnx_path_and_setup_customIO( model_name: str, cache_dir: Optional[str] = None, hf_token: Optional[str] = None, full_batch_size: Optional[int] = None, local_model_dir: Optional[str] = None, + mxint8_kv_cache: Optional[int] = False, ): """ Exports the PyTorch model to ONNX format if a pre-exported file is not found, @@ -63,6 +65,9 @@ def get_onnx_model_path( ) onnx_model_path = qeff_model.export() logger.info(f"Generated onnx_path: {onnx_model_path}") + + # Generating Custom IO for the compile. + generate_custom_io(qeff_model, mxint8_kv_cache=mxint8_kv_cache) return onnx_model_path @@ -72,13 +77,14 @@ def main( hf_token: Optional[str] = None, local_model_dir: Optional[str] = None, full_batch_size: Optional[int] = None, + mxint8_kv_cache: Optional[bool] = False, ) -> None: """ Main function for the QEfficient ONNX export CLI application. This function serves as the entry point for exporting a PyTorch model, loaded via QEFFCommonLoader, to the ONNX format. It prepares the necessary - paths and calls `get_onnx_model_path`. + paths and calls `get_onnx_path_and_setup_customIO`. Parameters ---------- @@ -106,12 +112,13 @@ def main( """ cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir) - get_onnx_model_path( + get_onnx_path_and_setup_customIO( model_name=model_name, cache_dir=cache_dir, hf_token=hf_token, full_batch_size=full_batch_size, local_model_dir=local_model_dir, + mxint8_kv_cache=mxint8_kv_cache, ) @@ -137,5 +144,11 @@ def main( default=None, help="Set full batch size to enable continuous batching mode, default is None", ) + parser.add_argument( + "--mxint8_kv_cache", + "--mxint8-kv-cache", + required=False, + help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False", + ) args = parser.parse_args() main(**args.__dict__) diff --git a/QEfficient/compile/compile_helper.py b/QEfficient/compile/compile_helper.py index 9cb2e1062..5de21f876 100644 --- a/QEfficient/compile/compile_helper.py +++ b/QEfficient/compile/compile_helper.py @@ -270,6 +270,7 @@ def compile( This method will be removed soon; use `QEFFAutoModelForCausalLM.compile` instead. """ + if full_batch_size and batch_size != 1: raise ValueError("Only either batch_size or full_batch_size should be greater than one") @@ -284,11 +285,20 @@ def compile( full_batch_size=full_batch_size, ) - # Select the customIO config based on the mx flag. - custom_io_file_name = "custom_io_int8.yaml" if mxint8 else "custom_io_fp16.yaml" + dtype_suffix = "int8" if mxint8 else "fp16" + source_path = f"./custom_io_{dtype_suffix}.yaml" + destination_path = os.path.join(os.path.dirname(qpc_path), f"custom_io_{dtype_suffix}.yaml") + + # Move the custom YAML file to the cache/qeff_model directory + try: + shutil.move(source_path, destination_path) + print(f"Successfully moved '{source_path}' to '{destination_path}'.") + except Exception as e: + print(f"Error while moving file '{source_path}': {e}") + custom_io_file_name = f"custom_io_{dtype_suffix}.yaml" if custom_io_file_path is None: - custom_io_file_path = os.path.join(os.path.dirname(onnx_path), custom_io_file_name) + custom_io_file_path = os.path.join(os.path.dirname(qpc_path), custom_io_file_name) if not os.path.isfile(custom_io_file_path): raise FileNotFoundError( diff --git a/QEfficient/utils/custom_yaml.py b/QEfficient/utils/custom_yaml.py new file mode 100644 index 000000000..2adb656b5 --- /dev/null +++ b/QEfficient/utils/custom_yaml.py @@ -0,0 +1,213 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import warnings +from pathlib import Path + + +class CustomIOGenerator: + """ + Abstract base class for generating custom IO mappings for different model types. + + Args: + model (object): The model instance for which IO mappings are to be generated. + cache_dir (str): Directory path where the generated YAML files will be saved. + mxint8_kv_cache (bool): If True, use 'mxint8' precision for KV cache; otherwise, use 'float16'. + """ + + def __init__(self, model, cache_dir=".", mxint8_kv_cache=False): + self.model = model + self.cache_dir = Path(cache_dir) + self.kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + self.dtype_suffix = "int8" if mxint8_kv_cache else "fp16" + + def dump(self, custom_io: dict, suffix: str): + """ + Writes the custom IO mapping to a YAML file. + + Args: + custom_io (dict): Dictionary containing IO names and their precision types. + suffix (str): Suffix to append to the output filename. + """ + custom_io_yaml = self.cache_dir / f"custom_io_{suffix}.yaml" + with open(custom_io_yaml, "w") as fp: + for io_name, dtype in custom_io.items(): + fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n") + + def generate(self) -> dict: + """ + Abstract method to generate custom IO mappings. + + Returns: + dict: A dictionary of IO names and their precision types. + + Raises: + NotImplementedError: Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement this method") + + +class CausalLMIOGenerator(CustomIOGenerator): + """ + IO generator for causal language models. + """ + + def generate(self) -> dict: + """ + Generates IO mappings for past key/value states in causal language models. + + Returns: + dict: Mapping of IO names to precision types. + """ + custom_io = {} + num_layers = getattr(self.model, "num_layers", 12) + for suffix in ["", "_RetainedState"]: + for i in range(num_layers): + for kv in ["key", "value"]: + custom_io[f"past_{kv}.{i}{suffix}"] = self.kv_cache_dtype + self.dump(custom_io, self.dtype_suffix) + return custom_io + + +class DualQPCIOGenerator(CustomIOGenerator): + """ + IO generator for dual QPC models (e.g., vision-language models). + """ + + def generate(self) -> dict: + """ + Generates IO mappings for both vision and language components. + + Returns: + dict: Combined mapping of IO names to precision types for vision and language outputs. + """ + output_names = self.model.model.get_output_names() + custom_io_vision = { + name: self.kv_cache_dtype if name.startswith("past_") else "float16" + for name in output_names.get("vision", []) + } + + custom_io_lang = {} + for name in output_names.get("lang", []): + if name.endswith("_RetainedState"): + base = name[: -len("_RetainedState")] + dtype = "float16" if "vision_embeds" in name else self.kv_cache_dtype + custom_io_lang[base] = dtype + custom_io_lang[name] = dtype + + self.dump(custom_io_vision, f"{self.dtype_suffix}_vision") + self.dump(custom_io_lang, f"{self.dtype_suffix}_lang") + warnings.warn(f"Unsupported model class via CLI: {type(self.model).__name__}", UserWarning) + return {**custom_io_vision, **custom_io_lang} + + +class SingleQPCIOGenerator(CustomIOGenerator): + """ + IO generator for single QPC models. + """ + + def generate(self) -> dict: + """ + Generates IO mappings for retained states in single QPC models. + + Returns: + dict: Mapping of IO names to precision types. + """ + output_names = self.model.model.get_output_names() + custom_io = {} + for name in output_names: + if name.endswith("_RetainedState"): + base = name[: -len("_RetainedState")] + dtype = "float16" if "pixel_values" in name else self.kv_cache_dtype + custom_io[base] = dtype + custom_io[name] = dtype + self.dump(custom_io, self.dtype_suffix) + return custom_io + + +class SpeechSeq2SeqIOGenerator(CustomIOGenerator): + """ + IO generator for speech sequence-to-sequence models. + """ + + def generate(self) -> dict: + """ + Generates IO mappings for input features and retained states in speech models. + + Returns: + dict: Mapping of IO names to precision types. + """ + output_names = self.model.model.get_output_names() + custom_io = {"input_features": self.kv_cache_dtype} + for name in output_names: + if name.endswith("_RetainedState"): + base = name[: -len("_RetainedState")] + custom_io[base] = self.kv_cache_dtype + custom_io[name] = self.kv_cache_dtype + self.dump(custom_io, self.dtype_suffix) + return custom_io + + +class UnsupportedModelIOGenerator(CustomIOGenerator): + """ + Fallback IO generator for unsupported model types. + """ + + def generate(self) -> dict: + """ + Emits a warning for unsupported model types. + + Returns: + dict: Empty dictionary. + """ + warnings.warn(f"Unsupported model class: {type(self.model).__name__}", UserWarning) + return {} + + +class CustomIOFactory: + """ + Factory class to instantiate the appropriate IO generator based on model type. + """ + + @staticmethod + def get_generator(model, cache_dir=".", mxint8_kv_cache=False) -> CustomIOGenerator: + """ + Returns the appropriate IO generator instance for the given model. + + Args: + model (object): The model instance. + cache_dir (str): Directory to store YAML files. + mxint8_kv_cache (bool): Flag to use 'mxint8' precision. + + Returns: + CustomIOGenerator: An instance of the appropriate subclass. + """ + model_class_name = type(model).__name__ + mapping = { + "QEFFAutoModelForCausalLM": CausalLMIOGenerator, + "_QEFFAutoModelForImageTextToTextDualQPC": DualQPCIOGenerator, + "_QEFFAutoModelForImageTextToTextSingleQPC": SingleQPCIOGenerator, + "QEFFAutoModelForSpeechSeq2Seq": SpeechSeq2SeqIOGenerator, + } + generator_class = mapping.get(model_class_name, UnsupportedModelIOGenerator) + return generator_class(model, cache_dir, mxint8_kv_cache) + + +def generate_custom_io(qeff_model, cache_dir=".", mxint8_kv_cache=False) -> dict: + """ + Generates and returns custom IO mappings for the given QEFF model. + + Args: + qeff_model (object): The model instance. + cache_dir (str): Directory to store YAML files. + mxint8_kv_cache (bool): Flag to use 'mxint8' precision. + + Returns: + dict: Custom IO mapping generated by the appropriate generator. + """ + generator = CustomIOFactory.get_generator(qeff_model, cache_dir, mxint8_kv_cache) + return generator.generate() diff --git a/examples/cpp_execution/text_inference_using_cpp.py b/examples/cpp_execution/text_inference_using_cpp.py index 9b0d59c73..072f2c57c 100644 --- a/examples/cpp_execution/text_inference_using_cpp.py +++ b/examples/cpp_execution/text_inference_using_cpp.py @@ -14,7 +14,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast import QEfficient -from QEfficient.cloud.export import get_onnx_model_path +from QEfficient.cloud.export import get_onnx_path_and_setup_customIO from QEfficient.generation.text_generation_inference import fix_prompts, get_compilation_dims, get_input_prompts from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists from QEfficient.utils.logging_utils import logger @@ -103,7 +103,7 @@ def main( logger.info(f"Pre-compiled qpc found at {qpc_dir_path}! Executing with given prompt") else: # Handle onnx model generation - onnx_model_path = get_onnx_model_path( + onnx_model_path = get_onnx_path_and_setup_customIO( model_name, cache_dir, tokenizer, hf_token, local_model_dir, full_batch_size ) _ = QEfficient.compile( diff --git a/tests/cloud/test_export_compile_execute.py b/tests/cloud/test_export_compile_execute.py index 7cac59da7..f1c80a6b0 100644 --- a/tests/cloud/test_export_compile_execute.py +++ b/tests/cloud/test_export_compile_execute.py @@ -18,7 +18,7 @@ def check_export_compile_execute(mocker, model_name, full_batch_size=None, enable_qnn=False): check_and_assign_cache_dir_spy = mocker.spy(QEfficient.cloud.export, "check_and_assign_cache_dir") - get_onnx_model_path_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_model_path") + get_onnx_path_and_setup_customIO_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_path_and_setup_customIO") load_hf_tokenizer_spy = mocker.spy(QEfficient.cloud.execute, "load_hf_tokenizer") cloud_ai_100_exec_kv_spy = mocker.spy(QEfficient.cloud.execute, "cloud_ai_100_exec_kv") @@ -29,9 +29,9 @@ def check_export_compile_execute(mocker, model_name, full_batch_size=None, enabl ) check_and_assign_cache_dir_spy.assert_called_once() - get_onnx_model_path_spy.assert_called_once() + get_onnx_path_and_setup_customIO_spy.assert_called_once() - onnx_model_path = get_onnx_model_path_spy.spy_return + onnx_model_path = get_onnx_path_and_setup_customIO_spy.spy_return assert os.path.isfile(onnx_model_path)