Skip to content

Commit 0cf6002

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-1495][SQL]add support for left semi join
Just submit another solution for apache#395 Author: Daoyuan <[email protected]> Author: Michael Armbrust <[email protected]> Author: Daoyuan Wang <[email protected]> Closes apache#837 from adrian-wang/left-semi-join-support and squashes the following commits: d39cd12 [Daoyuan Wang] Merge pull request #1 from marmbrus/pr/837 6713c09 [Michael Armbrust] Better debugging for failed query tests. 035b73e [Michael Armbrust] Add test for left semi that can't be done with a hash join. 5ec6fa4 [Michael Armbrust] Add left semi to SQL Parser. 4c726e5 [Daoyuan] improvement according to Michael 8d4a121 [Daoyuan] add golden files for leftsemijoin 83a3c8a [Daoyuan] scala style fix 14cff80 [Daoyuan] add support for left semi join
1 parent 35630c8 commit 0cf6002

File tree

37 files changed

+216
-3
lines changed

37 files changed

+216
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
131131
protected val OUTER = Keyword("OUTER")
132132
protected val RIGHT = Keyword("RIGHT")
133133
protected val SELECT = Keyword("SELECT")
134+
protected val SEMI = Keyword("SEMI")
134135
protected val STRING = Keyword("STRING")
135136
protected val SUM = Keyword("SUM")
136137
protected val TRUE = Keyword("TRUE")
@@ -241,6 +242,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
241242

242243
protected lazy val joinType: Parser[JoinType] =
243244
INNER ^^^ Inner |
245+
LEFT ~ SEMI ^^^ LeftSemi |
244246
LEFT ~ opt(OUTER) ^^^ LeftOuter |
245247
RIGHT ~ opt(OUTER) ^^^ RightOuter |
246248
FULL ~ opt(OUTER) ^^^ FullOuter

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
119119
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
120120
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
121121
splitPredicates(predicates ++ condition, join)
122+
// All predicates can be evaluated for left semi join (those that are in the WHERE
123+
// clause can only from left table, so they can all be pushed down.)
124+
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
125+
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
126+
splitPredicates(predicates ++ condition, join)
122127
case join @ Join(left, right, joinType, condition) =>
123128
logger.debug(s"Considering hash join on: $condition")
124129
splitPredicates(condition.toSeq, join)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ case object Inner extends JoinType
2222
case object LeftOuter extends JoinType
2323
case object RightOuter extends JoinType
2424
case object FullOuter extends JoinType
25+
case object LeftSemi extends JoinType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.sql.catalyst.expressions._
21-
import org.apache.spark.sql.catalyst.plans.JoinType
21+
import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
2222
import org.apache.spark.sql.catalyst.types._
2323

2424
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@@ -81,7 +81,12 @@ case class Join(
8181
condition: Option[Expression]) extends BinaryNode {
8282

8383
def references = condition.map(_.references).getOrElse(Set.empty)
84-
def output = left.output ++ right.output
84+
def output = joinType match {
85+
case LeftSemi =>
86+
left.output
87+
case _ =>
88+
left.output ++ right.output
89+
}
8590
}
8691

8792
case class InsertIntoTable(

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
193193
val strategies: Seq[Strategy] =
194194
TakeOrdered ::
195195
PartialAggregation ::
196+
LeftSemiJoin ::
196197
HashJoin ::
197198
ParquetOperations ::
198199
BasicOperators ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
2828
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
2929
self: SQLContext#SparkPlanner =>
3030

31+
object LeftSemiJoin extends Strategy with PredicateHelper {
32+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
33+
// Find left semi joins where at least some predicates can be evaluated by matching hash
34+
// keys using the HashFilteredJoin pattern.
35+
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
36+
val semiJoin = execution.LeftSemiJoinHash(
37+
leftKeys, rightKeys, planLater(left), planLater(right))
38+
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
39+
// no predicate can be evaluated by matching hash keys
40+
case logical.Join(left, right, LeftSemi, condition) =>
41+
execution.LeftSemiJoinBNL(
42+
planLater(left), planLater(right), condition)(sparkContext) :: Nil
43+
case _ => Nil
44+
}
45+
}
46+
3147
object HashJoin extends Strategy with PredicateHelper {
3248
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3349
// Find inner joins where at least some predicates can be evaluated by matching hash keys

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,137 @@ case class HashJoin(
140140
}
141141
}
142142

143+
/**
144+
* :: DeveloperApi ::
145+
* Build the right table's join keys into a HashSet, and iteratively go through the left
146+
* table, to find the if join keys are in the Hash set.
147+
*/
148+
@DeveloperApi
149+
case class LeftSemiJoinHash(
150+
leftKeys: Seq[Expression],
151+
rightKeys: Seq[Expression],
152+
left: SparkPlan,
153+
right: SparkPlan) extends BinaryNode {
154+
155+
override def outputPartitioning: Partitioning = left.outputPartitioning
156+
157+
override def requiredChildDistribution =
158+
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
159+
160+
val (buildPlan, streamedPlan) = (right, left)
161+
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
162+
163+
def output = left.output
164+
165+
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
166+
@transient lazy val streamSideKeyGenerator =
167+
() => new MutableProjection(streamedKeys, streamedPlan.output)
168+
169+
def execute() = {
170+
171+
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
172+
val hashTable = new java.util.HashSet[Row]()
173+
var currentRow: Row = null
174+
175+
// Create a Hash set of buildKeys
176+
while (buildIter.hasNext) {
177+
currentRow = buildIter.next()
178+
val rowKey = buildSideKeyGenerator(currentRow)
179+
if(!rowKey.anyNull) {
180+
val keyExists = hashTable.contains(rowKey)
181+
if (!keyExists) {
182+
hashTable.add(rowKey)
183+
}
184+
}
185+
}
186+
187+
new Iterator[Row] {
188+
private[this] var currentStreamedRow: Row = _
189+
private[this] var currentHashMatched: Boolean = false
190+
191+
private[this] val joinKeys = streamSideKeyGenerator()
192+
193+
override final def hasNext: Boolean =
194+
streamIter.hasNext && fetchNext()
195+
196+
override final def next() = {
197+
currentStreamedRow
198+
}
199+
200+
/**
201+
* Searches the streamed iterator for the next row that has at least one match in hashtable.
202+
*
203+
* @return true if the search is successful, and false the streamed iterator runs out of
204+
* tuples.
205+
*/
206+
private final def fetchNext(): Boolean = {
207+
currentHashMatched = false
208+
while (!currentHashMatched && streamIter.hasNext) {
209+
currentStreamedRow = streamIter.next()
210+
if (!joinKeys(currentStreamedRow).anyNull) {
211+
currentHashMatched = hashTable.contains(joinKeys.currentValue)
212+
}
213+
}
214+
currentHashMatched
215+
}
216+
}
217+
}
218+
}
219+
}
220+
221+
/**
222+
* :: DeveloperApi ::
223+
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
224+
* for hash join.
225+
*/
226+
@DeveloperApi
227+
case class LeftSemiJoinBNL(
228+
streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
229+
(@transient sc: SparkContext)
230+
extends BinaryNode {
231+
// TODO: Override requiredChildDistribution.
232+
233+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
234+
235+
override def otherCopyArgs = sc :: Nil
236+
237+
def output = left.output
238+
239+
/** The Streamed Relation */
240+
def left = streamed
241+
/** The Broadcast relation */
242+
def right = broadcast
243+
244+
@transient lazy val boundCondition =
245+
InterpretedPredicate(
246+
condition
247+
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
248+
.getOrElse(Literal(true)))
249+
250+
251+
def execute() = {
252+
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
253+
254+
streamed.execute().mapPartitions { streamedIter =>
255+
val joinedRow = new JoinedRow
256+
257+
streamedIter.filter(streamedRow => {
258+
var i = 0
259+
var matched = false
260+
261+
while (i < broadcastedRelation.value.size && !matched) {
262+
val broadcastedRow = broadcastedRelation.value(i)
263+
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
264+
matched = true
265+
}
266+
i += 1
267+
}
268+
matched
269+
})
270+
}
271+
}
272+
}
273+
143274
/**
144275
* :: DeveloperApi ::
145276
*/

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
4444
fail(
4545
s"""
4646
|Exception thrown while executing query:
47-
|${rdd.logicalPlan}
47+
|${rdd.queryExecution}
4848
|== Exception ==
4949
|$e
5050
""".stripMargin)

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
4040
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
4141
}
4242

43+
test("left semi greater than predicate") {
44+
checkAnswer(
45+
sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
46+
Seq((3,1), (3,2))
47+
)
48+
}
49+
4350
test("index into array of arrays") {
4451
checkAnswer(
4552
sql(

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
224224
DataSinks,
225225
Scripts,
226226
PartialAggregation,
227+
LeftSemiJoin,
227228
HashJoin,
228229
BasicOperators,
229230
CartesianProduct,

0 commit comments

Comments
 (0)