Skip to content

[SPARK-4233] [SQL] WIP:Simplify the UDAF API (Interface) #3247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ package org.apache.spark.sql

import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.{StructType, DateUtils}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericRow}

object Row {
/**
Expand Down Expand Up @@ -305,6 +305,49 @@ trait Row extends Serializable {
*/
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]

/* TODO make the syntactic sugar it as API? */
@inline
final def apply(bound: BoundReference): Any = apply(bound.ordinal)
@inline
final def isNullAt(bound: BoundReference): Boolean = isNullAt(bound.ordinal)
@inline
final def getInt(bound: BoundReference): Int = getInt(bound.ordinal)
@inline
final def getLong(bound: BoundReference): Long = getLong(bound.ordinal)
@inline
final def getDouble(bound: BoundReference): Double = getDouble(bound.ordinal)
@inline
final def getBoolean(bound: BoundReference): Boolean = getBoolean(bound.ordinal)
@inline
final def getShort(bound: BoundReference): Short = getShort(bound.ordinal)
@inline
final def getByte(bound: BoundReference): Byte = getByte(bound.ordinal)
@inline
final def getFloat(bound: BoundReference): Float = getFloat(bound.ordinal)
@inline
final def getString(bound: BoundReference): String = getString(bound.ordinal)
@inline
final def getAs[T](bound: BoundReference): T = getAs[T](bound.ordinal)
@inline
final def getDecimal(bound: BoundReference): java.math.BigDecimal = getDecimal(bound.ordinal)
@inline
final def getDate(bound: BoundReference): java.sql.Date = getDate(bound.ordinal)
@inline
final def getSeq[T](bound: BoundReference): Seq[T] = getSeq(bound.ordinal)
@inline
final def getList[T](bound: BoundReference): java.util.List[T] = getList(bound.ordinal)
@inline
final def getMap[K, V](bound: BoundReference): scala.collection.Map[K, V] = {
getMap(bound.ordinal)
}
@inline
final def getJavaMap[K, V](bound: BoundReference): java.util.Map[K, V] = {
getJavaMap(bound.ordinal)
}
@inline
final def getStruct(bound: BoundReference): Row = getStruct(bound.ordinal)
/* end of the syntactic sugar it as API */

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes must be needed for this patch?
The interfaces of Row are related to all the other operator.
I think that if necessary, you make a PR first to add these interfaces in Row.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with that.

override def toString(): String = s"[${this.mkString(",")}]"

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,18 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {

protected lazy val function: Parser[Expression] =
( SUM ~> "(" ~> expression <~ ")" ^^ { case exp => Sum(exp) }
| SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => SumDistinct(exp) }
| SUM ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => Sum(exp, true) }
| COUNT ~ "(" ~> "*" <~ ")" ^^ { case _ => Count(Literal(1)) }
| COUNT ~ "(" ~> expression <~ ")" ^^ { case exp => Count(exp) }
| COUNT ~> "(" ~> DISTINCT ~> repsep(expression, ",") <~ ")" ^^
{ case exps => CountDistinct(exps) }
{ case exps => CountDistinct(exps) }
| COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) }
// TODO approximate is not supported, will convert into COUNT
| APPROXIMATE ~ COUNT ~ "(" ~ DISTINCT ~> expression <~ ")" ^^
{ case exp => ApproxCountDistinct(exp) }
{ case exp => CountDistinct(exp :: Nil) }
| APPROXIMATE ~> "(" ~> floatLit ~ ")" ~ COUNT ~ "(" ~ DISTINCT ~ expression <~ ")" ^^
{ case s ~ _ ~ _ ~ _ ~ _ ~ e => ApproxCountDistinct(e, s.toDouble) }
// TODO approximate s.toDouble
{ case s ~ _ ~ _ ~ _ ~ _ ~ e => CountDistinct(e :: Nil) }
| FIRST ~ "(" ~> expression <~ ")" ^^ { case exp => First(exp) }
| LAST ~ "(" ~> expression <~ ")" ^^ { case exp => Last(exp) }
| AVG ~ "(" ~> expression <~ ")" ^^ { case exp => Average(exp) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ class Analyzer(catalog: Catalog,
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) =>
case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) =>
val expandedArgs = args.flatMap {
case s: Star => s.expand(child.output, resolver)
case o => o :: Nil
Expand Down Expand Up @@ -385,8 +385,8 @@ class Analyzer(catalog: Catalog,
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan =>
q transformExpressions {
case u @ UnresolvedFunction(name, children) if u.childrenResolved =>
registry.lookupFunction(name, children)
case u @ UnresolvedFunction(name, children, distinct) if u.childrenResolved =>
registry.lookupFunction(name, children, distinct)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,12 @@ trait FunctionRegistry {

def registerFunction(name: String, builder: FunctionBuilder): Unit

def lookupFunction(name: String, children: Seq[Expression]): Expression

def caseSensitive: Boolean

def lookupFunction(
name: String,
children: Seq[Expression],
distinct: Boolean = false): Expression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to have careful considerations on this change
because I think distinct is not related to simple UDFs and UDTF.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but we couldn't tell if it's the UDF or UDAF, except the catalog itself.

Any better idea on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I have not a smart idea.
Does you know a similar logic to look up user-defined functions in hive?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Spark SQL, we put everything about the expression, as its constructor argument list, but Hive requires some additional object to keep the expression info, and both of them will be initialized/used within the executors. We need to throw exception if it's not an UDAF once distinct was set to true.

}

trait OverrideFunctionRegistry extends FunctionRegistry {
Expand All @@ -39,8 +42,13 @@ trait OverrideFunctionRegistry extends FunctionRegistry {
functionBuilders.put(name, builder)
}

abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name, children))
abstract override def lookupFunction(
name: String,
children: Seq[Expression],
distinct: Boolean = false): Expression = {
functionBuilders.get(name)
.map(_(children))
.getOrElse(super.lookupFunction(name, children, distinct))
}
}

Expand All @@ -51,7 +59,10 @@ class SimpleFunctionRegistry(val caseSensitive: Boolean) extends FunctionRegistr
functionBuilders.put(name, builder)
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
override def lookupFunction(
name: String,
children: Seq[Expression],
distinct: Boolean = false): Expression = {
functionBuilders(name)(children)
}
}
Expand All @@ -65,7 +76,10 @@ object EmptyFunctionRegistry extends FunctionRegistry {
throw new UnsupportedOperationException
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
def lookupFunction(
name: String,
children: Seq[Expression],
distinct: Boolean = false): Expression = {
throw new UnsupportedOperationException
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,10 +270,10 @@ trait HiveTypeCoercion {
case i @ In(a, b) if a.dataType == TimestampType && b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
case Average(e) if e.dataType == StringType =>
Average(Cast(e, DoubleType))
case Sum(e, distinct) if e.dataType == StringType =>
Sum(Cast(e, DoubleType), distinct)
case Average(e, distinct) if e.dataType == StringType =>
Average(Cast(e, DoubleType), distinct)
case Sqrt(e) if e.dataType == StringType =>
Sqrt(Cast(e, DoubleType))
}
Expand Down Expand Up @@ -484,25 +484,21 @@ trait HiveTypeCoercion {
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))

// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
case Sum(e @ FractionalType()) if e.dataType != DoubleType => Sum(Cast(e, DoubleType))

case s @ SumDistinct(e @ DecimalType()) => s // Decimal is already the biggest.
case SumDistinct(e @ IntegralType()) if e.dataType != LongType =>
SumDistinct(Cast(e, LongType))
case SumDistinct(e @ FractionalType()) if e.dataType != DoubleType =>
SumDistinct(Cast(e, DoubleType))

case s @ Average(e @ DecimalType()) => s // Decimal is already the biggest.
case Average(e @ IntegralType()) if e.dataType != LongType =>
Average(Cast(e, LongType))
case Average(e @ FractionalType()) if e.dataType != DoubleType =>
Average(Cast(e, DoubleType))
case s @ Sum(e @ DecimalType(), _) => s // Decimal is already the biggest.
case Sum(e @ IntegralType(), distinct) if e.dataType != LongType =>
Sum(Cast(e, LongType), distinct)
case Sum(e @ FractionalType(), distinct) if e.dataType != DoubleType =>
Sum(Cast(e, DoubleType), distinct)

case s @ Average(e @ DecimalType(), _) => s // Decimal is already the biggest.
case Average(e @ IntegralType(), distinct) if e.dataType != LongType =>
Average(Cast(e, LongType), distinct)
case Average(e @ FractionalType(), distinct) if e.dataType != DoubleType =>
Average(Cast(e, DoubleType), distinct)

// Hive lets you do aggregation of timestamps... for some reason
case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType))
case Average(e @ TimestampType()) => Average(Cast(e, DoubleType))
case Sum(e @ TimestampType(), distinct) => Sum(Cast(e, DoubleType), distinct)
case Average(e @ TimestampType(), distinct) => Average(Cast(e, DoubleType), distinct)

// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
override def toString: String = s"'$name"
}

case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression {
case class UnresolvedFunction(
name: String,
children: Seq[Expression],
distinct: Boolean = false) extends Expression {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,12 @@ package object dsl {
}

def sum(e: Expression): Expression = Sum(e)
def sumDistinct(e: Expression): Expression = SumDistinct(e)
def sumDistinct(e: Expression): Expression = Sum(e, true)
def count(e: Expression): Expression = Count(e)
def countDistinct(e: Expression*): Expression = CountDistinct(e)
def approxCountDistinct(e: Expression, rsd: Double = 0.05): Expression =
ApproxCountDistinct(e, rsd)
def countDistinct(e: Expression) = CountDistinct(e :: Nil)
def countDistinct(e: Expression*) = CountDistinct(e)
// TODO we don't support approximate, will convert it into Count
def approxCountDistinct(e: Expression, rsd: Double = 0.05) = CountDistinct(e :: Nil)
def avg(e: Expression): Expression = Average(e)
def first(e: Expression): Expression = First(e)
def last(e: Expression): Expression = Last(e)
Expand All @@ -161,6 +162,7 @@ package object dsl {
def abs(e: Expression): Expression = Abs(e)

implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }

// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
override def expr: Expression = Literal(s)
Expand Down
Loading