@@ -26,6 +26,25 @@ import org.apache.spark.sql.catalyst.expressions._
26
26
import org .apache .spark .sql .catalyst .plans .logical .{Rollup , Cube , Aggregate }
27
27
import org .apache .spark .sql .types .NumericType
28
28
29
+ /**
30
+ * The Grouping Type
31
+ */
32
+ sealed private [sql] trait GroupType
33
+
34
+ /**
35
+ * To indicate it's the GroupBy
36
+ */
37
+ private [sql] object GroupByType extends GroupType
38
+
39
+ /**
40
+ * To indicate it's the CUBE
41
+ */
42
+ private [sql] object CubeType extends GroupType
43
+
44
+ /**
45
+ * To indicate it's the ROLLUP
46
+ */
47
+ private [sql] object RollupType extends GroupType
29
48
30
49
/**
31
50
* :: Experimental ::
@@ -34,10 +53,13 @@ import org.apache.spark.sql.types.NumericType
34
53
* @since 1.3.0
35
54
*/
36
55
@ Experimental
37
- class GroupedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ]) {
56
+ class GroupedData protected [sql](
57
+ df : DataFrame ,
58
+ groupingExprs : Seq [Expression ],
59
+ groupType : GroupType ) {
38
60
39
61
protected def aggregateExpressions (aggrExprs : Seq [NamedExpression ])
40
- : Seq [NamedExpression ] = {
62
+ : Seq [NamedExpression ] = {
41
63
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
42
64
val retainedExprs = groupingExprs.map {
43
65
case expr : NamedExpression => expr
@@ -50,8 +72,17 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
50
72
}
51
73
52
74
protected [sql] implicit def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
53
- DataFrame (
54
- df.sqlContext, Aggregate (groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan))
75
+ groupType match {
76
+ case GroupByType =>
77
+ DataFrame (
78
+ df.sqlContext, Aggregate (groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan))
79
+ case RollupType =>
80
+ DataFrame (
81
+ df.sqlContext, Rollup (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
82
+ case CubeType =>
83
+ DataFrame (
84
+ df.sqlContext, Cube (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
85
+ }
55
86
}
56
87
57
88
private [this ] def aggregateNumericColumns (colNames : String * )(f : Expression => Expression )
@@ -259,27 +290,3 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
259
290
}
260
291
261
292
}
262
-
263
- /**
264
- * A set of methods for aggregations on a [[DataFrame ]] cube, created by [[DataFrame.cube ]].
265
- */
266
- private [sql] class CubedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ])
267
- extends GroupedData (df, groupingExprs) {
268
-
269
- protected [sql] implicit override def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
270
- DataFrame (
271
- df.sqlContext, Cube (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
272
- }
273
- }
274
-
275
- /**
276
- * A set of methods for aggregations on a [[DataFrame ]] rollup, created by [[DataFrame.rollup ]].
277
- */
278
- private [sql] class RollupedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ])
279
- extends GroupedData (df, groupingExprs) {
280
-
281
- protected [sql] implicit override def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
282
- DataFrame (
283
- df.sqlContext, Rollup (groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
284
- }
285
- }
0 commit comments