Skip to content

Commit 6351fc8

Browse files
committed
address Davides's comment
1 parent 6035648 commit 6351fc8

File tree

3 files changed

+7
-59
lines changed

3 files changed

+7
-59
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,7 @@ abstract class StddevAgg(child: Expression) extends AlgebraicAggregate {
291291
// we remove the old aggregate functions. Then, we will not need NullType at here.
292292
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))
293293

294-
private val resultType = child.dataType match {
295-
case DecimalType.Fixed(p, s) =>
296-
DecimalType.bounded(p + 14, s + 4)
297-
case _ => DoubleType
298-
}
299-
300-
private val zero = Cast(Literal(0), resultType)
294+
private val resultType = DoubleType
301295

302296
private val preCount = AttributeReference("preCount", resultType)()
303297
private val currentCount = AttributeReference("currentCount", resultType)()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -696,11 +696,7 @@ case class LastFunction(expr: Expression, base: AggregateExpression1) extends Ag
696696
// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
697697
abstract class StddevAgg1(child: Expression) extends UnaryExpression with PartialAggregate1 {
698698
override def nullable: Boolean = true
699-
override def dataType: DataType = child.dataType match {
700-
case DecimalType.Fixed(p, s) =>
701-
DecimalType.bounded(p + 14, s + 4)
702-
case _ => DoubleType
703-
}
699+
override def dataType: DataType = DoubleType
704700

705701
def isSample: Boolean
706702

@@ -742,10 +738,7 @@ case class ComputePartialStd(child: Expression) extends UnaryExpression with Agg
742738

743739
override def children: Seq[Expression] = child :: Nil
744740
override def nullable: Boolean = false
745-
override def dataType: DataType = child.dataType match {
746-
case DecimalType.Fixed(p, s) => ArrayType(DecimalType.bounded(p + 10, s + 4))
747-
case _ => ArrayType(DoubleType)
748-
}
741+
override def dataType: DataType = ArrayType(DoubleType)
749742
override def toString: String = s"computePartialStddev($child)"
750743
override def newInstance(): ComputePartialStdFunction =
751744
new ComputePartialStdFunction(child, this)
@@ -757,10 +750,7 @@ case class ComputePartialStdFunction (
757750
) extends AggregateFunction1 {
758751
def this() = this(null, null) // Required for serialization
759752

760-
private val computeType = expr.dataType match {
761-
case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s + 4)
762-
case _ => DoubleType
763-
}
753+
private val computeType = DoubleType
764754
private val zero = Cast(Literal(0), computeType)
765755
private var partialCount: Long = 0L
766756

@@ -811,10 +801,7 @@ case class MergePartialStd(
811801

812802
override def children: Seq[Expression] = child:: Nil
813803
override def nullable: Boolean = false
814-
override def dataType: DataType = child.dataType match {
815-
case ArrayType(DecimalType.Fixed(p, s), _) => DecimalType.bounded(p + 14, s + 4)
816-
case _ => DoubleType
817-
}
804+
override def dataType: DataType = DoubleType
818805
override def toString: String = s"MergePartialStd($child)"
819806
override def newInstance(): MergePartialStdFunction = {
820807
new MergePartialStdFunction(child, this, isSample)
@@ -828,10 +815,7 @@ case class MergePartialStdFunction(
828815
) extends AggregateFunction1 {
829816
def this() = this (null, null, false) // Required for serialization
830817

831-
private val computeType = expr.dataType match {
832-
case ArrayType(DecimalType.Fixed(p, s), _) => DecimalType.bounded(p + 14, s + 4)
833-
case _ => DoubleType
834-
}
818+
private val computeType = DoubleType
835819
private val zero = Cast(Literal(0), computeType)
836820
private val combineCount = MutableLiteral(zero.eval(null), computeType)
837821
private val combineAvg = MutableLiteral(zero.eval(null), computeType)
@@ -906,10 +890,7 @@ case class StddevFunction(
906890

907891
def this() = this(null, null, false) // Required for serialization
908892

909-
private val computeType = expr.dataType match {
910-
case DecimalType.Fixed(p, s) => DecimalType.bounded(p + 14, s + 4)
911-
case _ => DoubleType
912-
}
893+
private val computeType = DoubleType
913894
private var curCount: Long = 0L
914895
private val zero = Cast(Literal(0), computeType)
915896
private val curAvg = MutableLiteral(zero.eval(null), computeType)

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -303,15 +303,6 @@ object functions {
303303
*/
304304
def stddev(e: Column): Column = Stddev(e.expr)
305305

306-
/**
307-
* Aggregate function: returns the unbiased sample standard deviation
308-
* of the column in a group.
309-
*
310-
* @group agg_funcs
311-
* @since 1.6.0
312-
*/
313-
def stddev(columnName: String): Column = stddev(Column(columnName))
314-
315306
/**
316307
* Aggregate function: returns the population standard deviation of
317308
* the expression in a group.
@@ -321,15 +312,6 @@ object functions {
321312
*/
322313
def stddev_pop(e: Column): Column = StddevPop(e.expr)
323314

324-
/**
325-
* Aggregate function: returns the standard deviation of the column
326-
* in a group.
327-
*
328-
* @group agg_funcs
329-
* @since 1.6.0
330-
*/
331-
def stddev_pop(columnName: String): Column = stddev_pop(Column(columnName))
332-
333315
/**
334316
* Aggregate function: returns the unbiased sample standard deviation of
335317
* the expression in a group.
@@ -339,15 +321,6 @@ object functions {
339321
*/
340322
def stddev_samp(e: Column): Column = StddevSamp(e.expr)
341323

342-
/**
343-
* Aggregate function: returns the unbiased sample standard deviation of the
344-
* column in a group.
345-
*
346-
* @group agg_funcs
347-
* @since 1.6.0
348-
*/
349-
def stddev_samp(columnName: String): Column = stddev_samp(Column(columnName))
350-
351324
/**
352325
* Aggregate function: returns the sum of all values in the expression.
353326
*

0 commit comments

Comments
 (0)