Skip to content

Commit 8d7fbe7

Browse files
committed
Fixes to multiple spilling-related bugs.
1 parent 82e21c1 commit 8d7fbe7

File tree

5 files changed

+55
-34
lines changed

5 files changed

+55
-34
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,15 @@ public void spill() throws IOException {
126126
spillWriters.size() > 1 ? " times" : " time");
127127

128128
final UnsafeSorterSpillWriter spillWriter =
129-
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics);
129+
new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics,
130+
sorter.numRecords());
130131
spillWriters.add(spillWriter);
131132
final UnsafeSorterIterator sortedRecords = sorter.getSortedIterator();
132133
while (sortedRecords.hasNext()) {
133134
sortedRecords.loadNext();
134135
final Object baseObject = sortedRecords.getBaseObject();
135136
final long baseOffset = sortedRecords.getBaseOffset();
136-
// TODO: this assumption that the first long holds a length is not enforced via our interfaces
137-
// We need to either always store this via the write path (e.g. not require the caller to do
138-
// it), or provide interfaces / hooks for customizing the physical storage format etc.
139-
final int recordLength = (int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
137+
final int recordLength = sortedRecords.getRecordLength();
140138
spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix());
141139
}
142140
spillWriter.close();

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ public UnsafeInMemorySorter(
8989
this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager);
9090
}
9191

92+
/**
93+
* @return the number of records that have been inserted into this sorter.
94+
*/
95+
public int numRecords() {
96+
return pointerArrayInsertPosition / 2;
97+
}
98+
9299
public long getMemoryUsage() {
93100
return pointerArray.length * 8L;
94101
}
@@ -106,7 +113,8 @@ public void expandPointerArray() {
106113
}
107114

108115
/**
109-
* Inserts a record to be sorted.
116+
* Inserts a record to be sorted. Assumes that the record pointer points to a record length
117+
* stored as a 4-byte integer, followed by the record's bytes.
110118
*
111119
* @param recordPointer pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
112120
* @param keyPrefix a user-defined key prefix

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,47 @@
2525
import org.apache.spark.storage.BlockManager;
2626
import org.apache.spark.unsafe.PlatformDependent;
2727

28+
/**
29+
* Reads spill files written by {@link UnsafeSorterSpillWriter} (see that class for a description
30+
* of the file format).
31+
*/
2832
final class UnsafeSorterSpillReader extends UnsafeSorterIterator {
2933

30-
private final File file;
3134
private InputStream in;
3235
private DataInputStream din;
3336

34-
private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
35-
private int nextRecordLength;
36-
37+
// Variables that change with every record read:
38+
private int recordLength;
3739
private long keyPrefix;
40+
private int numRecordsRemaining;
41+
42+
private final byte[] arr = new byte[1024 * 1024]; // TODO: tune this (maybe grow dynamically)?
3843
private final Object baseObject = arr;
3944
private final long baseOffset = PlatformDependent.BYTE_ARRAY_OFFSET;
4045

4146
public UnsafeSorterSpillReader(
4247
BlockManager blockManager,
4348
File file,
4449
BlockId blockId) throws IOException {
45-
this.file = file;
46-
assert (file.length() > 0);
50+
assert (file.length() > 0);
4751
final BufferedInputStream bs = new BufferedInputStream(new FileInputStream(file));
4852
this.in = blockManager.wrapForCompression(blockId, bs);
4953
this.din = new DataInputStream(this.in);
50-
nextRecordLength = din.readInt();
54+
numRecordsRemaining = din.readInt();
5155
}
5256

5357
@Override
5458
public boolean hasNext() {
55-
return (in != null);
59+
return (numRecordsRemaining > 0);
5660
}
5761

5862
@Override
5963
public void loadNext() throws IOException {
64+
recordLength = din.readInt();
6065
keyPrefix = din.readLong();
61-
ByteStreams.readFully(in, arr, 0, nextRecordLength);
62-
nextRecordLength = din.readInt();
63-
if (nextRecordLength == UnsafeSorterSpillWriter.EOF_MARKER) {
66+
ByteStreams.readFully(in, arr, 0, recordLength);
67+
numRecordsRemaining--;
68+
if (numRecordsRemaining == 0) {
6469
in.close();
6570
in = null;
6671
din = null;
@@ -79,7 +84,7 @@ public long getBaseOffset() {
7984

8085
@Override
8186
public int getRecordLength() {
82-
return 0;
87+
return recordLength;
8388
}
8489

8590
@Override

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,14 @@
3030
import org.apache.spark.storage.TempLocalBlockId;
3131
import org.apache.spark.unsafe.PlatformDependent;
3232

33+
/**
34+
* Spills a list of sorted records to disk. Spill files have the following format:
35+
*
36+
* [# of records (int)] [[len (int)][prefix (long)][data (bytes)]...]
37+
*/
3338
final class UnsafeSorterSpillWriter {
3439

3540
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
36-
static final int EOF_MARKER = -1;
3741

3842
// Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
3943
// be an API to directly transfer bytes from managed memory to the disk writer, we buffer
@@ -42,22 +46,29 @@ final class UnsafeSorterSpillWriter {
4246

4347
private final File file;
4448
private final BlockId blockId;
49+
private final int numRecordsToWrite;
4550
private BlockObjectWriter writer;
51+
private int numRecordsSpilled = 0;
4652

4753
public UnsafeSorterSpillWriter(
4854
BlockManager blockManager,
4955
int fileBufferSize,
50-
ShuffleWriteMetrics writeMetrics) {
56+
ShuffleWriteMetrics writeMetrics,
57+
int numRecordsToWrite) throws IOException {
5158
final Tuple2<TempLocalBlockId, File> spilledFileInfo =
5259
blockManager.diskBlockManager().createTempLocalBlock();
5360
this.file = spilledFileInfo._2();
5461
this.blockId = spilledFileInfo._1();
62+
this.numRecordsToWrite = numRecordsToWrite;
5563
// Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
5664
// Our write path doesn't actually use this serializer (since we end up calling the `write()`
5765
// OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
5866
// around this, we pass a dummy no-op serializer.
5967
writer = blockManager.getDiskWriter(
6068
blockId, file, DummySerializerInstance.INSTANCE, fileBufferSize, writeMetrics);
69+
// Write the number of records
70+
writeIntToBuffer(numRecordsToWrite, 0);
71+
writer.write(writeBuffer, 0, 4);
6172
}
6273

6374
// Based on DataOutputStream.writeLong.
@@ -85,6 +96,12 @@ public void write(
8596
long baseOffset,
8697
int recordLength,
8798
long keyPrefix) throws IOException {
99+
if (numRecordsSpilled == numRecordsToWrite) {
100+
throw new IllegalStateException(
101+
"Number of records written exceeded numRecordsToWrite = " + numRecordsToWrite);
102+
} else {
103+
numRecordsSpilled++;
104+
}
88105
writeIntToBuffer(recordLength, 0);
89106
writeLongToBuffer(keyPrefix, 4);
90107
int dataRemaining = recordLength;
@@ -107,8 +124,6 @@ public void write(
107124
}
108125

109126
public void close() throws IOException {
110-
writeIntToBuffer(EOF_MARKER, 0);
111-
writer.write(writeBuffer, 0, 4);
112127
writer.commitAndClose();
113128
writer = null;
114129
writeBuffer = null;

core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -153,22 +153,17 @@ public void testSortingOnlyByPrefix() throws Exception {
153153
insertNumber(sorter, 3);
154154
sorter.spill();
155155
insertNumber(sorter, 4);
156+
sorter.spill();
156157
insertNumber(sorter, 2);
157158

158159
UnsafeSorterIterator iter = sorter.getSortedIterator();
159160

160-
iter.loadNext();
161-
assertEquals(1, iter.getKeyPrefix());
162-
iter.loadNext();
163-
assertEquals(2, iter.getKeyPrefix());
164-
iter.loadNext();
165-
assertEquals(3, iter.getKeyPrefix());
166-
iter.loadNext();
167-
assertEquals(4, iter.getKeyPrefix());
168-
iter.loadNext();
169-
assertEquals(5, iter.getKeyPrefix());
170-
assertFalse(iter.hasNext());
171-
// TODO: check that the values are also read back properly.
161+
for (int i = 1; i <= 5; i++) {
162+
iter.loadNext();
163+
assertEquals(i, iter.getKeyPrefix());
164+
assertEquals(4, iter.getRecordLength());
165+
// TODO: read rest of value.
166+
}
172167

173168
// TODO: test for cleanup:
174169
// assert(tempDir.isEmpty)

0 commit comments

Comments
 (0)