Skip to content

[SPARK-18890][CORE] Move task serialization from the TaskSetManager to the CoarseGrainedSchedulerBackend #15505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ private[spark] class CoarseGrainedExecutorBackend(
if (executor == null) {
exitExecutor(1, "Received LaunchTask command but executor was null")
} else {
val taskDesc = TaskDescription.decode(data.value)
val (taskDesc, serializedTask) = TaskDescription.decode(data.value)
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc)
executor.launchTask(this, taskDesc, serializedTask)
}

case KillTask(taskId, _, interruptThread) =>
Expand Down
14 changes: 9 additions & 5 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,11 @@ private[spark] class Executor(

private[executor] def numRunningTasks: Int = runningTasks.size()

def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
val tr = new TaskRunner(context, taskDescription)
def launchTask(
context: ExecutorBackend,
taskDescription: TaskDescription,
serializedTask: ByteBuffer): Unit = {
val tr = new TaskRunner(context, taskDescription, serializedTask)
runningTasks.put(taskDescription.taskId, tr)
threadPool.execute(tr)
}
Expand Down Expand Up @@ -208,7 +211,8 @@ private[spark] class Executor(

class TaskRunner(
execBackend: ExecutorBackend,
private val taskDescription: TaskDescription)
private val taskDescription: TaskDescription,
private val serializedTask: ByteBuffer)
extends Runnable {

val taskId = taskDescription.taskId
Expand Down Expand Up @@ -287,8 +291,8 @@ private[spark] class Executor(
Executor.taskDeserializationProps.set(taskDescription.properties)

updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
task = ser.deserialize[Task[Any]](
taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
task = Utils.deserialize(serializedTask,
Thread.currentThread.getContextClassLoader).asInstanceOf[Task[Any]]
task.localProperties = taskDescription.properties
task.setTaskMemoryManager(taskMemoryManager)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ import java.util.Properties

import scala.collection.JavaConverters._
import scala.collection.mutable.{HashMap, Map}
import scala.util.control.NonFatal

import org.apache.spark.TaskNotSerializableException
import org.apache.spark.internal.Logging
import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}

/**
Expand Down Expand Up @@ -52,8 +55,26 @@ private[spark] class TaskDescription(
val addedFiles: Map[String, Long],
val addedJars: Map[String, Long],
val properties: Properties,
val serializedTask: ByteBuffer) {

// Task object corresponding to the TaskDescription. This is only defined on the driver; on
// the executor, the Task object is handled separately from the TaskDescription so that it can
// be deserialized after the TaskDescription is deserialized.
@transient private val task: Task[_] = null) extends Logging {

/**
* Serializes the task for this TaskDescription and returns the serialized task.
*
* This method should only be used on the driver (to serialize a task to send to a executor).
*/
def serializeTask(): ByteBuffer = {
try {
ByteBuffer.wrap(Utils.serialize(task))
} catch {
case NonFatal(e) =>
val msg = s"Failed to serialize task $taskId."
logError(msg, e)
throw new TaskNotSerializableException(e)
}
}
override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index)
}

Expand All @@ -66,6 +87,7 @@ private[spark] object TaskDescription {
}
}

@throws[TaskNotSerializableException]
def encode(taskDescription: TaskDescription): ByteBuffer = {
val bytesOut = new ByteBufferOutputStream(4096)
val dataOut = new DataOutputStream(bytesOut)
Expand All @@ -89,8 +111,8 @@ private[spark] object TaskDescription {
dataOut.writeUTF(value)
}

// Write the task. The task is already serialized, so write it directly to the byte buffer.
Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut)
// Serialize and write the task.
Utils.writeByteBuffer(taskDescription.serializeTask(), bytesOut)

dataOut.close()
bytesOut.close()
Expand All @@ -106,7 +128,7 @@ private[spark] object TaskDescription {
map
}

def decode(byteBuffer: ByteBuffer): TaskDescription = {
def decode(byteBuffer: ByteBuffer): (TaskDescription, ByteBuffer) = {
val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer))
val taskId = dataIn.readLong()
val attemptNumber = dataIn.readInt()
Expand All @@ -130,7 +152,8 @@ private[spark] object TaskDescription {
// Create a sub-buffer for the serialized task into its own buffer (to be deserialized later).
val serializedTask = byteBuffer.slice()

new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
properties, serializedTask)
(new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars,
properties),
serializedTask)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,15 @@ private[spark] class TaskSchedulerImpl private[scheduler](
val execId = shuffledOffers(i).executorId
val host = shuffledOffers(i).host
if (availableCpus(i) >= CPUS_PER_TASK) {
try {
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
}
} catch {
case e: TaskNotSerializableException =>
logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
// Do not offer resources for this task, but don't throw an error to allow other
// task sets to be submitted.
return launchedTask
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorIdToRunningTaskIds(execId).add(tid)
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
launchedTask = true
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ private[spark] class TaskSetManager(

// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()

val tasks = taskSet.tasks
val numTasks = tasks.length
Expand Down Expand Up @@ -413,7 +412,6 @@ private[spark] class TaskSetManager(
* @param host the host Id of the offered resource
* @param maxLocality the maximum locality we want to schedule the tasks at
*/
@throws[TaskNotSerializableException]
def resourceOffer(
execId: String,
host: String,
Expand Down Expand Up @@ -454,33 +452,15 @@ private[spark] class TaskSetManager(
currentLocalityIndex = getLocalityIndex(taskLocality)
lastLaunchTime = curTime
}
// Serialize and return the task
val serializedTask: ByteBuffer = try {
ser.serialize(task)
} catch {
// If the task cannot be serialized, then there's no point to re-attempt the task,
// as it will always fail. So just abort the whole task-set.
case NonFatal(e) =>
val msg = s"Failed to serialize task $taskId, not attempting to retry it."
logError(msg, e)
abort(s"$msg Exception during serialization: $e")
throw new TaskNotSerializableException(e)
}
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
!emittedTaskSizeWarning) {
emittedTaskSizeWarning = true
logWarning(s"Stage ${task.stageId} contains a task of very large size " +
s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
}

addRunningTask(taskId)

// We used to log the time it takes to serialize the task, but task size is already
// a good proxy to task serialization time.
// val timeTaken = clock.getTime() - startTime
val taskName = s"task ${info.id} in stage ${taskSet.id}"
logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " +
s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)")
s"partition ${task.partitionId}, $taskLocality)")

sched.dagScheduler.taskStarted(task, info)
new TaskDescription(
Expand All @@ -492,7 +472,7 @@ private[spark] class TaskSetManager(
sched.sc.addedFiles,
sched.sc.addedJars,
task.localProperties,
serializedTask)
task)
}
} else {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@

package org.apache.spark.scheduler.cluster

import java.nio.ByteBuffer
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Future
import scala.concurrent.duration.Duration
import scala.util.control.NonFatal

import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.prepareSerializedTask
import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils}

/**
Expand Down Expand Up @@ -256,28 +258,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp

// Launch tasks returned by a set of resource offers
private def launchTasks(tasks: Seq[Seq[TaskDescription]]) {
val abortedTaskSets = new HashSet[TaskSetManager]()
for (task <- tasks.flatten) {
val serializedTask = TaskDescription.encode(task)
if (serializedTask.limit >= maxRpcMessageSize) {
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.rpc.message.maxSize (%d bytes). Consider increasing " +
"spark.rpc.message.maxSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
}
}
else {
val serializedTask = prepareSerializedTask(scheduler, task,
abortedTaskSets, maxRpcMessageSize)
if (serializedTask != null) {
val executorData = executorDataMap(task.executorId)
executorData.freeCores -= scheduler.CPUS_PER_TASK

logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
s"${executorData.executorHost}.")

logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " +
s"${executorData.executorHost}, serializedTask: ${serializedTask.limit} bytes.")
executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask)))
}
}
Expand Down Expand Up @@ -621,6 +610,77 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
}

private[spark] object CoarseGrainedSchedulerBackend {
private[spark] object CoarseGrainedSchedulerBackend extends Logging {
val ENDPOINT_NAME = "CoarseGrainedScheduler"

// abort TaskSetManager without exception
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remind me why we need this (why do we ignore exceptions here?)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why the old code ignored the exception, just copied it here. :)

private def abortTaskSetManager(
scheduler: TaskSchedulerImpl,
taskId: Long,
msg: => String,
exception: Option[Throwable] = None): Unit = scheduler.synchronized {
scheduler.taskIdToTaskSetManager.get(taskId).foreach { taskSetMgr =>
try {
taskSetMgr.abort(msg, exception)
} catch {
case e: Exception => logError("Exception while aborting taskset", e)
}
}
}

private def getTaskSetManager(
scheduler: TaskSchedulerImpl,
taskId: Long): Option[TaskSetManager] = scheduler.synchronized {
scheduler.taskIdToTaskSetManager.get(taskId)
}

private[scheduler] def prepareSerializedTask(
scheduler: TaskSchedulerImpl,
task: TaskDescription,
abortedTaskSets: HashSet[TaskSetManager],
maxRpcMessageSize: Long): ByteBuffer = {
var serializedTask: ByteBuffer = null

try {
if (abortedTaskSets.isEmpty ||
!getTaskSetManager(scheduler, task.taskId).
exists(t => t.isZombie || abortedTaskSets.contains(t))) {
serializedTask = TaskDescription.encode(task)
}
} catch {
case NonFatal(e) =>
abortTaskSetManager(scheduler, task.taskId,
s"Failed to serialize task ${task.taskId}, not attempting to retry it.", Some(e))
getTaskSetManager(scheduler, task.taskId).foreach(abortedTaskSets.add)
}

if (serializedTask != null && serializedTask.limit >= maxRpcMessageSize) {
val msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.rpc.message.maxSize (%d bytes). Consider increasing " +
"spark.rpc.message.maxSize or using broadcast variables for large values."
abortTaskSetManager(scheduler, task.taskId,
msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize))
getTaskSetManager(scheduler, task.taskId).foreach(t => abortedTaskSets.add(t))
serializedTask = null
} else if (serializedTask != null) {
maybeEmitTaskSizeWarning(scheduler, serializedTask, task.taskId)
}
serializedTask
}

private def maybeEmitTaskSizeWarning(
scheduler: TaskSchedulerImpl,
serializedTask: ByteBuffer,
taskId: Long): Unit = {
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) {
getTaskSetManager(scheduler, taskId).filterNot(_.emittedTaskSizeWarning).
foreach { taskSetMgr =>
taskSetMgr.emittedTaskSizeWarning = true
val stageId = taskSetMgr.taskSet.stageId
logWarning(s"Stage $stageId contains a task of very large size " +
s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " +
s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.")
}
}
}
}
Loading