|
15 | 15 | #include "mlir/Dialect/Tosa/IR/TosaOps.td"
|
16 | 16 | #include "stablehlo/dialect/StablehloOps.td"
|
17 | 17 |
|
| 18 | +Rewrite zeroConst() -> Op [{ |
| 19 | + auto type = rewriter.getI8Type(); |
| 20 | + auto attr = mlir::DenseElementsAttr::get( |
| 21 | + llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type)); |
| 22 | + return rewriter.create<mlir::tosa::ConstOp>( |
| 23 | + rewriter.getUnknownLoc(), type, attr); |
| 24 | +}]; |
| 25 | + |
18 | 26 | // Helper functions.
|
19 | 27 | Rewrite onesLike(op: Op, type: Type) -> Op [{
|
20 | 28 | auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
|
@@ -137,7 +145,16 @@ Pattern =>
|
137 | 145 | Pattern =>
|
138 | 146 | replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
|
139 | 147 | input1 : Value<_: Tosa_Tensor>)
|
140 |
| - with op<tosa.mul>(input0, input1) {shift = attr<"0 : i8">}; |
| 148 | + with op<tosa.mul>(input0, input1, zeroConst()); |
| 149 | +// Pattern { |
| 150 | +// let root = op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>, |
| 151 | +// input1 : Value<_: Tosa_Tensor>); |
| 152 | +// rewrite root with { |
| 153 | +// let c0 = zeroConst(); |
| 154 | +// let mulResult = op<tosa.mul>(input0, input1, c0); |
| 155 | +// replace root with mulResult; |
| 156 | +// }; |
| 157 | +// } |
141 | 158 | Pattern =>
|
142 | 159 | replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
|
143 | 160 | input1 : Value<_: Tosa_Tensor>)
|
|
0 commit comments