Skip to content

Commit d910141

Browse files
committed
Updated rest of the files
1 parent 1e6e666 commit d910141

File tree

4 files changed

+7
-6
lines changed

4 files changed

+7
-6
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1069,7 +1069,7 @@ def agg(self, *exprs):
10691069
10701070
>>> from pyspark.sql import functions as F
10711071
>>> gdf.agg(F.min(df.age)).collect()
1072-
[Row(MIN(age)=2), Row(MIN(age)=5)]
1072+
[Row(name=u'Alice', MIN(age)=2), Row(name=u'Bob', MIN(age)=5)]
10731073
"""
10741074
assert exprs, "exprs should not be empty"
10751075
if len(exprs) == 1 and isinstance(exprs[0], dict):

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression])
135135
}
136136

137137
/**
138-
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
139-
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
138+
* Compute aggregates by specifying a series of aggregate columns. Note that this function by
139+
* default retains the grouping columns in its output. To not retain grouping columns, set
140+
* `spark.sql.retainGroupColumns` to false.
140141
*
141142
* The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
142143
*

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ private[sql] object StatFunctions extends Logging {
102102
/** Generate a table of frequencies for the elements of two columns. */
103103
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
104104
val tableName = s"${col1}_$col2"
105-
val counts = df.groupBy(col1, col2).agg(col(col1), col(col2), count("*")).take(1e6.toInt)
105+
val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt)
106106
if (counts.length == 1e6.toInt) {
107107
logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " +
108108
"the pairs. Please try reducing the amount of distinct items in your columns.")

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class DataFrameSuite extends QueryTest {
6262
val df = Seq((1,(1,1))).toDF()
6363

6464
checkAnswer(
65-
df.groupBy("_1").agg(col("_1"), sum("_2._1")).toDF("key", "total"),
65+
df.groupBy("_1").agg(sum("_2._1")).toDF("key", "total"),
6666
Row(1, 1) :: Nil)
6767
}
6868

@@ -127,7 +127,7 @@ class DataFrameSuite extends QueryTest {
127127
df2
128128
.select('_1 as 'letter, 'number)
129129
.groupBy('letter)
130-
.agg('letter, countDistinct('number)),
130+
.agg(countDistinct('number)),
131131
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
132132
)
133133
}

0 commit comments

Comments
 (0)