17
17
18
18
package org .apache .spark .sql .execution
19
19
20
- import org .apache .spark .SparkContext
21
20
import org .apache .spark .annotation .DeveloperApi
22
21
import org .apache .spark .sql .SQLContext
23
- import org .apache .spark .sql .catalyst .errors ._
24
22
import org .apache .spark .sql .catalyst .expressions ._
25
23
import org .apache .spark .sql .catalyst .plans .physical ._
26
24
import org .apache .spark .sql .catalyst .types ._
@@ -51,8 +49,6 @@ case class GeneratedAggregate(
51
49
child : SparkPlan )(@ transient sqlContext : SQLContext )
52
50
extends UnaryNode with NoBind {
53
51
54
- private def sc = sqlContext.sparkContext
55
-
56
52
override def requiredChildDistribution =
57
53
if (partial) {
58
54
UnspecifiedDistribution :: Nil
@@ -66,24 +62,24 @@ case class GeneratedAggregate(
66
62
67
63
override def otherCopyArgs = sqlContext :: Nil
68
64
69
- def output = aggregateExpressions.map(_.toAttribute)
65
+ override def output = aggregateExpressions.map(_.toAttribute)
70
66
71
- def execute () = {
67
+ override def execute () = {
72
68
val aggregatesToCompute = aggregateExpressions.flatMap { a =>
73
69
a.collect { case agg : AggregateExpression => agg}
74
70
}
75
71
76
72
val computeFunctions = aggregatesToCompute.map {
77
- case c@ Count (expr) =>
78
- val currentCount = AttributeReference (" currentCount" , LongType , true )()
73
+ case c @ Count (expr) =>
74
+ val currentCount = AttributeReference (" currentCount" , LongType , nullable = false )()
79
75
val initialValue = Literal (0L )
80
76
val updateFunction = If (IsNotNull (expr), Add (currentCount, Literal (1L )), currentCount)
81
77
val result = currentCount
82
78
83
79
AggregateEvaluation (currentCount :: Nil , initialValue :: Nil , updateFunction :: Nil , result)
84
80
85
81
case Sum (expr) =>
86
- val currentSum = AttributeReference (" currentSum" , expr.dataType, true )()
82
+ val currentSum = AttributeReference (" currentSum" , expr.dataType, nullable = false )()
87
83
val initialValue = Cast (Literal (0L ), expr.dataType)
88
84
89
85
// Coalasce avoids double calculation...
@@ -93,9 +89,9 @@ case class GeneratedAggregate(
93
89
94
90
AggregateEvaluation (currentSum :: Nil , initialValue :: Nil , updateFunction :: Nil , result)
95
91
96
- case a@ Average (expr) =>
97
- val currentCount = AttributeReference (" currentCount" , LongType , true )()
98
- val currentSum = AttributeReference (" currentSum" , expr.dataType, true )()
92
+ case a @ Average (expr) =>
93
+ val currentCount = AttributeReference (" currentCount" , LongType , nullable = false )()
94
+ val currentSum = AttributeReference (" currentSum" , expr.dataType, nullable = false )()
99
95
val initialCount = Literal (0L )
100
96
val initialSum = Cast (Literal (0L ), expr.dataType)
101
97
val updateCount = If (IsNotNull (expr), Add (currentCount, Literal (1L )), currentCount)
@@ -131,50 +127,70 @@ case class GeneratedAggregate(
131
127
132
128
child.execute().mapPartitions { iter =>
133
129
// Builds a new custom class for holding the results of aggregation for a group.
130
+ @ transient
134
131
val newAggregationBuffer =
135
132
newProjection(computeFunctions.flatMap(_.initialValues), child.output)
136
133
137
134
// A projection that is used to update the aggregate values for a group given a new tuple.
138
135
// This projection should be targeted at the current values for the group and then applied
139
136
// to a joined row of the current values with the new input row.
137
+ @ transient
140
138
val updateProjection =
141
139
newMutableProjection(
142
140
computeFunctions.flatMap(_.update),
143
141
computeFunctions.flatMap(_.schema) ++ child.output)()
144
142
145
143
// A projection that computes the group given an input tuple.
144
+ @ transient
146
145
val groupProjection = newProjection(groupingExpressions, child.output)
147
146
148
147
// A projection that produces the final result, given a computation.
148
+ @ transient
149
149
val resultProjectionBuilder =
150
150
newMutableProjection(
151
151
resultExpressions,
152
152
(namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq)
153
153
154
- val buffers = new java.util.HashMap [Row , MutableRow ]()
155
154
val joinedRow = new JoinedRow
156
155
157
- var currentRow : Row = null
158
- while (iter.hasNext) {
159
- currentRow = iter.next()
160
- val currentGroup = groupProjection(currentRow)
161
- var currentBuffer = buffers.get(currentGroup)
162
- if (currentBuffer == null ) {
163
- currentBuffer = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
164
- buffers.put(currentGroup, currentBuffer)
156
+ if (groupingExpressions.isEmpty) {
157
+ // TODO: Codegening anything other than the updateProjection is probably over kill.
158
+ val buffer = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
159
+ var currentRow : Row = null
160
+ while (iter.hasNext) {
161
+ currentRow = iter.next()
162
+ updateProjection.target(buffer)(joinedRow(buffer, currentRow))
163
+ }
164
+
165
+ val resultProjection = resultProjectionBuilder()
166
+ Iterator (resultProjection(buffer))
167
+ } else {
168
+ val buffers = new java.util.HashMap [Row , MutableRow ]()
169
+
170
+ var currentRow : Row = null
171
+ while (iter.hasNext) {
172
+ currentRow = iter.next()
173
+ val currentGroup = groupProjection(currentRow)
174
+ var currentBuffer = buffers.get(currentGroup)
175
+ if (currentBuffer == null ) {
176
+ currentBuffer = newAggregationBuffer(EmptyRow ).asInstanceOf [MutableRow ]
177
+ buffers.put(currentGroup, currentBuffer)
178
+ }
179
+ // Target the projection at the current aggregation buffer and then project the updated
180
+ // values.
181
+ updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
165
182
}
166
- updateProjection.target(currentBuffer)(joinedRow(currentBuffer, currentRow))
167
- }
168
183
169
- new Iterator [Row ] {
170
- private [this ] val resultIterator = buffers.entrySet.iterator()
171
- private [this ] val resultProjection = resultProjectionBuilder()
184
+ new Iterator [Row ] {
185
+ private [this ] val resultIterator = buffers.entrySet.iterator()
186
+ private [this ] val resultProjection = resultProjectionBuilder()
172
187
173
- def hasNext = resultIterator.hasNext
188
+ def hasNext = resultIterator.hasNext
174
189
175
- def next () = {
176
- val currentGroup = resultIterator.next()
177
- resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
190
+ def next () = {
191
+ val currentGroup = resultIterator.next()
192
+ resultProjection(joinedRow(currentGroup.getKey, currentGroup.getValue))
193
+ }
178
194
}
179
195
}
180
196
}
0 commit comments