Skip to content

Commit ff2e618

Browse files
committed
add testsuite
1 parent 1a8da2a commit ff2e618

File tree

3 files changed

+110
-1
lines changed

3 files changed

+110
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
3333

3434
object LeftSemiJoin extends Strategy with PredicateHelper {
3535
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
36+
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
37+
if sqlContext.autoBroadcastJoinThreshold > 0 &&
38+
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
39+
val semiJoin = joins.BroadcastLeftSemiJoinHash(
40+
leftKeys, rightKeys, planLater(left), planLater(right))
41+
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
3642
// Find left semi joins where at least some predicates can be evaluated by matching join keys
3743
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
3844
val semiJoin = joins.LeftSemiJoinHash(

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
4848
case j: LeftSemiJoinBNL => j
4949
case j: CartesianProduct => j
5050
case j: BroadcastNestedLoopJoin => j
51+
case j: BroadcastLeftSemiJoinHash => j
5152
}
5253

5354
assert(operators.size === 1)
@@ -382,4 +383,39 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
382383
""".stripMargin),
383384
(null, 10) :: Nil)
384385
}
386+
test("broadcasted left semi join operator selection") {
387+
clearCache()
388+
sql("CACHE TABLE testData")
389+
val tmp = autoBroadcastJoinThreshold
390+
391+
sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000""")
392+
Seq(
393+
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash])
394+
).foreach {
395+
case (query, joinClass) => assertJoin(query, joinClass)
396+
}
397+
398+
sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")
399+
400+
Seq(
401+
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
402+
).foreach {
403+
case (query, joinClass) => assertJoin(query, joinClass)
404+
}
405+
406+
sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-$tmp""")
407+
sql("UNCACHE TABLE testData")
408+
}
409+
410+
test("left semi join") {
411+
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
412+
checkAnswer(rdd,
413+
(1, 1) ::
414+
(1, 2) ::
415+
(2, 1) ::
416+
(2, 2) ::
417+
(3, 1) ::
418+
(3, 2) :: Nil)
419+
420+
}
385421
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ import org.scalatest.BeforeAndAfterAll
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark.sql.{SQLConf, QueryTest}
25-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
25+
import org.apache.spark.sql.catalyst.plans.logical.NativeCommand
26+
import org.apache.spark.sql.execution.joins._
2627
import org.apache.spark.sql.hive.test.TestHive
2728
import org.apache.spark.sql.hive.test.TestHive._
2829
import org.apache.spark.sql.hive.execution._
@@ -193,4 +194,70 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
193194
)
194195
}
195196

197+
test("auto converts to broadcast left semi join, by size estimate of a relation") {
198+
def mkTest(
199+
before: () => Unit,
200+
after: () => Unit,
201+
query: String,
202+
expectedAnswer: Seq[Any],
203+
ct: ClassTag[_]) = {
204+
before()
205+
206+
var rdd = sql(query)
207+
208+
// Assert src has a size smaller than the threshold.
209+
val sizes = rdd.queryExecution.analyzed.collect {
210+
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes
211+
}
212+
assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold
213+
&& sizes(0) <= autoBroadcastJoinThreshold,
214+
s"query should contain two relations, each of which has size smaller than autoConvertSize")
215+
216+
// Using `sparkPlan` because for relevant patterns in HashJoin to be
217+
// matched, other strategies need to be applied.
218+
var bhj = rdd.queryExecution.sparkPlan.collect {
219+
case j: BroadcastLeftSemiJoinHash => j
220+
}
221+
assert(bhj.size === 1,
222+
s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")
223+
224+
checkAnswer(rdd, expectedAnswer) // check correctness of output
225+
226+
TestHive.settings.synchronized {
227+
val tmp = autoBroadcastJoinThreshold
228+
229+
sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")
230+
rdd = sql(query)
231+
bhj = rdd.queryExecution.sparkPlan.collect {
232+
case j: BroadcastLeftSemiJoinHash => j
233+
}
234+
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
235+
236+
val shj = rdd.queryExecution.sparkPlan.collect {
237+
case j: LeftSemiJoinHash => j
238+
}
239+
assert(shj.size === 1,
240+
"LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
241+
242+
sql( s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""")
243+
}
244+
245+
after()
246+
}
247+
248+
/** Tests for MetastoreRelation */
249+
val leftSemiJoinQuery =
250+
"""SELECT * FROM src a
251+
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
252+
val Answer =(86, "val_86") ::Nil
253+
254+
mkTest(
255+
() => (),
256+
() => (),
257+
leftSemiJoinQuery,
258+
Answer,
259+
implicitly[ClassTag[MetastoreRelation]]
260+
)
261+
262+
}
196263
}

0 commit comments

Comments
 (0)