1
+ #include " circt/Dialect/HW/HWInstanceGraph.h"
2
+ #include " circt/Dialect/HW/HWOps.h"
3
+ #include " circt/Dialect/HW/HWPasses.h"
4
+ #include " circt/Dialect/HW/HWTypes.h"
5
+ #include " mlir/Pass/Pass.h"
6
+ #include " llvm/ADT/DenseMap.h"
7
+ #include " llvm/ADT/ImmutableList.h"
8
+ #include " llvm/ADT/PostOrderIterator.h"
9
+ #include " llvm/Support/Regex.h"
10
+ #include < numeric>
11
+
12
+ namespace circt {
13
+ namespace hw {
14
+ #define GEN_PASS_DEF_HWEXPUNGEMODULE
15
+ #include " circt/Dialect/HW/Passes.h.inc"
16
+ } // namespace hw
17
+ } // namespace circt
18
+
19
+ namespace {
20
+ struct HWExpungeModulePass
21
+ : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
22
+ void runOnOperation () override ;
23
+ };
24
+
25
+ struct InstPathSeg {
26
+ llvm::StringRef seg;
27
+
28
+ InstPathSeg (llvm::StringRef seg) : seg(seg) {}
29
+ const llvm::StringRef &getSeg () const { return seg; }
30
+ operator llvm::StringRef () const { return seg; }
31
+
32
+ void Profile (llvm::FoldingSetNodeID &ID) const { ID.AddString (seg); }
33
+ };
34
+ using InstPath = llvm::ImmutableList<InstPathSeg>;
35
+ std::string defaultPrefix (InstPath path) {
36
+ std::string accum;
37
+ while (!path.isEmpty ()) {
38
+ accum += path.getHead ().getSeg ();
39
+ accum += " _" ;
40
+ path = path.getTail ();
41
+ }
42
+ accum += " _" ;
43
+ return std::move (accum);
44
+ }
45
+
46
+ // The regex for port prefix specification
47
+ // "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
48
+ // Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
49
+ // manually parse the spec This parser may accept identifiers with invalid
50
+ // characters
51
+
52
+ std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
53
+ std::string>
54
+ parsePrefixSpec (llvm::StringRef in, InstPath::Factory &listFac) {
55
+ auto [l, r] = in.split (" =" );
56
+ if (r == " " )
57
+ return " No '=' found in input" ;
58
+ auto [ll, lr] = l.split (" :" );
59
+ if (lr == " " )
60
+ return " No ':' found before '='" ;
61
+ llvm::SmallVector<llvm::StringRef, 4 > segs;
62
+ while (lr != " " ) {
63
+ auto [seg, rest] = lr.split (" ." );
64
+ segs.push_back (seg);
65
+ lr = rest;
66
+ }
67
+ InstPath path;
68
+ for (auto &seg : llvm::reverse (segs))
69
+ path = listFac.add (seg, path);
70
+ return std::make_tuple (ll, path, r);
71
+ }
72
+ } // namespace
73
+
74
+ void HWExpungeModulePass::runOnOperation () {
75
+ auto root = getOperation ();
76
+ llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
77
+ root.walk (
78
+ [&](circt::hw::HWModuleLike mod) { allModules[mod.getName ()] = mod; });
79
+
80
+ // The instance graph. We only use this graph to traverse the hierarchy in
81
+ // post order. The order does not change throught out the operation, onlygets
82
+ // weakened, but still valid. So we keep this cached instance graph throughout
83
+ // the pass.
84
+ auto &instanceGraph = getAnalysis<circt::hw::InstanceGraph>();
85
+
86
+ // Instance path.
87
+ InstPath::Factory pathFactory;
88
+
89
+ // Process port prefix specifications
90
+ // (Module name, Instance path) -> Prefix
91
+ llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
92
+ designatedPrefixes;
93
+ bool containsFailure = false ;
94
+ for (const auto &raw : portPrefixes) {
95
+ auto matched = parsePrefixSpec (raw, pathFactory);
96
+ if (std::holds_alternative<std::string>(matched)) {
97
+ llvm::errs () << " Invalid port prefix specification: " << raw << " \n " ;
98
+ llvm::errs () << " Error: " << std::get<std::string>(matched) << " \n " ;
99
+ containsFailure = true ;
100
+ continue ;
101
+ }
102
+
103
+ auto [module, path, prefix] =
104
+ std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
105
+ matched);
106
+ if (!allModules.contains (module)) {
107
+ llvm::errs () << " Module not found in port prefix specification: "
108
+ << module << " \n " ;
109
+ llvm::errs () << " From specification: " << raw << " \n " ;
110
+ containsFailure = true ;
111
+ continue ;
112
+ }
113
+
114
+ // Skip checking instance paths' existence. Non-existent paths are ignored
115
+ designatedPrefixes.insert ({{module, path}, prefix});
116
+ }
117
+
118
+ if (containsFailure)
119
+ return signalPassFailure ();
120
+
121
+ // Instance path * prefix name
122
+ using ReplacedDescendent = std::pair<InstPath, std::string>;
123
+ // This map holds the expunged descendents of a module
124
+ llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
125
+ expungedDescendents;
126
+ for (auto &expunging : this ->modules ) {
127
+ // Clear expungedDescendents
128
+ for (auto &it : expungedDescendents)
129
+ it.getSecond ().clear ();
130
+
131
+ auto expungingMod = allModules.lookup (expunging);
132
+ if (!expungingMod)
133
+ continue ; // Ignored missing modules
134
+ auto expungingModTy = expungingMod.getHWModuleType ();
135
+ auto expungingModPorts = expungingModTy.getPorts ();
136
+
137
+ auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
138
+ const std::string &prefix,
139
+ auto genOutput, auto emitInput) {
140
+ mlir::OpBuilder builder (mod);
141
+ // Create ports using *REVERSE* direction of their definitions
142
+ for (auto &port : expungingModPorts) {
143
+ auto defaultName = prefix + port.name .getValue ();
144
+ auto finalName = defaultName;
145
+ if (port.dir == circt::hw::PortInfo::Input) {
146
+ auto val = genOutput (port);
147
+ assert (val.getType () == port.type );
148
+ mod.appendOutput (finalName, val);
149
+ } else if (port.dir == circt::hw::PortInfo::Output) {
150
+ auto [_, arg] = mod.appendInput (finalName, port.type );
151
+ emitInput (port, arg);
152
+ }
153
+ }
154
+ };
155
+
156
+ for (auto &instGraphNode : llvm::post_order (&instanceGraph)) {
157
+ // Skip extmodule and intmodule because they cannot contain anything
158
+ circt::hw::HWModuleOp processing =
159
+ llvm::dyn_cast_if_present<circt::hw::HWModuleOp>(
160
+ instGraphNode->getModule ().getOperation ());
161
+ if (!processing)
162
+ continue ;
163
+
164
+ std::optional<decltype (expungedDescendents.lookup (" " )) *>
165
+ outerExpDescHold = {};
166
+ auto getOuterExpDesc = [&]() -> decltype (**outerExpDescHold) {
167
+ if (!outerExpDescHold.has_value ())
168
+ outerExpDescHold = {
169
+ &expungedDescendents.insert ({processing.getName (), {}})
170
+ .first ->getSecond ()};
171
+ return **outerExpDescHold;
172
+ };
173
+
174
+ mlir::OpBuilder outerBuilder (processing);
175
+
176
+ processing.walk ([&](circt::hw::InstanceOp inst) {
177
+ auto instName = inst.getInstanceName ();
178
+ auto instMod = allModules.lookup (inst.getModuleName ());
179
+
180
+ if (instMod.getOutputNames ().size () != inst.getResults ().size () ||
181
+ instMod.getNumInputPorts () != inst.getInputs ().size ()) {
182
+ // Module have been modified during this pass, create new instances
183
+ assert (instMod.getNumOutputPorts () >= inst.getResults ().size ());
184
+ assert (instMod.getNumInputPorts () >= inst.getInputs ().size ());
185
+
186
+ auto instModInTypes = instMod.getInputTypes ();
187
+
188
+ llvm::SmallVector<mlir::Value> newInputs;
189
+ newInputs.reserve (instMod.getNumInputPorts ());
190
+
191
+ outerBuilder.setInsertionPointAfter (inst);
192
+
193
+ // Appended inputs are at the end of the input list
194
+ for (size_t i = 0 ; i < instMod.getNumInputPorts (); ++i) {
195
+ mlir::Value input;
196
+ if (i < inst.getNumInputPorts ()) {
197
+ input = inst.getInputs ()[i];
198
+ if (auto existingName = inst.getInputName (i))
199
+ assert (existingName == instMod.getInputName (i));
200
+ } else {
201
+ input =
202
+ outerBuilder
203
+ .create <mlir::UnrealizedConversionCastOp>(
204
+ inst.getLoc (), instModInTypes[i], mlir::ValueRange{})
205
+ .getResult (0 );
206
+ }
207
+ newInputs.push_back (input);
208
+ }
209
+
210
+ auto newInst = outerBuilder.create <circt::hw::InstanceOp>(
211
+ inst.getLoc (), instMod, inst.getInstanceNameAttr (), newInputs,
212
+ inst.getParameters (),
213
+ inst.getInnerSym ().value_or <circt::hw::InnerSymAttr>({}));
214
+
215
+ for (size_t i = 0 ; i < inst.getNumResults (); ++i)
216
+ assert (inst.getOutputName (i) == instMod.getOutputName (i));
217
+ inst.replaceAllUsesWith (
218
+ newInst.getResults ().slice (0 , inst.getNumResults ()));
219
+ inst.erase ();
220
+ inst = newInst;
221
+ }
222
+
223
+ llvm::StringMap<mlir::Value> instOMap;
224
+ llvm::StringMap<mlir::Value> instIMap;
225
+ assert (instMod.getOutputNames ().size () == inst.getResults ().size ());
226
+ for (auto [oname, oval] :
227
+ llvm::zip (instMod.getOutputNames (), inst.getResults ()))
228
+ instOMap[llvm::cast<mlir::StringAttr>(oname).getValue ()] = oval;
229
+ assert (instMod.getInputNames ().size () == inst.getInputs ().size ());
230
+ for (auto [iname, ival] :
231
+ llvm::zip (instMod.getInputNames (), inst.getInputs ()))
232
+ instIMap[llvm::cast<mlir::StringAttr>(iname).getValue ()] = ival;
233
+
234
+ // Get outer expunged descendent first because it may modify the map and
235
+ // invalidate iterators.
236
+ auto &outerExpDesc = getOuterExpDesc ();
237
+ auto instExpDesc = expungedDescendents.find (inst.getModuleName ());
238
+
239
+ if (inst.getModuleName () == expunging) {
240
+ // Handle the directly expunged module
241
+ // input maps also useful for directly expunged instance
242
+
243
+ auto singletonPath = pathFactory.create (instName);
244
+
245
+ auto designatedPrefix =
246
+ designatedPrefixes.find ({processing.getName (), singletonPath});
247
+ std::string prefix = designatedPrefix != designatedPrefixes.end ()
248
+ ? designatedPrefix->getSecond ().str ()
249
+ : (instName + " __" ).str ();
250
+
251
+ // Port name collision is still possible, but current relying on MLIR
252
+ // to automatically rename input arguments.
253
+ // TODO: name collision detect
254
+
255
+ createPortsOn (
256
+ processing, prefix,
257
+ [&](circt::hw::ModulePort port) {
258
+ // Generate output for outer module, so input for us
259
+ return instIMap.at (port.name );
260
+ },
261
+ [&](circt::hw::ModulePort port, mlir::Value val) {
262
+ // Generated input for outer module, replace inst results
263
+ assert (instOMap.contains (port.name ));
264
+ instOMap[port.name ].replaceAllUsesWith (val);
265
+ });
266
+
267
+ outerExpDesc.emplace_back (singletonPath, prefix);
268
+
269
+ assert (instExpDesc == expungedDescendents.end () ||
270
+ instExpDesc->getSecond ().size () == 0 );
271
+ inst.erase ();
272
+ } else if (instExpDesc != expungedDescendents.end ()) {
273
+ // Handle all transitive descendents
274
+ if (instExpDesc->second .size () == 0 )
275
+ return ;
276
+ llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
277
+ for (const auto &exp : instExpDesc->second ) {
278
+ auto newPath = pathFactory.add (instName, exp .first );
279
+ auto designatedPrefix =
280
+ designatedPrefixes.find ({processing.getName (), newPath});
281
+ std::string prefix = designatedPrefix != designatedPrefixes.end ()
282
+ ? designatedPrefix->getSecond ().str ()
283
+ : defaultPrefix (newPath);
284
+
285
+ // TODO: name collision detect
286
+
287
+ createPortsOn (
288
+ processing, prefix,
289
+ [&](circt::hw::ModulePort port) {
290
+ // Generate output for outer module, directly forward from
291
+ // inner inst
292
+ return instOMap.at ((exp .second + port.name .getValue ()).str ());
293
+ },
294
+ [&](circt::hw::ModulePort port, mlir::Value val) {
295
+ // Generated input for outer module, replace inst results.
296
+ // The operand in question has to be an backedge
297
+ auto in =
298
+ instIMap.at ((exp .second + port.name .getValue ()).str ());
299
+ auto inDef = in.getDefiningOp ();
300
+ assert (llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
301
+ in.replaceAllUsesWith (val);
302
+ inDef->erase ();
303
+ });
304
+
305
+ outerExpDesc.emplace_back (newPath, prefix);
306
+ }
307
+ }
308
+ });
309
+ }
310
+ }
311
+ }
312
+
313
+ std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass () {
314
+ return std::make_unique<HWExpungeModulePass>();
315
+ }
0 commit comments