Skip to content

Commit 33640ec

Browse files
committed
auto alias expressions in analyzer
1 parent a1e3649 commit 33640ec

File tree

11 files changed

+74
-86
lines changed

11 files changed

+74
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
9999
protected val WHERE = Keyword("WHERE")
100100
protected val WITH = Keyword("WITH")
101101

102-
protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = {
103-
exprs.zipWithIndex.map {
104-
case (ne: NamedExpression, _) => ne
105-
case (e, i) => Alias(e, s"c$i")()
106-
}
107-
}
108-
109102
protected lazy val start: Parser[LogicalPlan] =
110103
start1 | insert | cte
111104

@@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser {
130123
val base = r.getOrElse(OneRowRelation)
131124
val withFilter = f.map(Filter(_, base)).getOrElse(base)
132125
val withProjection = g
133-
.map(Aggregate(_, assignAliases(p), withFilter))
134-
.getOrElse(Project(assignAliases(p), withFilter))
126+
.map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter))
127+
.getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter))
135128
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
136129
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
137130
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,14 @@ class Analyzer(
7070
Batch("Resolution", fixedPoint,
7171
ResolveRelations ::
7272
ResolveReferences ::
73+
ResolveAliases ::
7374
ResolveGroupingAnalytics ::
7475
ResolveSortReferences ::
7576
ResolveGenerate ::
7677
ResolveFunctions ::
7778
ExtractWindowExpressions ::
7879
GlobalAggregates ::
7980
UnresolvedHavingClauseAttributes ::
80-
TrimGroupingAliases ::
8181
typeCoercionRules ++
8282
extendedResolutionRules : _*)
8383
)
@@ -131,13 +131,28 @@ class Analyzer(
131131
}
132132
}
133133

134-
/**
135-
* Removes no-op Alias expressions from the plan.
136-
*/
137-
object TrimGroupingAliases extends Rule[LogicalPlan] {
138-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
139-
case Aggregate(groups, aggs, child) =>
140-
Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child)
134+
object ResolveAliases extends Rule[LogicalPlan] {
135+
private def assignAliases(exprs: Seq[Expression]) = {
136+
var i = -1
137+
exprs.map(_ transformDown {
138+
case u @ UnresolvedAlias(child) =>
139+
child match {
140+
case ne: NamedExpression => ne
141+
case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
142+
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
143+
case e if !e.resolved => u
144+
case other =>
145+
i += 1
146+
Alias(other, s"c$i")()
147+
}
148+
}).asInstanceOf[Seq[NamedExpression]]
149+
}
150+
151+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
152+
case Aggregate(groups, aggs, child) if child.resolved =>
153+
Aggregate(groups, assignAliases(aggs), child)
154+
case Project(projectList, child) if child.resolved =>
155+
Project(assignAliases(projectList), child)
141156
}
142157
}
143158

@@ -228,7 +243,7 @@ class Analyzer(
228243
}
229244

230245
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
231-
case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
246+
case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) =>
232247
i.copy(table = EliminateSubQueries(getTable(u)))
233248
case u: UnresolvedRelation =>
234249
getTable(u)
@@ -352,8 +367,12 @@ class Analyzer(
352367
q.asInstanceOf[GroupingAnalytics].gid
353368
case u @ UnresolvedAttribute(nameParts) =>
354369
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
355-
val result =
356-
withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) }
370+
val result = withPosition(u) {
371+
q.resolveChildren(nameParts, resolver).map {
372+
case UnresolvedAlias(child) => child
373+
case other => other
374+
}.getOrElse(u)
375+
}
357376
logDebug(s"Resolving $u to $result")
358377
result
359378
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
@@ -586,19 +605,7 @@ class Analyzer(
586605
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
587606
private object AliasedGenerator {
588607
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
589-
case Alias(g: Generator, name)
590-
if g.resolved &&
591-
g.elementTypes.size > 1 &&
592-
java.util.regex.Pattern.matches("_c[0-9]+", name) => {
593-
// Assume the default name given by parser is "_c[0-9]+",
594-
// TODO in long term, move the naming logic from Parser to Analyzer.
595-
// In projection, Parser gave default name for TGF as does for normal UDF,
596-
// but the TGF probably have multiple output columns/names.
597-
// e.g. SELECT explode(map(key, value)) FROM src;
598-
// Let's simply ignore the default given name for this case.
599-
Some((g, Nil))
600-
}
601-
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
608+
case Alias(g: Generator, name) if g.elementTypes.size > 1 =>
602609
// If not given the default names, and the TGF with multiple output columns
603610
failAnalysis(
604611
s"""Expect multiple names given for ${g.getClass.getName},

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,7 @@ trait CheckAnalysis {
9595
case e => e.children.foreach(checkValidAggregateExpression)
9696
}
9797

98-
val cleaned = aggregateExprs.map(_.transform {
99-
// Should trim aliases around `GetField`s. These aliases are introduced while
100-
// resolving struct field accesses, because `GetField` is not a `NamedExpression`.
101-
// (Should we just turn `GetField` into a `NamedExpression`?)
102-
case Alias(g, _) => g
103-
})
104-
105-
cleaned.foreach(checkValidAggregateExpression)
98+
aggregateExprs.foreach(checkValidAggregateExpression)
10699

107100
case _ => // Fallbacks to the following checks
108101
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,17 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression)
206206

207207
override def toString: String = s"$child[$extraction]"
208208
}
209+
210+
case class UnresolvedAlias(child: Expression) extends NamedExpression with trees.UnaryNode[Expression] {
211+
override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
212+
override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers")
213+
override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
214+
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
215+
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
216+
override def name: String = throw new UnresolvedException(this, "name")
217+
218+
override lazy val resolved = false
219+
220+
override def eval(input: Row = null): Any =
221+
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")
222+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,16 +94,22 @@ trait ExtractValue extends UnaryExpression {
9494
self: Product =>
9595
}
9696

97+
abstract class ExtractValueWithStruct extends ExtractValue {
98+
self: Product =>
99+
100+
def field: StructField
101+
override def foldable: Boolean = child.foldable
102+
override def toString: String = s"$child.${field.name}"
103+
}
104+
97105
/**
98106
* Returns the value of fields in the Struct `child`.
99107
*/
100108
case class GetStructField(child: Expression, field: StructField, ordinal: Int)
101-
extends ExtractValue {
109+
extends ExtractValueWithStruct {
102110

103111
override def dataType: DataType = field.dataType
104112
override def nullable: Boolean = child.nullable || field.nullable
105-
override def foldable: Boolean = child.foldable
106-
override def toString: String = s"$child.${field.name}"
107113

108114
override def eval(input: InternalRow): Any = {
109115
val baseValue = child.eval(input).asInstanceOf[InternalRow]
@@ -118,12 +124,10 @@ case class GetArrayStructFields(
118124
child: Expression,
119125
field: StructField,
120126
ordinal: Int,
121-
containsNull: Boolean) extends ExtractValue {
127+
containsNull: Boolean) extends ExtractValueWithStruct {
122128

123129
override def dataType: DataType = ArrayType(field.dataType, containsNull)
124130
override def nullable: Boolean = child.nullable
125-
override def foldable: Boolean = child.foldable
126-
override def toString: String = s"$child.${field.name}"
127131

128132
override def eval(input: InternalRow): Any = {
129133
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.AnalysisException
22-
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver}
22+
import org.apache.spark.sql.catalyst.analysis._
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.QueryPlan
2525
import org.apache.spark.sql.catalyst.trees.TreeNode

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scala.language.implicitConversions
2121

2222
import org.apache.spark.annotation.Experimental
2323
import org.apache.spark.Logging
24-
import org.apache.spark.sql.expressions.Window
2524
import org.apache.spark.sql.functions.lit
2625
import org.apache.spark.sql.catalyst.expressions._
2726
import org.apache.spark.sql.catalyst.analysis._

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -629,11 +629,7 @@ class DataFrame private[sql](
629629
@scala.annotation.varargs
630630
def select(cols: Column*): DataFrame = {
631631
val namedExpressions = cols.map {
632-
case Column(expr: NamedExpression) => expr
633-
// Leave an unaliased explode with an empty list of names since the analzyer will generate the
634-
// correct defaults after the nested expression's type has been resolved.
635-
case Column(explode: Explode) => MultiAlias(explode, Nil)
636-
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
632+
case Column(expr: Expression) => UnresolvedAlias(expr)
637633
}
638634
// When user continuously call `select`, speed up analysis by collapsing `Project`
639635
import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import scala.collection.JavaConversions._
2121
import scala.language.implicitConversions
2222

2323
import org.apache.spark.annotation.Experimental
24-
import org.apache.spark.sql.catalyst.analysis.Star
24+
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, Star}
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
2727
import org.apache.spark.sql.types.NumericType
@@ -70,27 +70,24 @@ class GroupedData protected[sql](
7070
groupingExprs: Seq[Expression],
7171
private val groupType: GroupedData.GroupType) {
7272

73-
private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
73+
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
7474
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
75-
val retainedExprs = groupingExprs.map {
76-
case expr: NamedExpression => expr
77-
case expr: Expression => Alias(expr, expr.prettyString)()
78-
}
79-
retainedExprs ++ aggExprs
80-
} else {
81-
aggExprs
82-
}
75+
groupingExprs ++ aggExprs
76+
} else {
77+
aggExprs
78+
}
8379

80+
val aliasedAgg = aggregates.map(UnresolvedAlias(_))
8481
groupType match {
8582
case GroupedData.GroupByType =>
8683
DataFrame(
87-
df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
84+
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
8885
case GroupedData.RollupType =>
8986
DataFrame(
90-
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
87+
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
9188
case GroupedData.CubeType =>
9289
DataFrame(
93-
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
90+
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
9491
}
9592
}
9693

@@ -112,10 +109,7 @@ class GroupedData protected[sql](
112109
namedExpr
113110
}
114111
}
115-
toDF(columnExprs.map { c =>
116-
val a = f(c)
117-
Alias(a, a.prettyString)()
118-
})
112+
toDF(columnExprs.map(f))
119113
}
120114

121115
private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -169,8 +163,7 @@ class GroupedData protected[sql](
169163
*/
170164
def agg(exprs: Map[String, String]): DataFrame = {
171165
toDF(exprs.map { case (colName, expr) =>
172-
val a = strToExpr(expr)(df(colName).expr)
173-
Alias(a, a.prettyString)()
166+
strToExpr(expr)(df(colName).expr)
174167
}.toSeq)
175168
}
176169

@@ -224,10 +217,7 @@ class GroupedData protected[sql](
224217
*/
225218
@scala.annotation.varargs
226219
def agg(expr: Column, exprs: Column*): DataFrame = {
227-
toDF((expr +: exprs).map(_.expr).map {
228-
case expr: NamedExpression => expr
229-
case expr: Expression => Alias(expr, expr.prettyString)()
230-
})
220+
toDF((expr +: exprs).map(_.expr))
231221
}
232222

233223
/**

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql
1919

2020
import java.sql.Timestamp
2121

22-
import org.apache.spark.sql.catalyst.plans.logical
2322
import org.apache.spark.sql.test.TestSQLContext.implicits._
2423
import org.apache.spark.sql.test._
2524

0 commit comments

Comments
 (0)