Skip to content

Commit e592774

Browse files
Resolved Custom IO file not found via CLI (#556)
This PR resolves an issue encountered when executing the model via CLI, where an error was raised due to the absence of the custom_IO.yaml file. In previous versions, the CLI did not explicitly generate a separate configuration for custom I/O. To address this, a script has been added that automatically creates the custom_IO.yaml file based on the specified modeling class. Additionally, an optional argument --mxint8_kv_cache has been introduced in QEfficient.cloud.export(). This flag enables compression of the Present/Past KV cache to MXINT8 format using the generated CustomIO configuration. --------- Signed-off-by: abhishek-singh591 <[email protected]>
1 parent 8e13633 commit e592774

File tree

5 files changed

+247
-11
lines changed

5 files changed

+247
-11
lines changed

QEfficient/cloud/export.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,20 @@
1111

1212
from QEfficient.base.common import QEFFCommonLoader
1313
from QEfficient.utils import check_and_assign_cache_dir
14+
from QEfficient.utils.custom_yaml import generate_custom_io
1415
from QEfficient.utils.logging_utils import logger
1516

1617
# Specifically for Docker images.
1718
ROOT_DIR = os.path.dirname(os.path.abspath(""))
1819

1920

20-
def get_onnx_model_path(
21+
def get_onnx_path_and_setup_customIO(
2122
model_name: str,
2223
cache_dir: Optional[str] = None,
2324
hf_token: Optional[str] = None,
2425
full_batch_size: Optional[int] = None,
2526
local_model_dir: Optional[str] = None,
27+
mxint8_kv_cache: Optional[int] = False,
2628
):
2729
"""
2830
Exports the PyTorch model to ONNX format if a pre-exported file is not found,
@@ -63,6 +65,9 @@ def get_onnx_model_path(
6365
)
6466
onnx_model_path = qeff_model.export()
6567
logger.info(f"Generated onnx_path: {onnx_model_path}")
68+
69+
# Generating Custom IO for the compile.
70+
generate_custom_io(qeff_model, mxint8_kv_cache=mxint8_kv_cache)
6671
return onnx_model_path
6772

6873

@@ -72,13 +77,14 @@ def main(
7277
hf_token: Optional[str] = None,
7378
local_model_dir: Optional[str] = None,
7479
full_batch_size: Optional[int] = None,
80+
mxint8_kv_cache: Optional[bool] = False,
7581
) -> None:
7682
"""
7783
Main function for the QEfficient ONNX export CLI application.
7884
7985
This function serves as the entry point for exporting a PyTorch model, loaded
8086
via QEFFCommonLoader, to the ONNX format. It prepares the necessary
81-
paths and calls `get_onnx_model_path`.
87+
paths and calls `get_onnx_path_and_setup_customIO`.
8288
8389
Parameters
8490
----------
@@ -106,12 +112,13 @@ def main(
106112
107113
"""
108114
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
109-
get_onnx_model_path(
115+
get_onnx_path_and_setup_customIO(
110116
model_name=model_name,
111117
cache_dir=cache_dir,
112118
hf_token=hf_token,
113119
full_batch_size=full_batch_size,
114120
local_model_dir=local_model_dir,
121+
mxint8_kv_cache=mxint8_kv_cache,
115122
)
116123

117124

@@ -137,5 +144,11 @@ def main(
137144
default=None,
138145
help="Set full batch size to enable continuous batching mode, default is None",
139146
)
147+
parser.add_argument(
148+
"--mxint8_kv_cache",
149+
"--mxint8-kv-cache",
150+
required=False,
151+
help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False",
152+
)
140153
args = parser.parse_args()
141154
main(**args.__dict__)

QEfficient/compile/compile_helper.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def compile(
270270
This method will be removed soon; use `QEFFAutoModelForCausalLM.compile` instead.
271271
272272
"""
273+
273274
if full_batch_size and batch_size != 1:
274275
raise ValueError("Only either batch_size or full_batch_size should be greater than one")
275276

@@ -284,11 +285,20 @@ def compile(
284285
full_batch_size=full_batch_size,
285286
)
286287

287-
# Select the customIO config based on the mx flag.
288-
custom_io_file_name = "custom_io_int8.yaml" if mxint8 else "custom_io_fp16.yaml"
288+
dtype_suffix = "int8" if mxint8 else "fp16"
289+
source_path = f"./custom_io_{dtype_suffix}.yaml"
290+
destination_path = os.path.join(os.path.dirname(qpc_path), f"custom_io_{dtype_suffix}.yaml")
291+
292+
# Move the custom YAML file to the cache/qeff_model directory
293+
try:
294+
shutil.move(source_path, destination_path)
295+
print(f"Successfully moved '{source_path}' to '{destination_path}'.")
296+
except Exception as e:
297+
print(f"Error while moving file '{source_path}': {e}")
289298

299+
custom_io_file_name = f"custom_io_{dtype_suffix}.yaml"
290300
if custom_io_file_path is None:
291-
custom_io_file_path = os.path.join(os.path.dirname(onnx_path), custom_io_file_name)
301+
custom_io_file_path = os.path.join(os.path.dirname(qpc_path), custom_io_file_name)
292302

293303
if not os.path.isfile(custom_io_file_path):
294304
raise FileNotFoundError(

QEfficient/utils/custom_yaml.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
import warnings
9+
from pathlib import Path
10+
11+
12+
class CustomIOGenerator:
13+
"""
14+
Abstract base class for generating custom IO mappings for different model types.
15+
16+
Args:
17+
model (object): The model instance for which IO mappings are to be generated.
18+
cache_dir (str): Directory path where the generated YAML files will be saved.
19+
mxint8_kv_cache (bool): If True, use 'mxint8' precision for KV cache; otherwise, use 'float16'.
20+
"""
21+
22+
def __init__(self, model, cache_dir=".", mxint8_kv_cache=False):
23+
self.model = model
24+
self.cache_dir = Path(cache_dir)
25+
self.kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
26+
self.dtype_suffix = "int8" if mxint8_kv_cache else "fp16"
27+
28+
def dump(self, custom_io: dict, suffix: str):
29+
"""
30+
Writes the custom IO mapping to a YAML file.
31+
32+
Args:
33+
custom_io (dict): Dictionary containing IO names and their precision types.
34+
suffix (str): Suffix to append to the output filename.
35+
"""
36+
custom_io_yaml = self.cache_dir / f"custom_io_{suffix}.yaml"
37+
with open(custom_io_yaml, "w") as fp:
38+
for io_name, dtype in custom_io.items():
39+
fp.write(f" - IOName: {io_name}\n Precision: {dtype}\n\n")
40+
41+
def generate(self) -> dict:
42+
"""
43+
Abstract method to generate custom IO mappings.
44+
45+
Returns:
46+
dict: A dictionary of IO names and their precision types.
47+
48+
Raises:
49+
NotImplementedError: Must be implemented by subclasses.
50+
"""
51+
raise NotImplementedError("Subclasses must implement this method")
52+
53+
54+
class CausalLMIOGenerator(CustomIOGenerator):
55+
"""
56+
IO generator for causal language models.
57+
"""
58+
59+
def generate(self) -> dict:
60+
"""
61+
Generates IO mappings for past key/value states in causal language models.
62+
63+
Returns:
64+
dict: Mapping of IO names to precision types.
65+
"""
66+
custom_io = {}
67+
num_layers = getattr(self.model, "num_layers", 12)
68+
for suffix in ["", "_RetainedState"]:
69+
for i in range(num_layers):
70+
for kv in ["key", "value"]:
71+
custom_io[f"past_{kv}.{i}{suffix}"] = self.kv_cache_dtype
72+
self.dump(custom_io, self.dtype_suffix)
73+
return custom_io
74+
75+
76+
class DualQPCIOGenerator(CustomIOGenerator):
77+
"""
78+
IO generator for dual QPC models (e.g., vision-language models).
79+
"""
80+
81+
def generate(self) -> dict:
82+
"""
83+
Generates IO mappings for both vision and language components.
84+
85+
Returns:
86+
dict: Combined mapping of IO names to precision types for vision and language outputs.
87+
"""
88+
output_names = self.model.model.get_output_names()
89+
custom_io_vision = {
90+
name: self.kv_cache_dtype if name.startswith("past_") else "float16"
91+
for name in output_names.get("vision", [])
92+
}
93+
94+
custom_io_lang = {}
95+
for name in output_names.get("lang", []):
96+
if name.endswith("_RetainedState"):
97+
base = name[: -len("_RetainedState")]
98+
dtype = "float16" if "vision_embeds" in name else self.kv_cache_dtype
99+
custom_io_lang[base] = dtype
100+
custom_io_lang[name] = dtype
101+
102+
self.dump(custom_io_vision, f"{self.dtype_suffix}_vision")
103+
self.dump(custom_io_lang, f"{self.dtype_suffix}_lang")
104+
warnings.warn(f"Unsupported model class via CLI: {type(self.model).__name__}", UserWarning)
105+
return {**custom_io_vision, **custom_io_lang}
106+
107+
108+
class SingleQPCIOGenerator(CustomIOGenerator):
109+
"""
110+
IO generator for single QPC models.
111+
"""
112+
113+
def generate(self) -> dict:
114+
"""
115+
Generates IO mappings for retained states in single QPC models.
116+
117+
Returns:
118+
dict: Mapping of IO names to precision types.
119+
"""
120+
output_names = self.model.model.get_output_names()
121+
custom_io = {}
122+
for name in output_names:
123+
if name.endswith("_RetainedState"):
124+
base = name[: -len("_RetainedState")]
125+
dtype = "float16" if "pixel_values" in name else self.kv_cache_dtype
126+
custom_io[base] = dtype
127+
custom_io[name] = dtype
128+
self.dump(custom_io, self.dtype_suffix)
129+
return custom_io
130+
131+
132+
class SpeechSeq2SeqIOGenerator(CustomIOGenerator):
133+
"""
134+
IO generator for speech sequence-to-sequence models.
135+
"""
136+
137+
def generate(self) -> dict:
138+
"""
139+
Generates IO mappings for input features and retained states in speech models.
140+
141+
Returns:
142+
dict: Mapping of IO names to precision types.
143+
"""
144+
output_names = self.model.model.get_output_names()
145+
custom_io = {"input_features": self.kv_cache_dtype}
146+
for name in output_names:
147+
if name.endswith("_RetainedState"):
148+
base = name[: -len("_RetainedState")]
149+
custom_io[base] = self.kv_cache_dtype
150+
custom_io[name] = self.kv_cache_dtype
151+
self.dump(custom_io, self.dtype_suffix)
152+
return custom_io
153+
154+
155+
class UnsupportedModelIOGenerator(CustomIOGenerator):
156+
"""
157+
Fallback IO generator for unsupported model types.
158+
"""
159+
160+
def generate(self) -> dict:
161+
"""
162+
Emits a warning for unsupported model types.
163+
164+
Returns:
165+
dict: Empty dictionary.
166+
"""
167+
warnings.warn(f"Unsupported model class: {type(self.model).__name__}", UserWarning)
168+
return {}
169+
170+
171+
class CustomIOFactory:
172+
"""
173+
Factory class to instantiate the appropriate IO generator based on model type.
174+
"""
175+
176+
@staticmethod
177+
def get_generator(model, cache_dir=".", mxint8_kv_cache=False) -> CustomIOGenerator:
178+
"""
179+
Returns the appropriate IO generator instance for the given model.
180+
181+
Args:
182+
model (object): The model instance.
183+
cache_dir (str): Directory to store YAML files.
184+
mxint8_kv_cache (bool): Flag to use 'mxint8' precision.
185+
186+
Returns:
187+
CustomIOGenerator: An instance of the appropriate subclass.
188+
"""
189+
model_class_name = type(model).__name__
190+
mapping = {
191+
"QEFFAutoModelForCausalLM": CausalLMIOGenerator,
192+
"_QEFFAutoModelForImageTextToTextDualQPC": DualQPCIOGenerator,
193+
"_QEFFAutoModelForImageTextToTextSingleQPC": SingleQPCIOGenerator,
194+
"QEFFAutoModelForSpeechSeq2Seq": SpeechSeq2SeqIOGenerator,
195+
}
196+
generator_class = mapping.get(model_class_name, UnsupportedModelIOGenerator)
197+
return generator_class(model, cache_dir, mxint8_kv_cache)
198+
199+
200+
def generate_custom_io(qeff_model, cache_dir=".", mxint8_kv_cache=False) -> dict:
201+
"""
202+
Generates and returns custom IO mappings for the given QEFF model.
203+
204+
Args:
205+
qeff_model (object): The model instance.
206+
cache_dir (str): Directory to store YAML files.
207+
mxint8_kv_cache (bool): Flag to use 'mxint8' precision.
208+
209+
Returns:
210+
dict: Custom IO mapping generated by the appropriate generator.
211+
"""
212+
generator = CustomIOFactory.get_generator(qeff_model, cache_dir, mxint8_kv_cache)
213+
return generator.generate()

examples/cpp_execution/text_inference_using_cpp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
1515

1616
import QEfficient
17-
from QEfficient.cloud.export import get_onnx_model_path
17+
from QEfficient.cloud.export import get_onnx_path_and_setup_customIO
1818
from QEfficient.generation.text_generation_inference import fix_prompts, get_compilation_dims, get_input_prompts
1919
from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists
2020
from QEfficient.utils.logging_utils import logger
@@ -103,7 +103,7 @@ def main(
103103
logger.info(f"Pre-compiled qpc found at {qpc_dir_path}! Executing with given prompt")
104104
else:
105105
# Handle onnx model generation
106-
onnx_model_path = get_onnx_model_path(
106+
onnx_model_path = get_onnx_path_and_setup_customIO(
107107
model_name, cache_dir, tokenizer, hf_token, local_model_dir, full_batch_size
108108
)
109109
_ = QEfficient.compile(

tests/cloud/test_export_compile_execute.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
def check_export_compile_execute(mocker, model_name, full_batch_size=None, enable_qnn=False):
2020
check_and_assign_cache_dir_spy = mocker.spy(QEfficient.cloud.export, "check_and_assign_cache_dir")
21-
get_onnx_model_path_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_model_path")
21+
get_onnx_path_and_setup_customIO_spy = mocker.spy(QEfficient.cloud.export, "get_onnx_path_and_setup_customIO")
2222
load_hf_tokenizer_spy = mocker.spy(QEfficient.cloud.execute, "load_hf_tokenizer")
2323
cloud_ai_100_exec_kv_spy = mocker.spy(QEfficient.cloud.execute, "cloud_ai_100_exec_kv")
2424

@@ -29,9 +29,9 @@ def check_export_compile_execute(mocker, model_name, full_batch_size=None, enabl
2929
)
3030

3131
check_and_assign_cache_dir_spy.assert_called_once()
32-
get_onnx_model_path_spy.assert_called_once()
32+
get_onnx_path_and_setup_customIO_spy.assert_called_once()
3333

34-
onnx_model_path = get_onnx_model_path_spy.spy_return
34+
onnx_model_path = get_onnx_path_and_setup_customIO_spy.spy_return
3535

3636
assert os.path.isfile(onnx_model_path)
3737

0 commit comments

Comments
 (0)