Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit bd11b01

Browse files
sryzapwendell
authored andcommitted
[SPARK-7896] Allow ChainedBuffer to store more than 2 GB
Author: Sandy Ryza <[email protected]> Closes apache#6440 from sryza/sandy-spark-7896 and squashes the following commits: 49d8a0d [Sandy Ryza] Fix bug introduced when reading over record boundaries 6006856 [Sandy Ryza] Fix overflow issues 006b4b2 [Sandy Ryza] Fix scalastyle by removing non ascii characters 8b000ca [Sandy Ryza] Add ascii art to describe layout of data in metaBuffer f2053c0 [Sandy Ryza] Fix negative overflow issue 0368c78 [Sandy Ryza] Initialize size as 0 a5a4820 [Sandy Ryza] Use explicit types for all numbers in ChainedBuffer b7e0213 [Sandy Ryza] SPARK-7896. Allow ChainedBuffer to store more than 2 GB
1 parent 852f4de commit bd11b01

File tree

2 files changed

+55
-42
lines changed

2 files changed

+55
-42
lines changed

core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ import scala.collection.mutable.ArrayBuffer
2828
* occupy a contiguous segment of memory.
2929
*/
3030
private[spark] class ChainedBuffer(chunkSize: Int) {
31-
private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
32-
assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
31+
32+
private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
33+
java.lang.Long.highestOneBit(chunkSize))
34+
assert((1 << chunkSizeLog2) == chunkSize,
3335
s"ChainedBuffer chunk size $chunkSize must be a power of two")
3436
private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
35-
private var _size: Int = _
37+
private var _size: Long = 0
3638

3739
/**
3840
* Feed bytes from this buffer into a BlockObjectWriter.
@@ -41,16 +43,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
4143
* @param os OutputStream to read into.
4244
* @param len Number of bytes to read.
4345
*/
44-
def read(pos: Int, os: OutputStream, len: Int): Unit = {
46+
def read(pos: Long, os: OutputStream, len: Int): Unit = {
4547
if (pos + len > _size) {
4648
throw new IndexOutOfBoundsException(
4749
s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
4850
}
49-
var chunkIndex = pos >> chunkSizeLog2
50-
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
51-
var written = 0
51+
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
52+
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
53+
var written: Int = 0
5254
while (written < len) {
53-
val toRead = math.min(len - written, chunkSize - posInChunk)
55+
val toRead: Int = math.min(len - written, chunkSize - posInChunk)
5456
os.write(chunks(chunkIndex), posInChunk, toRead)
5557
written += toRead
5658
chunkIndex += 1
@@ -66,16 +68,16 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
6668
* @param offs Offset in the byte array to read to.
6769
* @param len Number of bytes to read.
6870
*/
69-
def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
71+
def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
7072
if (pos + len > _size) {
7173
throw new IndexOutOfBoundsException(
7274
s"Read of $len bytes at position $pos would go past size of buffer")
7375
}
74-
var chunkIndex = pos >> chunkSizeLog2
75-
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
76-
var written = 0
76+
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
77+
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
78+
var written: Int = 0
7779
while (written < len) {
78-
val toRead = math.min(len - written, chunkSize - posInChunk)
80+
val toRead: Int = math.min(len - written, chunkSize - posInChunk)
7981
System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
8082
written += toRead
8183
chunkIndex += 1
@@ -91,22 +93,22 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
9193
* @param offs Offset in the byte array to write from.
9294
* @param len Number of bytes to write.
9395
*/
94-
def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
96+
def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
9597
if (pos > _size) {
9698
throw new IndexOutOfBoundsException(
9799
s"Write at position $pos starts after end of buffer ${_size}")
98100
}
99101
// Grow if needed
100-
val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
102+
val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
101103
while (endChunkIndex >= chunks.length) {
102104
chunks += new Array[Byte](chunkSize)
103105
}
104106

105-
var chunkIndex = pos >> chunkSizeLog2
106-
var posInChunk = pos - (chunkIndex << chunkSizeLog2)
107-
var written = 0
107+
var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
108+
var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
109+
var written: Int = 0
108110
while (written < len) {
109-
val toWrite = math.min(len - written, chunkSize - posInChunk)
111+
val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
110112
System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
111113
written += toWrite
112114
chunkIndex += 1
@@ -119,19 +121,19 @@ private[spark] class ChainedBuffer(chunkSize: Int) {
119121
/**
120122
* Total size of buffer that can be written to without allocating additional memory.
121123
*/
122-
def capacity: Int = chunks.size * chunkSize
124+
def capacity: Long = chunks.size.toLong * chunkSize
123125

124126
/**
125127
* Size of the logical buffer.
126128
*/
127-
def size: Int = _size
129+
def size: Long = _size
128130
}
129131

130132
/**
131133
* Output stream that writes to a ChainedBuffer.
132134
*/
133135
private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
134-
private var pos = 0
136+
private var pos: Long = 0
135137

136138
override def write(b: Int): Unit = {
137139
throw new UnsupportedOperationException()

core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
4141
*
4242
* Currently, only sorting by partition is supported.
4343
*
44+
* Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
45+
* two integers:
46+
*
47+
* +-------------+------------+------------+-------------+
48+
* | keyStart | keyValLen | partitionId |
49+
* +-------------+------------+------------+-------------+
50+
*
4451
* @param metaInitialRecords The initial number of entries in the metadata buffer.
4552
* @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
4653
* @param serializerInstance the serializer used for serializing inserted records.
@@ -68,19 +75,15 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
6875
}
6976

7077
val keyStart = kvBuffer.size
71-
if (keyStart < 0) {
72-
throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes")
73-
}
7478
kvSerializationStream.writeKey[Any](key)
75-
kvSerializationStream.flush()
76-
val valueStart = kvBuffer.size
7779
kvSerializationStream.writeValue[Any](value)
7880
kvSerializationStream.flush()
79-
val valueEnd = kvBuffer.size
81+
val keyValLen = (kvBuffer.size - keyStart).toInt
8082

81-
metaBuffer.put(keyStart)
82-
metaBuffer.put(valueStart)
83-
metaBuffer.put(valueEnd)
83+
// keyStart, a long, gets split across two ints
84+
metaBuffer.put(keyStart.toInt)
85+
metaBuffer.put((keyStart >> 32).toInt)
86+
metaBuffer.put(keyValLen)
8487
metaBuffer.put(partition)
8588
}
8689

@@ -114,7 +117,7 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
114117
}
115118
}
116119

117-
override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity
120+
override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
118121

119122
override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
120123
: WritablePartitionedIterator = {
@@ -128,10 +131,10 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
128131
var pos = 0
129132

130133
def writeNext(writer: BlockObjectWriter): Unit = {
131-
val keyStart = metaBuffer.get(pos + KEY_START)
132-
val valueEnd = metaBuffer.get(pos + VAL_END)
134+
val keyStart = getKeyStartPos(metaBuffer, pos)
135+
val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
133136
pos += RECORD_SIZE
134-
kvBuffer.read(keyStart, writer, valueEnd - keyStart)
137+
kvBuffer.read(keyStart, writer, keyValLen)
135138
writer.recordWritten()
136139
}
137140
def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
@@ -163,23 +166,26 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
163166
private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
164167
extends InputStream {
165168

169+
import PartitionedSerializedPairBuffer._
170+
166171
private var metaBufferPos = 0
167172
private var kvBufferPos =
168-
if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0
173+
if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
169174

170175
override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
171176

172177
override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
173178
if (metaBufferPos >= metaBuffer.position) {
174179
return -1
175180
}
176-
val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos
181+
val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
182+
(kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
177183
val toRead = math.min(bytesRemainingInRecord, len)
178184
kvBuffer.read(kvBufferPos, bytes, offs, toRead)
179185
if (toRead == bytesRemainingInRecord) {
180186
metaBufferPos += RECORD_SIZE
181187
if (metaBufferPos < metaBuffer.position) {
182-
kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START)
188+
kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
183189
}
184190
} else {
185191
kvBufferPos += toRead
@@ -246,9 +252,14 @@ private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuf
246252
}
247253

248254
private[spark] object PartitionedSerializedPairBuffer {
249-
val KEY_START = 0
250-
val VAL_START = 1
251-
val VAL_END = 2
255+
val KEY_START = 0 // keyStart, a long, gets split across two ints
256+
val KEY_VAL_LEN = 2
252257
val PARTITION = 3
253-
val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata
258+
val RECORD_SIZE = PARTITION + 1 // num ints of metadata
259+
260+
def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
261+
val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
262+
val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
263+
(upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
264+
}
254265
}

0 commit comments

Comments
 (0)