diff --git a/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp b/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp index 095c767ea301..a3624039ff3b 100644 --- a/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp +++ b/lib/Dialect/FIRRTL/Transforms/IMDeadCodeElim.cpp @@ -11,6 +11,7 @@ #include "circt/Dialect/FIRRTL/Passes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Threading.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Support/Debug.h" @@ -38,6 +39,8 @@ struct IMDeadCodeElimPass : public IMDeadCodeElimBase { void rewriteModuleSignature(FModuleOp module); void rewriteModuleBody(FModuleOp module); + void eraseEmptyModule(FModuleOp module); + void forwardConstantOutputPort(FModuleOp module); void markAlive(Value value) { // If the value is already in `liveSet`, skip it. @@ -177,11 +180,57 @@ void IMDeadCodeElimPass::markBlockExecutable(Block *block) { } } +void IMDeadCodeElimPass::forwardConstantOutputPort(FModuleOp module) { + // This tracks constant values of output ports. + SmallVector> constantPortIndicesAndValues; + auto ports = module.getPorts(); + auto *instanceGraphNode = instanceGraph->lookup(module); + + for (const auto &e : llvm::enumerate(ports)) { + unsigned index = e.index(); + auto port = e.value(); + auto arg = module.getArgument(index); + + // If the port has don't touch, don't propagate the constant value. + if (!port.isOutput() || hasDontTouch(arg)) + continue; + + // Remember the index and constant value connected to an output port. + if (auto connect = getSingleConnectUserOf(arg)) + if (auto constant = connect.src().getDefiningOp()) + constantPortIndicesAndValues.push_back({index, constant.value()}); + } + + // If there is no constant port, abort. + if (constantPortIndicesAndValues.empty()) + return; + + // Rewrite all uses. + for (auto *use : instanceGraphNode->uses()) { + auto instance = cast(*use->getInstance()); + ImplicitLocOpBuilder builder(instance.getLoc(), instance); + for (auto [index, constant] : constantPortIndicesAndValues) { + auto result = instance.getResult(index); + assert(ports[index].isOutput() && "must be an output port"); + + // Replace the port with the constant. + result.replaceAllUsesWith(builder.create(constant)); + } + } +} + void IMDeadCodeElimPass::runOnOperation() { LLVM_DEBUG(llvm::dbgs() << "===----- Remove unused ports -----===" << "\n"); auto circuit = getOperation(); instanceGraph = &getAnalysis(); + + // Forward constant output ports to caller sides so that we can eliminate + // constant outputs. + for (auto *node : llvm::post_order(instanceGraph)) + if (auto module = dyn_cast_or_null(*node->getModule())) + forwardConstantOutputPort(module); + for (auto module : circuit.getBody()->getOps()) { // Mark the ports of public modules as alive. if (module.isPublic()) { @@ -204,6 +253,19 @@ void IMDeadCodeElimPass::runOnOperation() { mlir::parallelForEach(circuit.getContext(), circuit.getBody()->getOps(), [&](auto op) { rewriteModuleBody(op); }); + + // Erase empty modules. To erase empty modules transitively, it is necessary + // to visit modules in the post order of instance graph. + // FIXME: We copy the list of modules into a vector first to avoid iterator + // invalidation while we mutate the instance graph. See issue 3387. + SmallVector modules(llvm::make_filter_range( + llvm::map_range( + llvm::post_order(instanceGraph), + [](auto *node) { return dyn_cast(*node->getModule()); }), + [](auto module) { return module; })); + + for (auto module : modules) + eraseEmptyModule(module); } void IMDeadCodeElimPass::visitValue(Value value) { @@ -399,6 +461,42 @@ void IMDeadCodeElimPass::rewriteModuleSignature(FModuleOp module) { numRemovedPorts += deadPortIndexes.size(); } +void IMDeadCodeElimPass::eraseEmptyModule(FModuleOp module) { + // Public modules cannot be erased. + if (module.isPublic()) + return; + + // If the module doesn't have arguments, operations or annotations, we + // consider it to be dead. + if (!module.getBody()->args_empty() || !module.getBody()->empty() || + !module.annotations().empty()) + return; + + // Ok, the module is empty. Delete instances unless they have symbols. + LLVM_DEBUG(llvm::dbgs() << "Erase " << module.getName() << "\n"); + + InstanceGraphNode *instanceGraphNode = + instanceGraph->lookup(module.moduleNameAttr()); + + bool existsInstanceWithSymbol = false; + for (auto *use : llvm::make_early_inc_range(instanceGraphNode->uses())) { + auto instance = cast(use->getInstance()); + if (instance.inner_sym()) { + existsInstanceWithSymbol = true; + continue; + } + use->erase(); + instance.erase(); + } + + // If there is an instance with a symbol, we don't delete the module itself. + if (existsInstanceWithSymbol) + return; + + instanceGraph->erase(instanceGraphNode); + module.erase(); +} + std::unique_ptr circt::firrtl::createIMDeadCodeElimPass() { return std::make_unique(); } diff --git a/test/Dialect/FIRRTL/imdce.mlir b/test/Dialect/FIRRTL/imdce.mlir index 5156289b8e7a..219a4ae6d88a 100644 --- a/test/Dialect/FIRRTL/imdce.mlir +++ b/test/Dialect/FIRRTL/imdce.mlir @@ -3,8 +3,7 @@ firrtl.circuit "top" { // In `dead_module`, %source is connected to %dest through several dead operations such as // node, wire, reg or rgereset. %dest is also dead at any instantiation, so check that // all operations are removed by IMDeadCodeElim pass. - // CHECK-LABEL: private @dead_module() { - // CHECK-NEXT: } + // CHECK-NOT: @dead_module firrtl.module private @dead_module(in %source: !firrtl.uint<1>, out %dest: !firrtl.uint<1>, in %clock:!firrtl.clock, in %reset:!firrtl.uint<1>) { %dead_node = firrtl.node %source: !firrtl.uint<1> @@ -54,8 +53,7 @@ firrtl.circuit "top" { %tmp = firrtl.node %source: !firrtl.uint<1> firrtl.strictconnect %dest, %tmp : !firrtl.uint<1> - // TODO: Remove instances of empty modules. - // CHECK-NEXT: firrtl.instance dead_module @dead_module() + // CHECK-NOT: @dead_module %source1, %dest1, %clock1, %reset1 = firrtl.instance dead_module @dead_module(in source: !firrtl.uint<1>, out dest: !firrtl.uint<1>, in clock:!firrtl.clock, in reset:!firrtl.uint<1>) firrtl.strictconnect %source1, %source : !firrtl.uint<1> firrtl.strictconnect %clock1, %clock : !firrtl.clock @@ -82,13 +80,11 @@ firrtl.circuit "top" { // Check that it's possible to analyze complex dependency across different modules. firrtl.circuit "top" { - // CHECK-LABEL: firrtl.module private @Child1() { - // CHECK-NEXT: } + // CHECK-NOT: @Child1 firrtl.module private @Child1(in %input: !firrtl.uint<1>, out %output: !firrtl.uint<1>) { firrtl.strictconnect %output, %input : !firrtl.uint<1> } - // CHECK-LABEL: firrtl.module private @Child2() { - // CHECK-NEXT: } + // CHECK-NOT: @Child2 firrtl.module private @Child2(in %input: !firrtl.uint<1>, in %clock: !firrtl.clock, out %output: !firrtl.uint<1>) { %r = firrtl.reg %clock : !firrtl.uint<1> firrtl.strictconnect %r, %input : !firrtl.uint<1> @@ -96,8 +92,6 @@ firrtl.circuit "top" { } // CHECK-LABEL: firrtl.module @top(in %clock: !firrtl.clock, in %input: !firrtl.uint<1>) { - // CHECK-NEXT: firrtl.instance tile @Child1() - // CHECK-NEXT: firrtl.instance bar @Child2() // CHECK-NEXT: } firrtl.module @top(in %clock: !firrtl.clock, in %input: !firrtl.uint<1>) { %tile_input, %tile_output = firrtl.instance tile @Child1(in input: !firrtl.uint<1>, out output: !firrtl.uint<1>) @@ -142,12 +136,47 @@ firrtl.circuit "UnusedOutput" { // CHECK-LABEL: "PreserveOutputFile" firrtl.circuit "PreserveOutputFile" { // CHECK-NEXT: firrtl.module {{.+}}@Sub + // CHECK-NOT: %a // CHECK-SAME: output_file - firrtl.module private @Sub(in %a: !firrtl.uint<1>) attributes {output_file = #hw.output_file<"hello">} {} + firrtl.module private @Sub(in %a: !firrtl.uint<1>, in %b: !firrtl.uint<1> sym @sym) attributes {output_file = #hw.output_file<"hello">} {} // CHECK: firrtl.module @PreserveOutputFile firrtl.module @PreserveOutputFile() { // CHECK-NEXT: firrtl.instance sub // CHECK-SAME: output_file - firrtl.instance sub {output_file = #hw.output_file<"hello">} @Sub(in a: !firrtl.uint<1>) + firrtl.instance sub {output_file = #hw.output_file<"hello">} @Sub(in a: !firrtl.uint<1>, in b: !firrtl.uint<1>) + } +} + +// ----- + +// CHECK-LABEL: "DeleteEmptyModule" +firrtl.circuit "DeleteEmptyModule" { + // Don't delete @Sub because instance `sub1` has a symbol. + // CHECK: firrtl.module private @Sub + firrtl.module private @Sub(in %a: !firrtl.uint<1>) {} + // CHECK: firrtl.module @DeleteEmptyModule + firrtl.module @DeleteEmptyModule() { + // CHECK-NEXT: firrtl.instance sub1 sym @Foo @Sub() + firrtl.instance sub1 sym @Foo @Sub(in a: !firrtl.uint<1>) + // CHECK-NOT: sub2 + firrtl.instance sub2 @Sub(in a: !firrtl.uint<1>) + } +} + +// ----- + +// CHECK-LABEL: "ForwardConstant" +firrtl.circuit "ForwardConstant" { + // CHECK-NOT: Zero + firrtl.module private @Zero(out %zero: !firrtl.uint<1>) { + %c0_ui1 = firrtl.constant 0 : !firrtl.uint<1> + firrtl.strictconnect %zero, %c0_ui1 : !firrtl.uint<1> + } + // CHECK-LABEL: @ForwardConstant + firrtl.module @ForwardConstant(out %zero: !firrtl.uint<1>) { + // CHECK: %c0_ui1 = firrtl.constant 0 + %sub_zero = firrtl.instance sub @Zero(out zero: !firrtl.uint<1>) + // CHECK-NEXT: firrtl.strictconnect %zero, %c0_ui1 + firrtl.strictconnect %zero, %sub_zero : !firrtl.uint<1> } } diff --git a/test/firtool/firtool.fir b/test/firtool/firtool.fir index 77baaca5a8da..0cb8f932cc0d 100644 --- a/test/firtool/firtool.fir +++ b/test/firtool/firtool.fir @@ -118,7 +118,6 @@ circuit test_mod : %[[{"a": "a"}]] ; MLIR-NEXT: firrtl.strictconnect %multibitMux_a_1, %vec_1 : !firrtl.uint<1> ; MLIR-NEXT: firrtl.strictconnect %multibitMux_a_2, %vec_2 : !firrtl.uint<1> ; MLIR-NEXT: firrtl.strictconnect %multibitMux_sel, %b : !firrtl.uint<2> -; MLIR-NEXT: firrtl.instance unusedPortsMod interesting_name @UnusedPortsMod() ; ANNOTATIONS-LABEL: firrtl.module @test_mod ; ANNOTATIONS-SAME: info = "a ModuleTarget Annotation" @@ -173,7 +172,6 @@ circuit test_mod : %[[{"a": "a"}]] ; VERILOG-NEXT: .sel (b), ; VERILOG-NEXT: .b (out_multibitMux) ; VERILOG-NEXT: ); -; VERILOG-NEXT: UnusedPortsMod unusedPortsMod (); ; VERILOG: endmodule ; Check that we canonicalize the HW output of lowering. @@ -273,4 +271,3 @@ circuit test_mod : %[[{"a": "a"}]] input in : UInt<1> output out : UInt<1> out is invalid -; VERILOG-LABEL: module UnusedPortsMod();