Skip to content

Commit 585638e

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-2213] [SQL] sort merge join for spark sql
Thanks for the initial work from Ishiihara in #3173 This PR introduce a new join method of sort merge join, which firstly ensure that keys of same value are in the same partition, and inside each partition the Rows are sorted by key. Then we can run down both sides together, find matched rows using [sort merge join](http://en.wikipedia.org/wiki/Sort-merge_join). In this way, we don't have to store the whole hash table of one side as hash join, thus we have less memory usage. Also, this PR would benefit from #3438 , making the sorting phrase much more efficient. We introduced a new configuration of "spark.sql.planner.sortMergeJoin" to switch between this(`true`) and ShuffledHashJoin(`false`), probably we want the default value of it be `false` at first. Author: Daoyuan Wang <[email protected]> Author: Michael Armbrust <[email protected]> This patch had conflicts when merged, resolved by Committer: Michael Armbrust <[email protected]> Closes #5208 from adrian-wang/smj and squashes the following commits: 2493b9f [Daoyuan Wang] fix style 5049d88 [Daoyuan Wang] propagate rowOrdering for RangePartitioning f91a2ae [Daoyuan Wang] yin's comment: use external sort if option is enabled, add comments f515cd2 [Daoyuan Wang] yin's comment: outputOrdering, join suite refine ec8061b [Daoyuan Wang] minor change 413fd24 [Daoyuan Wang] Merge pull request #3 from marmbrus/pr/5208 952168a [Michael Armbrust] add type 5492884 [Michael Armbrust] copy when ordering 7ddd656 [Michael Armbrust] Cleanup addition of ordering requirements b198278 [Daoyuan Wang] inherit ordering in project c8e82a3 [Daoyuan Wang] fix style 6e897dd [Daoyuan Wang] hide boundReference from manually construct RowOrdering for key compare in smj 8681d73 [Daoyuan Wang] refactor Exchange and fix copy for sorting 2875ef2 [Daoyuan Wang] fix changed configuration 61d7f49 [Daoyuan Wang] add omitted comment 00a4430 [Daoyuan Wang] fix bug 078d69b [Daoyuan Wang] address comments: add comments, do sort in shuffle, and others 3af6ba5 [Daoyuan Wang] use buffer for only one side 171001f [Daoyuan Wang] change default outputordering 47455c9 [Daoyuan Wang] add apache license ... a28277f [Daoyuan Wang] fix style 645c70b [Daoyuan Wang] address comments using sort 068c35d [Daoyuan Wang] fix new style and add some tests 925203b [Daoyuan Wang] address comments 07ce92f [Daoyuan Wang] fix ArrayIndexOutOfBound 42fca0e [Daoyuan Wang] code clean e3ec096 [Daoyuan Wang] fix comment style.. 2edd235 [Daoyuan Wang] fix outputpartitioning 57baa40 [Daoyuan Wang] fix sort eval bug 303b6da [Daoyuan Wang] fix several errors 95db7ad [Daoyuan Wang] fix brackets for if-statement 4464f16 [Daoyuan Wang] fix error 880d8e9 [Daoyuan Wang] sort merge join for spark sql
1 parent 4754e16 commit 585638e

File tree

11 files changed

+534
-33
lines changed

11 files changed

+534
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.types.{UTF8String, StructType, NativeType}
21-
20+
import org.apache.spark.sql.types.{UTF8String, DataType, StructType, NativeType}
2221

2322
/**
2423
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
@@ -239,3 +238,10 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
239238
return 0
240239
}
241240
}
241+
242+
object RowOrdering {
243+
def forSchema(dataTypes: Seq[DataType]): RowOrdering =
244+
new RowOrdering(dataTypes.zipWithIndex.map {
245+
case(dt, index) => new SortOrder(BoundReference(index, dt, nullable = true), Ascending)
246+
})
247+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ sealed trait Partitioning {
9494
* only compatible if the `numPartitions` of them is the same.
9595
*/
9696
def compatibleWith(other: Partitioning): Boolean
97+
98+
/** Returns the expressions that are used to key the partitioning. */
99+
def keyExpressions: Seq[Expression]
97100
}
98101

99102
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -106,6 +109,8 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
106109
case UnknownPartitioning(_) => true
107110
case _ => false
108111
}
112+
113+
override def keyExpressions: Seq[Expression] = Nil
109114
}
110115

111116
case object SinglePartition extends Partitioning {
@@ -117,6 +122,8 @@ case object SinglePartition extends Partitioning {
117122
case SinglePartition => true
118123
case _ => false
119124
}
125+
126+
override def keyExpressions: Seq[Expression] = Nil
120127
}
121128

122129
case object BroadcastPartitioning extends Partitioning {
@@ -128,6 +135,8 @@ case object BroadcastPartitioning extends Partitioning {
128135
case SinglePartition => true
129136
case _ => false
130137
}
138+
139+
override def keyExpressions: Seq[Expression] = Nil
131140
}
132141

133142
/**
@@ -158,6 +167,8 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
158167
case _ => false
159168
}
160169

170+
override def keyExpressions: Seq[Expression] = expressions
171+
161172
override def eval(input: Row = null): EvaluatedType =
162173
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
163174
}
@@ -200,6 +211,8 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
200211
case _ => false
201212
}
202213

214+
override def keyExpressions: Seq[Expression] = ordering.map(_.child)
215+
203216
override def eval(input: Row): EvaluatedType =
204217
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
205218
}

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ private[spark] object SQLConf {
4747
// Options that control which operators can be chosen by the query planner. These should be
4848
// considered hints and may be ignored by future versions of Spark SQL.
4949
val EXTERNAL_SORT = "spark.sql.planner.externalSort"
50+
val SORTMERGE_JOIN = "spark.sql.planner.sortMergeJoin"
5051

5152
// This is only used for the thriftserver
5253
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
@@ -128,6 +129,13 @@ private[sql] class SQLConf extends Serializable {
128129
/** When true the planner will use the external sort, which may spill to disk. */
129130
private[spark] def externalSortEnabled: Boolean = getConf(EXTERNAL_SORT, "false").toBoolean
130131

132+
/**
133+
* Sort merge join would sort the two side of join first, and then iterate both sides together
134+
* only once to get all matches. Using sort merge join can save a lot of memory usage compared
135+
* to HashJoin.
136+
*/
137+
private[spark] def sortMergeJoinEnabled: Boolean = getConf(SORTMERGE_JOIN, "false").toBoolean
138+
131139
/**
132140
* When set to true, Spark SQL will use the Scala compiler at runtime to generate custom bytecode
133141
* that evaluates expressions found in queries. In general this custom code runs much faster

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
10811081
@transient
10821082
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
10831083
val batches =
1084-
Batch("Add exchange", Once, AddExchange(self)) :: Nil
1084+
Batch("Add exchange", Once, EnsureRequirements(self)) :: Nil
10851085
}
10861086

10871087
protected[sql] def openSession(): SQLSession = {

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 120 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,42 @@ package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
2121
import org.apache.spark.shuffle.sort.SortShuffleManager
22-
import org.apache.spark.sql.catalyst.expressions
2322
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
2423
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2524
import org.apache.spark.sql.{SQLContext, Row}
2625
import org.apache.spark.sql.catalyst.errors.attachTree
27-
import org.apache.spark.sql.catalyst.expressions.{Attribute, RowOrdering}
26+
import org.apache.spark.sql.catalyst.expressions._
2827
import org.apache.spark.sql.catalyst.plans.physical._
2928
import org.apache.spark.sql.catalyst.rules.Rule
3029
import org.apache.spark.util.MutablePair
3130

31+
object Exchange {
32+
/**
33+
* Returns true when the ordering expressions are a subset of the key.
34+
* if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
35+
*/
36+
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
37+
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
38+
}
39+
}
40+
3241
/**
3342
* :: DeveloperApi ::
43+
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
44+
* resulting partition based on expressions from the partition key. It is invalid to construct an
45+
* exchange operator with a `newOrdering` that cannot be calculated using the partitioning key.
3446
*/
3547
@DeveloperApi
36-
case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {
48+
case class Exchange(
49+
newPartitioning: Partitioning,
50+
newOrdering: Seq[SortOrder],
51+
child: SparkPlan)
52+
extends UnaryNode {
3753

3854
override def outputPartitioning: Partitioning = newPartitioning
3955

56+
override def outputOrdering: Seq[SortOrder] = newOrdering
57+
4058
override def output: Seq[Attribute] = child.output
4159

4260
/** We must copy rows when sort based shuffle is on */
@@ -45,6 +63,20 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
4563
private val bypassMergeThreshold =
4664
child.sqlContext.sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
4765

66+
private val keyOrdering = {
67+
if (newOrdering.nonEmpty) {
68+
val key = newPartitioning.keyExpressions
69+
val boundOrdering = newOrdering.map { o =>
70+
val ordinal = key.indexOf(o.child)
71+
if (ordinal == -1) sys.error(s"Invalid ordering on $o requested for $newPartitioning")
72+
o.copy(child = BoundReference(ordinal, o.child.dataType, o.child.nullable))
73+
}
74+
new RowOrdering(boundOrdering)
75+
} else {
76+
null // Ordering will not be used
77+
}
78+
}
79+
4880
override def execute(): RDD[Row] = attachTree(this , "execute") {
4981
newPartitioning match {
5082
case HashPartitioning(expressions, numPartitions) =>
@@ -56,7 +88,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
5688
// we can avoid the defensive copies to improve performance. In the long run, we probably
5789
// want to include information in shuffle dependencies to indicate whether elements in the
5890
// source RDD should be copied.
59-
val rdd = if (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) {
91+
val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
92+
93+
val rdd = if (willMergeSort || newOrdering.nonEmpty) {
6094
child.execute().mapPartitions { iter =>
6195
val hashExpressions = newMutableProjection(expressions, child.output)()
6296
iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -69,12 +103,17 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
69103
}
70104
}
71105
val part = new HashPartitioner(numPartitions)
72-
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
106+
val shuffled =
107+
if (newOrdering.nonEmpty) {
108+
new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(keyOrdering)
109+
} else {
110+
new ShuffledRDD[Row, Row, Row](rdd, part)
111+
}
73112
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
74113
shuffled.map(_._2)
75114

76115
case RangePartitioning(sortingExpressions, numPartitions) =>
77-
val rdd = if (sortBasedShuffleOn) {
116+
val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
78117
child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))}
79118
} else {
80119
child.execute().mapPartitions { iter =>
@@ -87,7 +126,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
87126
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
88127

89128
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
90-
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
129+
val shuffled =
130+
if (newOrdering.nonEmpty) {
131+
new ShuffledRDD[Row, Null, Null](rdd, part).setKeyOrdering(keyOrdering)
132+
} else {
133+
new ShuffledRDD[Row, Null, Null](rdd, part)
134+
}
91135
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
92136

93137
shuffled.map(_._1)
@@ -120,27 +164,34 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
120164
* Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
121165
* of input data meets the
122166
* [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
123-
* each operator by inserting [[Exchange]] Operators where required.
167+
* each operator by inserting [[Exchange]] Operators where required. Also ensure that the
168+
* required input partition ordering requirements are met.
124169
*/
125-
private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
170+
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
126171
// TODO: Determine the number of partitions.
127172
def numPartitions: Int = sqlContext.conf.numShufflePartitions
128173

129174
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
130175
case operator: SparkPlan =>
131-
// Check if every child's outputPartitioning satisfies the corresponding
176+
// True iff every child's outputPartitioning satisfies the corresponding
132177
// required data distribution.
133178
def meetsRequirements: Boolean =
134-
!operator.requiredChildDistribution.zip(operator.children).map {
179+
operator.requiredChildDistribution.zip(operator.children).forall {
135180
case (required, child) =>
136181
val valid = child.outputPartitioning.satisfies(required)
137182
logDebug(
138183
s"${if (valid) "Valid" else "Invalid"} distribution," +
139184
s"required: $required current: ${child.outputPartitioning}")
140185
valid
141-
}.exists(!_)
186+
}
142187

143-
// Check if outputPartitionings of children are compatible with each other.
188+
// True iff any of the children are incorrectly sorted.
189+
def needsAnySort: Boolean =
190+
operator.requiredChildOrdering.zip(operator.children).exists {
191+
case (required, child) => required.nonEmpty && required != child.outputOrdering
192+
}
193+
194+
// True iff outputPartitionings of children are compatible with each other.
144195
// It is possible that every child satisfies its required data distribution
145196
// but two children have incompatible outputPartitionings. For example,
146197
// A dataset is range partitioned by "a.asc" (RangePartitioning) and another
@@ -157,28 +208,69 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl
157208
case Seq(a,b) => a compatibleWith b
158209
}.exists(!_)
159210

160-
// Check if the partitioning we want to ensure is the same as the child's output
161-
// partitioning. If so, we do not need to add the Exchange operator.
162-
def addExchangeIfNecessary(partitioning: Partitioning, child: SparkPlan): SparkPlan =
163-
if (child.outputPartitioning != partitioning) Exchange(partitioning, child) else child
211+
// Adds Exchange or Sort operators as required
212+
def addOperatorsIfNecessary(
213+
partitioning: Partitioning,
214+
rowOrdering: Seq[SortOrder],
215+
child: SparkPlan): SparkPlan = {
216+
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
217+
val needsShuffle = child.outputPartitioning != partitioning
218+
val canSortWithShuffle = Exchange.canSortWithShuffle(partitioning, rowOrdering)
219+
220+
if (needSort && needsShuffle && canSortWithShuffle) {
221+
Exchange(partitioning, rowOrdering, child)
222+
} else {
223+
val withShuffle = if (needsShuffle) {
224+
Exchange(partitioning, Nil, child)
225+
} else {
226+
child
227+
}
164228

165-
if (meetsRequirements && compatible) {
229+
val withSort = if (needSort) {
230+
if (sqlContext.conf.externalSortEnabled) {
231+
ExternalSort(rowOrdering, global = false, withShuffle)
232+
} else {
233+
Sort(rowOrdering, global = false, withShuffle)
234+
}
235+
} else {
236+
withShuffle
237+
}
238+
239+
withSort
240+
}
241+
}
242+
243+
if (meetsRequirements && compatible && !needsAnySort) {
166244
operator
167245
} else {
168246
// At least one child does not satisfies its required data distribution or
169247
// at least one child's outputPartitioning is not compatible with another child's
170248
// outputPartitioning. In this case, we need to add Exchange operators.
171-
val repartitionedChildren = operator.requiredChildDistribution.zip(operator.children).map {
172-
case (AllTuples, child) =>
173-
addExchangeIfNecessary(SinglePartition, child)
174-
case (ClusteredDistribution(clustering), child) =>
175-
addExchangeIfNecessary(HashPartitioning(clustering, numPartitions), child)
176-
case (OrderedDistribution(ordering), child) =>
177-
addExchangeIfNecessary(RangePartitioning(ordering, numPartitions), child)
178-
case (UnspecifiedDistribution, child) => child
179-
case (dist, _) => sys.error(s"Don't know how to ensure $dist")
249+
val requirements =
250+
(operator.requiredChildDistribution, operator.requiredChildOrdering, operator.children)
251+
252+
val fixedChildren = requirements.zipped.map {
253+
case (AllTuples, rowOrdering, child) =>
254+
addOperatorsIfNecessary(SinglePartition, rowOrdering, child)
255+
case (ClusteredDistribution(clustering), rowOrdering, child) =>
256+
addOperatorsIfNecessary(HashPartitioning(clustering, numPartitions), rowOrdering, child)
257+
case (OrderedDistribution(ordering), rowOrdering, child) =>
258+
addOperatorsIfNecessary(RangePartitioning(ordering, numPartitions), rowOrdering, child)
259+
260+
case (UnspecifiedDistribution, Seq(), child) =>
261+
child
262+
case (UnspecifiedDistribution, rowOrdering, child) =>
263+
if (sqlContext.conf.externalSortEnabled) {
264+
ExternalSort(rowOrdering, global = false, child)
265+
} else {
266+
Sort(rowOrdering, global = false, child)
267+
}
268+
269+
case (dist, ordering, _) =>
270+
sys.error(s"Don't know how to ensure $dist with ordering $ordering")
180271
}
181-
operator.withNewChildren(repartitionedChildren)
272+
273+
operator.withNewChildren(fixedChildren)
182274
}
183275
}
184276
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
7272
def requiredChildDistribution: Seq[Distribution] =
7373
Seq.fill(children.size)(UnspecifiedDistribution)
7474

75+
/** Specifies how data is ordered in each partition. */
76+
def outputOrdering: Seq[SortOrder] = Nil
77+
78+
/** Specifies sort order for each partition requirements on the input data for this operator. */
79+
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
80+
7581
/**
7682
* Runs this query returning the result as an RDD.
7783
*/

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
9090
left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold =>
9191
makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft)
9292

93+
// If the sort merge join option is set, we want to use sort merge join prior to hashjoin
94+
// for now let's support inner join first, then add outer join
95+
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
96+
if sqlContext.conf.sortMergeJoinEnabled =>
97+
val mergeJoin =
98+
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
99+
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
100+
93101
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>
94102
val buildSide =
95103
if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
@@ -309,7 +317,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
309317
case logical.OneRowRelation =>
310318
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
311319
case logical.Repartition(expressions, child) =>
312-
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
320+
execution.Exchange(
321+
HashPartitioning(expressions, numPartitions), Nil, planLater(child)) :: Nil
313322
case e @ EvaluatePython(udf, child, _) =>
314323
BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
315324
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil

0 commit comments

Comments
 (0)