Skip to content

Commit 705d963

Browse files
committed
Add a rule for resolving ORDER BY expressions that reference attributes not present in the SELECT clause.
1 parent 82cabda commit 705d963

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
4848
Batch("Resolution", fixedPoint,
4949
ResolveReferences ::
5050
ResolveRelations ::
51+
ResolveSortReferences ::
5152
NewRelationInstances ::
5253
ImplicitGenerate ::
5354
StarExpansion ::
@@ -120,6 +121,51 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
120121
}
121122
}
122123

124+
/**
125+
* In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT
126+
* clause. This rule detects such queries and adds the required attributes to the original
127+
* projection, so that they will be available during sorting. Another projection is added to
128+
* remove these attributes after sorting.
129+
*/
130+
object ResolveSortReferences extends Rule[LogicalPlan] {
131+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
132+
case s@Sort(ordering, p@Project(projectList, child)) if !s.resolved && p.resolved =>
133+
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name})
134+
val resolved = unresolved.flatMap(child.resolveChildren)
135+
val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
136+
137+
val missingInProject = requiredAttributes -- p.output
138+
if (missingInProject.nonEmpty) {
139+
// Add missing attributes and then project them away after the sort.
140+
Project(projectList,
141+
Sort(ordering,
142+
Project(projectList ++ missingInProject, child)))
143+
} else {
144+
s // Nothing we can do here. Return original plan.
145+
}
146+
case s@Sort(ordering, a@Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
147+
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name})
148+
// A small hack to create an object that will allow us to resolve any references that
149+
// refer to named expressions that are present in the grouping expressions.
150+
val groupingRelation = LocalRelation(
151+
grouping.collect { case ne: NamedExpression => ne.toAttribute}
152+
)
153+
154+
logWarning(s"Grouping expressions: $groupingRelation")
155+
val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
156+
val missingInAggs = resolved -- a.outputSet
157+
logWarning(s"Resolved: $resolved Missing in aggs: $missingInAggs")
158+
if(missingInAggs.nonEmpty) {
159+
// Add missing grouping exprs and then project them away after the sort.
160+
Project(a.output,
161+
Sort(ordering,
162+
Aggregate(grouping, aggs ++ missingInAggs, child)))
163+
} else {
164+
s // Nothing we can do here. Return original plan.
165+
}
166+
}
167+
}
168+
123169
/**
124170
* Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]].
125171
*/
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.sql.{SQLConf, QueryTest}
23+
import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin}
24+
import org.apache.spark.sql.hive.test.TestHive
25+
import org.apache.spark.sql.hive.test.TestHive._
26+
27+
/**
28+
* A collection of hive query tests where we generate the answers ourselves instead of depending on
29+
* Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
30+
* valid, but Hive currently cannot execute it.
31+
*/
32+
class SQLQuerySuite extends QueryTest {
33+
test("ordering not in select") {
34+
checkAnswer(
35+
sql("SELECT key FROM src ORDER BY value"),
36+
sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq)
37+
}
38+
39+
test("ordering not in agg") {
40+
checkAnswer(
41+
sql("SELECT key FROM src GROUP BY key, value ORDER BY value"),
42+
sql("""
43+
SELECT key
44+
FROM (
45+
SELECT key, value
46+
FROM src
47+
GROUP BY key, value
48+
ORDER BY value) a""").collect().toSeq)
49+
}
50+
}

0 commit comments

Comments
 (0)