1
1
#include " internal.hh"
2
2
3
+ #include < algorithm>
4
+ #include < cstddef>
5
+ #include < deque>
6
+ #include < stdexcept>
7
+
3
8
namespace
4
9
{
5
10
using namespace rego ;
6
11
using namespace wf ::ops;
7
12
13
+ struct InitSide
14
+ {
15
+ std::set<Location> vars;
16
+ std::set<Location> inits;
17
+ };
18
+
8
19
struct InitInfo
9
20
{
10
- std::set<Location> lhs_vars;
11
- std::set<Location> rhs_vars;
21
+ std::size_t index;
22
+ InitSide lhs;
23
+ InitSide rhs;
12
24
};
13
25
14
26
Node to_init (
@@ -37,11 +49,12 @@ namespace
37
49
return LiteralInit << lhs_vars << rhs_vars << (AssignInfix << lhs << rhs);
38
50
}
39
51
40
- void vars_from (Node node, std::set<Location>& vars)
52
+ void inits_from (
53
+ Node node, const std::set<Location>& locals, std::set<Location>& inits)
41
54
{
42
- if (node->type () == Var)
55
+ if (node->type () == Var && contains (locals, node-> location ()) )
43
56
{
44
- vars .insert (node->location ());
57
+ inits .insert (node->location ());
45
58
return ;
46
59
}
47
60
@@ -78,13 +91,102 @@ namespace
78
91
79
92
for (Node child : *node)
80
93
{
81
- vars_from (child, vars);
94
+ inits_from (child, locals, inits);
95
+ }
96
+ }
97
+
98
+ void vars_from (
99
+ Node node, const std::set<Location>& locals, std::set<Location>& vars)
100
+ {
101
+ if (node->type () == Var && contains (locals, node->location ()))
102
+ {
103
+ vars.insert (node->location ());
104
+ }
105
+
106
+ for (Node child : *node)
107
+ {
108
+ vars_from (child, locals, vars);
109
+ }
110
+ }
111
+
112
+ InitSide side_from (Node node, const std::set<Location>& locals)
113
+ {
114
+ InitSide side;
115
+ inits_from (node, locals, side.inits );
116
+ vars_from (node, locals, side.vars );
117
+ return side;
118
+ }
119
+
120
+ bool any_compiler_inits (const InitSide& lhs)
121
+ {
122
+ return std::any_of (lhs.inits .begin (), lhs.inits .end (), [](auto & loc) {
123
+ std::string name = loc.str ();
124
+ return starts_with (name, " unify$" ) || starts_with (name, " out$" ) ||
125
+ starts_with (name, " value$" );
126
+ });
127
+ }
128
+
129
+ void remove_locals (
130
+ std::deque<InitInfo>& init_deque, const std::set<Location>& to_remove)
131
+ {
132
+ std::size_t count = init_deque.size ();
133
+ for (std::size_t i = 0 ; i < count; ++i)
134
+ {
135
+ InitInfo& init_stmt = init_deque.front ();
136
+ for (auto & loc : to_remove)
137
+ {
138
+ init_stmt.lhs .vars .erase (loc);
139
+ init_stmt.lhs .inits .erase (loc);
140
+ init_stmt.rhs .vars .erase (loc);
141
+ init_stmt.rhs .inits .erase (loc);
142
+ }
143
+ if (!init_stmt.lhs .inits .empty () || !init_stmt.rhs .inits .empty ())
144
+ {
145
+ init_deque.push_back (init_stmt);
146
+ }
147
+
148
+ init_deque.pop_front ();
82
149
}
83
150
}
84
151
152
+ std::vector<InitInfo> sort_init_stmts (
153
+ const std::set<Location>& locals, std::deque<InitInfo>& init_deque)
154
+ {
155
+ std::set<Location> initialized;
156
+ std::vector<InitInfo> init_stmts;
157
+ while (!init_deque.empty () && initialized != locals)
158
+ {
159
+ // find all strict init statements
160
+ auto it =
161
+ std::find_if (init_deque.begin (), init_deque.end (), [](auto & init_stmt) {
162
+ return init_stmt.lhs .vars .empty () || init_stmt.rhs .vars .empty ();
163
+ });
164
+
165
+ if (it == init_deque.end ())
166
+ {
167
+ // we have a cycle, so we use the first statement
168
+ it = init_deque.begin ();
169
+ init_stmts.push_back (*it);
170
+ }
171
+ else
172
+ {
173
+ init_stmts.push_back (*it);
174
+ }
175
+
176
+ std::set<Location> to_remove;
177
+ to_remove.insert (it->lhs .inits .begin (), it->lhs .inits .end ());
178
+ to_remove.insert (it->rhs .inits .begin (), it->rhs .inits .end ());
179
+ init_deque.erase (it);
180
+ remove_locals (init_deque, to_remove);
181
+ initialized.insert (to_remove.begin (), to_remove.end ());
182
+ }
183
+
184
+ return init_stmts;
185
+ }
186
+
85
187
void find_init_stmts (Node unifybody, std::set<Location>& locals)
86
188
{
87
- // gather all locals
189
+ std::deque<InitInfo> potential_init_stmts;
88
190
for (std::size_t i = 0 ; i < unifybody->size (); ++i)
89
191
{
90
192
Node stmt = unifybody->at (i);
@@ -95,15 +197,6 @@ namespace
95
197
else if (stmt->type () == LiteralEnum)
96
198
{
97
199
locals.erase ((stmt / Item)->location ());
98
- find_init_stmts (stmt / UnifyBody, locals);
99
- }
100
- else if (stmt->type () == LiteralWith)
101
- {
102
- find_init_stmts (stmt / UnifyBody, locals);
103
- }
104
- else if (stmt->type () == LiteralNot)
105
- {
106
- find_init_stmts (stmt / UnifyBody, locals);
107
200
}
108
201
else if (stmt->type () == Literal)
109
202
{
@@ -115,42 +208,65 @@ namespace
115
208
116
209
Node lhs = expr->front ();
117
210
Node rhs = expr->back ();
118
- std::set<Location> lhs_vars;
119
- vars_from (lhs, lhs_vars);
120
- std::set<Location> lhs_found;
121
- std::set_intersection (
122
- lhs_vars.begin (),
123
- lhs_vars.end (),
124
- locals.begin (),
125
- locals.end (),
126
- std::inserter (lhs_found, lhs_found.begin ()));
127
-
128
- std::set<Location> rhs_vars;
129
- vars_from (rhs, rhs_vars);
130
- std::set<Location> rhs_found;
131
- std::set_intersection (
132
- rhs_vars.begin (),
133
- rhs_vars.end (),
134
- locals.begin (),
135
- locals.end (),
136
- std::inserter (rhs_found, rhs_found.begin ()));
137
-
138
- if (lhs_found.empty () && rhs_found.empty ())
139
- {
140
- continue ;
141
- }
142
211
143
- for (auto & loc : lhs_found)
212
+ InitSide lhs_side = side_from (lhs, locals);
213
+ InitSide rhs_side = side_from (rhs, locals);
214
+
215
+ if (any_compiler_inits (lhs_side))
144
216
{
145
- locals.erase (loc);
217
+ // compiler statements will never be right-assign, so we can
218
+ // use this fact later to help resolve some ambiguities
219
+ rhs_side.inits .clear ();
146
220
}
147
221
148
- for ( auto & loc : rhs_found )
222
+ if (lhs_side. inits . empty () && rhs_side. inits . empty () )
149
223
{
150
- locals. erase (loc) ;
224
+ continue ;
151
225
}
152
226
153
- unifybody->replace_at (i, to_init (lhs, lhs_found, rhs, rhs_found));
227
+ potential_init_stmts.push_back ({i, lhs_side, rhs_side});
228
+ }
229
+ }
230
+
231
+ std::vector<InitInfo> init_stmts =
232
+ sort_init_stmts (locals, potential_init_stmts);
233
+ for (std::size_t i = 0 ; i < init_stmts.size (); ++i)
234
+ {
235
+ InitInfo& init_stmt = init_stmts[i];
236
+ Node expr = unifybody->at (init_stmt.index )->front ()->front ();
237
+
238
+ Node lhs = expr->front ();
239
+ Node rhs = expr->back ();
240
+
241
+ for (auto & loc : init_stmt.lhs .inits )
242
+ {
243
+ locals.erase (loc);
244
+ }
245
+
246
+ for (auto & loc : init_stmt.rhs .inits )
247
+ {
248
+ locals.erase (loc);
249
+ }
250
+
251
+ unifybody->replace_at (
252
+ init_stmt.index ,
253
+ to_init (lhs, init_stmt.lhs .inits , rhs, init_stmt.rhs .inits ));
254
+ }
255
+
256
+ // where appropriate, recurse with the updated locals
257
+ for (Node stmt : *unifybody)
258
+ {
259
+ if (stmt->type () == LiteralEnum)
260
+ {
261
+ find_init_stmts (stmt / UnifyBody, locals);
262
+ }
263
+ else if (stmt->type () == LiteralWith)
264
+ {
265
+ find_init_stmts (stmt / UnifyBody, locals);
266
+ }
267
+ else if (stmt->type () == LiteralNot)
268
+ {
269
+ find_init_stmts (stmt / UnifyBody, locals);
154
270
}
155
271
}
156
272
}
0 commit comments