Skip to content

Commit 87b7d37

Browse files
committed
Use the proper serializer in limit.
1 parent 9b79246 commit 87b7d37

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,14 @@ package execution
2020

2121
import scala.reflect.runtime.universe.TypeTag
2222

23-
import org.apache.spark.rdd.RDD
24-
import org.apache.spark.SparkContext
25-
23+
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
24+
import org.apache.spark.rdd.{RDD, ShuffledRDD}
25+
import org.apache.spark.sql.catalyst.ScalaReflection
2626
import org.apache.spark.sql.catalyst.errors._
2727
import org.apache.spark.sql.catalyst.expressions._
2828
import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution}
29-
import org.apache.spark.sql.catalyst.ScalaReflection
29+
import org.apache.spark.util.MutablePair
30+
3031

3132
case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode {
3233
override def output = projectList.map(_.toAttribute)
@@ -70,17 +71,24 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
7071
* data to a single partition to compute the global limit.
7172
*/
7273
case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode {
74+
// TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
75+
// partition local limit -> exchange into one partition -> partition local limit again
76+
7377
override def otherCopyArgs = sc :: Nil
7478

7579
override def output = child.output
7680

7781
override def executeCollect() = child.execute().map(_.copy()).take(limit)
7882

7983
override def execute() = {
80-
child.execute()
81-
.mapPartitions(_.take(limit).map(_.copy()))
82-
.coalesce(1, shuffle = true)
83-
.mapPartitions(_.take(limit))
84+
val rdd = child.execute().mapPartitions { iter =>
85+
val mutablePair = new MutablePair[Boolean, Row]()
86+
iter.take(limit).map(row => mutablePair.update(false, row))
87+
}
88+
val part = new HashPartitioner(1)
89+
val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
90+
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
91+
shuffled.mapPartitions(_.take(limit).map(_._2))
8492
}
8593
}
8694

0 commit comments

Comments
 (0)