Skip to content

Commit 07fa191

Browse files
wangxiaojingmarmbrus
authored andcommitted
[SPARK-4570][SQL]add BroadcastLeftSemiJoinHash
JIRA issue: [SPARK-4570](https://issues.apache.org/jira/browse/SPARK-4570) We are planning to create a `BroadcastLeftSemiJoinHash` to implement the broadcast join for `left semijoin` In left semijoin : If the size of data from right side is smaller than the user-settable threshold `AUTO_BROADCASTJOIN_THRESHOLD`, the planner would mark it as the `broadcast` relation and mark the other relation as the stream side. The broadcast table will be broadcasted to all of the executors involved in the join, as a `org.apache.spark.broadcast.Broadcast` object. It will use `joins.BroadcastLeftSemiJoinHash`.,else it will use `joins.LeftSemiJoinHash`. The benchmark suggests these made the optimized version 4x faster when `left semijoin` <pre><code> Original: left semi join : 9288 ms Optimized: left semi join : 1963 ms </code></pre> The micro benchmark load `data1/kv3.txt` into a normal Hive table. Benchmark code: <pre><code> def benchmark(f: => Unit) = { val begin = System.currentTimeMillis() f val end = System.currentTimeMillis() end - begin } val sc = new SparkContext( new SparkConf() .setMaster("local") .setAppName(getClass.getSimpleName.stripSuffix("$"))) val hiveContext = new HiveContext(sc) import hiveContext._ sql("drop table if exists left_table") sql("drop table if exists right_table") sql( """create table left_table (key int, value string) """.stripMargin) sql( s"""load data local inpath "/data1/kv3.txt" into table left_table""") sql( """create table right_table (key int, value string) """.stripMargin) sql( """ |from left_table |insert overwrite table right_table |select left_table.key, left_table.value """.stripMargin) val leftSimeJoin = sql( """select a.key from left_table a |left semi join right_table b on a.key = b.key""".stripMargin) val leftSemiJoinDuration = benchmark(leftSimeJoin.count()) println(s"left semi join : $leftSemiJoinDuration ms ") </code></pre> Author: wangxiaojing <[email protected]> Closes #3442 from wangxiaojing/SPARK-4570 and squashes the following commits: a4a43c9 [wangxiaojing] rebase f103983 [wangxiaojing] change style fbe4887 [wangxiaojing] change style ff2e618 [wangxiaojing] add testsuite 1a8da2a [wangxiaojing] add BroadcastLeftSemiJoinHash
1 parent 8f29b7c commit 07fa191

File tree

4 files changed

+160
-1
lines changed

4 files changed

+160
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
3333

3434
object LeftSemiJoin extends Strategy with PredicateHelper {
3535
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
36+
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right)
37+
if sqlContext.autoBroadcastJoinThreshold > 0 &&
38+
right.statistics.sizeInBytes <= sqlContext.autoBroadcastJoinThreshold =>
39+
val semiJoin = joins.BroadcastLeftSemiJoinHash(
40+
leftKeys, rightKeys, planLater(left), planLater(right))
41+
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
3642
// Find left semi joins where at least some predicates can be evaluated by matching join keys
3743
case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
3844
val semiJoin = joins.LeftSemiJoinHash(
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.sql.catalyst.expressions.{Expression, Row}
22+
import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution
23+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
24+
25+
/**
26+
* :: DeveloperApi ::
27+
* Build the right table's join keys into a HashSet, and iteratively go through the left
28+
* table, to find the if join keys are in the Hash set.
29+
*/
30+
@DeveloperApi
31+
case class BroadcastLeftSemiJoinHash(
32+
leftKeys: Seq[Expression],
33+
rightKeys: Seq[Expression],
34+
left: SparkPlan,
35+
right: SparkPlan) extends BinaryNode with HashJoin {
36+
37+
override val buildSide = BuildRight
38+
39+
override def output = left.output
40+
41+
override def execute() = {
42+
val buildIter= buildPlan.execute().map(_.copy()).collect().toIterator
43+
val hashSet = new java.util.HashSet[Row]()
44+
var currentRow: Row = null
45+
46+
// Create a Hash set of buildKeys
47+
while (buildIter.hasNext) {
48+
currentRow = buildIter.next()
49+
val rowKey = buildSideKeyGenerator(currentRow)
50+
if (!rowKey.anyNull) {
51+
val keyExists = hashSet.contains(rowKey)
52+
if (!keyExists) {
53+
hashSet.add(rowKey)
54+
}
55+
}
56+
}
57+
58+
val broadcastedRelation = sparkContext.broadcast(hashSet)
59+
60+
streamedPlan.execute().mapPartitions { streamIter =>
61+
val joinKeys = streamSideKeyGenerator()
62+
streamIter.filter(current => {
63+
!joinKeys(current).anyNull && broadcastedRelation.value.contains(joinKeys.currentValue)
64+
})
65+
}
66+
}
67+
}

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
4848
case j: LeftSemiJoinBNL => j
4949
case j: CartesianProduct => j
5050
case j: BroadcastNestedLoopJoin => j
51+
case j: BroadcastLeftSemiJoinHash => j
5152
}
5253

5354
assert(operators.size === 1)
@@ -382,4 +383,41 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
382383
""".stripMargin),
383384
(null, 10) :: Nil)
384385
}
386+
387+
test("broadcasted left semi join operator selection") {
388+
clearCache()
389+
sql("CACHE TABLE testData")
390+
val tmp = autoBroadcastJoinThreshold
391+
392+
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000")
393+
Seq(
394+
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
395+
classOf[BroadcastLeftSemiJoinHash])
396+
).foreach {
397+
case (query, joinClass) => assertJoin(query, joinClass)
398+
}
399+
400+
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
401+
402+
Seq(
403+
("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash])
404+
).foreach {
405+
case (query, joinClass) => assertJoin(query, joinClass)
406+
}
407+
408+
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString)
409+
sql("UNCACHE TABLE testData")
410+
}
411+
412+
test("left semi join") {
413+
val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
414+
checkAnswer(rdd,
415+
(1, 1) ::
416+
(1, 2) ::
417+
(2, 1) ::
418+
(2, 2) ::
419+
(3, 1) ::
420+
(3, 2) :: Nil)
421+
422+
}
385423
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterAll
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark.sql.{SQLConf, QueryTest}
25-
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
25+
import org.apache.spark.sql.execution.joins._
2626
import org.apache.spark.sql.hive.test.TestHive
2727
import org.apache.spark.sql.hive.test.TestHive._
2828
import org.apache.spark.sql.hive.execution._
@@ -193,4 +193,52 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
193193
)
194194
}
195195

196+
test("auto converts to broadcast left semi join, by size estimate of a relation") {
197+
val leftSemiJoinQuery =
198+
"""SELECT * FROM src a
199+
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
200+
val answer = (86, "val_86") :: Nil
201+
202+
var rdd = sql(leftSemiJoinQuery)
203+
204+
// Assert src has a size smaller than the threshold.
205+
val sizes = rdd.queryExecution.analyzed.collect {
206+
case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass
207+
.isAssignableFrom(r.getClass) =>
208+
r.statistics.sizeInBytes
209+
}
210+
assert(sizes.size === 2 && sizes(1) <= autoBroadcastJoinThreshold
211+
&& sizes(0) <= autoBroadcastJoinThreshold,
212+
s"query should contain two relations, each of which has size smaller than autoConvertSize")
213+
214+
// Using `sparkPlan` because for relevant patterns in HashJoin to be
215+
// matched, other strategies need to be applied.
216+
var bhj = rdd.queryExecution.sparkPlan.collect {
217+
case j: BroadcastLeftSemiJoinHash => j
218+
}
219+
assert(bhj.size === 1,
220+
s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")
221+
222+
checkAnswer(rdd, answer) // check correctness of output
223+
224+
TestHive.settings.synchronized {
225+
val tmp = autoBroadcastJoinThreshold
226+
227+
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
228+
rdd = sql(leftSemiJoinQuery)
229+
bhj = rdd.queryExecution.sparkPlan.collect {
230+
case j: BroadcastLeftSemiJoinHash => j
231+
}
232+
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
233+
234+
val shj = rdd.queryExecution.sparkPlan.collect {
235+
case j: LeftSemiJoinHash => j
236+
}
237+
assert(shj.size === 1,
238+
"LeftSemiJoinHash should be planned when BroadcastHashJoin is turned off")
239+
240+
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp")
241+
}
242+
243+
}
196244
}

0 commit comments

Comments
 (0)