@@ -142,29 +142,23 @@ case class HashJoin(
142
142
143
143
/**
144
144
* :: DeveloperApi ::
145
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
146
+ * table, to find the if join keys are in the Hash set.
145
147
*/
146
148
@ DeveloperApi
147
149
case class LeftSemiJoinHash (
148
- leftKeys : Seq [Expression ],
149
- rightKeys : Seq [Expression ],
150
- buildSide : BuildSide ,
151
- left : SparkPlan ,
152
- right : SparkPlan ) extends BinaryNode {
150
+ leftKeys : Seq [Expression ],
151
+ rightKeys : Seq [Expression ],
152
+ left : SparkPlan ,
153
+ right : SparkPlan ) extends BinaryNode {
153
154
154
155
override def outputPartitioning : Partitioning = left.outputPartitioning
155
156
156
157
override def requiredChildDistribution =
157
158
ClusteredDistribution (leftKeys) :: ClusteredDistribution (rightKeys) :: Nil
158
159
159
- val (buildPlan, streamedPlan) = buildSide match {
160
- case BuildLeft => (left, right)
161
- case BuildRight => (right, left)
162
- }
163
-
164
- val (buildKeys, streamedKeys) = buildSide match {
165
- case BuildLeft => (leftKeys, rightKeys)
166
- case BuildRight => (rightKeys, leftKeys)
167
- }
160
+ val (buildPlan, streamedPlan) = (right, left)
161
+ val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
168
162
169
163
def output = left.output
170
164
@@ -175,24 +169,18 @@ case class LeftSemiJoinHash(
175
169
def execute () = {
176
170
177
171
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
178
- // TODO: Use Spark's HashMap implementation.
179
- val hashTable = new java.util.HashMap [Row , ArrayBuffer [Row ]]()
172
+ val hashTable = new java.util.HashSet [Row ]()
180
173
var currentRow : Row = null
181
174
182
- // Create a mapping of buildKeys -> rows
175
+ // Create a Hash set of buildKeys
183
176
while (buildIter.hasNext) {
184
177
currentRow = buildIter.next()
185
178
val rowKey = buildSideKeyGenerator(currentRow)
186
179
if (! rowKey.anyNull) {
187
- val existingMatchList = hashTable.get(rowKey)
188
- val matchList = if (existingMatchList == null ) {
189
- val newMatchList = new ArrayBuffer [Row ]()
190
- hashTable.put(rowKey, newMatchList)
191
- newMatchList
192
- } else {
193
- existingMatchList
180
+ val keyExists = hashTable.contains(rowKey)
181
+ if (! keyExists) {
182
+ hashTable.add(rowKey)
194
183
}
195
- matchList += currentRow.copy()
196
184
}
197
185
}
198
186
@@ -220,7 +208,7 @@ case class LeftSemiJoinHash(
220
208
while (! currentHashMatched && streamIter.hasNext) {
221
209
currentStreamedRow = streamIter.next()
222
210
if (! joinKeys(currentStreamedRow).anyNull) {
223
- currentHashMatched = true
211
+ currentHashMatched = hashTable.contains(joinKeys.currentValue)
224
212
}
225
213
}
226
214
currentHashMatched
@@ -232,6 +220,8 @@ case class LeftSemiJoinHash(
232
220
233
221
/**
234
222
* :: DeveloperApi ::
223
+ * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
224
+ * for hash join.
235
225
*/
236
226
@ DeveloperApi
237
227
case class LeftSemiJoinBNL (
@@ -261,26 +251,23 @@ case class LeftSemiJoinBNL(
261
251
def execute () = {
262
252
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
263
253
264
- val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
254
+ streamed.execute().mapPartitions { streamedIter =>
265
255
val joinedRow = new JoinedRow
266
256
267
257
streamedIter.filter(streamedRow => {
268
258
var i = 0
269
259
var matched = false
270
260
271
261
while (i < broadcastedRelation.value.size && ! matched) {
272
- // TODO: One bitset per partition instead of per row.
273
262
val broadcastedRow = broadcastedRelation.value(i)
274
263
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
275
264
matched = true
276
265
}
277
266
i += 1
278
267
}
279
268
matched
280
- }).map(streamedRow => (streamedRow, null ))
269
+ })
281
270
}
282
-
283
- streamedPlusMatches.map(_._1)
284
271
}
285
272
}
286
273
0 commit comments