@@ -20,13 +20,14 @@ package execution
20
20
21
21
import scala .reflect .runtime .universe .TypeTag
22
22
23
- import org .apache .spark .rdd . RDD
24
- import org .apache .spark .SparkContext
25
-
23
+ import org .apache .spark .{ HashPartitioner , SparkConf , SparkContext }
24
+ import org .apache .spark .rdd .{ RDD , ShuffledRDD }
25
+ import org . apache . spark . sql . catalyst . ScalaReflection
26
26
import org .apache .spark .sql .catalyst .errors ._
27
27
import org .apache .spark .sql .catalyst .expressions ._
28
28
import org .apache .spark .sql .catalyst .plans .physical .{OrderedDistribution , UnspecifiedDistribution }
29
- import org .apache .spark .sql .catalyst .ScalaReflection
29
+ import org .apache .spark .util .MutablePair
30
+
30
31
31
32
case class Project (projectList : Seq [NamedExpression ], child : SparkPlan ) extends UnaryNode {
32
33
override def output = projectList.map(_.toAttribute)
@@ -70,17 +71,24 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends
70
71
* data to a single partition to compute the global limit.
71
72
*/
72
73
case class Limit (limit : Int , child : SparkPlan )(@ transient sc : SparkContext ) extends UnaryNode {
74
+ // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan:
75
+ // partition local limit -> exchange into one partition -> partition local limit again
76
+
73
77
override def otherCopyArgs = sc :: Nil
74
78
75
79
override def output = child.output
76
80
77
81
override def executeCollect () = child.execute().map(_.copy()).take(limit)
78
82
79
83
override def execute () = {
80
- child.execute()
81
- .mapPartitions(_.take(limit).map(_.copy()))
82
- .coalesce(1 , shuffle = true )
83
- .mapPartitions(_.take(limit))
84
+ val rdd = child.execute().mapPartitions { iter =>
85
+ val mutablePair = new MutablePair [Boolean , Row ]()
86
+ iter.take(limit).map(row => mutablePair.update(false , row))
87
+ }
88
+ val part = new HashPartitioner (1 )
89
+ val shuffled = new ShuffledRDD [Boolean , Row , MutablePair [Boolean , Row ]](rdd, part)
90
+ shuffled.setSerializer(new SparkSqlSerializer (new SparkConf (false )))
91
+ shuffled.mapPartitions(_.take(limit).map(_._2))
84
92
}
85
93
}
86
94
0 commit comments