Skip to content

Commit 13f4f15

Browse files
using OpenHashSet instead
1 parent b539baf commit 13f4f15

File tree

1 file changed

+51
-49
lines changed

1 file changed

+51
-49
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import scala.collection._
21-
2220
import org.apache.spark.annotation.DeveloperApi
2321

22+
import org.apache.spark.util.collection.{OpenHashSet, OpenHashMap}
23+
2424
import org.apache.spark.sql.catalyst.errors._
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.physical._
@@ -39,7 +39,7 @@ sealed case class AggregateFunctionBind(
3939
sealed class InputBufferSeens(
4040
var input: Row, //
4141
var buffer: MutableRow,
42-
var seens: Array[mutable.HashSet[Any]] = null) {
42+
var seens: Array[OpenHashSet[Any]] = null) {
4343
def this() {
4444
this(new GenericMutableRow(0), null)
4545
}
@@ -54,7 +54,7 @@ sealed class InputBufferSeens(
5454
this
5555
}
5656

57-
def withSeens(seens: Array[mutable.HashSet[Any]]): InputBufferSeens = {
57+
def withSeens(seens: Array[OpenHashSet[Any]]): InputBufferSeens = {
5858
this.seens = seens
5959
this
6060
}
@@ -250,20 +250,13 @@ case class AggregatePreShuffle(
250250

251251
createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer)))
252252
} else {
253-
val results = new mutable.HashMap[Row, InputBufferSeens]()
253+
val results = new OpenHashMap[Row, InputBufferSeens]()
254254
while (iter.hasNext) {
255255
val currentRow = iter.next()
256256

257257
val keys = groupByProjection(currentRow)
258-
results.get(keys) match {
259-
case Some(inputbuffer) =>
260-
var idx = 0
261-
while (idx < aggregates.length) {
262-
val ae = aggregates(idx)
263-
ae.iterate(ae.eval(currentRow), inputbuffer.buffer)
264-
idx += 1
265-
}
266-
case None =>
258+
results(keys) match {
259+
case null =>
267260
val buffer = new GenericMutableRow(bufferSchema.length)
268261
var idx = 0
269262
while (idx < aggregates.length) {
@@ -278,11 +271,19 @@ case class AggregatePreShuffle(
278271
}
279272

280273
val copies = keys.copy()
281-
results.put(copies, new InputBufferSeens(copies, buffer))
274+
results(copies) = new InputBufferSeens(copies, buffer)
275+
case inputbuffer =>
276+
var idx = 0
277+
while (idx < aggregates.length) {
278+
val ae = aggregates(idx)
279+
ae.iterate(ae.eval(currentRow), inputbuffer.buffer)
280+
idx += 1
281+
}
282+
282283
}
283284
}
284285

285-
createIterator(aggregates, results.valuesIterator)
286+
createIterator(aggregates, results.iterator.map(_._2))
286287
}
287288
}
288289
}
@@ -328,32 +329,32 @@ case class AggregatePostShuffle(
328329

329330
createIterator(aggregates, Iterator(new InputBufferSeens().withBuffer(buffer)))
330331
} else {
331-
val results = new mutable.HashMap[Row, InputBufferSeens]()
332+
val results = new OpenHashMap[Row, InputBufferSeens]()
332333
while (iter.hasNext) {
333334
val currentRow = iter.next()
334335
val keys = groupByProjection(currentRow)
335-
results.get(keys) match {
336-
case Some(pair) =>
336+
results(keys) match {
337+
case null =>
338+
val buffer = new GenericMutableRow(bufferSchema.length)
337339
var idx = 0
338340
while (idx < aggregates.length) {
339341
val ae = aggregates(idx)
340-
ae.merge(currentRow, pair.buffer)
342+
ae.reset(buffer)
343+
ae.merge(currentRow, buffer)
341344
idx += 1
342345
}
343-
case None =>
344-
val buffer = new GenericMutableRow(bufferSchema.length)
346+
results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer)
347+
case pair =>
345348
var idx = 0
346349
while (idx < aggregates.length) {
347350
val ae = aggregates(idx)
348-
ae.reset(buffer)
349-
ae.merge(currentRow, buffer)
351+
ae.merge(currentRow, pair.buffer)
350352
idx += 1
351353
}
352-
results.put(keys.copy(), new InputBufferSeens(currentRow.copy(), buffer))
353354
}
354355
}
355356

356-
createIterator(aggregates, results.valuesIterator)
357+
createIterator(aggregates, results.iterator.map(_._2))
357358
}
358359
}
359360
}
@@ -383,15 +384,15 @@ case class DistinctAggregate(
383384
if (groupingExpressions.isEmpty) {
384385
val buffer = new GenericMutableRow(bufferSchema.length)
385386
// TODO save the memory only for those DISTINCT aggregate expressions
386-
val seens = new Array[mutable.HashSet[Any]](aggregateFunctionBinds.length)
387+
val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length)
387388

388389
var idx = 0
389390
while (idx < aggregateFunctionBinds.length) {
390391
val ae = aggregates(idx)
391392
ae.reset(buffer)
392393

393394
if (ae.distinct) {
394-
seens(idx) = new mutable.HashSet[Any]()
395+
seens(idx) = new OpenHashSet[Any]()
395396
}
396397

397398
idx += 1
@@ -420,56 +421,57 @@ case class DistinctAggregate(
420421

421422
createIterator(aggregates, Iterator(ibs))
422423
} else {
423-
val results = new mutable.HashMap[Row, InputBufferSeens]()
424+
val results = new OpenHashMap[Row, InputBufferSeens]()
424425

425426
while (iter.hasNext) {
426427
val currentRow = iter.next()
427428

428429
val keys = groupByProjection(currentRow)
429-
results.get(keys) match {
430-
case Some(inputBufferSeens) =>
430+
results(keys) match {
431+
case null =>
432+
val buffer = new GenericMutableRow(bufferSchema.length)
433+
// TODO save the memory only for those DISTINCT aggregate expressions
434+
val seens = new Array[OpenHashSet[Any]](aggregateFunctionBinds.length)
435+
431436
var idx = 0
432437
while (idx < aggregateFunctionBinds.length) {
433438
val ae = aggregates(idx)
434439
val value = ae.eval(currentRow)
440+
ae.reset(buffer)
441+
ae.iterate(value, buffer)
435442

436443
if (ae.distinct) {
437-
if (value != null && !inputBufferSeens.seens(idx).contains(value)) {
438-
ae.iterate(value, inputBufferSeens.buffer)
439-
inputBufferSeens.seens(idx).add(value)
444+
val seen = new OpenHashSet[Any]()
445+
if (value != null) {
446+
seen.add(value)
440447
}
441-
} else {
442-
ae.iterate(value, inputBufferSeens.buffer)
448+
seens.update(idx, seen)
443449
}
450+
444451
idx += 1
445452
}
446-
case None =>
447-
val buffer = new GenericMutableRow(bufferSchema.length)
448-
// TODO save the memory only for those DISTINCT aggregate expressions
449-
val seens = new Array[mutable.HashSet[Any]](aggregateFunctionBinds.length)
453+
results(keys.copy()) = new InputBufferSeens(currentRow.copy(), buffer, seens)
450454

455+
case inputBufferSeens =>
451456
var idx = 0
452457
while (idx < aggregateFunctionBinds.length) {
453458
val ae = aggregates(idx)
454459
val value = ae.eval(currentRow)
455-
ae.reset(buffer)
456-
ae.iterate(value, buffer)
457460

458461
if (ae.distinct) {
459-
val seen = new mutable.HashSet[Any]()
460-
if (value != null) {
461-
seen.add(value)
462+
if (value != null && !inputBufferSeens.seens(idx).contains(value)) {
463+
ae.iterate(value, inputBufferSeens.buffer)
464+
inputBufferSeens.seens(idx).add(value)
462465
}
463-
seens.update(idx, seen)
466+
} else {
467+
ae.iterate(value, inputBufferSeens.buffer)
464468
}
465-
466469
idx += 1
467470
}
468-
results.put(keys.copy(), new InputBufferSeens(currentRow.copy(), buffer, seens))
469471
}
470472
}
471473

472-
createIterator(aggregates, results.valuesIterator)
474+
createIterator(aggregates, results.iterator.map(_._2))
473475
}
474476
}
475477
}

0 commit comments

Comments
 (0)