Skip to content

Commit e1d88aa

Browse files
update the code as suggested
1 parent 03bc3d9 commit e1d88aa

File tree

3 files changed

+69
-64
lines changed

3 files changed

+69
-64
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -685,10 +685,13 @@ class DataFrame private[sql](
685685
* @since 1.3.0
686686
*/
687687
@scala.annotation.varargs
688-
def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), GroupByType)
688+
def groupBy(cols: Column*): GroupedData = {
689+
GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
690+
}
689691

690692
/**
691-
* Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them.
693+
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
694+
* so we can run aggregation on them.
692695
* See [[GroupedData]] for all the available aggregate functions.
693696
*
694697
* {{{
@@ -705,10 +708,13 @@ class DataFrame private[sql](
705708
* @since 1.4.0
706709
*/
707710
@scala.annotation.varargs
708-
def rollup(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), RollupType)
711+
def rollup(cols: Column*): GroupedData = {
712+
GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
713+
}
709714

710715
/**
711-
* Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them.
716+
* Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
717+
* so we can run aggregation on them.
712718
* See [[GroupedData]] for all the available aggregate functions.
713719
*
714720
* {{{
@@ -725,7 +731,7 @@ class DataFrame private[sql](
725731
* @since 1.4.0
726732
*/
727733
@scala.annotation.varargs
728-
def cube(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr), CubeType)
734+
def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), GroupedData.CubeType)
729735

730736
/**
731737
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -750,14 +756,15 @@ class DataFrame private[sql](
750756
@scala.annotation.varargs
751757
def groupBy(col1: String, cols: String*): GroupedData = {
752758
val colNames: Seq[String] = col1 +: cols
753-
new GroupedData(this, colNames.map(colName => resolve(colName)), GroupByType)
759+
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.GroupByType)
754760
}
755761

756762
/**
757-
* Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them.
763+
* Create a multi-dimensional rollup for the current [[DataFrame]] using the specified columns,
764+
* so we can run aggregation on them.
758765
* See [[GroupedData]] for all the available aggregate functions.
759766
*
760-
* This is a variant of groupBy that can only group by existing columns using column names
767+
* This is a variant of rollup that can only group by existing columns using column names
761768
* (i.e. cannot construct expressions).
762769
*
763770
* {{{
@@ -776,14 +783,15 @@ class DataFrame private[sql](
776783
@scala.annotation.varargs
777784
def rollup(col1: String, cols: String*): GroupedData = {
778785
val colNames: Seq[String] = col1 +: cols
779-
new GroupedData(this, colNames.map(colName => resolve(colName)), RollupType)
786+
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.RollupType)
780787
}
781788

782789
/**
783-
* Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them.
790+
* Create a multi-dimensional cube for the current [[DataFrame]] using the specified columns,
791+
* so we can run aggregation on them.
784792
* See [[GroupedData]] for all the available aggregate functions.
785793
*
786-
* This is a variant of groupBy that can only group by existing columns using column names
794+
* This is a variant of cube that can only group by existing columns using column names
787795
* (i.e. cannot construct expressions).
788796
*
789797
* {{{
@@ -802,7 +810,7 @@ class DataFrame private[sql](
802810
@scala.annotation.varargs
803811
def cube(col1: String, cols: String*): GroupedData = {
804812
val colNames: Seq[String] = col1 +: cols
805-
new GroupedData(this, colNames.map(colName => resolve(colName)), CubeType)
813+
GroupedData(this, colNames.map(colName => resolve(colName)), GroupedData.CubeType)
806814
}
807815

808816
/**

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,31 @@ import org.apache.spark.sql.types.NumericType
3232
sealed private[sql] trait GroupType
3333

3434
/**
35-
* To indicate it's the GroupBy
35+
* Companion object for GroupedData
3636
*/
37-
private[sql] object GroupByType extends GroupType
37+
private[sql] object GroupedData {
38+
def apply(
39+
df: DataFrame,
40+
groupingExprs: Seq[Expression],
41+
groupType: GroupType): GroupedData = {
42+
new GroupedData(df, groupingExprs).withNewGroupType(groupType)
43+
}
3844

39-
/**
40-
* To indicate it's the CUBE
41-
*/
42-
private[sql] object CubeType extends GroupType
45+
/**
46+
* To indicate it's the GroupBy
47+
*/
48+
private[sql] object GroupByType extends GroupType
4349

44-
/**
45-
* To indicate it's the ROLLUP
46-
*/
47-
private[sql] object RollupType extends GroupType
50+
/**
51+
* To indicate it's the CUBE
52+
*/
53+
private[sql] object CubeType extends GroupType
54+
55+
/**
56+
* To indicate it's the ROLLUP
57+
*/
58+
private[sql] object RollupType extends GroupType
59+
}
4860

4961
/**
5062
* :: Experimental ::
@@ -53,35 +65,36 @@ private[sql] object RollupType extends GroupType
5365
* @since 1.3.0
5466
*/
5567
@Experimental
56-
class GroupedData protected[sql](
57-
df: DataFrame,
58-
groupingExprs: Seq[Expression],
59-
groupType: GroupType) {
68+
class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
6069

61-
protected def aggregateExpressions(aggrExprs: Seq[NamedExpression])
62-
: Seq[NamedExpression] = {
63-
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
64-
val retainedExprs = groupingExprs.map {
65-
case expr: NamedExpression => expr
66-
case expr: Expression => Alias(expr, expr.prettyString)()
67-
}
68-
retainedExprs ++ aggrExprs
69-
} else {
70-
aggrExprs
71-
}
70+
private var groupType: GroupType = _
71+
72+
private[sql] def withNewGroupType(groupType: GroupType): GroupedData = {
73+
this.groupType = groupType
74+
this
7275
}
7376

74-
protected[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
77+
private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
78+
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
79+
val retainedExprs = groupingExprs.map {
80+
case expr: NamedExpression => expr
81+
case expr: Expression => Alias(expr, expr.prettyString)()
82+
}
83+
retainedExprs ++ aggExprs
84+
} else {
85+
aggExprs
86+
}
87+
7588
groupType match {
76-
case GroupByType =>
89+
case GroupedData.GroupByType =>
7790
DataFrame(
78-
df.sqlContext, Aggregate(groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan))
79-
case RollupType =>
91+
df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
92+
case GroupedData.RollupType =>
8093
DataFrame(
81-
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
82-
case CubeType =>
94+
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
95+
case GroupedData.CubeType =>
8396
DataFrame(
84-
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
97+
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
8598
}
8699
}
87100

@@ -288,5 +301,4 @@ class GroupedData protected[sql](
288301
def sum(colNames: String*): DataFrame = {
289302
aggregateNumericColumns(colNames:_*)(Sum)
290303
}
291-
292304
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
2525

2626
case class TestData2Int(a: Int, b: Int)
2727

28-
class HiveDataFrameAnalyticsSuiteSuite extends QueryTest {
28+
// TODO ideally we should put the test suite into the package `sql`, as
29+
// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
30+
// support the `cube` or `rollup` yet.
31+
class HiveDataFrameAnalyticsSuite extends QueryTest {
2932
val testData =
3033
TestHive.sparkContext.parallelize(
3134
TestData2Int(1, 2) ::
@@ -56,22 +59,4 @@ class HiveDataFrameAnalyticsSuiteSuite extends QueryTest {
5659
sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
5760
)
5861
}
59-
60-
test("spark.sql.retainGroupColumns config") {
61-
val oldConf = conf.getConf("spark.sql.retainGroupColumns", "true")
62-
try {
63-
conf.setConf("spark.sql.retainGroupColumns", "false")
64-
checkAnswer(
65-
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
66-
sql("select sum(a-b) from mytable group by a + b, b with rollup").collect()
67-
)
68-
69-
checkAnswer(
70-
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
71-
sql("select sum(a-b) from mytable group by a + b, b with cube").collect()
72-
)
73-
} finally {
74-
conf.setConf("spark.sql.retainGroupColumns", oldConf)
75-
}
76-
}
7762
}

0 commit comments

Comments
 (0)