Skip to content

Commit

Permalink
Fix bug in comparison of unordered lists
Browse files Browse the repository at this point in the history
  • Loading branch information
loveleif committed Apr 14, 2023
1 parent 73986c7 commit 9af5db7
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ case class CypherPropertyMap(properties: Map[String, CypherValue] = Map.empty)

trait CypherList extends CypherValue {
def elements: List[CypherValue]
protected[tck] lazy val sortedElements: List[CypherValue] = elements.sorted(CypherValue.ordering)
}

/**
Expand All @@ -128,7 +129,8 @@ case class CypherOrderedList(elements: List[CypherValue] = List.empty) extends C
case _ => false
}

override def hashCode(): Int = Hashing.productHash(this)
// Hash code needs to be on sorted elements for comparison with unordered lists to work
override def hashCode(): Int = sortedElements.hashCode()
}

/**
Expand All @@ -141,14 +143,11 @@ private[tck] case class CypherUnorderedList(elements: List[CypherValue] = List.e

override def equals(obj: scala.Any): Boolean = obj match {
case null => false
case other: CypherOrderedList =>
other.elements.sorted(CypherValue.ordering) == elements.sorted(CypherValue.ordering)
case other: CypherUnorderedList =>
other.elements.sorted(CypherValue.ordering) == elements.sorted(CypherValue.ordering)
case other: CypherList => other.sortedElements == sortedElements
case _ => false
}

override def hashCode(): Int = Hashing.productHash(this)
override def hashCode(): Int = sortedElements.hashCode()
}

case object CypherNull extends CypherValue {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
package org.opencypher.tools.tck.api

import org.opencypher.tools.tck.values.CypherString
import org.opencypher.tools.tck.values.CypherValue
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers

Expand All @@ -49,4 +50,53 @@ class CypherValueRecordsTest extends AnyFunSuite with Matchers {

a.equalsUnordered(b) shouldBe false
}

test("compare unordered lists") {
val a = records(List(
Map("foo" -> CypherValue("['Ada', 'Danielle']", orderedLists = false)),
Map("foo" -> CypherValue("['Carl']", orderedLists = false)),
Map("foo" -> CypherValue("['Danielle']", orderedLists = false)),
Map("foo" -> CypherValue("[]", orderedLists = false)),
Map("foo" -> CypherValue("['Bob', 'Carl']", orderedLists = false))
))
val b = records(List(
Map("foo" -> CypherValue("['Carl', 'Bob']")),
Map("foo" -> CypherValue("['Ada', 'Danielle']")),
Map("foo" -> CypherValue("['Danielle']")),
Map("foo" -> CypherValue("['Carl']")),
Map("foo" -> CypherValue("[]"))
))

a.equalsUnordered(b) shouldBe true
b.equalsUnordered(a) shouldBe true
a should not equal(b)
b should not equal(a)
}

test("compare unordered lists 2") {
val a = records(List(
Map("foo" -> CypherValue("['Bob', 'Carl']", orderedLists = false)),
Map("foo" -> CypherValue("['Ada', 'Danielle']", orderedLists = false)),
Map("foo" -> CypherValue("['Danielle']", orderedLists = false)),
Map("foo" -> CypherValue("['Carl']", orderedLists = false)),
Map("foo" -> CypherValue("[]", orderedLists = false))
))
val b = records(List(
Map("foo" -> CypherValue("['Carl', 'Bob']")),
Map("foo" -> CypherValue("['Ada', 'Danielle']")),
Map("foo" -> CypherValue("['Danielle']")),
Map("foo" -> CypherValue("['Carl']")),
Map("foo" -> CypherValue("[]"))
))

a.equalsUnordered(b) shouldBe true
b.equalsUnordered(a) shouldBe true
a shouldBe (b)
b shouldBe (a)
}

private def records(columns: List[Map[String, CypherValue]]): CypherValueRecords = {
val header = columns.flatMap(_.keySet).distinct
CypherValueRecords(header, columns)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,15 @@ class CypherValueTest extends AnyFunSuite with Matchers {
val uList1 = CypherUnorderedList(List(CypherInteger(2), CypherInteger(1)).sorted(CypherValue.ordering))
val uList2 = CypherUnorderedList(List(CypherInteger(1), CypherInteger(2)).sorted(CypherValue.ordering))

oList1 should equal(oList1)
oList2 should equal(oList2)
uList1 should equal(uList1)
uList2 should equal(uList2)

oList1 should not equal oList2
oList2 should not equal oList1

uList1 should equal(oList1)
uList1 should equal(oList2)
uList1 should equal(uList2)
uList2 should equal(oList1)
uList2 should equal(oList2)
uList2 should equal(uList1)

oList1 should equal(uList1)
oList2 should equal(uList1)
uList2 should equal(uList1)
oList1 should equal(uList2)
oList2 should equal(uList2)
uList1 should equal(uList2)
assertReallyEqual(uList1, oList1)
assertReallyEqual(uList1, oList2)
assertReallyEqual(uList1, uList2)
assertReallyEqual(uList2, oList1)
assertReallyEqual(uList2, oList2)
assertReallyEqual(uList2, uList1)
}

test("list comparisons with strings") {
Expand All @@ -69,48 +57,39 @@ class CypherValueTest extends AnyFunSuite with Matchers {
val uList1 = CypherUnorderedList(List(CypherString("name"), CypherString("age")).sorted(CypherValue.ordering))
val uList2 = CypherUnorderedList(List(CypherString("age"), CypherString("name")).sorted(CypherValue.ordering))

oList1 should equal(oList1)
oList2 should equal(oList2)
uList1 should equal(uList1)
uList2 should equal(uList2)

oList1 should not equal oList2
oList2 should not equal oList1

uList1 should equal(oList1)
uList1 should equal(oList2)
uList1 should equal(uList2)
uList2 should equal(oList1)
uList2 should equal(oList2)
uList2 should equal(uList1)

oList1 should equal(uList1)
oList2 should equal(uList1)
uList2 should equal(uList1)
oList1 should equal(uList2)
oList2 should equal(uList2)
uList1 should equal(uList2)
assertReallyEqual(uList1, oList1)
assertReallyEqual(uList1, oList2)
assertReallyEqual(uList1, uList2)
assertReallyEqual(uList2, oList1)
assertReallyEqual(uList2, oList2)
}

test("lists that are equal should have the same hashCode") {
val oList = CypherOrderedList(List(CypherString("Foo")))
val uList = CypherUnorderedList(List(CypherString("Foo")))

oList should equal(uList)
oList.hashCode() should equal(uList.hashCode())
assertReallyEqual(oList, uList)
}

test("lists that are equal should really have the same hashCode") {
val a = CypherValue("['Carl', 'Bob']", orderedLists = false)
val b = CypherValue("['Bob', 'Carl']")
assertReallyEqual(a, b)
}

test("list comparisons simple example") {
val orderedItems1 = List(CypherString("name"), CypherString("age"), CypherString("address"))
val orderedItems2 = List(CypherString("age"), CypherString("name"), CypherString("address"))
val l1 = CypherUnorderedList(orderedItems1.sorted(CypherValue.ordering))
val l2 = CypherOrderedList(orderedItems1)
l1 should equal(l2)
l2 should equal(l1)
assertReallyEqual(l1, l2)
}

test("node comparison with labelled nodes") {
CypherNode(Set("A", "B")) should equal(CypherNode(Set("B", "A")))
assertReallyEqual(CypherNode(Set("A", "B")), CypherNode(Set("B", "A")))
CypherNode(Set("A", "C")) should not equal(CypherNode(Set("A", "B")))
}

Expand All @@ -124,45 +103,61 @@ class CypherValueTest extends AnyFunSuite with Matchers {
CypherNode(scala.collection.immutable.SortedSet("C", "B")(Ordering.String.reverse))
)

CypherOrderedList(nodeList1) should equal(CypherOrderedList(nodeList2))
CypherOrderedList(nodeList2) should equal(CypherOrderedList(nodeList1))
CypherOrderedList(nodeList1) should equal(CypherUnorderedList(nodeList2))
CypherOrderedList(nodeList2) should equal(CypherUnorderedList(nodeList1))
CypherOrderedList(nodeList1.reverse) should equal(CypherUnorderedList(nodeList2))
CypherOrderedList(nodeList1) should equal(CypherUnorderedList(nodeList2.reverse))
CypherOrderedList(nodeList1.reverse) should equal(CypherUnorderedList(nodeList1))
CypherOrderedList(nodeList1) should equal(CypherUnorderedList(nodeList1.reverse))
assertReallyEqual(CypherOrderedList(nodeList1), CypherOrderedList(nodeList2))
assertReallyEqual(CypherOrderedList(nodeList1), CypherUnorderedList(nodeList2))
assertReallyEqual(CypherOrderedList(nodeList1.reverse), CypherUnorderedList(nodeList2))
assertReallyEqual(CypherOrderedList(nodeList1), CypherUnorderedList(nodeList2.reverse))
assertReallyEqual(CypherOrderedList(nodeList1.reverse), CypherUnorderedList(nodeList1))
assertReallyEqual(CypherOrderedList(nodeList1), CypherUnorderedList(nodeList1.reverse))

CypherOrderedList(nodeList1.reverse) should not equal(CypherOrderedList(nodeList2))
CypherOrderedList(nodeList1) should not equal(CypherOrderedList(nodeList2.reverse))
CypherOrderedList(nodeList1.reverse) should not equal(CypherOrderedList(nodeList1))
CypherOrderedList(nodeList1) should not equal(CypherOrderedList(nodeList1.reverse))

CypherUnorderedList(nodeList1) should equal(CypherUnorderedList(nodeList2))
CypherUnorderedList(nodeList2) should equal(CypherUnorderedList(nodeList1))
CypherUnorderedList(nodeList1) should equal(CypherOrderedList(nodeList2))
CypherUnorderedList(nodeList2) should equal(CypherOrderedList(nodeList1))
CypherUnorderedList(nodeList1.reverse) should equal(CypherUnorderedList(nodeList2))
CypherUnorderedList(nodeList1) should equal(CypherUnorderedList(nodeList2.reverse))
CypherUnorderedList(nodeList1.reverse) should equal(CypherOrderedList(nodeList2))
CypherUnorderedList(nodeList1) should equal(CypherOrderedList(nodeList2.reverse))
CypherUnorderedList(nodeList1.reverse) should equal(CypherUnorderedList(nodeList1))
CypherUnorderedList(nodeList1) should equal(CypherUnorderedList(nodeList1.reverse))
CypherUnorderedList(nodeList1.reverse) should equal(CypherOrderedList(nodeList1))
CypherUnorderedList(nodeList1) should equal(CypherOrderedList(nodeList1.reverse))
assertReallyEqual(CypherUnorderedList(nodeList1), CypherUnorderedList(nodeList2))
assertReallyEqual(CypherUnorderedList(nodeList1), CypherOrderedList(nodeList2))
assertReallyEqual(CypherUnorderedList(nodeList2), CypherOrderedList(nodeList1))
assertReallyEqual(CypherUnorderedList(nodeList1.reverse), CypherUnorderedList(nodeList2))
assertReallyEqual(CypherUnorderedList(nodeList1), CypherUnorderedList(nodeList2.reverse))
assertReallyEqual(CypherUnorderedList(nodeList1.reverse), CypherOrderedList(nodeList2))
assertReallyEqual(CypherUnorderedList(nodeList1), CypherOrderedList(nodeList2.reverse))
assertReallyEqual(CypherUnorderedList(nodeList1.reverse), CypherUnorderedList(nodeList1))
assertReallyEqual(CypherUnorderedList(nodeList1), CypherUnorderedList(nodeList1.reverse))
assertReallyEqual(CypherUnorderedList(nodeList1.reverse), CypherOrderedList(nodeList1))
assertReallyEqual(CypherUnorderedList(nodeList1), CypherOrderedList(nodeList1.reverse))
}

test("list of lists comparison with labelled nodes") {
CypherValue("[[(:A:D), (:B:C)], [(:AA:DD), (:BB:CC)]]", orderedLists = false) should
equal(CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = false))

CypherValue("[[(:A:D), (:B:C)], [(:AA:DD), (:BB:CC)]]", orderedLists = true) should
equal(CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = true))

CypherValue("[[(:AA:DD), (:BB:CC)], [(:A:D), (:B:C)]]", orderedLists = false) should
equal(CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = false))

assertReallyEqual(
CypherValue("[[(:A:D), (:B:C)], [(:AA:DD), (:BB:CC)]]", orderedLists = false),
CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = false)
)
assertReallyEqual(
CypherValue("[[(:A:D), (:B:C)], [(:AA:DD), (:BB:CC)]]", orderedLists = true),
CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = false)
)
assertReallyEqual(
CypherValue("[[(:A:D), (:B:C)], [(:AA:DD), (:BB:CC)]]", orderedLists = true),
CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = true)
)
assertReallyEqual(
CypherValue("[[(:AA:DD), (:BB:CC)], [(:A:D), (:B:C)]]", orderedLists = false),
CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = false)
)
assertReallyEqual(
CypherValue("[[(:AA:DD), (:BB:CC)], [(:A:D), (:B:C)]]", orderedLists = false),
CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = true)
)
CypherValue("[[(:AA:DD), (:BB:CC)], [(:A:D), (:B:C)]]", orderedLists = true) should
not equal(CypherValue("[[(:D:A), (:C:B)], [(:DD:AA), (:CC:BB)]]", orderedLists = true))
}

private def assertReallyEqual(a: CypherValue, b: CypherValue): Unit = {
a shouldBe a
b shouldBe b
a shouldBe b
b shouldBe a
a.hashCode() shouldBe b.hashCode()
}
}

0 comments on commit 9af5db7

Please sign in to comment.