Skip to content

Commit 77ba3a6

Browse files
committed
Add PreAnalyzer.
1 parent 5bbcd13 commit 77ba3a6

File tree

4 files changed

+99
-3
lines changed

4 files changed

+99
-3
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.plans.logical._
22+
import org.apache.spark.sql.catalyst.rules._
23+
24+
class PreAnalyzer(caseSensitive: Boolean = true,
25+
maxIterations: Int = 100) extends RuleExecutor[LogicalPlan] {
26+
27+
val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution
28+
29+
val fixedPoint = FixedPoint(maxIterations)
30+
31+
lazy val batches: Seq[Batch] = Seq(
32+
Batch("ResolveSelfJoin", fixedPoint, ResolveSelfJoin)
33+
)
34+
35+
/**
36+
* Special handling for cases when self-join introduce duplicate expression ids
37+
*/
38+
object ResolveSelfJoin extends Rule[LogicalPlan] {
39+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
40+
case p: LogicalPlan if !p.childrenResolved => p
41+
42+
case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
43+
val conflictingAttributes = left.outputSet.intersect(right.outputSet)
44+
logDebug(s"Conflicting attributes ${conflictingAttributes.mkString(",")} in $j")
45+
46+
val (oldRelation, newRelation) = right.collect {
47+
// Handle base relations that might appear more than once.
48+
case oldVersion: MultiInstanceRelation
49+
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
50+
val newVersion = oldVersion.newInstance()
51+
(oldVersion, newVersion)
52+
53+
// Handle projects that create conflicting aliases.
54+
case oldVersion @ Project(projectList, _)
55+
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
56+
(oldVersion, oldVersion.copy(projectList = newAliases(projectList)))
57+
58+
case oldVersion @ Aggregate(_, aggregateExpressions, _)
59+
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
60+
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))
61+
}.head // Only handle first case found, others will be fixed on the next pass.
62+
63+
val attributeRewrites = AttributeMap(oldRelation.output.zip(newRelation.output))
64+
val newRight = right transformUp {
65+
case r if r == oldRelation => newRelation
66+
} transformUp {
67+
case other => other transformExpressions {
68+
case a: Attribute => attributeRewrites.get(a).getOrElse(a)
69+
}
70+
}
71+
j.copy(right = newRight)
72+
}
73+
74+
def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
75+
expressions.map {
76+
case a: Alias => Alias(a.child, a.name)()
77+
case other => other
78+
}
79+
}
80+
81+
def findAliases(projectList: Seq[NamedExpression]): AttributeSet = {
82+
AttributeSet(projectList.collect { case a: Alias => a.toAttribute })
83+
}
84+
}
85+
}
86+

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
113113
@transient
114114
protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry(true)
115115

116+
@transient
117+
protected[sql] lazy val preAnalyzer: PreAnalyzer = new PreAnalyzer()
116118
@transient
117119
protected[sql] lazy val analyzer: Analyzer =
118120
new Analyzer(catalog, functionRegistry, caseSensitive = true) {
@@ -1104,9 +1106,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
11041106
* access to the intermediate phases of query execution for developers.
11051107
*/
11061108
@DeveloperApi
1107-
protected[sql] class QueryExecution(val logical: LogicalPlan) {
1109+
protected[sql] class QueryExecution(val rawPlan: LogicalPlan) {
11081110
def assertAnalyzed(): Unit = checkAnalysis(analyzed)
11091111

1112+
lazy val logical: LogicalPlan = preAnalyzer(rawPlan)
11101113
lazy val analyzed: LogicalPlan = analyzer(logical)
11111114
lazy val withCachedData: LogicalPlan = {
11121115
assertAnalyzed()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ class DataFrameSuite extends QueryTest {
113113
checkAnswer(
114114
df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("x.str").count(),
115115
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
116+
117+
checkAnswer(
118+
df.as('x).join(df.as('y), $"x.str" === $"y.str").groupBy("y.str").count(),
119+
Row("1", 1) :: Row("2", 1) :: Row("3", 1) :: Nil)
116120
}
117121

118122
test("explode") {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
376376
override protected[sql] val planner = hivePlanner
377377

378378
/** Extends QueryExecution with hive specific features. */
379-
protected[sql] class QueryExecution(logicalPlan: LogicalPlan)
380-
extends super.QueryExecution(logicalPlan) {
379+
protected[sql] class QueryExecution(rawPlan: LogicalPlan)
380+
extends super.QueryExecution(rawPlan) {
381+
382+
lazy val logicalPlan: LogicalPlan = preAnalyzer(rawPlan)
383+
381384
// Like what we do in runHive, makes sure the session represented by the
382385
// `sessionState` field is activated.
383386
if (SessionState.get() != sessionState) {

0 commit comments

Comments
 (0)