Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions internal/ref/ref.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,3 @@ func ParseDataPath(s string) (ast.Ref, error) {

return path.Ref(ast.DefaultRootDocument), nil
}

// ArrayPath will take an ast.Array and build an ast.Ref using the ast.Terms in the Array
func ArrayPath(a *ast.Array) ast.Ref {
ref := make(ast.Ref, 0, a.Len())

a.Foreach(func(term *ast.Term) {
ref = append(ref, term)
})

return ref
}
44 changes: 18 additions & 26 deletions v1/topdown/builtins/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ func NewOperandErr(pos int, f string, a ...any) error {

// NewOperandTypeErr returns an operand error indicating the operand's type was wrong.
func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error {

if len(expected) == 1 {
return NewOperandErr(pos, "must be %v but got %v", expected[0], ast.ValueName(got))
}
Expand All @@ -138,7 +137,6 @@ func NewOperandTypeErr(pos int, got ast.Value, expected ...string) error {
// NewOperandElementErr returns an operand error indicating an element in the
// composite operand was wrong.
func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected ...string) error {

tpe := ast.ValueName(composite)

if len(expected) == 1 {
Expand All @@ -150,7 +148,6 @@ func NewOperandElementErr(pos int, composite ast.Value, got ast.Value, expected

// NewOperandEnumErr returns an operand error indicating a value was wrong.
func NewOperandEnumErr(pos int, expected ...string) error {

if len(expected) == 1 {
return NewOperandErr(pos, "must be %v", expected[0])
}
Expand Down Expand Up @@ -192,30 +189,27 @@ func BigIntOperand(x ast.Value, pos int) (*big.Int, error) {
// NumberOperand converts x to a number. If the cast fails, a descriptive error is
// returned.
func NumberOperand(x ast.Value, pos int) (ast.Number, error) {
n, ok := x.(ast.Number)
if !ok {
return ast.Number(""), NewOperandTypeErr(pos, x, "number")
if n, ok := x.(ast.Number); ok {
return n, nil
}
return n, nil
return ast.Number(""), NewOperandTypeErr(pos, x, "number")
}

// SetOperand converts x to a set. If the cast fails, a descriptive error is
// returned.
func SetOperand(x ast.Value, pos int) (ast.Set, error) {
s, ok := x.(ast.Set)
if !ok {
return nil, NewOperandTypeErr(pos, x, "set")
if s, ok := x.(ast.Set); ok {
return s, nil
}
return s, nil
return nil, NewOperandTypeErr(pos, x, "set")
}

// StringOperand returns x as [ast.String], or a descriptive error if the conversion fails.
func StringOperand(x ast.Value, pos int) (ast.String, error) {
s, ok := x.(ast.String)
if !ok {
return ast.String(""), NewOperandTypeErr(pos, x, "string")
if s, ok := x.(ast.String); ok {
return s, nil
}
return s, nil
return ast.String(""), NewOperandTypeErr(pos, x, "string")
}

// StringOperandByteSlice returns x a []byte, assuming x is [ast.String], or a descriptive error
Expand All @@ -229,24 +223,22 @@ func StringOperandByteSlice(x ast.Value, pos int) ([]byte, error) {
return util.StringToByteSlice(string(s)), nil
}

// ObjectOperand converts x to an object. If the cast fails, a descriptive
// ObjectOperand converts x to an object. If the conversion fails, a descriptive
// error is returned.
func ObjectOperand(x ast.Value, pos int) (ast.Object, error) {
o, ok := x.(ast.Object)
if !ok {
return nil, NewOperandTypeErr(pos, x, "object")
func ObjectOperand(x ast.Value, pos int) (o ast.Object, err error) {
if o, ok := x.(ast.Object); ok {
return o, nil
}
return o, nil
return nil, NewOperandTypeErr(pos, x, "object")
}

// ArrayOperand converts x to an array. If the cast fails, a descriptive
// ArrayOperand converts x to an array. If the conversion fails, a descriptive
// error is returned.
func ArrayOperand(x ast.Value, pos int) (*ast.Array, error) {
a, ok := x.(*ast.Array)
if !ok {
return nil, NewOperandTypeErr(pos, x, "array")
if a, ok := x.(*ast.Array); ok {
return a, nil
}
return a, nil
return nil, NewOperandTypeErr(pos, x, "array")
}

// NumberToFloat converts n to a big float.
Expand Down
39 changes: 15 additions & 24 deletions v1/topdown/object.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
package topdown

import (
"github.com/open-policy-agent/opa/internal/ref"
"cmp"

"github.com/open-policy-agent/opa/v1/ast"
"github.com/open-policy-agent/opa/v1/topdown/builtins"
)
Expand All @@ -24,10 +25,7 @@ func builtinObjectUnion(_ BuiltinContext, operands []*ast.Term, iter func(*ast.T
if objA.Len() == 0 {
return iter(operands[1])
}
if objB.Len() == 0 {
return iter(operands[0])
}
if objA.Compare(objB) == 0 {
if objB.Len() == 0 || objA.Compare(objB) == 0 {
return iter(operands[0])
}

Expand Down Expand Up @@ -126,34 +124,27 @@ func builtinObjectFilter(_ BuiltinContext, operands []*ast.Term, iter func(*ast.
}

func builtinObjectGet(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
object, err := builtins.ObjectOperand(operands[0].Value, 1)
// silly micro optimization: initial ref to last item avoids
// later bounds checks as 1 and 0 then known to be valid indices
defaultValue, path, curr := operands[2], operands[1], operands[0]

object, err := builtins.ObjectOperand(curr.Value, 1)
if err != nil {
return err
}

// if the get key is not an array, attempt to get the top level key for the operand value in the object
path, ok := operands[1].Value.(*ast.Array)
arr, ok := path.Value.(*ast.Array)
if !ok {
if ret := object.Get(operands[1]); ret != nil {
return iter(ret)
}

return iter(operands[2])
return iter(cmp.Or(object.Get(path), defaultValue))
}

// if the path is empty, then we skip selecting nested keys and return the whole object
if path.Len() == 0 {
return iter(operands[0])
}

// build an ast.Ref from the array and see if it matches within the object
pathRef := ref.ArrayPath(path)
value, err := object.Find(pathRef)
if err != nil {
return iter(operands[2])
for i := range arr.Len() {
if curr = curr.Get(arr.Elem(i)); curr == nil {
break
}
}

return iter(ast.NewTerm(value))
return iter(cmp.Or(curr, defaultValue))
}

func builtinObjectKeys(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
Expand Down
36 changes: 36 additions & 0 deletions v1/topdown/object_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,39 @@ func BenchmarkObjectUnionNSlow(b *testing.B) {
}
}
}

// 72.64 ns/op 56 B/op 2 allocs/op
// 45.49 ns/op 0 B/op 0 allocs/op
func BenchmarkObjectGetFound(b *testing.B) {
obj := ast.MustParseTerm(`{"a": {"b": {"c": {"d": 1}}}}`)
arr := ast.ArrayTerm(ast.InternedTerm("a"), ast.InternedTerm("b"), ast.InternedTerm("c"), ast.InternedTerm("d"))
def := ast.NullTerm()

ops := []*ast.Term{obj, arr, def}
exp := eqIter(ast.InternedTerm(1))
bcx := BuiltinContext{}

for b.Loop() {
if err := builtinObjectGet(bcx, ops, exp); err != nil {
b.Fatal(err)
}
}
}

// 48.15 ns/op 32 B/op 1 allocs/op
// 36.74 ns/op 0 B/op 0 allocs/op
func BenchmarkObjectGetNotFound(b *testing.B) {
obj := ast.MustParseTerm(`{"a": {"b": {"c": {"d": 1}}}}`)
arr := ast.ArrayTerm(ast.InternedTerm("a"), ast.InternedTerm("b"), ast.InternedTerm("c"), ast.InternedTerm("e"))
def := ast.NullTerm()

ops := []*ast.Term{obj, arr, def}
exp := eqIter(def)
bcx := BuiltinContext{}

for b.Loop() {
if err := builtinObjectGet(bcx, ops, exp); err != nil {
b.Fatal(err)
}
}
}
Loading