17
17
18
18
package org .apache .spark .sql .execution
19
19
20
- import scala .collection ._
21
-
22
20
import org .apache .spark .annotation .DeveloperApi
23
21
22
+ import org .apache .spark .util .collection .{OpenHashSet , OpenHashMap }
23
+
24
24
import org .apache .spark .sql .catalyst .errors ._
25
25
import org .apache .spark .sql .catalyst .expressions ._
26
26
import org .apache .spark .sql .catalyst .plans .physical ._
@@ -39,7 +39,7 @@ sealed case class AggregateFunctionBind(
39
39
sealed class InputBufferSeens (
40
40
var input : Row , //
41
41
var buffer : MutableRow ,
42
- var seens : Array [mutable. HashSet [Any ]] = null ) {
42
+ var seens : Array [OpenHashSet [Any ]] = null ) {
43
43
def this () {
44
44
this (new GenericMutableRow (0 ), null )
45
45
}
@@ -54,7 +54,7 @@ sealed class InputBufferSeens(
54
54
this
55
55
}
56
56
57
- def withSeens (seens : Array [mutable. HashSet [Any ]]): InputBufferSeens = {
57
+ def withSeens (seens : Array [OpenHashSet [Any ]]): InputBufferSeens = {
58
58
this .seens = seens
59
59
this
60
60
}
@@ -250,20 +250,13 @@ case class AggregatePreShuffle(
250
250
251
251
createIterator(aggregates, Iterator (new InputBufferSeens ().withBuffer(buffer)))
252
252
} else {
253
- val results = new mutable. HashMap [Row , InputBufferSeens ]()
253
+ val results = new OpenHashMap [Row , InputBufferSeens ]()
254
254
while (iter.hasNext) {
255
255
val currentRow = iter.next()
256
256
257
257
val keys = groupByProjection(currentRow)
258
- results.get(keys) match {
259
- case Some (inputbuffer) =>
260
- var idx = 0
261
- while (idx < aggregates.length) {
262
- val ae = aggregates(idx)
263
- ae.iterate(ae.eval(currentRow), inputbuffer.buffer)
264
- idx += 1
265
- }
266
- case None =>
258
+ results(keys) match {
259
+ case null =>
267
260
val buffer = new GenericMutableRow (bufferSchema.length)
268
261
var idx = 0
269
262
while (idx < aggregates.length) {
@@ -278,11 +271,19 @@ case class AggregatePreShuffle(
278
271
}
279
272
280
273
val copies = keys.copy()
281
- results.put(copies, new InputBufferSeens (copies, buffer))
274
+ results(copies) = new InputBufferSeens (copies, buffer)
275
+ case inputbuffer =>
276
+ var idx = 0
277
+ while (idx < aggregates.length) {
278
+ val ae = aggregates(idx)
279
+ ae.iterate(ae.eval(currentRow), inputbuffer.buffer)
280
+ idx += 1
281
+ }
282
+
282
283
}
283
284
}
284
285
285
- createIterator(aggregates, results.valuesIterator )
286
+ createIterator(aggregates, results.iterator.map(_._2) )
286
287
}
287
288
}
288
289
}
@@ -328,32 +329,32 @@ case class AggregatePostShuffle(
328
329
329
330
createIterator(aggregates, Iterator (new InputBufferSeens ().withBuffer(buffer)))
330
331
} else {
331
- val results = new mutable. HashMap [Row , InputBufferSeens ]()
332
+ val results = new OpenHashMap [Row , InputBufferSeens ]()
332
333
while (iter.hasNext) {
333
334
val currentRow = iter.next()
334
335
val keys = groupByProjection(currentRow)
335
- results.get(keys) match {
336
- case Some (pair) =>
336
+ results(keys) match {
337
+ case null =>
338
+ val buffer = new GenericMutableRow (bufferSchema.length)
337
339
var idx = 0
338
340
while (idx < aggregates.length) {
339
341
val ae = aggregates(idx)
340
- ae.merge(currentRow, pair.buffer)
342
+ ae.reset(buffer)
343
+ ae.merge(currentRow, buffer)
341
344
idx += 1
342
345
}
343
- case None =>
344
- val buffer = new GenericMutableRow (bufferSchema.length)
346
+ results(keys.copy()) = new InputBufferSeens (currentRow.copy(), buffer)
347
+ case pair =>
345
348
var idx = 0
346
349
while (idx < aggregates.length) {
347
350
val ae = aggregates(idx)
348
- ae.reset(buffer)
349
- ae.merge(currentRow, buffer)
351
+ ae.merge(currentRow, pair.buffer)
350
352
idx += 1
351
353
}
352
- results.put(keys.copy(), new InputBufferSeens (currentRow.copy(), buffer))
353
354
}
354
355
}
355
356
356
- createIterator(aggregates, results.valuesIterator )
357
+ createIterator(aggregates, results.iterator.map(_._2) )
357
358
}
358
359
}
359
360
}
@@ -383,15 +384,15 @@ case class DistinctAggregate(
383
384
if (groupingExpressions.isEmpty) {
384
385
val buffer = new GenericMutableRow (bufferSchema.length)
385
386
// TODO save the memory only for those DISTINCT aggregate expressions
386
- val seens = new Array [mutable. HashSet [Any ]](aggregateFunctionBinds.length)
387
+ val seens = new Array [OpenHashSet [Any ]](aggregateFunctionBinds.length)
387
388
388
389
var idx = 0
389
390
while (idx < aggregateFunctionBinds.length) {
390
391
val ae = aggregates(idx)
391
392
ae.reset(buffer)
392
393
393
394
if (ae.distinct) {
394
- seens(idx) = new mutable. HashSet [Any ]()
395
+ seens(idx) = new OpenHashSet [Any ]()
395
396
}
396
397
397
398
idx += 1
@@ -420,56 +421,57 @@ case class DistinctAggregate(
420
421
421
422
createIterator(aggregates, Iterator (ibs))
422
423
} else {
423
- val results = new mutable. HashMap [Row , InputBufferSeens ]()
424
+ val results = new OpenHashMap [Row , InputBufferSeens ]()
424
425
425
426
while (iter.hasNext) {
426
427
val currentRow = iter.next()
427
428
428
429
val keys = groupByProjection(currentRow)
429
- results.get(keys) match {
430
- case Some (inputBufferSeens) =>
430
+ results(keys) match {
431
+ case null =>
432
+ val buffer = new GenericMutableRow (bufferSchema.length)
433
+ // TODO save the memory only for those DISTINCT aggregate expressions
434
+ val seens = new Array [OpenHashSet [Any ]](aggregateFunctionBinds.length)
435
+
431
436
var idx = 0
432
437
while (idx < aggregateFunctionBinds.length) {
433
438
val ae = aggregates(idx)
434
439
val value = ae.eval(currentRow)
440
+ ae.reset(buffer)
441
+ ae.iterate(value, buffer)
435
442
436
443
if (ae.distinct) {
437
- if (value != null && ! inputBufferSeens.seens(idx).contains(value)) {
438
- ae.iterate (value, inputBufferSeens.buffer)
439
- inputBufferSeens.seens(idx) .add(value)
444
+ val seen = new OpenHashSet [ Any ]()
445
+ if (value != null ) {
446
+ seen .add(value)
440
447
}
441
- } else {
442
- ae.iterate(value, inputBufferSeens.buffer)
448
+ seens.update(idx, seen)
443
449
}
450
+
444
451
idx += 1
445
452
}
446
- case None =>
447
- val buffer = new GenericMutableRow (bufferSchema.length)
448
- // TODO save the memory only for those DISTINCT aggregate expressions
449
- val seens = new Array [mutable.HashSet [Any ]](aggregateFunctionBinds.length)
453
+ results(keys.copy()) = new InputBufferSeens (currentRow.copy(), buffer, seens)
450
454
455
+ case inputBufferSeens =>
451
456
var idx = 0
452
457
while (idx < aggregateFunctionBinds.length) {
453
458
val ae = aggregates(idx)
454
459
val value = ae.eval(currentRow)
455
- ae.reset(buffer)
456
- ae.iterate(value, buffer)
457
460
458
461
if (ae.distinct) {
459
- val seen = new mutable. HashSet [ Any ]()
460
- if (value != null ) {
461
- seen .add(value)
462
+ if (value != null && ! inputBufferSeens.seens(idx).contains(value)) {
463
+ ae.iterate (value, inputBufferSeens.buffer)
464
+ inputBufferSeens.seens(idx) .add(value)
462
465
}
463
- seens.update(idx, seen)
466
+ } else {
467
+ ae.iterate(value, inputBufferSeens.buffer)
464
468
}
465
-
466
469
idx += 1
467
470
}
468
- results.put(keys.copy(), new InputBufferSeens (currentRow.copy(), buffer, seens))
469
471
}
470
472
}
471
473
472
- createIterator(aggregates, results.valuesIterator )
474
+ createIterator(aggregates, results.iterator.map(_._2) )
473
475
}
474
476
}
475
477
}
0 commit comments