Skip to content

Commit fa8c237

Browse files
committed
fix mulop
1 parent 48a1e14 commit fa8c237

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

stablehlo/conversions/tosa/transforms/StablehloLegalizeToTosa.pdll

+18-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
#include "mlir/Dialect/Tosa/IR/TosaOps.td"
1616
#include "stablehlo/dialect/StablehloOps.td"
1717

18+
Rewrite zeroConst() -> Op [{
19+
auto type = rewriter.getI8Type();
20+
auto attr = mlir::DenseElementsAttr::get(
21+
llvm::cast<mlir::ShapedType>(type), rewriter.getZeroAttr(type));
22+
return rewriter.create<mlir::tosa::ConstOp>(
23+
rewriter.getUnknownLoc(), type, attr);
24+
}];
25+
1826
// Helper functions.
1927
Rewrite onesLike(op: Op, type: Type) -> Op [{
2028
auto elementType = llvm::cast<mlir::TensorType>(type).getElementType();
@@ -137,7 +145,16 @@ Pattern =>
137145
Pattern =>
138146
replace op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
139147
input1 : Value<_: Tosa_Tensor>)
140-
with op<tosa.mul>(input0, input1) {shift = attr<"0 : i8">};
148+
with op<tosa.mul>(input0, input1, zeroConst());
149+
// Pattern {
150+
// let root = op<stablehlo.multiply>(input0 : Value<_: Tosa_Tensor>,
151+
// input1 : Value<_: Tosa_Tensor>);
152+
// rewrite root with {
153+
// let c0 = zeroConst();
154+
// let mulResult = op<tosa.mul>(input0, input1, c0);
155+
// replace root with mulResult;
156+
// };
157+
// }
141158
Pattern =>
142159
replace op<stablehlo.or>(input0 : Value<_: Tosa_Tensor>,
143160
input1 : Value<_: Tosa_Tensor>)

0 commit comments

Comments
 (0)