Skip to content

Commit 3c045c7

Browse files
Optimize the Constant Folding by adding more rules
1 parent 2645d4f commit 3c045c7

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ abstract class Expression extends TreeNode[Expression] {
4444
* - A [[expressions.Cast Cast]] or [[expressions.UnaryMinus UnaryMinus]] is foldable if its
4545
* child is foldable.
4646
*/
47-
// TODO: Supporting more foldable expressions. For example, deterministic Hive UDFs.
4847
def foldable: Boolean = false
4948
def nullable: Boolean
5049
def references: Set[Attribute]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,25 +94,9 @@ object ConstantFolding extends Rule[LogicalPlan] {
9494
case q: LogicalPlan => q transformExpressionsDown {
9595
// Skip redundant folding of literals.
9696
case l: Literal => l
97-
case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue
98-
case e @ In(Literal(v, _), list) if(list.exists(c => c match {
99-
case Literal(candidate, _) if(candidate == v) => true
100-
case _ => false
101-
})) => Literal(true, BooleanType)
102-
case e if e.foldable => Literal(e.eval(null), e.dataType)
103-
}
104-
}
105-
}
106-
107-
/**
108-
* The expression may be constant value, due to one or more of its children expressions is null or
109-
* not null constantly, replaces [[catalyst.expressions.Expression Expressions]] with equivalent
110-
* [[catalyst.expressions.Literal Literal]] values if possible caused by that.
111-
*/
112-
object NullPropagation extends Rule[LogicalPlan] {
113-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
114-
case q: LogicalPlan => q transformExpressionsUp {
115-
case l: Literal => l
97+
case e @ Count(Literal(null, _)) => Literal(null, e.dataType)
98+
case e @ Sum(Literal(null, _)) => Literal(null, e.dataType)
99+
case e @ Average(Literal(null, _)) => Literal(null, e.dataType)
116100
case e @ IsNull(Literal(null, _)) => Literal(true, BooleanType)
117101
case e @ IsNull(Literal(_, _)) => Literal(false, BooleanType)
118102
case e @ IsNull(c @ Rand) => Literal(false, BooleanType)
@@ -135,6 +119,11 @@ object NullPropagation extends Rule[LogicalPlan] {
135119
Coalesce(newChildren)
136120
}
137121
}
122+
case e @ If(Literal(v, _), trueValue, falseValue) => if(v == true) trueValue else falseValue
123+
case e @ In(Literal(v, _), list) if(list.exists(c => c match {
124+
case Literal(candidate, _) if(candidate == v) => true
125+
case _ => false
126+
})) => Literal(true, BooleanType)
138127
// TODO put exceptional cases(Unary & Binary Expression) before here.
139128
case e: UnaryExpression => e.child match {
140129
case Literal(null, _) => Literal(null, e.dataType)
@@ -143,6 +132,7 @@ object NullPropagation extends Rule[LogicalPlan] {
143132
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
144133
case left :: Literal(null, _) :: Nil => Literal(null, e.dataType)
145134
}
135+
case e if e.foldable => Literal(e.eval(null), e.dataType)
146136
}
147137
}
148138
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
2222
import org.apache.hadoop.hive.common.`type`.HiveDecimal
2323
import org.apache.hadoop.hive.ql.exec.UDF
2424
import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
25+
import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
2526
import org.apache.hadoop.hive.ql.udf.generic._
2627
import org.apache.hadoop.hive.serde2.objectinspector._
2728
import org.apache.hadoop.hive.serde2.objectinspector.primitive._
@@ -213,6 +214,16 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
213214

214215
@transient
215216
protected lazy val returnInspector = function.initialize(argumentInspectors.toArray)
217+
218+
@transient
219+
protected lazy val isUDFDeterministic = {
220+
val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
221+
(udfType != null && udfType.deterministic())
222+
}
223+
224+
override def foldable = {
225+
isUDFDeterministic && children.foldLeft(true)((prev, n) => prev && n.foldable)
226+
}
216227

217228
val dataType: DataType = inspectorToDataType(returnInspector)
218229

0 commit comments

Comments
 (0)