Skip to content

Commit b539baf

Browse files
fix the bug of revert the null issue in Sum and also the Average UDAF
1 parent 341e708 commit b539baf

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,8 @@ case class Sum(child: Expression, distinct: Boolean = false)
413413
@transient var arg: MutableLiteral = _
414414
@transient var sum: Add = _
415415

416+
lazy val DEFAULT_VALUE = Cast(Literal(0, IntegerType), dataType).eval()
417+
416418
override def initialBoundReference(buffers: Seq[BoundReference]) = {
417419
aggr = buffers(0)
418420
arg = MutableLiteral(null, dataType)
@@ -431,6 +433,10 @@ case class Sum(child: Expression, distinct: Boolean = false)
431433
arg.value = argument
432434
buf(aggr) = sum.eval(buf)
433435
}
436+
} else {
437+
if (buf.isNullAt(aggr)) {
438+
buf(aggr) = DEFAULT_VALUE
439+
}
434440
}
435441
}
436442

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ private[hive] case class HiveGenericUdaf(
261261
// Initialize (reinitialize) the aggregation buffer
262262
override def reset(buf: MutableRow): Unit = {
263263
val buffer = evaluator.getNewAggregationBuffer
264-
.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
265264
evaluator.reset(buffer)
266265
// This is a hack, we never use the mutable row as buffer, but define our own buffer,
267266
// which is set as the first element of the buffer
@@ -276,27 +275,27 @@ private[hive] case class HiveGenericUdaf(
276275
}.toArray
277276

278277
evaluator.iterate(
279-
buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal),
278+
buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal),
280279
args)
281280
}
282281

283282
// Merge 2 aggregation buffer, and write back to the later one
284283
override def merge(value: Row, buf: MutableRow): Unit = {
285-
val buffer = buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal)
284+
val buffer = buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal)
286285
evaluator.merge(buffer, wrap(value.get(bound.ordinal), bufferObjectInspector))
287286
}
288287

289288
@deprecated
290289
override def terminatePartial(buf: MutableRow): Unit = {
291-
val buffer = buf.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal)
290+
val buffer = buf.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal)
292291
// this is for serialization
293292
buf(bound) = unwrap(evaluator.terminatePartial(buffer), bufferObjectInspector)
294293
}
295294

296295
// Output the final result by feeding the aggregation buffer
297296
override def terminate(input: Row): Any = {
298297
unwrap(evaluator.terminate(
299-
input.getAs[GenericUDAFEvaluator.AbstractAggregationBuffer](bound.ordinal)),
298+
input.getAs[GenericUDAFEvaluator.AggregationBuffer](bound.ordinal)),
300299
objectInspector)
301300
}
302301
}

0 commit comments

Comments
 (0)