Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit ef357e1

Browse files
Add Cube / Rollup for dataframe
1 parent c9fa870 commit ef357e1

File tree

3 files changed

+219
-17
lines changed

3 files changed

+219
-17
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,6 +687,46 @@ class DataFrame private[sql](
687687
@scala.annotation.varargs
688688
def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr))
689689

690+
/**
691+
* Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them.
692+
* See [[GroupedData]] for all the available aggregate functions.
693+
*
694+
* {{{
695+
* // Compute the average for all numeric columns rolluped by department and group.
696+
* df.rollup($"department", $"group").avg()
697+
*
698+
* // Compute the max age and average salary, rolluped by department and gender.
699+
* df.rollup($"department", $"gender").agg(Map(
700+
* "salary" -> "avg",
701+
* "age" -> "max"
702+
* ))
703+
* }}}
704+
* @group dfops
705+
* @since 1.4.0
706+
*/
707+
@scala.annotation.varargs
708+
def rollup(cols: Column*): GroupedData = new RollupedData(this, cols.map(_.expr))
709+
710+
/**
711+
* Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them.
712+
* See [[GroupedData]] for all the available aggregate functions.
713+
*
714+
* {{{
715+
* // Compute the average for all numeric columns cubed by department and group.
716+
* df.cube($"department", $"group").avg()
717+
*
718+
* // Compute the max age and average salary, cubed by department and gender.
719+
* df.cube($"department", $"gender").agg(Map(
720+
* "salary" -> "avg",
721+
* "age" -> "max"
722+
* ))
723+
* }}}
724+
* @group dfops
725+
* @since 1.4.0
726+
*/
727+
@scala.annotation.varargs
728+
def cube(cols: Column*): GroupedData = new CubedData(this, cols.map(_.expr))
729+
690730
/**
691731
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
692732
* See [[GroupedData]] for all the available aggregate functions.
@@ -713,6 +753,58 @@ class DataFrame private[sql](
713753
new GroupedData(this, colNames.map(colName => resolve(colName)))
714754
}
715755

756+
/**
757+
* Rollup the [[DataFrame]] using the specified columns, so we can run aggregation on them.
758+
* See [[GroupedData]] for all the available aggregate functions.
759+
*
760+
* This is a variant of groupBy that can only group by existing columns using column names
761+
* (i.e. cannot construct expressions).
762+
*
763+
* {{{
764+
* // Compute the average for all numeric columns rolluped by department and group.
765+
* df.rollup("department", "group").avg()
766+
*
767+
* // Compute the max age and average salary, rolluped by department and gender.
768+
* df.rollup($"department", $"gender").agg(Map(
769+
* "salary" -> "avg",
770+
* "age" -> "max"
771+
* ))
772+
* }}}
773+
* @group dfops
774+
* @since 1.4.0
775+
*/
776+
@scala.annotation.varargs
777+
def rollup(col1: String, cols: String*): GroupedData = {
778+
val colNames: Seq[String] = col1 +: cols
779+
new RollupedData(this, colNames.map(colName => resolve(colName)))
780+
}
781+
782+
/**
783+
* Cube the [[DataFrame]] using the specified columns, so we can run aggregation on them.
784+
* See [[GroupedData]] for all the available aggregate functions.
785+
*
786+
* This is a variant of groupBy that can only group by existing columns using column names
787+
* (i.e. cannot construct expressions).
788+
*
789+
* {{{
790+
* // Compute the average for all numeric columns cubed by department and group.
791+
* df.cube("department", "group").avg()
792+
*
793+
* // Compute the max age and average salary, cubed by department and gender.
794+
* df.cube($"department", $"gender").agg(Map(
795+
* "salary" -> "avg",
796+
* "age" -> "max"
797+
* ))
798+
* }}}
799+
* @group dfops
800+
* @since 1.4.0
801+
*/
802+
@scala.annotation.varargs
803+
def cube(col1: String, cols: String*): GroupedData = {
804+
val colNames: Seq[String] = col1 +: cols
805+
new CubedData(this, colNames.map(colName => resolve(colName)))
806+
}
807+
716808
/**
717809
* (Scala-specific) Aggregates on the entire [[DataFrame]] without groups.
718810
* {{{

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.language.implicitConversions
2323
import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.catalyst.analysis.Star
2525
import org.apache.spark.sql.catalyst.expressions._
26-
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
26+
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
2727
import org.apache.spark.sql.types.NumericType
2828

2929

@@ -36,13 +36,22 @@ import org.apache.spark.sql.types.NumericType
3636
@Experimental
3737
class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
3838

39-
private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
40-
val namedGroupingExprs = groupingExprs.map {
41-
case expr: NamedExpression => expr
42-
case expr: Expression => Alias(expr, expr.prettyString)()
39+
protected def aggregateExpressions(aggrExprs: Seq[NamedExpression])
40+
: Seq[NamedExpression] = {
41+
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
42+
val retainedExprs = groupingExprs.map {
43+
case expr: NamedExpression => expr
44+
case expr: Expression => Alias(expr, expr.prettyString)()
45+
}
46+
retainedExprs ++ aggrExprs
47+
} else {
48+
aggrExprs
4349
}
50+
}
51+
52+
protected[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
4453
DataFrame(
45-
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
54+
df.sqlContext, Aggregate(groupingExprs, aggregateExpressions(aggExprs), df.logicalPlan))
4655
}
4756

4857
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
@@ -175,19 +184,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
175184
*/
176185
@scala.annotation.varargs
177186
def agg(expr: Column, exprs: Column*): DataFrame = {
178-
val aggExprs = (expr +: exprs).map(_.expr).map {
187+
(expr +: exprs).map(_.expr).map {
179188
case expr: NamedExpression => expr
180189
case expr: Expression => Alias(expr, expr.prettyString)()
181190
}
182-
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
183-
val retainedExprs = groupingExprs.map {
184-
case expr: NamedExpression => expr
185-
case expr: Expression => Alias(expr, expr.prettyString)()
186-
}
187-
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
188-
} else {
189-
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
190-
}
191191
}
192192

193193
/**
@@ -256,5 +256,38 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
256256
@scala.annotation.varargs
257257
def sum(colNames: String*): DataFrame = {
258258
aggregateNumericColumns(colNames:_*)(Sum)
259-
}
259+
}
260+
261+
}
262+
263+
/**
264+
* :: Experimental ::
265+
* A set of methods for aggregations on a [[DataFrame]] cube, created by [[DataFrame.cube]].
266+
*
267+
* @since 1.4.0
268+
*/
269+
@Experimental
270+
class CubedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
271+
extends GroupedData(df, groupingExprs) {
272+
273+
protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
274+
DataFrame(
275+
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
276+
}
277+
}
278+
279+
/**
280+
* :: Experimental ::
281+
* A set of methods for aggregations on a [[DataFrame]] rollup, created by [[DataFrame.rollup]].
282+
*
283+
* @since 1.4.0
284+
*/
285+
@Experimental
286+
class RollupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
287+
extends GroupedData(df, groupingExprs) {
288+
289+
protected[sql] implicit override def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
290+
DataFrame(
291+
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregateExpressions(aggExprs)))
292+
}
260293
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive
19+
20+
import org.apache.spark.sql.QueryTest
21+
import org.apache.spark.sql.functions._
22+
import org.apache.spark.sql.hive.test.TestHive
23+
import org.apache.spark.sql.hive.test.TestHive._
24+
import org.apache.spark.sql.hive.test.TestHive.implicits._
25+
26+
case class TestData2Int(a: Int, b: Int)
27+
28+
class HiveDataFrameAnalyticsSuiteSuite extends QueryTest {
29+
val testData =
30+
TestHive.sparkContext.parallelize(
31+
TestData2Int(1, 2) ::
32+
TestData2Int(2, 4) :: Nil).toDF()
33+
34+
testData.registerTempTable("mytable")
35+
36+
test("rollup") {
37+
checkAnswer(
38+
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
39+
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with rollup").collect()
40+
)
41+
42+
checkAnswer(
43+
testData.rollup("a", "b").agg(sum("b")),
44+
sql("select a, b, sum(b) from mytable group by a, b with rollup").collect()
45+
)
46+
}
47+
48+
test("cube") {
49+
checkAnswer(
50+
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
51+
sql("select a + b, b, sum(a - b) from mytable group by a + b, b with cube").collect()
52+
)
53+
54+
checkAnswer(
55+
testData.cube("a", "b").agg(sum("b")),
56+
sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
57+
)
58+
}
59+
60+
test("spark.sql.retainGroupColumns config") {
61+
val oldConf = conf.getConf("spark.sql.retainGroupColumns", "true")
62+
try {
63+
conf.setConf("spark.sql.retainGroupColumns", "false")
64+
checkAnswer(
65+
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
66+
sql("select sum(a-b) from mytable group by a + b, b with rollup").collect()
67+
)
68+
69+
checkAnswer(
70+
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
71+
sql("select sum(a-b) from mytable group by a + b, b with cube").collect()
72+
)
73+
} finally {
74+
conf.setConf("spark.sql.retainGroupColumns", oldConf)
75+
}
76+
}
77+
}

0 commit comments

Comments
 (0)