Skip to content

Commit aeb680e

Browse files
committed
[SPARK-3386] Reuse SerializerInstance in shuffle code paths
1 parent 968ad97 commit aeb680e

File tree

8 files changed

+32
-24
lines changed

8 files changed

+32
-24
lines changed

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
3030
import org.apache.spark.deploy.worker.WorkerWatcher
3131
import org.apache.spark.scheduler.TaskDescription
3232
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
33+
import org.apache.spark.serializer.SerializerInstance
3334
import org.apache.spark.util.{SignalLogger, Utils}
3435

3536
private[spark] class CoarseGrainedExecutorBackend(
@@ -47,6 +48,8 @@ private[spark] class CoarseGrainedExecutorBackend(
4748
var executor: Executor = null
4849
@volatile var driver: Option[RpcEndpointRef] = None
4950

51+
private[this] val ser: SerializerInstance = env.closureSerializer.newInstance()
52+
5053
override def onStart() {
5154
import scala.concurrent.ExecutionContext.Implicits.global
5255
logInfo("Connecting to driver: " + driverUrl)
@@ -83,7 +86,6 @@ private[spark] class CoarseGrainedExecutorBackend(
8386
logError("Received LaunchTask command but executor was null")
8487
System.exit(1)
8588
} else {
86-
val ser = env.closureSerializer.newInstance()
8789
val taskDesc = ser.deserialize[TaskDescription](data.value)
8890
logInfo("Got assigned task " + taskDesc.taskId)
8991
executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,

core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,12 @@ class FileShuffleBlockManager(conf: SparkConf)
113113
private var fileGroup: ShuffleFileGroup = null
114114

115115
val openStartTime = System.nanoTime
116+
val serializerInstance = serializer.newInstance()
116117
val writers: Array[BlockObjectWriter] = if (consolidateShuffleFiles) {
117118
fileGroup = getUnusedFileGroup()
118119
Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
119120
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
120-
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, bufferSize,
121+
blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializerInstance, bufferSize,
121122
writeMetrics)
122123
}
123124
} else {
@@ -133,7 +134,8 @@ class FileShuffleBlockManager(conf: SparkConf)
133134
logWarning(s"Failed to remove existing shuffle file $blockFile")
134135
}
135136
}
136-
blockManager.getDiskWriter(blockId, blockFile, serializer, bufferSize, writeMetrics)
137+
blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
138+
writeMetrics)
137139
}
138140
}
139141
// Creating the file to write to and creating a disk writer both involve interacting with

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.network.netty.SparkTransportConf
3737
import org.apache.spark.network.shuffle.ExternalShuffleClient
3838
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
3939
import org.apache.spark.rpc.RpcEnv
40-
import org.apache.spark.serializer.Serializer
40+
import org.apache.spark.serializer.{SerializerInstance, Serializer}
4141
import org.apache.spark.shuffle.ShuffleManager
4242
import org.apache.spark.shuffle.hash.HashShuffleManager
4343
import org.apache.spark.util._
@@ -646,13 +646,13 @@ private[spark] class BlockManager(
646646
def getDiskWriter(
647647
blockId: BlockId,
648648
file: File,
649-
serializer: Serializer,
649+
serializerInstance: SerializerInstance,
650650
bufferSize: Int,
651651
writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
652652
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
653653
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
654-
new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites,
655-
writeMetrics)
654+
new DiskBlockObjectWriter(blockId, file, serializerInstance, bufferSize, compressStream,
655+
syncWrites, writeMetrics)
656656
}
657657

658658
/**

core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream}
2121
import java.nio.channels.FileChannel
2222

2323
import org.apache.spark.Logging
24-
import org.apache.spark.serializer.{SerializationStream, Serializer}
24+
import org.apache.spark.serializer.{SerializerInstance, SerializationStream}
2525
import org.apache.spark.executor.ShuffleWriteMetrics
2626
import org.apache.spark.util.Utils
2727

@@ -71,7 +71,7 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
7171
private[spark] class DiskBlockObjectWriter(
7272
blockId: BlockId,
7373
file: File,
74-
serializer: Serializer,
74+
serializerInstance: SerializerInstance,
7575
bufferSize: Int,
7676
compressStream: OutputStream => OutputStream,
7777
syncWrites: Boolean,
@@ -134,7 +134,7 @@ private[spark] class DiskBlockObjectWriter(
134134
ts = new TimeTrackingOutputStream(fos)
135135
channel = fos.getChannel()
136136
bs = compressStream(new BufferedOutputStream(ts, bufferSize))
137-
objOut = serializer.newInstance().serializeStream(bs)
137+
objOut = serializerInstance.serializeStream(bs)
138138
initialized = true
139139
this
140140
}

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.{Logging, TaskContext}
2727
import org.apache.spark.network.BlockTransferService
2828
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
2929
import org.apache.spark.network.buffer.ManagedBuffer
30-
import org.apache.spark.serializer.Serializer
30+
import org.apache.spark.serializer.{SerializerInstance, Serializer}
3131
import org.apache.spark.util.{CompletionIterator, Utils}
3232

3333
/**
@@ -106,6 +106,8 @@ final class ShuffleBlockFetcherIterator(
106106

107107
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
108108

109+
private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
110+
109111
/**
110112
* Whether the iterator is still active. If isZombie is true, the callback interface will no
111113
* longer place fetched blocks into [[results]].
@@ -299,7 +301,7 @@ final class ShuffleBlockFetcherIterator(
299301
// the scheduler gets a FetchFailedException.
300302
Try(buf.createInputStream()).map { is0 =>
301303
val is = blockManager.wrapForCompression(blockId, is0)
302-
val iter = serializer.newInstance().deserializeStream(is).asIterator
304+
val iter = serializerInstance.deserializeStream(is).asIterator
303305
CompletionIterator[Any, Iterator[Any]](iter, {
304306
// Once the iterator is exhausted, release the buffer and set currentResult to null
305307
// so we don't release it again in cleanup.

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,7 @@ class ExternalAppendOnlyMap[K, V, C](
151151
override protected[this] def spill(collection: SizeTracker): Unit = {
152152
val (blockId, file) = diskBlockManager.createTempLocalBlock()
153153
curWriteMetrics = new ShuffleWriteMetrics()
154-
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
155-
curWriteMetrics)
154+
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
156155
var objectsWritten = 0
157156

158157
// List of batch sizes (bytes) in the order they are written to disk
@@ -179,8 +178,7 @@ class ExternalAppendOnlyMap[K, V, C](
179178
if (objectsWritten == serializerBatchSize) {
180179
flush()
181180
curWriteMetrics = new ShuffleWriteMetrics()
182-
writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
183-
curWriteMetrics)
181+
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
184182
}
185183
}
186184
if (objectsWritten > 0) {

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,8 @@ private[spark] class ExternalSorter[K, V, C](
272272
// createTempShuffleBlock here; see SPARK-3426 for more context.
273273
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
274274
curWriteMetrics = new ShuffleWriteMetrics()
275-
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
275+
var writer = blockManager.getDiskWriter(
276+
blockId, file, serInstance, fileBufferSize, curWriteMetrics)
276277
var objectsWritten = 0 // Objects written since the last flush
277278

278279
// List of batch sizes (bytes) in the order they are written to disk
@@ -308,7 +309,8 @@ private[spark] class ExternalSorter[K, V, C](
308309
if (objectsWritten == serializerBatchSize) {
309310
flush()
310311
curWriteMetrics = new ShuffleWriteMetrics()
311-
writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
312+
writer = blockManager.getDiskWriter(
313+
blockId, file, serInstance, fileBufferSize, curWriteMetrics)
312314
}
313315
}
314316
if (objectsWritten > 0) {
@@ -358,7 +360,9 @@ private[spark] class ExternalSorter[K, V, C](
358360
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
359361
// createTempShuffleBlock here; see SPARK-3426 for more context.
360362
val (blockId, file) = diskBlockManager.createTempShuffleBlock()
361-
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
363+
val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
364+
curWriteMetrics)
365+
writer.open()
362366
}
363367
// Creating the file to write to and creating a disk writer both involve interacting with
364368
// the disk, and can take a long time in aggregate when we open many files, so should be
@@ -749,8 +753,8 @@ private[spark] class ExternalSorter[K, V, C](
749753
// partition and just write everything directly.
750754
for ((id, elements) <- this.partitionedIterator) {
751755
if (elements.hasNext) {
752-
val writer = blockManager.getDiskWriter(
753-
blockId, outputFile, ser, fileBufferSize, context.taskMetrics.shuffleWriteMetrics.get)
756+
val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
757+
context.taskMetrics.shuffleWriteMetrics.get)
754758
for (elem <- elements) {
755759
writer.write(elem)
756760
}

core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class BlockObjectWriterSuite extends FunSuite {
3030
val file = new File(Utils.createTempDir(), "somefile")
3131
val writeMetrics = new ShuffleWriteMetrics()
3232
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
33-
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
33+
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
3434

3535
writer.write(Long.box(20))
3636
// Record metrics update on every write
@@ -52,7 +52,7 @@ class BlockObjectWriterSuite extends FunSuite {
5252
val file = new File(Utils.createTempDir(), "somefile")
5353
val writeMetrics = new ShuffleWriteMetrics()
5454
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
55-
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
55+
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
5656

5757
writer.write(Long.box(20))
5858
// Record metrics update on every write
@@ -75,7 +75,7 @@ class BlockObjectWriterSuite extends FunSuite {
7575
val file = new File(Utils.createTempDir(), "somefile")
7676
val writeMetrics = new ShuffleWriteMetrics()
7777
val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
78-
new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
78+
new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
7979

8080
writer.open()
8181
writer.close()

0 commit comments

Comments
 (0)