@@ -576,43 +576,47 @@ class Analyzer(
576
576
filter
577
577
}
578
578
579
- case sort @ Sort (sortOrder, global,
580
- aggregate @ Aggregate (grouping, originalAggExprs, child))
579
+ case sort @ Sort (sortOrder, global, aggregate : Aggregate )
581
580
if aggregate.resolved && ! sort.resolved =>
582
581
583
582
// Try resolving the ordering as though it is in the aggregate clause.
584
583
try {
585
- val aliasedOrder = sortOrder.map(o => Alias (o.child, " aggOrder" )())
586
- val aggregatedOrdering = Aggregate (grouping, aliasedOrder, child)
587
- val resolvedOperator : Aggregate = execute(aggregatedOrdering).asInstanceOf [Aggregate ]
588
- def resolvedAggregateOrdering = resolvedOperator.aggregateExpressions
589
-
590
- // Expressions that have an aggregate can be pushed down.
591
- val needsAggregate = resolvedAggregateOrdering.exists(containsAggregate)
592
-
593
- // Attribute references, that are missing from the order but are present in the grouping
594
- // expressions can also be pushed down.
595
- val requiredAttributes = resolvedAggregateOrdering.map(_.references).reduce(_ ++ _)
596
- val missingAttributes = requiredAttributes -- aggregate.outputSet
597
- val validPushdownAttributes =
598
- missingAttributes.filter(a => grouping.exists(a.semanticEquals))
599
-
600
- // If resolution was successful and we see the ordering either has an aggregate in it or
601
- // it is missing something that is projected away by the aggregate, add the ordering
602
- // the original aggregate operator.
603
- if (resolvedOperator.resolved && (needsAggregate || validPushdownAttributes.nonEmpty)) {
604
- val evaluatedOrderings : Seq [SortOrder ] = sortOrder.zip(resolvedAggregateOrdering).map {
605
- case (order, evaluated) => order.copy(child = evaluated.toAttribute)
606
- }
607
- val aggExprsWithOrdering : Seq [NamedExpression ] =
608
- resolvedAggregateOrdering ++ originalAggExprs
609
-
610
- Project (aggregate.output,
611
- Sort (evaluatedOrderings, global,
612
- aggregate.copy(aggregateExpressions = aggExprsWithOrdering)))
613
- } else {
614
- sort
584
+ val aliasedOrdering = sortOrder.map(o => Alias (o.child, " aggOrder" )())
585
+ val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering)
586
+ val resolvedAggregate : Aggregate = execute(aggregatedOrdering).asInstanceOf [Aggregate ]
587
+ val resolvedAliasedOrdering : Seq [Alias ] =
588
+ resolvedAggregate.aggregateExpressions.asInstanceOf [Seq [Alias ]]
589
+
590
+ // If we pass the analysis check, then the ordering expressions should only reference to
591
+ // aggregate expressions or grouping expressions, and it's safe to push them down to
592
+ // Aggregate.
593
+ checkAnalysis(resolvedAggregate)
594
+
595
+ val originalAggExprs = aggregate.aggregateExpressions.map(
596
+ CleanupAliases .trimNonTopLevelAliases(_).asInstanceOf [NamedExpression ])
597
+
598
+ // If the ordering expression is same with original aggregate expression, we don't need
599
+ // to push down this ordering expression and can reference the original aggregate
600
+ // expression instead.
601
+ val needsPushDown = ArrayBuffer .empty[NamedExpression ]
602
+ val evaluatedOrderings = resolvedAliasedOrdering.zip(sortOrder).map {
603
+ case (evaluated, order) =>
604
+ val index = originalAggExprs.indexWhere {
605
+ case Alias (child, _) => child semanticEquals evaluated.child
606
+ case other => other semanticEquals evaluated.child
607
+ }
608
+
609
+ if (index == - 1 ) {
610
+ needsPushDown += evaluated
611
+ order.copy(child = evaluated.toAttribute)
612
+ } else {
613
+ order.copy(child = originalAggExprs(index).toAttribute)
614
+ }
615
615
}
616
+
617
+ Project (aggregate.output,
618
+ Sort (evaluatedOrderings, global,
619
+ aggregate.copy(aggregateExpressions = originalAggExprs ++ needsPushDown)))
616
620
} catch {
617
621
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
618
622
// just return the original plan.
@@ -621,9 +625,7 @@ class Analyzer(
621
625
}
622
626
623
627
protected def containsAggregate (condition : Expression ): Boolean = {
624
- condition
625
- .collect { case ae : AggregateExpression => ae }
626
- .nonEmpty
628
+ condition.find(_.isInstanceOf [AggregateExpression ]).isDefined
627
629
}
628
630
}
629
631
0 commit comments