Skip to content

Commit efe1102

Browse files
committed
Changing CacheManager and BlockManager to pass iterators directly to the serializer when a 'DISK_ONLY' persist is called.
This is in response to SPARK-942.
1 parent dfd1ebc commit efe1102

File tree

6 files changed

+28
-18
lines changed

6 files changed

+28
-18
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,21 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
7171
val computedValues = rdd.computeOrReadCheckpoint(split, context)
7272
// Persist the result, so long as the task is not running locally
7373
if (context.runningLocally) { return computedValues }
74-
val elements = new ArrayBuffer[Any]
75-
elements ++= computedValues
76-
blockManager.put(key, elements, storageLevel, tellMaster = true)
77-
return elements.iterator.asInstanceOf[Iterator[T]]
74+
if (storageLevel == StorageLevel.DISK_ONLY || storageLevel == StorageLevel.DISK_ONLY_2) {
75+
blockManager.put(key, computedValues, storageLevel, tellMaster = true)
76+
return blockManager.get(key) match {
77+
case Some(values) =>
78+
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
79+
case None =>
80+
logInfo("Failure to store %s".format(key));
81+
return null;
82+
}
83+
} else {
84+
val elements = new ArrayBuffer[Any]
85+
elements ++= computedValues
86+
blockManager.put(key, elements, storageLevel, tellMaster = true)
87+
return elements.iterator.asInstanceOf[Iterator[T]]
88+
}
7889
} finally {
7990
loading.synchronized {
8091
loading.remove(key)

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import org.apache.spark.util.ByteBufferInputStream
2424

2525
private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
2626
val objOut = new ObjectOutputStream(out)
27-
def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
27+
//Calling reset to avoid memory leak: http://stackoverflow.com/questions/1281549/memory-leak-traps-in-the-java-standard-api
28+
def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); objOut.reset(); this }
2829
def flush() { objOut.flush() }
2930
def close() { objOut.close() }
3031
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ private[spark] class BlockManager(
356356
// TODO: Consider creating a putValues that also takes in a iterator?
357357
val valuesBuffer = new ArrayBuffer[Any]
358358
valuesBuffer ++= values
359-
memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
359+
memoryStore.putValues(blockId, valuesBuffer.toIterator, level, true).data match {
360360
case Left(values2) =>
361361
return Some(values2)
362362
case _ =>
@@ -451,9 +451,7 @@ private[spark] class BlockManager(
451451

452452
def put(blockId: BlockId, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
453453
: Long = {
454-
val elements = new ArrayBuffer[Any]
455-
elements ++= values
456-
put(blockId, elements, level, tellMaster)
454+
doPut(blockId, Left(values), level, tellMaster)
457455
}
458456

459457
/**
@@ -474,7 +472,7 @@ private[spark] class BlockManager(
474472
def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
475473
tellMaster: Boolean = true) : Long = {
476474
require(values != null, "Values is null")
477-
doPut(blockId, Left(values), level, tellMaster)
475+
doPut(blockId, Left(values.toIterator), level, tellMaster)
478476
}
479477

480478
/**
@@ -486,7 +484,7 @@ private[spark] class BlockManager(
486484
doPut(blockId, Right(bytes), level, tellMaster)
487485
}
488486

489-
private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
487+
private def doPut(blockId: BlockId, data: Either[Iterator[Any], ByteBuffer],
490488
level: StorageLevel, tellMaster: Boolean = true): Long = {
491489
require(blockId != null, "BlockId is null")
492490
require(level != null && level.isValid, "StorageLevel is null or invalid")
@@ -691,7 +689,7 @@ private[spark] class BlockManager(
691689
logInfo("Writing block " + blockId + " to disk")
692690
data match {
693691
case Left(elements) =>
694-
diskStore.putValues(blockId, elements, level, false)
692+
diskStore.putValues(blockId, elements.toIterator, level, false)
695693
case Right(bytes) =>
696694
diskStore.putBytes(blockId, bytes, level)
697695
}

core/src/main/scala/org/apache/spark/storage/BlockStore.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ abstract class BlockStore(val blockManager: BlockManager) extends Logging {
3636
* @return a PutResult that contains the size of the data, as well as the values put if
3737
* returnValues is true (if not, the result's data field can be null)
3838
*/
39-
def putValues(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
39+
def putValues(blockId: BlockId, values: Iterator[Any], level: StorageLevel,
4040
returnValues: Boolean) : PutResult
4141

4242
/**

core/src/main/scala/org/apache/spark/storage/DiskStore.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
5757

5858
override def putValues(
5959
blockId: BlockId,
60-
values: ArrayBuffer[Any],
60+
values: Iterator[Any],
6161
level: StorageLevel,
6262
returnValues: Boolean)
6363
: PutResult = {
@@ -66,7 +66,7 @@ private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManage
6666
val startTime = System.currentTimeMillis
6767
val file = diskManager.getFile(blockId)
6868
val outputStream = new FileOutputStream(file)
69-
blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
69+
blockManager.dataSerializeStream(blockId, outputStream, values)
7070
val length = file.length
7171

7272
val timeTaken = System.currentTimeMillis - startTime

core/src/main/scala/org/apache/spark/storage/MemoryStore.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,17 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
6565

6666
override def putValues(
6767
blockId: BlockId,
68-
values: ArrayBuffer[Any],
68+
values: Iterator[Any],
6969
level: StorageLevel,
7070
returnValues: Boolean)
7171
: PutResult = {
7272

7373
if (level.deserialized) {
7474
val sizeEstimate = SizeEstimator.estimate(values.asInstanceOf[AnyRef])
7575
tryToPut(blockId, values, sizeEstimate, true)
76-
PutResult(sizeEstimate, Left(values.iterator))
76+
PutResult(sizeEstimate, Left(values))
7777
} else {
78-
val bytes = blockManager.dataSerialize(blockId, values.iterator)
78+
val bytes = blockManager.dataSerialize(blockId, values)
7979
tryToPut(blockId, bytes, bytes.limit, false)
8080
PutResult(bytes.limit(), Right(bytes.duplicate()))
8181
}

0 commit comments

Comments
 (0)