Skip to content

Commit 0a4844f

Browse files
committed
[SPARK-7462] By default retain group by columns in aggregate
Updated Java, Scala, Python, and R. Author: Reynold Xin <[email protected]> Author: Shivaram Venkataraman <[email protected]> Closes apache#5996 from rxin/groupby-retain and squashes the following commits: aac7119 [Reynold Xin] Merge branch 'groupby-retain' of github.com:rxin/spark into groupby-retain f6858f6 [Reynold Xin] Merge branch 'master' into groupby-retain 5f923c0 [Reynold Xin] Merge pull request alteryx#15 from shivaram/sparkr-groupby-retrain c1de670 [Shivaram Venkataraman] Revert workaround in SparkR to retain grouped cols Based on reverting code added in commit amplab-extras@9a6be74 b8b87e1 [Reynold Xin] Fixed DataFrameJoinSuite. d910141 [Reynold Xin] Updated rest of the files 1e6e666 [Reynold Xin] [SPARK-7462] By default retain group by columns in aggregate
1 parent 1b46556 commit 0a4844f

File tree

10 files changed

+218
-172
lines changed

10 files changed

+218
-172
lines changed

R/pkg/R/group.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ setMethod("agg",
102102
}
103103
}
104104
jcols <- lapply(cols, function(c) { c@jc })
105-
# the GroupedData.agg(col, cols*) API does not contain grouping Column
106-
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "aggWithGrouping",
107-
x@sgd, listToSeq(jcols))
105+
sdf <- callJMethod(x@sgd, "agg", jcols[[1]], listToSeq(jcols[-1]))
108106
} else {
109107
stop("agg can only support Column or character")
110108
}

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ def agg(self, *exprs):
10691069
10701070
>>> from pyspark.sql import functions as F
10711071
>>> gdf.agg(F.min(df.age)).collect()
1072-
[Row(MIN(age)=2), Row(MIN(age)=5)]
1072+
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
10731073
"""
10741074
assert exprs, "exprs should not be empty"
10751075
if len(exprs) == 1 and isinstance(exprs[0], dict):

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
135135
}
136136

137137
/**
138-
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
139-
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
138+
* Compute aggregates by specifying a series of aggregate columns. Note that this function by
139+
* default retains the grouping columns in its output. To not retain grouping columns, set
140+
* `spark.sql.retainGroupColumns` to false.
140141
*
141142
* The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
142143
*
@@ -158,7 +159,15 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
158159
case expr: NamedExpression => expr
159160
case expr: Expression => Alias(expr, expr.prettyString)()
160161
}
161-
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
162+
if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
163+
val retainedExprs = groupingExprs.map {
164+
case expr: NamedExpression => expr
165+
case expr: Expression => Alias(expr, expr.prettyString)()
166+
}
167+
DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ aggExprs, df.logicalPlan))
168+
} else {
169+
DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan))
170+
}
162171
}
163172

164173
/**

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ private[spark] object SQLConf {
7474
// See SPARK-6231.
7575
val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = "spark.sql.selfJoinAutoResolveAmbiguity"
7676

77+
// Whether to retain group by columns or not in GroupedData.agg.
78+
val DATAFRAME_RETAIN_GROUP_COLUMNS = "spark.sql.retainGroupColumns"
79+
7780
val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
7881

7982
val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"
@@ -242,6 +245,9 @@ private[sql] class SQLConf extends Serializable with CatalystConf {
242245

243246
private[spark] def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
244247
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY, "true").toBoolean
248+
249+
private[spark] def dataFrameRetainGroupColumns: Boolean =
250+
getConf(DATAFRAME_RETAIN_GROUP_COLUMNS, "true").toBoolean
245251

246252
/** ********************** SQLConf functionality methods ************ */
247253

sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,6 @@ private[r] object SQLUtils {
7272
sqlContext.createDataFrame(rowRDD, schema)
7373
}
7474

75-
// A helper to include grouping columns in Agg()
76-
def aggWithGrouping(gd: GroupedData, exprs: Column*): DataFrame = {
77-
val aggExprs = exprs.map { col =>
78-
col.expr match {
79-
case expr: NamedExpression => expr
80-
case expr: Expression => Alias(expr, expr.simpleString)()
81-
}
82-
}
83-
gd.toDF(aggExprs)
84-
}
85-
8675
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
8776
df.map(r => rowToRBytes(r))
8877
}

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ private[sql] object StatFunctions extends Logging {
104104
/** Generate a table of frequencies for the elements of two columns. */
105105
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
106106
val tableName = s"${col1}_$col2"
107-
val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e6.toInt)
107+
val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
108108
if (counts.length == 1e6.toInt) {
109109
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
110110
"the pairs. Please try reducing the amount of distinct items in your columns.")
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.TestData._
21+
import org.apache.spark.sql.functions._
22+
import org.apache.spark.sql.test.TestSQLContext
23+
import org.apache.spark.sql.test.TestSQLContext.implicits._
24+
import org.apache.spark.sql.types.DecimalType
25+
26+
27+
class DataFrameAggregateSuite extends QueryTest {
28+
29+
test("groupBy") {
30+
checkAnswer(
31+
testData2.groupBy("a").agg(sum($"b")),
32+
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
33+
)
34+
checkAnswer(
35+
testData2.groupBy("a").agg(sum($"b").as("totB")).agg(sum('totB)),
36+
Row(9)
37+
)
38+
checkAnswer(
39+
testData2.groupBy("a").agg(count("*")),
40+
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
41+
)
42+
checkAnswer(
43+
testData2.groupBy("a").agg(Map("*" -> "count")),
44+
Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
45+
)
46+
checkAnswer(
47+
testData2.groupBy("a").agg(Map("b" -> "sum")),
48+
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
49+
)
50+
51+
val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
52+
.toDF("key", "value1", "value2", "rest")
53+
54+
checkAnswer(
55+
df1.groupBy("key").min(),
56+
df1.groupBy("key").min("value1", "value2").collect()
57+
)
58+
checkAnswer(
59+
df1.groupBy("key").min("value2"),
60+
Seq(Row("a", 0), Row("b", 4))
61+
)
62+
}
63+
64+
test("spark.sql.retainGroupColumns config") {
65+
checkAnswer(
66+
testData2.groupBy("a").agg(sum($"b")),
67+
Seq(Row(1, 3), Row(2, 3), Row(3, 3))
68+
)
69+
70+
TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false")
71+
checkAnswer(
72+
testData2.groupBy("a").agg(sum($"b")),
73+
Seq(Row(3), Row(3), Row(3))
74+
)
75+
TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true")
76+
}
77+
78+
test("agg without groups") {
79+
checkAnswer(
80+
testData2.agg(sum('b)),
81+
Row(9)
82+
)
83+
}
84+
85+
test("average") {
86+
checkAnswer(
87+
testData2.agg(avg('a)),
88+
Row(2.0))
89+
90+
// Also check mean
91+
checkAnswer(
92+
testData2.agg(mean('a)),
93+
Row(2.0))
94+
95+
checkAnswer(
96+
testData2.agg(avg('a), sumDistinct('a)), // non-partial
97+
Row(2.0, 6.0) :: Nil)
98+
99+
checkAnswer(
100+
decimalData.agg(avg('a)),
101+
Row(new java.math.BigDecimal(2.0)))
102+
checkAnswer(
103+
decimalData.agg(avg('a), sumDistinct('a)), // non-partial
104+
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
105+
106+
checkAnswer(
107+
decimalData.agg(avg('a cast DecimalType(10, 2))),
108+
Row(new java.math.BigDecimal(2.0)))
109+
// non-partial
110+
checkAnswer(
111+
decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))),
112+
Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil)
113+
}
114+
115+
test("null average") {
116+
checkAnswer(
117+
testData3.agg(avg('b)),
118+
Row(2.0))
119+
120+
checkAnswer(
121+
testData3.agg(avg('b), countDistinct('b)),
122+
Row(2.0, 1))
123+
124+
checkAnswer(
125+
testData3.agg(avg('b), sumDistinct('b)), // non-partial
126+
Row(2.0, 2.0))
127+
}
128+
129+
test("zero average") {
130+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
131+
checkAnswer(
132+
emptyTableData.agg(avg('a)),
133+
Row(null))
134+
135+
checkAnswer(
136+
emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial
137+
Row(null, null))
138+
}
139+
140+
test("count") {
141+
assert(testData2.count() === testData2.map(_ => 1).count())
142+
143+
checkAnswer(
144+
testData2.agg(count('a), sumDistinct('a)), // non-partial
145+
Row(6, 6.0))
146+
}
147+
148+
test("null count") {
149+
checkAnswer(
150+
testData3.groupBy('a).agg(count('b)),
151+
Seq(Row(1,0), Row(2, 1))
152+
)
153+
154+
checkAnswer(
155+
testData3.groupBy('a).agg(count('a + 'b)),
156+
Seq(Row(1,0), Row(2, 1))
157+
)
158+
159+
checkAnswer(
160+
testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)),
161+
Row(2, 1, 2, 2, 1)
162+
)
163+
164+
checkAnswer(
165+
testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial
166+
Row(1, 1, 2)
167+
)
168+
}
169+
170+
test("zero count") {
171+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
172+
assert(emptyTableData.count() === 0)
173+
174+
checkAnswer(
175+
emptyTableData.agg(count('a), sumDistinct('a)), // non-partial
176+
Row(0, null))
177+
}
178+
179+
test("zero sum") {
180+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
181+
checkAnswer(
182+
emptyTableData.agg(sum('a)),
183+
Row(null))
184+
}
185+
186+
test("zero sum distinct") {
187+
val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b")
188+
checkAnswer(
189+
emptyTableData.agg(sumDistinct('a)),
190+
Row(null))
191+
}
192+
193+
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ class DataFrameJoinSuite extends QueryTest {
7777
df.join(df, df("key") === df("key") && df("value") === 1),
7878
Row(1, "1", 1, "1") :: Nil)
7979

80-
val left = df.groupBy("key").agg($"key", count("*"))
81-
val right = df.groupBy("key").agg($"key", sum("key"))
80+
val left = df.groupBy("key").agg(count("*"))
81+
val right = df.groupBy("key").agg(sum("key"))
8282
checkAnswer(
8383
left.join(right, left("key") === right("key")),
8484
Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil)

0 commit comments

Comments
 (0)