1
+ #include " circt/Dialect/HW/HWOps.h"
2
+ #include " circt/Dialect/HW/HWPasses.h"
3
+ #include " circt/Dialect/HW/HWTypes.h"
4
+ #include " mlir/Pass/Pass.h"
5
+ #include " llvm/ADT/DenseMap.h"
6
+ #include " llvm/ADT/ImmutableList.h"
7
+ #include " llvm/Support/Regex.h"
8
+
9
+ namespace circt {
10
+ namespace hw {
11
+ #define GEN_PASS_DEF_HWEXPUNGEMODULE
12
+ #include " circt/Dialect/HW/Passes.h.inc"
13
+ } // namespace hw
14
+ } // namespace circt
15
+
16
+ namespace {
17
+ struct HWExpungeModulePass : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
18
+ void runOnOperation () override ;
19
+ };
20
+
21
+ struct InstPathSeg {
22
+ llvm::StringRef seg;
23
+
24
+ InstPathSeg (llvm::StringRef seg) : seg(seg) {}
25
+ const llvm::StringRef& getSeg () const { return seg; }
26
+ operator llvm::StringRef () const { return seg; }
27
+
28
+ void Profile (llvm::FoldingSetNodeID &ID) const {
29
+ ID.AddString (seg);
30
+ }
31
+ };
32
+ using InstPath = llvm::ImmutableList<InstPathSeg>;
33
+ std::string defaultPrefix (InstPath path) {
34
+ std::string accum;
35
+ while (!path.isEmpty ()) {
36
+ accum += path.getHead ().getSeg ();
37
+ accum += " _" ;
38
+ path = path.getTail ();
39
+ }
40
+ accum += " _" ;
41
+ return std::move (accum);
42
+ }
43
+
44
+ // The regex for port prefix specification
45
+ // "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
46
+ // Unfortunately, the LLVM Regex cannot capture repeating capture groups, so manually parse the spec
47
+ // This parser may accept identifiers with invalid characters
48
+
49
+ std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>, std::string> parsePrefixSpec (llvm::StringRef in, InstPath::Factory &listFac) {
50
+ auto [l, r] = in.split (" =" );
51
+ if (r == " " ) return " No '=' found in input" ;
52
+ auto [ll, lr] = l.split (" :" );
53
+ if (lr == " " ) return " No ':' found before '='" ;
54
+ llvm::SmallVector<llvm::StringRef, 4 > segs;
55
+ while (lr != " " ) {
56
+ auto [seg, rest] = lr.split (" ." );
57
+ segs.push_back (seg);
58
+ lr = rest;
59
+ }
60
+ InstPath path;
61
+ for (auto &seg : llvm::reverse (segs))
62
+ path = listFac.add (seg, path);
63
+ return std::make_tuple (ll, path, r);
64
+ }
65
+ } // namespace
66
+
67
+ void HWExpungeModulePass::runOnOperation () {
68
+ auto root = getOperation ();
69
+ llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
70
+ root.walk ([&](circt::hw::HWModuleLike mod) {
71
+ allModules[mod.getName ()] = mod;
72
+ });
73
+ // Reverse-topo order generated by module instances.
74
+ // The topo sorted order does not change throught out the operation. It only gets weakened, but
75
+ // still valid.
76
+ llvm::SmallVector<circt::hw::HWModuleLike> revTopo;
77
+ llvm::DenseSet<mlir::StringRef> visited;
78
+ auto visit = [&allModules, &visited, &revTopo](auto &self, circt::hw::HWModuleLike mod) -> void {
79
+ if (visited.contains (mod.getName ())) return ;
80
+ visited.insert (mod.getName ());
81
+
82
+ mod.walk ([&](circt::hw::InstanceOp inst) {
83
+ auto instModName = inst.getModuleName ();
84
+ auto instMod = allModules.lookup (instModName);
85
+ if (!instMod) inst.emitError (" Unknown module " ) << instModName;
86
+ self (self, instMod);
87
+ });
88
+
89
+ // Reverse topo order, insert on stack pop
90
+ revTopo.push_back (mod);
91
+ };
92
+ for (auto &[_, mod] : allModules) visit (visit, mod);
93
+ assert (revTopo.size () == allModules.size ());
94
+
95
+ // Instance path.
96
+ InstPath::Factory pathFactory;
97
+
98
+ // Process port prefix specifications
99
+ // (Module name, Instance path) -> Prefix
100
+ llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef> designatedPrefixes;
101
+ bool containsFailure = false ;
102
+ for (const auto &raw : portPrefixes) {
103
+ auto matched = parsePrefixSpec (raw, pathFactory);
104
+ if (std::holds_alternative<std::string>(matched)) {
105
+ llvm::errs () << " Invalid port prefix specification: " << raw << " \n " ;
106
+ llvm::errs () << " Error: " << std::get<std::string>(matched) << " \n " ;
107
+ containsFailure = true ;
108
+ continue ;
109
+ }
110
+
111
+ auto [module, path, prefix] = std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(matched);
112
+ if (!allModules.contains (module)) {
113
+ llvm::errs () << " Module not found in port prefix specification: " << module << " \n " ;
114
+ llvm::errs () << " From specification: " << raw << " \n " ;
115
+ containsFailure = true ;
116
+ continue ;
117
+ }
118
+
119
+ // Skip checking instance paths' existence. Non-existent paths are ignored
120
+ designatedPrefixes.insert ({{module, path}, prefix});
121
+ }
122
+
123
+ if (containsFailure) return signalPassFailure ();
124
+
125
+ // Instance path * prefix name
126
+ using ReplacedDescendent = std::pair<InstPath, std::string>;
127
+ // This map holds the expunged descendents of a module
128
+ llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>> expungedDescendents;
129
+ for (auto &expunging : this ->modules ) {
130
+ // Clear expungedDescendents
131
+ for (auto &it : expungedDescendents)
132
+ it.getSecond ().clear ();
133
+
134
+ auto expungingMod = allModules.lookup (expunging);
135
+ if (!expungingMod) continue ; // Ignored missing modules
136
+ auto expungingModTy = expungingMod.getHWModuleType ();
137
+ auto expungingModPorts = expungingModTy.getPorts ();
138
+
139
+ auto createPortsOn = [&expungingModPorts](
140
+ circt::hw::HWModuleOp mod,
141
+ const std::string &prefix,
142
+ auto genOutput,
143
+ auto emitInput
144
+ ) {
145
+ mlir::OpBuilder builder (mod);
146
+ // Create ports using *REVERSE* direction of their definitions
147
+ for (auto &port : expungingModPorts) {
148
+ auto defaultName = prefix + port.name .getValue ();
149
+ auto finalName = defaultName;
150
+ if (port.dir == circt::hw::PortInfo::Input) {
151
+ auto val = genOutput (port);
152
+ assert (val.getType () == port.type );
153
+ mod.appendOutput (finalName, val);
154
+ } else if (port.dir == circt::hw::PortInfo::Output) {
155
+ auto [_, arg] = mod.appendInput (finalName, port.type );
156
+ emitInput (port, arg);
157
+ }
158
+ }
159
+ };
160
+
161
+ for (auto &processingRaw : revTopo) {
162
+ // Skip extmodule and intmodule because they cannot contain anything
163
+ if (!llvm::isa<circt::hw::HWModuleOp>(processingRaw)) continue ;
164
+ circt::hw::HWModuleOp processing = llvm::cast<circt::hw::HWModuleOp>(processingRaw);
165
+
166
+ std::optional<decltype (expungedDescendents.lookup (" " )) *> outerExpDescHold = {};
167
+ auto getOuterExpDesc = [&]() -> decltype (**outerExpDescHold) {
168
+ if (!outerExpDescHold.has_value ())
169
+ outerExpDescHold = { &expungedDescendents.insert ({processing.getName (), {}}).first ->getSecond () };
170
+ return **outerExpDescHold;
171
+ };
172
+
173
+ mlir::OpBuilder outerBuilder (processing);
174
+
175
+ processing.walk ([&](circt::hw::InstanceOp inst) {
176
+ auto instName = inst.getInstanceName ();
177
+ auto instMod = allModules.lookup (inst.getModuleName ());
178
+
179
+ if (
180
+ instMod.getOutputNames ().size () != inst.getResults ().size ()
181
+ || instMod.getNumInputPorts () != inst.getInputs ().size ()
182
+ ) {
183
+ // Module have been modified during this pass, create new instances
184
+ assert (instMod.getNumOutputPorts () >= inst.getResults ().size ());
185
+ assert (instMod.getNumInputPorts () >= inst.getInputs ().size ());
186
+
187
+ auto instModInTypes = instMod.getInputTypes ();
188
+
189
+ llvm::SmallVector<mlir::Value> newInputs;
190
+ newInputs.reserve (instMod.getNumInputPorts ());
191
+
192
+ outerBuilder.setInsertionPointAfter (inst);
193
+
194
+ // Appended inputs are at the end of the input list
195
+ for (size_t i = 0 ; i < instMod.getNumInputPorts (); ++i) {
196
+ mlir::Value input;
197
+ if (i < inst.getNumInputPorts ()) {
198
+ input = inst.getInputs ()[i];
199
+ if (auto existingName = inst.getInputName (i))
200
+ assert (existingName == instMod.getInputName (i));
201
+ } else {
202
+ input = outerBuilder.create <mlir::UnrealizedConversionCastOp>(
203
+ inst.getLoc (), instModInTypes[i], mlir::ValueRange{}).getResult (0 );
204
+ }
205
+ newInputs.push_back (input);
206
+ }
207
+
208
+ auto newInst = outerBuilder.create <circt::hw::InstanceOp>(
209
+ inst.getLoc (),
210
+ instMod,
211
+ inst.getInstanceNameAttr (),
212
+ newInputs,
213
+ inst.getParameters (),
214
+ inst.getInnerSym ().value_or <circt::hw::InnerSymAttr>({})
215
+ );
216
+
217
+ for (size_t i = 0 ; i < inst.getNumResults (); ++i)
218
+ assert (inst.getOutputName (i) == instMod.getOutputName (i));
219
+ inst.replaceAllUsesWith (newInst.getResults ().slice (0 , inst.getNumResults ()));
220
+ inst.erase ();
221
+ inst = newInst;
222
+ }
223
+
224
+ llvm::DenseMap<llvm::StringRef, mlir::Value> instOMap;
225
+ llvm::DenseMap<llvm::StringRef, mlir::Value> instIMap;
226
+ assert (instMod.getOutputNames ().size () == inst.getResults ().size ());
227
+ for (auto [oname, oval] : llvm::zip (instMod.getOutputNames (), inst.getResults ()))
228
+ instOMap[llvm::cast<mlir::StringAttr>(oname).getValue ()] = oval;
229
+ assert (instMod.getInputNames ().size () == inst.getInputs ().size ());
230
+ for (auto [iname, ival] : llvm::zip (instMod.getInputNames (), inst.getInputs ()))
231
+ instIMap[llvm::cast<mlir::StringAttr>(iname).getValue ()] = ival;
232
+
233
+ // Get outer expunged descendent first because it may modify the map and invalidate iterators.
234
+ auto &outerExpDesc = getOuterExpDesc ();
235
+ auto instExpDesc = expungedDescendents.find (inst.getModuleName ());
236
+
237
+ if (inst.getModuleName () == expunging) {
238
+ // Handle the directly expunged module
239
+ // input maps also useful for directly expunged instance
240
+
241
+ auto singletonPath = pathFactory.create (instName);
242
+
243
+ auto designatedPrefix = designatedPrefixes.find ({processing.getName (), singletonPath});
244
+ std::string prefix = designatedPrefix != designatedPrefixes.end () ? designatedPrefix->getSecond ().str () : (instName + " __" ).str ();
245
+
246
+ // Port name collision is still possible, but current relying on MLIR
247
+ // to automatically rename input arguments.
248
+ // TODO: name collision detect
249
+
250
+ createPortsOn (processing, prefix, [&](circt::hw::ModulePort port) {
251
+ // Generate output for outer module, so input for us
252
+ return instIMap.at (port.name );
253
+ }, [&](circt::hw::ModulePort port, mlir::Value val) {
254
+ // Generated input for outer module, replace inst results
255
+ assert (instOMap.contains (port.name ));
256
+ instOMap[port.name ].replaceAllUsesWith (val);
257
+ });
258
+
259
+ outerExpDesc.emplace_back (singletonPath, prefix);
260
+
261
+ assert (instExpDesc == expungedDescendents.end () || instExpDesc->getSecond ().size () == 0 );
262
+ inst.erase ();
263
+ } else if (instExpDesc != expungedDescendents.end ()) {
264
+ // Handle all transitive descendents
265
+ if (instExpDesc->second .size () == 0 ) return ;
266
+ llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
267
+ for (const auto &exp : instExpDesc->second ) {
268
+ auto newPath = pathFactory.add (instName, exp .first );
269
+ auto designatedPrefix = designatedPrefixes.find ({processing.getName (), newPath});
270
+ std::string prefix = designatedPrefix != designatedPrefixes.end () ? designatedPrefix->getSecond ().str () : defaultPrefix (newPath);
271
+
272
+ // TODO: name collision detect
273
+
274
+ createPortsOn (processing, prefix, [&](circt::hw::ModulePort port) {
275
+ // Generate output for outer module, directly forward from inner inst
276
+ return instOMap.at ((exp .second + port.name .getValue ()).str ());
277
+ }, [&](circt::hw::ModulePort port, mlir::Value val) {
278
+ // Generated input for outer module, replace inst results.
279
+ // The operand in question has to be an backedge
280
+ auto in = instIMap.at ((exp .second + port.name .getValue ()).str ());
281
+ auto inDef = in.getDefiningOp ();
282
+ assert (llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
283
+ in.replaceAllUsesWith (val);
284
+ inDef->erase ();
285
+ });
286
+
287
+ outerExpDesc.emplace_back (newPath, prefix);
288
+ }
289
+ }
290
+ });
291
+ }
292
+ }
293
+ }
294
+
295
+ std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass () {
296
+ return std::make_unique<HWExpungeModulePass>();
297
+ }
0 commit comments