Skip to content

Commit e00fc1f

Browse files
committed
[HW] Add Passes: hw-expunge-module, hw-tree-shake
1 parent b951ce7 commit e00fc1f

File tree

6 files changed

+438
-0
lines changed

6 files changed

+438
-0
lines changed

include/circt/Dialect/HW/HWPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
3434
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
3535
std::unique_ptr<mlir::Pass> createFooWiresPass();
3636
std::unique_ptr<mlir::Pass> createHWAggregateToCombPass();
37+
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
38+
std::unique_ptr<mlir::Pass> createHWTreeShakePass();
3739

3840
/// Generate the code for registering passes.
3941
#define GEN_PASS_REGISTRATION

include/circt/Dialect/HW/Passes.td

+32
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
7575
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
7676
}
7777

78+
def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
79+
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
80+
let description = [{
81+
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
82+
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
83+
and an random suffix would be added to the <instance_path> part.
84+
85+
For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
86+
}];
87+
let constructor = "circt::hw::createHWExpungeModulePass()";
88+
89+
let options = [
90+
ListOption<"modules", "modules", "std::string",
91+
"Comma separated list of module names to be removed from the hierarchy.">,
92+
ListOption<"portPrefixes", "port-prefixes", "std::string",
93+
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
94+
];
95+
}
96+
97+
def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
98+
let summary = "Remove unused modules.";
99+
let description = [{
100+
This pass removes all modules besides a specified list of modules and their transitive dependencies.
101+
}];
102+
let constructor = "circt::hw::createHWTreeShakePass()";
103+
104+
let options = [
105+
ListOption<"keep", "keep", "std::string",
106+
"Comma separated list of module names to be kept.">,
107+
];
108+
}
109+
78110
/**
79111
* Tutorial Pass, doesn't do anything interesting
80112
*/

lib/Dialect/HW/Transforms/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ add_circt_dialect_library(CIRCTHWTransforms
77
VerifyInnerRefNamespace.cpp
88
FlattenModules.cpp
99
FooWires.cpp
10+
HWExpungeModule.cpp
11+
HWTreeShake.cpp
1012

1113
DEPENDS
1214
CIRCTHWTransformsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
#include "circt/Dialect/HW/HWInstanceGraph.h"
2+
#include "circt/Dialect/HW/HWOps.h"
3+
#include "circt/Dialect/HW/HWPasses.h"
4+
#include "circt/Dialect/HW/HWTypes.h"
5+
#include "mlir/Pass/Pass.h"
6+
#include "llvm/ADT/DenseMap.h"
7+
#include "llvm/ADT/ImmutableList.h"
8+
#include "llvm/ADT/PostOrderIterator.h"
9+
#include "llvm/Support/Regex.h"
10+
#include <numeric>
11+
12+
namespace circt {
13+
namespace hw {
14+
#define GEN_PASS_DEF_HWEXPUNGEMODULE
15+
#include "circt/Dialect/HW/Passes.h.inc"
16+
} // namespace hw
17+
} // namespace circt
18+
19+
namespace {
20+
struct HWExpungeModulePass
21+
: circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
22+
void runOnOperation() override;
23+
};
24+
25+
struct InstPathSeg {
26+
llvm::StringRef seg;
27+
28+
InstPathSeg(llvm::StringRef seg) : seg(seg) {}
29+
const llvm::StringRef &getSeg() const { return seg; }
30+
operator llvm::StringRef() const { return seg; }
31+
32+
void Profile(llvm::FoldingSetNodeID &ID) const { ID.AddString(seg); }
33+
};
34+
using InstPath = llvm::ImmutableList<InstPathSeg>;
35+
std::string defaultPrefix(InstPath path) {
36+
std::string accum;
37+
while (!path.isEmpty()) {
38+
accum += path.getHead().getSeg();
39+
accum += "_";
40+
path = path.getTail();
41+
}
42+
accum += "_";
43+
return std::move(accum);
44+
}
45+
46+
// The regex for port prefix specification
47+
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
48+
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
49+
// manually parse the spec This parser may accept identifiers with invalid
50+
// characters
51+
52+
std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
53+
std::string>
54+
parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
55+
auto [l, r] = in.split("=");
56+
if (r == "")
57+
return "No '=' found in input";
58+
auto [ll, lr] = l.split(":");
59+
if (lr == "")
60+
return "No ':' found before '='";
61+
llvm::SmallVector<llvm::StringRef, 4> segs;
62+
while (lr != "") {
63+
auto [seg, rest] = lr.split(".");
64+
segs.push_back(seg);
65+
lr = rest;
66+
}
67+
InstPath path;
68+
for (auto &seg : llvm::reverse(segs))
69+
path = listFac.add(seg, path);
70+
return std::make_tuple(ll, path, r);
71+
}
72+
} // namespace
73+
74+
void HWExpungeModulePass::runOnOperation() {
75+
auto root = getOperation();
76+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
77+
root.walk(
78+
[&](circt::hw::HWModuleLike mod) { allModules[mod.getName()] = mod; });
79+
80+
// The instance graph. We only use this graph to traverse the hierarchy in
81+
// post order. The order does not change throught out the operation, onlygets
82+
// weakened, but still valid. So we keep this cached instance graph throughout
83+
// the pass.
84+
auto &instanceGraph = getAnalysis<circt::hw::InstanceGraph>();
85+
86+
// Instance path.
87+
InstPath::Factory pathFactory;
88+
89+
// Process port prefix specifications
90+
// (Module name, Instance path) -> Prefix
91+
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
92+
designatedPrefixes;
93+
bool containsFailure = false;
94+
for (const auto &raw : portPrefixes) {
95+
auto matched = parsePrefixSpec(raw, pathFactory);
96+
if (std::holds_alternative<std::string>(matched)) {
97+
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
98+
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
99+
containsFailure = true;
100+
continue;
101+
}
102+
103+
auto [module, path, prefix] =
104+
std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
105+
matched);
106+
if (!allModules.contains(module)) {
107+
llvm::errs() << "Module not found in port prefix specification: "
108+
<< module << "\n";
109+
llvm::errs() << "From specification: " << raw << "\n";
110+
containsFailure = true;
111+
continue;
112+
}
113+
114+
// Skip checking instance paths' existence. Non-existent paths are ignored
115+
designatedPrefixes.insert({{module, path}, prefix});
116+
}
117+
118+
if (containsFailure)
119+
return signalPassFailure();
120+
121+
// Instance path * prefix name
122+
using ReplacedDescendent = std::pair<InstPath, std::string>;
123+
// This map holds the expunged descendents of a module
124+
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
125+
expungedDescendents;
126+
for (auto &expunging : this->modules) {
127+
// Clear expungedDescendents
128+
for (auto &it : expungedDescendents)
129+
it.getSecond().clear();
130+
131+
auto expungingMod = allModules.lookup(expunging);
132+
if (!expungingMod)
133+
continue; // Ignored missing modules
134+
auto expungingModTy = expungingMod.getHWModuleType();
135+
auto expungingModPorts = expungingModTy.getPorts();
136+
137+
auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
138+
const std::string &prefix,
139+
auto genOutput, auto emitInput) {
140+
mlir::OpBuilder builder(mod);
141+
// Create ports using *REVERSE* direction of their definitions
142+
for (auto &port : expungingModPorts) {
143+
auto defaultName = prefix + port.name.getValue();
144+
auto finalName = defaultName;
145+
if (port.dir == circt::hw::PortInfo::Input) {
146+
auto val = genOutput(port);
147+
assert(val.getType() == port.type);
148+
mod.appendOutput(finalName, val);
149+
} else if (port.dir == circt::hw::PortInfo::Output) {
150+
auto [_, arg] = mod.appendInput(finalName, port.type);
151+
emitInput(port, arg);
152+
}
153+
}
154+
};
155+
156+
for (auto &instGraphNode : llvm::post_order(&instanceGraph)) {
157+
// Skip extmodule and intmodule because they cannot contain anything
158+
circt::hw::HWModuleOp processing =
159+
llvm::dyn_cast_if_present<circt::hw::HWModuleOp>(
160+
instGraphNode->getModule().getOperation());
161+
if (!processing)
162+
continue;
163+
164+
std::optional<decltype(expungedDescendents.lookup("")) *>
165+
outerExpDescHold = {};
166+
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
167+
if (!outerExpDescHold.has_value())
168+
outerExpDescHold = {
169+
&expungedDescendents.insert({processing.getName(), {}})
170+
.first->getSecond()};
171+
return **outerExpDescHold;
172+
};
173+
174+
mlir::OpBuilder outerBuilder(processing);
175+
176+
processing.walk([&](circt::hw::InstanceOp inst) {
177+
auto instName = inst.getInstanceName();
178+
auto instMod = allModules.lookup(inst.getModuleName());
179+
180+
if (instMod.getOutputNames().size() != inst.getResults().size() ||
181+
instMod.getNumInputPorts() != inst.getInputs().size()) {
182+
// Module have been modified during this pass, create new instances
183+
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
184+
assert(instMod.getNumInputPorts() >= inst.getInputs().size());
185+
186+
auto instModInTypes = instMod.getInputTypes();
187+
188+
llvm::SmallVector<mlir::Value> newInputs;
189+
newInputs.reserve(instMod.getNumInputPorts());
190+
191+
outerBuilder.setInsertionPointAfter(inst);
192+
193+
// Appended inputs are at the end of the input list
194+
for (size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
195+
mlir::Value input;
196+
if (i < inst.getNumInputPorts()) {
197+
input = inst.getInputs()[i];
198+
if (auto existingName = inst.getInputName(i))
199+
assert(existingName == instMod.getInputName(i));
200+
} else {
201+
input =
202+
outerBuilder
203+
.create<mlir::UnrealizedConversionCastOp>(
204+
inst.getLoc(), instModInTypes[i], mlir::ValueRange{})
205+
.getResult(0);
206+
}
207+
newInputs.push_back(input);
208+
}
209+
210+
auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
211+
inst.getLoc(), instMod, inst.getInstanceNameAttr(), newInputs,
212+
inst.getParameters(),
213+
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({}));
214+
215+
for (size_t i = 0; i < inst.getNumResults(); ++i)
216+
assert(inst.getOutputName(i) == instMod.getOutputName(i));
217+
inst.replaceAllUsesWith(
218+
newInst.getResults().slice(0, inst.getNumResults()));
219+
inst.erase();
220+
inst = newInst;
221+
}
222+
223+
llvm::StringMap<mlir::Value> instOMap;
224+
llvm::StringMap<mlir::Value> instIMap;
225+
assert(instMod.getOutputNames().size() == inst.getResults().size());
226+
for (auto [oname, oval] :
227+
llvm::zip(instMod.getOutputNames(), inst.getResults()))
228+
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
229+
assert(instMod.getInputNames().size() == inst.getInputs().size());
230+
for (auto [iname, ival] :
231+
llvm::zip(instMod.getInputNames(), inst.getInputs()))
232+
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;
233+
234+
// Get outer expunged descendent first because it may modify the map and
235+
// invalidate iterators.
236+
auto &outerExpDesc = getOuterExpDesc();
237+
auto instExpDesc = expungedDescendents.find(inst.getModuleName());
238+
239+
if (inst.getModuleName() == expunging) {
240+
// Handle the directly expunged module
241+
// input maps also useful for directly expunged instance
242+
243+
auto singletonPath = pathFactory.create(instName);
244+
245+
auto designatedPrefix =
246+
designatedPrefixes.find({processing.getName(), singletonPath});
247+
std::string prefix = designatedPrefix != designatedPrefixes.end()
248+
? designatedPrefix->getSecond().str()
249+
: (instName + "__").str();
250+
251+
// Port name collision is still possible, but current relying on MLIR
252+
// to automatically rename input arguments.
253+
// TODO: name collision detect
254+
255+
createPortsOn(
256+
processing, prefix,
257+
[&](circt::hw::ModulePort port) {
258+
// Generate output for outer module, so input for us
259+
return instIMap.at(port.name);
260+
},
261+
[&](circt::hw::ModulePort port, mlir::Value val) {
262+
// Generated input for outer module, replace inst results
263+
assert(instOMap.contains(port.name));
264+
instOMap[port.name].replaceAllUsesWith(val);
265+
});
266+
267+
outerExpDesc.emplace_back(singletonPath, prefix);
268+
269+
assert(instExpDesc == expungedDescendents.end() ||
270+
instExpDesc->getSecond().size() == 0);
271+
inst.erase();
272+
} else if (instExpDesc != expungedDescendents.end()) {
273+
// Handle all transitive descendents
274+
if (instExpDesc->second.size() == 0)
275+
return;
276+
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
277+
for (const auto &exp : instExpDesc->second) {
278+
auto newPath = pathFactory.add(instName, exp.first);
279+
auto designatedPrefix =
280+
designatedPrefixes.find({processing.getName(), newPath});
281+
std::string prefix = designatedPrefix != designatedPrefixes.end()
282+
? designatedPrefix->getSecond().str()
283+
: defaultPrefix(newPath);
284+
285+
// TODO: name collision detect
286+
287+
createPortsOn(
288+
processing, prefix,
289+
[&](circt::hw::ModulePort port) {
290+
// Generate output for outer module, directly forward from
291+
// inner inst
292+
return instOMap.at((exp.second + port.name.getValue()).str());
293+
},
294+
[&](circt::hw::ModulePort port, mlir::Value val) {
295+
// Generated input for outer module, replace inst results.
296+
// The operand in question has to be an backedge
297+
auto in =
298+
instIMap.at((exp.second + port.name.getValue()).str());
299+
auto inDef = in.getDefiningOp();
300+
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
301+
in.replaceAllUsesWith(val);
302+
inDef->erase();
303+
});
304+
305+
outerExpDesc.emplace_back(newPath, prefix);
306+
}
307+
}
308+
});
309+
}
310+
}
311+
}
312+
313+
std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
314+
return std::make_unique<HWExpungeModulePass>();
315+
}

0 commit comments

Comments
 (0)