Skip to content

[SPARK-8432] [SQL] fix hashCode() and equals() of BinaryType in Row #6876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java
Original file line number Diff line number Diff line change
Expand Up @@ -155,27 +155,6 @@ public int fieldIndex(String name) {
throw new UnsupportedOperationException();
}

/**
* A generic version of Row.equals(Row), which is used for tests.
*/
@Override
public boolean equals(Object other) {
if (other instanceof Row) {
Row row = (Row) other;
int n = size();
if (n != row.size()) {
return false;
}
for (int i = 0; i < n; i ++) {
if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
return false;
}
}
return true;
}
return false;
}

@Override
public InternalRow copy() {
final int n = size();
Expand Down
32 changes: 0 additions & 32 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql

import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -365,36 +363,6 @@ trait Row extends Serializable {
false
}

override def equals(that: Any): Boolean = that match {
case null => false
case that: Row =>
if (this.length != that.length) {
return false
}
var i = 0
val len = this.length
while (i < len) {
if (apply(i) != that.apply(i)) {
return false
}
i += 1
}
true
case _ => false
}

override def hashCode: Int = {
// Using Scala's Seq hash code implementation.
var n = 0
var h = MurmurHash3.seqSeed
val len = length
while (n < len) {
h = MurmurHash3.mix(h, apply(n).##)
n += 1
}
MurmurHash3.finalizeHash(h, n)
}

/* ---------------------- utility methods for Scala ---------------------- */

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,78 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.catalyst.expressions._

/**
* An abstract class for row used internal in Spark SQL, which only contain the columns as
* internal types.
*/
abstract class InternalRow extends Row {
// A default implementation to change the return type
override def copy(): InternalRow = {this}
override def copy(): InternalRow = this

override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[Row]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will we change it to isInstanceOf[InternalRow] after #6869?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

return false
}

val other = o.asInstanceOf[Row]
if (length != other.length) {
return false
}

var i = 0
while (i < length) {
if (isNullAt(i) != other.isNullAt(i)) {
return false
}
if (!isNullAt(i)) {
val o1 = apply(i)
val o2 = other.apply(i)
if (o1.isInstanceOf[Array[Byte]]) {
// handle equality of Array[Byte]
val b1 = o1.asInstanceOf[Array[Byte]]
if (!o2.isInstanceOf[Array[Byte]] ||
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
return false
}
} else if (o1 != o2) {
return false
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about

(o1, o2) match {
  case (b1: Array[Byte], b2: Array[Byte]) => java.util.Arrays.equals(b1, b2)
  case _ => o1 == o2
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

match is also slow in Scala

}
i += 1
}
true
}

// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37
var i = 0
while (i < length) {
val update: Int =
if (isNullAt(i)) {
0
} else {
apply(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
case i: Int => i
case l: Long => (l ^ (l >>> 32)).toInt
case f: Float => java.lang.Float.floatToIntBits(f)
case d: Double =>
val b = java.lang.Double.doubleToLongBits(d)
(b ^ (b >>> 32)).toInt
case a: Array[Byte] => java.util.Arrays.hashCode(a)
case other => other.hashCode()
}
}
result = 37 * result + update
i += 1
}
result
}
}

object InternalRow {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
case FloatType => s"Float.floatToIntBits($col)"
case DoubleType =>
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
case BinaryType => s"java.util.Arrays.hashCode($col)"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also update equals for generated code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's already done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, genEqual has already handled BinaryType.

case _ => s"$col.hashCode()"
}
s"isNullAt($i) ? 0 : ($nonNull)"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,58 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
}
}

// TODO(davies): add getDate and getDecimal

// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
var result: Int = 37

var i = 0
while (i < values.length) {
val update: Int =
if (isNullAt(i)) {
0
} else {
apply(i) match {
case b: Boolean => if (b) 0 else 1
case b: Byte => b.toInt
case s: Short => s.toInt
case i: Int => i
case l: Long => (l ^ (l >>> 32)).toInt
case f: Float => java.lang.Float.floatToIntBits(f)
case d: Double =>
val b = java.lang.Double.doubleToLongBits(d)
(b ^ (b >>> 32)).toInt
case other => other.hashCode()
}
}
result = 37 * result + update
i += 1
}
result
}

override def equals(o: Any): Boolean = o match {
case other: InternalRow =>
if (values.length != other.length) {
return false
}

var i = 0
while (i < values.length) {
if (isNullAt(i) != other.isNullAt(i)) {
return false
}
if (apply(i) != other.apply(i)) {
return false
}
i += 1
}
true

case _ => false
}

override def copy(): InternalRow = this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,23 @@ trait ExpressionEvalHelper {

protected def checkEvaluation(
expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
checkEvaluationWithoutCodegen(expression, expected, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
checkEvaluationWithOptimization(expression, expected, inputRow)
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
checkEvaluationWithOptimization(expression, catalystValue, inputRow)
}

/**
* Check the equality between result of expression and expected value, it will handle
* Array[Byte].
*/
protected def checkResult(result: Any, expected: Any): Boolean = {
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
case _ => result == expected
}
}

protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
Expand All @@ -55,7 +68,7 @@ trait ExpressionEvalHelper {
val actual = try evaluate(expression, inputRow) catch {
case e: Exception => fail(s"Exception evaluating $expression", e)
}
if (actual != expected) {
if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect evaluation (codegen off): $expression, " +
s"actual: $actual, " +
Expand Down Expand Up @@ -83,7 +96,7 @@ trait ExpressionEvalHelper {
}

val actual = plan(inputRow).apply(0)
if (actual != expected) {
if (!checkResult(actual, expected)) {
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
}
Expand All @@ -109,7 +122,7 @@ trait ExpressionEvalHelper {
}

val actual = plan(inputRow)
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
val expectedRow = new GenericRow(Array[Any](expected))
if (actual.hashCode() != expectedRow.hashCode()) {
fail(
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,79 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types._


class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {

// TODO: Add tests for all data types.
test("null") {
checkEvaluation(Literal.create(null, BooleanType), null)
checkEvaluation(Literal.create(null, ByteType), null)
checkEvaluation(Literal.create(null, ShortType), null)
checkEvaluation(Literal.create(null, IntegerType), null)
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, FloatType), null)
checkEvaluation(Literal.create(null, LongType), null)
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal.create(null, BinaryType), null)
checkEvaluation(Literal.create(null, DecimalType()), null)
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
}

test("boolean literals") {
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)
}

test("int literals") {
checkEvaluation(Literal(1), 1)
checkEvaluation(Literal(0L), 0L)
List(0, 1, Int.MinValue, Int.MaxValue).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toLong), d.toLong)
checkEvaluation(Literal(d.toShort), d.toShort)
checkEvaluation(Literal(d.toByte), d.toByte)
}
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
}

test("double literals") {
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach {
d => {
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)
}
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)
}
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)

}

test("string literals") {
checkEvaluation(Literal(""), "")
checkEvaluation(Literal("test"), "test")
checkEvaluation(Literal.create(null, StringType), null)
checkEvaluation(Literal("\0"), "\0")
}

test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
}

test("binary literals") {
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
}

test("decimal") {
List(0.0, 1.2, 1.1111, 5).foreach { d =>
checkEvaluation(Literal(Decimal(d)), Decimal(d))
checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt))
checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong))
checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)),
Decimal((d * 1000L).toLong, 10, 1))
}
}

// TODO(davies): add tests for ArrayType, MapType and StructType
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
checkEvaluation(StringLength(regEx), 0, create_row(""))
checkEvaluation(StringLength(regEx), null, create_row(null))
// TODO currently bug in codegen, let's temporally disable this
// checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}


}
Loading