Skip to content

Commit 6d84fc7

Browse files
authored
Fixes for regressions in local variable behavior. (#98)
As a result of the optimization passes, some incorrect behavior was introduced for local variables. This commit restores correct functionality. It also fixes a crash when accessing certain nodes with an undefined key. Signed-off-by: Matthew Johnson <[email protected]>
1 parent ebda1aa commit 6d84fc7

File tree

12 files changed

+227
-51
lines changed

12 files changed

+227
-51
lines changed

CHANGELOG

+15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
# Changelog
22

3+
## 2024-01-19 - Version 0.3.11
4+
Minor improvements and bug fixes.
5+
6+
**New Features**
7+
- Updated to more recent Trieste version
8+
- More sophisticated logging
9+
10+
**Bug fixes**
11+
- Comprehensions over local variables were not properly capturing the local (regression due to optimization)
12+
- Local variable initializations were order-dependent (regression due to optimization)
13+
- In some circumstances, indexing the data object with an undefined key caused a segfault.
14+
15+
**Other**
16+
- Various CI changes due to issues with Github actions.
17+
318
## 2023-09-21 - Version 0.3.10
419
Instrumentation and optimization.
520

VERSION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0.3.10
1+
0.3.11

examples/rust/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ edition = "2021"
66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
77

88
[dependencies]
9-
regorust = "0.3.10"
9+
regorust = "0.3.11"
1010
clap = { version = "4.0", features = ["derive"] }

src/internal.hh

+1
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ namespace rego
146146
PassDef explicit_enums();
147147
PassDef body_locals(const BuiltIns& builtins);
148148
PassDef value_locals(const BuiltIns& builtins);
149+
PassDef compr_locals(const BuiltIns& builtins);
149150
PassDef rules_to_compr();
150151
PassDef compr();
151152
PassDef absolute_refs();

src/passes/init.cc

+161-45
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
#include "internal.hh"
22

3+
#include <algorithm>
4+
#include <cstddef>
5+
#include <deque>
6+
#include <stdexcept>
7+
38
namespace
49
{
510
using namespace rego;
611
using namespace wf::ops;
712

13+
struct InitSide
14+
{
15+
std::set<Location> vars;
16+
std::set<Location> inits;
17+
};
18+
819
struct InitInfo
920
{
10-
std::set<Location> lhs_vars;
11-
std::set<Location> rhs_vars;
21+
std::size_t index;
22+
InitSide lhs;
23+
InitSide rhs;
1224
};
1325

1426
Node to_init(
@@ -37,11 +49,12 @@ namespace
3749
return LiteralInit << lhs_vars << rhs_vars << (AssignInfix << lhs << rhs);
3850
}
3951

40-
void vars_from(Node node, std::set<Location>& vars)
52+
void inits_from(
53+
Node node, const std::set<Location>& locals, std::set<Location>& inits)
4154
{
42-
if (node->type() == Var)
55+
if (node->type() == Var && contains(locals, node->location()))
4356
{
44-
vars.insert(node->location());
57+
inits.insert(node->location());
4558
return;
4659
}
4760

@@ -78,13 +91,102 @@ namespace
7891

7992
for (Node child : *node)
8093
{
81-
vars_from(child, vars);
94+
inits_from(child, locals, inits);
95+
}
96+
}
97+
98+
void vars_from(
99+
Node node, const std::set<Location>& locals, std::set<Location>& vars)
100+
{
101+
if (node->type() == Var && contains(locals, node->location()))
102+
{
103+
vars.insert(node->location());
104+
}
105+
106+
for (Node child : *node)
107+
{
108+
vars_from(child, locals, vars);
109+
}
110+
}
111+
112+
InitSide side_from(Node node, const std::set<Location>& locals)
113+
{
114+
InitSide side;
115+
inits_from(node, locals, side.inits);
116+
vars_from(node, locals, side.vars);
117+
return side;
118+
}
119+
120+
bool any_compiler_inits(const InitSide& lhs)
121+
{
122+
return std::any_of(lhs.inits.begin(), lhs.inits.end(), [](auto& loc) {
123+
std::string name = loc.str();
124+
return starts_with(name, "unify$") || starts_with(name, "out$") ||
125+
starts_with(name, "value$");
126+
});
127+
}
128+
129+
void remove_locals(
130+
std::deque<InitInfo>& init_deque, const std::set<Location>& to_remove)
131+
{
132+
std::size_t count = init_deque.size();
133+
for (std::size_t i = 0; i < count; ++i)
134+
{
135+
InitInfo& init_stmt = init_deque.front();
136+
for (auto& loc : to_remove)
137+
{
138+
init_stmt.lhs.vars.erase(loc);
139+
init_stmt.lhs.inits.erase(loc);
140+
init_stmt.rhs.vars.erase(loc);
141+
init_stmt.rhs.inits.erase(loc);
142+
}
143+
if (!init_stmt.lhs.inits.empty() || !init_stmt.rhs.inits.empty())
144+
{
145+
init_deque.push_back(init_stmt);
146+
}
147+
148+
init_deque.pop_front();
82149
}
83150
}
84151

152+
std::vector<InitInfo> sort_init_stmts(
153+
const std::set<Location>& locals, std::deque<InitInfo>& init_deque)
154+
{
155+
std::set<Location> initialized;
156+
std::vector<InitInfo> init_stmts;
157+
while (!init_deque.empty() && initialized != locals)
158+
{
159+
// find all strict init statements
160+
auto it =
161+
std::find_if(init_deque.begin(), init_deque.end(), [](auto& init_stmt) {
162+
return init_stmt.lhs.vars.empty() || init_stmt.rhs.vars.empty();
163+
});
164+
165+
if (it == init_deque.end())
166+
{
167+
// we have a cycle, so we use the first statement
168+
it = init_deque.begin();
169+
init_stmts.push_back(*it);
170+
}
171+
else
172+
{
173+
init_stmts.push_back(*it);
174+
}
175+
176+
std::set<Location> to_remove;
177+
to_remove.insert(it->lhs.inits.begin(), it->lhs.inits.end());
178+
to_remove.insert(it->rhs.inits.begin(), it->rhs.inits.end());
179+
init_deque.erase(it);
180+
remove_locals(init_deque, to_remove);
181+
initialized.insert(to_remove.begin(), to_remove.end());
182+
}
183+
184+
return init_stmts;
185+
}
186+
85187
void find_init_stmts(Node unifybody, std::set<Location>& locals)
86188
{
87-
// gather all locals
189+
std::deque<InitInfo> potential_init_stmts;
88190
for (std::size_t i = 0; i < unifybody->size(); ++i)
89191
{
90192
Node stmt = unifybody->at(i);
@@ -95,15 +197,6 @@ namespace
95197
else if (stmt->type() == LiteralEnum)
96198
{
97199
locals.erase((stmt / Item)->location());
98-
find_init_stmts(stmt / UnifyBody, locals);
99-
}
100-
else if (stmt->type() == LiteralWith)
101-
{
102-
find_init_stmts(stmt / UnifyBody, locals);
103-
}
104-
else if (stmt->type() == LiteralNot)
105-
{
106-
find_init_stmts(stmt / UnifyBody, locals);
107200
}
108201
else if (stmt->type() == Literal)
109202
{
@@ -115,42 +208,65 @@ namespace
115208

116209
Node lhs = expr->front();
117210
Node rhs = expr->back();
118-
std::set<Location> lhs_vars;
119-
vars_from(lhs, lhs_vars);
120-
std::set<Location> lhs_found;
121-
std::set_intersection(
122-
lhs_vars.begin(),
123-
lhs_vars.end(),
124-
locals.begin(),
125-
locals.end(),
126-
std::inserter(lhs_found, lhs_found.begin()));
127-
128-
std::set<Location> rhs_vars;
129-
vars_from(rhs, rhs_vars);
130-
std::set<Location> rhs_found;
131-
std::set_intersection(
132-
rhs_vars.begin(),
133-
rhs_vars.end(),
134-
locals.begin(),
135-
locals.end(),
136-
std::inserter(rhs_found, rhs_found.begin()));
137-
138-
if (lhs_found.empty() && rhs_found.empty())
139-
{
140-
continue;
141-
}
142211

143-
for (auto& loc : lhs_found)
212+
InitSide lhs_side = side_from(lhs, locals);
213+
InitSide rhs_side = side_from(rhs, locals);
214+
215+
if (any_compiler_inits(lhs_side))
144216
{
145-
locals.erase(loc);
217+
// compiler statements will never be right-assign, so we can
218+
// use this fact later to help resolve some ambiguities
219+
rhs_side.inits.clear();
146220
}
147221

148-
for (auto& loc : rhs_found)
222+
if (lhs_side.inits.empty() && rhs_side.inits.empty())
149223
{
150-
locals.erase(loc);
224+
continue;
151225
}
152226

153-
unifybody->replace_at(i, to_init(lhs, lhs_found, rhs, rhs_found));
227+
potential_init_stmts.push_back({i, lhs_side, rhs_side});
228+
}
229+
}
230+
231+
std::vector<InitInfo> init_stmts =
232+
sort_init_stmts(locals, potential_init_stmts);
233+
for (std::size_t i = 0; i < init_stmts.size(); ++i)
234+
{
235+
InitInfo& init_stmt = init_stmts[i];
236+
Node expr = unifybody->at(init_stmt.index)->front()->front();
237+
238+
Node lhs = expr->front();
239+
Node rhs = expr->back();
240+
241+
for (auto& loc : init_stmt.lhs.inits)
242+
{
243+
locals.erase(loc);
244+
}
245+
246+
for (auto& loc : init_stmt.rhs.inits)
247+
{
248+
locals.erase(loc);
249+
}
250+
251+
unifybody->replace_at(
252+
init_stmt.index,
253+
to_init(lhs, init_stmt.lhs.inits, rhs, init_stmt.rhs.inits));
254+
}
255+
256+
// where appropriate, recurse with the updated locals
257+
for (Node stmt : *unifybody)
258+
{
259+
if (stmt->type() == LiteralEnum)
260+
{
261+
find_init_stmts(stmt / UnifyBody, locals);
262+
}
263+
else if (stmt->type() == LiteralWith)
264+
{
265+
find_init_stmts(stmt / UnifyBody, locals);
266+
}
267+
else if (stmt->type() == LiteralNot)
268+
{
269+
find_init_stmts(stmt / UnifyBody, locals);
154270
}
155271
}
156272
}

src/passes/locals.cc

+8
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,14 @@ namespace rego
134134
RuleObj, [builtins](Node n) { return preprocess_body(n, builtins); });
135135
locals.pre(
136136
RuleSet, [builtins](Node n) { return preprocess_body(n, builtins); });
137+
138+
return locals;
139+
}
140+
141+
PassDef compr_locals(const BuiltIns& builtins)
142+
{
143+
PassDef locals = {
144+
"compr_locals", wf_pass_locals, dir::bottomup | dir::once};
137145
locals.pre(
138146
ArrayCompr, [builtins](Node n) { return preprocess_body(n, builtins); });
139147
locals.pre(

src/passes/rulebody.cc

-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ namespace rego
138138
<< (T(Var)[Lhs] * T(Var)[Rhs] * T(UnifyBody)[UnifyBody])) >>
139139
[](Match& _) {
140140
ACTION();
141-
logging::Debug() << "enum";
142141
Location value = _.fresh({"value"});
143142
return Seq << (Lift << UnifyBody
144143
<< (Local << (Var ^ value) << Undefined))

src/rego.cc

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace rego
3030
explicit_enums(),
3131
body_locals(builtins),
3232
value_locals(builtins),
33+
compr_locals(builtins),
3334
rules_to_compr(),
3435
compr(),
3536
absolute_refs(),

src/unifier.cc

+9-1
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,15 @@ namespace rego
975975
}
976976
else
977977
{
978-
auto maybe_nodes = Resolver::apply_access(container, args[1]->node());
978+
Node index = args[1]->node();
979+
if (index->type() == Undefined)
980+
{
981+
values.push_back(
982+
ValueDef::create(var, Undefined ^ "undefined", sources));
983+
return values;
984+
}
985+
986+
auto maybe_nodes = Resolver::apply_access(container, index);
979987
if (maybe_nodes)
980988
{
981989
Nodes defs = maybe_nodes.value();

tests/regocpp.yaml

+28
Original file line numberDiff line numberDiff line change
@@ -1094,3 +1094,31 @@ cases:
10941094
query: data.every_some.output = x
10951095
want_result:
10961096
- x: true
1097+
- note: regocpp/bug95
1098+
modules:
1099+
- |
1100+
package test
1101+
1102+
x = c {
1103+
a = b
1104+
b = c
1105+
a = 12
1106+
}
1107+
query: data.test.x = x
1108+
want_result:
1109+
- x: 12
1110+
- note: regocpp/bug97
1111+
modules:
1112+
- |
1113+
package test
1114+
1115+
x = y {
1116+
a = [1, 2, 3]
1117+
y = {z | z = a[_]}
1118+
}
1119+
query: data.test.x = x
1120+
want_result:
1121+
- x:
1122+
- 1
1123+
- 2
1124+
- 3

wrappers/python/docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
project = 'regopy'
1010
copyright = '2023, Microsoft'
1111
author = 'Microsoft'
12-
release = '0.3.10'
12+
release = '0.3.11'
1313

1414
# -- General configuration ---------------------------------------------------
1515
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

wrappers/rust/regorust/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "regorust"
3-
version = "0.3.10"
3+
version = "0.3.11"
44
edition = "2021"
55
description = "Rust bindings for the rego-cpp Rego compiler and interpreter"
66
license = "MIT"

0 commit comments

Comments
 (0)