Skip to content

Commit 10489b2

Browse files
committed
[HW] Add Passes: hw-expunge-module, hw-tree-shake
1 parent 0d35c61 commit 10489b2

File tree

6 files changed

+453
-0
lines changed

6 files changed

+453
-0
lines changed

include/circt/Dialect/HW/HWPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
3333
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
3434
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
3535
std::unique_ptr<mlir::Pass> createFooWiresPass();
36+
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
37+
std::unique_ptr<mlir::Pass> createHWTreeShakePass();
3638

3739
/// Generate the code for registering passes.
3840
#define GEN_PASS_REGISTRATION

include/circt/Dialect/HW/Passes.td

+32
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
7575
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
7676
}
7777

78+
def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
79+
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
80+
let description = [{
81+
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
82+
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
83+
and an random suffix would be added to the <instance_path> part.
84+
85+
For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
86+
}];
87+
let constructor = "circt::hw::createHWExpungeModulePass()";
88+
89+
let options = [
90+
ListOption<"modules", "modules", "std::string",
91+
"Comma separated list of module names to be removed from the hierarchy.">,
92+
ListOption<"portPrefixes", "port-prefixes", "std::string",
93+
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
94+
];
95+
}
96+
97+
def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
98+
let summary = "Remove unused modules.";
99+
let description = [{
100+
This pass removes all modules besides a specified list of modules and their transitive dependencies.
101+
}];
102+
let constructor = "circt::hw::createHWTreeShakePass()";
103+
104+
let options = [
105+
ListOption<"keep", "keep", "std::string",
106+
"Comma separated list of module names to be kept.">,
107+
];
108+
}
109+
78110
/**
79111
* Tutorial Pass, doesn't do anything interesting
80112
*/

lib/Dialect/HW/Transforms/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ add_circt_dialect_library(CIRCTHWTransforms
66
VerifyInnerRefNamespace.cpp
77
FlattenModules.cpp
88
FooWires.cpp
9+
HWExpungeModule.cpp
10+
HWTreeShake.cpp
911

1012
DEPENDS
1113
CIRCTHWTransformsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
#include "circt/Dialect/HW/HWOps.h"
2+
#include "circt/Dialect/HW/HWPasses.h"
3+
#include "circt/Dialect/HW/HWTypes.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "llvm/ADT/DenseMap.h"
6+
#include "llvm/ADT/ImmutableList.h"
7+
#include "llvm/Support/Regex.h"
8+
9+
namespace circt {
10+
namespace hw {
11+
#define GEN_PASS_DEF_HWEXPUNGEMODULE
12+
#include "circt/Dialect/HW/Passes.h.inc"
13+
} // namespace hw
14+
} // namespace circt
15+
16+
namespace {
17+
struct HWExpungeModulePass
18+
: circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
19+
void runOnOperation() override;
20+
};
21+
22+
struct InstPathSeg {
23+
llvm::StringRef seg;
24+
25+
InstPathSeg(llvm::StringRef seg) : seg(seg) {}
26+
const llvm::StringRef &getSeg() const { return seg; }
27+
operator llvm::StringRef() const { return seg; }
28+
29+
void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddString(seg); }
30+
};
31+
using InstPath = llvm::ImmutableList<InstPathSeg>;
32+
std::string defaultPrefix(InstPath path) {
33+
std::string accum;
34+
while (!path.isEmpty()) {
35+
accum += path.getHead().getSeg();
36+
accum += "_";
37+
path = path.getTail();
38+
}
39+
accum += "_";
40+
return std::move(accum);
41+
}
42+
43+
// The regex for port prefix specification
44+
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
45+
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
46+
// manually parse the spec This parser may accept identifiers with invalid
47+
// characters
48+
49+
std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
50+
std::string>
51+
parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
52+
auto [l, r] = in.split("=");
53+
if (r == "")
54+
return "No '=' found in input";
55+
auto [ll, lr] = l.split(":");
56+
if (lr == "")
57+
return "No ':' found before '='";
58+
llvm::SmallVector<llvm::StringRef, 4> segs;
59+
while (lr != "") {
60+
auto [seg, rest] = lr.split(".");
61+
segs.push_back(seg);
62+
lr = rest;
63+
}
64+
InstPath path;
65+
for (auto &seg : llvm::reverse(segs))
66+
path = listFac.add(seg, path);
67+
return std::make_tuple(ll, path, r);
68+
}
69+
} // namespace
70+
71+
void HWExpungeModulePass::runOnOperation() {
72+
auto root = getOperation();
73+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
74+
root.walk(
75+
[&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; });
76+
// Reverse-topo order generated by module instances.
77+
// The topo sorted order does not change throught out the operation. It only
78+
// gets weakened, but still valid.
79+
llvm::SmallVector<circt::hw::HWModuleLike> revTopo;
80+
llvm::DenseSet<mlir::StringRef> visited;
81+
auto visit = [&allModules, &visited,
82+
&revTopo](auto &self, circt::hw::HWModuleLike mod) -> void {
83+
if (visited.contains(mod.getName()))
84+
return;
85+
visited.insert(mod.getName());
86+
87+
mod.walk([&](circt::hw::InstanceOp inst) {
88+
auto instModName = inst.getModuleName();
89+
auto instMod = allModules.lookup(instModName);
90+
if (!instMod)
91+
inst.emitError("Unknown module ") << instModName;
92+
self(self, instMod);
93+
});
94+
95+
// Reverse topo order, insert on stack pop
96+
revTopo.push_back(mod);
97+
};
98+
for (auto &[_, mod] : allModules)
99+
visit(visit, mod);
100+
assert(revTopo.size() == allModules.size());
101+
102+
// Instance path.
103+
InstPath::Factory pathFactory;
104+
105+
// Process port prefix specifications
106+
// (Module name, Instance path) -> Prefix
107+
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
108+
designatedPrefixes;
109+
bool containsFailure = false;
110+
for (const auto &raw : portPrefixes) {
111+
auto matched = parsePrefixSpec(raw, pathFactory);
112+
if (std::holds_alternative<std::string>(matched)) {
113+
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
114+
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
115+
containsFailure = true;
116+
continue;
117+
}
118+
119+
auto [module, path, prefix] =
120+
std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
121+
matched);
122+
if (!allModules.contains(module)) {
123+
llvm::errs() << "Module not found in port prefix specification: "
124+
<< module << "\n";
125+
llvm::errs() << "From specification: " << raw << "\n";
126+
containsFailure = true;
127+
continue;
128+
}
129+
130+
// Skip checking instance paths' existence. Non-existent paths are ignored
131+
designatedPrefixes.insert({{module, path}, prefix});
132+
}
133+
134+
if (containsFailure)
135+
return signalPassFailure();
136+
137+
// Instance path * prefix name
138+
using ReplacedDescendent = std::pair<InstPath, std::string>;
139+
// This map holds the expunged descendents of a module
140+
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
141+
expungedDescendents;
142+
for (auto &expunging : this->modules) {
143+
// Clear expungedDescendents
144+
for (auto &it : expungedDescendents)
145+
it.getSecond().clear();
146+
147+
auto expungingMod = allModules.lookup(expunging);
148+
if (!expungingMod)
149+
continue; // Ignored missing modules
150+
auto expungingModTy = expungingMod.getHWModuleType();
151+
auto expungingModPorts = expungingModTy.getPorts();
152+
153+
auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
154+
const std::string &prefix,
155+
auto genOutput, auto emitInput) {
156+
mlir::OpBuilder builder(mod);
157+
// Create ports using *REVERSE* direction of their definitions
158+
for (auto &port : expungingModPorts) {
159+
auto defaultName = prefix + port.name.getValue();
160+
auto finalName = defaultName;
161+
if (port.dir == circt::hw::PortInfo::Input) {
162+
auto val = genOutput(port);
163+
assert(val.getType() == port.type);
164+
mod.appendOutput(finalName, val);
165+
} else if (port.dir == circt::hw::PortInfo::Output) {
166+
auto [_, arg] = mod.appendInput(finalName, port.type);
167+
emitInput(port, arg);
168+
}
169+
}
170+
};
171+
172+
for (auto &processingRaw : revTopo) {
173+
// Skip extmodule and intmodule because they cannot contain anything
174+
if (!llvm::isa<circt::hw::HWModuleOp>(processingRaw))
175+
continue;
176+
circt::hw::HWModuleOp processing =
177+
llvm::cast<circt::hw::HWModuleOp>(processingRaw);
178+
179+
std::optional<decltype(expungedDescendents.lookup("")) *>
180+
outerExpDescHold = {};
181+
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
182+
if (!outerExpDescHold.has_value())
183+
outerExpDescHold = {
184+
&expungedDescendents.insert({processing.getName(), {}})
185+
.first->getSecond()};
186+
return **outerExpDescHold;
187+
};
188+
189+
mlir::OpBuilder outerBuilder(processing);
190+
191+
processing.walk([&](circt::hw::InstanceOp inst) {
192+
auto instName = inst.getInstanceName();
193+
auto instMod = allModules.lookup(inst.getModuleName());
194+
195+
if (instMod.getOutputNames().size() != inst.getResults().size() ||
196+
instMod.getNumInputPorts() != inst.getInputs().size()) {
197+
// Module have been modified during this pass, create new instances
198+
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
199+
assert(instMod.getNumInputPorts() >= inst.getInputs().size());
200+
201+
auto instModInTypes = instMod.getInputTypes();
202+
203+
llvm::SmallVector<mlir::Value> newInputs;
204+
newInputs.reserve(instMod.getNumInputPorts());
205+
206+
outerBuilder.setInsertionPointAfter(inst);
207+
208+
// Appended inputs are at the end of the input list
209+
for (size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
210+
mlir::Value input;
211+
if (i < inst.getNumInputPorts()) {
212+
input = inst.getInputs()[i];
213+
if (auto existingName = inst.getInputName(i))
214+
assert(existingName == instMod.getInputName(i));
215+
} else {
216+
input =
217+
outerBuilder
218+
.create<mlir::UnrealizedConversionCastOp>(
219+
inst.getLoc(), instModInTypes[i], mlir::ValueRange{})
220+
.getResult(0);
221+
}
222+
newInputs.push_back(input);
223+
}
224+
225+
auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
226+
inst.getLoc(), instMod, inst.getInstanceNameAttr(), newInputs,
227+
inst.getParameters(),
228+
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({}));
229+
230+
for (size_t i = 0; i < inst.getNumResults(); ++i)
231+
assert(inst.getOutputName(i) == instMod.getOutputName(i));
232+
inst.replaceAllUsesWith(
233+
newInst.getResults().slice(0, inst.getNumResults()));
234+
inst.erase();
235+
inst = newInst;
236+
}
237+
238+
llvm::DenseMap<llvm::StringRef, mlir::Value> instOMap;
239+
llvm::DenseMap<llvm::StringRef, mlir::Value> instIMap;
240+
assert(instMod.getOutputNames().size() == inst.getResults().size());
241+
for (auto [oname, oval] :
242+
llvm::zip(instMod.getOutputNames(), inst.getResults()))
243+
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
244+
assert(instMod.getInputNames().size() == inst.getInputs().size());
245+
for (auto [iname, ival] :
246+
llvm::zip(instMod.getInputNames(), inst.getInputs()))
247+
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;
248+
249+
// Get outer expunged descendent first because it may modify the map and
250+
// invalidate iterators.
251+
auto &outerExpDesc = getOuterExpDesc();
252+
auto instExpDesc = expungedDescendents.find(inst.getModuleName());
253+
254+
if (inst.getModuleName() == expunging) {
255+
// Handle the directly expunged module
256+
// input maps also useful for directly expunged instance
257+
258+
auto singletonPath = pathFactory.create(instName);
259+
260+
auto designatedPrefix =
261+
designatedPrefixes.find({processing.getName(), singletonPath});
262+
std::string prefix = designatedPrefix != designatedPrefixes.end()
263+
? designatedPrefix->getSecond().str()
264+
: (instName + "__").str();
265+
266+
// Port name collision is still possible, but current relying on MLIR
267+
// to automatically rename input arguments.
268+
// TODO: name collision detect
269+
270+
createPortsOn(
271+
processing, prefix,
272+
[&](circt::hw::ModulePort port) {
273+
// Generate output for outer module, so input for us
274+
return instIMap.at(port.name);
275+
},
276+
[&](circt::hw::ModulePort port, mlir::Value val) {
277+
// Generated input for outer module, replace inst results
278+
assert(instOMap.contains(port.name));
279+
instOMap[port.name].replaceAllUsesWith(val);
280+
});
281+
282+
outerExpDesc.emplace_back(singletonPath, prefix);
283+
284+
assert(instExpDesc == expungedDescendents.end() ||
285+
instExpDesc->getSecond().size() == 0);
286+
inst.erase();
287+
} else if (instExpDesc != expungedDescendents.end()) {
288+
// Handle all transitive descendents
289+
if (instExpDesc->second.size() == 0)
290+
return;
291+
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
292+
for (const auto &exp : instExpDesc->second) {
293+
auto newPath = pathFactory.add(instName, exp.first);
294+
auto designatedPrefix =
295+
designatedPrefixes.find({processing.getName(), newPath});
296+
std::string prefix = designatedPrefix != designatedPrefixes.end()
297+
? designatedPrefix->getSecond().str()
298+
: defaultPrefix(newPath);
299+
300+
// TODO: name collision detect
301+
302+
createPortsOn(
303+
processing, prefix,
304+
[&](circt::hw::ModulePort port) {
305+
// Generate output for outer module, directly forward from
306+
// inner inst
307+
return instOMap.at((exp.second + port.name.getValue()).str());
308+
},
309+
[&](circt::hw::ModulePort port, mlir::Value val) {
310+
// Generated input for outer module, replace inst results.
311+
// The operand in question has to be an backedge
312+
auto in =
313+
instIMap.at((exp.second + port.name.getValue()).str());
314+
auto inDef = in.getDefiningOp();
315+
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
316+
in.replaceAllUsesWith(val);
317+
inDef->erase();
318+
});
319+
320+
outerExpDesc.emplace_back(newPath, prefix);
321+
}
322+
}
323+
});
324+
}
325+
}
326+
}
327+
328+
std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
329+
return std::make_unique<HWExpungeModulePass>();
330+
}

0 commit comments

Comments
 (0)