Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 48f7ec8

Browse files
River707tensorflower-gardener
authored andcommitted
Add support for inlining toy call operations.
The GenericCallOp needed to have the CallOpInterface to be picked up by the inliner. This also adds a CastOp to perform shape casts that are generated during inlining. The casts generated by the inliner will be folded away after shape inference. PiperOrigin-RevId: 275150438
1 parent b682c8a commit 48f7ec8

File tree

7 files changed

+97
-7
lines changed

7 files changed

+97
-7
lines changed

examples/toy/Ch4/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_toy_chapter(toyc-ch4
2121
add_dependencies(toyc-ch4 ToyCh4OpsIncGen)
2222
add_dependencies(toyc-ch4 ToyCh4ShapeInferenceInterfaceIncGen)
2323
add_dependencies(toyc-ch4 ToyCh4CombineIncGen)
24+
add_dependencies(toyc-ch4 MLIRCallOpInterfacesIncGen)
2425
include_directories(include/)
2526
include_directories(${CMAKE_CURRENT_BINARY_DIR})
2627
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/)

examples/toy/Ch4/include/toy/Ops.td

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
#else
2424
#define TOY_OPS
2525

26+
#ifdef MLIR_CALLINTERFACES
27+
#else
28+
include "mlir/Analysis/CallInterfaces.td"
29+
#endif // MLIR_CALLINTERFACES
30+
2631
#ifdef SHAPE_INFERENCE_INTERFACE
2732
#else
2833
include "toy/ShapeInferenceInterface.td"
@@ -111,7 +116,27 @@ def AddOp : Toy_Op<"add",
111116
>];
112117
}
113118

114-
def GenericCallOp : Toy_Op<"generic_call"> {
119+
def CastOp : Toy_Op<"cast",
120+
[DeclareOpInterfaceMethods<ShapeInferenceOpInterface>, NoSideEffect,
121+
SameOperandsAndResultShape]> {
122+
let summary = "shape cast operation";
123+
let description = [{
124+
The "cast" operation converts a tensor from one type to an equivalent type
125+
without changing any data elements. The source and destination types
126+
must both be tensor types with the same element type. If both are ranked
127+
then the rank should be the same and static dimensions should match. The
128+
operation is invalid if converting to a mismatching constant dimension.
129+
}];
130+
131+
let arguments = (ins F64Tensor:$input);
132+
let results = (outs F64Tensor:$output);
133+
134+
// Set the folder bit so that we can fold redundant cast operations.
135+
let hasFolder = 1;
136+
}
137+
138+
def GenericCallOp : Toy_Op<"generic_call",
139+
[DeclareOpInterfaceMethods<CallOpInterface>]> {
115140
let summary = "generic call operation";
116141
let description = [{
117142
Generic calls represent calls to a user defined function that needs to

examples/toy/Ch4/mlir/Dialect.cpp

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,17 @@ struct ToyInlinerInterface : public DialectInlinerInterface {
6464
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
6565
valuesToRepl[it.index()]->replaceAllUsesWith(it.value());
6666
}
67+
68+
/// Attempts to materialize a conversion for a type mismatch between a call
69+
/// from this dialect, and a callable region. This method should generate an
70+
/// operation that takes 'input' as the only operand, and produces a single
71+
/// result of 'resultType'. If a conversion can not be generated, nullptr
72+
/// should be returned.
73+
Operation *materializeCallConversion(OpBuilder &builder, Value *input,
74+
Type resultType,
75+
Location conversionLoc) const final {
76+
return builder.create<CastOp>(conversionLoc, resultType, input);
77+
}
6778
};
6879

6980
//===----------------------------------------------------------------------===//
@@ -94,7 +105,12 @@ static void buildConstantOp(mlir::Builder *builder, mlir::OperationState &state,
94105
ConstantOp::build(builder, state, dataType, dataAttribute);
95106
}
96107

97-
/// Verifier for constant operation.
108+
/// Infer the output shape of the CastOp, this is required by the shape
109+
/// inference interface.
110+
void CastOp::inferShapes() { getResult()->setType(getOperand()->getType()); }
111+
112+
/// Verifier for the constant operation. This corresponds to the `::verify(...)`
113+
/// in the op definition.
98114
static mlir::LogicalResult verify(ConstantOp op) {
99115
// If the return type of the constant is not an unranked tensor, the shape
100116
// must match the shape of the attribute holding the data.
@@ -139,6 +155,16 @@ static void buildGenericCallOp(mlir::Builder *builder,
139155
state.addAttribute("callee", builder->getSymbolRefAttr(callee));
140156
}
141157

158+
/// Return the callee of the generic call operation, this is required by the
159+
/// call interface.
160+
CallInterfaceCallable GenericCallOp::getCallableForCallee() {
161+
return getAttrOfType<SymbolRefAttr>("callee");
162+
}
163+
164+
/// Get the argument operands to the called function, this is required by the
165+
/// call interface.
166+
Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
167+
142168
static void buildMulOp(mlir::Builder *builder, mlir::OperationState &state,
143169
mlir::Value *lhs, mlir::Value *rhs) {
144170
state.addTypes(builder->getTensorType(builder->getF64Type()));

examples/toy/Ch4/mlir/ShapeInferencePass.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,13 @@ class ShapeInferencePass : public mlir::FunctionPass<ShapeInferencePass> {
8080

8181
// Ask the operation to infer its output shapes.
8282
LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
83-
auto shapeOp = dyn_cast<ShapeInference>(op);
84-
shapeOp.inferShapes();
83+
if (auto shapeOp = dyn_cast<ShapeInference>(op)) {
84+
shapeOp.inferShapes();
85+
} else {
86+
op->emitError("unable to infer shape of operation without shape "
87+
"inference interface");
88+
return signalPassFailure();
89+
}
8590
}
8691

8792
// If the operation worklist isn't empty, this indicates a failure.

examples/toy/Ch4/mlir/ToyCombine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ namespace {
3232
#include "ToyCombine.inc"
3333
} // end anonymous namespace
3434

35+
/// Fold simple cast operations that return the same type as the input.
36+
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
37+
return mlir::impl::foldCastOp(*this);
38+
}
39+
3540
/// This is an example of a c++ rewrite pattern for the TransposeOp. It
3641
/// optimizes the following scenario: transpose(transpose(x)) -> transpose(x)
3742
struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {

examples/toy/Ch4/toyc.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,14 @@ int dumpMLIR() {
122122
// Apply any generic pass manager command line options and run the pipeline.
123123
applyPassManagerCLOptions(pm);
124124

125-
// Add a run of the canonicalizer to optimize the mlir module.
126-
pm.addPass(mlir::createCanonicalizerPass());
127-
128125
// Inline all functions into main and then delete them.
129126
pm.addPass(mlir::createInlinerPass());
130127
pm.addPass(mlir::toy::createDeadFunctionEliminationPass());
131128

132129
// Now that there is only one function, we can infer the shapes of each of
133130
// the operations.
134131
pm.addPass(mlir::toy::createShapeInferencePass());
132+
pm.addPass(mlir::createCanonicalizerPass());
135133

136134
if (mlir::failed(pm.run(*module)))
137135
return 4;
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: toyc-ch4 %s -emit=mlir -opt 2>&1 | FileCheck %s
2+
3+
// Check the result of inlining+shape inference on an input module.
4+
5+
func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
6+
%0 = "toy.transpose"(%arg1) : (tensor<*xf64>) -> tensor<*xf64>
7+
%1 = "toy.mul"(%arg0, %0) : (tensor<*xf64>, tensor<*xf64>) -> tensor<*xf64>
8+
"toy.return"(%1) : (tensor<*xf64>) -> ()
9+
}
10+
func @main() {
11+
%0 = "toy.constant"() {value = dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
12+
%1 = "toy.reshape"(%0) : (tensor<2x3xf64>) -> tensor<2x3xf64>
13+
%2 = "toy.constant"() {value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>} : () -> tensor<6xf64>
14+
%3 = "toy.reshape"(%2) : (tensor<6xf64>) -> tensor<2x3xf64>
15+
%4 = "toy.generic_call"(%1, %3) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
16+
%5 = "toy.generic_call"(%3, %1) {callee = @multiply_transpose} : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
17+
"toy.print"(%5) : (tensor<*xf64>) -> ()
18+
"toy.return"() : () -> ()
19+
}
20+
21+
// CHECK-NOT: func @multiply_transpose
22+
// CHECK-NOT: tensor<*xf64>
23+
24+
// CHECK-LABEL: func @main() {
25+
// CHECK: [[VAL_0:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
26+
// CHECK: [[VAL_1:%.*]] = "toy.constant"() {value = dense<{{\[\[}}1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>} : () -> tensor<2x3xf64>
27+
// CHECK: [[VAL_2:%.*]] = "toy.transpose"([[VAL_0]]) : (tensor<2x3xf64>) -> tensor<3x2xf64>
28+
// CHECK: [[VAL_3:%.*]] = "toy.mul"([[VAL_1]], [[VAL_2]]) : (tensor<2x3xf64>, tensor<3x2xf64>) -> tensor<2x2xf64>
29+
// CHECK: "toy.print"([[VAL_3]]) : (tensor<2x2xf64>) -> ()
30+
// CHECK: "toy.return"() : () -> ()

0 commit comments

Comments
 (0)