Skip to content

Commit 5819d33

Browse files
author
Davies Liu
committed
unify equals() and hashCode()
1 parent 0fff25d commit 5819d33

File tree

5 files changed

+58
-142
lines changed

5 files changed

+58
-142
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -155,39 +155,6 @@ public int fieldIndex(String name) {
155155
throw new UnsupportedOperationException();
156156
}
157157

158-
/**
159-
* A generic version of Row.equals(Row), which is used for tests.
160-
*/
161-
@Override
162-
public boolean equals(Object other) {
163-
if (other instanceof Row) {
164-
Row row = (Row) other;
165-
int n = size();
166-
if (n != row.size()) {
167-
return false;
168-
}
169-
for (int i = 0; i < n; i ++) {
170-
if (isNullAt(i) != row.isNullAt(i)) {
171-
return false;
172-
}
173-
if (!isNullAt(i)) {
174-
Object o1 = get(i);
175-
Object o2 = row.get(i);
176-
if (o1 instanceof byte[]) {
177-
// handle equals() of byte[]
178-
if (!(o2 instanceof byte[]) || !java.util.Arrays.equals((byte[])o1, (byte[])o2)) {
179-
return false;
180-
}
181-
} else if (!o1.equals(o2)) {
182-
return false;
183-
}
184-
}
185-
}
186-
return true;
187-
}
188-
return false;
189-
}
190-
191158
@Override
192159
public InternalRow copy() {
193160
final int n = size();
@@ -227,15 +194,4 @@ public String mkString(String sep) {
227194
public String mkString(String start, String sep, String end) {
228195
return toSeq().mkString(start, sep, end);
229196
}
230-
231-
/*
232-
* Returns hash code based on bytes in `arr`
233-
* */
234-
protected int bytesHashCode(byte[] arr) {
235-
int hash = 0;
236-
for (int i = 0; i < arr.length; i++) {
237-
hash = hash * 37 + (int)arr[i];
238-
}
239-
return hash;
240-
}
241197
}

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.sql
1919

20-
import scala.util.hashing.MurmurHash3
21-
2220
import org.apache.spark.sql.catalyst.expressions.GenericRow
2321
import org.apache.spark.sql.types.StructType
2422

@@ -365,36 +363,6 @@ trait Row extends Serializable {
365363
false
366364
}
367365

368-
override def equals(that: Any): Boolean = that match {
369-
case null => false
370-
case that: Row =>
371-
if (this.length != that.length) {
372-
return false
373-
}
374-
var i = 0
375-
val len = this.length
376-
while (i < len) {
377-
if (apply(i) != that.apply(i)) {
378-
return false
379-
}
380-
i += 1
381-
}
382-
true
383-
case _ => false
384-
}
385-
386-
override def hashCode: Int = {
387-
// Using Scala's Seq hash code implementation.
388-
var n = 0
389-
var h = MurmurHash3.seqSeed
390-
val len = length
391-
while (n < len) {
392-
h = MurmurHash3.mix(h, apply(n).##)
393-
n += 1
394-
}
395-
MurmurHash3.finalizeHash(h, n)
396-
}
397-
398366
/* ---------------------- utility methods for Scala ---------------------- */
399367

400368
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst
1919

2020
import org.apache.spark.sql.Row
21-
import org.apache.spark.sql.catalyst.expressions.GenericRow
21+
import org.apache.spark.sql.catalyst.expressions._
2222

2323
/**
2424
* An abstract class for row used internal in Spark SQL, which only contain the columns as
@@ -27,6 +27,62 @@ import org.apache.spark.sql.catalyst.expressions.GenericRow
2727
abstract class InternalRow extends Row {
2828
// A default implementation to change the return type
2929
override def copy(): InternalRow = {this}
30+
31+
// A default version (slow), used for tests
32+
override def equals(o: Any): Boolean = o match {
33+
case other: InternalRow =>
34+
if (length != other.length) {
35+
return false
36+
}
37+
38+
for (i <- 0 until length) {
39+
if (isNullAt(i) != other.isNullAt(i)) {
40+
return false
41+
}
42+
if (!isNullAt(i)) {
43+
val o1 = apply(i)
44+
val o2 = other.apply(i)
45+
if (o1.isInstanceOf[Array[Byte]]) {
46+
if (!o2.isInstanceOf[Array[Byte]] ||
47+
!java.util.Arrays.equals(o1.asInstanceOf[Array[Byte]], o2.asInstanceOf[Array[Byte]])) {
48+
return false
49+
}
50+
} else if (o1 != o2) {
51+
return false
52+
}
53+
}
54+
}
55+
true
56+
case _ => false
57+
}
58+
59+
// Custom hashCode function that matches the efficient code generated version.
60+
override def hashCode: Int = {
61+
var result: Int = 37
62+
63+
for (i <- 0 until length) {
64+
val update: Int =
65+
if (isNullAt(i)) {
66+
0
67+
} else {
68+
apply(i) match {
69+
case b: Boolean => if (b) 0 else 1
70+
case b: Byte => b.toInt
71+
case s: Short => s.toInt
72+
case i: Int => i
73+
case l: Long => (l ^ (l >>> 32)).toInt
74+
case f: Float => java.lang.Float.floatToIntBits(f)
75+
case d: Double =>
76+
val b = java.lang.Double.doubleToLongBits(d)
77+
(b ^ (b >>> 32)).toInt
78+
case a: Array[Byte] => java.util.Arrays.hashCode(a)
79+
case other => other.hashCode()
80+
}
81+
}
82+
result = 37 * result + update
83+
}
84+
result
85+
}
3086
}
3187

3288
object InternalRow {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
127127
case FloatType => s"Float.floatToIntBits($col)"
128128
case DoubleType =>
129129
s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))"
130-
case BinaryType => s"bytesHashCode($col)"
130+
case BinaryType => s"java.util.Arrays.hashCode($col)"
131131
case _ => s"$col.hashCode()"
132132
}
133133
s"isNullAt($i) ? 0 : ($nonNull)"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -121,70 +121,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow {
121121
}
122122
}
123123

124-
// TODO(davies): add getDate and getDecimal
125-
126-
// Custom hashCode function that matches the efficient code generated version.
127-
override def hashCode: Int = {
128-
var result: Int = 37
129-
130-
var i = 0
131-
while (i < values.length) {
132-
val update: Int =
133-
if (isNullAt(i)) {
134-
0
135-
} else {
136-
apply(i) match {
137-
case b: Boolean => if (b) 0 else 1
138-
case b: Byte => b.toInt
139-
case s: Short => s.toInt
140-
case i: Int => i
141-
case l: Long => (l ^ (l >>> 32)).toInt
142-
case f: Float => java.lang.Float.floatToIntBits(f)
143-
case d: Double =>
144-
val b = java.lang.Double.doubleToLongBits(d)
145-
(b ^ (b >>> 32)).toInt
146-
case a: Array[Byte] => a.map(_.toInt).fold(0)(_ * 37 + _)
147-
case other => other.hashCode()
148-
}
149-
}
150-
result = 37 * result + update
151-
i += 1
152-
}
153-
result
154-
}
155-
156-
override def equals(o: Any): Boolean = o match {
157-
case other: InternalRow =>
158-
if (values.length != other.length) {
159-
return false
160-
}
161-
162-
var i = 0
163-
while (i < values.length) {
164-
if (isNullAt(i) != other.isNullAt(i)) {
165-
return false
166-
}
167-
if (!isNullAt(i)) {
168-
val o1 = apply(i)
169-
val o2 = other.apply(i)
170-
if (o1.isInstanceOf[Array[Byte]]) {
171-
val b1 = o1.asInstanceOf[Array[Byte]]
172-
if (!o2.isInstanceOf[Array[Byte]] ||
173-
java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
174-
return false
175-
}
176-
177-
} else if (apply(i) != other.apply(i)) {
178-
return false
179-
}
180-
}
181-
i += 1
182-
}
183-
true
184-
185-
case _ => false
186-
}
187-
188124
override def copy(): InternalRow = this
189125
}
190126

0 commit comments

Comments
 (0)