Skip to content

Commit 7915fb9

Browse files
jnthntatumcopybara-github
authored andcommitted
Add recursive implementation for cel.@block.
PiperOrigin-RevId: 721450778
1 parent 0c7c871 commit 7915fb9

File tree

5 files changed

+296
-11
lines changed

5 files changed

+296
-11
lines changed

eval/compiler/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ cc_library(
101101
"//common:ast",
102102
"//common:ast_traverse",
103103
"//common:ast_visitor",
104+
"//common:expr",
104105
"//common:kind",
105106
"//common:memory",
106107
"//common:type",

eval/compiler/cel_expression_builder_flat_impl_test.cc

Lines changed: 224 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
#include "cel/expr/conformance/proto3/test_all_types.pb.h"
5353
#include "google/protobuf/arena.h"
5454
#include "google/protobuf/message.h"
55+
#include "google/protobuf/text_format.h"
5556

5657
namespace google::api::expr::runtime {
5758

@@ -103,6 +104,7 @@ struct RecursiveTestCase {
103104
std::string test_name;
104105
std::string expr;
105106
test::CelValueMatcher matcher;
107+
std::string pb_expr;
106108
};
107109

108110
class RecursivePlanTest : public ::testing::TestWithParam<RecursiveTestCase> {
@@ -144,19 +146,29 @@ class RecursivePlanTest : public ::testing::TestWithParam<RecursiveTestCase> {
144146
}
145147
};
146148

147-
absl::StatusOr<ParsedExpr> ParseWithBind(absl::string_view cel) {
149+
absl::StatusOr<ParsedExpr> ParseTestCase(const RecursiveTestCase& test_case) {
148150
static const std::vector<Macro>* kMacros = []() {
149151
auto* result = new std::vector<Macro>(Macro::AllMacros());
150152
absl::c_copy(cel::extensions::bindings_macros(),
151153
std::back_inserter(*result));
152154
return result;
153155
}();
154-
return ParseWithMacros(cel, *kMacros, "<input>");
156+
157+
if (!test_case.expr.empty()) {
158+
return ParseWithMacros(test_case.expr, *kMacros, "<input>");
159+
} else if (!test_case.pb_expr.empty()) {
160+
ParsedExpr result;
161+
if (!google::protobuf::TextFormat::ParseFromString(test_case.pb_expr, &result)) {
162+
return absl::InvalidArgumentError("Failed to parse proto");
163+
}
164+
return result;
165+
}
166+
return absl::InvalidArgumentError("No expression provided");
155167
}
156168

157169
TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) {
158170
const RecursiveTestCase& test_case = GetParam();
159-
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
171+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case));
160172
cel::RuntimeOptions options;
161173
options.container = "cel.expr.conformance.proto3";
162174
google::protobuf::Arena arena;
@@ -183,7 +195,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveImpl) {
183195

184196
TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) {
185197
const RecursiveTestCase& test_case = GetParam();
186-
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
198+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case));
187199
cel::RuntimeOptions options;
188200
options.container = "cel.expr.conformance.proto3";
189201
google::protobuf::Arena arena;
@@ -216,7 +228,7 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) {
216228

217229
TEST_P(RecursivePlanTest, ParsedExprRecursiveTraceSupport) {
218230
const RecursiveTestCase& test_case = GetParam();
219-
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
231+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case));
220232
cel::RuntimeOptions options;
221233
options.container = "cel.expr.conformance.proto3";
222234
google::protobuf::Arena arena;
@@ -249,7 +261,7 @@ TEST_P(RecursivePlanTest, Disabled) {
249261
google::protobuf::LinkMessageReflection<TestAllTypes>();
250262

251263
const RecursiveTestCase& test_case = GetParam();
252-
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseWithBind(test_case.expr));
264+
ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, ParseTestCase(test_case));
253265
cel::RuntimeOptions options;
254266
options.container = "cel.expr.conformance.proto3";
255267
google::protobuf::Arena arena;
@@ -326,7 +338,212 @@ INSTANTIATE_TEST_SUITE_P(
326338
{"re_matches_receiver",
327339
"(string_abc + string_def).matches(r'(123)?' + r'abc' + r'def')",
328340
test::IsCelBool(true)},
329-
}),
341+
{"block", "", test::IsCelBool(true),
342+
R"pb(
343+
expr {
344+
id: 1
345+
call_expr {
346+
function: "cel.@block"
347+
args {
348+
id: 2
349+
list_expr {
350+
elements { const_expr { int64_value: 8 } }
351+
elements { const_expr { int64_value: 10 } }
352+
}
353+
}
354+
args {
355+
id: 3
356+
call_expr {
357+
function: "_<_"
358+
args { ident_expr { name: "@index0" } }
359+
args { ident_expr { name: "@index1" } }
360+
}
361+
}
362+
}
363+
})pb"},
364+
{"block_with_comprehensions", "", test::IsCelBool(true),
365+
// Something like:
366+
// variables:
367+
// - users: {'bob': ['bar'], 'alice': ['foo', 'bar']}
368+
// - somone_has_bar: users.exists(u, 'bar' in users[u])
369+
// policy:
370+
// - someone_has_bar && !users.exists(u, u == 'eve'))
371+
//
372+
R"pb(
373+
expr {
374+
call_expr {
375+
function: "cel.@block"
376+
args {
377+
list_expr {
378+
elements {
379+
struct_expr: {
380+
entries: {
381+
map_key: { const_expr: { string_value: "bob" } }
382+
value: {
383+
list_expr: {
384+
elements: { const_expr: { string_value: "bar" } }
385+
}
386+
}
387+
}
388+
entries: {
389+
map_key: { const_expr: { string_value: "alice" } }
390+
value: {
391+
list_expr: {
392+
elements: { const_expr: { string_value: "bar" } }
393+
elements: { const_expr: { string_value: "foo" } }
394+
}
395+
}
396+
}
397+
}
398+
}
399+
elements {
400+
id: 16
401+
comprehension_expr: {
402+
iter_var: "u"
403+
iter_range: {
404+
id: 1
405+
ident_expr: { name: "@index0" }
406+
}
407+
accu_var: "__result__"
408+
accu_init: {
409+
id: 9
410+
const_expr: { bool_value: false }
411+
}
412+
loop_condition: {
413+
id: 12
414+
call_expr: {
415+
function: "@not_strictly_false"
416+
args: {
417+
id: 11
418+
call_expr: {
419+
function: "!_"
420+
args: {
421+
id: 10
422+
ident_expr: { name: "__result__" }
423+
}
424+
}
425+
}
426+
}
427+
}
428+
loop_step: {
429+
id: 14
430+
call_expr: {
431+
function: "_||_"
432+
args: {
433+
id: 13
434+
ident_expr: { name: "__result__" }
435+
}
436+
args: {
437+
id: 5
438+
call_expr: {
439+
function: "@in"
440+
args: {
441+
id: 4
442+
const_expr: { string_value: "bar" }
443+
}
444+
args: {
445+
id: 7
446+
call_expr: {
447+
function: "_[_]"
448+
args: {
449+
id: 6
450+
ident_expr: { name: "@index0" }
451+
}
452+
args: {
453+
id: 8
454+
ident_expr: { name: "u" }
455+
}
456+
}
457+
}
458+
}
459+
}
460+
}
461+
}
462+
result: {
463+
id: 15
464+
ident_expr: { name: "__result__" }
465+
}
466+
}
467+
}
468+
}
469+
}
470+
args {
471+
id: 17
472+
call_expr: {
473+
function: "_&&_"
474+
args: {
475+
id: 1
476+
ident_expr: { name: "@index1" }
477+
}
478+
args: {
479+
id: 2
480+
call_expr: {
481+
function: "!_"
482+
args: {
483+
id: 16
484+
comprehension_expr: {
485+
iter_var: "u"
486+
iter_range: {
487+
id: 3
488+
ident_expr: { name: "@index0" }
489+
}
490+
accu_var: "__result__"
491+
accu_init: {
492+
id: 9
493+
const_expr: { bool_value: false }
494+
}
495+
loop_condition: {
496+
id: 12
497+
call_expr: {
498+
function: "@not_strictly_false"
499+
args: {
500+
id: 11
501+
call_expr: {
502+
function: "!_"
503+
args: {
504+
id: 10
505+
ident_expr: { name: "__result__" }
506+
}
507+
}
508+
}
509+
}
510+
}
511+
loop_step: {
512+
id: 14
513+
call_expr: {
514+
function: "_||_"
515+
args: {
516+
id: 13
517+
ident_expr: { name: "__result__" }
518+
}
519+
args: {
520+
id: 7
521+
call_expr: {
522+
function: "_==_"
523+
args: {
524+
id: 6
525+
ident_expr: { name: "u" }
526+
}
527+
args: {
528+
id: 8
529+
const_expr: { string_value: "eve" }
530+
}
531+
}
532+
}
533+
}
534+
}
535+
result: {
536+
id: 15
537+
ident_expr: { name: "__result__" }
538+
}
539+
}
540+
}
541+
}
542+
}
543+
}
544+
}
545+
}
546+
})pb"}}),
330547

331548
[](const testing::TestParamInfo<RecursiveTestCase>& info) -> std::string {
332549
return info.param.test_name;

eval/compiler/flat_expr_builder.cc

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
#include "common/ast.h"
5555
#include "common/ast_traverse.h"
5656
#include "common/ast_visitor.h"
57+
#include "common/expr.h"
5758
#include "common/kind.h"
5859
#include "common/memory.h"
5960
#include "common/type.h"
@@ -2017,15 +2018,42 @@ FlatExprVisitor::CallHandlerResult FlatExprVisitor::HandleBlock(
20172018
const cel::ast_internal::Expr& expr,
20182019
const cel::ast_internal::Call& call_expr) {
20192020
ABSL_DCHECK(call_expr.function() == kBlock);
2020-
if (!block_.has_value() || block_->expr != &expr) {
2021-
SetProgressStatusError(absl::InvalidArgumentError(
2022-
"unexpected number call to internal cel.@block"));
2021+
if (!block_.has_value() || block_->expr != &expr ||
2022+
call_expr.args().size() != 2) {
2023+
SetProgressStatusError(
2024+
absl::InvalidArgumentError("unexpected call to internal cel.@block"));
20232025
return CallHandlerResult::kIntercepted;
20242026
}
2027+
20252028
BlockInfo& block = *block_;
20262029
block.in = false;
20272030
index_manager().ReleaseSlots(block.slot_count);
2028-
AddStep(CreateClearSlotsStep(block.index, block.slot_count, -1));
2031+
2032+
// Check if eligible for recursion and update the plan if so.
2033+
//
2034+
// The first argument to @block is the list of initializers. These don't
2035+
// generate a plan in the main program (they are tracked separately to support
2036+
// lazy evaluation) so we only need to extract the second argument -- the body
2037+
// of the block that uses the initializers.
2038+
ProgramBuilder::Subexpression* body_subexpression =
2039+
program_builder_.GetSubexpression(&call_expr.args()[1]);
2040+
2041+
if (options_.max_recursion_depth != 0 && body_subexpression != nullptr &&
2042+
body_subexpression->IsRecursive() &&
2043+
(options_.max_recursion_depth < 0 ||
2044+
body_subexpression->recursive_program().depth <
2045+
options_.max_recursion_depth)) {
2046+
auto recursive_program = body_subexpression->ExtractRecursiveProgram();
2047+
SetRecursiveStep(
2048+
CreateDirectBlockStep(block.index, block.slot_count,
2049+
std::move(recursive_program.step), expr.id()),
2050+
recursive_program.depth + 1);
2051+
return CallHandlerResult::kIntercepted;
2052+
}
2053+
2054+
// Otherwise, iterative plan.
2055+
AddStep(CreateClearSlotsStep(block.index, block.slot_count, expr.id()));
2056+
20292057
return CallHandlerResult::kIntercepted;
20302058
}
20312059

0 commit comments

Comments
 (0)