Skip to content

Commit cdc09a1

Browse files
authored
[mlir][IntRangeInference] Infer values for {memref,tensor}.dim (#122945)
Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the `dim` of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the `dim` argument could be validly referring to.
1 parent de7438e commit cdc09a1

File tree

13 files changed

+229
-3
lines changed

13 files changed

+229
-3
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/CastInterfaces.h"
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/CopyOpInterface.h"
20+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/MemorySlotInterfaces.h"
2223
#include "mlir/Interfaces/ShapedOpInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
1414
include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/CopyOpInterface.td"
17+
include "mlir/Interfaces/InferIntRangeInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/MemorySlotInterfaces.td"
1920
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
577578
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
578579
MemRefsNormalizable,
579580
ConditionallySpeculatable, NoMemoryEffect,
580-
ShapedDimOpInterface]> {
581+
ShapedDimOpInterface,
582+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
581583
let summary = "dimension index operation";
582584
let description = [{
583585
The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16751677
}]>,
16761678

16771679
// Builder that infers the result layout map. The result shape must be
1678-
// specified. Otherwise, the op may be ambiguous. The output shape for
1680+
// specified. Otherwise, the op may be ambiguous. The output shape for
16791681
// the op will be inferred using the inferOutputShape() method.
16801682
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
16811683
"ArrayRef<ReassociationIndices>":$reassociation)>,

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Interfaces/CastInterfaces.h"
1919
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2020
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
21+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2122
#include "mlir/Interfaces/InferTypeOpInterface.h"
2223
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
2324
#include "mlir/Interfaces/ShapedOpInterfaces.h"

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
1313
include "mlir/Interfaces/CastInterfaces.td"
1414
include "mlir/Interfaces/ControlFlowInterfaces.td"
1515
include "mlir/Interfaces/DestinationStyleOpInterface.td"
16+
include "mlir/Interfaces/InferIntRangeInterface.td"
1617
include "mlir/Interfaces/InferTypeOpInterface.td"
1718
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
1819
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -197,7 +198,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
197198
def Tensor_DimOp : Tensor_Op<"dim", [
198199
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
199200
ConditionallySpeculatable, NoMemoryEffect,
200-
ShapedDimOpInterface]> {
201+
ShapedDimOpInterface,
202+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
201203
let summary = "dimension index operation";
202204
let description = [{
203205
The `tensor.dim` operation takes a tensor and a dimension operand of type

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

+8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <optional>
2121

2222
namespace mlir {
23+
class ShapedDimOpInterface;
24+
2325
namespace intrange {
2426
/// Function that performs inference on an array of `ConstantIntRanges`,
2527
/// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
143145
const ConstantIntRanges &lhs,
144146
const ConstantIntRanges &rhs);
145147

148+
/// Returns the integer range for the result of a `ShapedDimOpInterface` given
149+
/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
150+
/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
151+
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
152+
const IntegerValueRange &maybeDim);
153+
146154
} // namespace intrange
147155
} // namespace mlir
148156

mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
1616
MLIRControlFlowInterfaces
1717
MLIRDialect
1818
MLIRDialectUtils
19+
MLIRInferIntRangeCommon
20+
MLIRInferIntRangeInterface
1921
MLIRInferTypeOpInterface
2022
MLIRIR
2123
MLIRMemorySlotInterfaces

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/TypeUtilities.h"
2020
#include "mlir/Interfaces/InferTypeOpInterface.h"
2121
#include "mlir/Interfaces/SideEffectInterfaces.h"
22+
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
2223
#include "mlir/Interfaces/ViewLikeInterface.h"
2324
#include "llvm/ADT/STLExtras.h"
2425
#include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
915916
return Speculation::Speculatable;
916917
}
917918

919+
void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
920+
SetIntLatticeFn setResultRange) {
921+
setResultRange(getResult(),
922+
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
923+
}
924+
918925
/// Return a map with key being elements in `vals` and data being number of
919926
/// occurences of it. Use std::map, since the `vals` here are strides and the
920927
/// dynamic stride value is the same as the tombstone value for

mlir/lib/Dialect/Tensor/IR/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
2626
MLIRDestinationStyleOpInterface
2727
MLIRDialectUtils
2828
MLIRIR
29+
MLIRInferIntRangeCommon
30+
MLIRInferIntRangeInterface
2931
MLIRInferTypeOpInterface
3032
MLIRParallelCombiningOpInterface
3133
MLIRShapedOpInterfaces

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

+8
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
#include "mlir/IR/OpDefinition.h"
2424
#include "mlir/IR/TypeUtilities.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
26+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2627
#include "mlir/Interfaces/LoopLikeInterface.h"
28+
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
2729
#include "mlir/Support/LLVM.h"
2830
#include "llvm/ADT/DenseSet.h"
2931
#include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
782784
return Speculation::Speculatable;
783785
}
784786

787+
void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
788+
SetIntLatticeFn setResultRange) {
789+
setResultRange(getResult(),
790+
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
791+
}
792+
785793
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
786794
// All forms of folding require a known index.
787795
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());

mlir/lib/Interfaces/Utils/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
88
MLIRInferIntRangeInterfaceIncGen
99

1010
LINK_LIBS PUBLIC
11+
MLIRShapedOpInterfaces
1112
MLIRInferIntRangeInterface
1213
MLIRIR
1314
)

mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp

+44
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
1515

1616
#include "mlir/Interfaces/InferIntRangeInterface.h"
17+
#include "mlir/Interfaces/ShapedOpInterfaces.h"
1718

1819
#include "llvm/ADT/ArrayRef.h"
1920
#include "llvm/ADT/STLExtras.h"
@@ -725,3 +726,46 @@ std::optional<bool> mlir::intrange::evaluatePred(CmpPredicate pred,
725726
return false;
726727
return std::nullopt;
727728
}
729+
730+
//===----------------------------------------------------------------------===//
731+
// Shaped type dimension accessors / ShapedDimOpInterface
732+
//===----------------------------------------------------------------------===//
733+
734+
ConstantIntRanges
735+
mlir::intrange::inferShapedDimOpInterface(ShapedDimOpInterface op,
736+
const IntegerValueRange &maybeDim) {
737+
unsigned width =
738+
ConstantIntRanges::getStorageBitwidth(op->getResult(0).getType());
739+
APInt zero = APInt::getZero(width);
740+
APInt typeMax = APInt::getSignedMaxValue(width);
741+
742+
auto shapedTy = cast<ShapedType>(op.getShapedValue().getType());
743+
if (!shapedTy.hasRank())
744+
return ConstantIntRanges::fromSigned(zero, typeMax);
745+
746+
int64_t rank = shapedTy.getRank();
747+
int64_t minDim = 0;
748+
int64_t maxDim = rank - 1;
749+
if (!maybeDim.isUninitialized()) {
750+
const ConstantIntRanges &dim = maybeDim.getValue();
751+
minDim = std::max(minDim, dim.smin().getSExtValue());
752+
maxDim = std::min(maxDim, dim.smax().getSExtValue());
753+
}
754+
755+
std::optional<ConstantIntRanges> result;
756+
auto joinResult = [&](const ConstantIntRanges &thisResult) {
757+
if (!result.has_value())
758+
result = thisResult;
759+
else
760+
result = result->rangeUnion(thisResult);
761+
};
762+
for (int64_t i = minDim; i <= maxDim; ++i) {
763+
int64_t length = shapedTy.getDimSize(i);
764+
765+
if (ShapedType::isDynamic(length))
766+
joinResult(ConstantIntRanges::fromSigned(zero, typeMax));
767+
else
768+
joinResult(ConstantIntRanges::constant(APInt(width, length)));
769+
}
770+
return result.value_or(ConstantIntRanges::fromSigned(zero, typeMax));
771+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: @dim_const
4+
// CHECK: %[[ret:.+]] = arith.constant 3 : index
5+
// CHECK: return %[[ret]]
6+
func.func @dim_const(%m: memref<3x5xi32>) -> index {
7+
%c0 = arith.constant 0 : index
8+
%0 = memref.dim %m, %c0 : memref<3x5xi32>
9+
return %0 : index
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: @dim_any_static
15+
// CHECK: %[[op:.+]] = memref.dim
16+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
17+
// CHECK: return %[[ret]]
18+
func.func @dim_any_static(%m: memref<3x5xi32>, %x: index) -> index {
19+
%0 = memref.dim %m, %x : memref<3x5xi32>
20+
%1 = test.reflect_bounds %0 : index
21+
return %1 : index
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: @dim_dynamic
27+
// CHECK: %[[op:.+]] = memref.dim
28+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
29+
// CHECK: return %[[ret]]
30+
func.func @dim_dynamic(%m: memref<?x5xi32>) -> index {
31+
%c0 = arith.constant 0 : index
32+
%0 = memref.dim %m, %c0 : memref<?x5xi32>
33+
%1 = test.reflect_bounds %0 : index
34+
return %1 : index
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: @dim_any_dynamic
40+
// CHECK: %[[op:.+]] = memref.dim
41+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
42+
// CHECK: return %[[ret]]
43+
func.func @dim_any_dynamic(%m: memref<?x5xi32>, %x: index) -> index {
44+
%0 = memref.dim %m, %x : memref<?x5xi32>
45+
%1 = test.reflect_bounds %0 : index
46+
return %1 : index
47+
}
48+
49+
// -----
50+
51+
// CHECK-LABEL: @dim_some_omitting_dynamic
52+
// CHECK: %[[op:.+]] = memref.dim
53+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
54+
// CHECK: return %[[ret]]
55+
func.func @dim_some_omitting_dynamic(%m: memref<?x3x5xi32>, %x: index) -> index {
56+
%c1 = arith.constant 1 : index
57+
%0 = arith.maxsi %x, %c1 : index
58+
%1 = memref.dim %m, %0 : memref<?x3x5xi32>
59+
%2 = test.reflect_bounds %1 : index
60+
return %2 : index
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: @dim_unranked
66+
// CHECK: %[[op:.+]] = memref.dim
67+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
68+
// CHECK: return %[[ret]]
69+
func.func @dim_unranked(%m: memref<*xi32>) -> index {
70+
%c0 = arith.constant 0 : index
71+
%0 = memref.dim %m, %c0 : memref<*xi32>
72+
%1 = test.reflect_bounds %0 : index
73+
return %1 : index
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: mlir-opt -int-range-optimizations --split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: @dim_const
4+
// CHECK: %[[ret:.+]] = arith.constant 3 : index
5+
// CHECK: return %[[ret]]
6+
func.func @dim_const(%t: tensor<3x5xi32>) -> index {
7+
%c0 = arith.constant 0 : index
8+
%0 = tensor.dim %t, %c0 : tensor<3x5xi32>
9+
return %0 : index
10+
}
11+
12+
// -----
13+
14+
// CHECK-LABEL: @dim_any_static
15+
// CHECK: %[[op:.+]] = tensor.dim
16+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
17+
// CHECK: return %[[ret]]
18+
func.func @dim_any_static(%t: tensor<3x5xi32>, %x: index) -> index {
19+
%0 = tensor.dim %t, %x : tensor<3x5xi32>
20+
%1 = test.reflect_bounds %0 : index
21+
return %1 : index
22+
}
23+
24+
// -----
25+
26+
// CHECK-LABEL: @dim_dynamic
27+
// CHECK: %[[op:.+]] = tensor.dim
28+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
29+
// CHECK: return %[[ret]]
30+
func.func @dim_dynamic(%t: tensor<?x5xi32>) -> index {
31+
%c0 = arith.constant 0 : index
32+
%0 = tensor.dim %t, %c0 : tensor<?x5xi32>
33+
%1 = test.reflect_bounds %0 : index
34+
return %1 : index
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: @dim_any_dynamic
40+
// CHECK: %[[op:.+]] = tensor.dim
41+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
42+
// CHECK: return %[[ret]]
43+
func.func @dim_any_dynamic(%t: tensor<?x5xi32>, %x: index) -> index {
44+
%0 = tensor.dim %t, %x : tensor<?x5xi32>
45+
%1 = test.reflect_bounds %0 : index
46+
return %1 : index
47+
}
48+
49+
// -----
50+
51+
// CHECK-LABEL: @dim_some_omitting_dynamic
52+
// CHECK: %[[op:.+]] = tensor.dim
53+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 5 : index, smin = 3 : index, umax = 5 : index, umin = 3 : index} %[[op]]
54+
// CHECK: return %[[ret]]
55+
func.func @dim_some_omitting_dynamic(%t: tensor<?x3x5xi32>, %x: index) -> index {
56+
%c1 = arith.constant 1 : index
57+
%0 = arith.maxsi %x, %c1 : index
58+
%1 = tensor.dim %t, %0 : tensor<?x3x5xi32>
59+
%2 = test.reflect_bounds %1 : index
60+
return %2 : index
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: @dim_unranked
66+
// CHECK: %[[op:.+]] = tensor.dim
67+
// CHECK: %[[ret:.+]] = test.reflect_bounds {smax = 9223372036854775807 : index, smin = 0 : index, umax = 9223372036854775807 : index, umin = 0 : index} %[[op]]
68+
// CHECK: return %[[ret]]
69+
func.func @dim_unranked(%t: tensor<*xi32>) -> index {
70+
%c0 = arith.constant 0 : index
71+
%0 = tensor.dim %t, %c0 : tensor<*xi32>
72+
%1 = test.reflect_bounds %0 : index
73+
return %1 : index
74+
}

0 commit comments

Comments
 (0)