Skip to content

[Spark-4512] [SQL] Unresolved Attribute Exception in Sort By #3386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ class SqlParser extends AbstractSparkSQLParser {
)

protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] =
( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, l) }
| SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => SortPartitions(o, l) }
( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, true, l) }
| SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, false, l) }
)

protected lazy val ordering: Parser[Seq[SortOrder]] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ class Analyzer(catalog: Catalog,
case p: LogicalPlan if !p.childrenResolved => p

// If the projection list contains Stars, expand it.
case p@Project(projectList, child) if containsStar(projectList) =>
case p @ Project(projectList, child) if containsStar(projectList) =>
Project(
projectList.flatMap {
case s: Star => s.expand(child.output, resolver)
Expand Down Expand Up @@ -310,7 +310,8 @@ class Analyzer(catalog: Catalog,
*/
object ResolveSortReferences extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
case s @ Sort(ordering, global, p @ Project(projectList, child))
if !s.resolved && p.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
val resolved = unresolved.flatMap(child.resolve(_, resolver))
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
Expand All @@ -319,13 +320,14 @@ class Analyzer(catalog: Catalog,
if (missingInProject.nonEmpty) {
// Add missing attributes and then project them away after the sort.
Project(projectList.map(_.toAttribute),
Sort(ordering,
Sort(ordering, global,
Project(projectList ++ missingInProject, child)))
} else {
logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
s // Nothing we can do here. Return original plan.
}
case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
case s @ Sort(ordering, global, a @ Aggregate(grouping, aggs, child))
if !s.resolved && a.resolved =>
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
// A small hack to create an object that will allow us to resolve any references that
// refer to named expressions that are present in the grouping expressions.
Expand All @@ -340,8 +342,7 @@ class Analyzer(catalog: Catalog,
if (missingInAggs.nonEmpty) {
// Add missing grouping exprs and then project them away after the sort.
Project(a.output,
Sort(ordering,
Aggregate(grouping, aggs ++ missingInAggs, child)))
Sort(ordering, global, Aggregate(grouping, aggs ++ missingInAggs, child)))
} else {
s // Nothing we can do here. Return original plan.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ package object dsl {
condition: Option[Expression] = None) =
Join(logicalPlan, otherPlan, joinType, condition)

def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan)
def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, true, logicalPlan)

def sortBy(sortExprs: SortOrder*) = SortPartitions(sortExprs, logicalPlan)
def sortBy(sortExprs: SortOrder*) = Sort(sortExprs, false, logicalPlan)

def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = {
val aliasedExprs = aggregateExprs.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,16 @@ case class WriteToFile(
override def output = child.output
}

case class Sort(order: Seq[SortOrder], child: LogicalPlan) extends UnaryNode {
/**
* @param order The ordering expressions
* @param global True means global sorting apply for entire data set,
* False means sorting only apply within the partition.
* @param child Child logical plan
*/
case class Sort(
order: Seq[SortOrder],
global: Boolean,
child: LogicalPlan) extends UnaryNode {
override def output = child.output
}

Expand Down
5 changes: 2 additions & 3 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ class SchemaRDD(
* @group Query
*/
def orderBy(sortExprs: SortOrder*): SchemaRDD =
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan))

/**
* Sorts the results by the given expressions within partition.
Expand All @@ -227,7 +227,7 @@ class SchemaRDD(
* @group Query
*/
def sortBy(sortExprs: SortOrder*): SchemaRDD =
new SchemaRDD(sqlContext, SortPartitions(sortExprs, logicalPlan))
new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan))

@deprecated("use limit with integer argument", "1.1.0")
def limit(limitExpr: Expression): SchemaRDD =
Expand All @@ -238,7 +238,6 @@ class SchemaRDD(
* {{{
* schemaRDD.limit(10)
* }}}
*
* @group Query
*/
def limit(limitNum: Int): SchemaRDD =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {

object TakeOrdered extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
execution.TakeOrdered(limit, order, planLater(child)) :: Nil
case _ => Nil
}
Expand Down Expand Up @@ -257,15 +257,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Distinct(partial = false,
execution.Distinct(partial = true, planLater(child))) :: Nil

case logical.Sort(sortExprs, child) if sqlContext.externalSortEnabled =>
execution.ExternalSort(sortExprs, global = true, planLater(child)):: Nil
case logical.Sort(sortExprs, child) =>
execution.Sort(sortExprs, global = true, planLater(child)):: Nil

case logical.SortPartitions(sortExprs, child) =>
// This sort only sorts tuples within a partition. Its requiredDistribution will be
// an UnspecifiedDistribution.
execution.Sort(sortExprs, global = false, planLater(child)) :: Nil
case logical.Sort(sortExprs, global, child) if sqlContext.externalSortEnabled =>
execution.ExternalSort(sortExprs, global, planLater(child)):: Nil
case logical.Sort(sortExprs, global, child) =>
execution.Sort(sortExprs, global, planLater(child)):: Nil
case logical.Project(projectList, child) =>
execution.Project(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
Expand Down
19 changes: 14 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class DslQuerySuite extends QueryTest {
Seq(Seq(6)))
}

test("sorting") {
test("global sorting") {
checkAnswer(
testData2.orderBy('a.asc, 'b.asc),
Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
Expand Down Expand Up @@ -120,22 +120,31 @@ class DslQuerySuite extends QueryTest {
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}

test("sorting #2") {
test("partition wide sorting") {
// 2 partitions totally, and
// Partition #1 with values:
// (1, 1)
// (1, 2)
// (2, 1)
// Partition #2 with values:
// (2, 2)
// (3, 1)
// (3, 2)
checkAnswer(
testData2.sortBy('a.asc, 'b.asc),
Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))

checkAnswer(
testData2.sortBy('a.asc, 'b.desc),
Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
Seq((1,2), (1,1), (2,1), (2,2), (3,2), (3,1)))

checkAnswer(
testData2.sortBy('a.desc, 'b.desc),
Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
Seq((2,1), (1,2), (1,1), (3,2), (3,1), (2,2)))

checkAnswer(
testData2.sortBy('a.desc, 'b.asc),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
Seq((2,1), (1,1), (1,2), (3,1), (3,2), (2,2)))
}

test("limit") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
TestData2(3, 2) :: Nil).toSchemaRDD
TestData2(3, 2) :: Nil, 2).toSchemaRDD
testData2.registerTempTable("testData2")

case class DecimalData(a: BigDecimal, b: BigDecimal)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,16 +680,16 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val withSort =
(orderByClause, sortByClause, distributeByClause, clusterByClause) match {
case (Some(totalOrdering), None, None, None) =>
Sort(totalOrdering.getChildren.map(nodeToSortOrder), withHaving)
Sort(totalOrdering.getChildren.map(nodeToSortOrder), true, withHaving)
case (None, Some(perPartitionOrdering), None, None) =>
SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder), withHaving)
Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false, withHaving)
case (None, None, Some(partitionExprs), None) =>
Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving)
case (None, Some(perPartitionOrdering), Some(partitionExprs), None) =>
SortPartitions(perPartitionOrdering.getChildren.map(nodeToSortOrder),
Sort(perPartitionOrdering.getChildren.map(nodeToSortOrder), false,
Repartition(partitionExprs.getChildren.map(nodeToExpr), withHaving))
case (None, None, None, Some(clusterExprs)) =>
SortPartitions(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)),
Sort(clusterExprs.getChildren.map(nodeToExpr).map(SortOrder(_, Ascending)), false,
Repartition(clusterExprs.getChildren.map(nodeToExpr), withHaving))
case (None, None, None, None) => withHaving
case _ => sys.error("Unsupported set of ordering / distribution clauses.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ abstract class HiveComparisonTest

def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case PhysicalOperation(_, _, Sort(_, _)) => true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ case class Nested3(f3: Int)
* valid, but Hive currently cannot execute it.
*/
class SQLQuerySuite extends QueryTest {
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),
sql("SELECT key + key as a FROM src ORDER BY a").collect().toSeq
)
}

test("CTAS with serde") {
sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect
sql(
Expand Down