Skip to content

Commit 87b6ed9

Browse files
committed
Fix critical issues in test which led to false negatives.
1 parent 8d7fbe7 commit 87b6ed9

File tree

2 files changed

+37
-29
lines changed

2 files changed

+37
-29
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
22-
import org.apache.spark.sql.catalyst.expressions.BoundReference
2322
import org.apache.spark.sql.catalyst.util._
2423
import org.apache.spark.sql.test.TestSQLContext
2524
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row}
@@ -145,12 +144,15 @@ class SparkPlanTest extends SparkFunSuite {
145144
* instantiate a reference implementation of the physical operator
146145
* that's being tested. The result of executing this plan will be
147146
* treated as the source-of-truth for the test.
147+
* @param sortAnswers if true, the answers will be sorted by their toString representations prior
148+
* to being compared.
148149
*/
149150
protected def checkAnswer(
150151
input: DataFrame,
151152
planFunction: SparkPlan => SparkPlan,
152-
expectedPlanFunction: SparkPlan => SparkPlan): Unit = {
153-
SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction) match {
153+
expectedPlanFunction: SparkPlan => SparkPlan,
154+
sortAnswers: Boolean): Unit = {
155+
SparkPlanTest.checkAnswer(input, planFunction, expectedPlanFunction, sortAnswers) match {
154156
case Some(errorMessage) => fail(errorMessage)
155157
case None =>
156158
}
@@ -175,7 +177,8 @@ object SparkPlanTest {
175177
def checkAnswer(
176178
input: DataFrame,
177179
planFunction: SparkPlan => SparkPlan,
178-
expectedPlanFunction: SparkPlan => SparkPlan): Option[String] = {
180+
expectedPlanFunction: SparkPlan => SparkPlan,
181+
sortAnswers: Boolean): Option[String] = {
179182

180183
val outputPlan = planFunction(input.queryExecution.sparkPlan)
181184
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
@@ -210,7 +213,7 @@ object SparkPlanTest {
210213
return Some(errorMessage)
211214
}
212215

213-
compareAnswers(actualAnswer, expectedAnswer).map { errorMessage =>
216+
compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
214217
s"""
215218
| Results do not match.
216219
| Actual result Spark plan:
@@ -262,7 +265,8 @@ object SparkPlanTest {
262265

263266
private def compareAnswers(
264267
sparkAnswer: Seq[Row],
265-
expectedAnswer: Seq[Row]): Option[String] = {
268+
expectedAnswer: Seq[Row],
269+
sort: Boolean = true): Option[String] = {
266270
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
267271
// Converts data to types that we can do equality comparison using Scala collections.
268272
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -277,7 +281,11 @@ object SparkPlanTest {
277281
case o => o
278282
})
279283
}
280-
converted.sortBy(_.toString())
284+
if (sort) {
285+
converted.sortBy(_.toString())
286+
} else {
287+
converted
288+
}
281289
}
282290
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
283291
val errorMessage =

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,29 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
3838

3939
// Test sorting on different data types
4040
// TODO: randomized spilling to ensure that merging is tested at least once for every data type.
41-
(DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
42-
for (
43-
nullable <- Seq(true, false);
44-
sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
45-
randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
46-
) {
47-
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
48-
val inputData = Seq.fill(1024)(randomDataGenerator()).filter {
49-
case d: Double => !d.isNaN
50-
case f: Float => !java.lang.Float.isNaN(f)
51-
case x => true
52-
}
53-
val inputDf = TestSQLContext.createDataFrame(
54-
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
55-
StructType(StructField("a", dataType, nullable = true) :: Nil)
56-
)
57-
assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
58-
checkAnswer(
59-
inputDf,
60-
UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 100),
61-
Sort(sortOrder, global = false, _: SparkPlan)
62-
)
41+
for (
42+
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
43+
nullable <- Seq(true, false);
44+
sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
45+
randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
46+
) {
47+
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
48+
val inputData = Seq.fill(3)(randomDataGenerator()).filter {
49+
case d: Double => !d.isNaN
50+
case f: Float => !java.lang.Float.isNaN(f)
51+
case x => true
6352
}
53+
val inputDf = TestSQLContext.createDataFrame(
54+
TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
55+
StructType(StructField("a", dataType, nullable = true) :: Nil)
56+
)
57+
assert(UnsafeExternalSort.supportsSchema(inputDf.schema))
58+
checkAnswer(
59+
inputDf,
60+
UnsafeExternalSort(sortOrder, global = false, _: SparkPlan, testSpillFrequency = 2),
61+
Sort(sortOrder, global = false, _: SparkPlan),
62+
sortAnswers = false
63+
)
6464
}
6565
}
6666
}

0 commit comments

Comments
 (0)