diff --git a/docker/Dockerfile.onnx-mlir b/docker/Dockerfile.onnx-mlir
index f028ed29b3..170f3d63b8 100644
--- a/docker/Dockerfile.onnx-mlir
+++ b/docker/Dockerfile.onnx-mlir
@@ -26,7 +26,7 @@ RUN ONNX_ROOT=${WORK_DIR}/onnx-mlir/third_party/onnx \
ARG NPROC=4
ARG ACCEL=NNPA
ARG TEST_NOFLOAT16
-ARG TEST_MCPU
+ARG TEST_MARCH
ARG KEEPSRC
RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \
@@ -53,21 +53,21 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \
([ "$(uname -m)" = "x86_64" ] && echo true || \
([ "$(uname -m)" = "ppc64le" ] && echo || echo)))} \
# User image is built with SIMD (currently on s390x only)
- && TEST_MCPU=${TEST_MCPU:-$([ "$(uname -m)" = "s390x" ] && echo z16 || \
+ && TEST_MARCH=${TEST_MARCH:-$([ "$(uname -m)" = "s390x" ] && echo z16 || \
([ "$(uname -m)" = "x86_64" ] && echo || \
([ "$(uname -m)" = "ppc64le" ] && echo || echo)))} \
- && TEST_ARGS="-mcpu=${TEST_MCPU}" \
+ && TEST_ARGS="-march=${TEST_MARCH}" \
&& make check-docs \
&& make check-unittest \
&& make check-multiple-models \
&& make NPROC=${NPROC} \
CTEST_PARALLEL_LEVEL=${NPROC} \
TEST_NOFLOAT16=${TEST_NOFLOAT16} \
- TEST_MCPU=${TEST_MCPU} \
+ TEST_MARCH=${TEST_MARCH} \
TEST_ARGS="${TEST_ARGS}" \
-j${NPROC} \
check-onnx-backend-numerical \
- && if [ "${TEST_MCPU}" = "z16" ]; then \
+ && if [ "${TEST_MARCH}" = "z16" ]; then \
make NPROC=${NPROC} \
CTEST_PARALLEL_LEVEL=${NPROC} \
-j${NPROC} \
diff --git a/docker/Dockerfile.onnx-mlir-dev b/docker/Dockerfile.onnx-mlir-dev
index 574737c1a9..344fa273b5 100644
--- a/docker/Dockerfile.onnx-mlir-dev
+++ b/docker/Dockerfile.onnx-mlir-dev
@@ -20,7 +20,7 @@ RUN ONNX_ROOT=${WORK_DIR}/onnx-mlir/third_party/onnx \
ARG NPROC=4
ARG ACCEL=NNPA
ARG TEST_NOFLOAT16
-ARG TEST_MCPU
+ARG TEST_MARCH
RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \
&& ONNX_MLIR_ROOT=${WORK_DIR}/onnx-mlir \
@@ -51,10 +51,10 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \
([ "$(uname -m)" = "x86_64" ] && echo true || \
([ "$(uname -m)" = "ppc64le" ] && echo || echo)))} \
# Dev image is built without SIMD, placeholder for easy SIMD enablement
- && TEST_MCPU=$([ "$(uname -m)" = "s390x" ] && echo || \
+ && TEST_MARCH=$([ "$(uname -m)" = "s390x" ] && echo || \
([ "$(uname -m)" = "x86_64" ] && echo || \
([ "$(uname -m)" = "ppc64le" ] && echo || echo))) \
- && TEST_ARGS="-mcpu=${TEST_MCPU}" \
+ && TEST_ARGS="-march=${TEST_MARCH}" \
&& TEST_OPTLEVEL=0 \
&& make check-docs \
&& make check-unittest \
@@ -62,7 +62,7 @@ RUN LLVM_PROJECT_ROOT=${WORK_DIR}/llvm-project \
&& make NPROC=${NPROC} \
CTEST_PARALLEL_LEVEL=${NPROC} \
TEST_NOFLOAT16=${TEST_NOFLOAT16} \
- TEST_MCPU=${TEST_MCPU} \
+ TEST_MARCH=${TEST_MARCH} \
TEST_ARGS="${TEST_ARGS}" \
TEST_OPTLEVEL=${TEST_OPTLEVEL} \
-j${NPROC} \
diff --git a/docs/DebuggingNumericalError.md b/docs/DebuggingNumericalError.md
index 62b513ff33..0eeabcb505 100644
--- a/docs/DebuggingNumericalError.md
+++ b/docs/DebuggingNumericalError.md
@@ -65,7 +65,7 @@ optional arguments:
## Helper script to compare a model under two distinct compile option.
Based on the above `utils/runONNXModel.py`, the `utils/checkONNXModel.py` allows a user to run a given model twice, under two distinct compile options, and compare its results.
-This let a user simply test a new option, comparing the safe version of the compiler (e.g. `-O0` or `-O3`) with a more advanced version (e.g. `-O3` or `-O3 -march=x86-64`). Simply specify the compile options using the `--ref-compile-args` and `--test-compile-args` flags, a model using the `--model` flag, and possibly a `--shape-info` in presence of dynamic shape inputs.
+This let a user simply test a new option, comparing the safe version of the compiler (e.g. `-O0` or `-O3`) with a more advanced version (e.g. `-O3` or `-O3 --march=x86-64`). Simply specify the compile options using the `--ref-compile-args` and `--test-compile-args` flags, a model using the `--model` flag, and possibly a `--shape-info` in presence of dynamic shape inputs.
Full options are listed under the `--help` flag.
## Debugging the Code Generated for an Operator.
diff --git a/docs/Dialects/zhigh.md b/docs/Dialects/zhigh.md
index 4780cbe551..dd87eeecf5 100644
--- a/docs/Dialects/zhigh.md
+++ b/docs/Dialects/zhigh.md
@@ -337,6 +337,61 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `hn_output` | unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS
+### `zhigh.Gelu` (::onnx_mlir::zhigh::ZHighGeluOp)
+
+_ZHigh Gelu operation_
+
+"ZHigh operation to perform a Gelu."
+
+Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+approximate | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+### `zhigh.InvSqrt` (::onnx_mlir::zhigh::ZHighInvSqrtOp)
+
+_ZHigh InvSqrt operation_
+
+ZHigh operation to perform a InvSqrt.
+
+Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
### `zhigh.LSTM` (::onnx_mlir::zhigh::ZHighLSTMOp)
_ZHigh LSTM operation_
@@ -389,6 +444,37 @@ Effects: `MemoryEffects::Effect{}`
| `hn_output` | unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS
| `cf_output` | unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS
+### `zhigh.LeakyRelu` (::onnx_mlir::zhigh::ZHighLeakyReluOp)
+
+_ZHigh LeakyRelu operation_
+
+"ZHigh operation to perform a LeakyRelu."
+
+Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+alpha | ::mlir::FloatAttr | 32-bit float attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
### `zhigh.Log` (::onnx_mlir::zhigh::ZHighLogOp)
_ZHigh Log operation_
@@ -425,6 +511,14 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac
Effects: `MemoryEffects::Effect{}`
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+transposeA | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+transposeB | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+
+
#### Operands:
| Operand | Description |
@@ -577,6 +671,168 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+### `zhigh.QuantizedMatMul` (::onnx_mlir::zhigh::ZHighQuantizedMatMulOp)
+
+_ZHigh QuantizedMatMul operation_
+
+ZHigh operation to perform a quantized MatMul.
+
+`OutRecScaleIn` and `OutOffsetIn` are recscale and offset for the output.
+If `OutRecScaleIn` is given, it will be passed to `OutRecScale`. If it is
+None, `OutRescScale` is set to 1.0.
+If `OutOffsetIn` is given, it will be passed to `OutOffset`. If it is
+None, `OutOffset` is set to 0.0.
+
+* PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed.
+
+`DequantizeOutput` indicates if the output
+is dequantized to real dfloat16 or not. If not, the output is int8 but stored in dlfloat (int8-as-dlfloat).
+* DequantizeOutput: -1 output is dequantized, 0: output is not dequantized.
+
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+PreComputedBias | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+DisableClipping | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+DequantizeOutput | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS
+| `XRecScale` | 0D tensor of 32-bit float values
+| `XOffset` | 0D tensor of 32-bit float values
+| `Y` | unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS
+| `YRecScale` | 0D tensor of 32-bit float values
+| `YOffset` | 0D tensor of 32-bit float values
+| `B` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 8-bit signless integer or 16-bit float values or 1D tensor of 8-bit signless integer or 16-bit float values with layout _1D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2DS or none type
+| `BRecScale` | 0D tensor of 32-bit float values or none type
+| `BOffset` | 0D tensor of 32-bit float values or none type
+| `OutRecScaleIn` | 0D tensor of 32-bit float values or none type
+| `OutOffsetIn` | 0D tensor of 32-bit float values or none type
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Out` | unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS
+| `OutRecScale` | 0D tensor of 32-bit float values
+| `OutOffset` | 0D tensor of 32-bit float values
+
+### `zhigh.QuantizedStick` (::onnx_mlir::zhigh::ZHighQuantizedStickOp)
+
+_ZHigh QuantizedStick operation_
+
+ZHigh operation to perform a quantized Stick.
+Type is one of values: dlfloat16, int8, and weights.
+`sym_mode` indicates whether to use symmetric quantization or not to compute the output rescale and offset.
+`sym_mode` is only effective when the input rescale and offset are None.
+By default, asymmetric quantization is used.
+
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+quantized_type | ::mlir::StringAttr | string attribute |
+sym_mode | ::mlir::IntegerAttr | 64-bit signless integer attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `In` | tensor of 32-bit float values or tensor of 8-bit signless integer values or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS
+| `InRecScale` | 0D tensor of 32-bit float values or none type
+| `InOffset` | 0D tensor of 32-bit float values or none type
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Out` | unranked tensor of 8-bit signless integer or 16-bit float values or 1D tensor of 8-bit signless integer or 16-bit float values with layout _1D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2DS or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS or none type
+| `RecScale` | 0D tensor of 32-bit float values
+| `Offset` | 0D tensor of 32-bit float values
+
+### `zhigh.ReduceMax` (::onnx_mlir::zhigh::ZHighReduceMaxOp)
+
+_ZHigh ReduceMax operation_
+
+ZHigh operation to perform a ReduceMax.
+op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM.
+
+Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+op_type | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `data` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+### `zhigh.ReduceMin` (::onnx_mlir::zhigh::ZHighReduceMinOp)
+
+_ZHigh ReduceMin operation_
+
+ZHigh operation to perform a ReduceMin.
+op_type: REDUCE_OP_MAXIMUM or REDUCE_OP_MINIMUM.
+
+Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+op_type | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `data` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
### `zhigh.Relu` (::onnx_mlir::zhigh::ZHighReluOp)
_ZHigh Relu operation_
@@ -657,6 +913,30 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `Out` | unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS
+### `zhigh.Sqrt` (::onnx_mlir::zhigh::ZHighSqrtOp)
+
+_ZHigh Sqrt operation_
+
+ZHigh operation to perform a Sqrt.
+
+Traits: `AlwaysSpeculatableImplTrait`, `SameOperandsAndResultLayout`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
### `zhigh.StickForGRU` (::onnx_mlir::zhigh::ZHighStickForGRUOp)
_ZHigh stick operation for GRU_
@@ -815,7 +1095,7 @@ Effects: `MemoryEffects::Effect{}`
| Result | Description |
| :----: | ----------- |
-| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+| `output` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH or unranked tensor of 8-bit signless integer or 16-bit float values or 1D tensor of 8-bit signless integer or 16-bit float values with layout _1D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2D or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3D or unranked tensor of 8-bit signless integer or 16-bit float values or 2D tensor of 8-bit signless integer or 16-bit float values with layout _2DS or unranked tensor of 8-bit signless integer or 16-bit float values or 3D tensor of 8-bit signless integer or 16-bit float values with layout _3DS
### `zhigh.Sub` (::onnx_mlir::zhigh::ZHighSubOp)
diff --git a/docs/Dialects/zlow.md b/docs/Dialects/zlow.md
index ba6907fced..7be1c6457b 100644
--- a/docs/Dialects/zlow.md
+++ b/docs/Dialects/zlow.md
@@ -342,6 +342,52 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `hn_output` | memref of dlfloat16 type values
+### `zlow.gelu` (::onnx_mlir::zlow::ZLowGeluOp)
+
+_ZLow gelu operation_
+
+ZLow operation to perform a gelu.
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
+### `zlow.invsqrt` (::onnx_mlir::zlow::ZLowInvSqrtOp)
+
+_ZLow invsqrt operation_
+
+ZLow operation to perform a invsqrt.
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
### `zlow.lstm` (::onnx_mlir::zlow::ZLowLSTMOp)
_ZLow lstm operation_
@@ -387,6 +433,30 @@ Interfaces: `MemoryEffectOpInterface`
| `hn_output` | memref of dlfloat16 type values
| `cf_output` | memref of dlfloat16 type values
+### `zlow.leakyrelu` (::onnx_mlir::zlow::ZLowLeakyReluOp)
+
+_ZLow leakyrelu operation_
+
+ZLow operation to perform a leakyrelu.
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+alpha | ::mlir::FloatAttr | 32-bit float attribute |
+layout | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
### `zlow.log` (::onnx_mlir::zlow::ZLowLogOp)
_ZLow log operation_
@@ -423,14 +493,18 @@ shape is a 1D MemRef (memref<3xi64>) whose items are:
* 2nd item: n
* 3rd item: p
* In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p)
- or broadcasting: X(s, m, n) * Y(n, p) + Bias(p)
+ or broadcasting1: X(m, n) * Y(s, n, p) + Bias(s, p)
+ or broadcasting23: X(s, m, n) * Y(n, p) + Bias(p)
shape is a 1D MemRef (memref<4xi64>) whose items are:
* 1st item: s
* 2nd item: m
* 3rd item: n
* 4th item: p
-* is_bcast: -1 broadcasting, 0: no broadcasting.
+* is_bcast1: -1 broadcasting1, 0: no broadcasting1.
+* is_bcast23: -1 broadcasting23, 0: no broadcasting23.
* is_stacked: -1 stacked, 0: unstacked.
+* transposeA: !0 transpose A, 0: do not transpose A.
+* transposeB: !0 transpose B, 0: do not transpose B.
Traits: `MemRefsNormalizable`
@@ -440,8 +514,11 @@ Interfaces: `MemoryEffectOpInterface`
Attribute | MLIR Type | Description |
-is_bcast | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+is_bcast1 | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+is_bcast23 | ::mlir::IntegerAttr | 64-bit signed integer attribute |
is_stacked | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+transposeA | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+transposeB | ::mlir::IntegerAttr | 64-bit signed integer attribute |
#### Operands:
@@ -592,6 +669,144 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values
+### `zlow.quantizedMatmul` (::onnx_mlir::zlow::ZLowQuantizedMatMulOp)
+
+_ZLow quantized matmul operation_
+
+ZLow operation to perform a matmul.
+work_area: a 4K-aligned buffer having the same layout as bias but dlfloat16 type.
+* In case of unstacked: X(m, n) * Y(n, p) + Bias(p)
+shape is a 1D MemRef (memref<3xi64>) whose items are:
+ * 1st item: m
+ * 2nd item: n
+ * 3rd item: p
+* In case of stacked: X(s, m, n) * Y(s, n, p) + Bias(s, p)
+ or broadcasting: X(s, m, n) * Y(n, p) + Bias(p)
+shape is a 1D MemRef (memref<4xi64>) whose items are:
+ * 1st item: s
+ * 2nd item: m
+ * 3rd item: n
+ * 4th item: p
+* is_bcast: -1 broadcasting, 0: no broadcasting.
+* is_stacked: -1 stacked, 0: unstacked.
+* DequantizeOutput: -1 output is dequantized, 0: output is not dequantized.
+* PreComputedBias: -1 bias is re-computed, 0: bias is not pre-computed.
+
+Values for `q_type` are "DLFLOAT16", "INT8", "WEIGHTS", "UNDEFINED".
+
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+x_q_type | ::mlir::StringAttr | string attribute |
+y_q_type | ::mlir::StringAttr | string attribute |
+bias_q_type | ::mlir::StringAttr | string attribute |
+out_q_type | ::mlir::StringAttr | string attribute |
+is_bcast | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+is_stacked | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+pre_computed_bias | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+disable_clipping | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+dequantize_output | ::mlir::IntegerAttr | 64-bit signed integer attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type or 8-bit signless integer values
+| `x_rec_scale` | 0D memref of 32-bit float values
+| `x_offset` | 0D memref of 32-bit float values
+| `Y` | memref of dlfloat16 type or 8-bit signless integer values
+| `y_rec_scale` | 0D memref of 32-bit float values
+| `y_offset` | 0D memref of 32-bit float values
+| `Bias` | memref of dlfloat16 type or 8-bit signless integer values
+| `bias_rec_scale` | 0D memref of 32-bit float values
+| `bias_offset` | 0D memref of 32-bit float values
+| `work_area` | memref of dlfloat16 type or 8-bit signless integer values or none type
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type or 8-bit signless integer values
+| `out_rec_scale` | 0D memref of 32-bit float values
+| `out_offset` | 0D memref of 32-bit float values
+
+### `zlow.quantizedStick` (::onnx_mlir::zlow::ZLowQuantizedStickOp)
+
+_ZLow stick operation for quantization_
+
+"ZLow operation to perform a quantization stick."
+"Type is one of values: dlfloat16, int8, and weights."
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+q_type | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of 8-bit signless integer or 32-bit float values
+| `rec_scale` | 0D memref of 32-bit float values
+| `offset` | 0D memref of 32-bit float values
+| `out` | memref of dlfloat16 type or 8-bit signless integer values
+
+### `zlow.reducemax` (::onnx_mlir::zlow::ZLowReduceMaxOp)
+
+_ZLow reducemax operation_
+
+ZLow operation to perform a reducemax.
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+op_type | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `work_area` | memref of 8-bit signless integer values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
+### `zlow.reducemin` (::onnx_mlir::zlow::ZLowReduceMinOp)
+
+_ZLow reducemin operation_
+
+ZLow operation to perform a reducemin.
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+op_type | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `work_area` | memref of 8-bit signless integer values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
### `zlow.relu` (::onnx_mlir::zlow::ZLowReluOp)
_ZLow relu operation_
@@ -670,6 +885,29 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values
+### `zlow.sqrt` (::onnx_mlir::zlow::ZLowSqrtOp)
+
+_ZLow sqrt operation_
+
+ZLow operation to perform a sqrt.
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
### `zlow.stickForGRU` (::onnx_mlir::zlow::ZLowStickForGRUOp)
_ZLow stick operation for GRU_
diff --git a/docs/Instrumentation.md b/docs/Instrumentation.md
index 31969ff15e..25b77153b6 100644
--- a/docs/Instrumentation.md
+++ b/docs/Instrumentation.md
@@ -61,11 +61,11 @@ The output for the memory measurement is explained here.
Other example for NNPA
- Performance profiling for onnx ops before lowering to zhigh ops:
- `onnx-mlir --mcpu=z16 --maccel=NNPA --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx`
+ `onnx-mlir --march=z16 --maccel=NNPA --instrument-stage=Onnx --instrument-ops=onnx.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx`
- Performance profiling for onnx and zhigh ops:
- `onnx-mlir --mcpu=z16 --maccel=NNPA --instrument-stage=ZHigh --instrument-ops=onnx.*,zhigh.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx`
+ `onnx-mlir --march=z16 --maccel=NNPA --instrument-stage=ZHigh --instrument-ops=onnx.*,zhigh.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx`
- Performance profiling for zlow ops:
- `onnx-mlir --mcpu=z16 --maccel=NNPA --instrument-stage=ZLow --instrument-ops=zlow.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx`
+ `onnx-mlir --march=z16 --maccel=NNPA --instrument-stage=ZLow --instrument-ops=zlow.* --InstrumentBeforeOp --InstrumentAfterOp --InstrumentReportTime mymodel.onnx`
## Control instrument at runtime
By providing certain env variable at runtime, you can disable reports from instrument library.
diff --git a/docs/SupportedONNXOps-NNPA.md b/docs/SupportedONNXOps-NNPA.md
index 80fa3287cf..a0f85aef41 100644
--- a/docs/SupportedONNXOps-NNPA.md
+++ b/docs/SupportedONNXOps-NNPA.md
@@ -8,38 +8,38 @@ Onnx-mlir currently supports ONNX operations targeting up to opset 21. Limitatio
* Operations are defined by the [ONNX Standard](https://github.com/onnx/onnx/blob/main/docs/Operators.md).
* **Supported Opsets** indicates the lowest and highest opset a model may have for onnx-mlir to support compiling a model with the operator.
* A * indicates onnx-mlir is compatible with the latest version of that operator available as of opset 21.
+ * A ^ indicates onnx-mlir is compatible with the latest level of the NNPA Architecture which is z16.
-NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. NNPA currently only support DLFLOAT16 as its data type. Common data formats like FP32, FP16, BFLOAT need to undergo data conversions to the NNPA internal format DLFLOAT16. Hence ONNX ops which updated their tensors to BFLOAT16 will not be natively supported on NNPA.
+NNPA has hardware limitations in dimension index size and tensor size, which are described in [NNPALimit.hpp](../src/Accelerators/NNPA/Support/NNPALimit.hpp). They are large enough for normal use cases, but if your model exceeds the limitations, CPU is used instead of NNPA. NNPA currently only support DLFLOAT16 as its data type. Common data formats like FP32, FP16, BFLOAT need to undergo data conversions to the NNPA internal format DLFLOAT16. Hence ONNX ops which updated their tensors to BFLOAT16 will not be natively supported on NNPA. Onnx-mlir with NNPA utilizes hardware when possible. To accomplish this, the compiler converts ONNX ops to [ZHigh](Dialects/zhigh.md) ops, [ZLow](Dialects/zlow.md) ops, and are processed by the [IBM Z Deep Neural Network Library (zDNN)](https://github.com/IBM/zDNN).
-| Op |Supported Opsets (inclusive) |Limitations |Notes |
-| --- |--- |--- |--- |
-| **Add** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
-| **AveragePool** |6 - * |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors (N x C x H x W).
- `kernel_shape` must be static.
- `count_include_pad` must be default value(0).
- `ceil_mode` must be default value(0). | |
-| **BatchNormalization** |6 - * |Input and output tensor must be 4D(N x C x H x W). | |
-| **Conv** |6 - * |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- Dimension in Height and weight must be static.
- `group` must be default value(1).
- `dilations` must be default value(1).
- Input and output tensors must have 4D (N x C x H x W).
- `kernel_shape` must be static. | |
-| **ConvTranspose** |6 - * |- 1D and 3D not supported because Conv1D and Conv3D not supported in zDNN. non-default `dilations` not supported because dilated convolution not supported in zDNN. | |
-| **Div** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
-| **Exp** |6 - * |Input tensor must have 4 dimensions. | |
-| **GRU** |7 - * |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- If `B` and `initial_h` are given, they must have static dimensions.
- `sequence_lens` is not supported for bidirectional GRU.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `linear_before_reset` must be 1.
- `layout` is not supported. | |
-| **Gemm** |6 - * |- `alpha` and `beta` must be default value(1).
- Rank of `C` must be 1 or 2. If the rank is 1, the dimension of `C` must be the same with the seconde dimension of `B`. | |
-| **GlobalAveragePool** |6 - * |- Input shape must be 4D tensor(NCHW).
- Dimensions in `H` and `W` must be static. | |
-| **LSTM** |7 - * |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- `B` and `initial_h` have static dimensions if given. `B`'s direction dim must be 1 or 2.
- `P`(peepholes), `activation_alpha`, and `activation_beta` are not supported.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `input_forget` must be default value(0).
- `layout` is not supported. | |
-| **LeakyRelu** |6 - * |The operations immediately before and after the LeakyRelu operation must be executed on the NNPA. Otherwise, LeakyRelu is executed on the CPU. This limitation is set to avoid performance degradation. | |
-| **Log** |6 - * |Input tensor must have 4 dimensions. | |
-| **LogSoftmax** |6 - * | | |
-| **MatMul** |6 - * |Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. | |
-| **Max** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
-| **MaxPool** |6 - * |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors(N x C x H x W).
- `kernel_shape` must be static.
- `ceil_mode` must be default value(0).
- `dilations` must be default value(1). | |
-| **Min** |6 - * |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
-| **Mul** |6 - * |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
-| **Pow** |7 - * |- Exponent should be a scalar integer and less or equal to 64. | |
-| **ReduceMean** |6 - * |- `keepdims` must be 1.
- Input tensor must be 4D tensors and `axis` must be [2, 3]. | |
-| **Relu** |6 - * |Input tensor must be less than or equal to 4 dimensions. | |
-| **Sigmoid** |6 - * |Input tensor must be less than or equal to 4 dimensions. | |
-| **Softmax** |6 - * |- `axis` must be the last dimension, i.e. `rank - 1` or -1. | |
-| **Softplus** |6 - * |The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation. | |
-| **Sub** |6 - * |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
-| **Sum** |6 - * |- All inputs must have the same static shape (Broadcasting not supported.)
- Single input not supported. | |
-| **Tanh** |6 - * |Input tensor must be less than or equal to 4 dimensions. | |
+| Op |Supported Opsets (inclusive) |Minimum NNPA Level(Inclusive) |Limitations |Notes |
+| --- |--- |--- |--- |--- |
+| **Add** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
+| **AveragePool** |6 - * |z16 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors (N x C x H x W).
- `kernel_shape` must be static.
- `count_include_pad` must be default value(0).
- `ceil_mode` must be default value(0). | |
+| **BatchNormalization** |6 - * |z16 |Input and output tensor must be 4D(N x C x H x W). | |
+| **Conv** |6 - * |z16 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- Dimension in Height and weight must be static.
- `group` must be default value(1).
- `dilations` must be default value(1).
- Input and output tensors must have 4D (N x C x H x W).
- `kernel_shape` must be static. | |
+| **ConvTranspose** |6 - * |z16 |- 1D and 3D not supported because Conv1D and Conv3D not supported in zDNN. non-default `dilations` not supported because dilated convolution not supported in zDNN. | |
+| **Div** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
+| **Exp** |6 - * |z16 |Input tensor must have 4 dimensions. | |
+| **GRU** |7 - * |z16 |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- If `B` and `initial_h` are given, they must have static dimensions.
- `sequence_lens` is not supported for bidirectional GRU.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `linear_before_reset` must be 1.
- `layout` is not supported. | |
+| **Gemm** |6 - * |z16 |- `alpha` and `beta` must be default value(1).
- Rank of `C` must be 1 or 2. If the rank is 1, the dimension of `C` must be the same with the seconde dimension of `B`.
. | |
+| **GlobalAveragePool** |6 - * |z16 |- Input shape must be 4D tensor(NCHW).
- Dimensions in `H` and `W` must be static. | |
+| **LSTM** |7 - * |z16 |- `direction` and `hidden_size` in `W` must have static dimensions.
- `R` must have static dimensions.
- `B` and `initial_h` have static dimensions if given. `B`'s direction dim must be 1 or 2.
- `P`(peepholes), `activation_alpha`, and `activation_beta` are not supported.
- `activations` must be `["Sigmoid", "Tanh", "Tanh"]`.
- `clip` is not supported.
- `input_forget` must be default value(0).
- `layout` is not supported. | |
+| **Log** |6 - * |z16 |Input tensor must have 4 dimensions. | |
+| **LogSoftmax** |6 - * |z16 | | |
+| **MatMul** |6 - * |z16 |Ranks of input tensors must be (Rank of A, Rank of B) = (M, N), where M >= 2 and N >= 2. | |
+| **Max** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
+| **MaxPool** |6 - * |z16 |- `auto_pad` must be `NOTSET`, `VALID`, and `SAME_UPPER`. If `NOTSET` is used, `pads` must be set so that the padding valid type or same upper.
- `ceil_mode` must be default value(0)
- Input and output tensors must be 4D tensors(N x C x H x W).
- `kernel_shape` must be static.
- `ceil_mode` must be default value(0).
- `dilations` must be default value(1). | |
+| **Min** |6 - * |z16 |- Shape of input tensors must be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
+| **Mul** |6 - * |z16 |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
+| **Pow** |7 - * |z16 |- Exponent should be a scalar integer and less or equal to 64. | |
+| **ReduceMean** |6 - * |z16 |- `keepdims` must be 1.
- Input tensor must be 4D tensors and `axis` must be [2, 3]. | |
+| **Relu** |6 - * |z16 |Input tensor must be less than or equal to 4 dimensions. | |
+| **Sigmoid** |6 - * |z16 |Input tensor must be less than or equal to 4 dimensions. | |
+| **Softmax** |6 - * |z16 |- `axis` must be the last dimension, i.e. `rank - 1` or -1. | |
+| **Softplus** |6 - * |z16 |The operations immediately before and after the Softplus operation must be executed on the NNPA. Otherwise, Softplus is executed on the CPU. This limitation is set to avoid performance degradation. | |
+| **Sub** |6 - * |z16 |- Shape of input tensors should be the same since broadcasting is not supported.
- Input tensors must have static dimensions. | |
+| **Sum** |6 - * |z16 |- All inputs must have the same static shape (Broadcasting not supported.)
- Single input not supported. | |
+| **Tanh** |6 - * |z16 |Input tensor must be less than or equal to 4 dimensions. | |
diff --git a/docs/Testing.md b/docs/Testing.md
index 9fafc75876..c5029f01a5 100644
--- a/docs/Testing.md
+++ b/docs/Testing.md
@@ -122,9 +122,9 @@ cmake --build . --config Release --target check-onnx-backend-signature
### Enable SIMD instructions
-On supported platforms, currently s390x only, backend tests can generate SIMD instructions for the compiled models. To enable SIMD, set the TEST_MCPU environment variable, e.g.,
+On supported platforms (currently s390x z14 and up, x86, and arm), backend tests can generate SIMD instructions for the compiled models. To enable SIMD, set the TEST_MARCH environment variable, e.g.,
```
-TEST_MCPU=z14 cmake --build . --config Release --target check-onnx-backend[-jni]
+TEST_MARCH=z16 cmake --build . --config Release --target check-onnx-backend[-jni]
```
### Execution of backend tests
@@ -294,9 +294,9 @@ If you need to change ATOL and RTOL for accuracy checks, set the environment var
### Enable SIMD instructions
-On supported platforms, currently s390x only, numerical tests can generate SIMD instructions for the compiled models. To enable SIMD, set the `TEST_ARGS` environment variable, e.g.,
+On supported platforms (currently s390x z14 and up, x86, and arm), numerical tests can generate SIMD instructions for the compiled models. To enable SIMD, set the `TEST_ARGS` environment variable, e.g.,
```
-TEST_ARGS="-mcpu=z14" CTEST_PARALLEL_LEVEL=$(nproc) cmake --build . --config Release --target check-onnx-numerical
+TEST_ARGS="-march=z16" CTEST_PARALLEL_LEVEL=$(nproc) cmake --build . --config Release --target check-onnx-numerical
```
### Testing of specific accelerators
@@ -395,7 +395,7 @@ Without specifying a model using `-m`, the script will check all models in the O
If you want to gather performance info about a model zoo (or any models, for that matter), simplest is to request the desired statistic at compile time (using `-profile-ir` flag), divert the output statistic to a file, and then analyze it using `make-report.py`. For example:
```
-> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 -march=arm64 --profile-ir=Onnx" -m bertsquad-10
+> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 --march=arm64 --profile-ir=Onnx" -m bertsquad-10
...
> make-report.py -r run.log
...
@@ -408,7 +408,7 @@ Statistics start (all ops).
The runtime profiling info can be combined with specific compile-time statistics as well. Let's say that we are interested in SIMD statistics. We inform the compiler of the compile-time statistic to emit using `-opt-report` option, and inform `RunONNXModelZoo.py` that we want to preserve the compiler output using the `--log-to-file` option. For example
```
-> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 -march=arm64 -opt-report=Simd --profile-ir=Onnx" -m bertsquad-10 --log-to-file compile.log
+> ONNX_MLIR_INSTRUMENT_FILE=run.log RunONNXModelZoo.py -c "-O3 --march=arm64 -opt-report=Simd --profile-ir=Onnx" -m bertsquad-10 --log-to-file compile.log
...
> make-report.py -c compile.log -r run.log
...
diff --git a/src/Accelerators/NNPA/CMakeLists.txt b/src/Accelerators/NNPA/CMakeLists.txt
index 51625e984b..d3687aabc9 100644
--- a/src/Accelerators/NNPA/CMakeLists.txt
+++ b/src/Accelerators/NNPA/CMakeLists.txt
@@ -33,7 +33,7 @@ else()
endif()
include(zdnn.cmake)
-setup_zdnn(v1.0.1)
+setup_zdnn(v1.1.1)
add_subdirectory(Dialect)
add_subdirectory(Conversion)
diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp
index 91b6aab183..52d7933888 100644
--- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp
+++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.cpp
@@ -101,4 +101,27 @@ llvm::cl::opt nnpaEnableSaturation("nnpa-saturation",
"Default is false."),
llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
+llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPU("nnpa-cpu-dql",
+ llvm::cl::desc("Use dynamic quantized linear on CPU. Default is false"),
+ llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
+
+llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset(
+ "nnpa-cpu-dql-scale",
+ llvm::cl::desc("Use dynamic quantized linear computation of "
+ " scale and offset on CPU. Default is false"),
+ llvm::cl::init(false), llvm::cl::cat(OnnxMlirCommonOptions));
+
+llvm::cl::opt nnpaQuantization("nnpa-quantization",
+ llvm::cl::desc("Enable quantization with a specific type. Only "
+ "MatMul whose weight is a constant is supported."),
+ llvm::cl::values(
+ clEnumVal(DynSymI8,
+ "Dynamic Quantization to signed integer 8. Asymmetric "
+ "quant for activations and symmetric quant for weights."),
+ clEnumVal(SymSymI8,
+ "Dynamic Quantization to signed integer 8. Symmetric "
+ "quant for activations and symmetric quant for weights."),
+ clEnumVal(QNONE, "No quantization (default).")),
+ llvm::cl::init(QNONE), llvm::cl::cat(OnnxMlirOptions));
+
} // namespace onnx_mlir
diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp
index 2b0343295c..366efee3fe 100644
--- a/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp
+++ b/src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp
@@ -55,6 +55,15 @@ typedef enum {
MuchFasterOpsWSU, /* FasterOpsWSU only if significantly faster. */
} NNPAPlacementHeuristic;
+// Quantization type
+typedef enum {
+ DynSymI8, /* Dynamic quantization to signed integer 8. Asymmetric quant for
+ activations and symmetric quant for weights.*/
+ SymSymI8, /* Dynamic quantization to signed integer 8. Symmetric quant for
+ activations and symmetric quant for weights.*/
+ QNONE, /* Only qualifying ops that are faster on NNPA. */
+} NNPAQuantType;
+
extern llvm::cl::OptionCategory OnnxMlirOptions;
extern llvm::cl::OptionCategory OnnxMlirCommonOptions;
extern llvm::cl::opt nnpaEmissionTarget;
@@ -68,6 +77,9 @@ extern llvm::cl::opt profileZHighIR;
extern llvm::cl::opt nnpaLoadDevicePlacementFile;
extern llvm::cl::opt nnpaSaveDevicePlacementFile;
extern llvm::cl::opt nnpaEnableSaturation;
+extern llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPU;
+extern llvm::cl::opt nnpaUseDynamicQuantizeLinearOnCPUForScaleOffset;
+extern llvm::cl::opt nnpaQuantization;
} // namespace onnx_mlir
#endif
diff --git a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
index eefe6b9a15..d7c5cfcac0 100644
--- a/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
+++ b/src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp
@@ -36,6 +36,7 @@
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
#include "src/Accelerators/NNPA/Dialect/ZLow/ZLowOps.hpp"
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
+#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Compiler/CompilerPasses.hpp"
#include "src/Pass/Passes.hpp"
@@ -49,9 +50,9 @@ namespace onnx_mlir {
void configurePassesNNPA() {
configureOnnxToZHighLoweringPass(optReport == OptReport::NNPAUnsupportedOps);
- // Compiler generated sticks supports saturation, so force its usage.
- // TODO: remove this if zDNN adds support for saturation.
- if (nnpaEnableSaturation)
+ // z16 does not support for hardware saturation.
+ // So, force its usage to compiler generated sticks.
+ if (nnpaEnableSaturation && isLessEqualNNPALevel(NNPALevel::M14))
nnpaEnableCompilerStickUnstick = true;
}
@@ -84,7 +85,7 @@ void addONNXToZHighPasses(mlir::PassManager &pm) {
pm.addNestedPass(
onnx_mlir::createInstrumentPass(instrumentOps, instrumentActions));
- pm.addPass(onnx_mlir::createONNXToZHighPass());
+ pm.addPass(onnx_mlir::createONNXToZHighPass(nnpaQuantization));
pm.addNestedPass(onnx_mlir::createShapeInferencePass());
// There are more opportunities for const propagation once all zhigh ops were
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp
index 58e6439897..47724d8d3e 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/DevicePlacement.cpp
@@ -200,13 +200,13 @@ void DevicePlacementPass::runOnOperation() {
// Call ONNXToZHigh pass for lowering multiple ONNX ops at once to ZHigh.
// E.g. `onnx.ReLu (onnx.Conv)` to zhigh.Conv.
RewritePatternSet Patterns2(context);
- getONNXToZHighMultipleOpPatterns(Patterns2);
+ getONNXToZHighMultipleOpPatterns(Patterns2, nnpaQuantization);
(void)applyAnalysisConversion(module, target, std::move(Patterns2),
ConversionConfig{.legalizableOps = &legalizedOps2});
// Call ONNXToZHigh pass for lowering a single ONNX op to ZHigh.
RewritePatternSet Patterns3(context);
- getONNXToZHighOneOpPatterns(Patterns3);
+ getONNXToZHighOneOpPatterns(Patterns3, nnpaQuantization);
getONNXToZHighOneOpDynamicallyLegal(&target, &dimAnalysis);
(void)applyAnalysisConversion(module, target, std::move(Patterns3),
ConversionConfig{.legalizableOps = &legalizedOps3});
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
index 621c8ffcbf..aa161a9f9e 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
@@ -15,6 +15,7 @@
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
+#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp"
#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include "src/Conversion/ONNXToKrnl/RNN/RNNBase.hpp"
@@ -38,12 +39,16 @@ bool onnxToZHighUnsupportedReport(Operation *op, const std::string &message) {
/// Report incompatibility with NNPA Level.
bool onnxToZHighInCompatibilityReport(
- Operation *op, std::string inputNNPALevel) {
- std::string message =
- "onnx-mlir NNPA level (" + inputNNPALevel +
- ") is not compatible with NNPA level specified by '-mcpu'(" + mcpu +
- ").";
- return onnxToZHighUnsupportedReport(op, message);
+ Operation *op, const std::string &message) {
+ std::string compilerNNPALevelStr = getNNPAString(getNNPAFromFlags());
+ std::string errorMessage =
+ "onnx-mlir NNPA level \"" + message + "\" is not compatible with " +
+ "NNPA level specified by \"" + compilerNNPALevelStr + "\".";
+ return onnxToZHighUnsupportedReport(op, errorMessage);
+}
+
+bool onnxToZHighInCompatibilityReport(Operation *op, NNPALevel level) {
+ return onnxToZHighInCompatibilityReport(op, getNNPAString(level));
}
/// A function to check whether a value's element type is valid for zAIU or not.
@@ -357,8 +362,8 @@ template <>
bool isSuitableForZDNN(
ONNXAddOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16)) {
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14)) {
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
}
if (!isValidElementTypeAndRank(op.getOperation(), op.getA()))
return false;
@@ -376,8 +381,8 @@ template <>
bool isSuitableForZDNN(
ONNXSubOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
if (!isValidElementTypeAndRank(op.getOperation(), op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getOperation(), op.getB()))
@@ -394,8 +399,8 @@ template <>
bool isSuitableForZDNN(
ONNXMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
if (!isValidElementTypeAndRank(op.getOperation(), op.getA()))
return false;
if (!isValidElementTypeAndRank(op.getOperation(), op.getB()))
@@ -414,8 +419,8 @@ bool isSuitableForZDNN(
Value A = op.getA();
Value B = op.getB();
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Broadcast with a scalar operand.
if (isEnableScalarBcastBinary()) {
if (isF32ScalarConstantTensor(A) &&
@@ -442,8 +447,8 @@ template <>
bool isSuitableForZDNN(
ONNXSumOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Do not support a single input.
if (op.getData_0().size() < 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
@@ -473,8 +478,8 @@ template <>
bool isSuitableForZDNN(
ONNXMinOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
int64_t opnum = op.getNumOperands();
if (opnum != 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
@@ -491,13 +496,13 @@ bool isSuitableForZDNN(
}
/// Check legality for ONNXMax.
-/// zDNN Min/Max do not support boradcasting, and getNumOperands != 2.
+/// zDNN Min/Max do not support broadcasting, and getNumOperands != 2.
template <>
bool isSuitableForZDNN(
ONNXMaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
int64_t opnum = op.getNumOperands();
if (opnum != 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
@@ -520,8 +525,8 @@ template <>
bool isSuitableForZDNN(
ONNXSoftmaxOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
ShapedType inputType = mlir::cast(op.getType());
@@ -541,13 +546,37 @@ bool isSuitableForZDNN(
return true;
}
+/// Check legality for ONNXLeakyRelu.
+template <>
+bool isSuitableForZDNN(
+ ONNXLeakyReluOp op, const DimAnalysis *dimAnalysis) {
+ // Check NNPA level.
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);
+ if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
+ return false;
+ return true;
+}
+
/// Check legality for ONNXRelu.
template <>
bool isSuitableForZDNN(
ONNXReluOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
+ if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
+ return false;
+ return true;
+}
+
+/// Check legality for ONNXGelu.
+template <>
+bool isSuitableForZDNN(
+ ONNXGeluOp op, const DimAnalysis *dimAnalysis) {
+ // Check NNPA level.
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
return false;
return true;
@@ -558,8 +587,8 @@ template <>
bool isSuitableForZDNN(
ONNXTanhOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
return true;
@@ -570,8 +599,20 @@ template <>
bool isSuitableForZDNN(
ONNXSigmoidOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
+ if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
+ return false;
+ return true;
+}
+
+/// Check legality for ONNXSqrt.
+template <>
+bool isSuitableForZDNN(
+ ONNXSqrtOp op, const DimAnalysis *dimAnalysis) {
+ // Check NNPA level.
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
return false;
return true;
@@ -582,8 +623,8 @@ template <>
bool isSuitableForZDNN(
ONNXLogOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
return true;
@@ -594,8 +635,8 @@ template <>
bool isSuitableForZDNN(
ONNXExpOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
if (!isValidElementTypeAndRank(op.getOperation(), op.getInput()))
return false;
return true;
@@ -606,8 +647,8 @@ template <>
bool isSuitableForZDNN(
ONNXMatMulOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
int64_t opnum = op.getNumOperands();
if (opnum != 2)
return onnxToZHighUnsupportedReport(op.getOperation(),
@@ -663,10 +704,10 @@ bool isSuitableForZDNN(
}
return true;
} else if ((shapeA.size() == 3) && (shapeB.size() == 2)) {
- // stacked w/ bcast
+ // stacked w/ bcast23 case
if (aType.hasStaticShape() && bType.hasStaticShape()) {
if (shapeA[2] != shapeB[0]) {
- std::string message = "Stacked w/ bcast case: the 3rd dim of A (" +
+ std::string message = "Stacked w/ bcast23 case: the 3rd dim of A (" +
std::to_string(shapeA[2]) +
") and the 1st dim of B (" +
std::to_string(shapeB[0]) + ") are not the same.";
@@ -674,6 +715,21 @@ bool isSuitableForZDNN(
}
}
return true;
+ } else if ((shapeA.size() == 2) && (shapeB.size() == 3)) {
+ // stacked w/ bcast1 case
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(
+ op.getOperation(), NNPALevel::M15);
+ if (aType.hasStaticShape() && bType.hasStaticShape()) {
+ if (shapeA[1] != shapeB[1]) {
+ std::string message = "Stacked w/ bcast1 case: the 2nd dim of A (" +
+ std::to_string(shapeA[1]) +
+ ") and the 2nd dim of B (" +
+ std::to_string(shapeB[1]) + ") are not the same.";
+ return onnxToZHighUnsupportedReport(op.getOperation(), message);
+ }
+ }
+ return true;
}
std::string message = "Dim size of A(" + std::to_string(shapeA.size()) +
") and B(" + std::to_string(shapeB.size()) +
@@ -681,6 +737,141 @@ bool isSuitableForZDNN(
return onnxToZHighUnsupportedReport(op.getOperation(), message);
}
+/// Check legality for ONNXMatMulInteger.
+template <>
+bool isSuitableForZDNN(
+ ONNXMatMulIntegerOp op, const DimAnalysis *dimAnalysis) {
+ // Check NNPA level.
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);
+
+ // Only support per-tensor quantization.
+ Value AZeroPoint = op.getAZeroPoint();
+ Value BZeroPoint = op.getBZeroPoint();
+ if (!isScalarTensor(AZeroPoint))
+ return onnxToZHighInCompatibilityReport(
+ op.getOperation(), "A's zeropoint is not scalar");
+ if (!isScalarTensor(BZeroPoint))
+ return onnxToZHighInCompatibilityReport(
+ op.getOperation(), "B's zeropoint is not scalar");
+
+ ShapedType aType = mlir::cast(op.getA().getType());
+ ShapedType bType = mlir::cast(op.getB().getType());
+
+ // Illegal if A or B is unranked.
+ if (!aType.hasRank() || !bType.hasRank())
+ return false;
+
+ auto shapeA = aType.getShape();
+ auto shapeB = bType.getShape();
+
+ // In case of Tensors with unknown dimension, check only size of matrices.
+ // Actual shape is not checked. If actual shape does not meet, get error at
+ // runtime.
+ // TODO: Support other cases
+ // (https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul) on zDNN
+ // by using broadcasting etc.
+ if ((shapeA.size() == 2) && (shapeB.size() == 2)) {
+ // unstacked case
+ if (aType.hasStaticShape() && bType.hasStaticShape())
+ return (shapeA[1] == shapeB[0]);
+ else
+ return true;
+ } else if ((shapeA.size() == 3) && (shapeB.size() == 3)) {
+ // stacked w/o bcast case
+ if (aType.hasStaticShape() && bType.hasStaticShape())
+ return ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1]));
+ else
+ return true;
+ } else if ((shapeA.size() == 3) && (shapeB.size() == 2)) {
+ // stacked w/ bcast
+ if (aType.hasStaticShape() && bType.hasStaticShape())
+ return (shapeA[2] == shapeB[0]);
+ else
+ return true;
+ }
+
+ return false; // unsupported case
+}
+
+/// Check legality for ONNXQLinearMatMul.
+template <>
+bool isSuitableForZDNN(
+ ONNXQLinearMatMulOp op, const DimAnalysis *dimAnalysis) {
+ Value A = op.getA();
+ Value AScale = op.getAScale();
+ Value AZeroPoint = op.getAZeroPoint();
+ Value B = op.getB();
+ Value BScale = op.getBScale();
+ Value BZeroPoint = op.getBZeroPoint();
+ Value Y = op.getY();
+ Value YScale = op.getYScale();
+
+ // Check NNPA level.
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);
+
+ // Only support float32 <-> int8/uint8.
+ Type elemTyA = getElementType(A.getType());
+ Type elemTyAScale = getElementType(AScale.getType());
+ Type elemTyB = getElementType(B.getType());
+ Type elemTyBScale = getElementType(BScale.getType());
+ Type elemTyY = getElementType(Y.getType());
+ Type elemTyYScale = getElementType(YScale.getType());
+
+ if (!elemTyAScale.isF32() || !elemTyBScale.isF32() || !elemTyYScale.isF32())
+ return false;
+ if (!(elemTyA.isInteger(8) || elemTyA.isUnsignedInteger(8)))
+ return false;
+ if (!(elemTyB.isInteger(8) || elemTyB.isUnsignedInteger(8)))
+ return false;
+ if (!(elemTyY.isInteger(8) || elemTyY.isUnsignedInteger(8)))
+ return false;
+
+ // Only support per-tensor quantization.
+ if (!isScalarTensor(AScale) || !isScalarTensor(BScale) ||
+ !isScalarTensor(AZeroPoint) || !isScalarTensor(BZeroPoint))
+ return false;
+
+ ShapedType aType = mlir::cast(A.getType());
+ ShapedType bType = mlir::cast(B.getType());
+
+ // Illegal if A or B is unranked.
+ if (!aType.hasRank() || !bType.hasRank())
+ return false;
+
+ auto shapeA = aType.getShape();
+ auto shapeB = bType.getShape();
+
+ // In case of Tensors with unknown dimension, check only size of matrices.
+ // Actual shape is not checked. If actual shape does not meet, get error at
+ // runtime.
+ // TODO: Support other cases
+ // (https://github.com/onnx/onnx/blob/main/docs/Operators.md#MatMul) on zDNN
+ // by using broadcasting etc.
+ if ((shapeA.size() == 2) && (shapeB.size() == 2)) {
+ // unstacked case
+ if (aType.hasStaticShape() && bType.hasStaticShape())
+ return (shapeA[1] == shapeB[0]);
+ else
+ return true;
+ } else if ((shapeA.size() == 3) && (shapeB.size() == 3)) {
+ // stacked w/o bcast case
+ if (aType.hasStaticShape() && bType.hasStaticShape())
+ return ((shapeA[0] == shapeB[0]) && (shapeA[2] == shapeB[1]));
+ else
+ return true;
+ } else if ((shapeA.size() == 3) && (shapeB.size() == 2)) {
+ // stacked w/ bcast
+ if (aType.hasStaticShape() && bType.hasStaticShape())
+ return (shapeA[2] == shapeB[0]);
+ else
+ return true;
+ }
+
+ return false; // unsupported case
+}
+
/// Check legality for ONNXGemm.
template <>
bool isSuitableForZDNN(
@@ -690,8 +881,8 @@ bool isSuitableForZDNN(
Value C = op.getC();
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), A))
@@ -759,13 +950,99 @@ bool isSuitableForZDNN(
return true;
}
+// Common function for ReduceMax and ReduceMin
+template
+static bool checkReduceParam(OP_TYPE op) {
+ IndexExprBuilderForAnalysis createIE(op.getLoc());
+
+ // Check NNPA level.
+ if (!isCompatibleWithNNPALevel(NNPALevel::M15))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M15);
+
+ // Check data type.
+ Value data = op.getData();
+ if (!isValidElementTypeAndRank(op.getOperation(), data))
+ return false;
+
+ // Check axes value
+ Value axesVal = op.getAxes();
+ if (!isDenseONNXConstant(axesVal))
+ return false;
+
+ ONNXConstantOp axesConstant =
+ mlir::cast(axesVal.getDefiningOp());
+ int64_t axesInt = getScalarValue(axesConstant);
+
+ int64_t keepdims = op.getKeepdims();
+ int64_t noop_with_empty_axes = op.getNoopWithEmptyAxes();
+ int64_t rank = createIE.getShapedTypeRank(data);
+
+ // Check if axes (int64) is exactly a size of one
+ if (floor(log10(axesInt)) + 1 == 1) {
+ int64_t axis = axesInt;
+ // Accepted range is [-r, r-1] where r = rank(data)
+ if (axis < -rank || axis > rank - 1) {
+ std::string message =
+ "The `axis` is out of the accepted range which is [-r, r-1]";
+ return onnxToZHighUnsupportedReport(op, message);
+ }
+ if ((axis != -1) && (axis != rank - 1)) {
+ std::string message = "The `axis` must be the innermost dimension. ";
+ return onnxToZHighUnsupportedReport(op, message);
+ }
+ } else {
+ std::string message = "Axes can only be a scalar size of one. ";
+ return onnxToZHighUnsupportedReport(op, message);
+ }
+
+ // REMINDER: Should we check the input tensor rank.
+
+ // Check keepdims and noop_with_empty_axes, we only support the default
+ // value. Attributes: keepdims (default is 1) and noop_with_empty_axes
+ // (default is 0)
+ if ((noop_with_empty_axes == 1) || (keepdims == 0)) {
+ std::string message = "`noop_with_empty_axes` (" +
+ std::to_string(noop_with_empty_axes) +
+ ") must be 0 and `keepdims` (" +
+ std::to_string(keepdims) + ") must be 1.";
+ return onnxToZHighUnsupportedReport(op, message);
+ }
+ return true;
+}
+
+/// Check legality for ONNXReduceMax.
+template <>
+bool isSuitableForZDNN(
+ ONNXReduceMaxOp op, const DimAnalysis *dimAnalysis) {
+
+ // Check parameter restrictions for ReduceMax
+ bool isReduceMax = checkReduceParam(op);
+ if (!isReduceMax)
+ return false;
+
+ return true;
+}
+
+/// Check legality for ONNXReduceMin.
+template <>
+bool isSuitableForZDNN(
+ ONNXReduceMinOp op, const DimAnalysis *dimAnalysis) {
+
+ // Check parameter restrictions for ReduceMin
+ bool isReduceMin = checkReduceParam(op);
+ if (!isReduceMin)
+ return false;
+
+ return true;
+}
+
/// Check legality for ONNXReduceMeanV13.
template <>
bool isSuitableForZDNN(
ONNXReduceMeanV13Op op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getData()))
@@ -826,7 +1103,7 @@ template <>
bool isSuitableForZDNN(
ONNXSoftplusOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
return false;
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
return false;
@@ -844,8 +1121,8 @@ bool isSuitableForZDNN(
Value B = op.getB();
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check direction.
if ((direction != FORWARD) && (direction != REVERSE) &&
@@ -869,7 +1146,8 @@ bool isSuitableForZDNN(
std::string message =
"The first dimension of weight tensor `W` for `num_directions` (" +
std::to_string(wShape[0]) +
- ") must be 1 or 2, and the second dimension of it for `hidden_size` (" +
+ ") must be 1 or 2, and the second dimension of it for `hidden_size` "
+ "(" +
std::to_string(wShape[1]) + ") must be static.";
return onnxToZHighUnsupportedReport(op.getOperation(), message);
}
@@ -877,9 +1155,9 @@ bool isSuitableForZDNN(
ArrayRef rShape = mlir::cast(R.getType()).getShape();
if (!mlir::cast(R.getType()).hasStaticShape() ||
(rShape[0] != 1 && rShape[0] != 2)) {
- std::string message =
- "The recurrence weight tensor `R` must have static dimension, and the "
- "first dimension of it must be 1 or 2.";
+ std::string message = "The recurrence weight tensor `R` must have static "
+ "dimension, and the "
+ "first dimension of it must be 1 or 2.";
return onnxToZHighUnsupportedReport(op.getOperation(), message);
}
// Check hidden_size.
@@ -957,8 +1235,8 @@ bool isSuitableForZDNN(
Value B = op.getB();
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check direction.
if ((direction != FORWARD) && (direction != REVERSE) &&
@@ -982,7 +1260,8 @@ bool isSuitableForZDNN(
std::string message =
"The first dimension of weight tensor `W` for `num_directions` (" +
std::to_string(wShape[0]) +
- ") must be 1 or 2, and the second dimension of it for `hidden_size` (" +
+ ") must be 1 or 2, and the second dimension of it for `hidden_size` "
+ "(" +
std::to_string(wShape[1]) + ") must be static.";
return onnxToZHighUnsupportedReport(op.getOperation(), message);
}
@@ -1062,8 +1341,8 @@ template <>
bool isSuitableForZDNN(
ONNXMaxPoolSingleOutOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
@@ -1094,8 +1373,8 @@ template <>
bool isSuitableForZDNN(
ONNXAveragePoolOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
@@ -1111,9 +1390,9 @@ bool isSuitableForZDNN(
ONNXAveragePoolOpShapeHelper>(op, op.getY(), dimAnalysis);
}
-/// Check if input, output, kernel, strides, and paddingType for each axis meet
-/// parameter restrictions for conv2d. See "Conv2D Parameter Restrictions"
-/// in "zDNN API Reference"
+/// Check if input, output, kernel, strides, and paddingType for each axis
+/// meet parameter restrictions for conv2d. See "Conv2D Parameter
+/// Restrictions" in "zDNN API Reference"
static bool checkConv2DParamRestrictions(Operation *op, int64_t inputDim,
int64_t kernelDim, int64_t stride, int64_t outputDim,
StringRef paddingType) {
@@ -1218,8 +1497,8 @@ template <>
bool isSuitableForZDNN(
ONNXConvOp op, const DimAnalysis *dimAnalysis) {
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// Check data type.
if (!isValidElementTypeAndRank(op.getOperation(), op.getX()))
@@ -1255,7 +1534,8 @@ bool isSuitableForZDNN(
ShapedType::isDynamic(shapeOutput[2]) ||
ShapedType::isDynamic(shapeOutput[3]))
return onnxToZHighUnsupportedReport(op,
- "Height and/or width have dynamic dimensions. They are not supported.");
+ "Height and/or width have dynamic dimensions. They are not "
+ "supported.");
// Do not support group.
if (operandAdaptor.getGroup() != 1)
@@ -1271,7 +1551,8 @@ bool isSuitableForZDNN(
}
// `getStrPaddingType` returns `SAME_PADDING`, `VALID_PADDING`, or empty.
- // `zdnn_conv2d` only support padding for `SAME_PADDING` and `VALID_PADDING`.
+ // `zdnn_conv2d` only support padding for `SAME_PADDING` and
+ // `VALID_PADDING`.
StringRef paddingType =
getStrPaddingType(
op);
@@ -1324,8 +1605,8 @@ bool isSuitableForZDNN(
ArrayRef shapeOutput = outputType.getShape();
// Check NNPA level.
- if (!isCompatibleWithNNPALevel(NNPA_Z16))
- return onnxToZHighInCompatibilityReport(op.getOperation(), NNPA_Z16);
+ if (!isCompatibleWithNNPALevel(NNPALevel::M14))
+ return onnxToZHighInCompatibilityReport(op.getOperation(), NNPALevel::M14);
// 4D tensors(N x C x H x W) are supported as input and output.
if (shapeInput.size() != 4 || shapeOutput.size() != 4)
@@ -1344,3 +1625,19 @@ bool isSuitableForZDNN(
// Noop Reshape is suitable for zAIU as this pass removes such reshape ops.
return isIdentityReshape(op, dimAnalysis);
}
+
+/// Check legality for ONNXDequantizeLinearOp.
+template <>
+bool isSuitableForZDNN(
+ ONNXDequantizeLinearOp op, const DimAnalysis *dimAnalysis) {
+ // The pass rewrite-onnx-for-zhigh has a rule to rewrite the pattern
+ // `DequantizeLinear (QLinearMatMul inputs)` where ND inputs are reshaped
+ // into 3D inputs. This rule uses the function template
+ // `addDynamicallyLegalOpFor` to define legality using a custom lambda
+ // function instead of `isSuitableForZDNN`. Hence, the legality here should
+ // not be used/called. This legality is here to complete the function
+ // template `addDynamicallyLegalOpFor` so that it's not failed when building
+ // the compiler.
+ llvm_unreachable("Not used");
+ return false;
+}
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp
index f9c36372c4..09bfa6f4f6 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.hpp
@@ -17,6 +17,7 @@
#ifndef ONNX_MLIR_LEGALITY_H
#define ONNX_MLIR_LEGALITY_H
+#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp"
#include "src/Dialect/ONNX/ONNXDimAnalysis.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
@@ -53,6 +54,8 @@ bool onnxToZHighUnsupportedReport(
mlir::Operation *op, const std::string &message);
bool onnxToZHighInCompatibilityReport(
- mlir::Operation *op, std::string inputNNPALevel);
+ mlir::Operation *op, const std::string &message);
+
+bool onnxToZHighInCompatibilityReport(mlir::Operation *op, NNPALevel level);
#endif
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
index 2bfa450691..78e94a6a2a 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.cpp
@@ -4,7 +4,7 @@
//====------ ONNXToZHigh.cpp - ONNX dialect to ZHigh lowering -------------===//
//
-// Copyright 2019-2022 The IBM Research Authors.
+// Copyright 2019-2024 The IBM Research Authors.
//
// =============================================================================
//
@@ -13,6 +13,9 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Support/Debug.h"
+
+#include "src/Accelerators/NNPA/Compiler/NNPACompilerOptions.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.hpp"
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHighCommon.hpp"
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps.hpp"
@@ -25,6 +28,8 @@
#include "src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp"
#include "src/Dialect/ONNX/Transforms/ShapeInference.hpp"
+#define DEBUG_TYPE "onnx-to-zhigh"
+
using namespace mlir;
//
@@ -33,6 +38,17 @@ using namespace mlir;
namespace onnx_mlir {
+using namespace zhigh;
+
+#define QUANT_PATTERN_BENEFIT 1000
+
+/// Checks whether a constant tensor's elements are of type FloatType.
+bool isFloatType(Value constValue) {
+ ElementsAttr constElements = getElementAttributeFromONNXValue(constValue);
+ Type elemType = constElements.getElementType();
+ return mlir::isa(elemType);
+}
+
ArrayAttr getLSTMGRUBiasSplitShape(
Location loc, PatternRewriter &rewriter, ArrayRef shapeR) {
int64_t hiddenSize = shapeR[2];
@@ -253,6 +269,349 @@ SmallVector getArrayStrides(OP op) {
return shapeHelper.strides;
}
+/// Get approximate
+template
+StringRef getStrApproximateType(OP op) {
+ return op.getApproximate();
+}
+
+// Computes the folded bias to be passed to quantized matmul call when
+// operation is MATMUL_OP_ADDITION. Zb should be equal to 0, meaning the
+// correction term for input_a is also equal to 0. This allows the
+// correction term for input_b to be folded into qc_tilde, which removes the
+// need for correction being applied after the quantized matmul call.
+//
+// The original equation for qc_tilde is:
+// M = (Sa * Sb) / Sy
+// qc_tilde = Zy - (Sc / Sy) * Zc + (Sc / Sy) * input_c[j] + M*N*Za*Zb
+//
+// Given Zb = 0, the equation becomes:
+// M = (Sa * Sb) / Sy
+// qc_tilde = Zy - (Sc / Sy) * Zc + (Sc / Sy) * input_c[j]
+//
+// Given scales are stored as the reciprocal in zTensor, the modified equation
+// becomes:
+// M = RSy / (RSa * RSb)
+// qc_tilde = Zy - (RSy / RSc) * Zc + (RSy / RSc) * input_c[j]
+//
+// where RS = 1/S.
+//
+// We can reorder this to:
+// M = RSy / (RSa * RSb)
+// qc_tilde = input_c[j] * (RSy / RSc) + Zy - (RSy / RSc) * Zc
+//
+// This allows us to pre-compute a scale and offset to apply to input_c[j]:
+// M = RSy / (RSa * RSb).
+// scale = (RSy / RSc)
+// offset = Zy - scale * Zc
+// qc_tilde[j] = input_c[j] * scale + offset
+//
+// The original equation for the correction term for input_b is:
+// M = (RSa * RSb) / RSy
+// term_b = M * Za * sum(input_b[:,j])
+//
+// Given scales are stored as the reciprocal, the modified equation becomes:
+// M = RSy / (RSa * RSb)
+// term_b = M * Za * sum(input_b[:,j])
+//
+// This gives us the equation:
+// M = RSy / (RSa * RSb)
+// MZa = M * Za
+// scale = (RSy / RSc)
+// offset = Zy - scale * Zc
+// qc_tilde[j] = input_c[j] * scale + offset - MZa * sum(input_b[:,j])
+//
+// In case of MatMulInteger, input_c = 0, RSc = 1, Zc = 0, the final equation
+// is:
+// M = RSy / (RSa * RSb)
+// MZa = M * Za
+// scale = RSy
+// offset = Zy
+// qc_tilde[j] = offset - Za * (RSy / RSa / RSb) * sum(input_b[:,j])
+//
+// When Zy = 0, qc_tilde[j] = -Za * (RSy / RSa / RSb) * sum(input_b[:,j])
+static void preComputeBias(MultiDialectBuilder &create, Value RSa,
+ Value Za, Value BI8, Value RSb, Value RSy, Value Zy, Value &qcTilde,
+ Value &RSqctilde, Value &Zqctilde) {
+ OpBuilder rewriter = create.getBuilder();
+ Location loc = create.getLoc();
+
+ Type i64Ty = rewriter.getI64Type();
+ Type f32Ty = rewriter.getF32Type();
+ auto cstMinus2Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, i64Ty), static_cast(-2));
+ auto cst0Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, f32Ty), static_cast(0));
+ auto cst1Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, f32Ty), static_cast(1));
+
+ Value cst0 = create.onnx.constant(cst0Attr);
+ Value cst1 = create.onnx.constant(cst1Attr);
+
+ // Can be optimized further when Zy is zero.
+ bool ZyIsZero = isDenseONNXConstant(Zy) && isConstOf(Zy, 0.);
+
+ Value qcF32;
+ Value B = create.onnx.cast(BI8, f32Ty);
+ Value lastSecondAxis = create.onnx.constant(cstMinus2Attr);
+ // Emit: sum(input_b[:,j])
+ Value BSum = create.onnx.reduceSum(
+ UnrankedTensorType::get(f32Ty), B, lastSecondAxis, false);
+ // RSy, RSa, RSb, Za are scalar, do scalar computation.
+ // Emit: Za * (RSy / RSa / RSb)
+ Value RSyRSa = create.onnx.div(RSy, RSa);
+ Value RSyRSaRSb = create.onnx.div(RSyRSa, RSb);
+ Value MZa = create.onnx.mul(RSyRSaRSb, Za);
+ // Negate ZaRSyRSa to avoid broadcasting Sub:
+ // `Zy - Za * (RSy / RSa / RSb) * ...`
+ MZa = create.onnx.sub(cst0, MZa);
+ // Broadcast ops.
+ // Emit: - Za * (RSy / RSa / RSb) * sum(input_b[:,j])
+ Value MZaBSum = create.onnx.mul(MZa, BSum);
+ // Emit: Zy - Za * (RSy / RSa / RSb) * sum(input_b[:,j])
+ if (ZyIsZero) {
+ qcF32 = MZaBSum;
+ } else {
+ qcF32 = create.onnx.add(Zy, MZaBSum);
+ }
+
+ // Use 1 for recscale and 0 for offset. This is a dlfloat16 stickification.
+ int64_t rank = getRank(qcF32.getType());
+ StringAttr layoutAttr =
+ rewriter.getStringAttr((rank == 1) ? LAYOUT_1D : LAYOUT_2DS);
+ ZHighQuantizedStickOp qcOp = rewriter.create(loc,
+ qcF32, cst1, cst0, layoutAttr, rewriter.getStringAttr(QTYPE_DLFLOAT16));
+ qcTilde = qcOp.getResult(0);
+ RSqctilde = qcOp.getResult(1);
+ Zqctilde = qcOp.getResult(2);
+}
+
+static Value getOrCastToI8(Value val, MultiDialectBuilder &create,
+ bool simpleCast = false) {
+ if (!getElementType(val.getType()).isUnsignedInteger())
+ return val;
+
+ Type i8Ty = create.getBuilder().getI8Type();
+ if (simpleCast)
+ return create.onnx.cast(val, i8Ty);
+
+ // Use int16 to avoid integer overflow.
+ Type i16Ty = create.getBuilder().getI16Type();
+ auto cst128Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, i16Ty), static_cast(128));
+ Value valI16 = create.onnx.cast(val, i16Ty);
+ valI16 = create.onnx.sub(valI16, create.onnx.constant(cst128Attr));
+ Value valI8 = create.onnx.cast(valI16, i8Ty);
+ return valI8;
+}
+
+// Dynamic quantization helper to match and rewrite values A, B, C of A*B+C.
+class DynQuantI8PatternHelper {
+public:
+ DynQuantI8PatternHelper(PatternRewriter &rewriter, Location loc,
+ Operation *op, Value A, Value B, Value C, bool symForA)
+ : rewriter(rewriter), loc(loc), op(op), A(A), B(B), C(C),
+ symForA(symForA) {}
+
+ // Check the inputs A, B, C of `A*B+C` to see if they are suitable for doing
+ // dynamic quantization on NNPA.
+ LogicalResult match() {
+ // A is of f32.
+ if (!mlir::isa(getElementType(A.getType())))
+ return rewriter.notifyMatchFailure(op, "MatMul's A is not of f32.");
+
+ // Weight is a constant.
+ if (!isDenseONNXConstant(B))
+ return rewriter.notifyMatchFailure(op, "MatMul's B is not a constant.");
+
+ if (C) {
+ // Bias is a constant.
+ if (!isDenseONNXConstant(C))
+ return rewriter.notifyMatchFailure(op, "MatMul's C is not a constant");
+ // B and C shapes must be consistent. The reduction shape of B on the
+ // second dim from the last is the same as the shape of B, e.g. If B is
+ // [2x3x4], C must be [2x4].
+ ArrayRef bShape = getShape(B.getType());
+ ArrayRef cShape = getShape(C.getType());
+ int64_t bRank = bShape.size();
+ int64_t cRank = cShape.size();
+ if (bRank - 1 != cRank)
+ return rewriter.notifyMatchFailure(
+ op, "The ranks of B and C are imcompatible.");
+ if (bShape[bRank - 1] != cShape[cRank - 1])
+ return rewriter.notifyMatchFailure(
+ op, "The last dimensions of B and C are not the same.");
+ if (bShape.drop_back(2) != cShape.drop_back(1))
+ return rewriter.notifyMatchFailure(
+ op, "The shapes of B and C are imcompatible.");
+ }
+
+ return success();
+ }
+
+ // clang-format off
+ /*
+ * Emit the following code to compute `A*B+C` using i8 dynamic quantization.
+ * A can be quantized using asymmetric or symmetric quantization depending on
+ * the flag `symForA`, while B is always quantized using symmetric quantization.
+ * (Note that: If C is given, it will be added into the pre_computed_bias)
+ *
+ * ```
+ * (Quantize A using asymmetric/symmetric quant by setting `sym_mode` attr to the `symForA` flag)
+ * %qa, %a_recscale, %a_offset = zhigh.QuantizedStick(%A, none, none) { quantized_type = QUANTIZED_DLFLOAT16, sym_mode = 1/0}
+ *
+ * (Quantize B using symmetric quant)
+ * %b_offset = 0 // Symmetric quant mode for i8. Offset is always zero, qmin = * -127, qmax = 127.
+ * %absmax = onnx.ReduceMax(onnx.Abs(%B))
+ * %b_rescale = onnx.Div(127, absmax)
+ * %qb = onnx.cast(onnx.Clip(onnx.Round(onnx.Mul(%B, %b_rescale)), qmin, qmax))
+ * %qb, %b_recscale, %b_offset = zhigh.QuantizedStick(%qb, %b_recscale, %b_offset) { quantized_type = QUANTIZED_WEIGHTS_INT8 }
+ *
+ * (Pre computed bias, %C is added)
+ * %qc = emit_ops_for_pre_computed_bias_at_compile_time
+ * %qc = zhigh.Add(%qc, zhigh.Stick(%C)) // only done if C is given.
+ * %qc_recscale = 1
+ * %qc_offset = 0
+ *
+ * %Y_recscale = 1
+ * %Y_offset = 0
+ * %Y, %Y_recscale, %Y_offset = zhigh.QuantizedMatMul (%qa, %a_recscale, %a_offset,
+ * %qb, %b_recscale, %b_offset,
+ * %qc, %c_recscale, %c_offset,
+ * %Y_recscale, %Y_offset) {
+ * PreComputedBias = true, DisableClipping = true, DequantizeOutput = false
+ * }
+ * ```
+ *
+ * where the computation of `%qb` and `%qb_recscale` are expected to be folded by constant
+ * propagation so that they become constants.
+ *
+ * For more information about dynamic quantization, see https://www.maartengrootendorst.com/blog/quantization
+ */
+ // clang-format on
+ Value rewriteSym() {
+ MultiDialectBuilder create(rewriter, loc);
+
+ Type i8Ty = rewriter.getIntegerType(8);
+ Type si64Ty = rewriter.getIntegerType(64, true);
+ Type f16Ty = rewriter.getF16Type();
+ Type f32Ty = rewriter.getF32Type();
+ RankedTensorType scalarTy = RankedTensorType::get({}, f32Ty);
+
+ IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1);
+ IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0);
+
+ Value none = create.onnx.none();
+ Value cst0 = create.onnx.constant(
+ DenseElementsAttr::get(scalarTy, static_cast(0)));
+ Value cst1 = create.onnx.constant(
+ DenseElementsAttr::get(scalarTy, static_cast(1)));
+ Value cst127 = create.onnx.constant(
+ DenseElementsAttr::get(scalarTy, static_cast(127)));
+ Value cstNeg127 = create.onnx.constant(
+ DenseElementsAttr::get(scalarTy, static_cast(-127)));
+
+ int64_t rankA = getRank(A.getType());
+ int64_t rankB = getRank(B.getType());
+ StringAttr aLayoutAttr =
+ rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ StringAttr bLayoutAttr =
+ rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS);
+
+ // Quantize and stickify A.
+ IntegerAttr symModeAttr =
+ rewriter.getIntegerAttr(rewriter.getI64Type(), symForA ? 1 : 0);
+ ZHighQuantizedStickOp qAOp =
+ rewriter.create(loc, A, none, none, aLayoutAttr,
+ rewriter.getStringAttr(QTYPE_DLFLOAT16), symModeAttr);
+ Value AI8 = qAOp.getResult(0);
+ Value ARecScale = qAOp.getResult(1);
+ Value AOffset = qAOp.getResult(2);
+
+ // Quantize B. All computations here would be folded by constprop.
+ // Though computation here can be generalized for other integer types by
+ // changing qmin and qmax, we optimize it for i8 since NNPA supports i8 only
+ // at this moment.
+ // Symmetric mode for i8, meaning offset = 0, qmin = -127, qmax = 127.
+ Value BOffset = cst0;
+ Value qmin = cstNeg127;
+ Value qmax = cst127;
+ // %absmax = onnx.ReduceMax(onnx.Abs(%B))
+ // %b_rescale = onnx.Div(127, absmax)
+ Value absMax =
+ create.onnx.reduceMax(scalarTy, create.onnx.abs(B), none, false, false);
+ Value BRecScale = create.onnx.div(cst127, absMax);
+ // %qb = onnx.Cast(
+ // onnx.Clip(onnx.Round(onnx.Mul(%B, %b_rescale)), qmin, qmax))
+ Value BI8 = create.onnx.cast(
+ create.onnx.clip(
+ create.onnx.round(create.onnx.mul(B, BRecScale)), qmin, qmax),
+ i8Ty);
+ // Stickify B.
+ ZHighQuantizedStickOp qBOp =
+ rewriter.create(loc, BI8, BRecScale, BOffset,
+ bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS));
+
+ // Output information.
+ Value YRecScale = cst1;
+ Value YOffset = cst0;
+
+ // When A is also quantized using symmetric mode, both correction terms for
+ // A and B are canceled out. Thus, no precomputation is needed.
+ Value qcTilde = none, qcTildeRecScale = cst1, qcTildeOffset = cst0;
+ if (!symForA) {
+ // When only B is quantized using symmetric mode, precompute the
+ // correction term for B only.
+ preComputeBias(create, ARecScale, AOffset, BI8, BRecScale, YRecScale,
+ YOffset, qcTilde, qcTildeRecScale, qcTildeOffset);
+ }
+ // Add up C into bias if C is given.
+ if (C) {
+ int64_t rankC = getRank(C.getType());
+ assert((rankC == rankB - 1) &&
+ "C has a wrong shape to be added into pre_computed_bias");
+ assert((rankC == 1 || rankC == 2) && "Wrong rank for C");
+ StringAttr cLayoutAttr =
+ rewriter.getStringAttr((rankC == 1) ? LAYOUT_1D : LAYOUT_2DS);
+ Value stickC = rewriter.create(loc, C, cLayoutAttr);
+ if (symForA)
+ qcTilde = stickC;
+ else
+ qcTilde = rewriter.create(
+ loc, qcTilde.getType(), qcTilde, stickC);
+ }
+
+ // Emit zhigh.QuantizedMatMul.
+ // No need to dequantize since Y's rescale is 1.
+ // Do not clip the output values to i8, keep i32.
+ SmallVector resTypes;
+ resTypes.emplace_back(UnrankedTensorType::get(f16Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ ZHighQuantizedMatMulOp zhighQuantizedMatMulOp =
+ rewriter.create(loc, resTypes, AI8, ARecScale,
+ AOffset, qBOp.getResult(0), BRecScale, BOffset, qcTilde,
+ qcTildeRecScale, qcTildeOffset,
+ /*OutRecScale*/ YRecScale, /*OutOffset*/ YOffset,
+ /*PreComputedBias*/ trueAttr, /*DisableClipping*/ trueAttr,
+ /*DequantizeOutput*/ falseAttr);
+ (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {});
+
+ // Unstickify the matmul result that is int8-as-float.
+ Value res = rewriter.create(
+ loc, zhighQuantizedMatMulOp.getResult(0));
+ return res;
+ }
+
+private:
+ PatternRewriter &rewriter;
+ Location loc;
+ Operation *op;
+ Value A, B, C;
+ // Whether do symmetric quant for activation input A or not.
+ bool symForA = false;
+};
+
//===----------------------------------------------------------------------===//
// ONNX to ZHigh Lowering Pass
//===----------------------------------------------------------------------===//
@@ -262,9 +621,9 @@ namespace {
#include "src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXONNXToZHigh.inc"
// Enhance 'replaceONNXSumOpPatternRecursion' to allow operating recursively.
-struct ONNXSumOpPatternEnhancedRecursion
+struct replaceONNXSumOpPatternEnhancedRecursion
: public replaceONNXSumOpPatternRecursion {
- ONNXSumOpPatternEnhancedRecursion(MLIRContext *context)
+ replaceONNXSumOpPatternEnhancedRecursion(MLIRContext *context)
: replaceONNXSumOpPatternRecursion(context) {}
void initialize() {
// This pattern recursively unpacks one variadic operand at a time. The
@@ -274,6 +633,892 @@ struct ONNXSumOpPatternEnhancedRecursion
}
};
+/**
+ * This is a pattern for doing i8 dynamic quantization (symmetric mode) for
+ * onnx.MatMul(%A, %B), where %B is a constant.
+ */
+
+class replaceONNXMatMulByDynQuantI8Pattern
+ : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ replaceONNXMatMulByDynQuantI8Pattern(
+ MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false)
+ : OpRewritePattern(context, benefit), symForA(symForA) {}
+
+ LogicalResult matchAndRewrite(
+ ONNXMatMulOp mmOp, PatternRewriter &rewriter) const override {
+ Location loc = mmOp.getLoc();
+ Operation *op = mmOp.getOperation();
+ Value A = mmOp.getA();
+ Value B = mmOp.getB();
+
+ // Dynamic quantization helper.
+ DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, nullptr, symForA);
+
+ // Match
+ if (!isSuitableForZDNN(mmOp) || failed(dqHelper.match()))
+ return rewriter.notifyMatchFailure(op, "MatMul is not suitable for zDNN");
+
+ // Rewrite
+ Value res = dqHelper.rewriteSym();
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ bool symForA = false;
+};
+
+/**
+ * This is a pattern for doing i8 dynamic quantization (symmetric mode) for
+ * `onnx.Add(onnx.MatMul(%A, %B), %C)`. where
+ * - %B and %C are a constant and
+ * - %B and %C must have compatible shape, i.e. the reduction shape on the last
+ * second dim of %B is the same as %C's shape.
+ */
+class replaceONNXMatMulAddByDynQuantI8Pattern
+ : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ replaceONNXMatMulAddByDynQuantI8Pattern(
+ MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false)
+ : OpRewritePattern(context, benefit), symForA(symForA) {}
+
+ LogicalResult matchAndRewrite(
+ ONNXAddOp addOp, PatternRewriter &rewriter) const override {
+ Location loc = addOp.getLoc();
+ Operation *op = addOp.getOperation();
+ Value lhs = addOp.getOperand(0);
+ Value rhs = addOp.getOperand(1);
+
+ // Match A*B+C and C+A*B where B and C are constants, and then rewrite.
+ Value AB, C;
+ if (!areDefinedBy(lhs, rhs, AB, C))
+ return rewriter.notifyMatchFailure(
+ op, "MatMulAdd is not suitable for zDNN.");
+ ONNXMatMulOp mmOp = AB.getDefiningOp();
+ Value A = mmOp.getA();
+ Value B = mmOp.getB();
+
+ // Match A, B, C.
+ DynQuantI8PatternHelper dqHelper(rewriter, loc, op, A, B, C, symForA);
+ if (succeeded(dqHelper.match())) {
+ Value res = dqHelper.rewriteSym();
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ return failure();
+ }
+
+private:
+ bool symForA = false;
+};
+
+/**
+ * This is a pattern for doing i8 dynamic quantization (symmetric mode) for
+ * onnx.Gemm(%A, %B, %C), where %B and %C are constants.
+ *
+ * This pattern is applied only when the compiler option
+ * `--nnpa-quantization={DynSymI8|SymSymI8}` is specified.
+ *
+ */
+
+class replaceONNXGemmByDynQuantI8Pattern : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ replaceONNXGemmByDynQuantI8Pattern(
+ MLIRContext *context, PatternBenefit benefit = 1, bool symForA = false)
+ : OpRewritePattern(context, benefit), symForA(symForA) {}
+
+ LogicalResult matchAndRewrite(
+ ONNXGemmOp gemmOp, PatternRewriter &rewriter) const override {
+ Location loc = gemmOp.getLoc();
+ Operation *op = gemmOp.getOperation();
+
+ Value A = gemmOp.getA();
+ Value B = gemmOp.getB();
+ Value C = gemmOp.getC();
+ bool transA = (gemmOp.getTransA() != 0);
+ bool transB = (gemmOp.getTransB() != 0);
+
+ // Dynamic quantization helper.
+ DynQuantI8PatternHelper dqHelper(
+ rewriter, loc, op, A, B, isNoneValue(C) ? nullptr : C, symForA);
+
+ // Match
+ // TODO: if B is a constant and it is transposed, we can do transpose
+ // explicitly.
+ if (transA || transB)
+ return rewriter.notifyMatchFailure(op, "Gemm is with transpose");
+ if (!isSuitableForZDNN(gemmOp))
+ return rewriter.notifyMatchFailure(op, "Gemm is not suitable for zDNN");
+ if (failed(dqHelper.match()))
+ return failure();
+
+ // Rewrite
+ Value res = dqHelper.rewriteSym();
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ bool symForA = false;
+};
+
+class replaceONNXMatMulIntegerPattern
+ : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ ONNXMatMulIntegerOp mmiOp, PatternRewriter &rewriter) const override {
+ Location loc = mmiOp.getLoc();
+ Operation *op = mmiOp.getOperation();
+ MultiDialectBuilder create(rewriter, loc);
+
+ // Match
+ if (failed(canBeRewritten(rewriter, mmiOp)))
+ return failure();
+
+ Type si64Ty = rewriter.getIntegerType(64, true);
+ Type f16Ty = rewriter.getF16Type();
+ Type f32Ty = rewriter.getF32Type();
+ Type outElemTy = getElementType(mmiOp.getY().getType());
+ IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1);
+ IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0);
+
+ auto cst0Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, f32Ty), static_cast(0));
+ auto cst1Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, f32Ty), static_cast(1));
+ Value none = create.onnx.none();
+ Value zero = create.onnx.constant(cst0Attr);
+ Value zeroI64 = create.onnx.constantInt64({0});
+ Value one = create.onnx.constant(cst1Attr);
+
+ // Prepare inputs for zhigh QuantizedMatMul.
+
+ // I8 tensors
+ Value AI8 = getOrCastToI8(mmiOp.getA(), create, true);
+ Value BI8 = getOrCastToI8(mmiOp.getB(), create, true);
+
+ // Zero points in f32.
+ Value AZeroPointI8 = mmiOp.getAZeroPoint();
+ if (getRank(AZeroPointI8.getType()) == 1) {
+ // Normalize the zeropoint tensor to tensor.
+ AZeroPointI8 = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(AZeroPointI8.getType())),
+ AZeroPointI8, {zeroI64});
+ }
+ AZeroPointI8 = getOrCastToI8(AZeroPointI8, create, true);
+ Value AZeroPointF32 = create.onnx.cast(AZeroPointI8, f32Ty);
+ // TESTING: minus zeropoint in advance to cancel out the software part of
+ // zdnn quantized matmul.
+ // AI8 = create.onnx.sub(AI8, AZeroPointI8);
+ // Value AZeroPointF32 = zero;
+ Value BZeroPointI8 = mmiOp.getBZeroPoint();
+ if (getRank(BZeroPointI8.getType()) == 1) {
+ // Normalize the zeropoint tensor to tensor.
+ BZeroPointI8 = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(BZeroPointI8.getType())),
+ BZeroPointI8, {zeroI64});
+ }
+ BZeroPointI8 = getOrCastToI8(BZeroPointI8, create, true);
+ Value BZeroPointF32 = create.onnx.cast(BZeroPointI8, f32Ty);
+ // TESTING: minus zeropoint in advance to cancel out the software part of
+ // zdnn quantized matmul.
+ // BI8 = create.onnx.sub(BI8, AZeroPointI8);
+ // Value BZeroPointF32 = zero;
+ Value YZeroPointF32 = zero;
+
+ // Recscale in f32.
+ // Set recscale of A and B to 1. In dynamic quantization the output of
+ // MatMulInteger is scaled later outside the op.
+ Value ARecScale = one;
+ Value BRecScale = one;
+ Value YRecScale = one;
+
+ // Only pre-compute bias when B is a constant and BZeroPoint is zero.
+ bool canPreComputeBias = isDenseONNXConstant(BI8) &&
+ isDenseONNXConstant(BZeroPointI8) &&
+ isConstOf(BZeroPointI8, 0.0);
+
+ // Stickify AI8, Transform AI8 into zTensor format.
+ int64_t rankA = getRank(AI8.getType());
+ StringAttr aLayoutAttr =
+ rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ ZHighQuantizedStickOp qAOp =
+ rewriter.create(loc, AI8, ARecScale,
+ AZeroPointF32, aLayoutAttr, rewriter.getStringAttr(QTYPE_INT8));
+
+ // Stickify BI8. It is potentially folded at compile time.
+ int64_t rankB = getRank(BI8.getType());
+ StringAttr bLayoutAttr =
+ rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ ZHighQuantizedStickOp qBOp =
+ rewriter.create(loc, BI8, BRecScale,
+ BZeroPointF32, bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS));
+
+ // Bias is none or precomputed.
+ Value qcTilde, qcTildeRecScale, qcTildeZeroPointF32;
+ if (canPreComputeBias)
+ preComputeBias(create, ARecScale, AZeroPointF32, BI8, BRecScale,
+ YRecScale, YZeroPointF32, qcTilde, qcTildeRecScale,
+ qcTildeZeroPointF32);
+
+ // Emit zhigh.QuantizedMatMul. Bias is none.
+ // Do not dequantize, we want to keep the integer values that will be scaled
+ // outside this op.
+ // Do not clip the output values to i8, keep i32.
+ SmallVector resTypes;
+ resTypes.emplace_back(UnrankedTensorType::get(f16Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ ZHighQuantizedMatMulOp zhighQuantizedMatMulOp =
+ rewriter.create(loc, resTypes,
+ qAOp.getResult(0), qAOp.getResult(1), qAOp.getResult(2),
+ qBOp.getResult(0), qBOp.getResult(1), qBOp.getResult(2),
+ /*Bias*/ canPreComputeBias ? qcTilde : none,
+ /*BiasRecScale*/ canPreComputeBias ? qcTildeRecScale : none,
+ /*BiasOffset*/ canPreComputeBias ? qcTildeZeroPointF32 : none,
+ /*OutRecScale*/ YRecScale, /*OutOffset*/ YZeroPointF32,
+ /*PreComputedBias*/ canPreComputeBias ? trueAttr : falseAttr,
+ /*DisableClipping*/ trueAttr,
+ /*DequantizeOutput*/ falseAttr);
+ (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {});
+
+ // Unstickify the matmul result that is int8-as-float.
+ Value resI8F32 = rewriter.create(
+ loc, zhighQuantizedMatMulOp.getResult(0));
+ Value res = create.onnx.cast(resI8F32, outElemTy);
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ static mlir::LogicalResult canBeRewritten(
+ PatternRewriter &rewriter, ONNXMatMulIntegerOp mmiOp) {
+ if (!isSuitableForZDNN(mmiOp))
+ return rewriter.notifyMatchFailure(
+ mmiOp, "MatMulInteger is not suitable for zDNN");
+ return success();
+ }
+};
+
+// Replace by zhigh ops the following pattern:
+// clang-format off
+// func.func @pattern_in_bert(%X: tensor) : (tensor) -> tensor {
+// %y = onnx.Constant dense_resource<__elided__> : tensor<768x768xi8>
+// %y_scale = onnx.Constant dense<0.00656270096> : tensor
+// %y_zero_point = onnx.Constant dense<0> : tensor
+//
+// %x, %x_scale, %x_zero_point = "onnx.DynamicQuantizeLinear"(%X) : (tensor) -> (tensor, tensor, tensor)
+//
+// %matmul = "onnx.MatMulInteger"(%x, %y, %x_zero_point, %y_zero_point) : (tensor, tensor<768x768xi8>, tensor, tensor) -> tensor
+// %cast = "onnx.Cast"(%matmul) {saturate = 1 : si64, to = f32} : (tensor) -> tensor
+// %mul_1= "onnx.Mul"(%cast, %x_scale) : (tensor, tensor) -> tensor
+// %mul_2= "onnx.Mul"(%mul_1, %y_scale) : (tensor, tensor) -> tensor
+//
+// return %mul_2: tensor
+// }
+// clang-format on
+class replaceMatMulIntegerSubGraphFromMulPattern
+ : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ ONNXMulOp mulOp, PatternRewriter &rewriter) const override {
+ Location loc = mulOp.getLoc();
+ Operation *op = mulOp.getOperation();
+ MultiDialectBuilder create(rewriter, loc);
+
+ // Match
+ Value A, AI8, AScale, AZeroPointI8, BI8, BScale, BZeroPointI8;
+ if (failed(canBeRewritten(rewriter, mulOp, A, AI8, AScale, AZeroPointI8,
+ BI8, BScale, BZeroPointI8)))
+ return failure();
+
+ Type si64Ty = rewriter.getIntegerType(64, true);
+ Type f16Ty = rewriter.getF16Type();
+ Type f32Ty = rewriter.getF32Type();
+ IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1);
+ IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0);
+ Value none = create.onnx.none();
+
+ // Only pre-compute bias when BZeroPoint is zero.
+ bool canPreComputeBias = isDenseONNXConstant(BI8) &&
+ isDenseONNXConstant(BZeroPointI8) &&
+ isConstOf(BZeroPointI8, 0.0);
+
+ // Stickify A.
+ int64_t rankA = getRank(A.getType());
+ StringAttr aLayoutAttr =
+ rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ ZHighQuantizedStickOp qAOp;
+ if (nnpaUseDynamicQuantizeLinearOnCPU) {
+ Value zeroI64 = create.onnx.constantInt64({0});
+ // Input A was quantized on CPU by onnx.DynamicQuantizedLinear: f32 to i8.
+ if (getRank(AZeroPointI8.getType()) == 1) {
+ // Normalize the zeropoint tensor to tensor.
+ AZeroPointI8 = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(AZeroPointI8.getType())),
+ AZeroPointI8, {zeroI64});
+ }
+ AZeroPointI8 = getOrCastToI8(AZeroPointI8, create, true);
+ Value AZeroPointF32 = create.onnx.cast(AZeroPointI8, f32Ty);
+ Value ARecScale = create.onnx.reciprocal(AScale);
+ AI8 = getOrCastToI8(AI8, create, true);
+ // Stickify the quantized input A to ztensor format.
+ qAOp = rewriter.create(loc, AI8, ARecScale,
+ AZeroPointF32, aLayoutAttr, rewriter.getStringAttr(QTYPE_INT8));
+ } else {
+ // Stickify input A to dlfloat16, and it will be quantized internally by
+ // the NNPA quantized matmul.
+ qAOp = rewriter.create(loc, A, none, none,
+ aLayoutAttr, rewriter.getStringAttr(QTYPE_DLFLOAT16));
+ }
+ Value qA = qAOp.getResult(0);
+ Value ARecScale = qAOp.getResult(1);
+ Value AZeroPoint = qAOp.getResult(2);
+
+ // Stickify B. It is potentially folded at compile time.
+ int64_t rankB = getRank(BI8.getType());
+ StringAttr bLayoutAttr =
+ rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ Value BRecScale = create.onnx.reciprocal(BScale);
+ Value BZeroPoint = create.onnx.cast(BZeroPointI8, f32Ty);
+ ZHighQuantizedStickOp qBOp =
+ rewriter.create(loc, BI8, BRecScale, BZeroPoint,
+ bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS));
+ Value qB = qBOp.getResult(0);
+
+ // Output's rescale and zeropoint
+ auto cst0Attr =
+ DenseElementsAttr::get(RankedTensorType::get({}, f32Ty), (float)0);
+ auto cst1Attr =
+ DenseElementsAttr::get(RankedTensorType::get({}, f32Ty), (float)1);
+ Value OutRecScale = create.onnx.constant(cst1Attr);
+ Value OutZeroPoint = create.onnx.constant(cst0Attr);
+
+ // Bias is none or precomputed.
+ Value qcTilde, qcTildeRecScale, qcTildeZeroPoint;
+ if (canPreComputeBias)
+ preComputeBias(create, ARecScale, AZeroPoint, BI8, BRecScale, OutRecScale,
+ OutZeroPoint, qcTilde, qcTildeRecScale, qcTildeZeroPoint);
+
+ // Emit zhigh.QuantizedMatMul.
+ SmallVector resTypes;
+ resTypes.emplace_back(UnrankedTensorType::get(f16Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ ZHighQuantizedMatMulOp zhighQuantizedMatMulOp =
+ rewriter.create(loc, resTypes, qA, ARecScale,
+ AZeroPoint, qB, BRecScale, BZeroPoint,
+ /*Bias*/ canPreComputeBias ? qcTilde : none,
+ /*BiasRecScale*/ canPreComputeBias ? qcTildeRecScale : none,
+ /*BiasOffset*/ canPreComputeBias ? qcTildeZeroPoint : none,
+ /*OutRecScale*/ OutRecScale, /*OutOffset*/ OutZeroPoint,
+ /*PreComputedBias*/ canPreComputeBias ? trueAttr : falseAttr,
+ /*DequantizeOutput*/ trueAttr);
+ (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {});
+
+ // Unstickify the matmul result.
+ Value res = rewriter.create(
+ loc, zhighQuantizedMatMulOp.getResult(0));
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ // clang-format off
+ // func.func @pattern_in_bert(%A) {
+ // // A is dynamically quantized.
+ // %a, %a_scale, %a_zero_point = "onnx.DynamicQuantizeLinear"(%A)
+ //
+ // // B is a constant and already quantized.
+ // %b = onnx.Constant
+ // %b_scale = onnx.Constant
+ // %b_zero_point = onnx.Constant
+ //
+ //
+ // %matmul = "onnx.MatMulInteger"(%b, %b, %b_zero_point, %b_zero_point)
+ //
+ // // Scale the output.
+ // %mm_f32 = "onnx.Cast"(%matmul) {to = f32}
+ // %mm_a_scale = "onnx.Mul"(%mm_f32, %a_scale)
+ // %mm_ab_scale = "onnx.Mul"(%mm_a_scale, %b_scale)
+ //
+ // return %mm_y_scale
+ // }
+ // clang-format on
+ static mlir::LogicalResult canBeRewritten(PatternRewriter &rewriter,
+ ONNXMulOp mulOp, Value &A, Value &AI8, Value &AScale, Value &AZeroPoint,
+ Value &BI8, Value &BScale, Value &BZeroPoint) {
+
+ // Match `cast(mm_out) * a_scale * b_scale` to find two scales but we don't
+ // know yet which scale is for A or B.
+ Value scale1, scale2;
+ ONNXCastOp castOp;
+ ONNXMulOp mulScaleOp;
+
+ Value opr1 = mulOp.getOperand(0);
+ Value opr2 = mulOp.getOperand(1);
+
+ // Match cast(mm_out) * (a_scale * b_scale)
+ castOp = opr1.getDefiningOp();
+ mulScaleOp = opr2.getDefiningOp();
+ bool foundScales = false;
+ if (castOp && mulScaleOp && isScalarTensor(opr2)) {
+ Value lhs = mulScaleOp.getOperand(0);
+ Value rhs = mulScaleOp.getOperand(1);
+ if (isScalarTensor(lhs) && isScalarTensor(rhs)) {
+ // mulScaleOp is a_scale * b_scale;
+ foundScales = true;
+ scale1 = lhs;
+ scale2 = rhs;
+ }
+ }
+ // Match (a_scale * b_scale) * cast(mm_out)
+ if (!foundScales) {
+ mulScaleOp = opr1.getDefiningOp();
+ castOp = opr2.getDefiningOp();
+ if (mulScaleOp && isScalarTensor(opr1) && castOp) {
+ Value lhs = mulScaleOp.getOperand(0);
+ Value rhs = mulScaleOp.getOperand(1);
+ if (isScalarTensor(lhs) && isScalarTensor(rhs)) {
+ // mulScaleOp is a_scale * b_scale;
+ foundScales = true;
+ scale1 = lhs;
+ scale2 = rhs;
+ }
+ }
+ }
+ // Match [cast(mm_out) * a_scale] * b_scale
+ if (!foundScales & isScalarTensor(opr2)) {
+ scale1 = opr2;
+ mulScaleOp = opr1.getDefiningOp();
+ if (mulScaleOp) {
+ Value lhs = mulScaleOp.getOperand(0);
+ Value rhs = mulScaleOp.getOperand(1);
+ castOp = lhs.getDefiningOp();
+ if (castOp && isScalarTensor(rhs)) {
+ // Match cast(mm_out) * a_scale
+ scale2 = rhs;
+ foundScales = true;
+ }
+ if (!foundScales) {
+ // Match a_scale * cast(mm_out)
+ castOp = rhs.getDefiningOp();
+ if (isScalarTensor(lhs) && castOp) {
+ scale2 = lhs;
+ foundScales = true;
+ }
+ }
+ }
+ // Match b_scale * [cast(mm_out) * a_scale]
+ if (!foundScales && isScalarTensor(opr1)) {
+ scale1 = opr1;
+ mulScaleOp = opr2.getDefiningOp();
+ if (mulScaleOp) {
+ Value lhs = mulScaleOp.getOperand(0);
+ Value rhs = mulScaleOp.getOperand(1);
+ castOp = lhs.getDefiningOp();
+ if (castOp && isScalarTensor(rhs)) {
+ // Match cast(mm_out) * a_scale
+ scale2 = rhs;
+ foundScales = true;
+ }
+ if (!foundScales) {
+ // Match a_scale * cast(mm_out)
+ castOp = rhs.getDefiningOp();
+ if (isScalarTensor(lhs) && castOp) {
+ scale2 = lhs;
+ foundScales = true;
+ }
+ }
+ }
+ }
+ }
+ if (!foundScales)
+ return rewriter.notifyMatchFailure(mulOp, "Not found scale values");
+
+ // Identify a_scale and b_scale.
+ // a_scale is from DynamicQuantizeLinear.
+ if (scale1.getDefiningOp()) {
+ AScale = scale1;
+ BScale = scale2;
+ } else if (scale2.getDefiningOp()) {
+ AScale = scale2;
+ BScale = scale1;
+ } else {
+ return rewriter.notifyMatchFailure(
+ mulOp, "Could not identify a_scale and b_scale");
+ }
+
+ // Match cast.
+ // %cast = "onnx.Cast"(%matmul) {saturate = 1 : si64, to = f32}
+ Type castOutputType = castOp.getOutput().getType();
+ Type castInputType = castOp.getInput().getType();
+ if (isRankedShapedType(castInputType) &&
+ isRankedShapedType(castOutputType)) {
+ if (!getElementType(castInputType).isInteger(32))
+ return rewriter.notifyMatchFailure(
+ mulOp, "ONNXCast is not casting from i32");
+ if (!getElementType(castOutputType).isF32())
+ return rewriter.notifyMatchFailure(
+ mulOp, "ONNXCast is not casting to f32");
+ } else {
+ return rewriter.notifyMatchFailure(mulOp, "ONNXCast is unranked");
+ }
+
+ // Match matmul to get BI8 and BZeroPoint.
+ ONNXMatMulIntegerOp matmulOp =
+ castOp.getInput().getDefiningOp();
+ if (!matmulOp)
+ return rewriter.notifyMatchFailure(
+ mulOp, "The input of the CastOp is not defined by MatMulIntegerOp");
+ if (!isSuitableForZDNN(matmulOp))
+ return rewriter.notifyMatchFailure(
+ mulOp, "MatMulInteger is not suitable for zDNN");
+
+ AI8 = matmulOp->getOperand(0);
+ BI8 = matmulOp->getOperand(1);
+ AZeroPoint = matmulOp->getOperand(2);
+ BZeroPoint = matmulOp->getOperand(3);
+ if (!isDenseONNXConstant(BI8))
+ return rewriter.notifyMatchFailure(mulOp, "Quantized Y is not constant");
+ if (!isDenseONNXConstant(BZeroPoint))
+ return rewriter.notifyMatchFailure(mulOp, "BZeroPoint is not constant");
+ if (!(getElementType(BI8.getType()).isUnsignedInteger(8) ||
+ getElementType(BI8.getType()).isSignlessInteger(8)))
+ return rewriter.notifyMatchFailure(
+ mulOp, "Quantized Y is not signed int8");
+
+ // Match dynamic quantize linear to get A.
+ if (auto dqlOp =
+ llvm::dyn_cast(AI8.getDefiningOp())) {
+ if (AScale != dqlOp.getResult(1))
+ return rewriter.notifyMatchFailure(mulOp, "AScale is not used");
+ if (AZeroPoint != dqlOp.getResult(2))
+ return rewriter.notifyMatchFailure(mulOp, "AZeroPoint is not used");
+ // return A.
+ A = dqlOp.getOperand();
+ } else {
+ return rewriter.notifyMatchFailure(
+ mulOp, "Quantized A is not defined by DynamicQuantizeLinearOp");
+ }
+
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// Fuse ZHighQuantizedMatMul and ONNXAdd
+//===----------------------------------------------------------------------===//
+// Rewrite this pattern:
+// (ONNXAddOp
+// $x,
+// (ZHighUnstickOp
+// (ZHighQuantizedMatMulOp:$mm_res
+// $a, $Sa, $Za,
+// $b, $Sb, $Zb,
+// (ZHighQuantizedStick $c), $Sc, $Zb,
+// $So, $Zo,
+// $preComputed, $disableClipping, $dequantized))),
+//
+// into this pattern where $x is added to $c:
+//
+// (ZHighUnstickOp
+// (ZHighQuantizedMatMulOp
+// $a, $Sa, $Za,
+// $b, $Sb, $Zb,
+// (ZHighQuantizedStick (ONNXAddOp $x, $c)), $Sc, $Zb,
+// $So, $Zo,
+// $preComputed, $disableClipping, $dequantized)),
+//
+// Requirement: `preComputed` is true.
+
+class fuseZHighQuantizedMatMulONNXAddPattern
+ : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ ONNXAddOp addOp, PatternRewriter &rewriter) const override {
+ Location loc = addOp.getLoc();
+ Operation *op = addOp.getOperation();
+ MultiDialectBuilder create(rewriter, loc);
+
+ ZHighUnstickOp unstickOp;
+ ZHighQuantizedMatMulOp mmOp;
+ ZHighQuantizedStickOp qstickOp;
+ Value addInput;
+
+ // match
+ if (failed(canBeRewritten(
+ rewriter, addOp, unstickOp, mmOp, qstickOp, addInput)))
+ return failure();
+
+ // rewrite
+ Value newBias = create.onnx.add(addInput, qstickOp.getIn());
+ ZHighQuantizedStickOp newQStickOp = rewriter.create(
+ loc, newBias, qstickOp.getInRecScale(), qstickOp.getInOffset(),
+ qstickOp.getLayoutAttr(), qstickOp.getQuantizedTypeAttr());
+
+ SmallVector resTypes;
+ resTypes.emplace_back(mmOp.getResult(0).getType());
+ resTypes.emplace_back(mmOp.getResult(1).getType());
+ resTypes.emplace_back(mmOp.getResult(2).getType());
+ ZHighQuantizedMatMulOp newQMMOp = rewriter.create(
+ loc, resTypes, mmOp.getX(), mmOp.getXRecScale(), mmOp.getXOffset(),
+ mmOp.getY(), mmOp.getYRecScale(), mmOp.getYOffset(),
+ newQStickOp.getResult(0), newQStickOp.getResult(1),
+ newQStickOp.getResult(2), mmOp.getOutRecScaleIn(),
+ mmOp.getOutOffsetIn(), mmOp.getPreComputedBiasAttr(),
+ mmOp.getDisableClippingAttr(), mmOp.getDequantizeOutputAttr());
+ ZHighUnstickOp newUnstickOp =
+ rewriter.create(loc, newQMMOp.getResult(0));
+
+ rewriter.replaceOp(op, newUnstickOp);
+ return success();
+ }
+
+ static mlir::LogicalResult canBeRewritten(PatternRewriter &rewriter,
+ ONNXAddOp addOp, ZHighUnstickOp &unstickOp, ZHighQuantizedMatMulOp &mmOp,
+ ZHighQuantizedStickOp &qstickOp, Value &addInput) {
+ Value lhs = addOp.getOperand(0);
+ Value rhs = addOp.getOperand(1);
+ bool found = false;
+ if (auto op1 = lhs.getDefiningOp()) {
+ addInput = rhs;
+ unstickOp = op1;
+ Value mmOutput = unstickOp.getIn();
+ if (auto op2 = mmOutput.getDefiningOp()) {
+ mmOp = op2;
+ bool precomputed = (mmOp.getPreComputedBias() == -1);
+ if (!precomputed)
+ return rewriter.notifyMatchFailure(
+ addOp, "not precomputed quantized matmul");
+ Value qBias = mmOp.getB();
+ if (auto op3 = qBias.getDefiningOp()) {
+ qstickOp = op3;
+ Value bias = qstickOp.getIn();
+ // Check rank.
+ if (getRank(bias.getType()) != getRank(addInput.getType()))
+ return rewriter.notifyMatchFailure(addOp, "rank mismatched");
+ found = true;
+ }
+ }
+ }
+ if (found)
+ return success();
+
+ if (auto op1 = rhs.getDefiningOp()) {
+ addInput = lhs;
+ unstickOp = op1;
+ Value mmOutput = unstickOp.getIn();
+ if (auto op2 = mmOutput.getDefiningOp()) {
+ mmOp = op2;
+ bool precomputed = (mmOp.getPreComputedBias() == -1);
+ if (!precomputed)
+ return rewriter.notifyMatchFailure(
+ addOp, "not precomputed quantized matmul");
+ Value qBias = mmOp.getB();
+ if (auto op3 = qBias.getDefiningOp()) {
+ qstickOp = op3;
+ Value bias = qstickOp.getIn();
+ // Check rank.
+ if (getRank(bias.getType()) != getRank(addInput.getType()))
+ return rewriter.notifyMatchFailure(addOp, "rank mismatched");
+ found = true;
+ }
+ }
+ }
+ if (found)
+ return success();
+
+ return rewriter.notifyMatchFailure(addOp, "unstick not found");
+ }
+};
+
+class replaceONNXQLinearMatMulPattern
+ : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(
+ ONNXQLinearMatMulOp qmmOp, PatternRewriter &rewriter) const override {
+ Location loc = qmmOp.getLoc();
+ Operation *op = qmmOp.getOperation();
+ MultiDialectBuilder create(rewriter, loc);
+
+ // Match
+ if (failed(canBeRewritten(rewriter, qmmOp)))
+ return failure();
+
+ Type si64Ty = rewriter.getIntegerType(64, true);
+ Type f16Ty = rewriter.getF16Type();
+ Type f32Ty = rewriter.getF32Type();
+ IntegerAttr trueAttr = rewriter.getIntegerAttr(si64Ty, -1);
+ IntegerAttr falseAttr = rewriter.getIntegerAttr(si64Ty, 0);
+
+ Value A = qmmOp.getA();
+ Value AScale = qmmOp.getAScale();
+ Value AZeroPoint = qmmOp.getAZeroPoint();
+ Value B = qmmOp.getB();
+ Value BScale = qmmOp.getBScale();
+ Value BZeroPoint = qmmOp.getBZeroPoint();
+ Value Y = qmmOp.getY();
+ Value YScale = qmmOp.getYScale();
+ Value YZeroPoint = qmmOp.getYZeroPoint();
+
+ // Only pre-compute bias when B is a constant and BZeroPoint is int8 zero.
+ bool canPreComputeBias = false;
+ if (isDenseONNXConstant(B) && isDenseONNXConstant(BZeroPoint)) {
+ if (getElementType(BZeroPoint.getType()).isUnsignedInteger())
+ canPreComputeBias = isConstOf(BZeroPoint, 128.0);
+ else
+ canPreComputeBias = isConstOf(BZeroPoint, 0.0);
+ }
+
+ // Emit some common values.
+ Value none = create.onnx.none();
+ Value zero = create.onnx.constantInt64({0});
+
+ // Normalize scalar tensors to tensor.
+ if (getRank(AScale.getType()) == 1) {
+ AScale = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(AScale.getType())), AScale,
+ {zero});
+ }
+ if (getRank(AZeroPoint.getType()) == 1) {
+ AZeroPoint = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(AZeroPoint.getType())),
+ AZeroPoint, {zero});
+ }
+ if (getRank(BScale.getType()) == 1) {
+ BScale = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(BScale.getType())), BScale,
+ {zero});
+ }
+ if (getRank(BZeroPoint.getType()) == 1) {
+ BZeroPoint = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(BZeroPoint.getType())),
+ BZeroPoint, {zero});
+ }
+ if (getRank(YScale.getType()) == 1) {
+ YScale = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(YScale.getType())), YScale,
+ {zero});
+ }
+ if (getRank(YZeroPoint.getType()) == 1) {
+ YZeroPoint = create.onnx.squeeze(
+ RankedTensorType::get({}, getElementType(YZeroPoint.getType())),
+ YZeroPoint, {zero});
+ }
+
+ // zdnn supports signed int8, convert unsigned int8 inputs to signed int8.
+ Value AI8 = getOrCastToI8(A, create);
+ Value BI8 = getOrCastToI8(B, create);
+
+ Value ARecScale = create.onnx.reciprocal(AScale);
+ Value AZeroPointI8 = getOrCastToI8(AZeroPoint, create);
+ Value AZeroPointF32 = create.onnx.cast(AZeroPointI8, f32Ty);
+
+ Value BRecScale = create.onnx.reciprocal(BScale);
+ Value BZeroPointI8 = getOrCastToI8(BZeroPoint, create);
+ Value BZeroPointF32 = create.onnx.cast(BZeroPointI8, f32Ty);
+
+ Value YRecScale = create.onnx.reciprocal(YScale);
+ Value YZeroPointI8 = getOrCastToI8(YZeroPoint, create);
+ Value YZeroPointF32 = create.onnx.cast(YZeroPointI8, f32Ty);
+
+ // Stickify AI8, Transform AI8 into zTensor format.
+ int64_t rankA = getRank(AI8.getType());
+ StringAttr aLayoutAttr =
+ rewriter.getStringAttr((rankA == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ ZHighQuantizedStickOp qAOp =
+ rewriter.create(loc, AI8, ARecScale,
+ AZeroPointF32, aLayoutAttr, rewriter.getStringAttr(QTYPE_INT8));
+
+ // Stickify BI8. It is potentially folded at compile time.
+ int64_t rankB = getRank(BI8.getType());
+ StringAttr bLayoutAttr =
+ rewriter.getStringAttr((rankB == 2) ? LAYOUT_2D : LAYOUT_3DS);
+ ZHighQuantizedStickOp qBOp =
+ rewriter.create(loc, BI8, BRecScale,
+ BZeroPointF32, bLayoutAttr, rewriter.getStringAttr(QTYPE_WEIGHTS));
+
+ // Bias is none or precomputed.
+ Value qcTilde, qcTildeRecScale, qcTildeZeroPointF32;
+ if (canPreComputeBias)
+ preComputeBias(create, ARecScale, AZeroPointF32, BI8, BRecScale,
+ YRecScale, YZeroPointF32, qcTilde, qcTildeRecScale,
+ qcTildeZeroPointF32);
+
+ // Emit zhigh.QuantizedMatMul. Bias is none.
+ // DisableClipping gives the same output as the onnx backend test since the
+ // onnx backend test uses `astype` instead of `clipping` to cast the output
+ // to i8.
+ SmallVector resTypes;
+ resTypes.emplace_back(UnrankedTensorType::get(f16Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ resTypes.emplace_back(RankedTensorType::get({}, f32Ty));
+ ZHighQuantizedMatMulOp zhighQuantizedMatMulOp =
+ rewriter.create(loc, resTypes,
+ qAOp.getResult(0), qAOp.getResult(1), qAOp.getResult(2),
+ qBOp.getResult(0), qBOp.getResult(1), qBOp.getResult(2),
+ /*Bias*/ canPreComputeBias ? qcTilde : none,
+ /*BiasRecScale*/ canPreComputeBias ? qcTildeRecScale : none,
+ /*BiasOffset*/ canPreComputeBias ? qcTildeZeroPointF32 : none,
+ /*OutRecScale*/ YRecScale, /*OutOffset*/ YZeroPointF32,
+ /*PreComputedBias*/ canPreComputeBias ? trueAttr : falseAttr,
+ /*DisableClipping*/ trueAttr,
+ /*DequantizeOutput*/ falseAttr);
+ (void)zhighQuantizedMatMulOp.inferShapes([](Region ®ion) {});
+
+ // Unstickify the matmul result that is int8-as-float.
+ Value resI8F32 = rewriter.create(
+ loc, zhighQuantizedMatMulOp.getResult(0));
+ Value res;
+ Type outElemTy = getElementType(Y.getType());
+ if (outElemTy.isUnsignedInteger(8)) {
+ // The zdnn output is int8. Convert int8 to uint8.
+ // Use int16 to avoid integer overflow.
+ Type i16Ty = rewriter.getI16Type();
+ Type ui16Ty = rewriter.getIntegerType(16, false);
+ auto cst128Attr = DenseElementsAttr::get(
+ RankedTensorType::get({}, i16Ty), static_cast(128));
+ // clang-format off
+ Value resUI16 =
+ create.onnx.cast(
+ create.onnx.add(create.onnx.cast(resI8F32, i16Ty),
+ create.onnx.constant(cst128Attr)),
+ ui16Ty);
+ // clang-format on
+ res = create.onnx.cast(resUI16, outElemTy);
+ } else {
+ res = create.onnx.cast(resI8F32, outElemTy);
+ }
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+ static mlir::LogicalResult canBeRewritten(
+ PatternRewriter &rewriter, ONNXQLinearMatMulOp qmmOp) {
+ if (!isSuitableForZDNN(qmmOp))
+ return rewriter.notifyMatchFailure(
+ qmmOp, "QLinearMatMul is not suitable for zDNN");
+ return success();
+ }
+};
+
struct ONNXToZHighLoweringPass
: public PassWrapper> {
@@ -290,14 +1535,85 @@ struct ONNXToZHighLoweringPass
ONNXToZHighLoweringPass() = default;
ONNXToZHighLoweringPass(const ONNXToZHighLoweringPass &pass)
: PassWrapper>() {}
+ ONNXToZHighLoweringPass(NNPAQuantType quantMode) {
+ this->quantMode = quantMode;
+ }
void runOnOperation() final;
+
+public:
+ Option quantMode{*this, "quantization",
+ llvm::cl::desc("Enable quantization"),
+ llvm::cl::values(
+ clEnumVal(DynSymI8,
+ "Dynamic Quantization to signed integer 8. Asymmetric quant for "
+ "activations and symmetric quant for weights."),
+ clEnumVal(SymSymI8,
+ "Dynamic Quantization to signed integer 8. Symmetric quant for "
+ "activations and symmetric quant for weights."),
+ clEnumVal(QNONE, "No quantization (default).")),
+ llvm::cl::init(QNONE)};
};
} // end anonymous namespace.
-void getONNXToZHighOneOpPatterns(RewritePatternSet &patterns) {
+void getONNXToZHighOneOpPatterns(
+ RewritePatternSet &patterns, NNPAQuantType quantMode) {
MLIRContext *context = patterns.getContext();
- populateWithGenerated(patterns);
- patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+ patterns.insert(context);
+
+ // Pattern for i8 dynamic quantization, symmetric mode.
+ if (isCompatibleWithNNPALevel(NNPALevel::M15) &&
+ (quantMode == NNPAQuantType::DynSymI8 ||
+ quantMode == NNPAQuantType::SymSymI8)) {
+ // Bump up the pattern benefit to run these before non-quantization
+ // patterns.
+ PatternBenefit quantPriority(QUANT_PATTERN_BENEFIT);
+ patterns.insert(
+ context, quantPriority, quantMode == NNPAQuantType::SymSymI8);
+ patterns.insert(
+ context, quantPriority, quantMode == NNPAQuantType::SymSymI8);
+ }
}
void getONNXToZHighOneOpDynamicallyLegal(
@@ -309,7 +1625,10 @@ void getONNXToZHighOneOpDynamicallyLegal(
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
@@ -319,18 +1638,42 @@ void getONNXToZHighOneOpDynamicallyLegal(
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
+ addDynamicallyLegalOpFor(target, dimAnalysis);
}
-void getONNXToZHighMultipleOpPatterns(RewritePatternSet &patterns) {
+void getONNXToZHighMultipleOpPatterns(
+ RewritePatternSet &patterns, NNPAQuantType quantMode) {
MLIRContext *context = patterns.getContext();
patterns.insert