Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions v1/ast/builtins.go
Original file line number Diff line number Diff line change
Expand Up @@ -3662,11 +3662,12 @@ func (b *Builtin) Call(operands ...*Term) *Term {

// Ref returns a Ref that refers to the built-in function.
func (b *Builtin) Ref() Ref {
parts := strings.Split(b.Name, ".")
ref := make(Ref, len(parts))
ref[0] = VarTerm(parts[0])
for i := 1; i < len(parts); i++ {
ref[i] = InternedTerm(parts[i])
numParts := strings.Count(b.Name, ".") + 1
curr, remaining, ok := strings.Cut(b.Name, ".")
ref := append(make(Ref, 0, numParts), VarTerm(curr))
for ok {
curr, remaining, ok = strings.Cut(remaining, ".")
ref = append(ref, InternedTerm(curr))
}
return ref
}
Expand Down
59 changes: 30 additions & 29 deletions v1/ast/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
// is empty.
// Other comparisons are consistent but not defined.
func Compare(a, b any) int {

if t, ok := a.(*Term); ok {
if t == nil {
a = nil
Expand Down Expand Up @@ -77,7 +76,7 @@ func Compare(a, b any) int {
switch a := a.(type) {
case *Not:
b := b.(*Not)
return Compare(a.Body, b.Body)
return a.Compare(b)
case Null:
return 0
case Boolean:
Expand Down Expand Up @@ -120,25 +119,13 @@ func Compare(a, b any) int {
return a.Compare(b.(Set))
case *ArrayComprehension:
b := b.(*ArrayComprehension)
if cmp := Compare(a.Term, b.Term); cmp != 0 {
return cmp
}
return a.Body.Compare(b.Body)
return a.Compare(b)
case *ObjectComprehension:
b := b.(*ObjectComprehension)
if cmp := Compare(a.Key, b.Key); cmp != 0 {
return cmp
}
if cmp := Compare(a.Value, b.Value); cmp != 0 {
return cmp
}
return a.Body.Compare(b.Body)
return a.Compare(b)
case *SetComprehension:
b := b.(*SetComprehension)
if cmp := Compare(a.Term, b.Term); cmp != 0 {
return cmp
}
return a.Body.Compare(b.Body)
return a.Compare(b)
case Call:
return termSliceCompare(a, b.(Call))
case *Expr:
Expand Down Expand Up @@ -173,14 +160,8 @@ func Compare(a, b any) int {
panic(fmt.Sprintf("illegal value: %T", a))
}

type termSlice []*Term

func (s termSlice) Less(i, j int) bool { return Compare(s[i].Value, s[j].Value) < 0 }
func (s termSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s termSlice) Len() int { return len(s) }

func sortOrder(x any) int {
switch x.(type) {
func valueSortOrder(v Value) int {
switch v.(type) {
case Null:
return 0
case Boolean:
Expand Down Expand Up @@ -209,6 +190,28 @@ func sortOrder(x any) int {
return 12
case Call:
return 13
case *Not:
return 111
}
return 10000000
}

func valueTypeCompare[A, B Value](a A, b B) int {
sortA := valueSortOrder(a)
sortB := valueSortOrder(b)

if sortA < sortB {
return -1
} else if sortB < sortA {
return 1
}
return 0
}

func sortOrder(x any) int {
switch v := x.(type) {
case Value:
return valueSortOrder(v)
case Args:
return 14
case *Expr:
Expand All @@ -223,8 +226,6 @@ func sortOrder(x any) int {
return 104
case *With:
return 110
case *Not:
return 111
case *Head:
return 120
case Body:
Expand Down Expand Up @@ -294,7 +295,7 @@ func rulesCompare(a, b []*Rule) int {
func termSliceCompare(a, b []*Term) int {
minLen := min(len(b), len(a))
for i := range minLen {
if cmp := Compare(a[i], b[i]); cmp != 0 {
if cmp := a[i].Value.Compare(b[i].Value); cmp != 0 {
return cmp
}
}
Expand All @@ -309,7 +310,7 @@ func termSliceCompare(a, b []*Term) int {
func withSliceCompare(a, b []*With) int {
minLen := min(len(b), len(a))
for i := range minLen {
if cmp := Compare(a[i], b[i]); cmp != 0 {
if cmp := a[i].Compare(b[i]); cmp != 0 {
return cmp
}
}
Expand Down
11 changes: 5 additions & 6 deletions v1/ast/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ func (imp *Import) Compare(other *Import) int {
} else if other == nil {
return 1
}
if cmp := Compare(imp.Path, other.Path); cmp != 0 {
if cmp := imp.Path.Value.Compare(other.Path.Value); cmp != 0 {
return cmp
}

Expand Down Expand Up @@ -906,10 +906,10 @@ func (head *Head) Compare(other *Head) int {
} else if !head.Assign && other.Assign {
return 1
}
if cmp := Compare(head.Args, other.Args); cmp != 0 {
if cmp := termSliceCompare(head.Args, other.Args); cmp != 0 {
return cmp
}
if cmp := Compare(head.Reference, other.Reference); cmp != 0 {
if cmp := termSliceCompare(head.Reference, other.Reference); cmp != 0 {
return cmp
}
if cmp := VarCompare(head.Name, other.Name); cmp != 0 {
Expand Down Expand Up @@ -1218,7 +1218,6 @@ func (expr *Expr) Equal(other *Expr) bool {
// Otherwise, the expression terms are compared normally. If both expressions
// have the same terms, the modifiers are compared.
func (expr *Expr) Compare(other *Expr) int {

if expr == nil {
if other == nil {
return 0
Expand Down Expand Up @@ -1252,7 +1251,7 @@ func (expr *Expr) Compare(other *Expr) int {

switch t := expr.Terms.(type) {
case *Term:
if cmp := Compare(t.Value, other.Terms.(*Term).Value); cmp != 0 {
if cmp := t.Value.Compare(other.Terms.(*Term).Value); cmp != 0 {
return cmp
}
case []*Term:
Expand All @@ -1268,7 +1267,7 @@ func (expr *Expr) Compare(other *Expr) int {
return cmp
}
case *Not:
if cmp := Compare(t, other.Terms.(*Not)); cmp != 0 {
if cmp := t.Compare(other.Terms.(*Not)); cmp != 0 {
return cmp
}
case *LogicalAnd:
Expand Down
Loading
Loading