Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions QEfficient/cloud/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@

from QEfficient.base.common import QEFFCommonLoader
from QEfficient.utils import check_and_assign_cache_dir
from QEfficient.utils.custom_yaml import generate_custom_io
from QEfficient.utils.logging_utils import logger

# Specifically for Docker images.
ROOT_DIR = os.path.dirname(os.path.abspath(""))


def get_onnx_model_path(
def get_onnx_path_and_setup_customIO(
model_name: str,
cache_dir: Optional[str] = None,
hf_token: Optional[str] = None,
full_batch_size: Optional[int] = None,
local_model_dir: Optional[str] = None,
mxint8_kv_cache: Optional[int] = False,
):
"""
Exports the PyTorch model to ONNX format if a pre-exported file is not found,
Expand Down Expand Up @@ -63,6 +65,9 @@ def get_onnx_model_path(
)
onnx_model_path = qeff_model.export()
logger.info(f"Generated onnx_path: {onnx_model_path}")

# Generating Custom IO for the compile.
generate_custom_io(qeff_model, mxint8_kv_cache=mxint8_kv_cache)
return onnx_model_path


Expand All @@ -72,13 +77,14 @@ def main(
hf_token: Optional[str] = None,
local_model_dir: Optional[str] = None,
full_batch_size: Optional[int] = None,
mxint8_kv_cache: Optional[bool] = False,
) -> None:
"""
Main function for the QEfficient ONNX export CLI application.

This function serves as the entry point for exporting a PyTorch model, loaded
via QEFFCommonLoader, to the ONNX format. It prepares the necessary
paths and calls `get_onnx_model_path`.
paths and calls `get_onnx_path_and_setup_customIO`.

Parameters
----------
Expand Down Expand Up @@ -106,12 +112,13 @@ def main(

"""
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
get_onnx_model_path(
get_onnx_path_and_setup_customIO(
model_name=model_name,
cache_dir=cache_dir,
hf_token=hf_token,
full_batch_size=full_batch_size,
local_model_dir=local_model_dir,
mxint8_kv_cache=mxint8_kv_cache,
)


Expand All @@ -137,5 +144,11 @@ def main(
default=None,
help="Set full batch size to enable continuous batching mode, default is None",
)
parser.add_argument(
"--mxint8_kv_cache",
"--mxint8-kv-cache",
required=False,
help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False",
)
args = parser.parse_args()
main(**args.__dict__)
16 changes: 13 additions & 3 deletions QEfficient/compile/compile_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def compile(
This method will be removed soon; use `QEFFAutoModelForCausalLM.compile` instead.

"""

if full_batch_size and batch_size != 1:
raise ValueError("Only either batch_size or full_batch_size should be greater than one")

Expand All @@ -284,11 +285,20 @@ def compile(
full_batch_size=full_batch_size,
)

# Select the customIO config based on the mx flag.
custom_io_file_name = "custom_io_int8.yaml" if mxint8 else "custom_io_fp16.yaml"
dtype_suffix = "int8" if mxint8 else "fp16"
source_path = f"./custom_io_{dtype_suffix}.yaml"
destination_path = os.path.join(os.path.dirname(qpc_path), f"custom_io_{dtype_suffix}.yaml")

# Move the custom YAML file to the cache/qeff_model directory
try:
shutil.move(source_path, destination_path)
print(f"Successfully moved '{source_path}' to '{destination_path}'.")
except Exception as e:
print(f"Error while moving file '{source_path}': {e}")

custom_io_file_name = f"custom_io_{dtype_suffix}.yaml"
if custom_io_file_path is None:
custom_io_file_path = os.path.join(os.path.dirname(onnx_path), custom_io_file_name)
custom_io_file_path = os.path.join(os.path.dirname(qpc_path), custom_io_file_name)

if not os.path.isfile(custom_io_file_path):
raise FileNotFoundError(
Expand Down
213 changes: 213 additions & 0 deletions QEfficient/utils/custom_yaml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# ----------------------------------------------------------------------------

import warnings
from pathlib import Path


class CustomIOGenerator:
"""
Abstract base class for generating custom IO mappings for different model types.

Args:
model (object): The model instance for which IO mappings are to be generated.
cache_dir (str): Directory path where the generated YAML files will be saved.
mxint8_kv_cache (bool): If True, use 'mxint8' precision for KV cache; otherwise, use 'float16'.
"""

def __init__(self, model, cache_dir=".", mxint8_kv_cache=False):
self.model = model
self.cache_dir = Path(cache_dir)
self.kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
self.dtype_suffix = "int8" if mxint8_kv_cache else "fp16"

def dump(self, custom_io: dict, suffix: str):
"""
Writes the custom IO mapping to a YAML file.

Args:
custom_io (dict): Dictionary containing IO names and their precision types.
suffix (str): Suffix to append to the output filename.
"""
custom_io_yaml = self.cache_dir / f"custom_io_{suffix}.yaml"
with open(custom_io_yaml, "w") as fp:
for io_name, dtype in custom_io.items():
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")

def generate(self) -> dict:
"""
Abstract method to generate custom IO mappings.

Returns:
dict: A dictionary of IO names and their precision types.

Raises:
NotImplementedError: Must be implemented by subclasses.
"""
raise NotImplementedError("Subclasses must implement this method")


class CausalLMIOGenerator(CustomIOGenerator):
"""
IO generator for causal language models.
"""

def generate(self) -> dict:
"""
Generates IO mappings for past key/value states in causal language models.

Returns:
dict: Mapping of IO names to precision types.
"""
custom_io = {}
num_layers = getattr(self.model, "num_layers", 12)
for suffix in ["", "_RetainedState"]:
for i in range(num_layers):
for kv in ["key", "value"]:
custom_io[f"past_{kv}.{i}{suffix}"] = self.kv_cache_dtype
self.dump(custom_io, self.dtype_suffix)
return custom_io


class DualQPCIOGenerator(CustomIOGenerator):
"""
IO generator for dual QPC models (e.g., vision-language models).
"""

def generate(self) -> dict:
"""
Generates IO mappings for both vision and language components.

Returns:
dict: Combined mapping of IO names to precision types for vision and language outputs.
"""
output_names = self.model.model.get_output_names()
custom_io_vision = {
name: self.kv_cache_dtype if name.startswith("past_") else "float16"
for name in output_names.get("vision", [])
}

custom_io_lang = {}
for name in output_names.get("lang", []):
if name.endswith("_RetainedState"):
base = name[: -len("_RetainedState")]
dtype = "float16" if "vision_embeds" in name else self.kv_cache_dtype
custom_io_lang[base] = dtype
custom_io_lang[name] = dtype

self.dump(custom_io_vision, f"{self.dtype_suffix}_vision")
self.dump(custom_io_lang, f"{self.dtype_suffix}_lang")
warnings.warn(f"Unsupported model class via CLI: {type(self.model).__name__}", UserWarning)
return {**custom_io_vision, **custom_io_lang}


class SingleQPCIOGenerator(CustomIOGenerator):
"""
IO generator for single QPC models.
"""

def generate(self) -> dict:
"""
Generates IO mappings for retained states in single QPC models.

Returns:
dict: Mapping of IO names to precision types.
"""
output_names = self.model.model.get_output_names()
custom_io = {}
for name in output_names:
if name.endswith("_RetainedState"):
base = name[: -len("_RetainedState")]
dtype = "float16" if "pixel_values" in name else self.kv_cache_dtype
custom_io[base] = dtype
custom_io[name] = dtype
self.dump(custom_io, self.dtype_suffix)
return custom_io


class SpeechSeq2SeqIOGenerator(CustomIOGenerator):
"""
IO generator for speech sequence-to-sequence models.
"""

def generate(self) -> dict:
"""
Generates IO mappings for input features and retained states in speech models.

Returns:
dict: Mapping of IO names to precision types.
"""
output_names = self.model.model.get_output_names()
custom_io = {"input_features": self.kv_cache_dtype}
for name in output_names:
if name.endswith("_RetainedState"):
base = name[: -len("_RetainedState")]
custom_io[base] = self.kv_cache_dtype
custom_io[name] = self.kv_cache_dtype
self.dump(custom_io, self.dtype_suffix)
return custom_io


class UnsupportedModelIOGenerator(CustomIOGenerator):
"""
Fallback IO generator for unsupported model types.
"""

def generate(self) -> dict:
"""
Emits a warning for unsupported model types.

Returns:
dict: Empty dictionary.
"""
warnings.warn(f"Unsupported model class: {type(self.model).__name__}", UserWarning)
return {}


class CustomIOFactory:
"""
Factory class to instantiate the appropriate IO generator based on model type.
"""

@staticmethod
def get_generator(model, cache_dir=".", mxint8_kv_cache=False) -> CustomIOGenerator:
"""
Returns the appropriate IO generator instance for the given model.

Args:
model (object): The model instance.
cache_dir (str): Directory to store YAML files.
mxint8_kv_cache (bool): Flag to use 'mxint8' precision.

Returns:
CustomIOGenerator: An instance of the appropriate subclass.
"""
model_class_name = type(model).__name__
mapping = {
"QEFFAutoModelForCausalLM": CausalLMIOGenerator,
"_QEFFAutoModelForImageTextToTextDualQPC": DualQPCIOGenerator,
"_QEFFAutoModelForImageTextToTextSingleQPC": SingleQPCIOGenerator,
"QEFFAutoModelForSpeechSeq2Seq": SpeechSeq2SeqIOGenerator,
}
generator_class = mapping.get(model_class_name, UnsupportedModelIOGenerator)
return generator_class(model, cache_dir, mxint8_kv_cache)


def generate_custom_io(qeff_model, cache_dir=".", mxint8_kv_cache=False) -> dict:
"""
Generates and returns custom IO mappings for the given QEFF model.

Args:
qeff_model (object): The model instance.
cache_dir (str): Directory to store YAML files.
mxint8_kv_cache (bool): Flag to use 'mxint8' precision.

Returns:
dict: Custom IO mapping generated by the appropriate generator.
"""
generator = CustomIOFactory.get_generator(qeff_model, cache_dir, mxint8_kv_cache)
return generator.generate()
4 changes: 2 additions & 2 deletions examples/cpp_execution/text_inference_using_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

import QEfficient
from QEfficient.cloud.export import get_onnx_model_path
from QEfficient.cloud.export import get_onnx_path_and_setup_customIO
from QEfficient.generation.text_generation_inference import fix_prompts, get_compilation_dims, get_input_prompts
from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists
from QEfficient.utils.logging_utils import logger
Expand Down Expand Up @@ -103,7 +103,7 @@ def main(
logger.info(f"Pre-compiled qpc found at {qpc_dir_path}! Executing with given prompt")
else:
# Handle onnx model generation
onnx_model_path = get_onnx_model_path(
onnx_model_path = get_onnx_path_and_setup_customIO(
model_name, cache_dir, tokenizer, hf_token, local_model_dir, full_batch_size
)
_ = QEfficient.compile(
Expand Down
6 changes: 3 additions & 3 deletions tests/cloud/test_export_compile_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

def check_export_compile_execute(mocker, model_name, full_batch_size=None, enable_qnn=False):
check_and_assign_cache_dir_spy = mocker.spy(QEfficient.cloud.export, "check_and_assign_cache_dir")
get_onnx_model_path_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_model_path")
get_onnx_path_and_setup_customIO_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_path_and_setup_customIO")
load_hf_tokenizer_spy = mocker.spy(QEfficient.cloud.execute, "load_hf_tokenizer")
cloud_ai_100_exec_kv_spy = mocker.spy(QEfficient.cloud.execute, "cloud_ai_100_exec_kv")

Expand All @@ -29,9 +29,9 @@ def check_export_compile_execute(mocker, model_name, full_batch_size=None, enabl
)

check_and_assign_cache_dir_spy.assert_called_once()
get_onnx_model_path_spy.assert_called_once()
get_onnx_path_and_setup_customIO_spy.assert_called_once()

onnx_model_path = get_onnx_model_path_spy.spy_return
onnx_model_path = get_onnx_path_and_setup_customIO_spy.spy_return

assert os.path.isfile(onnx_model_path)

Expand Down
Loading