@@ -23,9 +23,40 @@ import scala.language.implicitConversions
23
23
import org .apache .spark .annotation .Experimental
24
24
import org .apache .spark .sql .catalyst .analysis .Star
25
25
import org .apache .spark .sql .catalyst .expressions ._
26
- import org .apache .spark .sql .catalyst .plans .logical .Aggregate
26
+ import org .apache .spark .sql .catalyst .plans .logical .{ Rollup , Cube , Aggregate }
27
27
import org .apache .spark .sql .types .NumericType
28
28
29
+ /**
30
+ * Companion object for GroupedData
31
+ */
32
+ private [sql] object GroupedData {
33
+ def apply (
34
+ df : DataFrame ,
35
+ groupingExprs : Seq [Expression ],
36
+ groupType : GroupType ): GroupedData = {
37
+ new GroupedData (df, groupingExprs, groupType : GroupType )
38
+ }
39
+
40
+ /**
41
+ * The Grouping Type
42
+ */
43
+ trait GroupType
44
+
45
+ /**
46
+ * To indicate it's the GroupBy
47
+ */
48
+ object GroupByType extends GroupType
49
+
50
+ /**
51
+ * To indicate it's the CUBE
52
+ */
53
+ object CubeType extends GroupType
54
+
55
+ /**
56
+ * To indicate it's the ROLLUP
57
+ */
58
+ object RollupType extends GroupType
59
+ }
29
60
30
61
/**
31
62
* :: Experimental ::
@@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType
34
65
* @since 1.3.0
35
66
*/
36
67
@ Experimental
37
- class GroupedData protected [sql](df : DataFrame , groupingExprs : Seq [Expression ]) {
68
+ class GroupedData protected [sql](
69
+ df : DataFrame ,
70
+ groupingExprs : Seq [Expression ],
71
+ private val groupType : GroupedData .GroupType ) {
38
72
39
- private [sql] implicit def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
40
- val namedGroupingExprs = groupingExprs.map {
41
- case expr : NamedExpression => expr
42
- case expr : Expression => Alias (expr, expr.prettyString)()
73
+ private [this ] def toDF (aggExprs : Seq [NamedExpression ]): DataFrame = {
74
+ val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
75
+ val retainedExprs = groupingExprs.map {
76
+ case expr : NamedExpression => expr
77
+ case expr : Expression => Alias (expr, expr.prettyString)()
78
+ }
79
+ retainedExprs ++ aggExprs
80
+ } else {
81
+ aggExprs
82
+ }
83
+
84
+ groupType match {
85
+ case GroupedData .GroupByType =>
86
+ DataFrame (
87
+ df.sqlContext, Aggregate (groupingExprs, aggregates, df.logicalPlan))
88
+ case GroupedData .RollupType =>
89
+ DataFrame (
90
+ df.sqlContext, Rollup (groupingExprs, df.logicalPlan, aggregates))
91
+ case GroupedData .CubeType =>
92
+ DataFrame (
93
+ df.sqlContext, Cube (groupingExprs, df.logicalPlan, aggregates))
43
94
}
44
- DataFrame (
45
- df.sqlContext, Aggregate (groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
46
95
}
47
96
48
97
private [this ] def aggregateNumericColumns (colNames : String * )(f : Expression => Expression )
49
- : Seq [ NamedExpression ] = {
98
+ : DataFrame = {
50
99
51
100
val columnExprs = if (colNames.isEmpty) {
52
101
// No columns specified. Use all numeric columns.
@@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
63
112
namedExpr
64
113
}
65
114
}
66
- columnExprs.map { c =>
115
+ toDF( columnExprs.map { c =>
67
116
val a = f(c)
68
117
Alias (a, a.prettyString)()
69
- }
118
+ })
70
119
}
71
120
72
121
private [this ] def strToExpr (expr : String ): (Expression => Expression ) = {
@@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
119
168
* @since 1.3.0
120
169
*/
121
170
def agg (exprs : Map [String , String ]): DataFrame = {
122
- exprs.map { case (colName, expr) =>
171
+ toDF( exprs.map { case (colName, expr) =>
123
172
val a = strToExpr(expr)(df(colName).expr)
124
173
Alias (a, a.prettyString)()
125
- }.toSeq
174
+ }.toSeq)
126
175
}
127
176
128
177
/**
@@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
175
224
*/
176
225
@ scala.annotation.varargs
177
226
def agg (expr : Column , exprs : Column * ): DataFrame = {
178
- val aggExprs = (expr +: exprs).map(_.expr).map {
227
+ toDF( (expr +: exprs).map(_.expr).map {
179
228
case expr : NamedExpression => expr
180
229
case expr : Expression => Alias (expr, expr.prettyString)()
181
- }
182
- if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
183
- val retainedExprs = groupingExprs.map {
184
- case expr : NamedExpression => expr
185
- case expr : Expression => Alias (expr, expr.prettyString)()
186
- }
187
- DataFrame (df.sqlContext, Aggregate (groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
188
- } else {
189
- DataFrame (df.sqlContext, Aggregate (groupingExprs, aggExprs, df.logicalPlan))
190
- }
230
+ })
191
231
}
192
232
193
233
/**
@@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
196
236
*
197
237
* @since 1.3.0
198
238
*/
199
- def count (): DataFrame = Seq (Alias (Count (Literal (1 )), " count" )())
239
+ def count (): DataFrame = toDF( Seq (Alias (Count (Literal (1 )), " count" )() ))
200
240
201
241
/**
202
242
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
@@ -256,5 +296,5 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
256
296
@ scala.annotation.varargs
257
297
def sum (colNames : String * ): DataFrame = {
258
298
aggregateNumericColumns(colNames:_* )(Sum )
259
- }
299
+ }
260
300
}
0 commit comments