Skip to content

Commit 6f4cadf

Browse files
Davies Liumarmbrus
authored andcommitted
[SPARK-8432] [SQL] fix hashCode() and equals() of BinaryType in Row
Also added more tests in LiteralExpressionSuite Author: Davies Liu <[email protected]> Closes #6876 from davies/fix_hashcode and squashes the following commits: 429c2c0 [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_hashcode 32d9811 [Davies Liu] fix test a0626ed [Davies Liu] Merge branch 'master' of github.com:apache/spark into fix_hashcode 89c2432 [Davies Liu] fix style bd20780 [Davies Liu] check with catalyst types 41caec6 [Davies Liu] change for to while d96929b [Davies Liu] address comment 6ad2a90 [Davies Liu] fix style 5819d33 [Davies Liu] unify equals() and hashCode() 0fff25d [Davies Liu] fix style 53c38b1 [Davies Liu] fix hashCode() and equals() of BinaryType in Row
1 parent 7b1450b commit 6f4cadf

File tree

10 files changed

+139
-135
lines changed

10 files changed

+139
-135
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -155,27 +155,6 @@ public int fieldIndex(String name) {
155155
throw new UnsupportedOperationException();
156156
}
157157

158-
/**
159-
* A generic version of Row.equals(Row), which is used for tests.
160-
*/
161-
@Override
162-
public boolean equals(Object other) {
163-
if (other instanceof Row) {
164-
Row row = (Row) other;
165-
int n = size();
166-
if (n != row.size()) {
167-
return false;
168-
}
169-
for (int i = 0; i < n; i ++) {
170-
if (isNullAt(i) != row.isNullAt(i) || (!isNullAt(i) && !get(i).equals(row.get(i)))) {
171-
return false;
172-
}
173-
}
174-
return true;
175-
}
176-
return false;
177-
}
178-
179158
@Override
180159
public InternalRow copy() {
181160
final int n = size();

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql
1919

20-
import scala.util.hashing.MurmurHash3
21-
2220
import org.apache.spark.sql.catalyst.expressions.GenericRow
2321
import org.apache.spark.sql.types.StructType
2422

@@ -365,36 +363,6 @@ trait Row extends Serializable {
365363
false
366364
}
367365

368-
override def equals(that: Any): Boolean = that match {
369-
case null => false
370-
case that: Row =>
371-
if (this.length != that.length) {
372-
return false
373-
}
374-
var i = 0
375-
val len = this.length
376-
while (i < len) {
377-
if (apply(i) != that.apply(i)) {
378-
return false
379-
}
380-
i += 1
381-
}
382-
true
383-
case _ => false
384-
}
385-
386-
override def hashCode: Int = {
387-
// Using Scala's Seq hash code implementation.
388-
var n = 0
389-
var h = MurmurHash3.seqSeed
390-
val len = length
391-
while (n < len) {
392-
h = MurmurHash3.mix(h, apply(n).##)
393-
n += 1
394-
}
395-
MurmurHash3.finalizeHash(h, n)
396-
}
397-
398366
/* ---------------------- utility methods for Scala ---------------------- */
399367

400368
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,78 @@
1818
package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.Row
21-
import org.apache.spark.sql.catalyst.expressions.GenericRow
21+
import org.apache.spark.sql.catalyst.expressions._
2222

2323
/**
2424
* An abstract class for row used internal in Spark SQL, which only contain the columns as
2525
* internal types.
2626
*/
2727
abstract class InternalRow extends Row {
2828
// A default implementation to change the return type
29-
override def copy(): InternalRow = {this}
29+
override def copy(): InternalRow = this
30+
31+
override def equals(o: Any): Boolean = {
32+
if (!o.isInstanceOf[Row]) {
33+
return false
34+
}
35+
36+
val other = o.asInstanceOf[Row]
37+
if (length != other.length) {
38+
return false
39+
}
40+
41+
var i = 0
42+
while (i < length) {
43+
if (isNullAt(i) != other.isNullAt(i)) {
44+
return false
45+
}
46+
if (!isNullAt(i)) {
47+
val o1 = apply(i)
48+
val o2 = other.apply(i)
49+
if (o1.isInstanceOf[Array[Byte]]) {
50+
// handle equality of Array[Byte]
51+
val b1 = o1.asInstanceOf[Array[Byte]]
52+
if (!o2.isInstanceOf[Array[Byte]] ||
53+
!java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
54+
return false
55+
}
56+
} else if (o1 != o2) {
57+
return false
58+
}
59+
}
60+
i += 1
61+
}
62+
true
63+
}
64+
65+
// Custom hashCode function that matches the efficient code generated version.
66+
override def hashCode: Int = {
67+
var result: Int = 37
68+
var i = 0
69+
while (i < length) {
70+
val update: Int =
71+
if (isNullAt(i)) {
72+
0
73+
} else {
74+
apply(i) match {
75+
case b: Boolean => if (b) 0 else 1
76+
case b: Byte => b.toInt
77+
case s: Short => s.toInt
78+
case i: Int => i
79+
case l: Long => (l ^ (l >>> 32)).toInt
80+
case f: Float => java.lang.Float.floatToIntBits(f)
81+
case d: Double =>
82+
val b = java.lang.Double.doubleToLongBits(d)
83+
(b ^ (b >>> 32)).toInt
84+
case a: Array[Byte] => java.util.Arrays.hashCode(a)
85+
case other => other.hashCode()
86+
}
87+
}
88+
result = 37 * result + update
89+
i += 1
90+
}
91+
result
92+
}
3093
}
3194

3295
object InternalRow {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
127127
case FloatType => s"Float.floatToIntBits($col)"
128128
case DoubleType =>
129129
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
130+
case BinaryType => s"java.util.Arrays.hashCode($col)"
130131
case _ => s"$col.hashCode()"
131132
}
132133
s"isNullAt($i) ? 0 : ($nonNull)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -121,58 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
121121
}
122122
}
123123

124-
// TODO(davies): add getDate and getDecimal
125-
126-
// Custom hashCode function that matches the efficient code generated version.
127-
override def hashCode: Int = {
128-
var result: Int = 37
129-
130-
var i = 0
131-
while (i < values.length) {
132-
val update: Int =
133-
if (isNullAt(i)) {
134-
0
135-
} else {
136-
apply(i) match {
137-
case b: Boolean => if (b) 0 else 1
138-
case b: Byte => b.toInt
139-
case s: Short => s.toInt
140-
case i: Int => i
141-
case l: Long => (l ^ (l >>> 32)).toInt
142-
case f: Float => java.lang.Float.floatToIntBits(f)
143-
case d: Double =>
144-
val b = java.lang.Double.doubleToLongBits(d)
145-
(b ^ (b >>> 32)).toInt
146-
case other => other.hashCode()
147-
}
148-
}
149-
result = 37 * result + update
150-
i += 1
151-
}
152-
result
153-
}
154-
155-
override def equals(o: Any): Boolean = o match {
156-
case other: InternalRow =>
157-
if (values.length != other.length) {
158-
return false
159-
}
160-
161-
var i = 0
162-
while (i < values.length) {
163-
if (isNullAt(i) != other.isNullAt(i)) {
164-
return false
165-
}
166-
if (apply(i) != other.apply(i)) {
167-
return false
168-
}
169-
i += 1
170-
}
171-
true
172-
173-
case _ => false
174-
}
175-
176124
override def copy(): InternalRow = this
177125
}
178126

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,23 @@ trait ExpressionEvalHelper {
3838

3939
protected def checkEvaluation(
4040
expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
41-
checkEvaluationWithoutCodegen(expression, expected, inputRow)
42-
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
43-
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
44-
checkEvaluationWithOptimization(expression, expected, inputRow)
41+
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
42+
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
43+
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
44+
checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
45+
checkEvaluationWithOptimization(expression, catalystValue, inputRow)
46+
}
47+
48+
/**
49+
* Check the equality between result of expression and expected value, it will handle
50+
* Array[Byte].
51+
*/
52+
protected def checkResult(result: Any, expected: Any): Boolean = {
53+
(result, expected) match {
54+
case (result: Array[Byte], expected: Array[Byte]) =>
55+
java.util.Arrays.equals(result, expected)
56+
case _ => result == expected
57+
}
4558
}
4659

4760
protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
@@ -55,7 +68,7 @@ trait ExpressionEvalHelper {
5568
val actual = try evaluate(expression, inputRow) catch {
5669
case e: Exception => fail(s"Exception evaluating $expression", e)
5770
}
58-
if (actual != expected) {
71+
if (!checkResult(actual, expected)) {
5972
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
6073
fail(s"Incorrect evaluation (codegen off): $expression, " +
6174
s"actual: $actual, " +
@@ -83,7 +96,7 @@ trait ExpressionEvalHelper {
8396
}
8497

8598
val actual = plan(inputRow).apply(0)
86-
if (actual != expected) {
99+
if (!checkResult(actual, expected)) {
87100
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
88101
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
89102
}
@@ -109,7 +122,7 @@ trait ExpressionEvalHelper {
109122
}
110123

111124
val actual = plan(inputRow)
112-
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
125+
val expectedRow = new GenericRow(Array[Any](expected))
113126
if (actual.hashCode() != expectedRow.hashCode()) {
114127
fail(
115128
s"""

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,79 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21-
import org.apache.spark.sql.types.StringType
21+
import org.apache.spark.sql.types._
2222

2323

2424
class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2525

26-
// TODO: Add tests for all data types.
26+
test("null") {
27+
checkEvaluation(Literal.create(null, BooleanType), null)
28+
checkEvaluation(Literal.create(null, ByteType), null)
29+
checkEvaluation(Literal.create(null, ShortType), null)
30+
checkEvaluation(Literal.create(null, IntegerType), null)
31+
checkEvaluation(Literal.create(null, LongType), null)
32+
checkEvaluation(Literal.create(null, FloatType), null)
33+
checkEvaluation(Literal.create(null, LongType), null)
34+
checkEvaluation(Literal.create(null, StringType), null)
35+
checkEvaluation(Literal.create(null, BinaryType), null)
36+
checkEvaluation(Literal.create(null, DecimalType()), null)
37+
checkEvaluation(Literal.create(null, ArrayType(ByteType, true)), null)
38+
checkEvaluation(Literal.create(null, MapType(StringType, IntegerType)), null)
39+
checkEvaluation(Literal.create(null, StructType(Seq.empty)), null)
40+
}
2741

2842
test("boolean literals") {
2943
checkEvaluation(Literal(true), true)
3044
checkEvaluation(Literal(false), false)
3145
}
3246

3347
test("int literals") {
34-
checkEvaluation(Literal(1), 1)
35-
checkEvaluation(Literal(0L), 0L)
48+
List(0, 1, Int.MinValue, Int.MaxValue).foreach { d =>
49+
checkEvaluation(Literal(d), d)
50+
checkEvaluation(Literal(d.toLong), d.toLong)
51+
checkEvaluation(Literal(d.toShort), d.toShort)
52+
checkEvaluation(Literal(d.toByte), d.toByte)
53+
}
54+
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
55+
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
3656
}
3757

3858
test("double literals") {
39-
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach {
40-
d => {
41-
checkEvaluation(Literal(d), d)
42-
checkEvaluation(Literal(d.toFloat), d.toFloat)
43-
}
59+
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
60+
checkEvaluation(Literal(d), d)
61+
checkEvaluation(Literal(d.toFloat), d.toFloat)
4462
}
63+
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
64+
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
65+
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
66+
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)
67+
4568
}
4669

4770
test("string literals") {
71+
checkEvaluation(Literal(""), "")
4872
checkEvaluation(Literal("test"), "test")
49-
checkEvaluation(Literal.create(null, StringType), null)
73+
checkEvaluation(Literal("\0"), "\0")
5074
}
5175

5276
test("sum two literals") {
5377
checkEvaluation(Add(Literal(1), Literal(1)), 2)
5478
}
79+
80+
test("binary literals") {
81+
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
82+
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
83+
}
84+
85+
test("decimal") {
86+
List(0.0, 1.2, 1.1111, 5).foreach { d =>
87+
checkEvaluation(Literal(Decimal(d)), Decimal(d))
88+
checkEvaluation(Literal(Decimal(d.toInt)), Decimal(d.toInt))
89+
checkEvaluation(Literal(Decimal(d.toLong)), Decimal(d.toLong))
90+
checkEvaluation(Literal(Decimal((d * 1000L).toLong, 10, 1)),
91+
Decimal((d * 1000L).toLong, 10, 1))
92+
}
93+
}
94+
95+
// TODO(davies): add tests for ArrayType, MapType and StructType
5596
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
222222
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
223223
checkEvaluation(StringLength(regEx), 0, create_row(""))
224224
checkEvaluation(StringLength(regEx), null, create_row(null))
225-
// TODO currently bug in codegen, let's temporally disable this
226-
// checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
225+
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
227226
}
228-
229-
230227
}

0 commit comments

Comments
 (0)