Skip to content

Commit 27ae625

Browse files
committed
unbiased standard deviation aggregation function
1 parent ebfd91c commit 27ae625

File tree

8 files changed

+158
-7
lines changed

8 files changed

+158
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,87 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
302302

303303
override val evaluateExpression = Cast(currentSum, resultType)
304304
}
305+
306+
/**
307+
* Calculates the unbiased Standard Deviation using the online formula here:
308+
* https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
309+
*/
310+
case class StandardDeviation(child: Expression) extends AlgebraicAggregate {
311+
312+
override def children: Seq[Expression] = child :: Nil
313+
314+
override def nullable: Boolean = true
315+
316+
// Return data type.
317+
override def dataType: DataType = resultType
318+
319+
// Expected input data type.
320+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
321+
322+
private lazy val resultType = child.dataType match {
323+
case DecimalType.Fixed(p, s) =>
324+
DecimalType.bounded(p + 4, s + 4)
325+
case _ => DoubleType
326+
}
327+
328+
private lazy val sumDataType = child.dataType match {
329+
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
330+
case _ => DoubleType
331+
}
332+
333+
private lazy val currentCount = AttributeReference("currentCount", LongType)()
334+
private lazy val currentAvg = AttributeReference("currentAverage", sumDataType)()
335+
private lazy val currentMk = AttributeReference("currentMoment", sumDataType)()
336+
337+
// the values should be updated in a special order, because they re-use each other
338+
override lazy val bufferAttributes = currentCount :: currentAvg :: currentMk :: Nil
339+
340+
override lazy val initialValues = Seq(
341+
/* currentCount = */ Literal(0L),
342+
/* currentAvg = */ Cast(Literal(0), sumDataType),
343+
/* currentMk = */ Cast(Literal(0), sumDataType)
344+
)
345+
346+
override lazy val updateExpressions = {
347+
val currentValue = Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)
348+
val deltaX = Subtract(currentValue, currentAvg)
349+
val updatedCount = If(IsNull(child), currentCount, currentCount + 1L)
350+
val updatedAvg = Add(currentAvg, Divide(deltaX, updatedCount))
351+
Seq(
352+
/* currentCount = */ updatedCount,
353+
/* currentAvg = */ If(IsNull(child), currentAvg, updatedAvg),
354+
/* currentMk = */ If(IsNull(child),
355+
currentMk, Add(currentMk, deltaX * Subtract(currentValue, updatedAvg)))
356+
)
357+
}
358+
359+
override lazy val mergeExpressions = {
360+
val totalCount = currentCount.left + currentCount.right
361+
val deltaX = currentAvg.left - currentAvg.right
362+
val deltaX2 = deltaX * deltaX
363+
val sumMoments = currentMk.left + currentMk.right
364+
val sumLeft = currentAvg.left * currentCount.left
365+
val sumRight = currentAvg.right * currentCount.right
366+
Seq(
367+
/* currentCount = */ totalCount,
368+
/* currentAvg = */ If(EqualTo(totalCount, Cast(Literal(0L), LongType)),
369+
Cast(Literal(0), sumDataType), (sumLeft + sumRight) / totalCount),
370+
/* currentMk = */ If(EqualTo(totalCount, Cast(Literal(0L), LongType)),
371+
Cast(Literal(0), sumDataType),
372+
sumMoments + deltaX2 * currentCount.left / totalCount * currentCount.right)
373+
)
374+
}
375+
376+
override lazy val evaluateExpression = {
377+
val count = If(EqualTo(currentCount, Cast(Literal(0L), LongType)),
378+
currentCount, currentCount - Cast(Literal(1L), LongType))
379+
child.dataType match {
380+
case DecimalType.Fixed(p, s) =>
381+
// increase the precision and scale to prevent precision loss
382+
val dt = DecimalType.bounded(p + 14, s + 4)
383+
Cast(Sqrt(Cast(currentMk, dt) / Cast(count, dt)), resultType)
384+
case _ =>
385+
Sqrt(Cast(currentMk, resultType) / Cast(count, resultType))
386+
}
387+
}
388+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/utils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,12 @@ object Utils {
164164
}
165165
case other => None
166166
}
167+
168+
def standardDeviation(e: Expression): Expression = {
169+
val std = aggregate.AggregateExpression2(
170+
aggregateFunction = aggregate.StandardDeviation(e),
171+
mode = aggregate.Complete,
172+
isDistinct = false)
173+
Alias(std, s"std(${e.prettyString})")()
174+
}
167175
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
3434
import org.apache.spark.sql.catalyst.InternalRow
3535
import org.apache.spark.sql.catalyst.analysis._
3636
import org.apache.spark.sql.catalyst.expressions._
37+
import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation
3738
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
3839
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
3940
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
@@ -1268,15 +1269,11 @@ class DataFrame private[sql](
12681269
@scala.annotation.varargs
12691270
def describe(cols: String*): DataFrame = {
12701271

1271-
// TODO: Add stddev as an expression, and remove it from here.
1272-
def stddevExpr(expr: Expression): Expression =
1273-
Sqrt(Subtract(Average(Multiply(expr, expr)), Multiply(Average(expr), Average(expr))))
1274-
12751272
// The list of summary statistics to compute, in the form of expressions.
12761273
val statistics = List[(String, Expression => Expression)](
12771274
"count" -> Count,
12781275
"mean" -> Average,
1279-
"stddev" -> stddevExpr,
1276+
"stddev" -> aggregate.Utils.standardDeviation,
12801277
"min" -> Min,
12811278
"max" -> Max)
12821279

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.language.implicitConversions
2323
import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star}
2525
import org.apache.spark.sql.catalyst.expressions._
26+
import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation
2627
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
2728
import org.apache.spark.sql.types.NumericType
2829

@@ -283,6 +284,18 @@ class GroupedData protected[sql](
283284
aggregateNumericColumns(colNames : _*)(Min)
284285
}
285286

287+
/**
288+
* Compute the sample standard deviation for each numeric column for each group.
289+
* The resulting [[DataFrame]] will also contain the grouping columns.
290+
* When specified columns are given, only compute the standard deviation for them.
291+
*
292+
* @since 1.5.0
293+
*/
294+
@scala.annotation.varargs
295+
def std(colNames: String*): DataFrame = {
296+
aggregateNumericColumns(colNames : _*)(aggregate.Utils.standardDeviation)
297+
}
298+
286299
/**
287300
* Compute the sum for each numeric columns for each group.
288301
* The resulting [[DataFrame]] will also contain the grouping columns.

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class SortBasedAggregationIterator(
8787
// The aggregation buffer used by the sort-based aggregation.
8888
private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer
8989

90+
private val dataTypes = allAggregateFunctions.flatMap(_.bufferAttributes).map(_.dataType)
91+
9092
/** Processes rows in the current group. It will stop when it find a new group. */
9193
protected def processCurrentSortedGroup(): Unit = {
9294
currentGroupingKey = nextGroupingKey
@@ -95,6 +97,7 @@ class SortBasedAggregationIterator(
9597
var findNextPartition = false
9698
// firstRowInNextGroup is the first row of this group. We first process it.
9799
processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
100+
println(dataTypes.zipWithIndex.map(d => sortBasedAggregationBuffer.get(d._2, d._1)).mkString("[", ",", "]"))
98101

99102
// The search will stop when we see the next group or there is no
100103
// input row left in the iter.
@@ -107,7 +110,9 @@ class SortBasedAggregationIterator(
107110
// Check if the current row belongs the current input row.
108111
if (currentGroupingKey == groupingKey) {
109112
processRow(sortBasedAggregationBuffer, currentRow)
110-
113+
println("Second")
114+
println(currentRow)
115+
println(dataTypes.zipWithIndex.map(d => sortBasedAggregationBuffer.get(d._2, d._1)).mkString("[", ",", "]"))
111116
hasNext = inputKVIterator.next()
112117
} else {
113118
// We find a new group.

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag}
2222
import scala.util.Try
2323

2424
import org.apache.spark.annotation.Experimental
25+
import org.apache.spark.sql.catalyst.expressions.aggregate.StandardDeviation
2526
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
2627
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
2728
import org.apache.spark.sql.catalyst.expressions._
@@ -294,6 +295,22 @@ object functions {
294295
*/
295296
def min(columnName: String): Column = min(Column(columnName))
296297

298+
/**
299+
* Aggregate function: returns the sample standard deviation of the values in a group.
300+
*
301+
* @group agg_funcs
302+
* @since 1.5.0
303+
*/
304+
def std(e: Column): Column = aggregate.Utils.standardDeviation(e.expr)
305+
306+
/**
307+
* Aggregate function: returns the sample standard deviation of the values in a group.
308+
*
309+
* @group agg_funcs
310+
* @since 1.5.0
311+
*/
312+
def std(columnName: String): Column = std(Column(columnName))
313+
297314
/**
298315
* Aggregate function: returns the sum of all values in the expression.
299316
*

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
442442
val describeResult = Seq(
443443
Row("count", "4", "4"),
444444
Row("mean", "33.0", "178.0"),
445-
Row("stddev", "16.583123951777", "10.0"),
445+
Row("stddev", "19.148542155126762", "11.547005383792516"),
446446
Row("min", "16", "164"),
447447
Row("max", "60", "192"))
448448

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import org.apache.spark.sql.hive.test.TestHive
2222
import org.apache.spark.sql.test.SQLTestUtils
2323
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
2424
import org.apache.spark.sql._
25+
import org.apache.spark.sql.functions.std
2526
import org.scalatest.BeforeAndAfterAll
2627
import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
2728

@@ -84,6 +85,32 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Be
8485
sqlContext.setConf(SQLConf.USE_SQL_AGGREGATE2.key, originalUseAggregate2.toString)
8586
}
8687

88+
test("test standard deviation") {
89+
val df = Seq.tabulate(10)(i => (i, 1)).toDF("val", "key")
90+
checkAnswer(
91+
df.select(std("val")),
92+
Row(3.0276503540974917) :: Nil)
93+
94+
checkAnswer(
95+
sqlContext.table("agg1").groupBy("key").std("value"),
96+
Row(1, 10.0) :: Row(2, 0.7071067811865476) :: Row(3, null) ::
97+
Row(null, 81.8535277187245) :: Nil)
98+
99+
checkAnswer(
100+
sqlContext.table("agg1").select(std("key"), std("value")),
101+
Row(0.7817359599705717, 44.898098909801135) :: Nil)
102+
103+
checkAnswer(
104+
sqlContext.table("agg2").groupBy("key", "value1").std("value2"),
105+
Row(1, 10, null) :: Row(1, 30, 42.42640687119285) :: Row(2, -1, null) ::
106+
Row(2, 1, 0.0) :: Row(2, null, null) :: Row(3, null, null) :: Row(null, -10, null) ::
107+
Row(null, -60, null) :: Row(null, 100, null) :: Row(null, null, null) :: Nil)
108+
109+
checkAnswer(
110+
sqlContext.table("emptyTable").select(std("value")),
111+
Row(null) :: Nil)
112+
}
113+
87114
test("empty table") {
88115
// If there is no GROUP BY clause and the table is empty, we will generate a single row.
89116
checkAnswer(

0 commit comments

Comments
 (0)