Skip to content

Commit

Permalink
fix: [~] row comparer & add list aprox comparer
Browse files Browse the repository at this point in the history
  • Loading branch information
eruizalo committed Apr 26, 2024
1 parent b4993d8 commit 81642fd
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,51 @@ object DataFrameSuiteBase {
def approxEquals(r1: Row, r2: Row, tolTimestamp: Duration): Boolean =
approxEquals(r1, r2, 0, tolTimestamp)

private def compareTimestamp(t1: Timestamp, t2: Timestamp,
tolTimestamp: Duration): Boolean = {
!(Duration.between(t1.toInstant, t2.toInstant).abs.compareTo(tolTimestamp) > 0)
}

private def compareDouble(d1: Double, d2: Double, tol: Double): Boolean = {
if (java.lang.Double.isNaN(d1) != java.lang.Double.isNaN(d2)) {
return false
}
if (abs(d1 - d2) > tol) {
return false
}
true
}

private def compareFloat(f1: Float, f2: Float, tol: Double): Boolean = {
if (java.lang.Float.isNaN(f1) != java.lang.Float.isNaN(f2)) {
return false
}
if (abs(f1 - f2) > tol) {
return false
}
true
}

private def compareJavaBigDecimal(d1: java.math.BigDecimal,
d2: java.math.BigDecimal,
tol: Double): Boolean = {
if (d1.compareTo(d2) != 0) {
if (d1.subtract(d2).abs.compareTo(new java.math.BigDecimal(tol)) > 0) {
return false
}
}
true
}

private def compareScalaBigDecimal(d1: scala.math.BigDecimal,
d2: scala.math.BigDecimal,
tol: Double): Boolean = {
if ((d1 - d2).abs > tol) {
return false
}
true
}

/** Approximate equality, based on equals from [[Row]] */
def approxEquals(r1: Row, r2: Row, tol: Double,
tolTimestamp: Duration): Boolean = {
Expand All @@ -457,47 +502,76 @@ object DataFrameSuiteBase {
}

case f1: Float =>
if (java.lang.Float.isNaN(f1) !=
java.lang.Float.isNaN(o2.asInstanceOf[Float]))
{
return false
}
if (abs(f1 - o2.asInstanceOf[Float]) > tol) {
if (!compareFloat(f1, o2.asInstanceOf[Float], tol)) {
return false
}

case d1: Double =>
if (java.lang.Double.isNaN(d1) !=
java.lang.Double.isNaN(o2.asInstanceOf[Double]))
{
return false
}
if (abs(d1 - o2.asInstanceOf[Double]) > tol) {
if (!compareDouble(d1, o2.asInstanceOf[Double], tol)) {
return false
}

case d1: java.math.BigDecimal =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) {
if (d1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs
.compareTo(new java.math.BigDecimal(tol)) > 0) {
return false
}
if (!compareJavaBigDecimal(d1, o2.asInstanceOf[java.math.BigDecimal], tol)) {
return false
}

case d1: scala.math.BigDecimal =>
if ((d1 - o2.asInstanceOf[scala.math.BigDecimal]).abs > tol) {
if (!compareScalaBigDecimal(d1, o2.asInstanceOf[scala.math.BigDecimal], tol)) {
return false
}

case t1: Timestamp =>
val t1Instant = t1.toInstant
val t2Instant = o2.asInstanceOf[Timestamp].toInstant
if (Duration.between(t1Instant, t2Instant).abs.compareTo(tolTimestamp) > 0) {
if (!compareTimestamp(t1, o2.asInstanceOf[Timestamp], tolTimestamp)) {
return false
}

case r1: Row =>
return approxEquals(r1, o2.asInstanceOf[Row], tol, tolTimestamp)
case row1: Row =>
if (!approxEquals(row1, o2.asInstanceOf[Row], tol, tolTimestamp)) {
return false
}

case head :: _ if head.isInstanceOf[Row] =>
o1.asInstanceOf[Seq[Row]].zip(o2.asInstanceOf[Seq[Row]]).foreach {
case (row1, row2) if !approxEquals(row1, row2, tol, tolTimestamp) =>
return false
case _ =>
}

case head :: _ if head.isInstanceOf[Timestamp] =>
o1.asInstanceOf[Seq[Timestamp]].zip(o2.asInstanceOf[Seq[Timestamp]]).foreach {
case (t1, t2) if !compareTimestamp(t1, t2, tolTimestamp) =>
return false
case _ =>
}

case head :: _ if head.isInstanceOf[Double] =>
o1.asInstanceOf[Seq[Double]].zip(o2.asInstanceOf[Seq[Double]]).foreach {
case (d1, d2) if !compareDouble(d1, d2, tol) =>
return false
case _ =>
}

case head :: _ if head.isInstanceOf[Float] =>
o1.asInstanceOf[Seq[Float]].zip(o2.asInstanceOf[Seq[Float]]).foreach {
case (f1, f2) if !compareFloat(f1, f2, tol) =>
return false
case _ =>
}

case head :: _ if head.isInstanceOf[java.math.BigDecimal] =>
o1.asInstanceOf[Seq[java.math.BigDecimal]].zip(o2.asInstanceOf[Seq[java.math.BigDecimal]]).foreach {
case (d1, d2) if !compareJavaBigDecimal(d1, d2, tol) =>
return false
case _ =>
}

case head :: _ if head.isInstanceOf[scala.math.BigDecimal] =>
o1.asInstanceOf[Seq[scala.math.BigDecimal]].zip(o2.asInstanceOf[Seq[scala.math.BigDecimal]]).foreach {
case (d1, d2) if !compareScalaBigDecimal(d1, d2, tol) =>
return false
case _ =>
}

case _ =>
if (o1 != o2) return false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,58 @@ class SampleDataFrameTest extends ScalaDataFrameSuiteBase {
val row12a = Row(new java.math.BigDecimal(1.0 + 1.0E-6))
val row13 = Row(scala.math.BigDecimal(1.0))
val row13a = Row(scala.math.BigDecimal(1.0 + 1.0E-6))
val row14 =
Row("abc", 1.1, Row("any", Row(Timestamp.valueOf("2018-01-12 20:23:13"))))
val row14a =
Row("abc", 1.2, Row("any", Row(Timestamp.valueOf("2018-01-12 20:23:15"))))
val row14 = Row(
"abc", 1.1,
Row("any", Row(Timestamp.valueOf("2018-01-12 20:23:13"))),
"abc"
)
val row14a = Row(
"abc", 1.2,
Row("any", Row(Timestamp.valueOf("2018-01-12 20:23:15"))),
"abc"
)
val row15 = Row(
"some string",
Row(true, "28/02/2024", Seq(Timestamp.valueOf("2018-01-12 20:23:13"))),
Seq(
Row("something", "anything", null, Row("row1"), Row(Seq("Apple"))),
Row("", null, Row(Seq("[email protected]"), "name", ""), Row("row2"), null)
),
Seq(Row(
Seq(1.1),
Seq(1.1f),
Seq(new java.math.BigDecimal(1.1)),
Seq(scala.math.BigDecimal(1.1))
))
)
val row15a = Row(
"some string",
Row(true, "28/02/2024", Seq(Timestamp.valueOf("2018-01-12 20:23:15"))),
Seq(
Row("something", "anything", null, Row("row1"), Row(Seq("Apple"))),
Row("", null, Row(Seq("[email protected]"), "name", ""), Row("row2"), null)
),
Seq(Row(
Seq(1.2),
Seq(1.2f),
Seq(new java.math.BigDecimal(1.2)),
Seq(scala.math.BigDecimal(1.2))
))
)
val row15b = Row(
"some string",
Row(true, "28/02/2024", Seq(Timestamp.valueOf("2018-01-12 20:23:13"))),
Seq(
Row("something", "anything", null, Row("row1"), Row(Seq("Apple"))),
Row("", null, Row(Seq("[email protected]"), "name", ""), Row("row2"), null)
),
Seq(Row(
Seq(1.1),
Seq(1.1f),
Seq(new java.math.BigDecimal(1.1)),
Seq(scala.math.BigDecimal(1.1))
))
)
assert(false === approxEquals(row, row2, 1E-7))
assert(true === approxEquals(row, row2, 1E-5))
assert(true === approxEquals(row3, row3, 1E-5))
Expand All @@ -147,6 +195,8 @@ class SampleDataFrameTest extends ScalaDataFrameSuiteBase {
assert(true === approxEquals(row12, row12a, 1.0E-6))
assert(true === approxEquals(row13, row13a, 1.0E-6))
assert(true === approxEquals(row14, row14a, 0.1, Duration.ofSeconds(5)))
assert(true === approxEquals(row15, row15a, 0.2, Duration.ofSeconds(3)))
assert(false === approxEquals(row15, row15b, 0, Duration.ZERO))
}

test("verify hive function support") {
Expand Down

0 comments on commit 81642fd

Please sign in to comment.