1
+ #include " circt/Dialect/HW/HWOps.h"
2
+ #include " circt/Dialect/HW/HWPasses.h"
3
+ #include " circt/Dialect/HW/HWTypes.h"
4
+ #include " mlir/Pass/Pass.h"
5
+ #include " llvm/ADT/DenseMap.h"
6
+ #include " llvm/ADT/ImmutableList.h"
7
+ #include " llvm/Support/Regex.h"
8
+
9
+ namespace circt {
10
+ namespace hw {
11
+ #define GEN_PASS_DEF_HWEXPUNGEMODULE
12
+ #include " circt/Dialect/HW/Passes.h.inc"
13
+ } // namespace hw
14
+ } // namespace circt
15
+
16
+ namespace {
17
+ struct HWExpungeModulePass
18
+ : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
19
+ void runOnOperation () override ;
20
+ };
21
+
22
+ struct InstPathSeg {
23
+ llvm::StringRef seg;
24
+
25
+ InstPathSeg (llvm::StringRef seg) : seg(seg) {}
26
+ const llvm::StringRef &getSeg () const { return seg; }
27
+ operator llvm::StringRef () const { return seg; }
28
+
29
+ void Profile (llvm::FoldingSetNodeID &ID) const { ID.AddString (seg); }
30
+ };
31
+ using InstPath = llvm::ImmutableList<InstPathSeg>;
32
+ std::string defaultPrefix (InstPath path) {
33
+ std::string accum;
34
+ while (!path.isEmpty ()) {
35
+ accum += path.getHead ().getSeg ();
36
+ accum += " _" ;
37
+ path = path.getTail ();
38
+ }
39
+ accum += " _" ;
40
+ return std::move (accum);
41
+ }
42
+
43
+ // The regex for port prefix specification
44
+ // "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
45
+ // Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
46
+ // manually parse the spec This parser may accept identifiers with invalid
47
+ // characters
48
+
49
+ std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
50
+ std::string>
51
+ parsePrefixSpec (llvm::StringRef in, InstPath::Factory &listFac) {
52
+ auto [l, r] = in.split (" =" );
53
+ if (r == " " )
54
+ return " No '=' found in input" ;
55
+ auto [ll, lr] = l.split (" :" );
56
+ if (lr == " " )
57
+ return " No ':' found before '='" ;
58
+ llvm::SmallVector<llvm::StringRef, 4 > segs;
59
+ while (lr != " " ) {
60
+ auto [seg, rest] = lr.split (" ." );
61
+ segs.push_back (seg);
62
+ lr = rest;
63
+ }
64
+ InstPath path;
65
+ for (auto &seg : llvm::reverse (segs))
66
+ path = listFac.add (seg, path);
67
+ return std::make_tuple (ll, path, r);
68
+ }
69
+ } // namespace
70
+
71
+ void HWExpungeModulePass::runOnOperation () {
72
+ auto root = getOperation ();
73
+ llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
74
+ root.walk (
75
+ [&](circt::hw::HWModuleLike mod) { allModules[mod.getName ()] = mod; });
76
+ // Reverse-topo order generated by module instances.
77
+ // The topo sorted order does not change throught out the operation. It only
78
+ // gets weakened, but still valid.
79
+ llvm::SmallVector<circt::hw::HWModuleLike> revTopo;
80
+ llvm::DenseSet<mlir::StringRef> visited;
81
+ auto visit = [&allModules, &visited,
82
+ &revTopo](auto &self, circt::hw::HWModuleLike mod) -> void {
83
+ if (visited.contains (mod.getName ()))
84
+ return ;
85
+ visited.insert (mod.getName ());
86
+
87
+ mod.walk ([&](circt::hw::InstanceOp inst) {
88
+ auto instModName = inst.getModuleName ();
89
+ auto instMod = allModules.lookup (instModName);
90
+ if (!instMod)
91
+ inst.emitError (" Unknown module " ) << instModName;
92
+ self (self, instMod);
93
+ });
94
+
95
+ // Reverse topo order, insert on stack pop
96
+ revTopo.push_back (mod);
97
+ };
98
+ for (auto &[_, mod] : allModules)
99
+ visit (visit, mod);
100
+ assert (revTopo.size () == allModules.size ());
101
+
102
+ // Instance path.
103
+ InstPath::Factory pathFactory;
104
+
105
+ // Process port prefix specifications
106
+ // (Module name, Instance path) -> Prefix
107
+ llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
108
+ designatedPrefixes;
109
+ bool containsFailure = false ;
110
+ for (const auto &raw : portPrefixes) {
111
+ auto matched = parsePrefixSpec (raw, pathFactory);
112
+ if (std::holds_alternative<std::string>(matched)) {
113
+ llvm::errs () << " Invalid port prefix specification: " << raw << " \n " ;
114
+ llvm::errs () << " Error: " << std::get<std::string>(matched) << " \n " ;
115
+ containsFailure = true ;
116
+ continue ;
117
+ }
118
+
119
+ auto [module, path, prefix] =
120
+ std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
121
+ matched);
122
+ if (!allModules.contains (module)) {
123
+ llvm::errs () << " Module not found in port prefix specification: "
124
+ << module << " \n " ;
125
+ llvm::errs () << " From specification: " << raw << " \n " ;
126
+ containsFailure = true ;
127
+ continue ;
128
+ }
129
+
130
+ // Skip checking instance paths' existence. Non-existent paths are ignored
131
+ designatedPrefixes.insert ({{module, path}, prefix});
132
+ }
133
+
134
+ if (containsFailure)
135
+ return signalPassFailure ();
136
+
137
+ // Instance path * prefix name
138
+ using ReplacedDescendent = std::pair<InstPath, std::string>;
139
+ // This map holds the expunged descendents of a module
140
+ llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
141
+ expungedDescendents;
142
+ for (auto &expunging : this ->modules ) {
143
+ // Clear expungedDescendents
144
+ for (auto &it : expungedDescendents)
145
+ it.getSecond ().clear ();
146
+
147
+ auto expungingMod = allModules.lookup (expunging);
148
+ if (!expungingMod)
149
+ continue ; // Ignored missing modules
150
+ auto expungingModTy = expungingMod.getHWModuleType ();
151
+ auto expungingModPorts = expungingModTy.getPorts ();
152
+
153
+ auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
154
+ const std::string &prefix,
155
+ auto genOutput, auto emitInput) {
156
+ mlir::OpBuilder builder (mod);
157
+ // Create ports using *REVERSE* direction of their definitions
158
+ for (auto &port : expungingModPorts) {
159
+ auto defaultName = prefix + port.name .getValue ();
160
+ auto finalName = defaultName;
161
+ if (port.dir == circt::hw::PortInfo::Input) {
162
+ auto val = genOutput (port);
163
+ assert (val.getType () == port.type );
164
+ mod.appendOutput (finalName, val);
165
+ } else if (port.dir == circt::hw::PortInfo::Output) {
166
+ auto [_, arg] = mod.appendInput (finalName, port.type );
167
+ emitInput (port, arg);
168
+ }
169
+ }
170
+ };
171
+
172
+ for (auto &processingRaw : revTopo) {
173
+ // Skip extmodule and intmodule because they cannot contain anything
174
+ if (!llvm::isa<circt::hw::HWModuleOp>(processingRaw))
175
+ continue ;
176
+ circt::hw::HWModuleOp processing =
177
+ llvm::cast<circt::hw::HWModuleOp>(processingRaw);
178
+
179
+ std::optional<decltype (expungedDescendents.lookup (" " )) *>
180
+ outerExpDescHold = {};
181
+ auto getOuterExpDesc = [&]() -> decltype (**outerExpDescHold) {
182
+ if (!outerExpDescHold.has_value ())
183
+ outerExpDescHold = {
184
+ &expungedDescendents.insert ({processing.getName (), {}})
185
+ .first ->getSecond ()};
186
+ return **outerExpDescHold;
187
+ };
188
+
189
+ mlir::OpBuilder outerBuilder (processing);
190
+
191
+ processing.walk ([&](circt::hw::InstanceOp inst) {
192
+ auto instName = inst.getInstanceName ();
193
+ auto instMod = allModules.lookup (inst.getModuleName ());
194
+
195
+ if (instMod.getOutputNames ().size () != inst.getResults ().size () ||
196
+ instMod.getNumInputPorts () != inst.getInputs ().size ()) {
197
+ // Module have been modified during this pass, create new instances
198
+ assert (instMod.getNumOutputPorts () >= inst.getResults ().size ());
199
+ assert (instMod.getNumInputPorts () >= inst.getInputs ().size ());
200
+
201
+ auto instModInTypes = instMod.getInputTypes ();
202
+
203
+ llvm::SmallVector<mlir::Value> newInputs;
204
+ newInputs.reserve (instMod.getNumInputPorts ());
205
+
206
+ outerBuilder.setInsertionPointAfter (inst);
207
+
208
+ // Appended inputs are at the end of the input list
209
+ for (size_t i = 0 ; i < instMod.getNumInputPorts (); ++i) {
210
+ mlir::Value input;
211
+ if (i < inst.getNumInputPorts ()) {
212
+ input = inst.getInputs ()[i];
213
+ if (auto existingName = inst.getInputName (i))
214
+ assert (existingName == instMod.getInputName (i));
215
+ } else {
216
+ input =
217
+ outerBuilder
218
+ .create <mlir::UnrealizedConversionCastOp>(
219
+ inst.getLoc (), instModInTypes[i], mlir::ValueRange{})
220
+ .getResult (0 );
221
+ }
222
+ newInputs.push_back (input);
223
+ }
224
+
225
+ auto newInst = outerBuilder.create <circt::hw::InstanceOp>(
226
+ inst.getLoc (), instMod, inst.getInstanceNameAttr (), newInputs,
227
+ inst.getParameters (),
228
+ inst.getInnerSym ().value_or <circt::hw::InnerSymAttr>({}));
229
+
230
+ for (size_t i = 0 ; i < inst.getNumResults (); ++i)
231
+ assert (inst.getOutputName (i) == instMod.getOutputName (i));
232
+ inst.replaceAllUsesWith (
233
+ newInst.getResults ().slice (0 , inst.getNumResults ()));
234
+ inst.erase ();
235
+ inst = newInst;
236
+ }
237
+
238
+ llvm::DenseMap<llvm::StringRef, mlir::Value> instOMap;
239
+ llvm::DenseMap<llvm::StringRef, mlir::Value> instIMap;
240
+ assert (instMod.getOutputNames ().size () == inst.getResults ().size ());
241
+ for (auto [oname, oval] :
242
+ llvm::zip (instMod.getOutputNames (), inst.getResults ()))
243
+ instOMap[llvm::cast<mlir::StringAttr>(oname).getValue ()] = oval;
244
+ assert (instMod.getInputNames ().size () == inst.getInputs ().size ());
245
+ for (auto [iname, ival] :
246
+ llvm::zip (instMod.getInputNames (), inst.getInputs ()))
247
+ instIMap[llvm::cast<mlir::StringAttr>(iname).getValue ()] = ival;
248
+
249
+ // Get outer expunged descendent first because it may modify the map and
250
+ // invalidate iterators.
251
+ auto &outerExpDesc = getOuterExpDesc ();
252
+ auto instExpDesc = expungedDescendents.find (inst.getModuleName ());
253
+
254
+ if (inst.getModuleName () == expunging) {
255
+ // Handle the directly expunged module
256
+ // input maps also useful for directly expunged instance
257
+
258
+ auto singletonPath = pathFactory.create (instName);
259
+
260
+ auto designatedPrefix =
261
+ designatedPrefixes.find ({processing.getName (), singletonPath});
262
+ std::string prefix = designatedPrefix != designatedPrefixes.end ()
263
+ ? designatedPrefix->getSecond ().str ()
264
+ : (instName + " __" ).str ();
265
+
266
+ // Port name collision is still possible, but current relying on MLIR
267
+ // to automatically rename input arguments.
268
+ // TODO: name collision detect
269
+
270
+ createPortsOn (
271
+ processing, prefix,
272
+ [&](circt::hw::ModulePort port) {
273
+ // Generate output for outer module, so input for us
274
+ return instIMap.at (port.name );
275
+ },
276
+ [&](circt::hw::ModulePort port, mlir::Value val) {
277
+ // Generated input for outer module, replace inst results
278
+ assert (instOMap.contains (port.name ));
279
+ instOMap[port.name ].replaceAllUsesWith (val);
280
+ });
281
+
282
+ outerExpDesc.emplace_back (singletonPath, prefix);
283
+
284
+ assert (instExpDesc == expungedDescendents.end () ||
285
+ instExpDesc->getSecond ().size () == 0 );
286
+ inst.erase ();
287
+ } else if (instExpDesc != expungedDescendents.end ()) {
288
+ // Handle all transitive descendents
289
+ if (instExpDesc->second .size () == 0 )
290
+ return ;
291
+ llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
292
+ for (const auto &exp : instExpDesc->second ) {
293
+ auto newPath = pathFactory.add (instName, exp .first );
294
+ auto designatedPrefix =
295
+ designatedPrefixes.find ({processing.getName (), newPath});
296
+ std::string prefix = designatedPrefix != designatedPrefixes.end ()
297
+ ? designatedPrefix->getSecond ().str ()
298
+ : defaultPrefix (newPath);
299
+
300
+ // TODO: name collision detect
301
+
302
+ createPortsOn (
303
+ processing, prefix,
304
+ [&](circt::hw::ModulePort port) {
305
+ // Generate output for outer module, directly forward from
306
+ // inner inst
307
+ return instOMap.at ((exp .second + port.name .getValue ()).str ());
308
+ },
309
+ [&](circt::hw::ModulePort port, mlir::Value val) {
310
+ // Generated input for outer module, replace inst results.
311
+ // The operand in question has to be an backedge
312
+ auto in =
313
+ instIMap.at ((exp .second + port.name .getValue ()).str ());
314
+ auto inDef = in.getDefiningOp ();
315
+ assert (llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
316
+ in.replaceAllUsesWith (val);
317
+ inDef->erase ();
318
+ });
319
+
320
+ outerExpDesc.emplace_back (newPath, prefix);
321
+ }
322
+ }
323
+ });
324
+ }
325
+ }
326
+ }
327
+
328
+ std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass () {
329
+ return std::make_unique<HWExpungeModulePass>();
330
+ }
0 commit comments