Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit e797dba

Browse files
yhuairxin
authored andcommitted
[SPARK-7965] [SPARK-7972] [SQL] Handle expressions containing multiple window expressions and make parser match window frames in case insensitive way
JIRAs: https://issues.apache.org/jira/browse/SPARK-7965 https://issues.apache.org/jira/browse/SPARK-7972 Author: Yin Huai <[email protected]> Closes apache#6524 from yhuai/7965-7972 and squashes the following commits: c12c79c [Yin Huai] Add doc for returned value. de64328 [Yin Huai] Address rxin's comments. fc9b1ad [Yin Huai] wip 2996da4 [Yin Huai] scala style 20b65b7 [Yin Huai] Handle expressions containing multiple window expressions. 9568b21 [Yin Huai] case insensitive matches 41f633d [Yin Huai] Failed test case.
1 parent 7f74bb3 commit e797dba

File tree

3 files changed

+134
-32
lines changed

3 files changed

+134
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -633,25 +633,35 @@ class Analyzer(
633633
* it into the plan tree.
634634
*/
635635
object ExtractWindowExpressions extends Rule[LogicalPlan] {
636-
def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
636+
private def hasWindowFunction(projectList: Seq[NamedExpression]): Boolean =
637637
projectList.exists(hasWindowFunction)
638638

639-
def hasWindowFunction(expr: NamedExpression): Boolean = {
639+
private def hasWindowFunction(expr: NamedExpression): Boolean = {
640640
expr.find {
641641
case window: WindowExpression => true
642642
case _ => false
643643
}.isDefined
644644
}
645645

646646
/**
647-
* From a Seq of [[NamedExpression]]s, extract window expressions and
648-
* other regular expressions.
647+
* From a Seq of [[NamedExpression]]s, extract expressions containing window expressions and
648+
* other regular expressions that do not contain any window expression. For example, for
649+
* `col1, Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5)`, we will extract
650+
* `col1`, `col2 + col3`, `col4`, and `col5` out and replace them appearances in
651+
* the window expression as attribute references. So, the first returned value will be
652+
* `[Sum(_w0) OVER (PARTITION BY _w1 ORDER BY _w2)]` and the second returned value will be
653+
* [col1, col2 + col3 as _w0, col4 as _w1, col5 as _w2].
654+
*
655+
* @return (seq of expressions containing at lease one window expressions,
656+
* seq of non-window expressions)
649657
*/
650-
def extract(
658+
private def extract(
651659
expressions: Seq[NamedExpression]): (Seq[NamedExpression], Seq[NamedExpression]) = {
652-
// First, we simple partition the input expressions to two part, one having
653-
// WindowExpressions and another one without WindowExpressions.
654-
val (windowExpressions, regularExpressions) = expressions.partition(hasWindowFunction)
660+
// First, we partition the input expressions to two part. For the first part,
661+
// every expression in it contain at least one WindowExpression.
662+
// Expressions in the second part do not have any WindowExpression.
663+
val (expressionsWithWindowFunctions, regularExpressions) =
664+
expressions.partition(hasWindowFunction)
655665

656666
// Then, we need to extract those regular expressions used in the WindowExpression.
657667
// For example, when we have col1 - Sum(col2 + col3) OVER (PARTITION BY col4 ORDER BY col5),
@@ -660,8 +670,8 @@ class Analyzer(
660670
val extractedExprBuffer = new ArrayBuffer[NamedExpression]()
661671
def extractExpr(expr: Expression): Expression = expr match {
662672
case ne: NamedExpression =>
663-
// If a named expression is not in regularExpressions, add extract it and replace it
664-
// with an AttributeReference.
673+
// If a named expression is not in regularExpressions, add it to
674+
// extractedExprBuffer and replace it with an AttributeReference.
665675
val missingExpr =
666676
AttributeSet(Seq(expr)) -- (regularExpressions ++ extractedExprBuffer)
667677
if (missingExpr.nonEmpty) {
@@ -678,8 +688,9 @@ class Analyzer(
678688
withName.toAttribute
679689
}
680690

681-
// Now, we extract expressions from windowExpressions by using extractExpr.
682-
val newWindowExpressions = windowExpressions.map {
691+
// Now, we extract regular expressions from expressionsWithWindowFunctions
692+
// by using extractExpr.
693+
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
683694
_.transform {
684695
// Extracts children expressions of a WindowFunction (input parameters of
685696
// a WindowFunction).
@@ -705,37 +716,80 @@ class Analyzer(
705716
}.asInstanceOf[NamedExpression]
706717
}
707718

708-
(newWindowExpressions, regularExpressions ++ extractedExprBuffer)
709-
}
719+
(newExpressionsWithWindowFunctions, regularExpressions ++ extractedExprBuffer)
720+
} // end of extract
710721

711722
/**
712723
* Adds operators for Window Expressions. Every Window operator handles a single Window Spec.
713724
*/
714-
def addWindow(windowExpressions: Seq[NamedExpression], child: LogicalPlan): LogicalPlan = {
715-
// First, we group window expressions based on their Window Spec.
716-
val groupedWindowExpression = windowExpressions.groupBy { expr =>
717-
val windowSpec = expr.collectFirst {
725+
private def addWindow(
726+
expressionsWithWindowFunctions: Seq[NamedExpression],
727+
child: LogicalPlan): LogicalPlan = {
728+
// First, we need to extract all WindowExpressions from expressionsWithWindowFunctions
729+
// and put those extracted WindowExpressions to extractedWindowExprBuffer.
730+
// This step is needed because it is possible that an expression contains multiple
731+
// WindowExpressions with different Window Specs.
732+
// After extracting WindowExpressions, we need to construct a project list to generate
733+
// expressionsWithWindowFunctions based on extractedWindowExprBuffer.
734+
// For example, for "sum(a) over (...) / sum(b) over (...)", we will first extract
735+
// "sum(a) over (...)" and "sum(b) over (...)" out, and assign "_we0" as the alias to
736+
// "sum(a) over (...)" and "_we1" as the alias to "sum(b) over (...)".
737+
// Then, the projectList will be [_we0/_we1].
738+
val extractedWindowExprBuffer = new ArrayBuffer[NamedExpression]()
739+
val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map {
740+
// We need to use transformDown because we want to trigger
741+
// "case alias @ Alias(window: WindowExpression, _)" first.
742+
_.transformDown {
743+
case alias @ Alias(window: WindowExpression, _) =>
744+
// If a WindowExpression has an assigned alias, just use it.
745+
extractedWindowExprBuffer += alias
746+
alias.toAttribute
747+
case window: WindowExpression =>
748+
// If there is no alias assigned to the WindowExpressions. We create an
749+
// internal column.
750+
val withName = Alias(window, s"_we${extractedWindowExprBuffer.length}")()
751+
extractedWindowExprBuffer += withName
752+
withName.toAttribute
753+
}.asInstanceOf[NamedExpression]
754+
}
755+
756+
// Second, we group extractedWindowExprBuffer based on their Window Spec.
757+
val groupedWindowExpressions = extractedWindowExprBuffer.groupBy { expr =>
758+
val distinctWindowSpec = expr.collect {
718759
case window: WindowExpression => window.windowSpec
760+
}.distinct
761+
762+
// We do a final check and see if we only have a single Window Spec defined in an
763+
// expressions.
764+
if (distinctWindowSpec.length == 0 ) {
765+
failAnalysis(s"$expr does not have any WindowExpression.")
766+
} else if (distinctWindowSpec.length > 1) {
767+
// newExpressionsWithWindowFunctions only have expressions with a single
768+
// WindowExpression. If we reach here, we have a bug.
769+
failAnalysis(s"$expr has multiple Window Specifications ($distinctWindowSpec)." +
770+
s"Please file a bug report with this error message, stack trace, and the query.")
771+
} else {
772+
distinctWindowSpec.head
719773
}
720-
windowSpec.getOrElse(
721-
failAnalysis(s"$windowExpressions does not have any WindowExpression."))
722774
}.toSeq
723775

724-
// For every Window Spec, we add a Window operator and set currentChild as the child of it.
776+
// Third, for every Window Spec, we add a Window operator and set currentChild as the
777+
// child of it.
725778
var currentChild = child
726779
var i = 0
727-
while (i < groupedWindowExpression.size) {
728-
val (windowSpec, windowExpressions) = groupedWindowExpression(i)
780+
while (i < groupedWindowExpressions.size) {
781+
val (windowSpec, windowExpressions) = groupedWindowExpressions(i)
729782
// Set currentChild to the newly created Window operator.
730783
currentChild = Window(currentChild.output, windowExpressions, windowSpec, currentChild)
731784

732-
// Move to next WindowExpression.
785+
// Move to next Window Spec.
733786
i += 1
734787
}
735788

736-
// We return the top operator.
737-
currentChild
738-
}
789+
// Finally, we create a Project to output currentChild's output
790+
// newExpressionsWithWindowFunctions.
791+
Project(currentChild.output ++ newExpressionsWithWindowFunctions, currentChild)
792+
} // end of addWindow
739793

740794
// We have to use transformDown at here to make sure the rule of
741795
// "Aggregate with Having clause" will be triggered.

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,6 +1561,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
15611561
""".stripMargin)
15621562
}
15631563

1564+
/* Case insensitive matches for Window Specification */
1565+
val PRECEDING = "(?i)preceding".r
1566+
val FOLLOWING = "(?i)following".r
1567+
val CURRENT = "(?i)current".r
15641568
def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match {
15651569
case Token(windowName, Nil) :: Nil =>
15661570
// Refer to a window spec defined in the window clause.
@@ -1614,11 +1618,19 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
16141618
} else {
16151619
val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame)
16161620
def nodeToBoundary(node: Node): FrameBoundary = node match {
1617-
case Token("preceding", Token(count, Nil) :: Nil) =>
1618-
if (count == "unbounded") UnboundedPreceding else ValuePreceding(count.toInt)
1619-
case Token("following", Token(count, Nil) :: Nil) =>
1620-
if (count == "unbounded") UnboundedFollowing else ValueFollowing(count.toInt)
1621-
case Token("current", Nil) => CurrentRow
1621+
case Token(PRECEDING(), Token(count, Nil) :: Nil) =>
1622+
if (count.toLowerCase() == "unbounded") {
1623+
UnboundedPreceding
1624+
} else {
1625+
ValuePreceding(count.toInt)
1626+
}
1627+
case Token(FOLLOWING(), Token(count, Nil) :: Nil) =>
1628+
if (count.toLowerCase() == "unbounded") {
1629+
UnboundedFollowing
1630+
} else {
1631+
ValueFollowing(count.toInt)
1632+
}
1633+
case Token(CURRENT(), Nil) => CurrentRow
16221634
case _ =>
16231635
throw new NotImplementedError(
16241636
s"""No parse rules for the Window Frame Boundary based on Node ${node.getName}

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,42 @@ class SQLQuerySuite extends QueryTest {
780780
).map(i => Row(i._1, i._2, i._3, i._4)))
781781
}
782782

783+
test("window function: multiple window expressions in a single expression") {
784+
val nums = sparkContext.parallelize(1 to 10).map(x => (x, x % 2)).toDF("x", "y")
785+
nums.registerTempTable("nums")
786+
787+
val expected =
788+
Row(1, 1, 1, 55, 1, 57) ::
789+
Row(0, 2, 3, 55, 2, 60) ::
790+
Row(1, 3, 6, 55, 4, 65) ::
791+
Row(0, 4, 10, 55, 6, 71) ::
792+
Row(1, 5, 15, 55, 9, 79) ::
793+
Row(0, 6, 21, 55, 12, 88) ::
794+
Row(1, 7, 28, 55, 16, 99) ::
795+
Row(0, 8, 36, 55, 20, 111) ::
796+
Row(1, 9, 45, 55, 25, 125) ::
797+
Row(0, 10, 55, 55, 30, 140) :: Nil
798+
799+
val actual = sql(
800+
"""
801+
|SELECT
802+
| y,
803+
| x,
804+
| sum(x) OVER w1 AS running_sum,
805+
| sum(x) OVER w2 AS total_sum,
806+
| sum(x) OVER w3 AS running_sum_per_y,
807+
| ((sum(x) OVER w1) + (sum(x) OVER w2) + (sum(x) OVER w3)) as combined2
808+
|FROM nums
809+
|WINDOW w1 AS (ORDER BY x ROWS BETWEEN UnBOUNDED PRECEDiNG AND CuRRENT RoW),
810+
| w2 AS (ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOuNDED FoLLOWING),
811+
| w3 AS (PARTITION BY y ORDER BY x ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)
812+
""".stripMargin)
813+
814+
checkAnswer(actual, expected)
815+
816+
dropTempTable("nums")
817+
}
818+
783819
test("test case key when") {
784820
(1 to 5).map(i => (i, i.toString)).toDF("k", "v").registerTempTable("t")
785821
checkAnswer(

0 commit comments

Comments
 (0)