Skip to content

Commit 5a10e10

Browse files
cloud-fanmarkhamstra
authored andcommitted
[SPARK-10389] [SQL] support order by non-attribute grouping expression on Aggregate
For example, we can write `SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1` in PostgreSQL, and we should support this in Spark SQL. Author: Wenchen Fan <[email protected]> Closes apache#8548 from cloud-fan/support-order-by-non-attribute.
1 parent 9a625f3 commit 5a10e10

File tree

2 files changed

+52
-39
lines changed

2 files changed

+52
-39
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -576,43 +576,47 @@ class Analyzer(
576576
filter
577577
}
578578

579-
case sort @ Sort(sortOrder, global,
580-
aggregate @ Aggregate(grouping, originalAggExprs, child))
579+
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
581580
if aggregate.resolved && !sort.resolved =>
582581

583582
// Try resolving the ordering as though it is in the aggregate clause.
584583
try {
585-
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
586-
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
587-
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
588-
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
589-
590-
// Expressions that have an aggregate can be pushed down.
591-
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
592-
593-
// Attribute references, that are missing from the order but are present in the grouping
594-
// expressions can also be pushed down.
595-
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
596-
val missingAttributes = requiredAttributes -- aggregate.outputSet
597-
val validPushdownAttributes =
598-
missingAttributes.filter(a => grouping.exists(a.semanticEquals))
599-
600-
// If resolution was successful and we see the ordering either has an aggregate in it or
601-
// it is missing something that is projected away by the aggregate, add the ordering
602-
// the original aggregate operator.
603-
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
604-
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
605-
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
606-
}
607-
val aggExprsWithOrdering: Seq[NamedExpression] =
608-
resolvedAggregateOrdering ++ originalAggExprs
609-
610-
Project(aggregate.output,
611-
Sort(evaluatedOrderings, global,
612-
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
613-
} else {
614-
sort
584+
val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
585+
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
586+
val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
587+
val resolvedAliasedOrdering: Seq[Alias] =
588+
resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]]
589+
590+
// If we pass the analysis check, then the ordering expressions should only reference to
591+
// aggregate expressions or grouping expressions, and it's safe to push them down to
592+
// Aggregate.
593+
checkAnalysis(resolvedAggregate)
594+
595+
val originalAggExprs = aggregate.aggregateExpressions.map(
596+
CleanupAliases.trimNonTopLevelAliases(_).asInstanceOf[NamedExpression])
597+
598+
// If the ordering expression is same with original aggregate expression, we don't need
599+
// to push down this ordering expression and can reference the original aggregate
600+
// expression instead.
601+
val needsPushDown = ArrayBuffer.empty[NamedExpression]
602+
val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
603+
case (evaluated, order) =>
604+
val index = originalAggExprs.indexWhere {
605+
case Alias(child, _) => child semanticEquals evaluated.child
606+
case other => other semanticEquals evaluated.child
607+
}
608+
609+
if (index == -1) {
610+
needsPushDown += evaluated
611+
order.copy(child = evaluated.toAttribute)
612+
} else {
613+
order.copy(child = originalAggExprs(index).toAttribute)
614+
}
615615
}
616+
617+
Project(aggregate.output,
618+
Sort(evaluatedOrderings, global,
619+
aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
616620
} catch {
617621
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
618622
// just return the original plan.
@@ -621,9 +625,7 @@ class Analyzer(
621625
}
622626

623627
protected def containsAggregate(condition: Expression): Boolean = {
624-
condition
625-
.collect { case ae: AggregateExpression => ae }
626-
.nonEmpty
628+
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
627629
}
628630
}
629631

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,10 +1712,21 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
17121712
}
17131713

17141714
test("SPARK-10130 type coercion for IF should have children resolved first") {
1715-
val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
1716-
df.registerTempTable("src")
1717-
checkAnswer(
1718-
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
1715+
withTempTable("src") {
1716+
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
1717+
checkAnswer(
1718+
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
1719+
}
1720+
}
1721+
1722+
test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
1723+
withTempTable("src") {
1724+
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
1725+
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
1726+
Seq(Row(1), Row(1)))
1727+
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
1728+
Seq(Row(1), Row(1)))
1729+
}
17191730
}
17201731

17211732
test("SortMergeJoin returns wrong results when using UnsafeRows") {

0 commit comments

Comments
 (0)