Skip to content

Commit cdc09a1

Browse files
authored
[mlir][IntRangeInference] Infer values for {memref,tensor}.dim (#122945)
Implement the integer range inference niterface for memref.dim and tetnor.dim using shared code. The inference will infer the `dim` of dynamic dimensions to [0, index_max] and take the union of all the dimensions that the `dim` argument could be validly referring to.
1 parent de7438e commit cdc09a1

File tree

13 files changed

+229
-3
lines changed

13 files changed

+229
-3
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Interfaces/CastInterfaces.h"
1818
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1919
#include "mlir/Interfaces/CopyOpInterface.h"
20+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
2122
#include "mlir/Interfaces/MemorySlotInterfaces.h"
2223
#include "mlir/Interfaces/ShapedOpInterfaces.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ include "mlir/Dialect/MemRef/IR/MemRefBase.td"
1414
include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/CopyOpInterface.td"
17+
include "mlir/Interfaces/InferIntRangeInterface.td"
1718
include "mlir/Interfaces/InferTypeOpInterface.td"
1819
include "mlir/Interfaces/MemorySlotInterfaces.td"
1920
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -577,7 +578,8 @@ def MemRef_DimOp : MemRef_Op<"dim", [
577578
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
578579
MemRefsNormalizable,
579580
ConditionallySpeculatable, NoMemoryEffect,
580-
ShapedDimOpInterface]> {
581+
ShapedDimOpInterface,
582+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
581583
let summary = "dimension index operation";
582584
let description = [{
583585
The `dim` operation takes a memref and a dimension operand of type `index`.
@@ -1675,7 +1677,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16751677
}]>,
16761678

16771679
// Builder that infers the result layout map. The result shape must be
1678-
// specified. Otherwise, the op may be ambiguous. The output shape for
1680+
// specified. Otherwise, the op may be ambiguous. The output shape for
16791681
// the op will be inferred using the inferOutputShape() method.
16801682
OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
16811683
"ArrayRef<ReassociationIndices>":$reassociation)>,

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Interfaces/CastInterfaces.h"
1919
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2020
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
21+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2122
#include "mlir/Interfaces/InferTypeOpInterface.h"
2223
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
2324
#include "mlir/Interfaces/ShapedOpInterfaces.h"

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td"
1313
include "mlir/Interfaces/CastInterfaces.td"
1414
include "mlir/Interfaces/ControlFlowInterfaces.td"
1515
include "mlir/Interfaces/DestinationStyleOpInterface.td"
16+
include "mlir/Interfaces/InferIntRangeInterface.td"
1617
include "mlir/Interfaces/InferTypeOpInterface.td"
1718
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
1819
include "mlir/Interfaces/ShapedOpInterfaces.td"
@@ -197,7 +198,8 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
197198
def Tensor_DimOp : Tensor_Op<"dim", [
198199
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
199200
ConditionallySpeculatable, NoMemoryEffect,
200-
ShapedDimOpInterface]> {
201+
ShapedDimOpInterface,
202+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>]> {
201203
let summary = "dimension index operation";
202204
let description = [{
203205
The `tensor.dim` operation takes a tensor and a dimension operand of type

mlir/include/mlir/Interfaces/Utils/InferIntRangeCommon.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#include <optional>
2121

2222
namespace mlir {
23+
class ShapedDimOpInterface;
24+
2325
namespace intrange {
2426
/// Function that performs inference on an array of `ConstantIntRanges`,
2527
/// abstracted away here to permit writing the function that handles both
@@ -143,6 +145,12 @@ std::optional<bool> evaluatePred(CmpPredicate pred,
143145
const ConstantIntRanges &lhs,
144146
const ConstantIntRanges &rhs);
145147

148+
/// Returns the integer range for the result of a `ShapedDimOpInterface` given
149+
/// the optional inferred ranges for the `dimension` index `maybeDim`. When a
150+
/// dynamic dimension is encountered, returns [0, signed_max(type(result))].
151+
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op,
152+
const IntegerValueRange &maybeDim);
153+
146154
} // namespace intrange
147155
} // namespace mlir
148156

mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ add_mlir_dialect_library(MLIRMemRefDialect
1616
MLIRControlFlowInterfaces
1717
MLIRDialect
1818
MLIRDialectUtils
19+
MLIRInferIntRangeCommon
20+
MLIRInferIntRangeInterface
1921
MLIRInferTypeOpInterface
2022
MLIRIR
2123
MLIRMemorySlotInterfaces

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/TypeUtilities.h"
2020
#include "mlir/Interfaces/InferTypeOpInterface.h"
2121
#include "mlir/Interfaces/SideEffectInterfaces.h"
22+
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
2223
#include "mlir/Interfaces/ViewLikeInterface.h"
2324
#include "llvm/ADT/STLExtras.h"
2425
#include "llvm/ADT/SmallBitVector.h"
@@ -915,6 +916,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
915916
return Speculation::Speculatable;
916917
}
917918

919+
void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
920+
SetIntLatticeFn setResultRange) {
921+
setResultRange(getResult(),
922+
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
923+
}
924+
918925
/// Return a map with key being elements in `vals` and data being number of
919926
/// occurences of it. Use std::map, since the `vals` here are strides and the
920927
/// dynamic stride value is the same as the tombstone value for

mlir/lib/Dialect/Tensor/IR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ add_mlir_dialect_library(MLIRTensorDialect
2626
MLIRDestinationStyleOpInterface
2727
MLIRDialectUtils
2828
MLIRIR
29+
MLIRInferIntRangeCommon
30+
MLIRInferIntRangeInterface
2931
MLIRInferTypeOpInterface
3032
MLIRParallelCombiningOpInterface
3133
MLIRShapedOpInterfaces

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
#include "mlir/IR/OpDefinition.h"
2424
#include "mlir/IR/TypeUtilities.h"
2525
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
26+
#include "mlir/Interfaces/InferIntRangeInterface.h"
2627
#include "mlir/Interfaces/LoopLikeInterface.h"
28+
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
2729
#include "mlir/Support/LLVM.h"
2830
#include "llvm/ADT/DenseSet.h"
2931
#include "llvm/ADT/STLExtras.h"
@@ -782,6 +784,12 @@ Speculation::Speculatability DimOp::getSpeculatability() {
782784
return Speculation::Speculatable;
783785
}
784786

787+
void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
788+
SetIntLatticeFn setResultRange) {
789+
setResultRange(getResult(),
790+
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
791+
}
792+
785793
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
786794
// All forms of folding require a known index.
787795
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());

mlir/lib/Interfaces/Utils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_library(MLIRInferIntRangeCommon
88
MLIRInferIntRangeInterfaceIncGen
99

1010
LINK_LIBS PUBLIC
11+
MLIRShapedOpInterfaces
1112
MLIRInferIntRangeInterface
1213
MLIRIR
1314
)

0 commit comments

Comments
 (0)