From cea3a7f403a6c0443267c17066a7990959c14b96 Mon Sep 17 00:00:00 2001 From: Anders Eknert Date: Thu, 4 Jun 2026 14:54:55 +0200 Subject: [PATCH] ast: Clean up code for value comparisons MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `ast.Compare(any, any)` function is a beast better avoided, and the `any` args type mean some AST values (like strings) escape to the heap when boxed. Previous work already ensured it wasn't called too often — this just moves it further along by having all `ast.Value`s do their own comparisons with the help of a new function to easily compare 2 different value types. Also: - topdown: slightly cheaper object.union_n implementation - eval: remove unused expr field on evalNot - eval: rename fmtVarTerm -> fmtVar - term: remove unused termSlice type - builtins: cheaper Builtin.Ref() Signed-off-by: Anders Eknert --- v1/ast/builtins.go | 11 +-- v1/ast/compare.go | 59 ++++++++-------- v1/ast/policy.go | 11 ++- v1/ast/term.go | 115 ++++++++++++++------------------ v1/ast/term_test.go | 6 +- v1/rego/rego.go | 3 +- v1/topdown/comparison.go | 12 ++-- v1/topdown/eval.go | 23 +++---- v1/topdown/eval_test.go | 4 +- v1/topdown/object.go | 100 +++++++++++++++++---------- v1/topdown/object_bench_test.go | 110 ++++++++++++++++++++++++------ 11 files changed, 266 insertions(+), 188 deletions(-) diff --git a/v1/ast/builtins.go b/v1/ast/builtins.go index 78bd9a7d872..d557a9bd4bd 100644 --- a/v1/ast/builtins.go +++ b/v1/ast/builtins.go @@ -3662,11 +3662,12 @@ func (b *Builtin) Call(operands ...*Term) *Term { // Ref returns a Ref that refers to the built-in function. func (b *Builtin) Ref() Ref { - parts := strings.Split(b.Name, ".") - ref := make(Ref, len(parts)) - ref[0] = VarTerm(parts[0]) - for i := 1; i < len(parts); i++ { - ref[i] = InternedTerm(parts[i]) + numParts := strings.Count(b.Name, ".") + 1 + curr, remaining, ok := strings.Cut(b.Name, ".") + ref := append(make(Ref, 0, numParts), VarTerm(curr)) + for ok { + curr, remaining, ok = strings.Cut(remaining, ".") + ref = append(ref, InternedTerm(curr)) } return ref } diff --git a/v1/ast/compare.go b/v1/ast/compare.go index 03149056702..1bf990a6ed3 100644 --- a/v1/ast/compare.go +++ b/v1/ast/compare.go @@ -38,7 +38,6 @@ import ( // is empty. // Other comparisons are consistent but not defined. func Compare(a, b any) int { - if t, ok := a.(*Term); ok { if t == nil { a = nil @@ -77,7 +76,7 @@ func Compare(a, b any) int { switch a := a.(type) { case *Not: b := b.(*Not) - return Compare(a.Body, b.Body) + return a.Compare(b) case Null: return 0 case Boolean: @@ -120,25 +119,13 @@ func Compare(a, b any) int { return a.Compare(b.(Set)) case *ArrayComprehension: b := b.(*ArrayComprehension) - if cmp := Compare(a.Term, b.Term); cmp != 0 { - return cmp - } - return a.Body.Compare(b.Body) + return a.Compare(b) case *ObjectComprehension: b := b.(*ObjectComprehension) - if cmp := Compare(a.Key, b.Key); cmp != 0 { - return cmp - } - if cmp := Compare(a.Value, b.Value); cmp != 0 { - return cmp - } - return a.Body.Compare(b.Body) + return a.Compare(b) case *SetComprehension: b := b.(*SetComprehension) - if cmp := Compare(a.Term, b.Term); cmp != 0 { - return cmp - } - return a.Body.Compare(b.Body) + return a.Compare(b) case Call: return termSliceCompare(a, b.(Call)) case *Expr: @@ -173,14 +160,8 @@ func Compare(a, b any) int { panic(fmt.Sprintf("illegal value: %T", a)) } -type termSlice []*Term - -func (s termSlice) Less(i, j int) bool { return Compare(s[i].Value, s[j].Value) < 0 } -func (s termSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s termSlice) Len() int { return len(s) } - -func sortOrder(x any) int { - switch x.(type) { +func valueSortOrder(v Value) int { + switch v.(type) { case Null: return 0 case Boolean: @@ -209,6 +190,28 @@ func sortOrder(x any) int { return 12 case Call: return 13 + case *Not: + return 111 + } + return 10000000 +} + +func valueTypeCompare[A, B Value](a A, b B) int { + sortA := valueSortOrder(a) + sortB := valueSortOrder(b) + + if sortA < sortB { + return -1 + } else if sortB < sortA { + return 1 + } + return 0 +} + +func sortOrder(x any) int { + switch v := x.(type) { + case Value: + return valueSortOrder(v) case Args: return 14 case *Expr: @@ -223,8 +226,6 @@ func sortOrder(x any) int { return 104 case *With: return 110 - case *Not: - return 111 case *Head: return 120 case Body: @@ -294,7 +295,7 @@ func rulesCompare(a, b []*Rule) int { func termSliceCompare(a, b []*Term) int { minLen := min(len(b), len(a)) for i := range minLen { - if cmp := Compare(a[i], b[i]); cmp != 0 { + if cmp := a[i].Value.Compare(b[i].Value); cmp != 0 { return cmp } } @@ -309,7 +310,7 @@ func termSliceCompare(a, b []*Term) int { func withSliceCompare(a, b []*With) int { minLen := min(len(b), len(a)) for i := range minLen { - if cmp := Compare(a[i], b[i]); cmp != 0 { + if cmp := a[i].Compare(b[i]); cmp != 0 { return cmp } } diff --git a/v1/ast/policy.go b/v1/ast/policy.go index 470c6ab6f76..f3bb43640f0 100644 --- a/v1/ast/policy.go +++ b/v1/ast/policy.go @@ -562,7 +562,7 @@ func (imp *Import) Compare(other *Import) int { } else if other == nil { return 1 } - if cmp := Compare(imp.Path, other.Path); cmp != 0 { + if cmp := imp.Path.Value.Compare(other.Path.Value); cmp != 0 { return cmp } @@ -906,10 +906,10 @@ func (head *Head) Compare(other *Head) int { } else if !head.Assign && other.Assign { return 1 } - if cmp := Compare(head.Args, other.Args); cmp != 0 { + if cmp := termSliceCompare(head.Args, other.Args); cmp != 0 { return cmp } - if cmp := Compare(head.Reference, other.Reference); cmp != 0 { + if cmp := termSliceCompare(head.Reference, other.Reference); cmp != 0 { return cmp } if cmp := VarCompare(head.Name, other.Name); cmp != 0 { @@ -1218,7 +1218,6 @@ func (expr *Expr) Equal(other *Expr) bool { // Otherwise, the expression terms are compared normally. If both expressions // have the same terms, the modifiers are compared. func (expr *Expr) Compare(other *Expr) int { - if expr == nil { if other == nil { return 0 @@ -1252,7 +1251,7 @@ func (expr *Expr) Compare(other *Expr) int { switch t := expr.Terms.(type) { case *Term: - if cmp := Compare(t.Value, other.Terms.(*Term).Value); cmp != 0 { + if cmp := t.Value.Compare(other.Terms.(*Term).Value); cmp != 0 { return cmp } case []*Term: @@ -1268,7 +1267,7 @@ func (expr *Expr) Compare(other *Expr) int { return cmp } case *Not: - if cmp := Compare(t, other.Terms.(*Not)); cmp != 0 { + if cmp := t.Compare(other.Terms.(*Not)); cmp != 0 { return cmp } case *LogicalAnd: diff --git a/v1/ast/term.go b/v1/ast/term.go index fb1e7aeae51..178bc09e389 100644 --- a/v1/ast/term.go +++ b/v1/ast/term.go @@ -857,12 +857,10 @@ func (num Number) Equal(other Value) bool { // Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (num Number) Compare(other Value) int { - // Optimize for the common case, as calling Compare allocates on heap. if otherNum, yes := other.(Number); yes { return NumberCompare(num, otherNum) } - - return Compare(num, other) + return valueTypeCompare(num, other) } // Find returns the current value or a not found error. @@ -974,7 +972,7 @@ func (str String) Compare(other Value) int { return 1 } - return Compare(str, other) + return valueTypeCompare(str, other) } // Find returns the current value or a not found error. @@ -1059,7 +1057,7 @@ func (ts *TemplateString) Compare(other Value) int { return 0 } - return Compare(ts, other) + return valueTypeCompare(ts, other) } func (ts *TemplateString) Find(path Ref) (Value, error) { @@ -1171,7 +1169,7 @@ func (v Var) Compare(other Value) int { if otherVar, ok := other.(Var); ok { return strings.Compare(string(v), string(otherVar)) } - return Compare(v, other) + return valueTypeCompare(v, other) } // Find returns the current value or a not found error. @@ -1364,8 +1362,7 @@ func (ref Ref) Compare(other Value) int { if o, ok := other.(Ref); ok { return termSliceCompare(ref, o) } - - return Compare(ref, other) + return valueTypeCompare(ref, other) } // Find returns the current value or a "not found" error. @@ -1634,16 +1631,7 @@ func (arr *Array) Compare(other Value) int { return termSliceCompare(arr.elems, b.elems) } - sortA := sortOrder(arr) - sortB := sortOrder(other) - - if sortA < sortB { - return -1 - } else if sortB < sortA { - return 1 - } - - return Compare(arr, other) + return valueTypeCompare(arr, other) } // Find returns the value at the index or an out-of-range error. @@ -1927,15 +1915,11 @@ func (s *set) sortedKeys() []*Term { // Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (s *set) Compare(other Value) int { - o1 := sortOrder(s) - o2 := sortOrder(other) - if o1 < o2 { - return -1 - } else if o1 > o2 { - return 1 + if t, ok := other.(*set); ok { + return slices.CompareFunc(s.sortedKeys(), t.sortedKeys(), TermValueCompare) } - t := other.(*set) - return termSliceCompare(s.sortedKeys(), t.sortedKeys()) + + return valueTypeCompare(s, other) } // Find returns the set or dereferences the element itself. @@ -2204,12 +2188,8 @@ func (l *lazyObj) force() Object { } func (l *lazyObj) Compare(other Value) int { - o1 := sortOrder(l) - o2 := sortOrder(other) - if o1 < o2 { - return -1 - } else if o2 < o1 { - return 1 + if c := valueTypeCompare(l, other); c != 0 { + return c } return l.force().Compare(other) } @@ -2414,12 +2394,8 @@ func (obj *object) Compare(other Value) int { if x, ok := other.(*lazyObj); ok { other = x.force() } - o1 := sortOrder(obj) - o2 := sortOrder(other) - if o1 < o2 { - return -1 - } else if o2 < o1 { - return 1 + if c := valueTypeCompare(obj, other); c != 0 { + return c } a := obj b := other.(*object) @@ -2431,27 +2407,14 @@ func (obj *object) Compare(other Value) int { minLen = len(bkeys) } for i := range minLen { - keysCmp := Compare(akeys[i].key, bkeys[i].key) - if keysCmp < 0 { - return -1 + if c := akeys[i].key.Value.Compare(bkeys[i].key.Value); c != 0 { + return c } - if keysCmp > 0 { - return 1 - } - valA := akeys[i].value - valB := bkeys[i].value - valCmp := Compare(valA, valB) - if valCmp != 0 { - return valCmp + if c := akeys[i].value.Value.Compare(bkeys[i].value.Value); c != 0 { + return c } } - if len(akeys) < len(bkeys) { - return -1 - } - if len(bkeys) < len(akeys) { - return 1 - } - return 0 + return len(akeys) - len(bkeys) } // Find returns the value at the key or undefined. @@ -2506,7 +2469,7 @@ func KeyHashEqual(x, y Value) bool { } } - return Compare(x, y) == 0 + return x.Compare(y) == 0 } // Hash returns the hash code for the Value. @@ -2906,13 +2869,19 @@ func (ac *ArrayComprehension) Copy() *ArrayComprehension { // Equal returns true if ac is equal to other. func (ac *ArrayComprehension) Equal(other Value) bool { - return Compare(ac, other) == 0 + return ac.Compare(other) == 0 } // Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (ac *ArrayComprehension) Compare(other Value) int { - return Compare(ac, other) + if bc, ok := other.(*ArrayComprehension); ok { + if c := ac.Term.Value.Compare(bc.Term.Value); c != 0 { + return c + } + return ac.Body.Compare(bc.Body) + } + return valueTypeCompare(ac, other) } // Find returns the current value or a not found error. @@ -2967,13 +2936,22 @@ func (oc *ObjectComprehension) Copy() *ObjectComprehension { // Equal returns true if oc is equal to other. func (oc *ObjectComprehension) Equal(other Value) bool { - return Compare(oc, other) == 0 + return oc.Compare(other) == 0 } // Compare compares oc to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (oc *ObjectComprehension) Compare(other Value) int { - return Compare(oc, other) + if bc, ok := other.(*ObjectComprehension); ok { + if c := oc.Key.Value.Compare(bc.Key.Value); c != 0 { + return c + } + if c := oc.Value.Value.Compare(bc.Value.Value); c != 0 { + return c + } + return oc.Body.Compare(bc.Body) + } + return valueTypeCompare(oc, other) } // Find returns the current value or a not found error. @@ -3025,13 +3003,19 @@ func (sc *SetComprehension) Copy() *SetComprehension { // Equal returns true if sc is equal to other. func (sc *SetComprehension) Equal(other Value) bool { - return Compare(sc, other) == 0 + return sc.Compare(other) == 0 } // Compare compares sc to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (sc *SetComprehension) Compare(other Value) int { - return Compare(sc, other) + if oc, ok := other.(*SetComprehension); ok { + if c := sc.Term.Value.Compare(oc.Term.Value); c != 0 { + return c + } + return sc.Body.Compare(oc.Body) + } + return valueTypeCompare(sc, other) } // Find returns the current value or a not found error. @@ -3074,7 +3058,10 @@ func (c Call) Copy() Call { // Compare compares c to other, return <0, 0, or >0 if it is less than, equal to, // or greater than other. func (c Call) Compare(other Value) int { - return Compare(c, other) + if oc, ok := other.(Call); ok { + return termSliceCompare(c, oc) + } + return valueTypeCompare(c, other) } // Find returns the current value or a not found error. diff --git a/v1/ast/term_test.go b/v1/ast/term_test.go index 47990b7ef5b..1f074b33427 100644 --- a/v1/ast/term_test.go +++ b/v1/ast/term_test.go @@ -12,7 +12,7 @@ import ( "math/rand" "reflect" "runtime" - "sort" + "slices" "strings" "sync" "testing" @@ -986,7 +986,7 @@ func TestSetConcurrentReads(t *testing.T) { s.Add(numbers[i]) } // In-place sort on numbers. - sort.Sort(termSlice(numbers)) + slices.SortFunc(numbers, TermValueCompare) // Check if race condition on key sorting is present. var wg sync.WaitGroup @@ -1028,7 +1028,7 @@ func TestObjectConcurrentReads(t *testing.T) { o.Insert(numbers[i], NullTerm()) } // In-place sort on numbers. - sort.Sort(termSlice(numbers)) + slices.SortFunc(numbers, TermValueCompare) // Check if race condition on key sorting is present. var wg sync.WaitGroup diff --git a/v1/rego/rego.go b/v1/rego/rego.go index 6f2a824cb5a..6e4ef530547 100644 --- a/v1/rego/rego.go +++ b/v1/rego/rego.go @@ -2173,8 +2173,7 @@ func (r *Rego) parseQuery(queryImports []*ast.Import, m metrics.Metrics) (ast.Bo func parserOptionsFromRegoVersionImport(imports []*ast.Import, popts ast.ParserOptions) (ast.ParserOptions, error) { for _, imp := range imports { - path := imp.Path.Value.(ast.Ref) - if ast.Compare(path, ast.RegoV1CompatibleRef) == 0 { + if ast.RegoV1CompatibleRef.Compare(imp.Path.Value) == 0 { popts.RegoVersion = ast.RegoV1 return popts, nil } diff --git a/v1/topdown/comparison.go b/v1/topdown/comparison.go index 6c10129faaf..d0434a028b1 100644 --- a/v1/topdown/comparison.go +++ b/v1/topdown/comparison.go @@ -9,27 +9,27 @@ import "github.com/open-policy-agent/opa/v1/ast" type compareFunc func(a, b ast.Value) bool func compareGreaterThan(a, b ast.Value) bool { - return ast.Compare(a, b) > 0 + return a.Compare(b) > 0 } func compareGreaterThanEq(a, b ast.Value) bool { - return ast.Compare(a, b) >= 0 + return a.Compare(b) >= 0 } func compareLessThan(a, b ast.Value) bool { - return ast.Compare(a, b) < 0 + return a.Compare(b) < 0 } func compareLessThanEq(a, b ast.Value) bool { - return ast.Compare(a, b) <= 0 + return a.Compare(b) <= 0 } func compareNotEq(a, b ast.Value) bool { - return ast.Compare(a, b) != 0 + return a.Compare(b) != 0 } func compareEq(a, b ast.Value) bool { - return ast.Compare(a, b) == 0 + return a.Compare(b) == 0 } func builtinCompare(cmp compareFunc) BuiltinFunc { diff --git a/v1/topdown/eval.go b/v1/topdown/eval.go index 87b94d16466..d600201b584 100644 --- a/v1/topdown/eval.go +++ b/v1/topdown/eval.go @@ -496,8 +496,7 @@ func (e *eval) evalStep(iter evalIterator) error { }) } case *ast.Term: - // generateVar inlined here to avoid extra allocations in hot path - rterm := ast.VarTerm(e.fmtVarTerm()) + rterm := ast.VarTerm(e.fmtVar()) if e.partial() { e.inliningControl.PushDisable(rterm.Value, true) @@ -536,9 +535,8 @@ func (e *eval) evalStep(iter evalIterator) error { case *ast.Not: en := evalNot{ - e: e, - not: terms, - expr: expr, + e: e, + not: terms, } err = en.eval(func(e *eval) error { defined = true @@ -575,8 +573,7 @@ func (e *eval) evalStep(iter evalIterator) error { }) } case *ast.Term: - // generateVar inlined here to avoid extra allocations in hot path - rterm := ast.VarTerm(e.fmtVarTerm()) + rterm := ast.VarTerm(e.fmtVar()) err = e.unify(terms, rterm, func() error { if e.saveSet != nil && e.saveSet.Contains(rterm, e.bindings) { return e.saveExpr(ast.NewExpr(rterm), e.bindings, func() error { @@ -600,9 +597,8 @@ func (e *eval) evalStep(iter evalIterator) error { case *ast.Not: en := evalNot{ - e: e, - not: terms, - expr: expr, + e: e, + not: terms, } err = en.eval(func(e *eval) error { return iter(e) @@ -617,7 +613,7 @@ func (e *eval) evalStep(iter evalIterator) error { // Single-purpose fmt.Sprintf replacement for generating variable names with only // one allocation performed instead of 4, and in 1/3 the time. -func (e *eval) fmtVarTerm() string { +func (e *eval) fmtVar() string { buf := make([]byte, 0, len(e.genvarprefix)+util.NumDigitsUint(e.queryID)+util.NumDigitsInt(e.index)+7) buf = append(buf, e.genvarprefix...) @@ -4235,9 +4231,8 @@ func (e *evalEvery) plug(expr *ast.Expr) *ast.Expr { } type evalNot struct { - e *eval - not *ast.Not - expr *ast.Expr + e *eval + not *ast.Not } func (e evalNot) eval(iter evalIterator) error { diff --git a/v1/topdown/eval_test.go b/v1/topdown/eval_test.go index 16b48d62ad6..660aba175a1 100644 --- a/v1/topdown/eval_test.go +++ b/v1/topdown/eval_test.go @@ -1675,7 +1675,7 @@ func TestFmtVarTerm(t *testing.T) { index: 54321, } - res := e.fmtVarTerm() + res := e.fmtVar() if res != "foobar_term_12345_54321" { t.Fatalf("Expected foobar_term_12345_54321 but got %s", res) @@ -1699,6 +1699,6 @@ func BenchmarkFormatVarTerm(b *testing.B) { } for b.Loop() { - _ = e.fmtVarTerm() + _ = e.fmtVar() } } diff --git a/v1/topdown/object.go b/v1/topdown/object.go index c2884e9c78d..299978a291c 100644 --- a/v1/topdown/object.go +++ b/v1/topdown/object.go @@ -29,9 +29,7 @@ func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T return iter(operands[0]) } - r := mergeWithOverwrite(objA, objB) - - return iter(ast.NewTerm(r)) + return iter(ast.NewTerm(mergeWithOverwrite(objA, objB))) } func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error { @@ -40,6 +38,21 @@ func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast. return err } + n := arr.Len() + if n == 0 { + return iter(ast.InternedEmptyObject) + } + + first := arr.Elem(0) + obj, ok := first.Value.(ast.Object) + if !ok { + return builtins.NewOperandElementErr(1, arr, first.Value, "object") + } + + if n == 1 { + return iter(first) + } + // Because we need merge-with-overwrite behavior, we can iterate // back-to-front, and get a mostly correct set of key assignments that // give us the "last assignment wins, with merges" behavior we want. @@ -52,20 +65,23 @@ func builtinObjectUnionN(_ BuiltinContext, operands []*ast.Term, iter func(*ast. // Want Output: {"a": {"c": 3}} // First pass: count total keys for pre-allocation - totalSize := 0 - for i := range arr.Len() { - o, ok := arr.Elem(i).Value.(ast.Object) + totalSize := obj.Len() + for i := 1; i < n; i++ { + elem := arr.Elem(i) + o, ok := elem.Value.(ast.Object) if !ok { - return builtins.NewOperandElementErr(1, arr, arr.Elem(i).Value, "object") + return builtins.NewOperandElementErr(1, arr, elem.Value, "object") } totalSize += o.Len() } result := ast.NewObjectWithCapacity(totalSize) frozenKeys := make(map[*ast.Term]struct{}, totalSize) - for i := arr.Len() - 1; i >= 0; i-- { - o := arr.Elem(i).Value.(ast.Object) // Already validated above - mergewithOverwriteInPlace(result, o, frozenKeys) + + for i := n - 1; i >= 0; i-- { + if o := arr.Elem(i).Value.(ast.Object); o.Len() > 0 { + mergewithOverwriteInPlace(result, o, frozenKeys) + } } return iter(ast.NewTerm(result)) @@ -198,39 +214,49 @@ func mergeWithOverwrite(objA, objB ast.Object) ast.Object { // Modifies obj with any new keys from other, and recursively // merges any keys where the values are both objects. -func mergewithOverwriteInPlace(obj, other ast.Object, frozenKeys map[*ast.Term]struct{}) { - other.Foreach(func(k, v *ast.Term) { - v2 := obj.Get(k) - // The key didn't exist in other, keep the original value. - if v2 == nil { - nestedObj, ok := v.Value.(ast.Object) - if !ok { - // v is not an object - obj.Insert(k, v) - } else { - // Copy the nested object so the original object would not be modified - nestedObjCopy := nestedObj.Copy() - obj.Insert(k, ast.NewTerm(nestedObjCopy)) - } +func mergewithOverwriteInPlace(dst, src ast.Object, frozenKeys map[*ast.Term]struct{}) { + if src.Len() == 0 || dst.Compare(src) == 0 { + return + } - return - } - // The key exists in both. Merge or reject change. - updateValueObj, ok2 := v.Value.(ast.Object) - originalValueObj, ok1 := v2.Value.(ast.Object) - // Both are objects? Merge recursively. - if ok1 && ok2 { - // Check to make sure that this key isn't frozen before merging. - if _, ok := frozenKeys[v2]; !ok { - mergewithOverwriteInPlace(originalValueObj, updateValueObj, frozenKeys) - } + src.Foreach(func(k, v *ast.Term) { + if v2 := dst.Get(k); v2 == nil { + // key not in dst, insert from src + dst.Insert(k, copyIfObject(v)) } else { - // Else, original value wins. Freeze the key. - frozenKeys[v2] = struct{}{} + // key in both, merge or reject change + srcObj, ok2 := v.Value.(ast.Object) + dstObj, ok1 := v2.Value.(ast.Object) + // both are objects? Merge recursively. + if ok1 && ok2 { + // Check to make sure that this key isn't frozen before merging. + if _, ok := frozenKeys[v2]; !ok { + mergewithOverwriteInPlace(dstObj, srcObj, frozenKeys) + } + } else { + // Else, original value wins. Freeze the key. + frozenKeys[v2] = struct{}{} + } } }) } +// copyIfObject returns term in which objects are copied recursively +// other values are returned as-is. This is much cheaper than .Copy() +// and sufficient for the use case of merging, as sets and arrays are +// overwritten rather than merged. +func copyIfObject(term *ast.Term) *ast.Term { + switch val := term.Value.(type) { + case ast.Object: + cpy, _ := val.Map(func(k, v *ast.Term) (*ast.Term, *ast.Term, error) { + return k, copyIfObject(v), nil + }) + return ast.NewTerm(cpy) + default: + return term + } +} + func init() { RegisterBuiltinFunc(ast.ObjectUnion.Name, builtinObjectUnion) RegisterBuiltinFunc(ast.ObjectUnionN.Name, builtinObjectUnionN) diff --git a/v1/topdown/object_bench_test.go b/v1/topdown/object_bench_test.go index 9e549e11905..1929a2bca0b 100644 --- a/v1/topdown/object_bench_test.go +++ b/v1/topdown/object_bench_test.go @@ -26,39 +26,36 @@ func genNxMObjectBenchmarkData(n, m int) ast.Value { } func BenchmarkObjectUnionN(b *testing.B) { - ctx := b.Context() - sizes := []int{10, 100, 250} for _, n := range sizes { for _, m := range sizes { b.Run(fmt.Sprintf("%dx%d", n, m), func(b *testing.B) { store := inmem.NewFromObject(map[string]any{"objs": genNxMObjectBenchmarkData(n, m)}) - module := `package test - - combined := object.union_n(data.objs)` - - query := ast.MustParseBody("data.test.combined") compiler := ast.MustCompileModules(map[string]string{ - "test.rego": module, + "test.rego": "package test\n\ncombined := object.union_n(data.objs)", }) + ctx := b.Context() b.ResetTimer() - for b.Loop() { - err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error { - _, err := NewQuery(query). - WithCompiler(compiler). - WithStore(store). - WithTransaction(txn). - Run(ctx) - - return err - }) + err := storage.Txn(ctx, store, storage.TransactionParams{}, func(txn storage.Transaction) error { + q := NewQuery(ast.MustParseBody("data.test.combined")). + WithCompiler(compiler). + WithStore(store). + WithTransaction(txn) - if err != nil { - b.Fatal(err) + for b.Loop() { + if _, err := q.Run(ctx); err != nil { + b.Fatal(err) + } } + + return nil + }) + + if err != nil { + b.Fatal(err) } }) } @@ -108,6 +105,79 @@ func BenchmarkObjectUnionNSlow(b *testing.B) { } } +// empty_array-16 102861856 11.5 ns/op 0 B/op 0 allocs/op +// single_object-16 43886583 25.9 ns/op 0 B/op 0 allocs/op +// merge_empty-16 6329768 189.9 ns/op 392 B/op 8 allocs/op +// merge_equal-16 5543514 216.9 ns/op 400 B/op 8 allocs/op +// merge_non-overlapping-16 4018898 296.0 ns/op 456 B/op 10 allocs/op +// merge_overlapping-16 4234546 282.7 ns/op 576 B/op 10 allocs/op +// merge_nested-16 2450734 489.9 ns/op 816 B/op 18 allocs/op +// merge_nested_with_conflict-16 2441014 492.6 ns/op 920 B/op 17 allocs/op +// merge_nested_1_equal_branch-16 1553234 773.5 ns/op 1272 B/op 24 allocs/op +// merge_nested_no_equal_branch-16 1490298 804.1 ns/op 1344 B/op 27 allocs/op +func BenchmarkObjectUnionNCallOnly(b *testing.B) { + cases := []struct { + name string + objs []string + want string + }{ + {"empty array", []string{}, `{}`}, + {"single object", []string{ + `{"a": 1}`, + }, `{"a": 1}`}, + {"merge empty", []string{ + `{"a": 1}`, + `{}`, + }, `{"a": 1}`}, + {"merge equal", []string{ + `{"a": 1}`, + `{"a": 1}`, + }, `{"a": 1}`}, + {"merge non-overlapping", []string{ + `{"a": 1}`, + `{"b": 2}`, + }, `{"a": 1, "b": 2}`}, + {"merge overlapping", []string{ + `{"a": 1}`, + `{"a": 2}`, + }, `{"a": 2}`}, + {"merge nested", []string{ + `{"a": {"b": 1}}`, + `{"a": {"c": 2}}`, + }, `{"a": {"b": 1, "c": 2}}`}, + {"merge nested with conflict", []string{ + `{"a": {"b": 1}}`, + `{"a": {"b": 2}}`, + }, `{"a": {"b": 2}}`}, + {"merge nested 1 equal branch", []string{ + `{"a": {"b": 1}, "b": {"c": 1}}`, + `{"a": {"b": 1}, "b": {"c": 2}}`, + }, `{"a": {"b": 1}, "b": {"c": 2}}`}, + {"merge nested no equal branch", []string{ + `{"a": {"b": 1}, "b": {"c": 1}}`, + `{"a": {"b": 2}, "b": {"d": 2}}`, + }, `{"a": {"b": 2}, "b": {"c": 1, "d": 2}}`}, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + arr := make([]*ast.Term, len(tc.objs)) + for i, o := range tc.objs { + arr[i] = ast.MustParseTerm(o) + } + + exp := ast.MustParseTerm(tc.want) + ops := []*ast.Term{ast.ArrayTerm(arr...)} + + for b.Loop() { + if err := builtinObjectUnionN(BuiltinContext{}, ops, eqIter(exp)); err != nil { + b.Fatal(err) + } + } + }) + } +} + // 72.64 ns/op 56 B/op 2 allocs/op // 45.49 ns/op 0 B/op 0 allocs/op func BenchmarkObjectGetFound(b *testing.B) {