Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,10 @@
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

def __init__(self, model: torch.nn.Module) -> None:
def __init__(self, model: torch.nn.Module, onnx_slim_transfom: bool = False) -> None:
super().__init__()
self.model = model
self.onnx_slim_transform = onnx_slim_transfom
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -119,6 +120,7 @@
example_inputs: Dict[str, torch.Tensor],
output_names: List[str],
dynamic_axes: Dict[str, Dict[int, str]],
onnx_slim_transform: bool = False,
export_kwargs: Optional[Dict[str, any]] = None,
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
Expand Down Expand Up @@ -146,7 +148,6 @@
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)

# Create input_names from example_inputs

input_names = []
for param in inspect.signature(self.model.forward).parameters:
if param in example_inputs:
Expand Down Expand Up @@ -183,11 +184,14 @@
**export_kwargs,
)
logger.info("Pytorch export successful")

model = onnx.load(tmp_onnx_path, load_external_data=False)
transform_kwargs = {
"temp_onnx_path": tmp_onnx_path,
"model_name": self.model_name,
"enable_onnx_slim_transform": onnx_slim_transform,
"onnx_base_dir": str(tmp_onnx_dir),
"model_name": self.model_name,

Check failure on line 193 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F601)

QEfficient/base/modeling_qeff.py:193:17: F601 Dictionary key literal `"model_name"` repeated
"enable_onnx_slim_transform": onnx_slim_transform,

Check failure on line 194 in QEfficient/base/modeling_qeff.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (F601)

QEfficient/base/modeling_qeff.py:194:17: F601 Dictionary key literal `"enable_onnx_slim_transform"` repeated
}
if onnx_transform_kwargs is not None:
transform_kwargs.update(onnx_transform_kwargs)
Expand Down Expand Up @@ -249,7 +253,6 @@
"""
if onnx_path is None and self.onnx_path is None:
self.export()

onnx_path = Path(onnx_path or self.onnx_path)
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
Expand Down Expand Up @@ -368,5 +371,4 @@
)

self.qpc_path = qpc_path

return qpc_path
33 changes: 33 additions & 0 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from typing import Optional, Tuple

import numpy as np
import onnx
import onnxslim
from onnx import ModelProto, external_data_helper, numpy_helper


Expand Down Expand Up @@ -37,6 +39,8 @@ class FP16ClipTransform(OnnxTransform):
Clips the tensor values to be in FP16 range, but preserves -inf values.
"""

print("FP16ClipTransform is applied")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to use logger to print any messages.


@classmethod
def apply(cls, model: ModelProto, *, onnx_base_dir: Optional[str] = None, **kwargs) -> Tuple[ModelProto, bool]:
"""
Expand Down Expand Up @@ -99,3 +103,32 @@ def apply(
current_file_size = tsize
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
return model, transformed


class OnnxSlimTransform(OnnxTransform):
"""
Applies onnx-slim transformations on the given ONNX graph.
"""

@classmethod
def apply(
cls,
model: ModelProto,
*,
onnx_base_dir: Optional[str] = None,
**kwargs,
) -> Tuple[ModelProto, bool]:
"""
:param enable_onnx_slim_transform: If True, applies onnx-slim transformations.
"""
# print(kwargs)
transformed = False
onnx_slim_transform = kwargs.get("enable_onnx_slim_transform", False)
temp_onnx_path = kwargs.get("temp_onnx_path", None)
if onnx_slim_transform:
print("onnx slim transform done")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print

transformed = True
slimmed_model = onnxslim.slim(model)
onnx.save(slimmed_model, temp_onnx_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Type Checking or Validation Ensure temp_onnx_path is not None before saving

return slimmed_model, transformed
return model, transformed
Loading
Loading