Skip to content

Commit b5a4efa

Browse files
author
kai
committed
(1) Add broadcast hash outer join, (2) Fix SparkPlanTest
1 parent d16a944 commit b5a4efa

File tree

8 files changed

+450
-102
lines changed

8 files changed

+450
-102
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,19 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
118118
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
119119

120120
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) =>
121-
joins.HashOuterJoin(
122-
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
121+
joinType match {
122+
case LeftOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
123+
right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
124+
joins.BroadcastHashOuterJoin(
125+
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
126+
case RightOuter if sqlContext.conf.autoBroadcastJoinThreshold > 0 &&
127+
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
128+
joins.BroadcastHashOuterJoin(
129+
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
130+
case _ =>
131+
joins.ShuffledHashOuterJoin(
132+
leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil
133+
}
123134

124135
case _ => Nil
125136
}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import org.apache.spark.annotation.DeveloperApi
21+
import org.apache.spark.rdd.RDD
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
24+
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
25+
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
26+
import org.apache.spark.util.ThreadUtils
27+
28+
import scala.concurrent._
29+
import scala.concurrent.duration._
30+
31+
/**
32+
* :: DeveloperApi ::
33+
* Performs a outer hash join for two child relations. When the output RDD of this operator is
34+
* being constructed, a Spark job is asynchronously started to calculate the values for the
35+
* broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
36+
* relation is not shuffled.
37+
*/
38+
@DeveloperApi
39+
case class BroadcastHashOuterJoin(
40+
leftKeys: Seq[Expression],
41+
rightKeys: Seq[Expression],
42+
joinType: JoinType,
43+
condition: Option[Expression],
44+
left: SparkPlan,
45+
right: SparkPlan) extends BinaryNode with HashOuterJoin {
46+
47+
val timeout = {
48+
val timeoutValue = sqlContext.conf.broadcastTimeout
49+
if (timeoutValue < 0) {
50+
Duration.Inf
51+
} else {
52+
timeoutValue.seconds
53+
}
54+
}
55+
56+
override def requiredChildDistribution =
57+
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
58+
59+
private[this] lazy val (buildPlan, streamedPlan) = joinType match {
60+
case RightOuter => (left, right)
61+
case LeftOuter => (right, left)
62+
case x =>
63+
throw new IllegalArgumentException(
64+
s"BroadcastHashOuterJoin should not take $x as the JoinType")
65+
}
66+
67+
private[this] lazy val (buildKeys, streamedKeys) = joinType match {
68+
case RightOuter => (leftKeys, rightKeys)
69+
case LeftOuter => (rightKeys, leftKeys)
70+
case x =>
71+
throw new IllegalArgumentException(
72+
s"BroadcastHashOuterJoin should not take $x as the JoinType")
73+
}
74+
75+
@transient
76+
private val broadcastFuture = future {
77+
// Note that we use .execute().collect() because we don't want to convert data to Scala types
78+
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
79+
// buildHashTable uses code-generated rows as keys, which are not serializable
80+
val hashed = new GeneralHashedRelation(
81+
buildHashTable(input.iterator, newProjection(buildKeys, buildPlan.output)))
82+
sparkContext.broadcast(hashed)
83+
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
84+
85+
override def doExecute(): RDD[InternalRow] = {
86+
val broadcastRelation = Await.result(broadcastFuture, timeout)
87+
88+
streamedPlan.execute().mapPartitions { streamedIter =>
89+
val joinedRow = new JoinedRow()
90+
val hashTable = broadcastRelation.value
91+
val keyGenerator = newProjection(streamedKeys, streamedPlan.output)
92+
93+
joinType match {
94+
case LeftOuter =>
95+
streamedIter.flatMap(currentRow => {
96+
val rowKey = keyGenerator(currentRow)
97+
joinedRow.withLeft(currentRow)
98+
leftOuterIterator(rowKey, joinedRow, hashTable.getOrElse(rowKey, EMPTY_LIST))
99+
})
100+
101+
case RightOuter =>
102+
streamedIter.flatMap(currentRow => {
103+
val rowKey = keyGenerator(currentRow)
104+
joinedRow.withRight(currentRow)
105+
rightOuterIterator(rowKey, hashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
106+
})
107+
108+
case x =>
109+
throw new IllegalArgumentException(
110+
s"BroadcastHashOuterJoin should not take $x as the JoinType")
111+
}
112+
}
113+
}
114+
}
115+
116+
object BroadcastHashOuterJoin {
117+
118+
private val broadcastHashOuterJoinExecutionContext = ExecutionContext.fromExecutorService(
119+
ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-outer-join", 128))
120+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala

Lines changed: 22 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -19,42 +19,32 @@ package org.apache.spark.sql.execution.joins
1919

2020
import java.util.{HashMap => JavaHashMap}
2121

22-
import org.apache.spark.rdd.RDD
23-
24-
import scala.collection.JavaConversions._
25-
2622
import org.apache.spark.annotation.DeveloperApi
2723
import org.apache.spark.sql.catalyst.expressions._
28-
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning, UnknownPartitioning}
24+
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
2925
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
30-
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
26+
import org.apache.spark.sql.execution.SparkPlan
3127
import org.apache.spark.util.collection.CompactBuffer
3228

33-
/**
34-
* :: DeveloperApi ::
35-
* Performs a hash based outer join for two child relations by shuffling the data using
36-
* the join keys. This operator requires loading the associated partition in both side into memory.
37-
*/
3829
@DeveloperApi
39-
case class HashOuterJoin(
40-
leftKeys: Seq[Expression],
41-
rightKeys: Seq[Expression],
42-
joinType: JoinType,
43-
condition: Option[Expression],
44-
left: SparkPlan,
45-
right: SparkPlan) extends BinaryNode {
46-
47-
override def outputPartitioning: Partitioning = joinType match {
30+
trait HashOuterJoin {
31+
self: SparkPlan =>
32+
33+
val leftKeys: Seq[Expression]
34+
val rightKeys: Seq[Expression]
35+
val joinType: JoinType
36+
val condition: Option[Expression]
37+
val left: SparkPlan
38+
val right: SparkPlan
39+
40+
override def outputPartitioning: Partitioning = joinType match {
4841
case LeftOuter => left.outputPartitioning
4942
case RightOuter => right.outputPartitioning
5043
case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
5144
case x =>
5245
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
5346
}
5447

55-
override def requiredChildDistribution: Seq[ClusteredDistribution] =
56-
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
57-
5848
override def output: Seq[Attribute] = {
5949
joinType match {
6050
case LeftOuter =>
@@ -68,8 +58,8 @@ case class HashOuterJoin(
6858
}
6959
}
7060

71-
@transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null)
72-
@transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow]
61+
@transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null)
62+
@transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]()
7363

7464
@transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length)
7565
@transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length)
@@ -80,7 +70,7 @@ case class HashOuterJoin(
8070
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
8171
// iterator for performance purpose.
8272

83-
private[this] def leftOuterIterator(
73+
protected[this] def leftOuterIterator(
8474
key: InternalRow,
8575
joinedRow: JoinedRow,
8676
rightIter: Iterable[InternalRow]): Iterator[InternalRow] = {
@@ -89,7 +79,7 @@ case class HashOuterJoin(
8979
val temp = rightIter.collect {
9080
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
9181
}
92-
if (temp.size == 0) {
82+
if (temp.isEmpty) {
9383
joinedRow.withRight(rightNullRow).copy :: Nil
9484
} else {
9585
temp
@@ -101,18 +91,17 @@ case class HashOuterJoin(
10191
ret.iterator
10292
}
10393

104-
private[this] def rightOuterIterator(
94+
protected[this] def rightOuterIterator(
10595
key: InternalRow,
10696
leftIter: Iterable[InternalRow],
10797
joinedRow: JoinedRow): Iterator[InternalRow] = {
108-
10998
val ret: Iterable[InternalRow] = {
11099
if (!key.anyNull) {
111100
val temp = leftIter.collect {
112101
case l if boundCondition(joinedRow.withLeft(l)) =>
113-
joinedRow.copy
102+
joinedRow.copy()
114103
}
115-
if (temp.size == 0) {
104+
if (temp.isEmpty) {
116105
joinedRow.withLeft(leftNullRow).copy :: Nil
117106
} else {
118107
temp
@@ -124,10 +113,9 @@ case class HashOuterJoin(
124113
ret.iterator
125114
}
126115

127-
private[this] def fullOuterIterator(
116+
protected[this] def fullOuterIterator(
128117
key: InternalRow, leftIter: Iterable[InternalRow], rightIter: Iterable[InternalRow],
129118
joinedRow: JoinedRow): Iterator[InternalRow] = {
130-
131119
if (!key.anyNull) {
132120
// Store the positions of records in right, if one of its associated row satisfy
133121
// the join condition.
@@ -171,7 +159,7 @@ case class HashOuterJoin(
171159
}
172160
}
173161

174-
private[this] def buildHashTable(
162+
protected[this] def buildHashTable(
175163
iter: Iterator[InternalRow],
176164
keyGenerator: Projection): JavaHashMap[InternalRow, CompactBuffer[InternalRow]] = {
177165
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]]()
@@ -190,43 +178,4 @@ case class HashOuterJoin(
190178

191179
hashTable
192180
}
193-
194-
protected override def doExecute(): RDD[InternalRow] = {
195-
val joinedRow = new JoinedRow()
196-
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
197-
// TODO this probably can be replaced by external sort (sort merged join?)
198-
199-
joinType match {
200-
case LeftOuter =>
201-
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
202-
val keyGenerator = newProjection(leftKeys, left.output)
203-
leftIter.flatMap( currentRow => {
204-
val rowKey = keyGenerator(currentRow)
205-
joinedRow.withLeft(currentRow)
206-
leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST))
207-
})
208-
209-
case RightOuter =>
210-
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
211-
val keyGenerator = newProjection(rightKeys, right.output)
212-
rightIter.flatMap ( currentRow => {
213-
val rowKey = keyGenerator(currentRow)
214-
joinedRow.withRight(currentRow)
215-
rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow)
216-
})
217-
218-
case FullOuter =>
219-
val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
220-
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
221-
(leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
222-
fullOuterIterator(key,
223-
leftHashTable.getOrElse(key, EMPTY_LIST),
224-
rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow)
225-
}
226-
227-
case x =>
228-
throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
229-
}
230-
}
231-
}
232181
}

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ import org.apache.spark.util.collection.CompactBuffer
3232
private[joins] sealed trait HashedRelation {
3333
def get(key: InternalRow): CompactBuffer[InternalRow]
3434

35+
def getOrElse(
36+
key: InternalRow,
37+
default: CompactBuffer[InternalRow]): CompactBuffer[InternalRow] = {
38+
val v = get(key)
39+
if (v eq null) default else v
40+
}
41+
3542
// This is a helper method to implement Externalizable, and is used by
3643
// GeneralHashedRelation and UniqueKeyHashedRelation
3744
protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {

0 commit comments

Comments
 (0)