Skip to content

Commit 5d10e22

Browse files
Merge pull request tensorflow#51984 from srishti-pm:srishti/introduce_tf_if_to_scf_if_lowering
PiperOrigin-RevId: 405421029 Change-Id: I36c64e8f970dd0d529714a61214f115a922de208
2 parents 1f535da + d305672 commit 5d10e22

File tree

6 files changed

+200
-0
lines changed

6 files changed

+200
-0
lines changed

tensorflow/compiler/mlir/tensorflow/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,7 @@ cc_library(
11031103
"transforms/constant_op_device_assignment.cc",
11041104
"transforms/convert_control_to_data_outputs.cc",
11051105
"transforms/convert_launch_func_to_tf_call.cc",
1106+
"transforms/convert_tf_control_flow_to_scf.cc",
11061107
"transforms/decompose_resource_ops_pass.cc",
11071108
"transforms/device_attribute_to_launch.cc",
11081109
"transforms/device_index_selector.cc",
@@ -1259,6 +1260,7 @@ cc_library(
12591260
"@llvm-project//mlir:Parser",
12601261
"@llvm-project//mlir:Pass",
12611262
"@llvm-project//mlir:Rewrite",
1263+
"@llvm-project//mlir:SCFDialect",
12621264
"@llvm-project//mlir:StandardOps",
12631265
"@llvm-project//mlir:Support",
12641266
"@llvm-project//mlir:TensorDialect",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: tf-opt -convert-tf-control-flow-to-scf %s | FileCheck %s
2+
3+
// `tf.IfRegion` which returns values gets converted to `scf.if`.
4+
func private @test_if_then1(tensor<4xf32>) -> tensor<4xf32>
5+
func private @test_if_else1(tensor<4xf32>) -> tensor<4xf32>
6+
// CHECK-LABEL: func @test_supported_lowering_of_tf_if_region1
7+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<i1>, %[[ARG1:.*]]: tensor<4xf32>)
8+
func @test_supported_lowering_of_tf_if_region1(%arg0: tensor<i1>, %arg1: tensor<4xf32>) -> (tensor<*xf32>, tensor<4xf32>) {
9+
%res:2 = "tf.IfRegion"(%arg0) ( {
10+
%call = call @test_if_then1(%arg1) : (tensor<4xf32>) -> tensor<4xf32>
11+
%add = "tf.AddV2"(%call, %call) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
12+
"tf.Yield"(%call, %add) : (tensor<4xf32>, tensor<4xf32>) -> ()
13+
}, {
14+
%call_0 = call @test_if_else1(%arg1) : (tensor<4xf32>) -> tensor<4xf32>
15+
"tf.Yield"(%call_0, %call_0) : (tensor<4xf32>, tensor<4xf32>) -> ()
16+
}) {is_stateless = false} : (tensor<i1>) -> (tensor<*xf32>, tensor<4xf32>)
17+
return %res#0, %res#1 : tensor<*xf32>, tensor<4xf32>
18+
19+
// CHECK-NEXT: %[[COND:.*]] = tensor.extract %[[ARG0]][] : tensor<i1>
20+
// CHECK-NEXT: %[[RES:.*]]:2 = scf.if %[[COND]] -> (tensor<*xf32>, tensor<4xf32>) {
21+
// CHECK-NEXT: %[[CALL:.*]] = call @test_if_then1(%[[ARG1]]) : (tensor<4xf32>) -> tensor<4xf32>
22+
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CALL]], %[[CALL]]) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
23+
// CHECK-NEXT: %[[CAST:.*]] = "tf.Cast"(%[[CALL]]) {Truncate = false} : (tensor<4xf32>) -> tensor<*xf32>
24+
// CHECK-NEXT: scf.yield %[[CAST]], %[[ADD]] : tensor<*xf32>, tensor<4xf32>
25+
// CHECK-NEXT: } else {
26+
// CHECK-NEXT: %[[CALL_0:.*]] = call @test_if_else1(%[[ARG1]]) : (tensor<4xf32>) -> tensor<4xf32>
27+
// CHECK-NEXT: %[[CAST_0:.*]] = "tf.Cast"(%[[CALL_0]]) {Truncate = false} : (tensor<4xf32>) -> tensor<*xf32>
28+
// CHECK-NEXT: scf.yield %[[CAST_0]], %[[CALL_0]] : tensor<*xf32>, tensor<4xf32>
29+
// CHECK-NEXT: }
30+
// CHECK-NEXT: return %[[RES]]#0, %[[RES]]#1 : tensor<*xf32>, tensor<4xf32>
31+
}
32+
33+
// `tf.IfRegion` which doesn't return values gets converted to `scf.if`.
34+
func private @test_if_then2(tensor<4xf32>) -> ()
35+
func private @test_if_else2(tensor<4xf32>) -> ()
36+
// CHECK-LABEL: func @test_supported_lowering_of_tf_if_region2
37+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<i1>, %[[ARG1:.*]]: tensor<4xf32>)
38+
func @test_supported_lowering_of_tf_if_region2(%arg0: tensor<i1>, %arg1: tensor<4xf32>) -> () {
39+
"tf.IfRegion"(%arg0) ( {
40+
call @test_if_then2(%arg1) : (tensor<4xf32>) -> ()
41+
"tf.Yield"() : () -> ()
42+
}, {
43+
call @test_if_else2(%arg1) : (tensor<4xf32>) -> ()
44+
"tf.Yield"() : () -> ()
45+
}) {is_stateless = false} : (tensor<i1>) -> ()
46+
return
47+
48+
// CHECK-NEXT: %[[COND:.*]] = tensor.extract %[[ARG0]][] : tensor<i1>
49+
// CHECK-NEXT: scf.if %[[COND]] {
50+
// CHECK-NEXT: call @test_if_then2(%[[ARG1]]) : (tensor<4xf32>) -> ()
51+
// CHECK-NEXT: } else {
52+
// CHECK-NEXT: call @test_if_else2(%[[ARG1]]) : (tensor<4xf32>) -> ()
53+
// CHECK-NEXT: }
54+
// CHECK-NEXT: return
55+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
17+
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
19+
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20+
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21+
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22+
23+
namespace mlir {
24+
namespace TF {
25+
26+
namespace {
27+
28+
/// Convert the `tf.IfRegion` op to the `scf.if` op.
29+
class ConvertIfRegionOp : public OpRewritePattern<IfRegionOp> {
30+
public:
31+
using OpRewritePattern<IfRegionOp>::OpRewritePattern;
32+
33+
LogicalResult matchAndRewrite(IfRegionOp op,
34+
PatternRewriter& rewriter) const override {
35+
// Creates the `then` or `else` region of the `scf.if` op. Note that
36+
// `tf_then_or_else_region` is the `then` or `else` region of the
37+
// `tf.IfRegion` op and `scf_then_or_else_region` is the `then` or `else`
38+
// region of the new `scf.if` op. Further, `tf_if_region_return_type` is the
39+
// list of return types of the `tf.IfRegion` op.
40+
auto createScfThenOrElse = [](Region& tf_then_or_else_region,
41+
Region& scf_then_or_else_region,
42+
TypeRange tf_if_region_return_type,
43+
PatternRewriter& rewriter) {
44+
// Clone all the ops of `tf_then_or_else_region` into
45+
// `scf_then_or_else_region`.
46+
rewriter.cloneRegionBefore(tf_then_or_else_region,
47+
&scf_then_or_else_region.front());
48+
rewriter.eraseBlock(&scf_then_or_else_region.back());
49+
50+
Block* first_block_of_scf_then_or_else_region =
51+
&scf_then_or_else_region.front();
52+
53+
// Replace the current terminator (a `tf.Yield` op) with an `scf.yield`
54+
// op. The input of the `scf.yield` op is a list of results of `tf.Cast`
55+
// ops, each of which casts an operand of the current terminator to the
56+
// corresponding result type of the `tf.IfRegion` op.
57+
Operation* current_terminator =
58+
first_block_of_scf_then_or_else_region->getTerminator();
59+
rewriter.setInsertionPoint(current_terminator);
60+
SmallVector<Value, 4> scf_yield_input;
61+
for (auto it : llvm::zip(tf_if_region_return_type,
62+
current_terminator->getOperands())) {
63+
scf_yield_input.push_back(rewriter.create<CastOp>(
64+
current_terminator->getLoc(), std::get<0>(it), std::get<1>(it)));
65+
}
66+
67+
rewriter.replaceOpWithNewOp<scf::YieldOp>(current_terminator,
68+
scf_yield_input);
69+
};
70+
71+
Location loc = op.getLoc();
72+
73+
// The condition of an `scf.if` op is a 1-bit signless integer. Whereas, the
74+
// condition of the `tf.IfRegion` op is a 0-D tensor of 1-bit signless
75+
// integers. Thus, we use the `tensor.extract` op to compute the condition
76+
// of `scf.if` from that of `tf.IfRegion`.
77+
auto scf_if_condition = rewriter.create<tensor::ExtractOp>(loc, op.cond());
78+
79+
TypeRange tf_if_region_return_type = op.getResultTypes();
80+
81+
// Create the `scf.if` op.
82+
auto scf_if_op =
83+
rewriter.create<scf::IfOp>(loc, tf_if_region_return_type,
84+
scf_if_condition, /*withElseRegion=*/true);
85+
86+
Region& then_region = op.then_branch();
87+
Region& else_region = op.else_branch();
88+
89+
// Create the `then` and `else` regions of the `scf.if` op.
90+
createScfThenOrElse(then_region, scf_if_op.thenRegion(),
91+
tf_if_region_return_type, rewriter);
92+
createScfThenOrElse(else_region, scf_if_op.elseRegion(),
93+
tf_if_region_return_type, rewriter);
94+
95+
// Replace the `tf.IfRegion` op with the results of the `scf.if` op.
96+
rewriter.replaceOp(op, scf_if_op.getResults());
97+
return success();
98+
}
99+
};
100+
101+
} // end anonymous namespace
102+
103+
void populateTfControlFlowToScfPatterns(MLIRContext* context,
104+
OwningRewritePatternList* patterns) {
105+
patterns->insert<ConvertIfRegionOp>(context);
106+
}
107+
108+
struct ConvertTfControlFlowToScf
109+
: public ConvertTfControlFlowToScfPassBase<ConvertTfControlFlowToScf> {
110+
void runOnOperation() override {
111+
OwningRewritePatternList patterns(&getContext());
112+
populateTfControlFlowToScfPatterns(&getContext(), &patterns);
113+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
114+
}
115+
};
116+
117+
std::unique_ptr<OperationPass<ModuleOp>> createConvertTfControlFlowToScfPass() {
118+
return std::make_unique<ConvertTfControlFlowToScf>();
119+
}
120+
121+
} // namespace TF
122+
} // end namespace mlir

tensorflow/compiler/mlir/tensorflow/transforms/passes.h

+5
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,11 @@ CreateTensorDeviceCopyConversionPass();
113113
// have built in broadcasting support.
114114
std::unique_ptr<OperationPass<FuncOp>> CreateBroadcastFoldPass();
115115

116+
void populateTfControlFlowToScfPatterns(MLIRContext* context,
117+
OwningRewritePatternList* patterns);
118+
// Create a pass to convert TensorFlow control flow to SCF.
119+
std::unique_ptr<OperationPass<ModuleOp>> createConvertTfControlFlowToScfPass();
120+
116121
struct LayoutOptimizationPipelineOptions
117122
: public PassPipelineOptions<LayoutOptimizationPipelineOptions> {
118123
Option<std::string> force_data_format{

tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ limitations under the License.
2121
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
2222

2323
namespace mlir {
24+
namespace scf {
25+
class SCFDialect;
26+
}
2427
namespace tensor {
2528
class TensorDialect;
2629
}

tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td

+13
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,19 @@ def ClusterOutliningPass : Pass<"tf-device-cluster-outlining", "ModuleOp"> {
505505
let constructor = "TFDevice::CreateClusterOutliningPass()";
506506
}
507507

508+
def ConvertTfControlFlowToScfPass : Pass<"convert-tf-control-flow-to-scf", "ModuleOp"> {
509+
let summary = "Convert TensorFlow control flow to SCF.";
510+
511+
let description = [{
512+
This pass can be used for all direct control flow lowerings from the TensorFlow
513+
dialect to the SCF dialect.
514+
}];
515+
516+
let dependentDialects = ["scf::SCFDialect", "tensor::TensorDialect"];
517+
518+
let constructor = "TF::createConvertTfControlFlowToScfPass()";
519+
}
520+
508521
def LaunchOutliningPass : Pass<"tf-device-launch-outlining", "ModuleOp"> {
509522
let summary = "Outlines regions of tf_device.launch operations";
510523

0 commit comments

Comments
 (0)