From 40fed21254ae0625b708727b4cfc858686548924 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 26 Mar 2015 21:57:15 -0700 Subject: [PATCH 1/4] shrink the commits --- .../main/scala/org/apache/spark/sql/Row.scala | 47 +- .../apache/spark/sql/catalyst/SqlParser.scala | 11 +- .../sql/catalyst/analysis/Analyzer.scala | 6 +- .../catalyst/analysis/FunctionRegistry.scala | 26 +- .../catalyst/analysis/HiveTypeCoercion.scala | 38 +- .../sql/catalyst/analysis/unresolved.scala | 5 +- .../spark/sql/catalyst/dsl/package.scala | 10 +- .../sql/catalyst/expressions/aggregates.scala | 916 +++++++----------- .../sql/catalyst/expressions/package.scala | 2 + .../spark/sql/catalyst/expressions/rows.scala | 17 +- .../sql/catalyst/optimizer/Optimizer.scala | 10 +- .../sql/catalyst/planning/patterns.scala | 37 +- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../org/apache/spark/sql/GroupedData.scala | 10 +- .../spark/sql/execution/Aggregate.scala | 536 +++++++--- .../sql/execution/GeneratedAggregate.scala | 27 +- .../spark/sql/execution/SparkStrategies.scala | 50 +- .../org/apache/spark/sql/functions.scala | 6 +- .../spark/sql/execution/PlannerSuite.scala | 23 +- .../spark/sql/hive/HiveInspectors.scala | 10 + .../org/apache/spark/sql/hive/HiveQl.scala | 13 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 198 ++-- ...ions #1-0-dc640a3b0e7f23e9052c454a739ba9db | 309 ++++++ ...ions #2-0-4b1bcbd566a255e2b694ec9d8bacb825 | 309 ++++++ ...ions #3-0-e4e01312d01a7a08cff2ac43196f6ea4 | 309 ++++++ ...ions #4-0-ff859636795b1019ad74567bf4ba095f | 309 ++++++ ...ions #5-0-5c7cdc7d4bc610cec923b54d3f3d696a | 309 ++++++ ...ions #6-0-dde25ab17e3198c18468e738f0464cf4 | 309 ++++++ ...sions #7-0-dc8a898b293d22742b62ce236e72f77 | 309 ++++++ ...sions #1-0-30038eb221d9d91ff4a098a57c1a5f9 | 1 + ...ions #2-0-75a3974aac80b9c47f23519da6a68876 | 1 + ...ions #3-0-8341e7bf739124bef28729aabb9fe542 | 1 + ...ions #4-0-679efde7a074d99d8dd227b4903b92f8 | 1 + ...ions #5-0-1e35f970b831ecfffdaff828428aea51 | 1 + ...ions #6-0-528e3454b467687ee9c1074cc7864660 | 1 + .../sql/hive/execution/AggregateSuite.scala | 168 ++++ 36 files changed, 3389 insertions(+), 948 deletions(-) create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-dc640a3b0e7f23e9052c454a739ba9db create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-4b1bcbd566a255e2b694ec9d8bacb825 create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-e4e01312d01a7a08cff2ac43196f6ea4 create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #4-0-ff859636795b1019ad74567bf4ba095f create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-5c7cdc7d4bc610cec923b54d3f3d696a create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-dde25ab17e3198c18468e738f0464cf4 create mode 100644 sql/hive/src/test/resources/golden/aggregation with group by expressions #7-0-dc8a898b293d22742b62ce236e72f77 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-679efde7a074d99d8dd227b4903b92f8 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 create mode 100644 sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-528e3454b467687ee9c1074cc7864660 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index d794f034f5578..baca061c565b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 -import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.{StructType, DateUtils} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericRow} object Row { /** @@ -305,6 +305,49 @@ trait Row extends Serializable { */ def getAs[T](i: Int): T = apply(i).asInstanceOf[T] + /* TODO make the syntactic sugar it as API? */ + @inline + final def apply(bound: BoundReference): Any = apply(bound.ordinal) + @inline + final def isNullAt(bound: BoundReference): Boolean = isNullAt(bound.ordinal) + @inline + final def getInt(bound: BoundReference): Int = getInt(bound.ordinal) + @inline + final def getLong(bound: BoundReference): Long = getLong(bound.ordinal) + @inline + final def getDouble(bound: BoundReference): Double = getDouble(bound.ordinal) + @inline + final def getBoolean(bound: BoundReference): Boolean = getBoolean(bound.ordinal) + @inline + final def getShort(bound: BoundReference): Short = getShort(bound.ordinal) + @inline + final def getByte(bound: BoundReference): Byte = getByte(bound.ordinal) + @inline + final def getFloat(bound: BoundReference): Float = getFloat(bound.ordinal) + @inline + final def getString(bound: BoundReference): String = getString(bound.ordinal) + @inline + final def getAs[T](bound: BoundReference): T = getAs[T](bound.ordinal) + @inline + final def getDecimal(bound: BoundReference): java.math.BigDecimal = getDecimal(bound.ordinal) + @inline + final def getDate(bound: BoundReference): java.sql.Date = getDate(bound.ordinal) + @inline + final def getSeq[T](bound: BoundReference): Seq[T] = getSeq(bound.ordinal) + @inline + final def getList[T](bound: BoundReference): java.util.List[T] = getList(bound.ordinal) + @inline + final def getMap[K, V](bound: BoundReference): scala.collection.Map[K, V] = { + getMap(bound.ordinal) + } + @inline + final def getJavaMap[K, V](bound: BoundReference): java.util.Map[K, V] = { + getJavaMap(bound.ordinal) + } + @inline + final def getStruct(bound: BoundReference): Row = getStruct(bound.ordinal) + /* end of the syntactic sugar it as API */ + override def toString(): String = s"[${this.mkString(",")}]" /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index ea7d44a3723d1..e61846d8dd7c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -271,15 +271,18 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected lazy val function: Parser[Expression] = ( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) } - | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) } + | SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => Sum(exp, true) } | COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) } | COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) } | COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^ - { case exps => CountDistinct(exps) } + { case exps => CountDistinct(exps) } + | COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } + // TODO approximate is not supported, will convert into COUNT | APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^ - { case exp => ApproxCountDistinct(exp) } + { case exp => CountDistinct(exp :: Nil) } | APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^ - { case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) } + // TODO approximate s.toDouble + { case s ~ _ ~ _ ~ _ ~ _ ~ e => CountDistinct(e :: Nil) } | FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) } | LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) } | AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 44eceb0b372e6..ca2ff6258ef33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -200,7 +200,7 @@ class Analyzer(catalog: Catalog, Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil @@ -385,8 +385,8 @@ class Analyzer(catalog: Catalog, def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressions { - case u @ UnresolvedFunction(name, children) if u.childrenResolved => - registry.lookupFunction(name, children) + case u @ UnresolvedFunction(name, children, distinct) if u.childrenResolved => + registry.lookupFunction(name, children, distinct) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c43ea55899695..9499ac7fec1fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -26,9 +26,12 @@ trait FunctionRegistry { def registerFunction(name: String, builder: FunctionBuilder): Unit - def lookupFunction(name: String, children: Seq[Expression]): Expression - def caseSensitive: Boolean + + def lookupFunction( + name: String, + children: Seq[Expression], + distinct: Boolean = false): Expression } trait OverrideFunctionRegistry extends FunctionRegistry { @@ -39,8 +42,13 @@ trait OverrideFunctionRegistry extends FunctionRegistry { functionBuilders.put(name, builder) } - abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children)) + abstract override def lookupFunction( + name: String, + children: Seq[Expression], + distinct: Boolean = false): Expression = { + functionBuilders.get(name) + .map(_(children)) + .getOrElse(super.lookupFunction(name, children, distinct)) } } @@ -51,7 +59,10 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr functionBuilders.put(name, builder) } - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + override def lookupFunction( + name: String, + children: Seq[Expression], + distinct: Boolean = false): Expression = { functionBuilders(name)(children) } } @@ -65,7 +76,10 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + def lookupFunction( + name: String, + children: Seq[Expression], + distinct: Boolean = false): Expression = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 34ef7d28cc7f2..0dd5815a518f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -270,10 +270,10 @@ trait HiveTypeCoercion { case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) => i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) - case Sum(e) if e.dataType == StringType => - Sum(Cast(e, DoubleType)) - case Average(e) if e.dataType == StringType => - Average(Cast(e, DoubleType)) + case Sum(e, distinct) if e.dataType == StringType => + Sum(Cast(e, DoubleType), distinct) + case Average(e, distinct) if e.dataType == StringType => + Average(Cast(e, DoubleType), distinct) case Sqrt(e) if e.dataType == StringType => Sqrt(Cast(e, DoubleType)) } @@ -484,25 +484,21 @@ trait HiveTypeCoercion { children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. - case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. - case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType)) - case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType)) - - case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest. - case SumDistinct(e @ IntegralType()) if e.dataType != LongType => - SumDistinct(Cast(e, LongType)) - case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType => - SumDistinct(Cast(e, DoubleType)) - - case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest. - case Average(e @ IntegralType()) if e.dataType != LongType => - Average(Cast(e, LongType)) - case Average(e @ FractionalType()) if e.dataType != DoubleType => - Average(Cast(e, DoubleType)) + case s @ Sum(e @ DecimalType(), _) => s // Decimal is already the biggest. + case Sum(e @ IntegralType(), distinct) if e.dataType != LongType => + Sum(Cast(e, LongType), distinct) + case Sum(e @ FractionalType(), distinct) if e.dataType != DoubleType => + Sum(Cast(e, DoubleType), distinct) + + case s @ Average(e @ DecimalType(), _) => s // Decimal is already the biggest. + case Average(e @ IntegralType(), distinct) if e.dataType != LongType => + Average(Cast(e, LongType), distinct) + case Average(e @ FractionalType(), distinct) if e.dataType != DoubleType => + Average(Cast(e, DoubleType), distinct) // Hive lets you do aggregation of timestamps... for some reason - case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) - case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + case Sum(e @ TimestampType(), distinct) => Sum(Cast(e, DoubleType), distinct) + case Average(e @ TimestampType(), distinct) => Average(Cast(e, DoubleType), distinct) // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 300e9ba187bc5..91b9510a886ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -68,7 +68,10 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def toString: String = s"'$name" } -case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { +case class UnresolvedFunction( + name: String, + children: Seq[Expression], + distinct: Boolean = false) extends Expression { override def dataType: DataType = throw new UnresolvedException(this, "dataType") override def foldable: Boolean = throw new UnresolvedException(this, "foldable") override def nullable: Boolean = throw new UnresolvedException(this, "nullable") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 145f062dd6817..4ccccffbd9b52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -145,11 +145,12 @@ package object dsl { } def sum(e: Expression): Expression = Sum(e) - def sumDistinct(e: Expression): Expression = SumDistinct(e) + def sumDistinct(e: Expression): Expression = Sum(e, true) def count(e: Expression): Expression = Count(e) - def countDistinct(e: Expression*): Expression = CountDistinct(e) - def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression = - ApproxCountDistinct(e, rsd) + def countDistinct(e: Expression) = CountDistinct(e :: Nil) + def countDistinct(e: Expression*) = CountDistinct(e) + // TODO we don't support approximate, will convert it into Count + def approxCountDistinct(e: Expression, rsd: Double = 0.05) = CountDistinct(e :: Nil) def avg(e: Expression): Expression = Average(e) def first(e: Expression): Expression = First(e) def last(e: Expression): Expression = Last(e) @@ -161,6 +162,7 @@ package object dsl { def abs(e: Expression): Expression = Abs(e) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } + // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { override def expr: Expression = Literal(s) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 30da4faa3f1c6..3892c89053902 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -17,285 +17,159 @@ package org.apache.spark.sql.catalyst.expressions -import com.clearspring.analytics.stream.cardinality.HyperLogLog - import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.trees -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.util.collection.OpenHashSet - -abstract class AggregateExpression extends Expression { - self: Product => - /** - * Creates a new instance that can be used to compute this aggregate expression for a group - * of input rows/ - */ - def newInstance(): AggregateFunction - /** - * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are - * replaced with a physical aggregate operator at runtime. - */ - override def eval(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") -} +/** + * This is from org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.Mode + * Just a hint for the UDAF developers which stage we are about to process, + * However, we probably don't want the developers knows so many details, here + * is just for keep consistent with Hive (when integrated with Hive), need to + * figure out if we have work around for that soon. + */ +@deprecated +trait Mode /** - * Represents an aggregation that has been rewritten to be performed in two steps. - * - * @param finalEvaluation an aggregate expression that evaluates to same final result as the - * original aggregation. - * @param partialEvaluations A sequence of [[NamedExpression]]s that can be computed on partial - * data sets and are required to compute the `finalEvaluation`. + * PARTIAL1: from original data to partial aggregation data: iterate() and + * terminatePartial() will be called. */ -case class SplitEvaluation( - finalEvaluation: Expression, - partialEvaluations: Seq[NamedExpression]) +@deprecated +case object PARTIAL1 extends Mode /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. - * These partial evaluations can then be combined to compute the actual answer. + * PARTIAL2: from partial aggregation data to partial aggregation data: + * merge() and terminatePartial() will be called. */ -abstract class PartialAggregate extends AggregateExpression { - self: Product => +@deprecated +case object PARTIAL2 extends Mode +/** + * FINAL: from partial aggregation to full aggregation: merge() and + * terminate() will be called. + */ +@deprecated +case object FINAL extends Mode +/** + * COMPLETE: from original data directly to full aggregation: iterate() and + * terminate() will be called. + */ +@deprecated +case object COMPLETE extends Mode - /** - * Returns a [[SplitEvaluation]] that computes this aggregation using partial aggregation. - */ - def asPartial: SplitEvaluation -} /** - * A specific implementation of an aggregate function. Used to wrap a generic - * [[AggregateExpression]] with an algorithm that will be used to compute one specific result. + * Aggregation Function Interface + * All of the function will be called within Spark executors. */ -abstract class AggregateFunction - extends AggregateExpression with Serializable with trees.LeafNode[Expression] { +trait AggregateFunction { self: Product => - override type EvaluatedType = Any - - /** Base should return the generic aggregate expression that this function is computing */ - val base: AggregateExpression - - override def nullable: Boolean = base.nullable - override def dataType: DataType = base.dataType + // Specify the BoundReference for Aggregate Buffer + def initialBoundReference(buffers: Seq[BoundReference]): Unit - def update(input: Row): Unit + // Initialize (reinitialize) the aggregation buffer + def reset(buf: MutableRow): Unit - // Do we really need this? - override def newInstance(): AggregateFunction = { - makeCopy(productIterator.map { case a: AnyRef => a }.toArray) - } -} + // Expect the aggregate function fills the aggregation buffer when + // fed with each value in the group + def iterate(arguments: Any, buf: MutableRow): Unit -case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + // Merge 2 aggregation buffer, and write back to the later one + def merge(value: Row, buf: MutableRow): Unit - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"MIN($child)" - - override def asPartial: SplitEvaluation = { - val partialMin = Alias(Min(child), "PartialMin")() - SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil) - } - - override def newInstance(): MinFunction = new MinFunction(child, this) -} - -case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. - - val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = GreaterThan(currentMin, expr) - - override def update(input: Row): Unit = { - if (currentMin.value == null) { - currentMin.value = expr.eval(input) - } else if(cmp.eval(input) == true) { - currentMin.value = expr.eval(input) - } - } - - override def eval(input: Row): Any = currentMin.value -} - -case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"MAX($child)" - - override def asPartial: SplitEvaluation = { - val partialMax = Alias(Max(child), "PartialMax")() - SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil) - } - - override def newInstance(): MaxFunction = new MaxFunction(child, this) -} - -case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. - - val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType) - val cmp = LessThan(currentMax, expr) - - override def update(input: Row): Unit = { - if (currentMax.value == null) { - currentMax.value = expr.eval(input) - } else if(cmp.eval(input) == true) { - currentMax.value = expr.eval(input) - } - } + // Semantically we probably don't need this, however, we need it when + // integrating with Hive UDAF(GenericUDAF) + @deprecated + def terminatePartial(buf: MutableRow): Unit = {} - override def eval(input: Row): Any = currentMax.value + // Output the final result by feeding the aggregation buffer + def terminate(input: Row): Any } -case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { +trait AggregateExpression extends Expression with AggregateFunction { + self: Product => + type EvaluatedType = Any - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"COUNT($child)" + var mode: Mode = COMPLETE - override def asPartial: SplitEvaluation = { - val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) + def initial(m: Mode): Unit = { + this.mode = m } - override def newInstance(): CountFunction = new CountFunction(child, this) -} - -case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { - def this() = this(null) + // Aggregation Buffer data types + def bufferDataType: Seq[DataType] = Nil + // Is it a distinct aggregate expression? + def distinct: Boolean + // Is it a distinct like aggregate expression (e.g. Min/Max is distinctLike, while avg is not) + def distinctLike: Boolean = false - override def children: Seq[Expression] = expressions + def nullable = true - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"COUNT(DISTINCT ${expressions.mkString(",")})" - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(expressions, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(expressions), "partialSets")() - SplitEvaluation( - CombineSetsAndCount(partialSet.toAttribute), - partialSet :: Nil) - } + override def eval(input: Row): EvaluatedType = children.map(_.eval(input)) } -case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { - def this() = this(null) +abstract class UnaryAggregateExpression extends UnaryExpression with AggregateExpression { + self: Product => - override def children: Seq[Expression] = expressions - override def nullable: Boolean = false - override def dataType: ArrayType = ArrayType(expressions.head.dataType) - override def toString: String = s"AddToHashSet(${expressions.mkString(",")})" - override def newInstance(): CollectHashSetFunction = new CollectHashSetFunction(expressions, this) + override def eval(input: Row): EvaluatedType = child.eval(input) } -case class CollectHashSetFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { +case class Min( + child: Expression) + extends UnaryAggregateExpression { - def this() = this(null, null) // Required for serialization. + override def distinct: Boolean = false + override def distinctLike: Boolean = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"MIN($child)" - val seen = new OpenHashSet[Any]() + /* The below code will be called in executors, be sure to make the instance transientable */ + @transient var arg: MutableLiteral = _ + @transient var buffer: MutableLiteral = _ + @transient var cmp: LessThan = _ + @transient var aggr: BoundReference = _ - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: Row): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) - } + /* Initialization on executors */ + override def initialBoundReference(buffers: Seq[BoundReference]): Unit = { + aggr = buffers(0) + arg = MutableLiteral(null, dataType) + buffer = MutableLiteral(null, dataType) + cmp = LessThan(arg, buffer) } - override def eval(input: Row): Any = { - seen - } -} - -case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = false - override def dataType: DataType = LongType - override def toString: String = s"CombineAndCount($inputSet)" - override def newInstance(): CombineSetsAndCountFunction = { - new CombineSetsAndCountFunction(inputSet, this) + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null } -} - -case class CombineSetsAndCountFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: Row): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + arg.value = argument + buffer.value = buf(aggr) + if (buf.isNullAt(aggr) || cmp.eval(null) == true) { + buf(aggr) = argument + } } } - override def eval(input: Row): Any = seen.size.toLong -} - -case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) - extends AggregateExpression with trees.UnaryNode[Expression] { - - override def nullable: Boolean = false - override def dataType: DataType = child.dataType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctPartitionFunction = { - new ApproxCountDistinctPartitionFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) - extends AggregateExpression with trees.UnaryNode[Expression] { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - override def newInstance(): ApproxCountDistinctMergeFunction = { - new ApproxCountDistinctMergeFunction(child, this, relativeSD) - } -} - -case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) - extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = false - override def dataType: LongType.type = LongType - override def toString: String = s"APPROXIMATE COUNT(DISTINCT $child)" - - override def asPartial: SplitEvaluation = { - val partialCount = - Alias(ApproxCountDistinctPartition(child, relativeSD), "PartialApproxCountDistinct")() - - SplitEvaluation( - ApproxCountDistinctMerge(partialCount.toAttribute, relativeSD), - partialCount :: Nil) + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + arg.value = value(aggr) + buffer.value = rowBuf(aggr) + if (rowBuf.isNullAt(aggr) || cmp.eval(null) == true) { + rowBuf(aggr) = arg.value + } + } } - override def newInstance(): CountDistinctFunction = new CountDistinctFunction(child :: Nil, this) + override def terminate(row: Row): Any = aggr.eval(row) } -case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true +case class Average(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = false - override def dataType: DataType = child.dataType match { + override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType(precision + 4, scale + 4) // Add 4 digits after decimal point, like Hive case DecimalType.Unlimited => @@ -304,426 +178,342 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN DoubleType } - override def toString: String = s"AVG($child)" + override def bufferDataType: Seq[DataType] = LongType :: dataType :: Nil + override def toString = s"AVG($child)" - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - // Turn the child to unlimited decimals for calculation, before going back to fixed - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var count: BoundReference = _ + @transient var sum: BoundReference = _ - val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) - val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) - SplitEvaluation( - Cast(Divide(castedSum, castedCount), dataType), - partialCount :: partialSum :: Nil) + // for iterate + @transient var arg: MutableLiteral = _ + @transient var cast: Expression = _ + @transient var add: Add = _ - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() + // for merge + @transient var argInMerge: MutableLiteral = _ + @transient var addInMerge: Add = _ - val castedSum = Cast(Sum(partialSum.toAttribute), dataType) - val castedCount = Cast(Sum(partialCount.toAttribute), dataType) - SplitEvaluation( - Divide(castedSum, castedCount), - partialCount :: partialSum :: Nil) - } - } + // for terminate + @transient var divide: Divide = _ - override def newInstance(): AverageFunction = new AverageFunction(child, this) -} + /* Initialization on executors */ + override def initialBoundReference(buffers: Seq[BoundReference]): Unit = { + count = buffers(0) + sum = buffers(1) -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + arg = MutableLiteral(null, child.dataType) + cast = if (arg.dataType != dataType) Cast(arg, dataType) else arg + add = Add(cast, sum) - override def nullable: Boolean = true + argInMerge = MutableLiteral(null, dataType) + addInMerge = Add(argInMerge, sum) - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType + divide = Divide(sum, Cast(count, dataType)) } - override def toString: String = s"SUM($child)" + override def reset(buf: MutableRow): Unit = { + buf(count) = 0L + buf(sum) = null + } - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), - partialSum :: Nil) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + arg.value = argument + buf(count) = buf.getLong(count) + 1 + if (buf.isNullAt(sum)) { + buf(sum) = cast.eval() + } else { + buf(sum) = add.eval(buf) + } + } + } - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - CombineSum(partialSum.toAttribute), - partialSum :: Nil) + override def merge(value: Row, buf: MutableRow): Unit = { + if (!value.isNullAt(sum)) { + buf(count) = value.getLong(count) + buf.getLong(count) + if (buf.isNullAt(sum)) { + buf(sum) = value(sum) + } else { + argInMerge.value = value(sum) + buf(sum) = addInMerge.eval(buf) + } } } - override def newInstance(): SumFunction = new SumFunction(child, this) + override def terminate(row: Row): Any = if (count.eval(row) == 0) null else divide.eval(row) } -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} +case class Max(child: Expression) + extends UnaryAggregateExpression { + override def distinct: Boolean = false + override def distinctLike: Boolean = true -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { + override def nullable = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"MAX($child)" - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + @transient var arg: MutableLiteral = _ + @transient var buffer: MutableLiteral = _ + @transient var cmp: GreaterThan = _ - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + arg = MutableLiteral(null, dataType) + buffer = MutableLiteral(null, dataType) + cmp = GreaterThan(arg, buffer) } -} - -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { - def this() = this(null, null) - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - val seen = new OpenHashSet[Any]() - - override def update(input: Row): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + arg.value = argument + buffer.value = buf(aggr) + if (buf.isNullAt(aggr) || cmp.eval(null) == true) { + buf(aggr) = argument + } } } - override def eval(input: Row): Any = { - val casted = seen.asInstanceOf[OpenHashSet[Row]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + arg.value = value(aggr) + buffer.value = rowBuf(aggr) + if (rowBuf.isNullAt(aggr) || cmp.eval(null) == true) { + rowBuf(aggr) = arg.value + } } } -} -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def terminate(row: Row): Any = aggr.eval(row) } -case class AverageFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class Count(child: Expression) + extends UnaryAggregateExpression { + def distinct: Boolean = false + override def nullable = false + override def dataType = LongType + override def bufferDataType: Seq[DataType] = LongType :: Nil + override def toString = s"COUNT($child)" - def this() = this(null, null) // Required for serialization. + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ - private val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - - private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), calcType) + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + } - private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) + override def reset(buf: MutableRow): Unit = { + buf(aggr) = 0L + } - override def eval(input: Row): Any = { - if (count == 0L) { - null - } else { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(Divide( - Cast(sum, DecimalType.Unlimited), - Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null) - case _ => - Divide( - Cast(sum, dataType), - Cast(Literal(count), dataType)).eval(null) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + if (buf.isNullAt(aggr)) { + buf(aggr) = 1L + } else { + buf(aggr) = buf.getLong(aggr) + 1L } } } - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1 - sum.update(addFunction(evaluatedExpr), input) + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (value.isNullAt(aggr)) { + // do nothing + } else if (rowBuf.isNullAt(aggr)) { + rowBuf(aggr) = value(aggr) + } else { + rowBuf(aggr) = value.getLong(aggr) + rowBuf.getLong(aggr) } } -} -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. + override def terminate(row: Row): Any = aggr.eval(row) +} - var count: Long = _ +case class CountDistinct(children: Seq[Expression]) + extends AggregateExpression { + def distinct: Boolean = true + override def nullable = false + override def dataType = LongType + override def toString = s"COUNT($children)" + override def bufferDataType: Seq[DataType] = LongType :: Nil - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) } - override def eval(input: Row): Any = count -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) + override def reset(buf: MutableRow): Unit = { + buf(aggr) = 0L + } - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (!argument.asInstanceOf[Seq[_]].exists(_ == null)) { + // CountDistinct supports multiple expression, and ONLY IF + // none of its expressions value equals null + if (buf.isNullAt(aggr)) { + buf(aggr) = 1L + } else { + buf(aggr) = buf.getLong(aggr) + 1L + } } } - override def eval(input: Row): Any = hyperLogLog -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + override def merge(value: Row, rowBuf: MutableRow): Unit = { + if (value.isNullAt(aggr)) { + // do nothing + } else if (rowBuf.isNullAt(aggr)) { + rowBuf(aggr) = value(aggr) + } else { + rowBuf(aggr) = value.getLong(aggr) + rowBuf.getLong(aggr) + } } - override def eval(input: Row): Any = hyperLogLog.cardinality() + override def terminate(row: Row): Any = aggr.eval(row) } -case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) + /** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class Sum(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = true + override def dataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } - private val sum = MutableLiteral(null, calcType) + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"SUM($child)" - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ + @transient var arg: MutableLiteral = _ + @transient var sum: Add = _ - override def update(input: Row): Unit = { - sum.update(addFunction, input) + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + arg = MutableLiteral(null, dataType) + sum = Add(arg, aggr) } - override def eval(input: Row): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) - } + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null } -} - -case class CombineSumFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - - private val calcType = - expr.dataType match { - case DecimalType.Fixed(_, _) => - DecimalType.Unlimited - case _ => - expr.dataType - } - - private val zero = Cast(Literal(0), calcType) - private val sum = MutableLiteral(null, calcType) - - private val addFunction = - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero)) - - override def update(input: Row): Unit = { - val result = expr.eval(input) - // partial sum result can be null only when no input rows present - if(result != null) { - sum.update(addFunction, input) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + if (buf.isNullAt(aggr)) { + buf(aggr) = argument + } else { + arg.value = argument + buf(aggr) = sum.eval(buf) + } } } - override def eval(input: Row): Any = { - expr.dataType match { - case DecimalType.Fixed(_, _) => - Cast(sum, dataType).eval(null) - case _ => sum.eval(null) + override def merge(value: Row, buf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + arg.value = value(aggr) + if (buf.isNullAt(aggr)) { + buf(aggr) = arg.value + } else { + buf(aggr) = sum.eval(buf) + } } } + + override def terminate(row: Row): Any = aggr.eval(row) } -case class SumDistinctFunction(expr: Expression, base: AggregateExpression) - extends AggregateFunction { +case class First(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"FIRST($child)" - def this() = this(null, null) // Required for serialization. + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ - private val seen = new scala.collection.mutable.HashSet[Any]() + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) + } - override def update(input: Row): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - seen += evaluatedExpr - } + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null } - override def eval(input: Row): Any = { - if (seen.size == 0) { - null - } else { - Cast(Literal( - seen.reduceLeft( - dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - dataType).eval(null) + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (buf.isNullAt(aggr)) { + if (argument != null) { + buf(aggr) = argument + } } } -} - -case class CountDistinctFunction( - @transient expr: Seq[Expression], - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - val seen = new OpenHashSet[Any]() - - @transient - val distinctValue = new InterpretedProjection(expr) - - override def update(input: Row): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) + override def merge(value: Row, buf: MutableRow): Unit = { + if (buf.isNullAt(aggr)) { + if (!value.isNullAt(aggr)) { + buf(aggr) = value(aggr) + } } } - override def eval(input: Row): Any = seen.size.toLong + override def terminate(row: Row): Any = aggr.eval(row) } -case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. +case class Last(child: Expression, distinct: Boolean = false) + extends UnaryAggregateExpression { + override def nullable = true + override def dataType = child.dataType + override def bufferDataType: Seq[DataType] = dataType :: Nil + override def toString = s"LAST($child)" - var result: Any = null + /* The below code will be called in executors, be sure to mark the instance as transient */ + @transient var aggr: BoundReference = _ - override def update(input: Row): Unit = { - if (result == null) { - result = expr.eval(input) - } + override def initialBoundReference(buffers: Seq[BoundReference]) = { + aggr = buffers(0) } - override def eval(input: Row): Any = result -} - -case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. - - var result: Any = null + override def reset(buf: MutableRow): Unit = { + buf(aggr) = null + } - override def update(input: Row): Unit = { - result = input + override def iterate(argument: Any, buf: MutableRow): Unit = { + if (argument != null) { + buf(aggr) = argument + } } - override def eval(input: Row): Any = { - if (result != null) expr.eval(result.asInstanceOf[Row]) else null + override def merge(value: Row, buf: MutableRow): Unit = { + if (!value.isNullAt(aggr)) { + buf(aggr) = value(aggr) + } } + + override def terminate(row: Row): Any = aggr.eval(row) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index fbc97b2e75312..bc94a99e3bf1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -77,4 +77,6 @@ package object expressions { /** Uses the given row to store the output of the projection. */ def target(row: MutableRow): MutableProjection } + + type AggrBuffer = GenericMutableRow } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index a8983df208318..fd8d1ae80e3e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.types.{StructType, NativeType} - /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting * a value through a primitive function implicitly marks that column as not null. @@ -28,7 +27,6 @@ trait MutableRow extends Row { def setNullAt(i: Int): Unit def update(ordinal: Int, value: Any) - def setInt(ordinal: Int, value: Int) def setLong(ordinal: Int, value: Long) def setDouble(ordinal: Int, value: Double) @@ -37,6 +35,21 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + + final def setByte(ordinal: Int, value: Int) { setByte(ordinal, value.asInstanceOf[Byte]) } + final def setShort(ordinal: Int, value: Int) { setByte(ordinal, value.asInstanceOf[Short]) } + + final def update(bound: BoundReference, value: Any) { update(bound.ordinal, value) } + final def setInt(bound: BoundReference, value: Int) { setInt(bound.ordinal, value) } + final def setLong(bound: BoundReference, value: Long) { setLong(bound.ordinal, value) } + final def setDouble(bound: BoundReference, value: Double) { setDouble(bound.ordinal, value) } + final def setBoolean(bound: BoundReference, value: Boolean) { setBoolean(bound.ordinal, value) } + final def setShort(bound: BoundReference, value: Int) { setShort(bound.ordinal, value) } + final def setShort(bound: BoundReference, value: Short) { setShort(bound.ordinal, value) } + final def setByte(bound: BoundReference, value: Byte) { setByte(bound.ordinal, value) } + final def setByte(bound: BoundReference, value: Int) { setByte(bound.ordinal, value) } + final def setFloat(bound: BoundReference, value: Float) { setFloat(bound.ordinal, value) } + final def setString(bound: BoundReference, value: String) { setString(bound.ordinal, value) } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c23d3b61887c6..255b4b03c3104 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -642,12 +642,14 @@ object DecimalAggregates extends Rule[LogicalPlan] { val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(Sum(UnscaledValue(e)), prec + 10, scale) + case Sum(e @ DecimalType.Expression(prec, scale), distinct) + if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(Sum(UnscaledValue(e), distinct), prec + 10, scale) - case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + case Average(e @ DecimalType.Expression(prec, scale), distinct) + if prec + 4 <= MAX_DOUBLE_DIGITS => Cast( - Divide(Average(UnscaledValue(e)), Literal(math.pow(10.0, scale), DoubleType)), + Divide(Average(UnscaledValue(e), distinct), Literal(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c8c643f7d17a..234e126b8f03b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -127,55 +127,43 @@ object PartialAggregation { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case logical.Aggregate(groupingExpressions, aggregateExpressions, child) => - // Collect all aggregate expressions. + // Collect all aggregate expressions that can be computed partially. val allAggregates = aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a}) - // Collect all aggregate expressions that can be computed partially. - val partialAggregates = - aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p}) // Only do partial aggregation if supported by all aggregate expressions. - if (allAggregates.size == partialAggregates.size) { - // Create a map of expressions to their partial evaluations for all aggregate expressions. - val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] = - partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap - + if (!allAggregates.exists(_.distinct)) { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = - groupingExpressions.filter(!_.isInstanceOf[Literal]).map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + val namedGroupingExpressions = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + } + val substitutions = namedGroupingExpressions.toMap // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp { - case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) => - partialEvaluations(new TreeNodeRef(e)).finalEvaluation - + case e: Expression if substitutions.contains(e) => + substitutions(e).toAttribute case e: Expression => // Should trim aliases around `GetField`s. These aliases are introduced while // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) - namedGroupingExpressions + substitutions .get(e.transform { case Alias(g: GetField, _) => g }) .map(_.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] - val partialComputation = - (namedGroupingExpressions.values ++ - partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq - - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq + val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) Some( (namedGroupingAttributes, rewrittenAggregateExpressions, groupingExpressions, - partialComputation, + aggregateExpressions, child)) } else { None @@ -184,7 +172,6 @@ object PartialAggregation { } } - /** * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 4c80359cf07af..1867e5139ff3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -783,7 +783,7 @@ class DataFrame private[sql]( // The list of summary statistics to compute, in the form of expressions. val statistics = List[(String, Expression => Expression)]( "count" -> Count, - "mean" -> Average, + "mean" -> (Average(_, false)), "stddev" -> stddevExpr, "min" -> Min, "max" -> Max) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45a63ae26ed71..4a1ba64630a35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -69,10 +69,10 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) private[this] def strToExpr(expr: String): (Expression => Expression) = { expr.toLowerCase match { - case "avg" | "average" | "mean" => Average + case "avg" | "average" | "mean" => Average(_, false) case "max" => Max case "min" => Min - case "sum" => Sum + case "sum" => Sum(_, false) case "count" | "size" => // Turn count(*) into count(1) (inputExpr: Expression) => inputExpr match { @@ -177,7 +177,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def mean(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames:_*)(Average(_, false)) } /** @@ -197,7 +197,7 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def avg(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Average) + aggregateNumericColumns(colNames:_*)(Average(_, false)) } /** @@ -217,6 +217,6 @@ class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { - aggregateNumericColumns(colNames:_*)(Sum) + aggregateNumericColumns(colNames:_*)(Sum(_, false)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 18b1ba4c5c4b9..31ebdb35805ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -17,181 +17,459 @@ package org.apache.spark.sql.execution -import java.util.HashMap +import scala.collection._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD + import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.SQLContext /** - * :: DeveloperApi :: - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. + * An aggregate that needs to be computed for each row in a group. * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. + * @param aggregate AggregateExpression, associated with the function + * @param substitution A MutableLiteral used to refer to the result of this aggregate in the final + * output. */ -@DeveloperApi -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan) - extends UnaryNode { - - override def requiredChildDistribution: List[Distribution] = { - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } +sealed case class AggregateFunctionBind( + aggregate: AggregateExpression, + substitution: MutableLiteral) + +sealed class InputBufferSeens( + var input: Row, // + var buffer: MutableRow, + var seens: Array[mutable.HashSet[Any]] = null) { + def this() { + this(new GenericMutableRow(0), null) } - override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) + def withInput(row: Row): InputBufferSeens = { + this.input = row + this + } - /** - * An aggregate that needs to be computed for each row in a group. - * - * @param unbound Unbound version of this aggregate, used for result substitution. - * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. - * @param resultAttribute An attribute used to refer to the result of this aggregate in the final - * output. - */ - case class ComputedAggregate( - unbound: AggregateExpression, - aggregate: AggregateExpression, - resultAttribute: AttributeReference) - - /** A list of aggregates that need to be computed for each group. */ - private[this] val computedAggregates = aggregateExpressions.flatMap { agg => - agg.collect { - case a: AggregateExpression => - ComputedAggregate( - a, - BindReferences.bindReference(a, child.output), - AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) + def withBuffer(row: MutableRow): InputBufferSeens = { + this.buffer = row + this + } + + def withSeens(seens: Array[mutable.HashSet[Any]]): InputBufferSeens = { + this.seens = seens + this + } +} + +sealed trait Aggregate { + self: Product => + // HACK: Generators don't correctly preserve their output through serializations so we grab + // out child's output attributes statically here. + val childOutput = child.output + val isGlobalAggregation = groupingExpressions.isEmpty + + def computedAggregates: Array[AggregateExpression] = { + boundProjection.flatMap { expr => + expr.collect { + case ae: AggregateExpression => ae + } } }.toArray - /** The schema of the result of all aggregate evaluations */ - private[this] val computedSchema = computedAggregates.map(_.resultAttribute) + // This is a hack, instead of relying on the BindReferences for the aggregation + // buffer schema in PostShuffle, we have a strong protocols which represented as the + // BoundReferences in PostShuffle for aggregation buffer. + @transient lazy val bufferSchema: Array[AttributeReference] = + computedAggregates.zipWithIndex.flatMap { case (ca, idx) => + ca.bufferDataType.zipWithIndex.map { case (dt, i) => + AttributeReference(s"aggr.${idx}_$i", dt)() } + }.toArray + + // The tuples of aggregate expressions with information + // (AggregateExpression, Aggregate Function, Placeholder of AggregateExpression result) + @transient lazy val aggregateFunctionBinds: Array[AggregateFunctionBind] = { + var pos = 0 + computedAggregates.map { ae => + ae.initial(mode) - /** Creates a new aggregate buffer for a group. */ - private[this] def newAggregateBuffer(): Array[AggregateFunction] = { - val buffer = new Array[AggregateFunction](computedAggregates.length) - var i = 0 - while (i < computedAggregates.length) { - buffer(i) = computedAggregates(i).aggregate.newInstance() - i += 1 + // we connect all of the aggregation buffers in a single Row, + // and "BIND" the attribute references in a Hack way. + val bufferDataTypes = ae.bufferDataType + ae.initialBoundReference(for (i <- 0 until bufferDataTypes.length) yield { + BoundReference(pos + i, bufferDataTypes(i), true) + }) + pos += bufferDataTypes.length + + AggregateFunctionBind(ae, MutableLiteral(null, ae.dataType)) } - buffer } - /** Named attributes used to substitute grouping attributes into the final result. */ - private[this] val namedGroups = groupingExpressions.map { - case ne: NamedExpression => ne -> ne.toAttribute - case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute + @transient lazy val groupByProjection = if (groupingExpressions.isEmpty) { + InterpretedMutableProjection(Nil) + } else { + new InterpretedMutableProjection(groupingExpressions, childOutput) } + // Indicate which stage we are running into + def mode: Mode + // This is provided by SparkPlan + def child: SparkPlan + // Group By Key Expressions + def groupingExpressions: Seq[Expression] + // Bounded Projection + def boundProjection: Seq[NamedExpression] +} + +sealed trait PreShuffle extends Aggregate { + self: Product => + + def boundProjection: Seq[NamedExpression] = projection.map { + case a: Attribute => // Attribute will be converted into BoundReference + Alias( + BindReferences.bindReference(a: Expression, childOutput), a.name)(a.exprId, a.qualifiers) + case a: NamedExpression => BindReferences.bindReference(a, childOutput) + } + + // The expression list for output, this is the unbound expressions + def projection: Seq[NamedExpression] +} + +sealed trait PostShuffle extends Aggregate { + self: Product => /** - * A map of substitutions that are used to insert the aggregate expressions and grouping - * expression into the final result expression. + * Substituted version of boundProjection expressions which are used to compute final + * output rows given a group and the result of all aggregate computations. */ - private[this] val resultMap = - (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap + @transient lazy val finalExpressions = { + val resultMap = aggregateFunctionBinds.map { ae => ae.aggregate -> ae.substitution }.toMap + boundProjection.map { agg => + agg.transform { + case e: AggregateExpression if resultMap.contains(e) => resultMap(e) + } + } + }.map(e => {BindReferences.bindReference(e: Expression, childOutput)}) + + @transient lazy val finalProjection = new InterpretedMutableProjection(finalExpressions) + + def aggregateFunctionBinds: Array[AggregateFunctionBind] + + def createIterator( + aggregates: Array[AggregateExpression], + iterator: Iterator[InputBufferSeens]) = { + val substitutions = aggregateFunctionBinds.map(_.substitution) + + new Iterator[Row] { + override final def hasNext: Boolean = iterator.hasNext + + override final def next(): Row = { + val keybuffer = iterator.next() + + var idx = 0 + while (idx < aggregates.length) { + // substitute the AggregateExpression value + substitutions(idx).value = aggregates(idx).terminate(keybuffer.buffer) + idx += 1 + } + + finalProjection(keybuffer.input) + } + } + } +} + +/** + * :: DeveloperApi :: + * Groups input data by `groupingExpressions` and computes the `projection` for each + * group. + * + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param projection expressions that are computed for each group. + * @param namedGroupingAttributes the attributes represent the output of the groupby expressions + * @param child the input data source. + */ +@DeveloperApi +case class AggregatePreShuffle( + groupingExpressions: Seq[Expression], + projection: Seq[NamedExpression], + namedGroupingAttributes: Seq[Attribute], + child: SparkPlan) + extends UnaryNode with PreShuffle { + + override def requiredChildDistribution = UnspecifiedDistribution :: Nil + + override def output = bufferSchema.map(_.toAttribute) ++ namedGroupingAttributes + + override def mode: Mode = PARTIAL1 // iterate & terminalPartial will be called /** - * Substituted version of aggregateExpressions expressions which are used to compute final - * output rows given a group and the result of all aggregate computations. + * Create Iterator for the in-memory hash map. */ - private[this] val resultExpressions = aggregateExpressions.map { agg => - agg.transform { - case e: Expression if resultMap.contains(e) => resultMap(e) + private[this] def createIterator( + functions: Array[AggregateExpression], + iterator: Iterator[InputBufferSeens]) = { + new Iterator[Row] { + private[this] val joinedRow = new JoinedRow + + override final def hasNext: Boolean = iterator.hasNext + + override final def next(): Row = { + val keybuffer = iterator.next() + var idx = 0 + while (idx < functions.length) { + functions(idx).terminatePartial(keybuffer.buffer) + idx += 1 + } + + joinedRow(keybuffer.buffer, keybuffer.input).copy() + } } } - override def execute(): RDD[Row] = attachTree(this, "execute") { - if (groupingExpressions.isEmpty) { - child.execute().mapPartitions { iter => - val buffer = newAggregateBuffer() - var currentRow: Row = null + override def execute() = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val aggregates = aggregateFunctionBinds.map(_.aggregate) + + if (groupingExpressions.isEmpty) { + // without group by keys + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + idx += 1 + } + while (iter.hasNext) { - currentRow = iter.next() - var i = 0 - while (i < buffer.length) { - buffer(i).update(currentRow) - i += 1 + val currentRow = iter.next() + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.iterate(ae.eval(currentRow), buffer) + idx += 1 } } - val resultProjection = new InterpretedProjection(resultExpressions, computedSchema) - val aggregateResults = new GenericMutableRow(computedAggregates.length) - var i = 0 - while (i < buffer.length) { - aggregateResults(i) = buffer(i).eval(EmptyRow) - i += 1 + createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) + } else { + val results = new mutable.HashMap[Row, InputBufferSeens]() + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results.get(keys) match { + case Some(inputbuffer) => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.iterate(ae.eval(currentRow), inputbuffer.buffer) + idx += 1 + } + case None => + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + // TODO distinctLike? We need to store the "seen" for + // AggregationExpression that distinctLike=true + // This is a trade off between memory & computing + ae.reset(buffer) + ae.iterate(value, buffer) + idx += 1 + } + + val copies = keys.copy() + results.put(copies, new InputBufferSeens(copies, buffer)) + } } - Iterator(resultProjection(aggregateResults)) + createIterator(aggregates, results.valuesIterator) } - } else { - child.execute().mapPartitions { iter => - val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) + } + } +} + +case class AggregatePostShuffle( + groupingExpressions: Seq[Expression], + boundProjection: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PostShuffle { + + override def output = boundProjection.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + + override def mode: Mode = FINAL // merge & terminate will be called + + override def execute() = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val aggregates = aggregateFunctionBinds.map(_.aggregate) + if (groupingExpressions.isEmpty) { + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + idx += 1 + } + + while (iter.hasNext) { + val currentRow = iter.next() + + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, buffer) + idx += 1 + } + } - var currentRow: Row = null + createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) + } else { + val results = new mutable.HashMap[Row, InputBufferSeens]() while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) + val currentRow = iter.next() + val keys = groupByProjection(currentRow) + results.get(keys) match { + case Some(pair) => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.merge(currentRow, pair.buffer) + idx += 1 + } + case None => + val buffer = new GenericMutableRow(bufferSchema.length) + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.reset(buffer) + ae.merge(currentRow, buffer) + idx += 1 + } + results.put(keys.copy(), new InputBufferSeens(currentRow.copy(), buffer)) } + } + + createIterator(aggregates, results.valuesIterator) + } + } + } +} - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) - i += 1 +// TODO Currently even if only a single DISTINCT exists in the aggregate expressions, we will +// not do partial aggregation (aggregating before shuffling), all of the data have to be shuffled +// to the reduce side and do aggregation directly, this probably causes the performance regression +// for Aggregation Function like CountDistinct etc. +case class DistinctAggregate( + groupingExpressions: Seq[Expression], + projection: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with PreShuffle with PostShuffle { + override def output = boundProjection.map(_.toAttribute) + + override def requiredChildDistribution: Seq[Distribution] = if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + + override def mode: Mode = COMPLETE // iterate() & terminate() will be called + + override def execute() = attachTree(this, "execute") { + child.execute().mapPartitions { iter => + val aggregates = aggregateFunctionBinds.map(_.aggregate) + if (groupingExpressions.isEmpty) { + val buffer = new GenericMutableRow(bufferSchema.length) + // TODO save the memory only for those DISTINCT aggregate expressions + val seens = new Array[mutable.HashSet[Any]](aggregateFunctionBinds.length) + + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + ae.reset(buffer) + + if (ae.distinct) { + seens(idx) = new mutable.HashSet[Any]() } + + idx += 1 } + val ibs = new InputBufferSeens().withBuffer(buffer).withSeens(seens) - new Iterator[Row] { - private[this] val hashTableIter = hashTable.entrySet().iterator() - private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new InterpretedMutableProjection( - resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow4 - - override final def hasNext: Boolean = hashTableIter.hasNext - - override final def next(): Row = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue - - var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) - i += 1 + while (iter.hasNext) { + val currentRow = iter.next() + + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + + if (ae.distinct) { + if (!seens(idx).contains(value)) { + ae.iterate(value, buffer) + seens(idx).add(value) + } + } else { + ae.iterate(value, buffer) } - resultProjection(joinedRow(aggregateResults, currentGroup)) + idx += 1 } } + + createIterator(aggregates, Iterator(ibs)) + } else { + val results = new mutable.HashMap[Row, InputBufferSeens]() + + while (iter.hasNext) { + val currentRow = iter.next() + + val keys = groupByProjection(currentRow) + results.get(keys) match { + case Some(inputBufferSeens) => + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + + if (ae.distinct) { + if (!inputBufferSeens.seens(idx).contains(value)) { + ae.iterate(value, inputBufferSeens.buffer) + inputBufferSeens.seens(idx).add(value) + } + } else { + ae.iterate(value, inputBufferSeens.buffer) + } + idx += 1 + } + case None => + val buffer = new GenericMutableRow(bufferSchema.length) + // TODO save the memory only for those DISTINCT aggregate expressions + val seens = new Array[mutable.HashSet[Any]](aggregateFunctionBinds.length) + + var idx = 0 + while (idx < aggregateFunctionBinds.length) { + val ae = aggregates(idx) + val value = ae.eval(currentRow) + ae.reset(buffer) + ae.iterate(value, buffer) + + if (ae.distinct) { + val seen = new mutable.HashSet[Any]() + if (value != null) { + seen.add(value) + } + seens.update(idx, seen) + } + + idx += 1 + } + results.put(keys.copy(), new InputBufferSeens(currentRow.copy(), buffer, seens)) + } + } + + createIterator(aggregates, results.valuesIterator) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 89682d25ca7dc..94bdb32d2831d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -83,7 +83,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case s @ Sum(expr) => + case s @ Sum(expr, distinct) => val calcType = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -109,7 +109,7 @@ case class GeneratedAggregate( AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case a @ Average(expr) => + case a @ Average(expr, distinct) => val calcType = expr.dataType match { case DecimalType.Fixed(_, _) => @@ -164,29 +164,6 @@ case class GeneratedAggregate( initialValue :: Nil, updateMax :: Nil, currentMax) - - case CollectHashSet(Seq(expr)) => - val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() - val initialValue = NewSet(expr.dataType) - val addToSet = AddItemToSet(expr, set) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - addToSet :: Nil, - set) - - case CombineSetsAndCount(inputSet) => - val ArrayType(inputType, _) = inputSet.dataType - val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(inputType) - val collectSets = CombineSets(set, inputSet) - - AggregateEvaluation( - set :: Nil, - initialValue :: Nil, - collectSets :: Nil, - CountSet(set)) } val computationSchema = computeFunctions.flatMap(_.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 2b581152e5f77..5311771b3d319 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -112,55 +112,27 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object HashAggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Aggregations that can be performed in two phases, before and after the shuffle. - - // Cases where all aggregates can be codegened. case PartialAggregation( namedGroupingAttributes, rewrittenAggregateExpressions, groupingExpressions, - partialComputation, - child) - if canBeCodeGened( - allAggregates(partialComputation) ++ - allAggregates(rewrittenAggregateExpressions)) && - codegenEnabled => - execution.GeneratedAggregate( - partial = false, + aggregateExpressions, + child) => + execution.AggregatePostShuffle( namedGroupingAttributes, rewrittenAggregateExpressions, - execution.GeneratedAggregate( - partial = true, + execution.AggregatePreShuffle( groupingExpressions, - partialComputation, + aggregateExpressions, + namedGroupingAttributes, planLater(child))) :: Nil - - // Cases where some aggregate can not be codegened - case PartialAggregation( - namedGroupingAttributes, - rewrittenAggregateExpressions, - groupingExpressions, - partialComputation, - child) => - execution.Aggregate( - partial = false, - namedGroupingAttributes, - rewrittenAggregateExpressions, - execution.Aggregate( - partial = true, - groupingExpressions, - partialComputation, - planLater(child))) :: Nil - case _ => Nil } - def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists { - case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false - // The generated set implementation is pretty limited ATM. - case CollectHashSet(exprs) if exprs.size == 1 && - Seq(IntegerType, LongType).contains(exprs.head.dataType) => false - case _ => true - } + def containsDistinct(aggregateExpressions: Seq[NamedExpression]) = + aggregateExpressions.flatMap(_.collect { + case ae: AggregateExpression if ae.distinct => true + }).contains(true) def allAggregates(exprs: Seq[Expression]): Seq[AggregateExpression] = exprs.flatMap(_.collect { case a: AggregateExpression => a }) @@ -281,7 +253,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Expand(projections, output, child) => execution.Expand(projections, output, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child)) :: Nil + execution.DistinctAggregate(group, agg, planLater(child)) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 111e751588a8b..2c9ef0949520e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -127,7 +127,7 @@ object functions { * * @group agg_funcs */ - def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def sumDistinct(e: Column): Column = Sum(e.expr, true) /** * Aggregate function: returns the sum of distinct values in the expression. @@ -177,7 +177,7 @@ object functions { * * @group agg_funcs */ - def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column): Column = CountDistinct(e.expr :: Nil) // TODO /** * Aggregate function: returns the approximate number of distinct items in a group. @@ -191,7 +191,7 @@ object functions { * * @group agg_funcs */ - def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd) + def approxCountDistinct(e: Column, rsd: Double): Column = CountDistinct(e.expr :: Nil) // TODO /** * Aggregate function: returns the approximate number of distinct items in a group. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 523be56df65ba..3f3924364124e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -43,24 +43,27 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed - val planned = HashAggregation(query).head - val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } + val query = testData.groupBy('value).agg(count('key)).queryExecution.sparkPlan + val preshuffles = query.collect { case n: AggregatePreShuffle => n } + val postshuffles = query.collect { case n: AggregatePostShuffle => n } - assert(aggregations.size === 2) + assert(preshuffles.size === 1) + assert(preshuffles.size === 1) } test("count distinct is partially aggregated") { - val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.sparkPlan + // TODO currently only reducer side aggregation support for DISTINCT + val shuffles = query.collect { case n: DistinctAggregate => n } + assert(shuffles.size === 1) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed - val planned = HashAggregation(query) - assert(planned.nonEmpty) + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.sparkPlan + // TODO currently only reducer side aggregation support for DISTINCT + val shuffles = query.collect { case n: DistinctAggregate => n } + assert(shuffles.size === 1) } test("sizeInBytes estimation of limit operator for broadcast hash join optimization") { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 4afa2e71d77cc..cd3361a2466e8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.{io => hiveIo} @@ -30,6 +31,15 @@ import org.apache.spark.sql.types._ /* Implicit conversions */ import scala.collection.JavaConversions._ +private[hive] trait HiveUDAFMode { + def toHiveMode(m: Mode) = m match { + case PARTIAL1 => GenericUDAFEvaluator.Mode.PARTIAL1 + case PARTIAL2 => GenericUDAFEvaluator.Mode.PARTIAL2 + case FINAL => GenericUDAFEvaluator.Mode.FINAL + case COMPLETE => GenericUDAFEvaluator.Mode.COMPLETE + } +} + /** * 1. The Underlying data type in catalyst and in Hive * In catalyst: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index c45c4ad70fae9..a5c9db58139fa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1072,14 +1072,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C UnresolvedStar(Some(name)) /* Aggregate Functions */ + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(COUNT(), Nil) :: arg :: Nil) => Count(nodeToExpr(arg)) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => Count(Literal(1)) - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => CountDistinct(args.map(nodeToExpr)) case Token("TOK_FUNCTION", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg)) - case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) + /* Distinct Aggregate Functions */ + case Token("TOK_FUNCTIONDI", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg), true) + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + CountDistinct(args.map(nodeToExpr)) + case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => Sum(nodeToExpr(arg), true) + case Token("TOK_FUNCTIONDI", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg)) + case Token("TOK_FUNCTIONDI", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg)) /* System functions about string operations */ case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg)) @@ -1209,6 +1214,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => UnresolvedFunction(name, args.map(nodeToExpr)) + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), true) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => UnresolvedFunction(name, UnresolvedStar(None) :: Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 47305571e579e..51824abef2162 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Generate, Project, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ + import org.apache.spark.sql.catalyst.analysis.MultiAlias import org.apache.spark.sql.catalyst.errors.TreeNodeException @@ -47,25 +48,28 @@ private[hive] abstract class HiveFunctionRegistry def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - def lookupFunction(name: String, children: Seq[Expression]): Expression = { + def lookupFunction( + name: String, + children: Seq[Expression], + distinct: Boolean = false): Expression = { // We only look it up to see if it exists, but do not include it in the HiveUDF since it is // not always serializable. val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( sys.error(s"Couldn't find function $name")) + val funcClazz = functionInfo.getFunctionClass val functionClassName = functionInfo.getFunctionClass.getName - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + if (classOf[UDF].isAssignableFrom(funcClazz)) { HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + } else if (classOf[GenericUDF].isAssignableFrom(funcClazz)) { HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(funcClazz)) { + HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children, distinct, false) + } else if (classOf[UDAF].isAssignableFrom(funcClazz)) { + HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children, distinct, true) + } else if (classOf[GenericUDTF].isAssignableFrom(funcClazz)) { HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") @@ -157,7 +161,7 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr @transient protected lazy val isUDFDeterministic = { val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) + udfType != null && udfType.deterministic() } override def foldable: Boolean = @@ -191,74 +195,118 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr private[hive] case class HiveGenericUdaf( funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression + children: Seq[Expression], + distinct: Boolean, + isUDAF: Boolean) extends AggregateExpression with HiveInspectors { - type UDFType = AbstractGenericUDAFResolver + protected def createEvaluator = resolver.getEvaluator( + new SimpleGenericUDAFParameterInfo(inspectors, false, false)) + + // Hive UDAF evaluator @transient - protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() + lazy val evaluator = createEvaluator @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + protected lazy val resolver: AbstractGenericUDAFResolver = if (isUDAF) { + // if it's UDAF, we need the UDAF bridge + new GenericUDAFBridge(funcWrapper.createFunction()) + } else { + funcWrapper.createFunction() } + // Output data object inspector @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) + lazy val objectInspector = createEvaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + // Aggregation Buffer Inspector + @transient + lazy val bufferObjectInspector = { + createEvaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) -} - -/** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression - with HiveInspectors { - - type UDFType = UDAF - + // Input arguments object inspectors @transient - protected lazy val resolver: AbstractGenericUDAFResolver = - new GenericUDAFBridge(funcWrapper.createFunction()) + lazy val inspectors = children.map(toInspector).toArray @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + override val distinctLike: Boolean = { + val annotation = evaluator.getClass().getAnnotation(classOf[HiveUDFType]) + if (annotation == null || !annotation.distinctLike()) false else true + } + override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + + // Aggregation Buffer Data Type, We assume only 1 element for the Hive Aggregation Buffer + // It will be StructType if more than 1 element (Actually will be StructSettableObjectInspector) + override def bufferDataType: Seq[DataType] = inspectorToDataType(bufferObjectInspector) :: Nil + + // Output data type + override def dataType: DataType = inspectorToDataType(objectInspector) + + /////////////////////////////////////////////////////////////////////////////////////////////// + // The following code will be called within the executors // + /////////////////////////////////////////////////////////////////////////////////////////////// + @transient var bound: BoundReference = _ + + override def initialBoundReference(buffers: Seq[BoundReference]) = { + bound = buffers(0) + mode match { + case FINAL => evaluator.init(GenericUDAFEvaluator.Mode.FINAL, Array(bufferObjectInspector)) + case COMPLETE => evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + case PARTIAL1 => evaluator.init(GenericUDAFEvaluator.Mode.PARTIAL1, inspectors) + } } - @transient - protected lazy val inspectors = children.map(toInspector) + // Initialize (reinitialize) the aggregation buffer + override def reset(buf: MutableRow): Unit = { + val buffer = evaluator.getNewAggregationBuffer + .asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] + evaluator.reset(buffer) + // This is a hack, we never use the mutable row as buffer, but define our own buffer, + // which is set as the first element of the buffer + buf(bound) = buffer + } - def dataType: DataType = inspectorToDataType(objectInspector) + // Expect the aggregate function fills the aggregation buffer when fed with each value + // in the group + override def iterate(arguments: Any, buf: MutableRow): Unit = { + val args = arguments.asInstanceOf[Seq[AnyRef]].zip(inspectors).map { + case (value, oi) => wrap(value, oi) + }.toArray - def nullable: Boolean = true + evaluator.iterate( + buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal), + args) + } - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + // Merge 2 aggregation buffer, and write back to the later one + override def merge(value: Row, buf: MutableRow): Unit = { + val buffer = buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal) + evaluator.merge(buffer, wrap(value.get(bound.ordinal), bufferObjectInspector)) + } + + @deprecated + override def terminatePartial(buf: MutableRow): Unit = { + val buffer = buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal) + // this is for serialization + buf(bound) = unwrap(evaluator.terminatePartial(buffer), bufferObjectInspector) } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + // Output the final result by feeding the aggregation buffer + override def terminate(input: Row): Any = { + unwrap(evaluator.terminate( + input.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal)), + objectInspector) + } } /** * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow - * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning - * dependent operations like calls to `close()` before producing output will not operate the same as - * in Hive. However, in practice this should not affect compatibility for most sane UDTFs + * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not + * allow Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning + * dependent operations like calls to `close()` before producing output will not operate the same + * asin Hive. However, in practice this should not affect compatibility for most sane UDTFs * (e.g. explode or GenericUDTFParseUrlTuple). * * Operators that require maintaining state in between input rows should instead be implemented as @@ -333,9 +381,6 @@ private[hive] case class HiveGenericUdtf( } } -/** - * Resolve Udtfs Alias. - */ private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case p @ Project(projectList, _) @@ -349,46 +394,3 @@ private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { } } -private[hive] case class HiveUdafFunction( - funcWrapper: HiveFunctionWrapper, - exprs: Seq[Expression], - base: AggregateExpression, - isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction - with HiveInspectors { - - def this() = this(null, null, null) - - private val resolver = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[AbstractGenericUDAFResolver]() - } - - private val inspectors = exprs.map(toInspector).toArray - - private val function = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - - private val buffer = - function.getNewAggregationBuffer - - override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) - - @transient - val inputProjection = new InterpretedProjection(exprs) - - @transient - protected lazy val cached = new Array[AnyRef](exprs.length) - - def update(input: Row): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) - } -} - diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-dc640a3b0e7f23e9052c454a739ba9db b/sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-dc640a3b0e7f23e9052c454a739ba9db new file mode 100644 index 0000000000000..9b2e37e3af110 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #1-0-dc640a3b0e7f23e9052c454a739ba9db @@ -0,0 +1,309 @@ +3 3 0 0 +5 1 2 2 +7 1 4 4 +8 3 5 5 +11 1 8 8 +12 1 9 9 +13 1 10 10 +14 1 11 11 +15 2 12 12 +18 2 15 15 +20 1 17 17 +21 2 18 18 +22 1 19 19 +23 1 20 20 +27 2 24 24 +29 2 26 26 +30 1 27 27 +31 1 28 28 +33 1 30 30 +36 1 33 33 +37 1 34 34 +38 3 35 35 +40 2 37 37 +44 1 41 41 +45 2 42 42 +46 1 43 43 +47 1 44 44 +50 1 47 47 +54 2 51 51 +56 1 53 53 +57 1 54 54 +60 1 57 57 +61 2 58 58 +67 1 64 64 +68 1 65 65 +69 1 66 66 +70 2 67 67 +72 1 69 69 +73 3 70 70 +75 2 72 72 +77 1 74 74 +79 2 76 76 +80 1 77 77 +81 1 78 78 +83 1 80 80 +85 1 82 82 +86 2 83 83 +87 2 84 84 +88 1 85 85 +89 1 86 86 +90 1 87 87 +93 3 90 90 +95 1 92 92 +98 2 95 95 +99 1 96 96 +100 2 97 97 +101 2 98 98 +103 2 100 100 +106 2 103 103 +107 2 104 104 +108 1 105 105 +114 1 111 111 +116 2 113 113 +117 1 114 114 +119 1 116 116 +121 2 118 118 +122 3 119 119 +123 2 120 120 +128 2 125 125 +129 1 126 126 +131 3 128 128 +132 2 129 129 +134 1 131 131 +136 1 133 133 +137 2 134 134 +139 1 136 136 +140 2 137 137 +141 4 138 138 +146 1 143 143 +148 1 145 145 +149 2 146 146 +152 2 149 149 +153 1 150 150 +155 2 152 152 +156 1 153 153 +158 1 155 155 +159 1 156 156 +160 1 157 157 +161 1 158 158 +163 1 160 160 +165 1 162 162 +166 1 163 163 +167 2 164 164 +168 2 165 165 +169 1 166 166 +170 3 167 167 +171 1 168 168 +172 4 169 169 +173 1 170 170 +175 2 172 172 +177 2 174 174 +178 2 175 175 +179 2 176 176 +180 1 177 177 +181 1 178 178 +182 2 179 179 +183 1 180 180 +184 1 181 181 +186 1 183 183 +189 1 186 186 +190 3 187 187 +192 1 189 189 +193 1 190 190 +194 2 191 191 +195 1 192 192 +196 3 193 193 +197 1 194 194 +198 2 195 195 +199 1 196 196 +200 2 197 197 +202 3 199 199 +203 2 200 200 +204 1 201 201 +205 1 202 202 +206 2 203 203 +208 2 205 205 +210 2 207 207 +211 3 208 208 +212 2 209 209 +216 2 213 213 +217 1 214 214 +219 2 216 216 +220 2 217 217 +221 1 218 218 +222 2 219 219 +224 2 221 221 +225 1 222 222 +226 2 223 223 +227 2 224 224 +229 1 226 226 +231 1 228 228 +232 2 229 229 +233 5 230 230 +236 2 233 233 +238 1 235 235 +240 2 237 237 +241 2 238 238 +242 2 239 239 +244 1 241 241 +245 2 242 242 +247 1 244 244 +250 1 247 247 +251 1 248 248 +252 1 249 249 +255 1 252 252 +258 2 255 255 +259 2 256 256 +260 1 257 257 +261 1 258 258 +263 1 260 260 +265 1 262 262 +266 1 263 263 +268 2 265 265 +269 1 266 266 +275 2 272 272 +276 3 273 273 +277 1 274 274 +278 1 275 275 +280 4 277 277 +281 2 278 278 +283 2 280 280 +284 2 281 281 +285 2 282 282 +286 1 283 283 +287 1 284 284 +288 1 285 285 +289 1 286 286 +290 1 287 287 +291 2 288 288 +292 1 289 289 +294 1 291 291 +295 1 292 292 +299 1 296 296 +301 3 298 298 +305 1 302 302 +308 1 305 305 +309 1 306 306 +310 2 307 307 +311 1 308 308 +312 2 309 309 +313 1 310 310 +314 3 311 311 +318 1 315 315 +319 3 316 316 +320 2 317 317 +321 3 318 318 +324 2 321 321 +325 2 322 322 +326 1 323 323 +328 2 325 325 +330 3 327 327 +334 2 331 331 +335 1 332 332 +336 2 333 333 +338 1 335 335 +339 1 336 336 +341 1 338 338 +342 1 339 339 +344 1 341 341 +345 2 342 342 +347 2 344 344 +348 1 345 345 +351 5 348 348 +354 1 351 351 +356 2 353 353 +359 1 356 356 +363 1 360 360 +365 1 362 362 +367 1 364 364 +368 1 365 365 +369 1 366 366 +370 2 367 367 +371 1 368 368 +372 3 369 369 +376 1 373 373 +377 1 374 374 +378 1 375 375 +380 1 377 377 +381 1 378 378 +382 1 379 379 +385 2 382 382 +387 3 384 384 +389 1 386 386 +392 1 389 389 +395 1 392 392 +396 1 393 393 +397 1 394 394 +398 2 395 395 +399 3 396 396 +400 2 397 397 +402 2 399 399 +403 1 400 400 +404 5 401 401 +405 1 402 402 +406 3 403 403 +407 2 404 404 +409 4 406 406 +410 1 407 407 +412 3 409 409 +414 1 411 411 +416 2 413 413 +417 2 414 414 +420 3 417 417 +421 1 418 418 +422 1 419 419 +424 1 421 421 +427 2 424 424 +430 1 427 427 +432 2 429 429 +433 3 430 430 +434 3 431 431 +435 1 432 432 +438 1 435 435 +439 1 436 436 +440 1 437 437 +441 3 438 438 +442 2 439 439 +446 1 443 443 +447 1 444 444 +449 1 446 446 +451 1 448 448 +452 1 449 449 +455 1 452 452 +456 1 453 453 +457 3 454 454 +458 1 455 455 +460 1 457 457 +461 2 458 458 +462 2 459 459 +463 1 460 460 +465 2 462 462 +466 2 463 463 +469 3 466 466 +470 1 467 467 +471 4 468 468 +472 5 469 469 +473 1 470 470 +475 1 472 472 +478 1 475 475 +480 1 477 477 +481 2 478 478 +482 1 479 479 +483 3 480 480 +484 1 481 481 +485 1 482 482 +486 1 483 483 +487 1 484 484 +488 1 485 485 +490 1 487 487 +492 4 489 489 +493 1 490 490 +494 1 491 491 +495 2 492 492 +496 1 493 493 +497 1 494 494 +498 1 495 495 +499 1 496 496 +500 1 497 497 +501 3 498 498 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-4b1bcbd566a255e2b694ec9d8bacb825 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-4b1bcbd566a255e2b694ec9d8bacb825 new file mode 100644 index 0000000000000..af074a6ab2ab8 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #2-0-4b1bcbd566a255e2b694ec9d8bacb825 @@ -0,0 +1,309 @@ +3 3 0 0 0 +5 1 2 2 2 +7 1 4 4 4 +8 3 5 5 15 +11 1 8 8 8 +12 1 9 9 9 +13 1 10 10 10 +14 1 11 11 11 +15 2 12 12 24 +18 2 15 15 30 +20 1 17 17 17 +21 2 18 18 36 +22 1 19 19 19 +23 1 20 20 20 +27 2 24 24 48 +29 2 26 26 52 +30 1 27 27 27 +31 1 28 28 28 +33 1 30 30 30 +36 1 33 33 33 +37 1 34 34 34 +38 3 35 35 105 +40 2 37 37 74 +44 1 41 41 41 +45 2 42 42 84 +46 1 43 43 43 +47 1 44 44 44 +50 1 47 47 47 +54 2 51 51 102 +56 1 53 53 53 +57 1 54 54 54 +60 1 57 57 57 +61 2 58 58 116 +67 1 64 64 64 +68 1 65 65 65 +69 1 66 66 66 +70 2 67 67 134 +72 1 69 69 69 +73 3 70 70 210 +75 2 72 72 144 +77 1 74 74 74 +79 2 76 76 152 +80 1 77 77 77 +81 1 78 78 78 +83 1 80 80 80 +85 1 82 82 82 +86 2 83 83 166 +87 2 84 84 168 +88 1 85 85 85 +89 1 86 86 86 +90 1 87 87 87 +93 3 90 90 270 +95 1 92 92 92 +98 2 95 95 190 +99 1 96 96 96 +100 2 97 97 194 +101 2 98 98 196 +103 2 100 100 200 +106 2 103 103 206 +107 2 104 104 208 +108 1 105 105 105 +114 1 111 111 111 +116 2 113 113 226 +117 1 114 114 114 +119 1 116 116 116 +121 2 118 118 236 +122 3 119 119 357 +123 2 120 120 240 +128 2 125 125 250 +129 1 126 126 126 +131 3 128 128 384 +132 2 129 129 258 +134 1 131 131 131 +136 1 133 133 133 +137 2 134 134 268 +139 1 136 136 136 +140 2 137 137 274 +141 4 138 138 552 +146 1 143 143 143 +148 1 145 145 145 +149 2 146 146 292 +152 2 149 149 298 +153 1 150 150 150 +155 2 152 152 304 +156 1 153 153 153 +158 1 155 155 155 +159 1 156 156 156 +160 1 157 157 157 +161 1 158 158 158 +163 1 160 160 160 +165 1 162 162 162 +166 1 163 163 163 +167 2 164 164 328 +168 2 165 165 330 +169 1 166 166 166 +170 3 167 167 501 +171 1 168 168 168 +172 4 169 169 676 +173 1 170 170 170 +175 2 172 172 344 +177 2 174 174 348 +178 2 175 175 350 +179 2 176 176 352 +180 1 177 177 177 +181 1 178 178 178 +182 2 179 179 358 +183 1 180 180 180 +184 1 181 181 181 +186 1 183 183 183 +189 1 186 186 186 +190 3 187 187 561 +192 1 189 189 189 +193 1 190 190 190 +194 2 191 191 382 +195 1 192 192 192 +196 3 193 193 579 +197 1 194 194 194 +198 2 195 195 390 +199 1 196 196 196 +200 2 197 197 394 +202 3 199 199 597 +203 2 200 200 400 +204 1 201 201 201 +205 1 202 202 202 +206 2 203 203 406 +208 2 205 205 410 +210 2 207 207 414 +211 3 208 208 624 +212 2 209 209 418 +216 2 213 213 426 +217 1 214 214 214 +219 2 216 216 432 +220 2 217 217 434 +221 1 218 218 218 +222 2 219 219 438 +224 2 221 221 442 +225 1 222 222 222 +226 2 223 223 446 +227 2 224 224 448 +229 1 226 226 226 +231 1 228 228 228 +232 2 229 229 458 +233 5 230 230 1150 +236 2 233 233 466 +238 1 235 235 235 +240 2 237 237 474 +241 2 238 238 476 +242 2 239 239 478 +244 1 241 241 241 +245 2 242 242 484 +247 1 244 244 244 +250 1 247 247 247 +251 1 248 248 248 +252 1 249 249 249 +255 1 252 252 252 +258 2 255 255 510 +259 2 256 256 512 +260 1 257 257 257 +261 1 258 258 258 +263 1 260 260 260 +265 1 262 262 262 +266 1 263 263 263 +268 2 265 265 530 +269 1 266 266 266 +275 2 272 272 544 +276 3 273 273 819 +277 1 274 274 274 +278 1 275 275 275 +280 4 277 277 1108 +281 2 278 278 556 +283 2 280 280 560 +284 2 281 281 562 +285 2 282 282 564 +286 1 283 283 283 +287 1 284 284 284 +288 1 285 285 285 +289 1 286 286 286 +290 1 287 287 287 +291 2 288 288 576 +292 1 289 289 289 +294 1 291 291 291 +295 1 292 292 292 +299 1 296 296 296 +301 3 298 298 894 +305 1 302 302 302 +308 1 305 305 305 +309 1 306 306 306 +310 2 307 307 614 +311 1 308 308 308 +312 2 309 309 618 +313 1 310 310 310 +314 3 311 311 933 +318 1 315 315 315 +319 3 316 316 948 +320 2 317 317 634 +321 3 318 318 954 +324 2 321 321 642 +325 2 322 322 644 +326 1 323 323 323 +328 2 325 325 650 +330 3 327 327 981 +334 2 331 331 662 +335 1 332 332 332 +336 2 333 333 666 +338 1 335 335 335 +339 1 336 336 336 +341 1 338 338 338 +342 1 339 339 339 +344 1 341 341 341 +345 2 342 342 684 +347 2 344 344 688 +348 1 345 345 345 +351 5 348 348 1740 +354 1 351 351 351 +356 2 353 353 706 +359 1 356 356 356 +363 1 360 360 360 +365 1 362 362 362 +367 1 364 364 364 +368 1 365 365 365 +369 1 366 366 366 +370 2 367 367 734 +371 1 368 368 368 +372 3 369 369 1107 +376 1 373 373 373 +377 1 374 374 374 +378 1 375 375 375 +380 1 377 377 377 +381 1 378 378 378 +382 1 379 379 379 +385 2 382 382 764 +387 3 384 384 1152 +389 1 386 386 386 +392 1 389 389 389 +395 1 392 392 392 +396 1 393 393 393 +397 1 394 394 394 +398 2 395 395 790 +399 3 396 396 1188 +400 2 397 397 794 +402 2 399 399 798 +403 1 400 400 400 +404 5 401 401 2005 +405 1 402 402 402 +406 3 403 403 1209 +407 2 404 404 808 +409 4 406 406 1624 +410 1 407 407 407 +412 3 409 409 1227 +414 1 411 411 411 +416 2 413 413 826 +417 2 414 414 828 +420 3 417 417 1251 +421 1 418 418 418 +422 1 419 419 419 +424 1 421 421 421 +427 2 424 424 848 +430 1 427 427 427 +432 2 429 429 858 +433 3 430 430 1290 +434 3 431 431 1293 +435 1 432 432 432 +438 1 435 435 435 +439 1 436 436 436 +440 1 437 437 437 +441 3 438 438 1314 +442 2 439 439 878 +446 1 443 443 443 +447 1 444 444 444 +449 1 446 446 446 +451 1 448 448 448 +452 1 449 449 449 +455 1 452 452 452 +456 1 453 453 453 +457 3 454 454 1362 +458 1 455 455 455 +460 1 457 457 457 +461 2 458 458 916 +462 2 459 459 918 +463 1 460 460 460 +465 2 462 462 924 +466 2 463 463 926 +469 3 466 466 1398 +470 1 467 467 467 +471 4 468 468 1872 +472 5 469 469 2345 +473 1 470 470 470 +475 1 472 472 472 +478 1 475 475 475 +480 1 477 477 477 +481 2 478 478 956 +482 1 479 479 479 +483 3 480 480 1440 +484 1 481 481 481 +485 1 482 482 482 +486 1 483 483 483 +487 1 484 484 484 +488 1 485 485 485 +490 1 487 487 487 +492 4 489 489 1956 +493 1 490 490 490 +494 1 491 491 491 +495 2 492 492 984 +496 1 493 493 493 +497 1 494 494 494 +498 1 495 495 495 +499 1 496 496 496 +500 1 497 497 497 +501 3 498 498 1494 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-e4e01312d01a7a08cff2ac43196f6ea4 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-e4e01312d01a7a08cff2ac43196f6ea4 new file mode 100644 index 0000000000000..85c30588b72e4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #3-0-e4e01312d01a7a08cff2ac43196f6ea4 @@ -0,0 +1,309 @@ +3 1 0 0 0 +5 1 2 2 2 +7 1 4 4 4 +8 1 5 5 5 +11 1 8 8 8 +12 1 9 9 9 +13 1 10 10 10 +14 1 11 11 11 +15 1 12 12 12 +18 1 15 15 15 +20 1 17 17 17 +21 1 18 18 18 +22 1 19 19 19 +23 1 20 20 20 +27 1 24 24 24 +29 1 26 26 26 +30 1 27 27 27 +31 1 28 28 28 +33 1 30 30 30 +36 1 33 33 33 +37 1 34 34 34 +38 1 35 35 35 +40 1 37 37 37 +44 1 41 41 41 +45 1 42 42 42 +46 1 43 43 43 +47 1 44 44 44 +50 1 47 47 47 +54 1 51 51 51 +56 1 53 53 53 +57 1 54 54 54 +60 1 57 57 57 +61 1 58 58 58 +67 1 64 64 64 +68 1 65 65 65 +69 1 66 66 66 +70 1 67 67 67 +72 1 69 69 69 +73 1 70 70 70 +75 1 72 72 72 +77 1 74 74 74 +79 1 76 76 76 +80 1 77 77 77 +81 1 78 78 78 +83 1 80 80 80 +85 1 82 82 82 +86 1 83 83 83 +87 1 84 84 84 +88 1 85 85 85 +89 1 86 86 86 +90 1 87 87 87 +93 1 90 90 90 +95 1 92 92 92 +98 1 95 95 95 +99 1 96 96 96 +100 1 97 97 97 +101 1 98 98 98 +103 1 100 100 100 +106 1 103 103 103 +107 1 104 104 104 +108 1 105 105 105 +114 1 111 111 111 +116 1 113 113 113 +117 1 114 114 114 +119 1 116 116 116 +121 1 118 118 118 +122 1 119 119 119 +123 1 120 120 120 +128 1 125 125 125 +129 1 126 126 126 +131 1 128 128 128 +132 1 129 129 129 +134 1 131 131 131 +136 1 133 133 133 +137 1 134 134 134 +139 1 136 136 136 +140 1 137 137 137 +141 1 138 138 138 +146 1 143 143 143 +148 1 145 145 145 +149 1 146 146 146 +152 1 149 149 149 +153 1 150 150 150 +155 1 152 152 152 +156 1 153 153 153 +158 1 155 155 155 +159 1 156 156 156 +160 1 157 157 157 +161 1 158 158 158 +163 1 160 160 160 +165 1 162 162 162 +166 1 163 163 163 +167 1 164 164 164 +168 1 165 165 165 +169 1 166 166 166 +170 1 167 167 167 +171 1 168 168 168 +172 1 169 169 169 +173 1 170 170 170 +175 1 172 172 172 +177 1 174 174 174 +178 1 175 175 175 +179 1 176 176 176 +180 1 177 177 177 +181 1 178 178 178 +182 1 179 179 179 +183 1 180 180 180 +184 1 181 181 181 +186 1 183 183 183 +189 1 186 186 186 +190 1 187 187 187 +192 1 189 189 189 +193 1 190 190 190 +194 1 191 191 191 +195 1 192 192 192 +196 1 193 193 193 +197 1 194 194 194 +198 1 195 195 195 +199 1 196 196 196 +200 1 197 197 197 +202 1 199 199 199 +203 1 200 200 200 +204 1 201 201 201 +205 1 202 202 202 +206 1 203 203 203 +208 1 205 205 205 +210 1 207 207 207 +211 1 208 208 208 +212 1 209 209 209 +216 1 213 213 213 +217 1 214 214 214 +219 1 216 216 216 +220 1 217 217 217 +221 1 218 218 218 +222 1 219 219 219 +224 1 221 221 221 +225 1 222 222 222 +226 1 223 223 223 +227 1 224 224 224 +229 1 226 226 226 +231 1 228 228 228 +232 1 229 229 229 +233 1 230 230 230 +236 1 233 233 233 +238 1 235 235 235 +240 1 237 237 237 +241 1 238 238 238 +242 1 239 239 239 +244 1 241 241 241 +245 1 242 242 242 +247 1 244 244 244 +250 1 247 247 247 +251 1 248 248 248 +252 1 249 249 249 +255 1 252 252 252 +258 1 255 255 255 +259 1 256 256 256 +260 1 257 257 257 +261 1 258 258 258 +263 1 260 260 260 +265 1 262 262 262 +266 1 263 263 263 +268 1 265 265 265 +269 1 266 266 266 +275 1 272 272 272 +276 1 273 273 273 +277 1 274 274 274 +278 1 275 275 275 +280 1 277 277 277 +281 1 278 278 278 +283 1 280 280 280 +284 1 281 281 281 +285 1 282 282 282 +286 1 283 283 283 +287 1 284 284 284 +288 1 285 285 285 +289 1 286 286 286 +290 1 287 287 287 +291 1 288 288 288 +292 1 289 289 289 +294 1 291 291 291 +295 1 292 292 292 +299 1 296 296 296 +301 1 298 298 298 +305 1 302 302 302 +308 1 305 305 305 +309 1 306 306 306 +310 1 307 307 307 +311 1 308 308 308 +312 1 309 309 309 +313 1 310 310 310 +314 1 311 311 311 +318 1 315 315 315 +319 1 316 316 316 +320 1 317 317 317 +321 1 318 318 318 +324 1 321 321 321 +325 1 322 322 322 +326 1 323 323 323 +328 1 325 325 325 +330 1 327 327 327 +334 1 331 331 331 +335 1 332 332 332 +336 1 333 333 333 +338 1 335 335 335 +339 1 336 336 336 +341 1 338 338 338 +342 1 339 339 339 +344 1 341 341 341 +345 1 342 342 342 +347 1 344 344 344 +348 1 345 345 345 +351 1 348 348 348 +354 1 351 351 351 +356 1 353 353 353 +359 1 356 356 356 +363 1 360 360 360 +365 1 362 362 362 +367 1 364 364 364 +368 1 365 365 365 +369 1 366 366 366 +370 1 367 367 367 +371 1 368 368 368 +372 1 369 369 369 +376 1 373 373 373 +377 1 374 374 374 +378 1 375 375 375 +380 1 377 377 377 +381 1 378 378 378 +382 1 379 379 379 +385 1 382 382 382 +387 1 384 384 384 +389 1 386 386 386 +392 1 389 389 389 +395 1 392 392 392 +396 1 393 393 393 +397 1 394 394 394 +398 1 395 395 395 +399 1 396 396 396 +400 1 397 397 397 +402 1 399 399 399 +403 1 400 400 400 +404 1 401 401 401 +405 1 402 402 402 +406 1 403 403 403 +407 1 404 404 404 +409 1 406 406 406 +410 1 407 407 407 +412 1 409 409 409 +414 1 411 411 411 +416 1 413 413 413 +417 1 414 414 414 +420 1 417 417 417 +421 1 418 418 418 +422 1 419 419 419 +424 1 421 421 421 +427 1 424 424 424 +430 1 427 427 427 +432 1 429 429 429 +433 1 430 430 430 +434 1 431 431 431 +435 1 432 432 432 +438 1 435 435 435 +439 1 436 436 436 +440 1 437 437 437 +441 1 438 438 438 +442 1 439 439 439 +446 1 443 443 443 +447 1 444 444 444 +449 1 446 446 446 +451 1 448 448 448 +452 1 449 449 449 +455 1 452 452 452 +456 1 453 453 453 +457 1 454 454 454 +458 1 455 455 455 +460 1 457 457 457 +461 1 458 458 458 +462 1 459 459 459 +463 1 460 460 460 +465 1 462 462 462 +466 1 463 463 463 +469 1 466 466 466 +470 1 467 467 467 +471 1 468 468 468 +472 1 469 469 469 +473 1 470 470 470 +475 1 472 472 472 +478 1 475 475 475 +480 1 477 477 477 +481 1 478 478 478 +482 1 479 479 479 +483 1 480 480 480 +484 1 481 481 481 +485 1 482 482 482 +486 1 483 483 483 +487 1 484 484 484 +488 1 485 485 485 +490 1 487 487 487 +492 1 489 489 489 +493 1 490 490 490 +494 1 491 491 491 +495 1 492 492 492 +496 1 493 493 493 +497 1 494 494 494 +498 1 495 495 495 +499 1 496 496 496 +500 1 497 497 497 +501 1 498 498 498 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #4-0-ff859636795b1019ad74567bf4ba095f b/sql/hive/src/test/resources/golden/aggregation with group by expressions #4-0-ff859636795b1019ad74567bf4ba095f new file mode 100644 index 0000000000000..85c30588b72e4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #4-0-ff859636795b1019ad74567bf4ba095f @@ -0,0 +1,309 @@ +3 1 0 0 0 +5 1 2 2 2 +7 1 4 4 4 +8 1 5 5 5 +11 1 8 8 8 +12 1 9 9 9 +13 1 10 10 10 +14 1 11 11 11 +15 1 12 12 12 +18 1 15 15 15 +20 1 17 17 17 +21 1 18 18 18 +22 1 19 19 19 +23 1 20 20 20 +27 1 24 24 24 +29 1 26 26 26 +30 1 27 27 27 +31 1 28 28 28 +33 1 30 30 30 +36 1 33 33 33 +37 1 34 34 34 +38 1 35 35 35 +40 1 37 37 37 +44 1 41 41 41 +45 1 42 42 42 +46 1 43 43 43 +47 1 44 44 44 +50 1 47 47 47 +54 1 51 51 51 +56 1 53 53 53 +57 1 54 54 54 +60 1 57 57 57 +61 1 58 58 58 +67 1 64 64 64 +68 1 65 65 65 +69 1 66 66 66 +70 1 67 67 67 +72 1 69 69 69 +73 1 70 70 70 +75 1 72 72 72 +77 1 74 74 74 +79 1 76 76 76 +80 1 77 77 77 +81 1 78 78 78 +83 1 80 80 80 +85 1 82 82 82 +86 1 83 83 83 +87 1 84 84 84 +88 1 85 85 85 +89 1 86 86 86 +90 1 87 87 87 +93 1 90 90 90 +95 1 92 92 92 +98 1 95 95 95 +99 1 96 96 96 +100 1 97 97 97 +101 1 98 98 98 +103 1 100 100 100 +106 1 103 103 103 +107 1 104 104 104 +108 1 105 105 105 +114 1 111 111 111 +116 1 113 113 113 +117 1 114 114 114 +119 1 116 116 116 +121 1 118 118 118 +122 1 119 119 119 +123 1 120 120 120 +128 1 125 125 125 +129 1 126 126 126 +131 1 128 128 128 +132 1 129 129 129 +134 1 131 131 131 +136 1 133 133 133 +137 1 134 134 134 +139 1 136 136 136 +140 1 137 137 137 +141 1 138 138 138 +146 1 143 143 143 +148 1 145 145 145 +149 1 146 146 146 +152 1 149 149 149 +153 1 150 150 150 +155 1 152 152 152 +156 1 153 153 153 +158 1 155 155 155 +159 1 156 156 156 +160 1 157 157 157 +161 1 158 158 158 +163 1 160 160 160 +165 1 162 162 162 +166 1 163 163 163 +167 1 164 164 164 +168 1 165 165 165 +169 1 166 166 166 +170 1 167 167 167 +171 1 168 168 168 +172 1 169 169 169 +173 1 170 170 170 +175 1 172 172 172 +177 1 174 174 174 +178 1 175 175 175 +179 1 176 176 176 +180 1 177 177 177 +181 1 178 178 178 +182 1 179 179 179 +183 1 180 180 180 +184 1 181 181 181 +186 1 183 183 183 +189 1 186 186 186 +190 1 187 187 187 +192 1 189 189 189 +193 1 190 190 190 +194 1 191 191 191 +195 1 192 192 192 +196 1 193 193 193 +197 1 194 194 194 +198 1 195 195 195 +199 1 196 196 196 +200 1 197 197 197 +202 1 199 199 199 +203 1 200 200 200 +204 1 201 201 201 +205 1 202 202 202 +206 1 203 203 203 +208 1 205 205 205 +210 1 207 207 207 +211 1 208 208 208 +212 1 209 209 209 +216 1 213 213 213 +217 1 214 214 214 +219 1 216 216 216 +220 1 217 217 217 +221 1 218 218 218 +222 1 219 219 219 +224 1 221 221 221 +225 1 222 222 222 +226 1 223 223 223 +227 1 224 224 224 +229 1 226 226 226 +231 1 228 228 228 +232 1 229 229 229 +233 1 230 230 230 +236 1 233 233 233 +238 1 235 235 235 +240 1 237 237 237 +241 1 238 238 238 +242 1 239 239 239 +244 1 241 241 241 +245 1 242 242 242 +247 1 244 244 244 +250 1 247 247 247 +251 1 248 248 248 +252 1 249 249 249 +255 1 252 252 252 +258 1 255 255 255 +259 1 256 256 256 +260 1 257 257 257 +261 1 258 258 258 +263 1 260 260 260 +265 1 262 262 262 +266 1 263 263 263 +268 1 265 265 265 +269 1 266 266 266 +275 1 272 272 272 +276 1 273 273 273 +277 1 274 274 274 +278 1 275 275 275 +280 1 277 277 277 +281 1 278 278 278 +283 1 280 280 280 +284 1 281 281 281 +285 1 282 282 282 +286 1 283 283 283 +287 1 284 284 284 +288 1 285 285 285 +289 1 286 286 286 +290 1 287 287 287 +291 1 288 288 288 +292 1 289 289 289 +294 1 291 291 291 +295 1 292 292 292 +299 1 296 296 296 +301 1 298 298 298 +305 1 302 302 302 +308 1 305 305 305 +309 1 306 306 306 +310 1 307 307 307 +311 1 308 308 308 +312 1 309 309 309 +313 1 310 310 310 +314 1 311 311 311 +318 1 315 315 315 +319 1 316 316 316 +320 1 317 317 317 +321 1 318 318 318 +324 1 321 321 321 +325 1 322 322 322 +326 1 323 323 323 +328 1 325 325 325 +330 1 327 327 327 +334 1 331 331 331 +335 1 332 332 332 +336 1 333 333 333 +338 1 335 335 335 +339 1 336 336 336 +341 1 338 338 338 +342 1 339 339 339 +344 1 341 341 341 +345 1 342 342 342 +347 1 344 344 344 +348 1 345 345 345 +351 1 348 348 348 +354 1 351 351 351 +356 1 353 353 353 +359 1 356 356 356 +363 1 360 360 360 +365 1 362 362 362 +367 1 364 364 364 +368 1 365 365 365 +369 1 366 366 366 +370 1 367 367 367 +371 1 368 368 368 +372 1 369 369 369 +376 1 373 373 373 +377 1 374 374 374 +378 1 375 375 375 +380 1 377 377 377 +381 1 378 378 378 +382 1 379 379 379 +385 1 382 382 382 +387 1 384 384 384 +389 1 386 386 386 +392 1 389 389 389 +395 1 392 392 392 +396 1 393 393 393 +397 1 394 394 394 +398 1 395 395 395 +399 1 396 396 396 +400 1 397 397 397 +402 1 399 399 399 +403 1 400 400 400 +404 1 401 401 401 +405 1 402 402 402 +406 1 403 403 403 +407 1 404 404 404 +409 1 406 406 406 +410 1 407 407 407 +412 1 409 409 409 +414 1 411 411 411 +416 1 413 413 413 +417 1 414 414 414 +420 1 417 417 417 +421 1 418 418 418 +422 1 419 419 419 +424 1 421 421 421 +427 1 424 424 424 +430 1 427 427 427 +432 1 429 429 429 +433 1 430 430 430 +434 1 431 431 431 +435 1 432 432 432 +438 1 435 435 435 +439 1 436 436 436 +440 1 437 437 437 +441 1 438 438 438 +442 1 439 439 439 +446 1 443 443 443 +447 1 444 444 444 +449 1 446 446 446 +451 1 448 448 448 +452 1 449 449 449 +455 1 452 452 452 +456 1 453 453 453 +457 1 454 454 454 +458 1 455 455 455 +460 1 457 457 457 +461 1 458 458 458 +462 1 459 459 459 +463 1 460 460 460 +465 1 462 462 462 +466 1 463 463 463 +469 1 466 466 466 +470 1 467 467 467 +471 1 468 468 468 +472 1 469 469 469 +473 1 470 470 470 +475 1 472 472 472 +478 1 475 475 475 +480 1 477 477 477 +481 1 478 478 478 +482 1 479 479 479 +483 1 480 480 480 +484 1 481 481 481 +485 1 482 482 482 +486 1 483 483 483 +487 1 484 484 484 +488 1 485 485 485 +490 1 487 487 487 +492 1 489 489 489 +493 1 490 490 490 +494 1 491 491 491 +495 1 492 492 492 +496 1 493 493 493 +497 1 494 494 494 +498 1 495 495 495 +499 1 496 496 496 +500 1 497 497 497 +501 1 498 498 498 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-5c7cdc7d4bc610cec923b54d3f3d696a b/sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-5c7cdc7d4bc610cec923b54d3f3d696a new file mode 100644 index 0000000000000..cdbe53cffd36e --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #5-0-5c7cdc7d4bc610cec923b54d3f3d696a @@ -0,0 +1,309 @@ +6 4 6 6 6 +10 6 10 10 10 +14 8 14 14 14 +16 9 16 16 16 +22 12 22 22 22 +24 13 24 24 24 +26 14 26 26 26 +28 15 28 28 28 +30 16 30 30 30 +36 19 36 36 36 +40 21 40 40 40 +42 22 42 42 42 +44 23 44 44 44 +46 24 46 46 46 +54 28 54 54 54 +58 30 58 58 58 +60 31 60 60 60 +62 32 62 62 62 +66 34 66 66 66 +72 37 72 72 72 +74 38 74 74 74 +76 39 76 76 76 +80 41 80 80 80 +88 45 88 88 88 +90 46 90 90 90 +92 47 92 92 92 +94 48 94 94 94 +100 51 100 100 100 +108 55 108 108 108 +112 57 112 112 112 +114 58 114 114 114 +120 61 120 120 120 +122 62 122 122 122 +134 68 134 134 134 +136 69 136 136 136 +138 70 138 138 138 +140 71 140 140 140 +144 73 144 144 144 +146 74 146 146 146 +150 76 150 150 150 +154 78 154 154 154 +158 80 158 158 158 +160 81 160 160 160 +162 82 162 162 162 +166 84 166 166 166 +170 86 170 170 170 +172 87 172 172 172 +174 88 174 174 174 +176 89 176 176 176 +178 90 178 178 178 +180 91 180 180 180 +186 94 186 186 186 +190 96 190 190 190 +196 99 196 196 196 +198 100 198 198 198 +200 101 200 200 200 +202 102 202 202 202 +206 104 206 206 206 +212 107 212 212 212 +214 108 214 214 214 +216 109 216 216 216 +228 115 228 228 228 +232 117 232 232 232 +234 118 234 234 234 +238 120 238 238 238 +242 122 242 242 242 +244 123 244 244 244 +246 124 246 246 246 +256 129 256 256 256 +258 130 258 258 258 +262 132 262 262 262 +264 133 264 264 264 +268 135 268 268 268 +272 137 272 272 272 +274 138 274 274 274 +278 140 278 278 278 +280 141 280 280 280 +282 142 282 282 282 +292 147 292 292 292 +296 149 296 296 296 +298 150 298 298 298 +304 153 304 304 304 +306 154 306 306 306 +310 156 310 310 310 +312 157 312 312 312 +316 159 316 316 316 +318 160 318 318 318 +320 161 320 320 320 +322 162 322 322 322 +326 164 326 326 326 +330 166 330 330 330 +332 167 332 332 332 +334 168 334 334 334 +336 169 336 336 336 +338 170 338 338 338 +340 171 340 340 340 +342 172 342 342 342 +344 173 344 344 344 +346 174 346 346 346 +350 176 350 350 350 +354 178 354 354 354 +356 179 356 356 356 +358 180 358 358 358 +360 181 360 360 360 +362 182 362 362 362 +364 183 364 364 364 +366 184 366 366 366 +368 185 368 368 368 +372 187 372 372 372 +378 190 378 378 378 +380 191 380 380 380 +384 193 384 384 384 +386 194 386 386 386 +388 195 388 388 388 +390 196 390 390 390 +392 197 392 392 392 +394 198 394 394 394 +396 199 396 396 396 +398 200 398 398 398 +400 201 400 400 400 +404 203 404 404 404 +406 204 406 406 406 +408 205 408 408 408 +410 206 410 410 410 +412 207 412 412 412 +416 209 416 416 416 +420 211 420 420 420 +422 212 422 422 422 +424 213 424 424 424 +432 217 432 432 432 +434 218 434 434 434 +438 220 438 438 438 +440 221 440 440 440 +442 222 442 442 442 +444 223 444 444 444 +448 225 448 448 448 +450 226 450 450 450 +452 227 452 452 452 +454 228 454 454 454 +458 230 458 458 458 +462 232 462 462 462 +464 233 464 464 464 +466 234 466 466 466 +472 237 472 472 472 +476 239 476 476 476 +480 241 480 480 480 +482 242 482 482 482 +484 243 484 484 484 +488 245 488 488 488 +490 246 490 490 490 +494 248 494 494 494 +500 251 500 500 500 +502 252 502 502 502 +504 253 504 504 504 +510 256 510 510 510 +516 259 516 516 516 +518 260 518 518 518 +520 261 520 520 520 +522 262 522 522 522 +526 264 526 526 526 +530 266 530 530 530 +532 267 532 532 532 +536 269 536 536 536 +538 270 538 538 538 +550 276 550 550 550 +552 277 552 552 552 +554 278 554 554 554 +556 279 556 556 556 +560 281 560 560 560 +562 282 562 562 562 +566 284 566 566 566 +568 285 568 568 568 +570 286 570 570 570 +572 287 572 572 572 +574 288 574 574 574 +576 289 576 576 576 +578 290 578 578 578 +580 291 580 580 580 +582 292 582 582 582 +584 293 584 584 584 +588 295 588 588 588 +590 296 590 590 590 +598 300 598 598 598 +602 302 602 602 602 +610 306 610 610 610 +616 309 616 616 616 +618 310 618 618 618 +620 311 620 620 620 +622 312 622 622 622 +624 313 624 624 624 +626 314 626 626 626 +628 315 628 628 628 +636 319 636 636 636 +638 320 638 638 638 +640 321 640 640 640 +642 322 642 642 642 +648 325 648 648 648 +650 326 650 650 650 +652 327 652 652 652 +656 329 656 656 656 +660 331 660 660 660 +668 335 668 668 668 +670 336 670 670 670 +672 337 672 672 672 +676 339 676 676 676 +678 340 678 678 678 +682 342 682 682 682 +684 343 684 684 684 +688 345 688 688 688 +690 346 690 690 690 +694 348 694 694 694 +696 349 696 696 696 +702 352 702 702 702 +708 355 708 708 708 +712 357 712 712 712 +718 360 718 718 718 +726 364 726 726 726 +730 366 730 730 730 +734 368 734 734 734 +736 369 736 736 736 +738 370 738 738 738 +740 371 740 740 740 +742 372 742 742 742 +744 373 744 744 744 +752 377 752 752 752 +754 378 754 754 754 +756 379 756 756 756 +760 381 760 760 760 +762 382 762 762 762 +764 383 764 764 764 +770 386 770 770 770 +774 388 774 774 774 +778 390 778 778 778 +784 393 784 784 784 +790 396 790 790 790 +792 397 792 792 792 +794 398 794 794 794 +796 399 796 796 796 +798 400 798 798 798 +800 401 800 800 800 +804 403 804 804 804 +806 404 806 806 806 +808 405 808 808 808 +810 406 810 810 810 +812 407 812 812 812 +814 408 814 814 814 +818 410 818 818 818 +820 411 820 820 820 +824 413 824 824 824 +828 415 828 828 828 +832 417 832 832 832 +834 418 834 834 834 +840 421 840 840 840 +842 422 842 842 842 +844 423 844 844 844 +848 425 848 848 848 +854 428 854 854 854 +860 431 860 860 860 +864 433 864 864 864 +866 434 866 866 866 +868 435 868 868 868 +870 436 870 870 870 +876 439 876 876 876 +878 440 878 878 878 +880 441 880 880 880 +882 442 882 882 882 +884 443 884 884 884 +892 447 892 892 892 +894 448 894 894 894 +898 450 898 898 898 +902 452 902 902 902 +904 453 904 904 904 +910 456 910 910 910 +912 457 912 912 912 +914 458 914 914 914 +916 459 916 916 916 +920 461 920 920 920 +922 462 922 922 922 +924 463 924 924 924 +926 464 926 926 926 +930 466 930 930 930 +932 467 932 932 932 +938 470 938 938 938 +940 471 940 940 940 +942 472 942 942 942 +944 473 944 944 944 +946 474 946 946 946 +950 476 950 950 950 +956 479 956 956 956 +960 481 960 960 960 +962 482 962 962 962 +964 483 964 964 964 +966 484 966 966 966 +968 485 968 968 968 +970 486 970 970 970 +972 487 972 972 972 +974 488 974 974 974 +976 489 976 976 976 +980 491 980 980 980 +984 493 984 984 984 +986 494 986 986 986 +988 495 988 988 988 +990 496 990 990 990 +992 497 992 992 992 +994 498 994 994 994 +996 499 996 996 996 +998 500 998 998 998 +1000 501 1000 1000 1000 +1002 502 1002 1002 1002 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-dde25ab17e3198c18468e738f0464cf4 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-dde25ab17e3198c18468e738f0464cf4 new file mode 100644 index 0000000000000..d735316e70d59 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #6-0-dde25ab17e3198c18468e738f0464cf4 @@ -0,0 +1,309 @@ +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 diff --git a/sql/hive/src/test/resources/golden/aggregation with group by expressions #7-0-dc8a898b293d22742b62ce236e72f77 b/sql/hive/src/test/resources/golden/aggregation with group by expressions #7-0-dc8a898b293d22742b62ce236e72f77 new file mode 100644 index 0000000000000..d735316e70d59 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation with group by expressions #7-0-dc8a898b293d22742b62ce236e72f77 @@ -0,0 +1,309 @@ +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 +0.0 0.0 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 new file mode 100644 index 0000000000000..5111dd17161f0 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #1-0-30038eb221d9d91ff4a098a57c1a5f9 @@ -0,0 +1 @@ +500 498 0 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 new file mode 100644 index 0000000000000..b4d2e5cc256d4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #2-0-75a3974aac80b9c47f23519da6a68876 @@ -0,0 +1 @@ +500 498 0 130091 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 new file mode 100644 index 0000000000000..276664a61678d --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #3-0-8341e7bf739124bef28729aabb9fe542 @@ -0,0 +1 @@ +309 498 0 79136 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-679efde7a074d99d8dd227b4903b92f8 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-679efde7a074d99d8dd227b4903b92f8 new file mode 100644 index 0000000000000..276664a61678d --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #4-0-679efde7a074d99d8dd227b4903b92f8 @@ -0,0 +1 @@ +309 498 0 79136 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 new file mode 100644 index 0000000000000..ce71b00ee105c --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #5-0-1e35f970b831ecfffdaff828428aea51 @@ -0,0 +1 @@ +503 499 2 130096 diff --git a/sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-528e3454b467687ee9c1074cc7864660 b/sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-528e3454b467687ee9c1074cc7864660 new file mode 100644 index 0000000000000..418739f242dd4 --- /dev/null +++ b/sql/hive/src/test/resources/golden/aggregation without group by expressions #6-0-528e3454b467687ee9c1074cc7864660 @@ -0,0 +1 @@ +313 500 3 79140 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala new file mode 100644 index 0000000000000..0af38b1ab401e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregateSuite.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.hive.test.TestHive +import org.scalatest.BeforeAndAfter + +class AggregateSuite extends HiveComparisonTest with BeforeAndAfter { + override def beforeAll() { + TestHive.cacheTables = true + } + + override def afterAll() { + TestHive.cacheTables = false + } + + createQueryTest("aggregation without group by expressions #1", + """ + |SELECT + | count(value), + | max(key), + | min(key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #2", + """ + |SELECT + | count(value), + | max(key), + | min(key), + | sum(key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #3", + """ + |SELECT + | count(distinct value), + | max(key), + | min(key), + | sum(distinct key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #4", + """ + |SELECT + | count(distinct value), + | max(distinct key), + | min(distinct key), + | sum(distinct key) + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #5", + """ + |SELECT + | count(value) + 3, + | max(key) + 1, + | min(key) + 2, + | sum(key) + 5 + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation without group by expressions #6", + """ + |SELECT + | count(distinct value) + 4, + | max(distinct key) + 2, + | min(distinct key) + 3, + | sum(distinct key) + 4 + |FROM src + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #1", + """ + |SELECT key + 3 as a, count(value), max(key), min(key) + |FROM src group by key, value + |ORDER BY a + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #2", + """ + |SELECT + | key + 3 as a, + | count(value), + | max(key), + | min(key), + | sum(key) + |FROM src + |GROUP BY key, value + |ORDER BY a + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #3", + """ + |SELECT + | key + 3 as a, + | count(distinct value), + | max(key), min(key), + | sum(distinct key) + |FROM src + |GROUP BY key, value + |ORDER BY a + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #4", + """ + |SELECT + | key + 3 as a, + | count(distinct value), + | max(distinct key), + | min(distinct key), + | sum(distinct key) + |FROM src + |GROUP BY key, value + |ORDER BY a + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #5", + """ + |SELECT + | (key + 3) * 2 as a, + | (key + 3) + count(distinct value), + | (key + 3) + max(distinct (key + 3)), + | (key + 3) + min(distinct key + 3), + | (key + 3) + sum(distinct (key + 3)) + |FROM src + |GROUP BY key + 3, value + |ORDER BY a + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #6", + """ + |SELECT + | stddev_pop(key) as a, + | stddev_samp(key) as b + |FROM src + |GROUP BY key + 3, value + |ORDER BY a, b + """.stripMargin, false) + + createQueryTest("aggregation with group by expressions #7", + """ + |SELECT + | stddev_pop(distinct key) as a, + | stddev_samp(distinct key) as b + |FROM src + |GROUP BY key + 3, value + |ORDER BY a, b + """.stripMargin, false) +} From 341e7089e45951abb16d6e01b6dafc52840de1a6 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Thu, 26 Mar 2015 23:33:03 -0700 Subject: [PATCH 2/4] fix bug in test --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 3 +-- .../main/scala/org/apache/spark/sql/execution/Aggregate.scala | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1f6526ef66c56..07a383a86a3e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -78,8 +78,7 @@ abstract class BinaryArithmetic extends BinaryExpression { override lazy val resolved = left.resolved && right.resolved && - left.dataType == right.dataType && - !DecimalType.isFixed(left.dataType) + left.dataType == right.dataType def dataType: DataType = { if (!resolved) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 31ebdb35805ac..685854ec1ff24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -407,7 +407,7 @@ case class DistinctAggregate( val value = ae.eval(currentRow) if (ae.distinct) { - if (!seens(idx).contains(value)) { + if (value != null && !seens(idx).contains(value)) { ae.iterate(value, buffer) seens(idx).add(value) } @@ -434,7 +434,7 @@ case class DistinctAggregate( val value = ae.eval(currentRow) if (ae.distinct) { - if (!inputBufferSeens.seens(idx).contains(value)) { + if (value != null && !inputBufferSeens.seens(idx).contains(value)) { ae.iterate(value, inputBufferSeens.buffer) inputBufferSeens.seens(idx).add(value) } From b539baf87993d1eab09561281027c179b98da7a5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 27 Mar 2015 09:40:34 -0700 Subject: [PATCH 3/4] fix the bug of revert the null issue in Sum and also the Average UDAF --- .../spark/sql/catalyst/expressions/aggregates.scala | 6 ++++++ .../main/scala/org/apache/spark/sql/hive/hiveUdfs.scala | 9 ++++----- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 3892c89053902..22f53888c5d76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -413,6 +413,8 @@ case class Sum(child: Expression, distinct: Boolean = false) @transient var arg: MutableLiteral = _ @transient var sum: Add = _ + lazy val DEFAULT_VALUE = Cast(Literal(0, IntegerType), dataType).eval() + override def initialBoundReference(buffers: Seq[BoundReference]) = { aggr = buffers(0) arg = MutableLiteral(null, dataType) @@ -431,6 +433,10 @@ case class Sum(child: Expression, distinct: Boolean = false) arg.value = argument buf(aggr) = sum.eval(buf) } + } else { + if (buf.isNullAt(aggr)) { + buf(aggr) = DEFAULT_VALUE + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 51824abef2162..8fd3a4ceced9b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -261,7 +261,6 @@ private[hive] case class HiveGenericUdaf( // Initialize (reinitialize) the aggregation buffer override def reset(buf: MutableRow): Unit = { val buffer = evaluator.getNewAggregationBuffer - .asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] evaluator.reset(buffer) // This is a hack, we never use the mutable row as buffer, but define our own buffer, // which is set as the first element of the buffer @@ -276,19 +275,19 @@ private[hive] case class HiveGenericUdaf( }.toArray evaluator.iterate( - buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal), + buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal), args) } // Merge 2 aggregation buffer, and write back to the later one override def merge(value: Row, buf: MutableRow): Unit = { - val buffer = buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal) + val buffer = buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal) evaluator.merge(buffer, wrap(value.get(bound.ordinal), bufferObjectInspector)) } @deprecated override def terminatePartial(buf: MutableRow): Unit = { - val buffer = buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal) + val buffer = buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal) // this is for serialization buf(bound) = unwrap(evaluator.terminatePartial(buffer), bufferObjectInspector) } @@ -296,7 +295,7 @@ private[hive] case class HiveGenericUdaf( // Output the final result by feeding the aggregation buffer override def terminate(input: Row): Any = { unwrap(evaluator.terminate( - input.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal)), + input.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal)), objectInspector) } } From 13f4f15b59b4bca39409828576947f960704c18e Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 30 Mar 2015 23:07:40 -0700 Subject: [PATCH 4/4] using OpenHashSet instead --- .../spark/sql/execution/Aggregate.scala | 100 +++++++++--------- 1 file changed, 51 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 685854ec1ff24..6f199a4e7cd97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.execution -import scala.collection._ - import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.collection.{OpenHashSet, OpenHashMap} + import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -39,7 +39,7 @@ sealed case class AggregateFunctionBind( sealed class InputBufferSeens( var input: Row, // var buffer: MutableRow, - var seens: Array[mutable.HashSet[Any]] = null) { + var seens: Array[OpenHashSet[Any]] = null) { def this() { this(new GenericMutableRow(0), null) } @@ -54,7 +54,7 @@ sealed class InputBufferSeens( this } - def withSeens(seens: Array[mutable.HashSet[Any]]): InputBufferSeens = { + def withSeens(seens: Array[OpenHashSet[Any]]): InputBufferSeens = { this.seens = seens this } @@ -250,20 +250,13 @@ case class AggregatePreShuffle( createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) } else { - val results = new mutable.HashMap[Row, InputBufferSeens]() + val results = new OpenHashMap[Row, InputBufferSeens]() while (iter.hasNext) { val currentRow = iter.next() val keys = groupByProjection(currentRow) - results.get(keys) match { - case Some(inputbuffer) => - var idx = 0 - while (idx < aggregates.length) { - val ae = aggregates(idx) - ae.iterate(ae.eval(currentRow), inputbuffer.buffer) - idx += 1 - } - case None => + results(keys) match { + case null => val buffer = new GenericMutableRow(bufferSchema.length) var idx = 0 while (idx < aggregates.length) { @@ -278,11 +271,19 @@ case class AggregatePreShuffle( } val copies = keys.copy() - results.put(copies, new InputBufferSeens(copies, buffer)) + results(copies) = new InputBufferSeens(copies, buffer) + case inputbuffer => + var idx = 0 + while (idx < aggregates.length) { + val ae = aggregates(idx) + ae.iterate(ae.eval(currentRow), inputbuffer.buffer) + idx += 1 + } + } } - createIterator(aggregates, results.valuesIterator) + createIterator(aggregates, results.iterator.map(_._2)) } } } @@ -328,32 +329,32 @@ case class AggregatePostShuffle( createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer))) } else { - val results = new mutable.HashMap[Row, InputBufferSeens]() + val results = new OpenHashMap[Row, InputBufferSeens]() while (iter.hasNext) { val currentRow = iter.next() val keys = groupByProjection(currentRow) - results.get(keys) match { - case Some(pair) => + results(keys) match { + case null => + val buffer = new GenericMutableRow(bufferSchema.length) var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.merge(currentRow, pair.buffer) + ae.reset(buffer) + ae.merge(currentRow, buffer) idx += 1 } - case None => - val buffer = new GenericMutableRow(bufferSchema.length) + results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer) + case pair => var idx = 0 while (idx < aggregates.length) { val ae = aggregates(idx) - ae.reset(buffer) - ae.merge(currentRow, buffer) + ae.merge(currentRow, pair.buffer) idx += 1 } - results.put(keys.copy(), new InputBufferSeens(currentRow.copy(), buffer)) } } - createIterator(aggregates, results.valuesIterator) + createIterator(aggregates, results.iterator.map(_._2)) } } } @@ -383,7 +384,7 @@ case class DistinctAggregate( if (groupingExpressions.isEmpty) { val buffer = new GenericMutableRow(bufferSchema.length) // TODO save the memory only for those DISTINCT aggregate expressions - val seens = new Array[mutable.HashSet[Any]](aggregateFunctionBinds.length) + val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length) var idx = 0 while (idx < aggregateFunctionBinds.length) { @@ -391,7 +392,7 @@ case class DistinctAggregate( ae.reset(buffer) if (ae.distinct) { - seens(idx) = new mutable.HashSet[Any]() + seens(idx) = new OpenHashSet[Any]() } idx += 1 @@ -420,56 +421,57 @@ case class DistinctAggregate( createIterator(aggregates, Iterator(ibs)) } else { - val results = new mutable.HashMap[Row, InputBufferSeens]() + val results = new OpenHashMap[Row, InputBufferSeens]() while (iter.hasNext) { val currentRow = iter.next() val keys = groupByProjection(currentRow) - results.get(keys) match { - case Some(inputBufferSeens) => + results(keys) match { + case null => + val buffer = new GenericMutableRow(bufferSchema.length) + // TODO save the memory only for those DISTINCT aggregate expressions + val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length) + var idx = 0 while (idx < aggregateFunctionBinds.length) { val ae = aggregates(idx) val value = ae.eval(currentRow) + ae.reset(buffer) + ae.iterate(value, buffer) if (ae.distinct) { - if (value != null && !inputBufferSeens.seens(idx).contains(value)) { - ae.iterate(value, inputBufferSeens.buffer) - inputBufferSeens.seens(idx).add(value) + val seen = new OpenHashSet[Any]() + if (value != null) { + seen.add(value) } - } else { - ae.iterate(value, inputBufferSeens.buffer) + seens.update(idx, seen) } + idx += 1 } - case None => - val buffer = new GenericMutableRow(bufferSchema.length) - // TODO save the memory only for those DISTINCT aggregate expressions - val seens = new Array[mutable.HashSet[Any]](aggregateFunctionBinds.length) + results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer, seens) + case inputBufferSeens => var idx = 0 while (idx < aggregateFunctionBinds.length) { val ae = aggregates(idx) val value = ae.eval(currentRow) - ae.reset(buffer) - ae.iterate(value, buffer) if (ae.distinct) { - val seen = new mutable.HashSet[Any]() - if (value != null) { - seen.add(value) + if (value != null && !inputBufferSeens.seens(idx).contains(value)) { + ae.iterate(value, inputBufferSeens.buffer) + inputBufferSeens.seens(idx).add(value) } - seens.update(idx, seen) + } else { + ae.iterate(value, inputBufferSeens.buffer) } - idx += 1 } - results.put(keys.copy(), new InputBufferSeens(currentRow.copy(), buffer, seens)) } } - createIterator(aggregates, results.valuesIterator) + createIterator(aggregates, results.iterator.map(_._2)) } } }