Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 9927441

Browse files
arayyhuai
authored andcommitted
[SPARK-13749][SQL] Faster pivot implementation for many distinct values with two phase aggregation
## What changes were proposed in this pull request? The existing implementation of pivot translates into a single aggregation with one aggregate per distinct pivot value. When the number of distinct pivot values is large (say 1000+) this can get extremely slow since each input value gets evaluated on every aggregate even though it only affects the value of one of them. I'm proposing an alternate strategy for when there are 10+ (somewhat arbitrary threshold) distinct pivot values. We do two phases of aggregation. In the first we group by the grouping columns plus the pivot column and perform the specified aggregations (one or sometimes more). In the second aggregation we group by the grouping columns and use the new (non public) PivotFirst aggregate that rearranges the outputs of the first aggregation into an array indexed by the pivot value. Finally we do a project to extract the array entries into the appropriate output column. ## How was this patch tested? Additional unit tests in DataFramePivotSuite and manual larger scale testing. Author: Andrew Ray <[email protected]> Closes apache#11583 from aray/fast-pivot.
1 parent 0a30269 commit 9927441

File tree

3 files changed

+296
-33
lines changed

3 files changed

+296
-33
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -363,43 +363,68 @@ class Analyzer(
363363

364364
object ResolvePivot extends Rule[LogicalPlan] {
365365
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
366-
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved) => p
366+
case p: Pivot if !p.childrenResolved | !p.aggregates.forall(_.resolved)
367+
| !p.groupByExprs.forall(_.resolved) | !p.pivotColumn.resolved => p
367368
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
368369
val singleAgg = aggregates.size == 1
369-
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
370-
def ifExpr(expr: Expression) = {
371-
If(EqualTo(pivotColumn, value), expr, Literal(null))
370+
def outputName(value: Literal, aggregate: Expression): String = {
371+
if (singleAgg) value.toString else value + "_" + aggregate.sql
372+
}
373+
if (aggregates.forall(a => PivotFirst.supportsDataType(a.dataType))) {
374+
// Since evaluating |pivotValues| if statements for each input row can get slow this is an
375+
// alternate plan that instead uses two steps of aggregation.
376+
val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)())
377+
val namedPivotCol = pivotColumn match {
378+
case n: NamedExpression => n
379+
case _ => Alias(pivotColumn, "__pivot_col")()
380+
}
381+
val bigGroup = groupByExprs :+ namedPivotCol
382+
val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
383+
val castPivotValues = pivotValues.map(Cast(_, pivotColumn.dataType).eval(EmptyRow))
384+
val pivotAggs = namedAggExps.map { a =>
385+
Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, castPivotValues)
386+
.toAggregateExpression()
387+
, "__pivot_" + a.sql)()
388+
}
389+
val secondAgg = Aggregate(groupByExprs, groupByExprs ++ pivotAggs, firstAgg)
390+
val pivotAggAttribute = pivotAggs.map(_.toAttribute)
391+
val pivotOutputs = pivotValues.zipWithIndex.flatMap { case (value, i) =>
392+
aggregates.zip(pivotAggAttribute).map { case (aggregate, pivotAtt) =>
393+
Alias(ExtractValue(pivotAtt, Literal(i), resolver), outputName(value, aggregate))()
394+
}
372395
}
373-
aggregates.map { aggregate =>
374-
val filteredAggregate = aggregate.transformDown {
375-
// Assumption is the aggregate function ignores nulls. This is true for all current
376-
// AggregateFunction's with the exception of First and Last in their default mode
377-
// (which we handle) and possibly some Hive UDAF's.
378-
case First(expr, _) =>
379-
First(ifExpr(expr), Literal(true))
380-
case Last(expr, _) =>
381-
Last(ifExpr(expr), Literal(true))
382-
case a: AggregateFunction =>
383-
a.withNewChildren(a.children.map(ifExpr))
384-
}.transform {
385-
// We are duplicating aggregates that are now computing a different value for each
386-
// pivot value.
387-
// TODO: Don't construct the physical container until after analysis.
388-
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
396+
Project(groupByExprs ++ pivotOutputs, secondAgg)
397+
} else {
398+
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
399+
def ifExpr(expr: Expression) = {
400+
If(EqualTo(pivotColumn, value), expr, Literal(null))
389401
}
390-
if (filteredAggregate.fastEquals(aggregate)) {
391-
throw new AnalysisException(
392-
s"Aggregate expression required for pivot, found '$aggregate'")
402+
aggregates.map { aggregate =>
403+
val filteredAggregate = aggregate.transformDown {
404+
// Assumption is the aggregate function ignores nulls. This is true for all current
405+
// AggregateFunction's with the exception of First and Last in their default mode
406+
// (which we handle) and possibly some Hive UDAF's.
407+
case First(expr, _) =>
408+
First(ifExpr(expr), Literal(true))
409+
case Last(expr, _) =>
410+
Last(ifExpr(expr), Literal(true))
411+
case a: AggregateFunction =>
412+
a.withNewChildren(a.children.map(ifExpr))
413+
}.transform {
414+
// We are duplicating aggregates that are now computing a different value for each
415+
// pivot value.
416+
// TODO: Don't construct the physical container until after analysis.
417+
case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId)
418+
}
419+
if (filteredAggregate.fastEquals(aggregate)) {
420+
throw new AnalysisException(
421+
s"Aggregate expression required for pivot, found '$aggregate'")
422+
}
423+
Alias(filteredAggregate, outputName(value, aggregate))()
393424
}
394-
val name = if (singleAgg) value.toString else value + "_" + aggregate.sql
395-
Alias(filteredAggregate, name)()
396425
}
426+
Aggregate(groupByExprs, groupByExprs ++ pivotAggregates, child)
397427
}
398-
val newGroupByExprs = groupByExprs.map {
399-
case UnresolvedAlias(e, _) => e
400-
case e => e
401-
}
402-
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
403428
}
404429
}
405430

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.aggregate
19+
20+
import scala.collection.immutable.HashMap
21+
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.expressions._
24+
import org.apache.spark.sql.catalyst.util.GenericArrayData
25+
import org.apache.spark.sql.types._
26+
27+
object PivotFirst {
28+
29+
def supportsDataType(dataType: DataType): Boolean = updateFunction.isDefinedAt(dataType)
30+
31+
// Currently UnsafeRow does not support the generic update method (throws
32+
// UnsupportedOperationException), so we need to explicitly support each DataType.
33+
private val updateFunction: PartialFunction[DataType, (MutableRow, Int, Any) => Unit] = {
34+
case DoubleType =>
35+
(row, offset, value) => row.setDouble(offset, value.asInstanceOf[Double])
36+
case IntegerType =>
37+
(row, offset, value) => row.setInt(offset, value.asInstanceOf[Int])
38+
case LongType =>
39+
(row, offset, value) => row.setLong(offset, value.asInstanceOf[Long])
40+
case FloatType =>
41+
(row, offset, value) => row.setFloat(offset, value.asInstanceOf[Float])
42+
case BooleanType =>
43+
(row, offset, value) => row.setBoolean(offset, value.asInstanceOf[Boolean])
44+
case ShortType =>
45+
(row, offset, value) => row.setShort(offset, value.asInstanceOf[Short])
46+
case ByteType =>
47+
(row, offset, value) => row.setByte(offset, value.asInstanceOf[Byte])
48+
case d: DecimalType =>
49+
(row, offset, value) => row.setDecimal(offset, value.asInstanceOf[Decimal], d.precision)
50+
}
51+
}
52+
53+
/**
54+
* PivotFirst is a aggregate function used in the second phase of a two phase pivot to do the
55+
* required rearrangement of values into pivoted form.
56+
*
57+
* For example on an input of
58+
* A | B
59+
* --+--
60+
* x | 1
61+
* y | 2
62+
* z | 3
63+
*
64+
* with pivotColumn=A, valueColumn=B, and pivotColumnValues=[z,y] the output is [3,2].
65+
*
66+
* @param pivotColumn column that determines which output position to put valueColumn in.
67+
* @param valueColumn the column that is being rearranged.
68+
* @param pivotColumnValues the list of pivotColumn values in the order of desired output. Values
69+
* not listed here will be ignored.
70+
*/
71+
case class PivotFirst(
72+
pivotColumn: Expression,
73+
valueColumn: Expression,
74+
pivotColumnValues: Seq[Any],
75+
mutableAggBufferOffset: Int = 0,
76+
inputAggBufferOffset: Int = 0) extends ImperativeAggregate {
77+
78+
override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil
79+
80+
override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType)
81+
82+
override val nullable: Boolean = false
83+
84+
val valueDataType = valueColumn.dataType
85+
86+
override val dataType: DataType = ArrayType(valueDataType)
87+
88+
val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*)
89+
90+
val indexSize = pivotIndex.size
91+
92+
private val updateRow: (MutableRow, Int, Any) => Unit = PivotFirst.updateFunction(valueDataType)
93+
94+
override def update(mutableAggBuffer: MutableRow, inputRow: InternalRow): Unit = {
95+
val pivotColValue = pivotColumn.eval(inputRow)
96+
if (pivotColValue != null) {
97+
// We ignore rows whose pivot column value is not in the list of pivot column values.
98+
val index = pivotIndex.getOrElse(pivotColValue, -1)
99+
if (index >= 0) {
100+
val value = valueColumn.eval(inputRow)
101+
if (value != null) {
102+
updateRow(mutableAggBuffer, mutableAggBufferOffset + index, value)
103+
}
104+
}
105+
}
106+
}
107+
108+
override def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit = {
109+
for (i <- 0 until indexSize) {
110+
if (!inputAggBuffer.isNullAt(inputAggBufferOffset + i)) {
111+
val value = inputAggBuffer.get(inputAggBufferOffset + i, valueDataType)
112+
updateRow(mutableAggBuffer, mutableAggBufferOffset + i, value)
113+
}
114+
}
115+
}
116+
117+
override def initialize(mutableAggBuffer: MutableRow): Unit = valueDataType match {
118+
case d: DecimalType =>
119+
// Per doc of setDecimal we need to do this instead of setNullAt for DecimalType.
120+
for (i <- 0 until indexSize) {
121+
mutableAggBuffer.setDecimal(mutableAggBufferOffset + i, null, d.precision)
122+
}
123+
case _ =>
124+
for (i <- 0 until indexSize) {
125+
mutableAggBuffer.setNullAt(mutableAggBufferOffset + i)
126+
}
127+
}
128+
129+
override def eval(input: InternalRow): Any = {
130+
val result = new Array[Any](indexSize)
131+
for (i <- 0 until indexSize) {
132+
result(i) = input.get(mutableAggBufferOffset + i, valueDataType)
133+
}
134+
new GenericArrayData(result)
135+
}
136+
137+
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
138+
copy(inputAggBufferOffset = newInputAggBufferOffset)
139+
140+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
141+
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
142+
143+
144+
override lazy val aggBufferAttributes: Seq[AttributeReference] =
145+
pivotIndex.toList.sortBy(_._2).map(kv => AttributeReference(kv._1.toString, valueDataType)())
146+
147+
override lazy val aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
148+
149+
override lazy val inputAggBufferAttributes: Seq[AttributeReference] =
150+
aggBufferAttributes.map(_.newInstance())
151+
}
152+

sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,31 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.expressions.aggregate.PivotFirst
2021
import org.apache.spark.sql.functions._
2122
import org.apache.spark.sql.internal.SQLConf
2223
import org.apache.spark.sql.test.SharedSQLContext
24+
import org.apache.spark.sql.types._
2325

2426
class DataFramePivotSuite extends QueryTest with SharedSQLContext{
2527
import testImplicits._
2628

27-
test("pivot courses with literals") {
29+
test("pivot courses") {
2830
checkAnswer(
2931
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
3032
.agg(sum($"earnings")),
3133
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
3234
)
3335
}
3436

35-
test("pivot year with literals") {
37+
test("pivot year") {
3638
checkAnswer(
3739
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
3840
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
3941
)
4042
}
4143

42-
test("pivot courses with literals and multiple aggregations") {
44+
test("pivot courses with multiple aggregations") {
4345
checkAnswer(
4446
courseSales.groupBy($"year")
4547
.pivot("course", Seq("dotNET", "Java"))
@@ -94,4 +96,88 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
9496
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
9597
)
9698
}
99+
100+
// Tests for optimized pivot (with PivotFirst) below
101+
102+
test("optimized pivot planned") {
103+
val df = courseSales.groupBy("year")
104+
// pivot with extra columns to trigger optimization
105+
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
106+
.agg(sum($"earnings"))
107+
val queryExecution = sqlContext.executePlan(df.queryExecution.logical)
108+
assert(queryExecution.simpleString.contains("pivotfirst"))
109+
}
110+
111+
112+
test("optimized pivot courses with literals") {
113+
checkAnswer(
114+
courseSales.groupBy("year")
115+
// pivot with extra columns to trigger optimization
116+
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
117+
.agg(sum($"earnings"))
118+
.select("year", "dotNET", "Java"),
119+
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
120+
)
121+
}
122+
123+
test("optimized pivot year with literals") {
124+
checkAnswer(
125+
courseSales.groupBy($"course")
126+
// pivot with extra columns to trigger optimization
127+
.pivot("year", Seq(2012, 2013) ++ (1 to 10))
128+
.agg(sum($"earnings"))
129+
.select("course", "2012", "2013"),
130+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
131+
)
132+
}
133+
134+
test("optimized pivot year with string values (cast)") {
135+
checkAnswer(
136+
courseSales.groupBy("course")
137+
// pivot with extra columns to trigger optimization
138+
.pivot("year", Seq("2012", "2013") ++ (1 to 10).map(_.toString))
139+
.sum("earnings")
140+
.select("course", "2012", "2013"),
141+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
142+
)
143+
}
144+
145+
test("optimized pivot DecimalType") {
146+
val df = courseSales.select($"course", $"year", $"earnings".cast(DecimalType(10, 2)))
147+
.groupBy("year")
148+
// pivot with extra columns to trigger optimization
149+
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
150+
.agg(sum($"earnings"))
151+
.select("year", "dotNET", "Java")
152+
153+
assertResult(IntegerType)(df.schema("year").dataType)
154+
assertResult(DecimalType(20, 2))(df.schema("Java").dataType)
155+
assertResult(DecimalType(20, 2))(df.schema("dotNET").dataType)
156+
157+
checkAnswer(df, Row(2012, BigDecimal(1500000, 2), BigDecimal(2000000, 2)) ::
158+
Row(2013, BigDecimal(4800000, 2), BigDecimal(3000000, 2)) :: Nil)
159+
}
160+
161+
test("PivotFirst supported datatypes") {
162+
val supportedDataTypes: Seq[DataType] = DoubleType :: IntegerType :: LongType :: FloatType ::
163+
BooleanType :: ShortType :: ByteType :: Nil
164+
for (datatype <- supportedDataTypes) {
165+
assertResult(true)(PivotFirst.supportsDataType(datatype))
166+
}
167+
assertResult(true)(PivotFirst.supportsDataType(DecimalType(10, 1)))
168+
assertResult(false)(PivotFirst.supportsDataType(null))
169+
assertResult(false)(PivotFirst.supportsDataType(ArrayType(IntegerType)))
170+
}
171+
172+
test("optimized pivot with multiple aggregations") {
173+
checkAnswer(
174+
courseSales.groupBy($"year")
175+
// pivot with extra columns to trigger optimization
176+
.pivot("course", Seq("dotNET", "Java") ++ (1 to 10).map(_.toString))
177+
.agg(sum($"earnings"), avg($"earnings")),
178+
Row(Seq(2012, 15000.0, 7500.0, 20000.0, 20000.0) ++ Seq.fill(20)(null): _*) ::
179+
Row(Seq(2013, 48000.0, 48000.0, 30000.0, 30000.0) ++ Seq.fill(20)(null): _*) :: Nil
180+
)
181+
}
182+
97183
}

0 commit comments

Comments
 (0)