Skip to content

Commit 479df75

Browse files
committed
support order by non-attribute grouping expression on Aggregate
1 parent d65656c commit 479df75

File tree

2 files changed

+33
-38
lines changed

2 files changed

+33
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 18 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -560,43 +560,29 @@ class Analyzer(
560560
filter
561561
}
562562

563-
case sort @ Sort(sortOrder, global,
564-
aggregate @ Aggregate(grouping, originalAggExprs, child))
563+
case sort @ Sort(sortOrder, global, aggregate: Aggregate)
565564
if aggregate.resolved && !sort.resolved =>
566565

567566
// Try resolving the ordering as though it is in the aggregate clause.
568567
try {
569-
val aliasedOrder = sortOrder.map(o => Alias(o.child, "aggOrder")())
570-
val aggregatedOrdering = Aggregate(grouping, aliasedOrder, child)
568+
val aliasedOrdering = sortOrder.map(o => Alias(o.child, "aggOrder")())
569+
val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
571570
val resolvedOperator: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate]
572-
def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
573-
574-
// Expressions that have an aggregate can be pushed down.
575-
val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
576-
577-
// Attribute references, that are missing from the order but are present in the grouping
578-
// expressions can also be pushed down.
579-
val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
580-
val missingAttributes = requiredAttributes -- aggregate.outputSet
581-
val validPushdownAttributes =
582-
missingAttributes.filter(a => grouping.exists(a.semanticEquals))
583-
584-
// If resolution was successful and we see the ordering either has an aggregate in it or
585-
// it is missing something that is projected away by the aggregate, add the ordering
586-
// the original aggregate operator.
587-
if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
588-
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedAggregateOrdering).map {
589-
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
590-
}
591-
val aggExprsWithOrdering: Seq[NamedExpression] =
592-
resolvedAggregateOrdering ++ originalAggExprs
593-
594-
Project(aggregate.output,
595-
Sort(evaluatedOrderings, global,
596-
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
597-
} else {
598-
sort
571+
val resolvedOrdering = resolvedOperator.aggregateExpressions
572+
573+
// If we pass the analysis check, then the ordering expressions should only reference to
574+
// aggregate expressions or grouping expressions, and it's safe to push them down to
575+
// Aggregate.
576+
checkAnalysis(resolvedOperator)
577+
// todo: some ordering expressions can be evaluated with existing aggregate expressions
578+
// and we don't need to push them down to Aggregate.
579+
val evaluatedOrderings: Seq[SortOrder] = sortOrder.zip(resolvedOrdering).map {
580+
case (order, evaluated) => order.copy(child = evaluated.toAttribute)
599581
}
582+
val aggExprsWithOrdering = aggregate.aggregateExpressions ++ resolvedOrdering
583+
Project(aggregate.output,
584+
Sort(evaluatedOrderings, global,
585+
aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
600586
} catch {
601587
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
602588
// just return the original plan.
@@ -605,9 +591,7 @@ class Analyzer(
605591
}
606592

607593
protected def containsAggregate(condition: Expression): Boolean = {
608-
condition
609-
.collect { case ae: AggregateExpression => ae }
610-
.nonEmpty
594+
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
611595
}
612596
}
613597

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,9 +1712,20 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
17121712
}
17131713

17141714
test("SPARK-10130 type coercion for IF should have children resolved first") {
1715-
val df = Seq((1, 1), (-1, 1)).toDF("key", "value")
1716-
df.registerTempTable("src")
1717-
checkAnswer(
1718-
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
1715+
withTempTable("src") {
1716+
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
1717+
checkAnswer(
1718+
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
1719+
}
1720+
}
1721+
1722+
test("SPARK-10389: order by non-attribute grouping expression on Aggregate") {
1723+
withTempTable("src") {
1724+
Seq((1, 1), (-1, 1)).toDF("key", "value").registerTempTable("src")
1725+
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY key + 1"),
1726+
Seq(Row(1), Row(1)))
1727+
checkAnswer(sql("SELECT MAX(value) FROM src GROUP BY key + 1 ORDER BY (key + 1) * 2"),
1728+
Seq(Row(1), Row(1)))
1729+
}
17191730
}
17201731
}

0 commit comments

Comments
 (0)