Skip to content

Commit f99401a

Browse files
marmbrusrxin
authored andcommitted
[SQL] Improve column pruning in the optimizer.
Author: Michael Armbrust <[email protected]> Closes #378 from marmbrus/columnPruning and squashes the following commits: 779da56 [Michael Armbrust] More consistent naming. 1a4e9ea [Michael Armbrust] More comments. 2f4e7b9 [Michael Armbrust] Improve column pruning in the optimizer.
1 parent 930b70f commit f99401a

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,56 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
3333
Batch("Filter Pushdown", Once,
3434
CombineFilters,
3535
PushPredicateThroughProject,
36-
PushPredicateThroughInnerJoin) :: Nil
36+
PushPredicateThroughInnerJoin,
37+
ColumnPruning) :: Nil
38+
}
39+
40+
/**
41+
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
42+
* transformations:
43+
*
44+
* - Inserting Projections beneath the following operators:
45+
* - Aggregate
46+
* - Project <- Join
47+
* - Collapse adjacent projections, performing alias substitution.
48+
*/
49+
object ColumnPruning extends Rule[LogicalPlan] {
50+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
51+
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
52+
// Project away references that are not needed to calculate the required aggregates.
53+
a.copy(child = Project(a.references.toSeq, child))
54+
55+
case Project(projectList, Join(left, right, joinType, condition)) =>
56+
// Collect the list of off references required either above or to evaluate the condition.
57+
val allReferences: Set[Attribute] =
58+
projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty)
59+
/** Applies a projection when the child is producing unnecessary attributes */
60+
def prunedChild(c: LogicalPlan) =
61+
if ((allReferences.filter(c.outputSet.contains) -- c.outputSet).nonEmpty) {
62+
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
63+
} else {
64+
c
65+
}
66+
67+
Project(projectList, Join(prunedChild(left), prunedChild(right), joinType, condition))
68+
69+
case Project(projectList1, Project(projectList2, child)) =>
70+
// Create a map of Aliases to their values from the child projection.
71+
// e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)).
72+
val aliasMap = projectList2.collect {
73+
case a @ Alias(e, _) => (a.toAttribute: Expression, a)
74+
}.toMap
75+
76+
// Substitute any attributes that are produced by the child projection, so that we safely
77+
// eliminate it.
78+
// e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...'
79+
// TODO: Fix TransformBase to avoid the cast below.
80+
val substitutedProjection = projectList1.map(_.transform {
81+
case a if aliasMap.contains(a) => aliasMap(a)
82+
}).asInstanceOf[Seq[NamedExpression]]
83+
84+
Project(substitutedProjection, child)
85+
}
3786
}
3887

3988
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ case class Aggregate(
127127
extends UnaryNode {
128128

129129
def output = aggregateExpressions.map(_.toAttribute)
130-
def references = child.references
130+
def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
131131
}
132132

133133
case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {

0 commit comments

Comments
 (0)