Skip to content

Commit bd20780

Browse files
author
Davies Liu
committed
check with catalyst types
1 parent 41caec6 commit bd20780

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,23 @@ trait ExpressionEvalHelper {
3838

3939
protected def checkEvaluation(
4040
expression: Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = {
41-
checkEvaluationWithoutCodegen(expression, expected, inputRow)
42-
checkEvaluationWithGeneratedMutableProjection(expression, expected, inputRow)
43-
checkEvaluationWithGeneratedProjection(expression, expected, inputRow)
44-
checkEvaluationWithOptimization(expression, expected, inputRow)
41+
val catalystValue = CatalystTypeConverters.convertToCatalyst(expected)
42+
checkEvaluationWithoutCodegen(expression, catalystValue, inputRow)
43+
checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow)
44+
checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow)
45+
checkEvaluationWithOptimization(expression, catalystValue, inputRow)
46+
}
47+
48+
/**
49+
* Check the equality between result of expression and expected value, it will handle
50+
* Array[Byte].
51+
*/
52+
protected def checkResult(result: Any, expected: Any): Boolean = {
53+
(result, expected) match {
54+
case (result: Array[Byte], expected: Array[Byte]) =>
55+
java.util.Arrays.equals(result, expected)
56+
case _ => result == expected
57+
}
4558
}
4659

4760
protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
@@ -55,7 +68,7 @@ trait ExpressionEvalHelper {
5568
val actual = try evaluate(expression, inputRow) catch {
5669
case e: Exception => fail(s"Exception evaluating $expression", e)
5770
}
58-
if (actual != expected) {
71+
if (!checkResult(actual, expected)) {
5972
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
6073
fail(s"Incorrect evaluation (codegen off): $expression, " +
6174
s"actual: $actual, " +
@@ -83,7 +96,7 @@ trait ExpressionEvalHelper {
8396
}
8497

8598
val actual = plan(inputRow).apply(0)
86-
if (actual != expected) {
99+
if (!checkResult(actual, expected)) {
87100
val input = if (inputRow == EmptyRow) "" else s", input: $inputRow"
88101
fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input")
89102
}
@@ -109,7 +122,7 @@ trait ExpressionEvalHelper {
109122
}
110123

111124
val actual = plan(inputRow)
112-
val expectedRow = new GenericRow(Array[Any](CatalystTypeConverters.convertToCatalyst(expected)))
125+
val expectedRow = new GenericRow(Array[Any](expected))
113126
if (actual.hashCode() != expectedRow.hashCode()) {
114127
fail(
115128
s"""

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,6 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
222222
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
223223
checkEvaluation(StringLength(regEx), 0, create_row(""))
224224
checkEvaluation(StringLength(regEx), null, create_row(null))
225-
// TODO currently bug in codegen, let's temporally disable this
226-
// checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
225+
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
227226
}
228-
229-
230227
}

unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.unsafe.types;
1919

20+
import javax.annotation.Nullable;
2021
import java.io.Serializable;
2122
import java.io.UnsupportedEncodingException;
2223
import java.util.Arrays;
23-
import javax.annotation.Nullable;
2424

2525
import org.apache.spark.unsafe.PlatformDependent;
2626

@@ -196,10 +196,6 @@ public int compare(final UTF8String other) {
196196
public boolean equals(final Object other) {
197197
if (other instanceof UTF8String) {
198198
return Arrays.equals(bytes, ((UTF8String) other).getBytes());
199-
} else if (other instanceof String) {
200-
// Used only in unit tests.
201-
String s = (String) other;
202-
return bytes.length >= s.length() && length() == s.length() && toString().equals(s);
203199
} else {
204200
return false;
205201
}

0 commit comments

Comments
 (0)