Skip to content

Commit 1b1428a

Browse files
committed
[HW] Add Passes: hw-expunge-module, hw-tree-shake
1 parent 0d35c61 commit 1b1428a

File tree

6 files changed

+418
-0
lines changed

6 files changed

+418
-0
lines changed

include/circt/Dialect/HW/HWPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
3333
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
3434
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
3535
std::unique_ptr<mlir::Pass> createFooWiresPass();
36+
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
37+
std::unique_ptr<mlir::Pass> createHWTreeShakePass();
3638

3739
/// Generate the code for registering passes.
3840
#define GEN_PASS_REGISTRATION

include/circt/Dialect/HW/Passes.td

+32
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
7575
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
7676
}
7777

78+
def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
79+
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
80+
let description = [{
81+
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
82+
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
83+
and an random suffix would be added to the <instance_path> part.
84+
85+
For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
86+
}];
87+
let constructor = "circt::hw::createHWExpungeModulePass()";
88+
89+
let options = [
90+
ListOption<"modules", "modules", "std::string",
91+
"Comma separated list of module names to be removed from the hierarchy.">,
92+
ListOption<"portPrefixes", "port-prefixes", "std::string",
93+
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
94+
];
95+
}
96+
97+
def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
98+
let summary = "Remove unused modules.";
99+
let description = [{
100+
This pass removes all modules besides a specified list of modules and their transitive dependencies.
101+
}];
102+
let constructor = "circt::hw::createHWTreeShakePass()";
103+
104+
let options = [
105+
ListOption<"keep", "keep", "std::string",
106+
"Comma separated list of module names to be kept.">,
107+
];
108+
}
109+
78110
/**
79111
* Tutorial Pass, doesn't do anything interesting
80112
*/

lib/Dialect/HW/Transforms/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ add_circt_dialect_library(CIRCTHWTransforms
66
VerifyInnerRefNamespace.cpp
77
FlattenModules.cpp
88
FooWires.cpp
9+
HWExpungeModule.cpp
10+
HWTreeShake.cpp
911

1012
DEPENDS
1113
CIRCTHWTransformsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
#include "circt/Dialect/HW/HWOps.h"
2+
#include "circt/Dialect/HW/HWPasses.h"
3+
#include "circt/Dialect/HW/HWTypes.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "llvm/ADT/DenseMap.h"
6+
#include "llvm/ADT/ImmutableList.h"
7+
#include "llvm/Support/Regex.h"
8+
9+
namespace circt {
10+
namespace hw {
11+
#define GEN_PASS_DEF_HWEXPUNGEMODULE
12+
#include "circt/Dialect/HW/Passes.h.inc"
13+
} // namespace hw
14+
} // namespace circt
15+
16+
namespace {
17+
struct HWExpungeModulePass : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
18+
void runOnOperation() override;
19+
};
20+
21+
struct InstPathSeg {
22+
llvm::StringRef seg;
23+
24+
InstPathSeg(llvm::StringRef seg) : seg(seg) {}
25+
const llvm::StringRef& getSeg() const { return seg; }
26+
operator llvm::StringRef() const { return seg; }
27+
28+
void Profile(llvm::FoldingSetNodeID &ID) const {
29+
ID.AddString(seg);
30+
}
31+
};
32+
using InstPath = llvm::ImmutableList<InstPathSeg>;
33+
std::string defaultPrefix(InstPath path) {
34+
std::string accum;
35+
while(!path.isEmpty()) {
36+
accum += path.getHead().getSeg();
37+
accum += "_";
38+
path = path.getTail();
39+
}
40+
accum += "_";
41+
return std::move(accum);
42+
}
43+
44+
// The regex for port prefix specification
45+
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
46+
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so manually parse the spec
47+
// This parser may accept identifiers with invalid characters
48+
49+
std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>, std::string> parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
50+
auto [l, r] = in.split("=");
51+
if (r == "") return "No '=' found in input";
52+
auto [ll, lr] = l.split(":");
53+
if (lr == "") return "No ':' found before '='";
54+
llvm::SmallVector<llvm::StringRef, 4> segs;
55+
while(lr != "") {
56+
auto [seg, rest] = lr.split(".");
57+
segs.push_back(seg);
58+
lr = rest;
59+
}
60+
InstPath path;
61+
for(auto &seg : llvm::reverse(segs))
62+
path = listFac.add(seg, path);
63+
return std::make_tuple(ll, path, r);
64+
}
65+
} // namespace
66+
67+
void HWExpungeModulePass::runOnOperation() {
68+
auto root = getOperation();
69+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
70+
root.walk([&](circt::hw::HWModuleLike mod) {
71+
allModules[mod.getName()] = mod;
72+
});
73+
// Reverse-topo order generated by module instances.
74+
// The topo sorted order does not change throught out the operation. It only gets weakened, but
75+
// still valid.
76+
llvm::SmallVector<circt::hw::HWModuleLike> revTopo;
77+
llvm::DenseSet<mlir::StringRef> visited;
78+
auto visit = [&allModules, &visited, &revTopo](auto &self, circt::hw::HWModuleLike mod) -> void {
79+
if(visited.contains(mod.getName())) return;
80+
visited.insert(mod.getName());
81+
82+
mod.walk([&](circt::hw::InstanceOp inst) {
83+
auto instModName = inst.getModuleName();
84+
auto instMod = allModules.lookup(instModName);
85+
if(!instMod) inst.emitError("Unknown module ") << instModName;
86+
self(self, instMod);
87+
});
88+
89+
// Reverse topo order, insert on stack pop
90+
revTopo.push_back(mod);
91+
};
92+
for(auto &[_, mod] : allModules) visit(visit, mod);
93+
assert(revTopo.size() == allModules.size());
94+
95+
// Instance path.
96+
InstPath::Factory pathFactory;
97+
98+
// Process port prefix specifications
99+
// (Module name, Instance path) -> Prefix
100+
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef> designatedPrefixes;
101+
bool containsFailure = false;
102+
for(const auto &raw : portPrefixes) {
103+
auto matched = parsePrefixSpec(raw, pathFactory);
104+
if(std::holds_alternative<std::string>(matched)) {
105+
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
106+
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
107+
containsFailure = true;
108+
continue;
109+
}
110+
111+
auto [module, path, prefix] = std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(matched);
112+
if(!allModules.contains(module)) {
113+
llvm::errs() << "Module not found in port prefix specification: " << module << "\n";
114+
llvm::errs() << "From specification: " << raw << "\n";
115+
containsFailure = true;
116+
continue;
117+
}
118+
119+
// Skip checking instance paths' existence. Non-existent paths are ignored
120+
designatedPrefixes.insert({{module, path}, prefix});
121+
}
122+
123+
if(containsFailure) return signalPassFailure();
124+
125+
// Instance path * prefix name
126+
using ReplacedDescendent = std::pair<InstPath, std::string>;
127+
// This map holds the expunged descendents of a module
128+
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>> expungedDescendents;
129+
for(auto &expunging : this->modules) {
130+
// Clear expungedDescendents
131+
for(auto &it : expungedDescendents)
132+
it.getSecond().clear();
133+
134+
auto expungingMod = allModules.lookup(expunging);
135+
if(!expungingMod) continue; // Ignored missing modules
136+
auto expungingModTy = expungingMod.getHWModuleType();
137+
auto expungingModPorts = expungingModTy.getPorts();
138+
139+
auto createPortsOn = [&expungingModPorts](
140+
circt::hw::HWModuleOp mod,
141+
const std::string &prefix,
142+
auto genOutput,
143+
auto emitInput
144+
) {
145+
mlir::OpBuilder builder(mod);
146+
// Create ports using *REVERSE* direction of their definitions
147+
for(auto &port : expungingModPorts) {
148+
auto defaultName = prefix + port.name.getValue();
149+
auto finalName = defaultName;
150+
if(port.dir == circt::hw::PortInfo::Input) {
151+
auto val = genOutput(port);
152+
assert(val.getType() == port.type);
153+
mod.appendOutput(finalName, val);
154+
} else if(port.dir == circt::hw::PortInfo::Output) {
155+
auto [_, arg] = mod.appendInput(finalName, port.type);
156+
emitInput(port, arg);
157+
}
158+
}
159+
};
160+
161+
for(auto &processingRaw : revTopo) {
162+
// Skip extmodule and intmodule because they cannot contain anything
163+
if(!llvm::isa<circt::hw::HWModuleOp>(processingRaw)) continue;
164+
circt::hw::HWModuleOp processing = llvm::cast<circt::hw::HWModuleOp>(processingRaw);
165+
166+
std::optional<decltype(expungedDescendents.lookup("")) *> outerExpDescHold = {};
167+
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
168+
if(!outerExpDescHold.has_value())
169+
outerExpDescHold = { &expungedDescendents.insert({processing.getName(), {}}).first->getSecond() };
170+
return **outerExpDescHold;
171+
};
172+
173+
mlir::OpBuilder outerBuilder(processing);
174+
175+
processing.walk([&](circt::hw::InstanceOp inst) {
176+
auto instName = inst.getInstanceName();
177+
auto instMod = allModules.lookup(inst.getModuleName());
178+
179+
if(
180+
instMod.getOutputNames().size() != inst.getResults().size()
181+
|| instMod.getNumInputPorts() != inst.getInputs().size()
182+
) {
183+
// Module have been modified during this pass, create new instances
184+
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
185+
assert(instMod.getNumInputPorts() >= inst.getInputs().size());
186+
187+
auto instModInTypes = instMod.getInputTypes();
188+
189+
llvm::SmallVector<mlir::Value> newInputs;
190+
newInputs.reserve(instMod.getNumInputPorts());
191+
192+
outerBuilder.setInsertionPointAfter(inst);
193+
194+
// Appended inputs are at the end of the input list
195+
for(size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
196+
mlir::Value input;
197+
if(i < inst.getNumInputPorts()) {
198+
input = inst.getInputs()[i];
199+
if(auto existingName = inst.getInputName(i))
200+
assert(existingName == instMod.getInputName(i));
201+
} else {
202+
input = outerBuilder.create<mlir::UnrealizedConversionCastOp>(
203+
inst.getLoc(), instModInTypes[i], mlir::ValueRange{}).getResult(0);
204+
}
205+
newInputs.push_back(input);
206+
}
207+
208+
auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
209+
inst.getLoc(),
210+
instMod,
211+
inst.getInstanceNameAttr(),
212+
newInputs,
213+
inst.getParameters(),
214+
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({})
215+
);
216+
217+
for(size_t i = 0; i < inst.getNumResults(); ++i)
218+
assert(inst.getOutputName(i) == instMod.getOutputName(i));
219+
inst.replaceAllUsesWith(newInst.getResults().slice(0, inst.getNumResults()));
220+
inst.erase();
221+
inst = newInst;
222+
}
223+
224+
llvm::DenseMap<llvm::StringRef, mlir::Value> instOMap;
225+
llvm::DenseMap<llvm::StringRef, mlir::Value> instIMap;
226+
assert(instMod.getOutputNames().size() == inst.getResults().size());
227+
for(auto [oname, oval] : llvm::zip(instMod.getOutputNames(), inst.getResults()))
228+
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
229+
assert(instMod.getInputNames().size() == inst.getInputs().size());
230+
for(auto [iname, ival] : llvm::zip(instMod.getInputNames(), inst.getInputs()))
231+
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;
232+
233+
// Get outer expunged descendent first because it may modify the map and invalidate iterators.
234+
auto &outerExpDesc = getOuterExpDesc();
235+
auto instExpDesc = expungedDescendents.find(inst.getModuleName());
236+
237+
if(inst.getModuleName() == expunging) {
238+
// Handle the directly expunged module
239+
// input maps also useful for directly expunged instance
240+
241+
auto singletonPath = pathFactory.create(instName);
242+
243+
auto designatedPrefix = designatedPrefixes.find({processing.getName(), singletonPath});
244+
std::string prefix = designatedPrefix != designatedPrefixes.end() ? designatedPrefix->getSecond().str() : (instName + "__").str();
245+
246+
// Port name collision is still possible, but current relying on MLIR
247+
// to automatically rename input arguments.
248+
// TODO: name collision detect
249+
250+
createPortsOn(processing, prefix, [&](circt::hw::ModulePort port) {
251+
// Generate output for outer module, so input for us
252+
return instIMap.at(port.name);
253+
}, [&](circt::hw::ModulePort port, mlir::Value val) {
254+
// Generated input for outer module, replace inst results
255+
assert(instOMap.contains(port.name));
256+
instOMap[port.name].replaceAllUsesWith(val);
257+
});
258+
259+
outerExpDesc.emplace_back(singletonPath, prefix);
260+
261+
assert(instExpDesc == expungedDescendents.end() || instExpDesc->getSecond().size() == 0);
262+
inst.erase();
263+
} else if(instExpDesc != expungedDescendents.end()) {
264+
// Handle all transitive descendents
265+
if(instExpDesc->second.size() == 0) return;
266+
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
267+
for(const auto &exp : instExpDesc->second) {
268+
auto newPath = pathFactory.add(instName, exp.first);
269+
auto designatedPrefix = designatedPrefixes.find({processing.getName(), newPath});
270+
std::string prefix = designatedPrefix != designatedPrefixes.end() ? designatedPrefix->getSecond().str() : defaultPrefix(newPath);
271+
272+
// TODO: name collision detect
273+
274+
createPortsOn(processing, prefix, [&](circt::hw::ModulePort port) {
275+
// Generate output for outer module, directly forward from inner inst
276+
return instOMap.at((exp.second + port.name.getValue()).str());
277+
}, [&](circt::hw::ModulePort port, mlir::Value val) {
278+
// Generated input for outer module, replace inst results.
279+
// The operand in question has to be an backedge
280+
auto in = instIMap.at((exp.second + port.name.getValue()).str());
281+
auto inDef = in.getDefiningOp();
282+
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
283+
in.replaceAllUsesWith(val);
284+
inDef->erase();
285+
});
286+
287+
outerExpDesc.emplace_back(newPath, prefix);
288+
}
289+
}
290+
});
291+
}
292+
}
293+
}
294+
295+
std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
296+
return std::make_unique<HWExpungeModulePass>();
297+
}

0 commit comments

Comments
 (0)