Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 10698e1

Browse files
chenghao-intelliancheng
authored andcommitted
[SPARK-7320] [SQL] Add Cube / Rollup for dataframe
Add `cube` & `rollup` for DataFrame For example: ```scala testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")) testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")) ``` Author: Cheng Hao <[email protected]> Closes apache#6257 from chenghao-intel/rollup and squashes the following commits: 7302319 [Cheng Hao] cancel the implicit keyword a66e38f [Cheng Hao] remove the unnecessary code changes a2869d4 [Cheng Hao] update the code as comments c441777 [Cheng Hao] update the code as suggested 84c9564 [Cheng Hao] Remove the CubedData & RollupedData 279584c [Cheng Hao] hiden the CubedData & RollupedData ef357e1 [Cheng Hao] Add Cube / Rollup for dataframe (cherry picked from commit 09265ad) Signed-off-by: Cheng Lian <[email protected]>
1 parent 8689339 commit 10698e1

File tree

3 files changed

+230
-28
lines changed

3 files changed

+230
-28
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,53 @@ class DataFrame private[sql](
685685
* @since 1.3.0
686686
*/
687687
@scala.annotation.varargs
688-
def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr))
688+
def groupBy(cols: Column*): GroupedData = {
689+
GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
690+
}
691+
692+
/**
693+
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
694+
* so we can run aggregation on them.
695+
* See [[GroupedData]] for all the available aggregate functions.
696+
*
697+
* {{{
698+
* // Compute the average for all numeric columns rolluped by department and group.
699+
* df.rollup($"department", $"group").avg()
700+
*
701+
* // Compute the max age and average salary, rolluped by department and gender.
702+
* df.rollup($"department", $"gender").agg(Map(
703+
* "salary" -> "avg",
704+
* "age" -> "max"
705+
* ))
706+
* }}}
707+
* @group dfops
708+
* @since 1.4.0
709+
*/
710+
@scala.annotation.varargs
711+
def rollup(cols: Column*): GroupedData = {
712+
GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
713+
}
714+
715+
/**
716+
* Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
717+
* so we can run aggregation on them.
718+
* See [[GroupedData]] for all the available aggregate functions.
719+
*
720+
* {{{
721+
* // Compute the average for all numeric columns cubed by department and group.
722+
* df.cube($"department", $"group").avg()
723+
*
724+
* // Compute the max age and average salary, cubed by department and gender.
725+
* df.cube($"department", $"gender").agg(Map(
726+
* "salary" -> "avg",
727+
* "age" -> "max"
728+
* ))
729+
* }}}
730+
* @group dfops
731+
* @since 1.4.0
732+
*/
733+
@scala.annotation.varargs
734+
def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
689735

690736
/**
691737
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -710,7 +756,61 @@ class DataFrame private[sql](
710756
@scala.annotation.varargs
711757
def groupBy(col1: String, cols: String*): GroupedData = {
712758
val colNames: Seq[String] = col1 +: cols
713-
new GroupedData(this, colNames.map(colName => resolve(colName)))
759+
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
760+
}
761+
762+
/**
763+
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
764+
* so we can run aggregation on them.
765+
* See [[GroupedData]] for all the available aggregate functions.
766+
*
767+
* This is a variant of rollup that can only group by existing columns using column names
768+
* (i.e. cannot construct expressions).
769+
*
770+
* {{{
771+
* // Compute the average for all numeric columns rolluped by department and group.
772+
* df.rollup("department", "group").avg()
773+
*
774+
* // Compute the max age and average salary, rolluped by department and gender.
775+
* df.rollup($"department", $"gender").agg(Map(
776+
* "salary" -> "avg",
777+
* "age" -> "max"
778+
* ))
779+
* }}}
780+
* @group dfops
781+
* @since 1.4.0
782+
*/
783+
@scala.annotation.varargs
784+
def rollup(col1: String, cols: String*): GroupedData = {
785+
val colNames: Seq[String] = col1 +: cols
786+
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
787+
}
788+
789+
/**
790+
* Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
791+
* so we can run aggregation on them.
792+
* See [[GroupedData]] for all the available aggregate functions.
793+
*
794+
* This is a variant of cube that can only group by existing columns using column names
795+
* (i.e. cannot construct expressions).
796+
*
797+
* {{{
798+
* // Compute the average for all numeric columns cubed by department and group.
799+
* df.cube("department", "group").avg()
800+
*
801+
* // Compute the max age and average salary, cubed by department and gender.
802+
* df.cube($"department", $"gender").agg(Map(
803+
* "salary" -> "avg",
804+
* "age" -> "max"
805+
* ))
806+
* }}}
807+
* @group dfops
808+
* @since 1.4.0
809+
*/
810+
@scala.annotation.varargs
811+
def cube(col1: String, cols: String*): GroupedData = {
812+
val colNames: Seq[String] = col1 +: cols
813+
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
714814
}
715815

716816
/**

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 66 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,40 @@ import scala.language.implicitConversions
2323
import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.catalyst.analysis.Star
2525
import org.apache.spark.sql.catalyst.expressions._
26-
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
26+
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
2727
import org.apache.spark.sql.types.NumericType
2828

29+
/**
30+
* Companion object for GroupedData
31+
*/
32+
private[sql] object GroupedData {
33+
def apply(
34+
df: DataFrame,
35+
groupingExprs: Seq[Expression],
36+
groupType: GroupType): GroupedData = {
37+
new GroupedData(df, groupingExprs, groupType: GroupType)
38+
}
39+
40+
/**
41+
* The Grouping Type
42+
*/
43+
trait GroupType
44+
45+
/**
46+
* To indicate it's the GroupBy
47+
*/
48+
object GroupByType extends GroupType
49+
50+
/**
51+
* To indicate it's the CUBE
52+
*/
53+
object CubeType extends GroupType
54+
55+
/**
56+
* To indicate it's the ROLLUP
57+
*/
58+
object RollupType extends GroupType
59+
}
2960

3061
/**
3162
* :: Experimental ::
@@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType
3465
* @since 1.3.0
3566
*/
3667
@Experimental
37-
class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
68+
class GroupedData protected[sql](
69+
df: DataFrame,
70+
groupingExprs: Seq[Expression],
71+
private val groupType: GroupedData.GroupType) {
3872

39-
private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
40-
val namedGroupingExprs = groupingExprs.map {
41-
case expr: NamedExpression => expr
42-
case expr: Expression => Alias(expr, expr.prettyString)()
73+
private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
74+
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
75+
val retainedExprs = groupingExprs.map {
76+
case expr: NamedExpression => expr
77+
case expr: Expression => Alias(expr, expr.prettyString)()
78+
}
79+
retainedExprs ++ aggExprs
80+
} else {
81+
aggExprs
82+
}
83+
84+
groupType match {
85+
case GroupedData.GroupByType =>
86+
DataFrame(
87+
df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
88+
case GroupedData.RollupType =>
89+
DataFrame(
90+
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
91+
case GroupedData.CubeType =>
92+
DataFrame(
93+
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
4394
}
44-
DataFrame(
45-
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
4695
}
4796

4897
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
49-
: Seq[NamedExpression] = {
98+
: DataFrame = {
5099

51100
val columnExprs = if (colNames.isEmpty) {
52101
// No columns specified. Use all numeric columns.
@@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
63112
namedExpr
64113
}
65114
}
66-
columnExprs.map { c =>
115+
toDF(columnExprs.map { c =>
67116
val a = f(c)
68117
Alias(a, a.prettyString)()
69-
}
118+
})
70119
}
71120

72121
private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
119168
* @since 1.3.0
120169
*/
121170
def agg(exprs: Map[String, String]): DataFrame = {
122-
exprs.map { case (colName, expr) =>
171+
toDF(exprs.map { case (colName, expr) =>
123172
val a = strToExpr(expr)(df(colName).expr)
124173
Alias(a, a.prettyString)()
125-
}.toSeq
174+
}.toSeq)
126175
}
127176

128177
/**
@@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
175224
*/
176225
@scala.annotation.varargs
177226
def agg(expr: Column, exprs: Column*): DataFrame = {
178-
val aggExprs = (expr +: exprs).map(_.expr).map {
227+
toDF((expr +: exprs).map(_.expr).map {
179228
case expr: NamedExpression => expr
180229
case expr: Expression => Alias(expr, expr.prettyString)()
181-
}
182-
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
183-
val retainedExprs = groupingExprs.map {
184-
case expr: NamedExpression => expr
185-
case expr: Expression => Alias(expr, expr.prettyString)()
186-
}
187-
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
188-
} else {
189-
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
190-
}
230+
})
191231
}
192232

193233
/**
@@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
196236
*
197237
* @since 1.3.0
198238
*/
199-
def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
239+
def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")()))
200240

201241
/**
202242
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
@@ -256,5 +296,5 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
256296
@scala.annotation.varargs
257297
def sum(colNames: String*): DataFrame = {
258298
aggregateNumericColumns(colNames:_*)(Sum)
259-
}
299+
}
260300
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive
19+
20+
import org.apache.spark.sql.QueryTest
21+
import org.apache.spark.sql.functions._
22+
import org.apache.spark.sql.hive.test.TestHive
23+
import org.apache.spark.sql.hive.test.TestHive._
24+
import org.apache.spark.sql.hive.test.TestHive.implicits._
25+
26+
case class TestData2Int(a: Int, b: Int)
27+
28+
// TODO ideally we should put the test suite into the package `sql`, as
29+
// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
30+
// support the `cube` or `rollup` yet.
31+
class HiveDataFrameAnalyticsSuite extends QueryTest {
32+
val testData =
33+
TestHive.sparkContext.parallelize(
34+
TestData2Int(1, 2) ::
35+
TestData2Int(2, 4) :: Nil).toDF()
36+
37+
testData.registerTempTable("mytable")
38+
39+
test("rollup") {
40+
checkAnswer(
41+
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
42+
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect()
43+
)
44+
45+
checkAnswer(
46+
testData.rollup("a", "b").agg(sum("b")),
47+
sql("select a, b, sum(b) from mytable group by a, b with rollup").collect()
48+
)
49+
}
50+
51+
test("cube") {
52+
checkAnswer(
53+
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
54+
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect()
55+
)
56+
57+
checkAnswer(
58+
testData.cube("a", "b").agg(sum("b")),
59+
sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
60+
)
61+
}
62+
}

0 commit comments

Comments
 (0)