Skip to content

Commit

Permalink
[FIRRTL] Backport IMDCE updates to sifive 1.5 (#3771)
Browse files Browse the repository at this point in the history
* [FIRRTL] Remove empty modules in IMDCE (#3378)

This commit extends IMDCE to delete empty modules. A module is empty if
it doesn't have any argument, operation and annotation. In the post-processing,
we traverse modules in a post-order of the instance graph, and remove empty
modules and their instances in a bottom-up manner.

* [IMDCE] Forward constant output ports to caller sides (#3688)

This PR makes IMDCE propagate constant output ports to caller sides before
actually performing dataflow analysis so that we can eliminate constant output ports.
  • Loading branch information
uenoku authored Aug 24, 2022
1 parent 69b3f68 commit c88b624
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 15 deletions.
98 changes: 98 additions & 0 deletions lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "circt/Dialect/FIRRTL/Passes.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Support/Debug.h"

Expand Down Expand Up @@ -38,6 +39,8 @@ struct IMDeadCodeElimPass : public IMDeadCodeElimBase<IMDeadCodeElimPass> {

void rewriteModuleSignature(FModuleOp module);
void rewriteModuleBody(FModuleOp module);
void eraseEmptyModule(FModuleOp module);
void forwardConstantOutputPort(FModuleOp module);

void markAlive(Value value) {
// If the value is already in `liveSet`, skip it.
Expand Down Expand Up @@ -177,11 +180,57 @@ void IMDeadCodeElimPass::markBlockExecutable(Block *block) {
}
}

void IMDeadCodeElimPass::forwardConstantOutputPort(FModuleOp module) {
// This tracks constant values of output ports.
SmallVector<std::pair<unsigned, APSInt>> constantPortIndicesAndValues;
auto ports = module.getPorts();
auto *instanceGraphNode = instanceGraph->lookup(module);

for (const auto &e : llvm::enumerate(ports)) {
unsigned index = e.index();
auto port = e.value();
auto arg = module.getArgument(index);

// If the port has don't touch, don't propagate the constant value.
if (!port.isOutput() || hasDontTouch(arg))
continue;

// Remember the index and constant value connected to an output port.
if (auto connect = getSingleConnectUserOf(arg))
if (auto constant = connect.src().getDefiningOp<ConstantOp>())
constantPortIndicesAndValues.push_back({index, constant.value()});
}

// If there is no constant port, abort.
if (constantPortIndicesAndValues.empty())
return;

// Rewrite all uses.
for (auto *use : instanceGraphNode->uses()) {
auto instance = cast<InstanceOp>(*use->getInstance());
ImplicitLocOpBuilder builder(instance.getLoc(), instance);
for (auto [index, constant] : constantPortIndicesAndValues) {
auto result = instance.getResult(index);
assert(ports[index].isOutput() && "must be an output port");

// Replace the port with the constant.
result.replaceAllUsesWith(builder.create<ConstantOp>(constant));
}
}
}

void IMDeadCodeElimPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "===----- Remove unused ports -----==="
<< "\n");
auto circuit = getOperation();
instanceGraph = &getAnalysis<InstanceGraph>();

// Forward constant output ports to caller sides so that we can eliminate
// constant outputs.
for (auto *node : llvm::post_order(instanceGraph))
if (auto module = dyn_cast_or_null<FModuleOp>(*node->getModule()))
forwardConstantOutputPort(module);

for (auto module : circuit.getBody()->getOps<FModuleOp>()) {
// Mark the ports of public modules as alive.
if (module.isPublic()) {
Expand All @@ -204,6 +253,19 @@ void IMDeadCodeElimPass::runOnOperation() {
mlir::parallelForEach(circuit.getContext(),
circuit.getBody()->getOps<FModuleOp>(),
[&](auto op) { rewriteModuleBody(op); });

// Erase empty modules. To erase empty modules transitively, it is necessary
// to visit modules in the post order of instance graph.
// FIXME: We copy the list of modules into a vector first to avoid iterator
// invalidation while we mutate the instance graph. See issue 3387.
SmallVector<FModuleOp, 0> modules(llvm::make_filter_range(
llvm::map_range(
llvm::post_order(instanceGraph),
[](auto *node) { return dyn_cast<FModuleOp>(*node->getModule()); }),
[](auto module) { return module; }));

for (auto module : modules)
eraseEmptyModule(module);
}

void IMDeadCodeElimPass::visitValue(Value value) {
Expand Down Expand Up @@ -399,6 +461,42 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) {
numRemovedPorts += deadPortIndexes.size();
}

void IMDeadCodeElimPass::eraseEmptyModule(FModuleOp module) {
// Public modules cannot be erased.
if (module.isPublic())
return;

// If the module doesn't have arguments, operations or annotations, we
// consider it to be dead.
if (!module.getBody()->args_empty() || !module.getBody()->empty() ||
!module.annotations().empty())
return;

// Ok, the module is empty. Delete instances unless they have symbols.
LLVM_DEBUG(llvm::dbgs() << "Erase " << module.getName() << "\n");

InstanceGraphNode *instanceGraphNode =
instanceGraph->lookup(module.moduleNameAttr());

bool existsInstanceWithSymbol = false;
for (auto *use : llvm::make_early_inc_range(instanceGraphNode->uses())) {
auto instance = cast<InstanceOp>(use->getInstance());
if (instance.inner_sym()) {
existsInstanceWithSymbol = true;
continue;
}
use->erase();
instance.erase();
}

// If there is an instance with a symbol, we don't delete the module itself.
if (existsInstanceWithSymbol)
return;

instanceGraph->erase(instanceGraphNode);
module.erase();
}

std::unique_ptr<mlir::Pass> circt::firrtl::createIMDeadCodeElimPass() {
return std::make_unique<IMDeadCodeElimPass>();
}
53 changes: 41 additions & 12 deletions test/Dialect/FIRRTL/imdce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ firrtl.circuit "top" {
// In `dead_module`, %source is connected to %dest through several dead operations such as
// node, wire, reg or rgereset. %dest is also dead at any instantiation, so check that
// all operations are removed by IMDeadCodeElim pass.
// CHECK-LABEL: private @dead_module() {
// CHECK-NEXT: }
// CHECK-NOT: @dead_module
firrtl.module private @dead_module(in %source: !firrtl.uint<1>, out %dest: !firrtl.uint<1>,
in %clock:!firrtl.clock, in %reset:!firrtl.uint<1>) {
%dead_node = firrtl.node %source: !firrtl.uint<1>
Expand Down Expand Up @@ -54,8 +53,7 @@ firrtl.circuit "top" {
%tmp = firrtl.node %source: !firrtl.uint<1>
firrtl.strictconnect %dest, %tmp : !firrtl.uint<1>

// TODO: Remove instances of empty modules.
// CHECK-NEXT: firrtl.instance dead_module @dead_module()
// CHECK-NOT: @dead_module
%source1, %dest1, %clock1, %reset1 = firrtl.instance dead_module @dead_module(in source: !firrtl.uint<1>, out dest: !firrtl.uint<1>, in clock:!firrtl.clock, in reset:!firrtl.uint<1>)
firrtl.strictconnect %source1, %source : !firrtl.uint<1>
firrtl.strictconnect %clock1, %clock : !firrtl.clock
Expand All @@ -82,22 +80,18 @@ firrtl.circuit "top" {

// Check that it's possible to analyze complex dependency across different modules.
firrtl.circuit "top" {
// CHECK-LABEL: firrtl.module private @Child1() {
// CHECK-NEXT: }
// CHECK-NOT: @Child1
firrtl.module private @Child1(in %input: !firrtl.uint<1>, out %output: !firrtl.uint<1>) {
firrtl.strictconnect %output, %input : !firrtl.uint<1>
}
// CHECK-LABEL: firrtl.module private @Child2() {
// CHECK-NEXT: }
// CHECK-NOT: @Child2
firrtl.module private @Child2(in %input: !firrtl.uint<1>, in %clock: !firrtl.clock, out %output: !firrtl.uint<1>) {
%r = firrtl.reg %clock : !firrtl.uint<1>
firrtl.strictconnect %r, %input : !firrtl.uint<1>
firrtl.strictconnect %output, %r : !firrtl.uint<1>
}

// CHECK-LABEL: firrtl.module @top(in %clock: !firrtl.clock, in %input: !firrtl.uint<1>) {
// CHECK-NEXT: firrtl.instance tile @Child1()
// CHECK-NEXT: firrtl.instance bar @Child2()
// CHECK-NEXT: }
firrtl.module @top(in %clock: !firrtl.clock, in %input: !firrtl.uint<1>) {
%tile_input, %tile_output = firrtl.instance tile @Child1(in input: !firrtl.uint<1>, out output: !firrtl.uint<1>)
Expand Down Expand Up @@ -142,12 +136,47 @@ firrtl.circuit "UnusedOutput" {
// CHECK-LABEL: "PreserveOutputFile"
firrtl.circuit "PreserveOutputFile" {
// CHECK-NEXT: firrtl.module {{.+}}@Sub
// CHECK-NOT: %a
// CHECK-SAME: output_file
firrtl.module private @Sub(in %a: !firrtl.uint<1>) attributes {output_file = #hw.output_file<"hello">} {}
firrtl.module private @Sub(in %a: !firrtl.uint<1>, in %b: !firrtl.uint<1> sym @sym) attributes {output_file = #hw.output_file<"hello">} {}
// CHECK: firrtl.module @PreserveOutputFile
firrtl.module @PreserveOutputFile() {
// CHECK-NEXT: firrtl.instance sub
// CHECK-SAME: output_file
firrtl.instance sub {output_file = #hw.output_file<"hello">} @Sub(in a: !firrtl.uint<1>)
firrtl.instance sub {output_file = #hw.output_file<"hello">} @Sub(in a: !firrtl.uint<1>, in b: !firrtl.uint<1>)
}
}

// -----

// CHECK-LABEL: "DeleteEmptyModule"
firrtl.circuit "DeleteEmptyModule" {
// Don't delete @Sub because instance `sub1` has a symbol.
// CHECK: firrtl.module private @Sub
firrtl.module private @Sub(in %a: !firrtl.uint<1>) {}
// CHECK: firrtl.module @DeleteEmptyModule
firrtl.module @DeleteEmptyModule() {
// CHECK-NEXT: firrtl.instance sub1 sym @Foo @Sub()
firrtl.instance sub1 sym @Foo @Sub(in a: !firrtl.uint<1>)
// CHECK-NOT: sub2
firrtl.instance sub2 @Sub(in a: !firrtl.uint<1>)
}
}

// -----

// CHECK-LABEL: "ForwardConstant"
firrtl.circuit "ForwardConstant" {
// CHECK-NOT: Zero
firrtl.module private @Zero(out %zero: !firrtl.uint<1>) {
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
firrtl.strictconnect %zero, %c0_ui1 : !firrtl.uint<1>
}
// CHECK-LABEL: @ForwardConstant
firrtl.module @ForwardConstant(out %zero: !firrtl.uint<1>) {
// CHECK: %c0_ui1 = firrtl.constant 0
%sub_zero = firrtl.instance sub @Zero(out zero: !firrtl.uint<1>)
// CHECK-NEXT: firrtl.strictconnect %zero, %c0_ui1
firrtl.strictconnect %zero, %sub_zero : !firrtl.uint<1>
}
}
3 changes: 0 additions & 3 deletions test/firtool/firtool.fir
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ circuit test_mod : %[[{"a": "a"}]]
; MLIR-NEXT: firrtl.strictconnect %multibitMux_a_1, %vec_1 : !firrtl.uint<1>
; MLIR-NEXT: firrtl.strictconnect %multibitMux_a_2, %vec_2 : !firrtl.uint<1>
; MLIR-NEXT: firrtl.strictconnect %multibitMux_sel, %b : !firrtl.uint<2>
; MLIR-NEXT: firrtl.instance unusedPortsMod interesting_name @UnusedPortsMod()

; ANNOTATIONS-LABEL: firrtl.module @test_mod
; ANNOTATIONS-SAME: info = "a ModuleTarget Annotation"
Expand Down Expand Up @@ -173,7 +172,6 @@ circuit test_mod : %[[{"a": "a"}]]
; VERILOG-NEXT: .sel (b),
; VERILOG-NEXT: .b (out_multibitMux)
; VERILOG-NEXT: );
; VERILOG-NEXT: UnusedPortsMod unusedPortsMod ();
; VERILOG: endmodule

; Check that we canonicalize the HW output of lowering.
Expand Down Expand Up @@ -273,4 +271,3 @@ circuit test_mod : %[[{"a": "a"}]]
input in : UInt<1>
output out : UInt<1>
out is invalid
; VERILOG-LABEL: module UnusedPortsMod();

0 comments on commit c88b624

Please sign in to comment.