Skip to content

Commit 5135200

Browse files
committed
Fix spill reading for large rows; add test
1 parent 2f48777 commit 5135200

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
3939
private long keyPrefix;
4040
private int numRecordsRemaining;
4141

42-
private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
43-
private final Object baseObject = arr;
42+
private byte[] arr = new byte[1024 * 1024];
43+
private Object baseObject = arr;
4444
private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
4545

4646
public UnsafeSorterSpillReader(
@@ -63,6 +63,10 @@ public boolean hasNext() {
6363
public void loadNext() throws IOException {
6464
recordLength = din.readInt();
6565
keyPrefix = din.readLong();
66+
if (recordLength > arr.length) {
67+
arr = new byte[recordLength];
68+
baseObject = arr;
69+
}
6670
ByteStreams.readFully(in, arr, 0, recordLength);
6771
numRecordsRemaining--;
6872
if (numRecordsRemaining == 0) {

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,17 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll {
6363
}
6464
}
6565

66+
test("sorting does not crash for large inputs") {
67+
val sortOrder = 'a.asc :: Nil
68+
val stringLength = 1024 * 1024 * 2
69+
checkAnswer(
70+
Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
71+
UnsafeExternalSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
72+
Sort(sortOrder, global = true, _: SparkPlan),
73+
sortAnswers = false
74+
)
75+
}
76+
6677
// Test sorting on different data types
6778
for (
6879
dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType)

0 commit comments

Comments
 (0)