Skip to content

Commit 4c726e5

Browse files
committed
improvement according to Michael
1 parent 8d4a121 commit 4c726e5

File tree

3 files changed

+26
-34
lines changed

3 files changed

+26
-34
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
119119
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
120120
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
121121
splitPredicates(predicates ++ condition, join)
122+
// All predicates can be evaluated for left semi join (those that are in the WHERE
123+
// clause can only from left table, so they can all be pushed down.)
124+
case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
125+
logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
126+
splitPredicates(predicates ++ condition, join)
122127
case join @ Join(left, right, joinType, condition) =>
123128
logger.debug(s"Considering hash join on: $condition")
124129
splitPredicates(condition.toSeq, join)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
3030

3131
object LeftSemiJoin extends Strategy with PredicateHelper {
3232
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
33-
// Find leftsemi joins where at least some predicates can be evaluated by matching hash keys
34-
// using the HashFilteredJoin pattern.
33+
// Find left semi joins where at least some predicates can be evaluated by matching hash
34+
// keys using the HashFilteredJoin pattern.
3535
case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
3636
val semiJoin = execution.LeftSemiJoinHash(
37-
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
37+
leftKeys, rightKeys, planLater(left), planLater(right))
3838
condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
3939
// no predicate can be evaluated by matching hash keys
4040
case logical.Join(left, right, LeftSemi, condition) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -142,29 +142,23 @@ case class HashJoin(
142142

143143
/**
144144
* :: DeveloperApi ::
145+
* Build the right table's join keys into a HashSet, and iteratively go through the left
146+
* table, to find the if join keys are in the Hash set.
145147
*/
146148
@DeveloperApi
147149
case class LeftSemiJoinHash(
148-
leftKeys: Seq[Expression],
149-
rightKeys: Seq[Expression],
150-
buildSide: BuildSide,
151-
left: SparkPlan,
152-
right: SparkPlan) extends BinaryNode {
150+
leftKeys: Seq[Expression],
151+
rightKeys: Seq[Expression],
152+
left: SparkPlan,
153+
right: SparkPlan) extends BinaryNode {
153154

154155
override def outputPartitioning: Partitioning = left.outputPartitioning
155156

156157
override def requiredChildDistribution =
157158
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
158159

159-
val (buildPlan, streamedPlan) = buildSide match {
160-
case BuildLeft => (left, right)
161-
case BuildRight => (right, left)
162-
}
163-
164-
val (buildKeys, streamedKeys) = buildSide match {
165-
case BuildLeft => (leftKeys, rightKeys)
166-
case BuildRight => (rightKeys, leftKeys)
167-
}
160+
val (buildPlan, streamedPlan) = (right, left)
161+
val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
168162

169163
def output = left.output
170164

@@ -175,24 +169,18 @@ case class LeftSemiJoinHash(
175169
def execute() = {
176170

177171
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
178-
// TODO: Use Spark's HashMap implementation.
179-
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
172+
val hashTable = new java.util.HashSet[Row]()
180173
var currentRow: Row = null
181174

182-
// Create a mapping of buildKeys -> rows
175+
// Create a Hash set of buildKeys
183176
while (buildIter.hasNext) {
184177
currentRow = buildIter.next()
185178
val rowKey = buildSideKeyGenerator(currentRow)
186179
if(!rowKey.anyNull) {
187-
val existingMatchList = hashTable.get(rowKey)
188-
val matchList = if (existingMatchList == null) {
189-
val newMatchList = new ArrayBuffer[Row]()
190-
hashTable.put(rowKey, newMatchList)
191-
newMatchList
192-
} else {
193-
existingMatchList
180+
val keyExists = hashTable.contains(rowKey)
181+
if (!keyExists) {
182+
hashTable.add(rowKey)
194183
}
195-
matchList += currentRow.copy()
196184
}
197185
}
198186

@@ -220,7 +208,7 @@ case class LeftSemiJoinHash(
220208
while (!currentHashMatched && streamIter.hasNext) {
221209
currentStreamedRow = streamIter.next()
222210
if (!joinKeys(currentStreamedRow).anyNull) {
223-
currentHashMatched = true
211+
currentHashMatched = hashTable.contains(joinKeys.currentValue)
224212
}
225213
}
226214
currentHashMatched
@@ -232,6 +220,8 @@ case class LeftSemiJoinHash(
232220

233221
/**
234222
* :: DeveloperApi ::
223+
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
224+
* for hash join.
235225
*/
236226
@DeveloperApi
237227
case class LeftSemiJoinBNL(
@@ -261,26 +251,23 @@ case class LeftSemiJoinBNL(
261251
def execute() = {
262252
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
263253

264-
val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
254+
streamed.execute().mapPartitions { streamedIter =>
265255
val joinedRow = new JoinedRow
266256

267257
streamedIter.filter(streamedRow => {
268258
var i = 0
269259
var matched = false
270260

271261
while (i < broadcastedRelation.value.size && !matched) {
272-
// TODO: One bitset per partition instead of per row.
273262
val broadcastedRow = broadcastedRelation.value(i)
274263
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
275264
matched = true
276265
}
277266
i += 1
278267
}
279268
matched
280-
}).map(streamedRow => (streamedRow, null))
269+
})
281270
}
282-
283-
streamedPlusMatches.map(_._1)
284271
}
285272
}
286273

0 commit comments

Comments
 (0)