Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit f9d6a73

Browse files
committed
Add a deterministic method to Expression.
1 parent 90c6069 commit f9d6a73

File tree

5 files changed

+61
-1
lines changed

5 files changed

+61
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@ abstract class Expression extends TreeNode[Expression] {
3737
* - A [[Cast]] or [[UnaryMinus]] is foldable if its child is foldable
3838
*/
3939
def foldable: Boolean = false
40+
41+
/**
42+
* Returns true when an expressions always return the same result for a specific set of
43+
* input values.
44+
*/
45+
// TODO: Need to well define what are explicit input values and implicit input values.
46+
def deterministic: Boolean = true
4047
def nullable: Boolean
4148
def references: AttributeSet = AttributeSet(children.flatMap(_.references.iterator))
4249

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ abstract class RDG(seed: Long) extends LeafExpression with Serializable {
3838
*/
3939
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
4040

41+
override def deterministic: Boolean = false
42+
4143
override def nullable: Boolean = false
4244

4345
override def dataType: DataType = DoubleType

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,16 @@ object ColumnPruning extends Rule[LogicalPlan] {
179179
* expressions into one single expression.
180180
*/
181181
object ProjectCollapsing extends Rule[LogicalPlan] {
182+
183+
/** Returns true if any expression in projectList is non-deterministic. */
184+
private def hasNondeterministic(projectList: Seq[NamedExpression]): Boolean = {
185+
projectList.exists(expr => expr.find(!_.deterministic).isDefined)
186+
}
187+
182188
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
183-
case Project(projectList1, Project(projectList2, child)) =>
189+
// We only collapse these two
190+
case Project(projectList1, Project(projectList2, child))
191+
if !hasNondeterministic(projectList2) =>
184192
// Create a map of Aliases to their values from the child projection.
185193
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
186194
val aliasMap = AttributeMap(projectList2.collect {

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql
1919

2020
import org.scalatest.Matchers._
2121

22+
import org.apache.spark.sql.execution.Project
2223
import org.apache.spark.sql.functions._
2324
import org.apache.spark.sql.test.TestSQLContext
2425
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -453,6 +454,44 @@ class ColumnExpressionSuite extends QueryTest {
453454
assert(row.getDouble(1) <= 1.0)
454455
assert(row.getDouble(1) >= 0.0)
455456
}
457+
458+
def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = {
459+
val projects = df.queryExecution.executedPlan.collect {
460+
case project: Project => project
461+
}
462+
assert(projects.size === expectedNumProjects)
463+
}
464+
465+
// We first create a plan with two Projects.
466+
// Project [rand + 1 AS rand1, rand - 1 AS rand2]
467+
// Project [key, Rand 5 AS rand]
468+
// LogicalRDD [key, value]
469+
// Because Rand function is not deterministic, the column rand is not deterministic.
470+
// So, in the optimizer, we will not collapse Project [rand + 1 AS rand1, rand - 1 AS rand2]
471+
// and Project [key, Rand 5 AS rand]. The final plan still has two Projects.
472+
val dfWithTwoProjects =
473+
testData
474+
.select('key, rand(5L).as("rand"))
475+
.select(('rand + 1).as("rand1"), ('rand - 1).as("rand2"))
476+
checkNumProjects(dfWithTwoProjects, 2)
477+
478+
// Now, we add one more project rand1 - rand2 on top of the query plan.
479+
// Since rand1 and rand2 are deterministic (they basically apply +/- to the generated
480+
// rand value), we can collapse rand1 - rand2 to the Project generating rand1 and rand2.
481+
// So, the plan will be optimized from ...
482+
// Project [(rand1 - rand2) AS (rand1 - rand2)]
483+
// Project [rand + 1 AS rand1, rand - 1 AS rand2]
484+
// Project [key, Rand 5 AS rand]
485+
// LogicalRDD [key, value]
486+
// to ...
487+
// Project [((rand + 1 AS rand1) - (rand - 1 AS rand2)) AS (rand1 - rand2)]
488+
// Project [key, Rand 5 AS rand]
489+
// LogicalRDD [key, value]
490+
val dfWithThreeProjects = dfWithTwoProjects.select('rand1 - 'rand2)
491+
checkNumProjects(dfWithThreeProjects, 2)
492+
dfWithThreeProjects.collect().foreach { row =>
493+
assert(row.getDouble(0) === 2.0 +- 0.0001)
494+
}
456495
}
457496

458497
test("randn") {

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, childre
7878

7979
type UDFType = UDF
8080

81+
override def deterministic: Boolean = isUDFDeterministic
82+
8183
override def nullable: Boolean = true
8284

8385
@transient
@@ -140,6 +142,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr
140142
extends Expression with HiveInspectors with Logging {
141143
type UDFType = GenericUDF
142144

145+
override def deterministic: Boolean = isUDFDeterministic
146+
143147
override def nullable: Boolean = true
144148

145149
@transient

0 commit comments

Comments
 (0)