Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1b1428a

Browse files
committedJan 12, 2025··
[HW] Add Passes: hw-expunge-module, hw-tree-shake
1 parent 0d35c61 commit 1b1428a

File tree

6 files changed

+418
-0
lines changed

6 files changed

+418
-0
lines changed
 

‎include/circt/Dialect/HW/HWPasses.h

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ std::unique_ptr<mlir::Pass> createFlattenIOPass(bool recursiveFlag = true,
3333
std::unique_ptr<mlir::Pass> createVerifyInnerRefNamespacePass();
3434
std::unique_ptr<mlir::Pass> createFlattenModulesPass();
3535
std::unique_ptr<mlir::Pass> createFooWiresPass();
36+
std::unique_ptr<mlir::Pass> createHWExpungeModulePass();
37+
std::unique_ptr<mlir::Pass> createHWTreeShakePass();
3638

3739
/// Generate the code for registering passes.
3840
#define GEN_PASS_REGISTRATION

‎include/circt/Dialect/HW/Passes.td

+32
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,38 @@ def VerifyInnerRefNamespace : Pass<"hw-verify-irn"> {
7575
let constructor = "circt::hw::createVerifyInnerRefNamespacePass()";
7676
}
7777

78+
def HWExpungeModule : Pass<"hw-expunge-module", "mlir::ModuleOp"> {
79+
let summary = "Remove module from the hierarchy, and recursively expose their ports to upper level.";
80+
let description = [{
81+
This pass removes a list of modules from the hierarchy on-by-one, recursively exposing their ports to upper level.
82+
The newly generated ports are by default named as <instance_path>__<port_name>. During a naming conflict, an warning would be genreated,
83+
and an random suffix would be added to the <instance_path> part.
84+
85+
For each given (transitive) parent module, the prefix can alternatively be specified by option instead of using the instance path.
86+
}];
87+
let constructor = "circt::hw::createHWExpungeModulePass()";
88+
89+
let options = [
90+
ListOption<"modules", "modules", "std::string",
91+
"Comma separated list of module names to be removed from the hierarchy.">,
92+
ListOption<"portPrefixes", "port-prefixes", "std::string",
93+
"Specify the prefix for ports of a given parent module's expunged childen. Each specification is formatted as <module>:<instance-path>=<prefix>. Only affect the top-most level module of the instance path.">,
94+
];
95+
}
96+
97+
def HWTreeShake : Pass<"hw-tree-shake", "mlir::ModuleOp"> {
98+
let summary = "Remove unused modules.";
99+
let description = [{
100+
This pass removes all modules besides a specified list of modules and their transitive dependencies.
101+
}];
102+
let constructor = "circt::hw::createHWTreeShakePass()";
103+
104+
let options = [
105+
ListOption<"keep", "keep", "std::string",
106+
"Comma separated list of module names to be kept.">,
107+
];
108+
}
109+
78110
/**
79111
* Tutorial Pass, doesn't do anything interesting
80112
*/

‎lib/Dialect/HW/Transforms/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ add_circt_dialect_library(CIRCTHWTransforms
66
VerifyInnerRefNamespace.cpp
77
FlattenModules.cpp
88
FooWires.cpp
9+
HWExpungeModule.cpp
10+
HWTreeShake.cpp
911

1012
DEPENDS
1113
CIRCTHWTransformsIncGen
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
#include "circt/Dialect/HW/HWOps.h"
2+
#include "circt/Dialect/HW/HWPasses.h"
3+
#include "circt/Dialect/HW/HWTypes.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "llvm/ADT/DenseMap.h"
6+
#include "llvm/ADT/ImmutableList.h"
7+
#include "llvm/Support/Regex.h"
8+
9+
namespace circt {
10+
namespace hw {
11+
#define GEN_PASS_DEF_HWEXPUNGEMODULE
12+
#include "circt/Dialect/HW/Passes.h.inc"
13+
} // namespace hw
14+
} // namespace circt
15+
16+
namespace {
17+
struct HWExpungeModulePass : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
18+
void runOnOperation() override;
19+
};
20+
21+
struct InstPathSeg {
22+
llvm::StringRef seg;
23+
24+
InstPathSeg(llvm::StringRef seg) : seg(seg) {}
25+
const llvm::StringRef& getSeg() const { return seg; }
26+
operator llvm::StringRef() const { return seg; }
27+
28+
void Profile(llvm::FoldingSetNodeID &ID) const {
29+
ID.AddString(seg);
30+
}
31+
};
32+
using InstPath = llvm::ImmutableList<InstPathSeg>;
33+
std::string defaultPrefix(InstPath path) {
34+
std::string accum;
35+
while(!path.isEmpty()) {
36+
accum += path.getHead().getSeg();
37+
accum += "_";
38+
path = path.getTail();
39+
}
40+
accum += "_";
41+
return std::move(accum);
42+
}
43+
44+
// The regex for port prefix specification
45+
// "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
46+
// Unfortunately, the LLVM Regex cannot capture repeating capture groups, so manually parse the spec
47+
// This parser may accept identifiers with invalid characters
48+
49+
std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>, std::string> parsePrefixSpec(llvm::StringRef in, InstPath::Factory &listFac) {
50+
auto [l, r] = in.split("=");
51+
if (r == "") return "No '=' found in input";
52+
auto [ll, lr] = l.split(":");
53+
if (lr == "") return "No ':' found before '='";
54+
llvm::SmallVector<llvm::StringRef, 4> segs;
55+
while(lr != "") {
56+
auto [seg, rest] = lr.split(".");
57+
segs.push_back(seg);
58+
lr = rest;
59+
}
60+
InstPath path;
61+
for(auto &seg : llvm::reverse(segs))
62+
path = listFac.add(seg, path);
63+
return std::make_tuple(ll, path, r);
64+
}
65+
} // namespace
66+
67+
void HWExpungeModulePass::runOnOperation() {
68+
auto root = getOperation();
69+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
70+
root.walk([&](circt::hw::HWModuleLike mod) {
71+
allModules[mod.getName()] = mod;
72+
});
73+
// Reverse-topo order generated by module instances.
74+
// The topo sorted order does not change throught out the operation. It only gets weakened, but
75+
// still valid.
76+
llvm::SmallVector<circt::hw::HWModuleLike> revTopo;
77+
llvm::DenseSet<mlir::StringRef> visited;
78+
auto visit = [&allModules, &visited, &revTopo](auto &self, circt::hw::HWModuleLike mod) -> void {
79+
if(visited.contains(mod.getName())) return;
80+
visited.insert(mod.getName());
81+
82+
mod.walk([&](circt::hw::InstanceOp inst) {
83+
auto instModName = inst.getModuleName();
84+
auto instMod = allModules.lookup(instModName);
85+
if(!instMod) inst.emitError("Unknown module ") << instModName;
86+
self(self, instMod);
87+
});
88+
89+
// Reverse topo order, insert on stack pop
90+
revTopo.push_back(mod);
91+
};
92+
for(auto &[_, mod] : allModules) visit(visit, mod);
93+
assert(revTopo.size() == allModules.size());
94+
95+
// Instance path.
96+
InstPath::Factory pathFactory;
97+
98+
// Process port prefix specifications
99+
// (Module name, Instance path) -> Prefix
100+
llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef> designatedPrefixes;
101+
bool containsFailure = false;
102+
for(const auto &raw : portPrefixes) {
103+
auto matched = parsePrefixSpec(raw, pathFactory);
104+
if(std::holds_alternative<std::string>(matched)) {
105+
llvm::errs() << "Invalid port prefix specification: " << raw << "\n";
106+
llvm::errs() << "Error: " << std::get<std::string>(matched) << "\n";
107+
containsFailure = true;
108+
continue;
109+
}
110+
111+
auto [module, path, prefix] = std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(matched);
112+
if(!allModules.contains(module)) {
113+
llvm::errs() << "Module not found in port prefix specification: " << module << "\n";
114+
llvm::errs() << "From specification: " << raw << "\n";
115+
containsFailure = true;
116+
continue;
117+
}
118+
119+
// Skip checking instance paths' existence. Non-existent paths are ignored
120+
designatedPrefixes.insert({{module, path}, prefix});
121+
}
122+
123+
if(containsFailure) return signalPassFailure();
124+
125+
// Instance path * prefix name
126+
using ReplacedDescendent = std::pair<InstPath, std::string>;
127+
// This map holds the expunged descendents of a module
128+
llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>> expungedDescendents;
129+
for(auto &expunging : this->modules) {
130+
// Clear expungedDescendents
131+
for(auto &it : expungedDescendents)
132+
it.getSecond().clear();
133+
134+
auto expungingMod = allModules.lookup(expunging);
135+
if(!expungingMod) continue; // Ignored missing modules
136+
auto expungingModTy = expungingMod.getHWModuleType();
137+
auto expungingModPorts = expungingModTy.getPorts();
138+
139+
auto createPortsOn = [&expungingModPorts](
140+
circt::hw::HWModuleOp mod,
141+
const std::string &prefix,
142+
auto genOutput,
143+
auto emitInput
144+
) {
145+
mlir::OpBuilder builder(mod);
146+
// Create ports using *REVERSE* direction of their definitions
147+
for(auto &port : expungingModPorts) {
148+
auto defaultName = prefix + port.name.getValue();
149+
auto finalName = defaultName;
150+
if(port.dir == circt::hw::PortInfo::Input) {
151+
auto val = genOutput(port);
152+
assert(val.getType() == port.type);
153+
mod.appendOutput(finalName, val);
154+
} else if(port.dir == circt::hw::PortInfo::Output) {
155+
auto [_, arg] = mod.appendInput(finalName, port.type);
156+
emitInput(port, arg);
157+
}
158+
}
159+
};
160+
161+
for(auto &processingRaw : revTopo) {
162+
// Skip extmodule and intmodule because they cannot contain anything
163+
if(!llvm::isa<circt::hw::HWModuleOp>(processingRaw)) continue;
164+
circt::hw::HWModuleOp processing = llvm::cast<circt::hw::HWModuleOp>(processingRaw);
165+
166+
std::optional<decltype(expungedDescendents.lookup("")) *> outerExpDescHold = {};
167+
auto getOuterExpDesc = [&]() -> decltype(**outerExpDescHold) {
168+
if(!outerExpDescHold.has_value())
169+
outerExpDescHold = { &expungedDescendents.insert({processing.getName(), {}}).first->getSecond() };
170+
return **outerExpDescHold;
171+
};
172+
173+
mlir::OpBuilder outerBuilder(processing);
174+
175+
processing.walk([&](circt::hw::InstanceOp inst) {
176+
auto instName = inst.getInstanceName();
177+
auto instMod = allModules.lookup(inst.getModuleName());
178+
179+
if(
180+
instMod.getOutputNames().size() != inst.getResults().size()
181+
|| instMod.getNumInputPorts() != inst.getInputs().size()
182+
) {
183+
// Module have been modified during this pass, create new instances
184+
assert(instMod.getNumOutputPorts() >= inst.getResults().size());
185+
assert(instMod.getNumInputPorts() >= inst.getInputs().size());
186+
187+
auto instModInTypes = instMod.getInputTypes();
188+
189+
llvm::SmallVector<mlir::Value> newInputs;
190+
newInputs.reserve(instMod.getNumInputPorts());
191+
192+
outerBuilder.setInsertionPointAfter(inst);
193+
194+
// Appended inputs are at the end of the input list
195+
for(size_t i = 0; i < instMod.getNumInputPorts(); ++i) {
196+
mlir::Value input;
197+
if(i < inst.getNumInputPorts()) {
198+
input = inst.getInputs()[i];
199+
if(auto existingName = inst.getInputName(i))
200+
assert(existingName == instMod.getInputName(i));
201+
} else {
202+
input = outerBuilder.create<mlir::UnrealizedConversionCastOp>(
203+
inst.getLoc(), instModInTypes[i], mlir::ValueRange{}).getResult(0);
204+
}
205+
newInputs.push_back(input);
206+
}
207+
208+
auto newInst = outerBuilder.create<circt::hw::InstanceOp>(
209+
inst.getLoc(),
210+
instMod,
211+
inst.getInstanceNameAttr(),
212+
newInputs,
213+
inst.getParameters(),
214+
inst.getInnerSym().value_or<circt::hw::InnerSymAttr>({})
215+
);
216+
217+
for(size_t i = 0; i < inst.getNumResults(); ++i)
218+
assert(inst.getOutputName(i) == instMod.getOutputName(i));
219+
inst.replaceAllUsesWith(newInst.getResults().slice(0, inst.getNumResults()));
220+
inst.erase();
221+
inst = newInst;
222+
}
223+
224+
llvm::DenseMap<llvm::StringRef, mlir::Value> instOMap;
225+
llvm::DenseMap<llvm::StringRef, mlir::Value> instIMap;
226+
assert(instMod.getOutputNames().size() == inst.getResults().size());
227+
for(auto [oname, oval] : llvm::zip(instMod.getOutputNames(), inst.getResults()))
228+
instOMap[llvm::cast<mlir::StringAttr>(oname).getValue()] = oval;
229+
assert(instMod.getInputNames().size() == inst.getInputs().size());
230+
for(auto [iname, ival] : llvm::zip(instMod.getInputNames(), inst.getInputs()))
231+
instIMap[llvm::cast<mlir::StringAttr>(iname).getValue()] = ival;
232+
233+
// Get outer expunged descendent first because it may modify the map and invalidate iterators.
234+
auto &outerExpDesc = getOuterExpDesc();
235+
auto instExpDesc = expungedDescendents.find(inst.getModuleName());
236+
237+
if(inst.getModuleName() == expunging) {
238+
// Handle the directly expunged module
239+
// input maps also useful for directly expunged instance
240+
241+
auto singletonPath = pathFactory.create(instName);
242+
243+
auto designatedPrefix = designatedPrefixes.find({processing.getName(), singletonPath});
244+
std::string prefix = designatedPrefix != designatedPrefixes.end() ? designatedPrefix->getSecond().str() : (instName + "__").str();
245+
246+
// Port name collision is still possible, but current relying on MLIR
247+
// to automatically rename input arguments.
248+
// TODO: name collision detect
249+
250+
createPortsOn(processing, prefix, [&](circt::hw::ModulePort port) {
251+
// Generate output for outer module, so input for us
252+
return instIMap.at(port.name);
253+
}, [&](circt::hw::ModulePort port, mlir::Value val) {
254+
// Generated input for outer module, replace inst results
255+
assert(instOMap.contains(port.name));
256+
instOMap[port.name].replaceAllUsesWith(val);
257+
});
258+
259+
outerExpDesc.emplace_back(singletonPath, prefix);
260+
261+
assert(instExpDesc == expungedDescendents.end() || instExpDesc->getSecond().size() == 0);
262+
inst.erase();
263+
} else if(instExpDesc != expungedDescendents.end()) {
264+
// Handle all transitive descendents
265+
if(instExpDesc->second.size() == 0) return;
266+
llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
267+
for(const auto &exp : instExpDesc->second) {
268+
auto newPath = pathFactory.add(instName, exp.first);
269+
auto designatedPrefix = designatedPrefixes.find({processing.getName(), newPath});
270+
std::string prefix = designatedPrefix != designatedPrefixes.end() ? designatedPrefix->getSecond().str() : defaultPrefix(newPath);
271+
272+
// TODO: name collision detect
273+
274+
createPortsOn(processing, prefix, [&](circt::hw::ModulePort port) {
275+
// Generate output for outer module, directly forward from inner inst
276+
return instOMap.at((exp.second + port.name.getValue()).str());
277+
}, [&](circt::hw::ModulePort port, mlir::Value val) {
278+
// Generated input for outer module, replace inst results.
279+
// The operand in question has to be an backedge
280+
auto in = instIMap.at((exp.second + port.name.getValue()).str());
281+
auto inDef = in.getDefiningOp();
282+
assert(llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
283+
in.replaceAllUsesWith(val);
284+
inDef->erase();
285+
});
286+
287+
outerExpDesc.emplace_back(newPath, prefix);
288+
}
289+
}
290+
});
291+
}
292+
}
293+
}
294+
295+
std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass() {
296+
return std::make_unique<HWExpungeModulePass>();
297+
}
+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#include "circt/Dialect/HW/HWOps.h"
2+
#include "circt/Dialect/HW/HWPasses.h"
3+
#include "circt/Dialect/HW/HWTypes.h"
4+
#include "mlir/Pass/Pass.h"
5+
#include "llvm/ADT/DenseMap.h"
6+
#include "llvm/Support/Regex.h"
7+
8+
namespace circt {
9+
namespace hw {
10+
#define GEN_PASS_DEF_HWTREESHAKE
11+
#include "circt/Dialect/HW/Passes.h.inc"
12+
} // namespace hw
13+
} // namespace circt
14+
15+
struct HWTreeShakePass : circt::hw::impl::HWTreeShakeBase<HWTreeShakePass> {
16+
void runOnOperation() override;
17+
};
18+
19+
void HWTreeShakePass::runOnOperation() {
20+
auto root = getOperation();
21+
llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
22+
root.walk([&](circt::hw::HWModuleLike mod) {
23+
allModules[mod.getName()] = mod;
24+
});
25+
26+
llvm::DenseSet<circt::hw::HWModuleLike> visited;
27+
auto visit = [&allModules, &visited](auto &self, circt::hw::HWModuleLike mod) -> void {
28+
if(visited.contains(mod)) return;
29+
visited.insert(mod);
30+
mod.walk([&](circt::hw::InstanceOp inst) {
31+
auto modName = inst.getModuleName();
32+
self(self, allModules.at(modName));
33+
});
34+
};
35+
36+
for(const auto &kept : keep) {
37+
auto lookup = allModules.find(kept);
38+
if(lookup == allModules.end()) continue; // Silently ignore missing modules
39+
visit(visit, lookup->getSecond());
40+
}
41+
42+
for(auto &mod : allModules) {
43+
if(!visited.contains(mod.getSecond())) {
44+
mod.getSecond()->remove();
45+
}
46+
}
47+
}
48+
49+
std::unique_ptr<mlir::Pass> circt::hw::createHWTreeShakePass() {
50+
return std::make_unique<HWTreeShakePass>();
51+
}

‎test/Dialect/HW/expunge-module.mlir

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}})" %s | FileCheck %s --check-prefixes FOO,BAR,BAZ,COMMON
2+
// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}},hw-tree-shake{keep=foo})" %s | FileCheck %s --check-prefixes FOO,BAR,COMMON
3+
// RUN: circt-opt --pass-pipeline="builtin.module(hw-expunge-module{modules={baz,b} port-prefixes={foo:bar2.baz2=meow_,bar:baz1=nya_}},hw-tree-shake{keep=baz})" %s | FileCheck %s --check-prefixes BAZ,COMMON
4+
5+
module {
6+
hw.module @foo(in %bar1_baz1__out : i2, out test : i1) {
7+
%bar1.self_out = hw.instance "bar1" @bar(self_in: %0: i1) -> (self_out: i1)
8+
%bar2.self_out = hw.instance "bar2" @bar(self_in: %bar1.self_out: i1) -> (self_out: i1)
9+
%0 = comb.extract %bar1_baz1__out from 0 : (i2) -> i1
10+
hw.output %bar2.self_out : i1
11+
}
12+
hw.module private @bar(in %self_in : i1, out self_out : i1) {
13+
%baz1.out = hw.instance "baz1" @baz(in: %self_in: i1) -> (out: i1)
14+
%baz2.out = hw.instance "baz2" @baz(in: %baz1.out: i1) -> (out: i1)
15+
hw.output %baz2.out : i1
16+
}
17+
hw.module private @baz(in %in : i1, out out : i1) {
18+
hw.output %in : i1
19+
}
20+
}
21+
22+
// COMMON: module {
23+
// FOO-NEXT: hw.module @foo(in %bar1_baz1__out : i2, in %bar1_baz1__out_0 : i1, in %bar1_baz2__out : i1, in %bar2_baz1__out : i1, in %meow_out : i1, out test : i1, out bar1_baz1__in : i1, out bar1_baz2__in : i1, out bar2_baz1__in : i1, out meow_in : i1) {
24+
// FOO-NEXT: %bar1.self_out, %bar1.nya_in, %bar1.baz2__in = hw.instance "bar1" @bar(self_in: %0: i1, nya_out: %bar1_baz1__out_0: i1, baz2__out: %bar1_baz2__out: i1) -> (self_out: i1, nya_in: i1, baz2__in: i1)
25+
// FOO-NEXT: %bar2.self_out, %bar2.nya_in, %bar2.baz2__in = hw.instance "bar2" @bar(self_in: %bar1.self_out: i1, nya_out: %bar2_baz1__out: i1, baz2__out: %meow_out: i1) -> (self_out: i1, nya_in: i1, baz2__in: i1)
26+
// FOO-NEXT: %0 = comb.extract %bar1_baz1__out from 0 : (i2) -> i1
27+
// FOO-NEXT: hw.output %bar2.self_out, %bar1.nya_in, %bar1.baz2__in, %bar2.nya_in, %bar2.baz2__in : i1, i1, i1, i1, i1
28+
// FOO-NEXT: }
29+
// BAR-NEXT: hw.module private @bar(in %self_in : i1, in %nya_out : i1, in %baz2__out : i1, out self_out : i1, out nya_in : i1, out baz2__in : i1) {
30+
// BAR-NEXT: hw.output %baz2__out, %self_in, %nya_out : i1, i1, i1
31+
// BAR-NEXT: }
32+
// BAZ-NEXT: hw.module private @baz(in %in : i1, out out : i1) {
33+
// BAZ-NEXT: hw.output %in : i1
34+
// BAZ-NEXT: }

0 commit comments

Comments
 (0)
Please sign in to comment.