Skip to content

Commit a46144a

Browse files
committed
WIP
1 parent af31335 commit a46144a

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.dsl.expressions._
21+
import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder}
22+
23+
class SortSuite extends SparkPlanTest {
24+
25+
test("basic sorting using ExternalSort") {
26+
27+
val input = Seq(
28+
("Hello", 4),
29+
("Hello", 1),
30+
("World", 8)
31+
)
32+
33+
val sortOrder = Seq(
34+
SortOrder('_1, Ascending),
35+
SortOrder('_2, Ascending)
36+
)
37+
38+
checkAnswer(
39+
input,
40+
child => new ExternalSort(sortOrder, global = false, child),
41+
input.sorted
42+
)
43+
44+
}
45+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import scala.util.control.NonFatal
21+
import scala.reflect.runtime.universe.TypeTag
22+
23+
import org.apache.spark.SparkFunSuite
24+
import org.apache.spark.sql.test.TestSQLContext
25+
import org.apache.spark.sql.{Row, DataFrame}
26+
import org.apache.spark.sql.catalyst.util._
27+
28+
/**
29+
* Base class for writing tests for individual physical operators. For an example of how this class
30+
* can be used, see [[SortSuite]].
31+
*/
32+
class SparkPlanTest extends SparkFunSuite {
33+
34+
/**
35+
* Runs the plan and makes sure the answer matches the expected result.
36+
* @param input the input data to be used.
37+
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate the
38+
* physical operator that's being tested.
39+
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
40+
*/
41+
protected def checkAnswer(
42+
input: DataFrame,
43+
planFunction: SparkPlan => SparkPlan,
44+
expectedAnswer: Seq[Row]): Unit = {
45+
SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer) match {
46+
case Some(errorMessage) => fail(errorMessage)
47+
case None =>
48+
}
49+
}
50+
51+
/**
52+
* Runs the plan and makes sure the answer matches the expected result.
53+
* @param input the input data to be used.
54+
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate the
55+
* physical operator that's being tested.
56+
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
57+
*/
58+
protected def checkAnswer[A <: Product : TypeTag](
59+
input: Seq[A],
60+
planFunction: SparkPlan => SparkPlan,
61+
expectedAnswer: Seq[A]): Unit = {
62+
val inputDf = TestSQLContext.createDataFrame(input)
63+
val expectedRows = expectedAnswer.map(t => Row.apply(t))
64+
SparkPlanTest.checkAnswer(inputDf, planFunction, expectedRows) match {
65+
case Some(errorMessage) => fail(errorMessage)
66+
case None =>
67+
}
68+
}
69+
}
70+
71+
/**
72+
* Helper methods for writing tests of individual physical operators.
73+
*/
74+
object SparkPlanTest {
75+
76+
/**
77+
* Runs the plan and makes sure the answer matches the expected result.
78+
* @param input the input data to be used.
79+
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate the
80+
* physical operator that's being tested.
81+
* @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
82+
*/
83+
def checkAnswer(
84+
input: DataFrame,
85+
planFunction: SparkPlan => SparkPlan,
86+
expectedAnswer: Seq[Row]): Option[String] = {
87+
88+
val outputPlan = planFunction(input.queryExecution.sparkPlan)
89+
90+
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
91+
// Converts data to types that we can do equality comparison using Scala collections.
92+
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
93+
// Java's java.math.BigDecimal.compareTo).
94+
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
95+
// equality test.
96+
// This function is copied from Catalyst's QueryTest
97+
val converted: Seq[Row] = answer.map { s =>
98+
Row.fromSeq(s.toSeq.map {
99+
case d: java.math.BigDecimal => BigDecimal(d)
100+
case b: Array[Byte] => b.toSeq
101+
case o => o
102+
})
103+
}
104+
converted.sortBy(_.toString())
105+
}
106+
107+
val sparkAnswer: Seq[Row] = try {
108+
outputPlan.executeCollect().toSeq
109+
} catch {
110+
case NonFatal(e) =>
111+
val errorMessage =
112+
s"""
113+
| Exception thrown while executing Spark plan:
114+
| $outputPlan
115+
| == Exception ==
116+
| $e
117+
| ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
118+
""".stripMargin
119+
return Some(errorMessage)
120+
}
121+
122+
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
123+
val errorMessage =
124+
s"""
125+
| Results do not match for Spark plan:
126+
| $outputPlan
127+
| == Results ==
128+
| ${sideBySide(
129+
s"== Correct Answer - ${expectedAnswer.size} ==" +:
130+
prepareAnswer(expectedAnswer).map(_.toString()),
131+
s"== Spark Answer - ${sparkAnswer.size} ==" +:
132+
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
133+
""".stripMargin
134+
return Some(errorMessage)
135+
}
136+
137+
None
138+
}
139+
}
140+

0 commit comments

Comments
 (0)