Skip to content

Commit 1801fb4

Browse files
authored
[MLIR] Fixes arith.sub folder crash on dynamically shaped tensors (#118908)
We can't create a constant for a value with dynamic shape. Fixes #118772
1 parent 92376c3 commit 1801fb4

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,12 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(
393393

394394
OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
395395
// subi(x,x) -> 0
396-
if (getOperand(0) == getOperand(1))
397-
return Builder(getContext()).getZeroAttr(getType());
396+
if (getOperand(0) == getOperand(1)) {
397+
auto shapedType = dyn_cast<ShapedType>(getType());
398+
// We can't generate a constant with a dynamic shaped tensor.
399+
if (!shapedType || shapedType.hasStaticShape())
400+
return Builder(getContext()).getZeroAttr(getType());
401+
}
398402
// subi(x,0) -> x
399403
if (matchPattern(adaptor.getRhs(), m_Zero()))
400404
return getLhs();

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,27 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
869869
return %add2 : index
870870
}
871871

872+
873+
// CHECK-LABEL: @foldSubXX_tensor
874+
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
875+
// CHECK: %[[sub:.+]] = arith.subi
876+
// CHECK: return %[[c0]], %[[sub]]
877+
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
878+
%static_sub = arith.subi %static, %static : tensor<10xi32>
879+
%dyn_sub = arith.subi %dyn, %dyn : tensor<?x?xi32>
880+
return %static_sub, %dyn_sub : tensor<10xi32>, tensor<?x?xi32>
881+
}
882+
883+
// CHECK-LABEL: @foldSubXX_vector
884+
// CHECK-DAG: %[[c0:.+]] = arith.constant dense<0> : vector<8xi32>
885+
// CHECK-DAG: %[[c0_scalable:.+]] = arith.constant dense<0> : vector<[4]xi32>
886+
// CHECK: return %[[c0]], %[[c0_scalable]]
887+
func.func @foldSubXX_vector(%static : vector<8xi32>, %dyn : vector<[4]xi32>) -> (vector<8xi32>, vector<[4]xi32>) {
888+
%static_sub = arith.subi %static, %static : vector<8xi32>
889+
%dyn_sub = arith.subi %dyn, %dyn : vector<[4]xi32>
890+
return %static_sub, %dyn_sub : vector<8xi32>, vector<[4]xi32>
891+
}
892+
872893
// CHECK-LABEL: @tripleAddSub0
873894
// CHECK: %[[cres:.+]] = arith.constant 59 : index
874895
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index

0 commit comments

Comments
 (0)