Skip to content

Commit 14cff80

Browse files
committed
add support for left semi join
1 parent 753b04d commit 14cff80

27 files changed

+197
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ case object Inner extends JoinType
2222
case object LeftOuter extends JoinType
2323
case object RightOuter extends JoinType
2424
case object FullOuter extends JoinType
25+
case object LeftSemi extends JoinType

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
193193
val strategies: Seq[Strategy] =
194194
TakeOrdered ::
195195
PartialAggregation ::
196+
LeftSemiJoin ::
196197
HashJoin ::
197198
ParquetOperations ::
198199
BasicOperators ::

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
2828
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
2929
self: SQLContext#SparkPlanner =>
3030

31+
object LeftSemiJoin extends Strategy with PredicateHelper {
32+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
33+
// Find leftsemi joins where at least some predicates can be evaluated by matching hash keys
34+
// using the HashFilteredJoin pattern.
35+
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
36+
val semiJoin =
37+
execution.LeftSemiJoinHash(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
38+
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
39+
// no predicate can be evaluated by matching hash keys
40+
case logical.Join(left, right, LeftSemi, condition) =>
41+
execution.LeftSemiJoinBNL(
42+
planLater(left), planLater(right), LeftSemi, condition)(sparkContext) :: Nil
43+
case _ => Nil
44+
}
45+
}
46+
3147
object HashJoin extends Strategy with PredicateHelper {
3248
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
3349
// Find inner joins where at least some predicates can be evaluated by matching hash keys

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,150 @@ case class HashJoin(
140140
}
141141
}
142142

143+
/**
144+
* :: DeveloperApi ::
145+
*/
146+
@DeveloperApi
147+
case class LeftSemiJoinHash(
148+
leftKeys: Seq[Expression],
149+
rightKeys: Seq[Expression],
150+
buildSide: BuildSide,
151+
left: SparkPlan,
152+
right: SparkPlan) extends BinaryNode {
153+
154+
override def outputPartitioning: Partitioning = left.outputPartitioning
155+
156+
override def requiredChildDistribution =
157+
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
158+
159+
val (buildPlan, streamedPlan) = buildSide match {
160+
case BuildLeft => (left, right)
161+
case BuildRight => (right, left)
162+
}
163+
164+
val (buildKeys, streamedKeys) = buildSide match {
165+
case BuildLeft => (leftKeys, rightKeys)
166+
case BuildRight => (rightKeys, leftKeys)
167+
}
168+
169+
def output = left.output
170+
171+
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
172+
@transient lazy val streamSideKeyGenerator =
173+
() => new MutableProjection(streamedKeys, streamedPlan.output)
174+
175+
def execute() = {
176+
177+
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
178+
// TODO: Use Spark's HashMap implementation.
179+
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
180+
var currentRow: Row = null
181+
182+
// Create a mapping of buildKeys -> rows
183+
while (buildIter.hasNext) {
184+
currentRow = buildIter.next()
185+
val rowKey = buildSideKeyGenerator(currentRow)
186+
if(!rowKey.anyNull) {
187+
val existingMatchList = hashTable.get(rowKey)
188+
val matchList = if (existingMatchList == null) {
189+
val newMatchList = new ArrayBuffer[Row]()
190+
hashTable.put(rowKey, newMatchList)
191+
newMatchList
192+
} else {
193+
existingMatchList
194+
}
195+
matchList += currentRow.copy()
196+
}
197+
}
198+
199+
new Iterator[Row] {
200+
private[this] var currentStreamedRow: Row = _
201+
private[this] var currentHashMatched: Boolean = false
202+
203+
private[this] val joinKeys = streamSideKeyGenerator()
204+
205+
override final def hasNext: Boolean =
206+
streamIter.hasNext && fetchNext()
207+
208+
override final def next() = {
209+
currentStreamedRow
210+
}
211+
212+
/**
213+
* Searches the streamed iterator for the next row that has at least one match in hashtable.
214+
*
215+
* @return true if the search is successful, and false the streamed iterator runs out of
216+
* tuples.
217+
*/
218+
private final def fetchNext(): Boolean = {
219+
currentHashMatched = false
220+
while (!currentHashMatched && streamIter.hasNext) {
221+
currentStreamedRow = streamIter.next()
222+
if (!joinKeys(currentStreamedRow).anyNull) {
223+
currentHashMatched = true
224+
}
225+
}
226+
currentHashMatched
227+
}
228+
}
229+
}
230+
}
231+
}
232+
233+
/**
234+
* :: DeveloperApi ::
235+
*/
236+
@DeveloperApi
237+
case class LeftSemiJoinBNL(
238+
streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression])
239+
(@transient sc: SparkContext)
240+
extends BinaryNode {
241+
// TODO: Override requiredChildDistribution.
242+
243+
override def outputPartitioning: Partitioning = streamed.outputPartitioning
244+
245+
override def otherCopyArgs = sc :: Nil
246+
247+
def output = left.output
248+
249+
/** The Streamed Relation */
250+
def left = streamed
251+
/** The Broadcast relation */
252+
def right = broadcast
253+
254+
@transient lazy val boundCondition =
255+
InterpretedPredicate(
256+
condition
257+
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
258+
.getOrElse(Literal(true)))
259+
260+
261+
def execute() = {
262+
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
263+
264+
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
265+
val joinedRow = new JoinedRow
266+
267+
streamedIter.filter(streamedRow => {
268+
var i = 0
269+
var matched = false
270+
271+
while (i < broadcastedRelation.value.size && !matched) {
272+
// TODO: One bitset per partition instead of per row.
273+
val broadcastedRow = broadcastedRelation.value(i)
274+
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
275+
matched = true
276+
}
277+
i += 1
278+
}
279+
matched
280+
}).map(streamedRow => (streamedRow, null))
281+
}
282+
283+
streamedPlusMatches.map(_._1)
284+
}
285+
}
286+
143287
/**
144288
* :: DeveloperApi ::
145289
*/

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
224224
DataSinks,
225225
Scripts,
226226
PartialAggregation,
227+
LeftSemiJoin,
227228
HashJoin,
228229
BasicOperators,
229230
CartesianProduct,

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ private[hive] object HiveQl {
680680
case "TOK_RIGHTOUTERJOIN" => RightOuter
681681
case "TOK_LEFTOUTERJOIN" => LeftOuter
682682
case "TOK_FULLOUTERJOIN" => FullOuter
683+
case "TOK_LEFTSEMIJOIN" => LeftSemi
683684
}
684685
assert(other.size <= 1, "Unhandled join clauses.")
685686
Join(nodeToRelation(relation1),

sql/hive/src/test/resources/golden/leftsemijoin-0-80b6466213face7fbcb0de044611e1f5

Whitespace-only changes.

sql/hive/src/test/resources/golden/leftsemijoin-1-d1f6a3dea28a5f0fee08026bf33d9129

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0

sql/hive/src/test/resources/golden/leftsemijoin-3-b07d292423312aafa5e5762a579decd2

Whitespace-only changes.

0 commit comments

Comments
 (0)