1
+ #include " circt/Dialect/HW/HWOps.h"
2
+ #include " circt/Dialect/HW/HWPasses.h"
3
+ #include " circt/Dialect/HW/HWTypes.h"
4
+ #include " circt/Dialect/HW/HWInstanceGraph.h"
5
+ #include " mlir/Pass/Pass.h"
6
+ #include " llvm/ADT/DenseMap.h"
7
+ #include " llvm/ADT/ImmutableList.h"
8
+ #include " llvm/ADT/PostOrderIterator.h"
9
+ #include " llvm/Support/Regex.h"
10
+ #include < numeric>
11
+
12
+ namespace circt {
13
+ namespace hw {
14
+ #define GEN_PASS_DEF_HWEXPUNGEMODULE
15
+ #include " circt/Dialect/HW/Passes.h.inc"
16
+ } // namespace hw
17
+ } // namespace circt
18
+
19
+ namespace {
20
+ struct HWExpungeModulePass
21
+ : circt::hw::impl::HWExpungeModuleBase<HWExpungeModulePass> {
22
+ void runOnOperation () override ;
23
+ };
24
+
25
+ struct InstPathSeg {
26
+ llvm::StringRef seg;
27
+
28
+ InstPathSeg (llvm::StringRef seg) : seg(seg) {}
29
+ const llvm::StringRef &getSeg () const { return seg; }
30
+ operator llvm::StringRef () const { return seg; }
31
+
32
+ void Profile (llvm::FoldingSetNodeID &ID) const { ID.AddString (seg); }
33
+ };
34
+ using InstPath = llvm::ImmutableList<InstPathSeg>;
35
+ std::string defaultPrefix (InstPath path) {
36
+ std::string accum;
37
+ while (!path.isEmpty ()) {
38
+ accum += path.getHead ().getSeg ();
39
+ accum += " _" ;
40
+ path = path.getTail ();
41
+ }
42
+ accum += " _" ;
43
+ return std::move (accum);
44
+ }
45
+
46
+ // The regex for port prefix specification
47
+ // "^([@#a-zA-Z0-9_]+):([a-zA-Z0-9_]+)(\\.[a-zA-Z0-9_]+)*=([a-zA-Z0-9_]+)$"
48
+ // Unfortunately, the LLVM Regex cannot capture repeating capture groups, so
49
+ // manually parse the spec This parser may accept identifiers with invalid
50
+ // characters
51
+
52
+ std::variant<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>,
53
+ std::string>
54
+ parsePrefixSpec (llvm::StringRef in, InstPath::Factory &listFac) {
55
+ auto [l, r] = in.split (" =" );
56
+ if (r == " " )
57
+ return " No '=' found in input" ;
58
+ auto [ll, lr] = l.split (" :" );
59
+ if (lr == " " )
60
+ return " No ':' found before '='" ;
61
+ llvm::SmallVector<llvm::StringRef, 4 > segs;
62
+ while (lr != " " ) {
63
+ auto [seg, rest] = lr.split (" ." );
64
+ segs.push_back (seg);
65
+ lr = rest;
66
+ }
67
+ InstPath path;
68
+ for (auto &seg : llvm::reverse (segs))
69
+ path = listFac.add (seg, path);
70
+ return std::make_tuple (ll, path, r);
71
+ }
72
+ } // namespace
73
+
74
+ void HWExpungeModulePass::runOnOperation () {
75
+ auto root = getOperation ();
76
+ llvm::DenseMap<mlir::StringRef, circt::hw::HWModuleLike> allModules;
77
+ root.walk (
78
+ [&](circt::hw::HWModuleLike mod) { allModules[mod.getName ()] = mod; });
79
+
80
+ // The instance graph. We only use this graph to traverse the hierarchy in post order.
81
+ // The order does not change throught out the operation, onlygets weakened, but still valid.
82
+ // So we keep this cached instance graph throughout the pass.
83
+ auto &instanceGraph = getAnalysis<circt::hw::InstanceGraph>();
84
+
85
+ // Instance path.
86
+ InstPath::Factory pathFactory;
87
+
88
+ // Process port prefix specifications
89
+ // (Module name, Instance path) -> Prefix
90
+ llvm::DenseMap<std::pair<mlir::StringRef, InstPath>, mlir::StringRef>
91
+ designatedPrefixes;
92
+ bool containsFailure = false ;
93
+ for (const auto &raw : portPrefixes) {
94
+ auto matched = parsePrefixSpec (raw, pathFactory);
95
+ if (std::holds_alternative<std::string>(matched)) {
96
+ llvm::errs () << " Invalid port prefix specification: " << raw << " \n " ;
97
+ llvm::errs () << " Error: " << std::get<std::string>(matched) << " \n " ;
98
+ containsFailure = true ;
99
+ continue ;
100
+ }
101
+
102
+ auto [module, path, prefix] =
103
+ std::get<std::tuple<llvm::StringRef, InstPath, llvm::StringRef>>(
104
+ matched);
105
+ if (!allModules.contains (module)) {
106
+ llvm::errs () << " Module not found in port prefix specification: "
107
+ << module << " \n " ;
108
+ llvm::errs () << " From specification: " << raw << " \n " ;
109
+ containsFailure = true ;
110
+ continue ;
111
+ }
112
+
113
+ // Skip checking instance paths' existence. Non-existent paths are ignored
114
+ designatedPrefixes.insert ({{module, path}, prefix});
115
+ }
116
+
117
+ if (containsFailure)
118
+ return signalPassFailure ();
119
+
120
+ // Instance path * prefix name
121
+ using ReplacedDescendent = std::pair<InstPath, std::string>;
122
+ // This map holds the expunged descendents of a module
123
+ llvm::DenseMap<llvm::StringRef, llvm::SmallVector<ReplacedDescendent>>
124
+ expungedDescendents;
125
+ for (auto &expunging : this ->modules ) {
126
+ // Clear expungedDescendents
127
+ for (auto &it : expungedDescendents)
128
+ it.getSecond ().clear ();
129
+
130
+ auto expungingMod = allModules.lookup (expunging);
131
+ if (!expungingMod)
132
+ continue ; // Ignored missing modules
133
+ auto expungingModTy = expungingMod.getHWModuleType ();
134
+ auto expungingModPorts = expungingModTy.getPorts ();
135
+
136
+ auto createPortsOn = [&expungingModPorts](circt::hw::HWModuleOp mod,
137
+ const std::string &prefix,
138
+ auto genOutput, auto emitInput) {
139
+ mlir::OpBuilder builder (mod);
140
+ // Create ports using *REVERSE* direction of their definitions
141
+ for (auto &port : expungingModPorts) {
142
+ auto defaultName = prefix + port.name .getValue ();
143
+ auto finalName = defaultName;
144
+ if (port.dir == circt::hw::PortInfo::Input) {
145
+ auto val = genOutput (port);
146
+ assert (val.getType () == port.type );
147
+ mod.appendOutput (finalName, val);
148
+ } else if (port.dir == circt::hw::PortInfo::Output) {
149
+ auto [_, arg] = mod.appendInput (finalName, port.type );
150
+ emitInput (port, arg);
151
+ }
152
+ }
153
+ };
154
+
155
+ for (auto &instGraphNode : llvm::post_order (&instanceGraph)) {
156
+ // Skip extmodule and intmodule because they cannot contain anything
157
+ circt::hw::HWModuleOp processing =
158
+ llvm::dyn_cast_if_present<circt::hw::HWModuleOp>(
159
+ instGraphNode->getModule ().getOperation ());
160
+ if (!processing)
161
+ continue ;
162
+
163
+ std::optional<decltype (expungedDescendents.lookup (" " )) *>
164
+ outerExpDescHold = {};
165
+ auto getOuterExpDesc = [&]() -> decltype (**outerExpDescHold) {
166
+ if (!outerExpDescHold.has_value ())
167
+ outerExpDescHold = {
168
+ &expungedDescendents.insert ({processing.getName (), {}})
169
+ .first ->getSecond ()};
170
+ return **outerExpDescHold;
171
+ };
172
+
173
+ mlir::OpBuilder outerBuilder (processing);
174
+
175
+ processing.walk ([&](circt::hw::InstanceOp inst) {
176
+ auto instName = inst.getInstanceName ();
177
+ auto instMod = allModules.lookup (inst.getModuleName ());
178
+
179
+ if (instMod.getOutputNames ().size () != inst.getResults ().size () ||
180
+ instMod.getNumInputPorts () != inst.getInputs ().size ()) {
181
+ // Module have been modified during this pass, create new instances
182
+ assert (instMod.getNumOutputPorts () >= inst.getResults ().size ());
183
+ assert (instMod.getNumInputPorts () >= inst.getInputs ().size ());
184
+
185
+ auto instModInTypes = instMod.getInputTypes ();
186
+
187
+ llvm::SmallVector<mlir::Value> newInputs;
188
+ newInputs.reserve (instMod.getNumInputPorts ());
189
+
190
+ outerBuilder.setInsertionPointAfter (inst);
191
+
192
+ // Appended inputs are at the end of the input list
193
+ for (size_t i = 0 ; i < instMod.getNumInputPorts (); ++i) {
194
+ mlir::Value input;
195
+ if (i < inst.getNumInputPorts ()) {
196
+ input = inst.getInputs ()[i];
197
+ if (auto existingName = inst.getInputName (i))
198
+ assert (existingName == instMod.getInputName (i));
199
+ } else {
200
+ input =
201
+ outerBuilder
202
+ .create <mlir::UnrealizedConversionCastOp>(
203
+ inst.getLoc (), instModInTypes[i], mlir::ValueRange{})
204
+ .getResult (0 );
205
+ }
206
+ newInputs.push_back (input);
207
+ }
208
+
209
+ auto newInst = outerBuilder.create <circt::hw::InstanceOp>(
210
+ inst.getLoc (), instMod, inst.getInstanceNameAttr (), newInputs,
211
+ inst.getParameters (),
212
+ inst.getInnerSym ().value_or <circt::hw::InnerSymAttr>({}));
213
+
214
+ for (size_t i = 0 ; i < inst.getNumResults (); ++i)
215
+ assert (inst.getOutputName (i) == instMod.getOutputName (i));
216
+ inst.replaceAllUsesWith (
217
+ newInst.getResults ().slice (0 , inst.getNumResults ()));
218
+ inst.erase ();
219
+ inst = newInst;
220
+ }
221
+
222
+ llvm::StringMap<mlir::Value> instOMap;
223
+ llvm::StringMap<mlir::Value> instIMap;
224
+ assert (instMod.getOutputNames ().size () == inst.getResults ().size ());
225
+ for (auto [oname, oval] :
226
+ llvm::zip (instMod.getOutputNames (), inst.getResults ()))
227
+ instOMap[llvm::cast<mlir::StringAttr>(oname).getValue ()] = oval;
228
+ assert (instMod.getInputNames ().size () == inst.getInputs ().size ());
229
+ for (auto [iname, ival] :
230
+ llvm::zip (instMod.getInputNames (), inst.getInputs ()))
231
+ instIMap[llvm::cast<mlir::StringAttr>(iname).getValue ()] = ival;
232
+
233
+ // Get outer expunged descendent first because it may modify the map and
234
+ // invalidate iterators.
235
+ auto &outerExpDesc = getOuterExpDesc ();
236
+ auto instExpDesc = expungedDescendents.find (inst.getModuleName ());
237
+
238
+ if (inst.getModuleName () == expunging) {
239
+ // Handle the directly expunged module
240
+ // input maps also useful for directly expunged instance
241
+
242
+ auto singletonPath = pathFactory.create (instName);
243
+
244
+ auto designatedPrefix =
245
+ designatedPrefixes.find ({processing.getName (), singletonPath});
246
+ std::string prefix = designatedPrefix != designatedPrefixes.end ()
247
+ ? designatedPrefix->getSecond ().str ()
248
+ : (instName + " __" ).str ();
249
+
250
+ // Port name collision is still possible, but current relying on MLIR
251
+ // to automatically rename input arguments.
252
+ // TODO: name collision detect
253
+
254
+ createPortsOn (
255
+ processing, prefix,
256
+ [&](circt::hw::ModulePort port) {
257
+ // Generate output for outer module, so input for us
258
+ return instIMap.at (port.name );
259
+ },
260
+ [&](circt::hw::ModulePort port, mlir::Value val) {
261
+ // Generated input for outer module, replace inst results
262
+ assert (instOMap.contains (port.name ));
263
+ instOMap[port.name ].replaceAllUsesWith (val);
264
+ });
265
+
266
+ outerExpDesc.emplace_back (singletonPath, prefix);
267
+
268
+ assert (instExpDesc == expungedDescendents.end () ||
269
+ instExpDesc->getSecond ().size () == 0 );
270
+ inst.erase ();
271
+ } else if (instExpDesc != expungedDescendents.end ()) {
272
+ // Handle all transitive descendents
273
+ if (instExpDesc->second .size () == 0 )
274
+ return ;
275
+ llvm::DenseMap<llvm::StringRef, mlir::Value> newInputs;
276
+ for (const auto &exp : instExpDesc->second ) {
277
+ auto newPath = pathFactory.add (instName, exp .first );
278
+ auto designatedPrefix =
279
+ designatedPrefixes.find ({processing.getName (), newPath});
280
+ std::string prefix = designatedPrefix != designatedPrefixes.end ()
281
+ ? designatedPrefix->getSecond ().str ()
282
+ : defaultPrefix (newPath);
283
+
284
+ // TODO: name collision detect
285
+
286
+ createPortsOn (
287
+ processing, prefix,
288
+ [&](circt::hw::ModulePort port) {
289
+ // Generate output for outer module, directly forward from
290
+ // inner inst
291
+ return instOMap.at ((exp .second + port.name .getValue ()).str ());
292
+ },
293
+ [&](circt::hw::ModulePort port, mlir::Value val) {
294
+ // Generated input for outer module, replace inst results.
295
+ // The operand in question has to be an backedge
296
+ auto in =
297
+ instIMap.at ((exp .second + port.name .getValue ()).str ());
298
+ auto inDef = in.getDefiningOp ();
299
+ assert (llvm::isa<mlir::UnrealizedConversionCastOp>(inDef));
300
+ in.replaceAllUsesWith (val);
301
+ inDef->erase ();
302
+ });
303
+
304
+ outerExpDesc.emplace_back (newPath, prefix);
305
+ }
306
+ }
307
+ });
308
+ }
309
+ }
310
+ }
311
+
312
+ std::unique_ptr<mlir::Pass> circt::hw::createHWExpungeModulePass () {
313
+ return std::make_unique<HWExpungeModulePass>();
314
+ }
0 commit comments