Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a6faaa7

Browse files
committedJan 26, 2025
[HW] Add Passes: hw-expunge-module, hw-tree-shake
1 parent 0d35c61 commit a6faaa7

File tree

6 files changed

+437
-0
lines changed

6 files changed

+437
-0
lines changed
 

‎include/circt/Dialect/HW/HWPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
3333
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
3434
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
3535
std::unique_ptr<mlir::Pass> createFooWiresPass();
36+
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
37+
std::unique_ptr<mlir::Pass> createHWTreeShakePass();
3638

3739
/// Generate the code for registering passes.
3840
#define GEN_PASS_REGISTRATION

‎include/circt/Dialect/HW/Passes.td

+32
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
7575
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
7676
}
7777

78+
def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
79+
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
80+
let description = [{
81+
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
82+
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
83+
and an random suffix would be added to the <instance_path> part.
84+
85+
For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
86+
}];
87+
let constructor = "circt::hw::createHWExpungeModulePass()";
88+
89+
let options = [
90+
ListOption<"modules", "modules", "std::string",
91+
"Comma separated list of module names to be removed from the hierarchy.">,
92+
ListOption<"portPrefixes", "port-prefixes", "std::string",
93+
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
94+
];
95+
}
96+
97+
def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
98+
let summary = "Remove unused modules.";
99+
let description = [{
100+
This pass removes all modules besides a specified list of modules and their transitive dependencies.
101+
}];
102+
let constructor = "circt::hw::createHWTreeShakePass()";
103+
104+
let options = [
105+
ListOption<"keep", "keep", "std::string",
106+
"Comma separated list of module names to be kept.">,
107+
];
108+
}
109+
78110
/**
79111
* Tutorial Pass, doesn't do anything interesting
80112
*/

‎lib/Dialect/HW/Transforms/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ add_circt_dialect_library(CIRCTHWTransforms
66
VerifyInnerRefNamespace.cpp
77
FlattenModules.cpp
88
FooWires.cpp
9+
HWExpungeModule.cpp
10+
HWTreeShake.cpp
911

1012
DEPENDS
1113
CIRCTHWTransformsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
#include "circt/Dialect/HW/HWOps.h"
2+
#include "circt/Dialect/HW/HWPasses.h"
3+
#include "circt/Dialect/HW/HWTypes.h"
4+
#include "circt/Dialect/HW/HWInstanceGraph.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "llvm/ADT/DenseMap.h"
7+
#include "llvm/ADT/ImmutableList.h"
8+
#include "llvm/ADT/PostOrderIterator.h"
9+
#include "llvm/Support/Regex.h"
10+
#include <numeric>
11+
12+
namespace circt {
13+
namespace hw {
14+
#define GEN_PASS_DEF_HWEXPUNGEMODULE
15+
#include "circt/Dialect/HW/Passes.h.inc"
16+
} // namespace hw
17+
} // namespace circt
18+
19+
namespace {
20+
struct HWExpungeModulePass
21+
: circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
22+
void runOnOperation() override;
23+
};
24+
25+
struct InstPathSeg {
26+
llvm::StringRef seg;
27+
28+
InstPathSeg(llvm::StringRef seg) : seg(seg) {}
29+
const llvm::StringRef &getSeg() const { return seg; }
30+
operator llvm::StringRef() const { return seg; }
31+
32+
void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddString(seg); }
33+
};
34+
using InstPath = llvm::ImmutableList<InstPathSeg>;
35+
std::string defaultPrefix(InstPath path) {
36+
std::string accum;
37+
while (!path.isEmpty()) {
38+
accum += path.getHead().getSeg();
39+
accum += "_";
40+
path = path.getTail();
41+
}
42+
accum += "_";
43+
return std::move(accum);
44+
}
45+
46+
// The regex for port prefix specification
47+
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
48+
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
49+
// manually parse the spec This parser may accept identifiers with invalid
50+
// characters
51+
52+
std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
53+
std::string>
54+
parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
55+
auto [l, r] = in.split("=");
56+
if (r == "")
57+
return "No '=' found in input";
58+
auto [ll, lr] = l.split(":");
59+
if (lr == "")
60+
return "No ':' found before '='";
61+
llvm::SmallVector<llvm::StringRef, 4> segs;
62+
while (lr != "") {
63+
auto [seg, rest] = lr.split(".");
64+
segs.push_back(seg);
65+
lr = rest;
66+
}
67+
InstPath path;
68+
for (auto &seg : llvm::reverse(segs))
69+
path = listFac.add(seg, path);
70+
return std::make_tuple(ll, path, r);
71+
}
72+
} // namespace
73+
74+
void HWExpungeModulePass::runOnOperation() {
75+
auto root = getOperation();
76+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
77+
root.walk(
78+
[&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; });
79+
80+
// The instance graph. We only use this graph to traverse the hierarchy in post order.
81+
// The order does not change throught out the operation, onlygets weakened, but still valid.
82+
// So we keep this cached instance graph throughout the pass.
83+
auto &instanceGraph = getAnalysis<circt::hw::InstanceGraph>();
84+
85+
// Instance path.
86+
InstPath::Factory pathFactory;
87+
88+
// Process port prefix specifications
89+
// (Module name, Instance path) -> Prefix
90+
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
91+
designatedPrefixes;
92+
bool containsFailure = false;
93+
for (const auto &raw : portPrefixes) {
94+
auto matched = parsePrefixSpec(raw, pathFactory);
95+
if (std::holds_alternative<std::string>(matched)) {
96+
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
97+
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
98+
containsFailure = true;
99+
continue;
100+
}
101+
102+
auto [module, path, prefix] =
103+
std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
104+
matched);
105+
if (!allModules.contains(module)) {
106+
llvm::errs() << "Module not found in port prefix specification: "
107+
<< module << "\n";
108+
llvm::errs() << "From specification: " << raw << "\n";
109+
containsFailure = true;
110+
continue;
111+
}
112+
113+
// Skip checking instance paths' existence. Non-existent paths are ignored
114+
designatedPrefixes.insert({{module, path}, prefix});
115+
}
116+
117+
if (containsFailure)
118+
return signalPassFailure();
119+
120+
// Instance path * prefix name
121+
using ReplacedDescendent = std::pair<InstPath, std::string>;
122+
// This map holds the expunged descendents of a module
123+
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
124+
expungedDescendents;
125+
for (auto &expunging : this->modules) {
126+
// Clear expungedDescendents
127+
for (auto &it : expungedDescendents)
128+
it.getSecond().clear();
129+
130+
auto expungingMod = allModules.lookup(expunging);
131+
if (!expungingMod)
132+
continue; // Ignored missing modules
133+
auto expungingModTy = expungingMod.getHWModuleType();
134+
auto expungingModPorts = expungingModTy.getPorts();
135+
136+
auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
137+
const std::string &prefix,
138+
auto genOutput, auto emitInput) {
139+
mlir::OpBuilder builder(mod);
140+
// Create ports using *REVERSE* direction of their definitions
141+
for (auto &port : expungingModPorts) {
142+
auto defaultName = prefix + port.name.getValue();
143+
auto finalName = defaultName;
144+
if (port.dir == circt::hw::PortInfo::Input) {
145+
auto val = genOutput(port);
146+
assert(val.getType() == port.type);
147+
mod.appendOutput(finalName, val);
148+
} else if (port.dir == circt::hw::PortInfo::Output) {
149+
auto [_, arg] = mod.appendInput(finalName, port.type);
150+
emitInput(port, arg);
151+
}
152+
}
153+
};
154+
155+
for (auto &instGraphNode : llvm::post_order(&instanceGraph)) {
156+
// Skip extmodule and intmodule because they cannot contain anything
157+
circt::hw::HWModuleOp processing =
158+
llvm::dyn_cast_if_present<circt::hw::HWModuleOp>(
159+
instGraphNode->getModule().getOperation());
160+
if (!processing)
161+
continue;
162+
163+
std::optional<decltype(expungedDescendents.lookup("")) *>
164+
outerExpDescHold = {};
165+
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
166+
if (!outerExpDescHold.has_value())
167+
outerExpDescHold = {
168+
&expungedDescendents.insert({processing.getName(), {}})
169+
.first->getSecond()};
170+
return **outerExpDescHold;
171+
};
172+
173+
mlir::OpBuilder outerBuilder(processing);
174+
175+
processing.walk([&](circt::hw::InstanceOp inst) {
176+
auto instName = inst.getInstanceName();
177+
auto instMod = allModules.lookup(inst.getModuleName());
178+
179+
if (instMod.getOutputNames().size() != inst.getResults().size() ||
180+
instMod.getNumInputPorts() != inst.getInputs().size()) {
181+
// Module have been modified during this pass, create new instances
182+
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
183+
assert(instMod.getNumInputPorts() >= inst.getInputs().size());
184+
185+
auto instModInTypes = instMod.getInputTypes();
186+
187+
llvm::SmallVector<mlir::Value> newInputs;
188+
newInputs.reserve(instMod.getNumInputPorts());
189+
190+
outerBuilder.setInsertionPointAfter(inst);
191+
192+
// Appended inputs are at the end of the input list
193+
for (size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
194+
mlir::Value input;
195+
if (i < inst.getNumInputPorts()) {
196+
input = inst.getInputs()[i];
197+
if (auto existingName = inst.getInputName(i))
198+
assert(existingName == instMod.getInputName(i));
199+
} else {
200+
input =
201+
outerBuilder
202+
.create<mlir::UnrealizedConversionCastOp>(
203+
inst.getLoc(), instModInTypes[i], mlir::ValueRange{})
204+
.getResult(0);
205+
}
206+
newInputs.push_back(input);
207+
}
208+
209+
auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
210+
inst.getLoc(), instMod, inst.getInstanceNameAttr(), newInputs,
211+
inst.getParameters(),
212+
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({}));
213+
214+
for (size_t i = 0; i < inst.getNumResults(); ++i)
215+
assert(inst.getOutputName(i) == instMod.getOutputName(i));
216+
inst.replaceAllUsesWith(
217+
newInst.getResults().slice(0, inst.getNumResults()));
218+
inst.erase();
219+
inst = newInst;
220+
}
221+
222+
llvm::StringMap<mlir::Value> instOMap;
223+
llvm::StringMap<mlir::Value> instIMap;
224+
assert(instMod.getOutputNames().size() == inst.getResults().size());
225+
for (auto [oname, oval] :
226+
llvm::zip(instMod.getOutputNames(), inst.getResults()))
227+
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
228+
assert(instMod.getInputNames().size() == inst.getInputs().size());
229+
for (auto [iname, ival] :
230+
llvm::zip(instMod.getInputNames(), inst.getInputs()))
231+
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;
232+
233+
// Get outer expunged descendent first because it may modify the map and
234+
// invalidate iterators.
235+
auto &outerExpDesc = getOuterExpDesc();
236+
auto instExpDesc = expungedDescendents.find(inst.getModuleName());
237+
238+
if (inst.getModuleName() == expunging) {
239+
// Handle the directly expunged module
240+
// input maps also useful for directly expunged instance
241+
242+
auto singletonPath = pathFactory.create(instName);
243+
244+
auto designatedPrefix =
245+
designatedPrefixes.find({processing.getName(), singletonPath});
246+
std::string prefix = designatedPrefix != designatedPrefixes.end()
247+
? designatedPrefix->getSecond().str()
248+
: (instName + "__").str();
249+
250+
// Port name collision is still possible, but current relying on MLIR
251+
// to automatically rename input arguments.
252+
// TODO: name collision detect
253+
254+
createPortsOn(
255+
processing, prefix,
256+
[&](circt::hw::ModulePort port) {
257+
// Generate output for outer module, so input for us
258+
return instIMap.at(port.name);
259+
},
260+
[&](circt::hw::ModulePort port, mlir::Value val) {
261+
// Generated input for outer module, replace inst results
262+
assert(instOMap.contains(port.name));
263+
instOMap[port.name].replaceAllUsesWith(val);
264+
});
265+
266+
outerExpDesc.emplace_back(singletonPath, prefix);
267+
268+
assert(instExpDesc == expungedDescendents.end() ||
269+
instExpDesc->getSecond().size() == 0);
270+
inst.erase();
271+
} else if (instExpDesc != expungedDescendents.end()) {
272+
// Handle all transitive descendents
273+
if (instExpDesc->second.size() == 0)
274+
return;
275+
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
276+
for (const auto &exp : instExpDesc->second) {
277+
auto newPath = pathFactory.add(instName, exp.first);
278+
auto designatedPrefix =
279+
designatedPrefixes.find({processing.getName(), newPath});
280+
std::string prefix = designatedPrefix != designatedPrefixes.end()
281+
? designatedPrefix->getSecond().str()
282+
: defaultPrefix(newPath);
283+
284+
// TODO: name collision detect
285+
286+
createPortsOn(
287+
processing, prefix,
288+
[&](circt::hw::ModulePort port) {
289+
// Generate output for outer module, directly forward from
290+
// inner inst
291+
return instOMap.at((exp.second + port.name.getValue()).str());
292+
},
293+
[&](circt::hw::ModulePort port, mlir::Value val) {
294+
// Generated input for outer module, replace inst results.
295+
// The operand in question has to be an backedge
296+
auto in =
297+
instIMap.at((exp.second + port.name.getValue()).str());
298+
auto inDef = in.getDefiningOp();
299+
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
300+
in.replaceAllUsesWith(val);
301+
inDef->erase();
302+
});
303+
304+
outerExpDesc.emplace_back(newPath, prefix);
305+
}
306+
}
307+
});
308+
}
309+
}
310+
}
311+
312+
std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
313+
return std::make_unique<HWExpungeModulePass>();
314+
}
+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "circt/Dialect/HW/HWOps.h"
2+
#include "circt/Dialect/HW/HWPasses.h"
3+
#include "circt/Dialect/HW/HWTypes.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "llvm/ADT/DenseMap.h"
6+
#include "llvm/Support/Regex.h"
7+
8+
namespace circt {
9+
namespace hw {
10+
#define GEN_PASS_DEF_HWTREESHAKE
11+
#include "circt/Dialect/HW/Passes.h.inc"
12+
} // namespace hw
13+
} // namespace circt
14+
15+
struct HWTreeShakePass : circt::hw::impl::HWTreeShakeBase<HWTreeShakePass> {
16+
void runOnOperation() override;
17+
};
18+
19+
void HWTreeShakePass::runOnOperation() {
20+
auto root = getOperation();
21+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
22+
root.walk(
23+
[&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; });
24+
25+
llvm::DenseSet<circt::hw::HWModuleLike> visited;
26+
auto visit = [&allModules, &visited](auto &self,
27+
circt::hw::HWModuleLike mod) -> void {
28+
if (visited.contains(mod))
29+
return;
30+
visited.insert(mod);
31+
mod.walk([&](circt::hw::InstanceOp inst) {
32+
auto modName = inst.getModuleName();
33+
self(self, allModules.at(modName));
34+
});
35+
};
36+
37+
for (const auto &kept : keep) {
38+
auto lookup = allModules.find(kept);
39+
if (lookup == allModules.end())
40+
continue; // Silently ignore missing modules
41+
visit(visit, lookup->getSecond());
42+
}
43+
44+
for (auto &mod : allModules) {
45+
if (!visited.contains(mod.getSecond())) {
46+
mod.getSecond()->remove();
47+
}
48+
}
49+
}
50+
51+
std::unique_ptr<mlir::Pass> circt::hw::createHWTreeShakePass() {
52+
return std::make_unique<HWTreeShakePass>();
53+
}

‎test/Dialect/HW/expunge-module.mlir

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}})" %s | FileCheck %s --check-prefixes FOO,BAR,BAZ,COMMON
2+
// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}},hw-tree-shake{keep=foo})" %s | FileCheck %s --check-prefixes FOO,BAR,COMMON
3+
// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}},hw-tree-shake{keep=baz})" %s | FileCheck %s --check-prefixes BAZ,COMMON
4+
5+
module {
6+
hw.module @foo(in %bar1_baz1__out : i2, out test : i1) {
7+
%bar1.self_out = hw.instance "bar1" @bar(self_in: %0: i1) -> (self_out: i1)
8+
%bar2.self_out = hw.instance "bar2" @bar(self_in: %bar1.self_out: i1) -> (self_out: i1)
9+
%0 = comb.extract %bar1_baz1__out from 0 : (i2) -> i1
10+
hw.output %bar2.self_out : i1
11+
}
12+
hw.module private @bar(in %self_in : i1, out self_out : i1) {
13+
%baz1.out = hw.instance "baz1" @baz(in: %self_in: i1) -> (out: i1)
14+
%baz2.out = hw.instance "baz2" @baz(in: %baz1.out: i1) -> (out: i1)
15+
hw.output %baz2.out : i1
16+
}
17+
hw.module private @baz(in %in : i1, out out : i1) {
18+
hw.output %in : i1
19+
}
20+
}
21+
22+
// COMMON: module {
23+
// FOO-NEXT: hw.module @foo(in %bar1_baz1__out : i2, in %bar1_baz1__out_0 : i1, in %bar1_baz2__out : i1, in %bar2_baz1__out : i1, in %meow_out : i1, out test : i1, out bar1_baz1__in : i1, out bar1_baz2__in : i1, out bar2_baz1__in : i1, out meow_in : i1) {
24+
// FOO-NEXT: %bar1.self_out, %bar1.nya_in, %bar1.baz2__in = hw.instance "bar1" @bar(self_in: %0: i1, nya_out: %bar1_baz1__out_0: i1, baz2__out: %bar1_baz2__out: i1) -> (self_out: i1, nya_in: i1, baz2__in: i1)
25+
// FOO-NEXT: %bar2.self_out, %bar2.nya_in, %bar2.baz2__in = hw.instance "bar2" @bar(self_in: %bar1.self_out: i1, nya_out: %bar2_baz1__out: i1, baz2__out: %meow_out: i1) -> (self_out: i1, nya_in: i1, baz2__in: i1)
26+
// FOO-NEXT: %0 = comb.extract %bar1_baz1__out from 0 : (i2) -> i1
27+
// FOO-NEXT: hw.output %bar2.self_out, %bar1.nya_in, %bar1.baz2__in, %bar2.nya_in, %bar2.baz2__in : i1, i1, i1, i1, i1
28+
// FOO-NEXT: }
29+
// BAR-NEXT: hw.module private @bar(in %self_in : i1, in %nya_out : i1, in %baz2__out : i1, out self_out : i1, out nya_in : i1, out baz2__in : i1) {
30+
// BAR-NEXT: hw.output %baz2__out, %self_in, %nya_out : i1, i1, i1
31+
// BAR-NEXT: }
32+
// BAZ-NEXT: hw.module private @baz(in %in : i1, out out : i1) {
33+
// BAZ-NEXT: hw.output %in : i1
34+
// BAZ-NEXT: }

0 commit comments

Comments
 (0)
Please sign in to comment.