diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index b376ecd301eab..8f4909ea2939e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -92,9 +92,9 @@ private[spark] class CoarseGrainedExecutorBackend( if (executor == null) { exitExecutor(1, "Received LaunchTask command but executor was null") } else { - val taskDesc = TaskDescription.decode(data.value) + val (taskDesc, serializedTask) = TaskDescription.decode(data.value) logInfo("Got assigned task " + taskDesc.taskId) - executor.launchTask(this, taskDesc) + executor.launchTask(this, taskDesc, serializedTask) } case KillTask(taskId, _, interruptThread) => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 975a6e4eeb33a..54449b1bcd2c4 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -150,8 +150,11 @@ private[spark] class Executor( private[executor] def numRunningTasks: Int = runningTasks.size() - def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = { - val tr = new TaskRunner(context, taskDescription) + def launchTask( + context: ExecutorBackend, + taskDescription: TaskDescription, + serializedTask: ByteBuffer): Unit = { + val tr = new TaskRunner(context, taskDescription, serializedTask) runningTasks.put(taskDescription.taskId, tr) threadPool.execute(tr) } @@ -208,7 +211,8 @@ private[spark] class Executor( class TaskRunner( execBackend: ExecutorBackend, - private val taskDescription: TaskDescription) + private val taskDescription: TaskDescription, + private val serializedTask: ByteBuffer) extends Runnable { val taskId = taskDescription.taskId @@ -287,8 +291,8 @@ private[spark] class Executor( Executor.taskDeserializationProps.set(taskDescription.properties) updateDependencies(taskDescription.addedFiles, taskDescription.addedJars) - task = ser.deserialize[Task[Any]]( - taskDescription.serializedTask, Thread.currentThread.getContextClassLoader) + task = Utils.deserialize(serializedTask, + Thread.currentThread.getContextClassLoader).asInstanceOf[Task[Any]] task.localProperties = taskDescription.properties task.setTaskMemoryManager(taskMemoryManager) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 78aa5c40010cc..c3c5ec9e0e1ce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -23,7 +23,10 @@ import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, Map} +import scala.util.control.NonFatal +import org.apache.spark.TaskNotSerializableException +import org.apache.spark.internal.Logging import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** @@ -52,8 +55,26 @@ private[spark] class TaskDescription( val addedFiles: Map[String, Long], val addedJars: Map[String, Long], val properties: Properties, - val serializedTask: ByteBuffer) { - + // Task object corresponding to the TaskDescription. This is only defined on the driver; on + // the executor, the Task object is handled separately from the TaskDescription so that it can + // be deserialized after the TaskDescription is deserialized. + @transient private val task: Task[_] = null) extends Logging { + + /** + * Serializes the task for this TaskDescription and returns the serialized task. + * + * This method should only be used on the driver (to serialize a task to send to a executor). + */ + def serializeTask(): ByteBuffer = { + try { + ByteBuffer.wrap(Utils.serialize(task)) + } catch { + case NonFatal(e) => + val msg = s"Failed to serialize task $taskId." + logError(msg, e) + throw new TaskNotSerializableException(e) + } + } override def toString: String = "TaskDescription(TID=%d, index=%d)".format(taskId, index) } @@ -66,6 +87,7 @@ private[spark] object TaskDescription { } } + @throws[TaskNotSerializableException] def encode(taskDescription: TaskDescription): ByteBuffer = { val bytesOut = new ByteBufferOutputStream(4096) val dataOut = new DataOutputStream(bytesOut) @@ -89,8 +111,8 @@ private[spark] object TaskDescription { dataOut.writeUTF(value) } - // Write the task. The task is already serialized, so write it directly to the byte buffer. - Utils.writeByteBuffer(taskDescription.serializedTask, bytesOut) + // Serialize and write the task. + Utils.writeByteBuffer(taskDescription.serializeTask(), bytesOut) dataOut.close() bytesOut.close() @@ -106,7 +128,7 @@ private[spark] object TaskDescription { map } - def decode(byteBuffer: ByteBuffer): TaskDescription = { + def decode(byteBuffer: ByteBuffer): (TaskDescription, ByteBuffer) = { val dataIn = new DataInputStream(new ByteBufferInputStream(byteBuffer)) val taskId = dataIn.readLong() val attemptNumber = dataIn.readInt() @@ -130,7 +152,8 @@ private[spark] object TaskDescription { // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). val serializedTask = byteBuffer.slice() - new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, - properties, serializedTask) + (new TaskDescription(taskId, attemptNumber, executorId, name, index, taskFiles, taskJars, + properties), + serializedTask) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bfbcfa1aa386f..9c5fcf14941ea 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -277,23 +277,15 @@ private[spark] class TaskSchedulerImpl private[scheduler]( val execId = shuffledOffers(i).executorId val host = shuffledOffers(i).host if (availableCpus(i) >= CPUS_PER_TASK) { - try { - for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { - tasks(i) += task - val tid = task.taskId - taskIdToTaskSetManager(tid) = taskSet - taskIdToExecutorId(tid) = execId - executorIdToRunningTaskIds(execId).add(tid) - availableCpus(i) -= CPUS_PER_TASK - assert(availableCpus(i) >= 0) - launchedTask = true - } - } catch { - case e: TaskNotSerializableException => - logError(s"Resource offer failed, task set ${taskSet.name} was not serializable") - // Do not offer resources for this task, but don't throw an error to allow other - // task sets to be submitted. - return launchedTask + for (task <- taskSet.resourceOffer(execId, host, maxLocality)) { + tasks(i) += task + val tid = task.taskId + taskIdToTaskSetManager(tid) = taskSet + taskIdToExecutorId(tid) = execId + executorIdToRunningTaskIds(execId).add(tid) + availableCpus(i) -= CPUS_PER_TASK + assert(availableCpus(i) >= 0) + launchedTask = true } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3b25513bea057..fd84ce58d60d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -65,7 +65,6 @@ private[spark] class TaskSetManager( // Serializer for closures and tasks. val env = SparkEnv.get - val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks val numTasks = tasks.length @@ -413,7 +412,6 @@ private[spark] class TaskSetManager( * @param host the host Id of the offered resource * @param maxLocality the maximum locality we want to schedule the tasks at */ - @throws[TaskNotSerializableException] def resourceOffer( execId: String, host: String, @@ -454,25 +452,7 @@ private[spark] class TaskSetManager( currentLocalityIndex = getLocalityIndex(taskLocality) lastLaunchTime = curTime } - // Serialize and return the task - val serializedTask: ByteBuffer = try { - ser.serialize(task) - } catch { - // If the task cannot be serialized, then there's no point to re-attempt the task, - // as it will always fail. So just abort the whole task-set. - case NonFatal(e) => - val msg = s"Failed to serialize task $taskId, not attempting to retry it." - logError(msg, e) - abort(s"$msg Exception during serialization: $e") - throw new TaskNotSerializableException(e) - } - if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && - !emittedTaskSizeWarning) { - emittedTaskSizeWarning = true - logWarning(s"Stage ${task.stageId} contains a task of very large size " + - s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + - s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") - } + addRunningTask(taskId) // We used to log the time it takes to serialize the task, but task size is already @@ -480,7 +460,7 @@ private[spark] class TaskSetManager( // val timeTaken = clock.getTime() - startTime val taskName = s"task ${info.id} in stage ${taskSet.id}" logInfo(s"Starting $taskName (TID $taskId, $host, executor ${info.executorId}, " + - s"partition ${task.partitionId}, $taskLocality, ${serializedTask.limit} bytes)") + s"partition ${task.partitionId}, $taskLocality)") sched.dagScheduler.taskStarted(task, info) new TaskDescription( @@ -492,7 +472,7 @@ private[spark] class TaskSetManager( sched.sc.addedFiles, sched.sc.addedJars, task.localProperties, - serializedTask) + task) } } else { None diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 94abe30bb12f2..e16575d35f636 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -17,13 +17,14 @@ package org.apache.spark.scheduler.cluster +import java.nio.ByteBuffer import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future -import scala.concurrent.duration.Duration +import scala.util.control.NonFatal import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} import org.apache.spark.internal.Logging @@ -31,6 +32,7 @@ import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.prepareSerializedTask import org.apache.spark.util.{RpcUtils, SerializableBuffer, ThreadUtils, Utils} /** @@ -256,28 +258,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Launch tasks returned by a set of resource offers private def launchTasks(tasks: Seq[Seq[TaskDescription]]) { + val abortedTaskSets = new HashSet[TaskSetManager]() for (task <- tasks.flatten) { - val serializedTask = TaskDescription.encode(task) - if (serializedTask.limit >= maxRpcMessageSize) { - scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr => - try { - var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + - "spark.rpc.message.maxSize (%d bytes). Consider increasing " + - "spark.rpc.message.maxSize or using broadcast variables for large values." - msg = msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize) - taskSetMgr.abort(msg) - } catch { - case e: Exception => logError("Exception in error callback", e) - } - } - } - else { + val serializedTask = prepareSerializedTask(scheduler, task, + abortedTaskSets, maxRpcMessageSize) + if (serializedTask != null) { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - - logDebug(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + - s"${executorData.executorHost}.") - + logInfo(s"Launching task ${task.taskId} on executor id: ${task.executorId} hostname: " + + s"${executorData.executorHost}, serializedTask: ${serializedTask.limit} bytes.") executorData.executorEndpoint.send(LaunchTask(new SerializableBuffer(serializedTask))) } } @@ -621,6 +610,77 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } } -private[spark] object CoarseGrainedSchedulerBackend { +private[spark] object CoarseGrainedSchedulerBackend extends Logging { val ENDPOINT_NAME = "CoarseGrainedScheduler" + + // abort TaskSetManager without exception + private def abortTaskSetManager( + scheduler: TaskSchedulerImpl, + taskId: Long, + msg: => String, + exception: Option[Throwable] = None): Unit = scheduler.synchronized { + scheduler.taskIdToTaskSetManager.get(taskId).foreach { taskSetMgr => + try { + taskSetMgr.abort(msg, exception) + } catch { + case e: Exception => logError("Exception while aborting taskset", e) + } + } + } + + private def getTaskSetManager( + scheduler: TaskSchedulerImpl, + taskId: Long): Option[TaskSetManager] = scheduler.synchronized { + scheduler.taskIdToTaskSetManager.get(taskId) + } + + private[scheduler] def prepareSerializedTask( + scheduler: TaskSchedulerImpl, + task: TaskDescription, + abortedTaskSets: HashSet[TaskSetManager], + maxRpcMessageSize: Long): ByteBuffer = { + var serializedTask: ByteBuffer = null + + try { + if (abortedTaskSets.isEmpty || + !getTaskSetManager(scheduler, task.taskId). + exists(t => t.isZombie || abortedTaskSets.contains(t))) { + serializedTask = TaskDescription.encode(task) + } + } catch { + case NonFatal(e) => + abortTaskSetManager(scheduler, task.taskId, + s"Failed to serialize task ${task.taskId}, not attempting to retry it.", Some(e)) + getTaskSetManager(scheduler, task.taskId).foreach(abortedTaskSets.add) + } + + if (serializedTask != null && serializedTask.limit >= maxRpcMessageSize) { + val msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " + + "spark.rpc.message.maxSize (%d bytes). Consider increasing " + + "spark.rpc.message.maxSize or using broadcast variables for large values." + abortTaskSetManager(scheduler, task.taskId, + msg.format(task.taskId, task.index, serializedTask.limit, maxRpcMessageSize)) + getTaskSetManager(scheduler, task.taskId).foreach(t => abortedTaskSets.add(t)) + serializedTask = null + } else if (serializedTask != null) { + maybeEmitTaskSizeWarning(scheduler, serializedTask, task.taskId) + } + serializedTask + } + + private def maybeEmitTaskSizeWarning( + scheduler: TaskSchedulerImpl, + serializedTask: ByteBuffer, + taskId: Long): Unit = { + if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) { + getTaskSetManager(scheduler, taskId).filterNot(_.emittedTaskSizeWarning). + foreach { taskSetMgr => + taskSetMgr.emittedTaskSizeWarning = true + val stageId = taskSetMgr.taskSet.stageId + logWarning(s"Stage $stageId contains a task of very large size " + + s"(${serializedTask.limit / 1024} KB). The maximum recommended task size is " + + s"${TaskSetManager.TASK_SIZE_TO_WARN_KB} KB.") + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala index 625f998cd4608..8f44b7def2db6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalSchedulerBackend.scala @@ -21,6 +21,8 @@ import java.io.File import java.net.URL import java.nio.ByteBuffer +import scala.collection.mutable.HashSet + import org.apache.spark.{SparkConf, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} @@ -28,7 +30,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.prepareSerializedTask import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.util.RpcUtils private case class ReviveOffers() @@ -59,6 +63,8 @@ private[spark] class LocalEndpoint( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, userClassPath, isLocal = true) + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(SparkEnv.get.conf) + override def receive: PartialFunction[Any, Unit] = { case ReviveOffers => reviveOffers() @@ -82,9 +88,17 @@ private[spark] class LocalEndpoint( def reviveOffers() { val offers = IndexedSeq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores)) + val abortTaskSet = new HashSet[TaskSetManager]() for (task <- scheduler.resourceOffers(offers).flatten) { - freeCores -= scheduler.CPUS_PER_TASK - executor.launchTask(executorBackend, task) + // make sure the task is serializable, + // so that it can be launched in a distributed environment. + val buffer = prepareSerializedTask(scheduler, task, + abortTaskSet, maxRpcMessageSize) + if (buffer != null) { + freeCores -= scheduler.CPUS_PER_TASK + val (taskDesc, serializedTask) = TaskDescription.decode(buffer) + executor.launchTask(executorBackend, taskDesc, serializedTask) + } } } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 10e5233679562..b0403fe9a0097 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -150,7 +150,12 @@ private[spark] object Utils extends Logging { /** Deserialize an object using Java serialization and the given ClassLoader */ def deserialize[T](bytes: Array[Byte], loader: ClassLoader): T = { - val bis = new ByteArrayInputStream(bytes) + deserialize(ByteBuffer.wrap(bytes), loader) + } + + /** Deserialize an object using Java serialization and the given ClassLoader */ + def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = { + val bis = new ByteBufferInputStream(bytes) val ois = new ObjectInputStream(bis) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { // scalastyle:off classforname diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index b743ff5376c49..26399f6bd9973 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -49,7 +49,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0)) - val taskDescription = createFakeTaskDescription(serializedTask) + val taskDescription = createFakeTaskDescription() // we use latches to force the program to run in this order: // +-----------------------------+---------------------------------------+ @@ -99,7 +99,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug try { executor = new Executor("id", "localhost", env, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread - executor.launchTask(mockExecutorBackend, taskDescription) + executor.launchTask(mockExecutorBackend, taskDescription, serializedTask) if (!executorSuiteHelper.latch1.await(5, TimeUnit.SECONDS)) { fail("executor did not send first status update in time") @@ -128,9 +128,9 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug val serializer = new JavaSerializer(conf) val env = createMockEnv(conf, serializer) val serializedTask = serializer.newInstance().serialize(new NonDeserializableTask) - val taskDescription = createFakeTaskDescription(serializedTask) + val taskDescription = createFakeTaskDescription() - val failReason = runTaskAndGetFailReason(taskDescription) + val failReason = runTaskAndGetFailReason(taskDescription, serializedTask) failReason match { case ef: ExceptionFailure => assert(ef.exception.isDefined) @@ -155,7 +155,7 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug mockEnv } - private def createFakeTaskDescription(serializedTask: ByteBuffer): TaskDescription = { + private def createFakeTaskDescription(): TaskDescription = { new TaskDescription( taskId = 0, attemptNumber = 0, @@ -164,17 +164,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug index = 0, addedFiles = Map[String, Long](), addedJars = Map[String, Long](), - properties = new Properties, - serializedTask) + properties = new Properties) } - private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + private def runTaskAndGetFailReason( + taskDescription: TaskDescription, + serializedTask: ByteBuffer): TaskFailedReason = { val mockBackend = mock[ExecutorBackend] var executor: Executor = null try { executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) // the task will be launched in a dedicated worker thread - executor.launchTask(mockBackend, taskDescription) + executor.launchTask(mockBackend, taskDescription, serializedTask) eventually(timeout(5 seconds), interval(10 milliseconds)) { assert(executor.numRunningTasks === 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index 04cccc67e328e..1b5e8ca68b9be 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -17,11 +17,46 @@ package org.apache.spark.scheduler -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import java.io.{IOException, NotSerializableException, ObjectInputStream, ObjectOutputStream} +import java.util.Properties + +import scala.collection.mutable + +import org.mockito.Matchers._ +import org.mockito.Mockito._ + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RetrieveSparkAppConfig} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.{RpcUtils, SerializableBuffer} -class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { +private[spark] class NotSerializablePartitionRDD( + sc: SparkContext, + numPartitions: Int) extends RDD[(Int, Int)](sc, Nil) with Serializable { + + override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = + throw new RuntimeException("should not be reached") + + override def getPartitions: Array[Partition] = (0 until numPartitions).map(i => new Partition { + override def index: Int = i + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + throw new NotSerializableException() + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} + }).toArray + + override def getPreferredLocations(partition: Partition): Seq[String] = Nil + + override def toString: String = "DAGSchedulerSuiteRDD " + id +} +class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { test("serialized task larger than max RPC message size") { val conf = new SparkConf conf.set("spark.rpc.message.maxSize", "1") @@ -38,4 +73,75 @@ class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkCo assert(smaller.size === 4) } + test("Scheduler aborts stages that have unserializable partition") { + val conf = new SparkConf() + .setMaster("local-cluster[2, 1, 1024]") + .setAppName("test") + .set("spark.dynamicAllocation.testing", "true") + sc = new SparkContext(conf) + val myRDD = new NotSerializablePartitionRDD(sc, 2) + val e = intercept[SparkException] { + myRDD.count() + } + assert(e.getMessage.contains("Failed to serialize task")) + assertResult(10) { + sc.parallelize(1 to 10).count() + } + } + + test("serialization task errors do not affect each other") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + sc = new SparkContext(conf) + val rpcEnv = sc.env.rpcEnv + + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + val message = RegisterExecutor("1", endpointRef, "localhost", 4, Map.empty) + + val taskScheduler = mock(classOf[TaskSchedulerImpl]) + when(taskScheduler.CPUS_PER_TASK).thenReturn(1) + when(taskScheduler.sc).thenReturn(sc) + when(taskScheduler.mapOutputTracker).thenReturn(sc.env.mapOutputTracker) + val taskIdToTaskSetManager = new mutable.HashMap[Long, TaskSetManager] + when(taskScheduler.taskIdToTaskSetManager).thenReturn(taskIdToTaskSetManager) + val dagScheduler = mock(classOf[DAGScheduler]) + when(taskScheduler.dagScheduler).thenReturn(dagScheduler) + val taskSet1 = FakeTask.createTaskSet(1) + val taskSet2 = FakeTask.createTaskSet(1) + taskSet1.tasks(0) = new NotSerializableFakeTask(1, 0) + + def createTaskDescription(taskId: Long, task: Task[_]): TaskDescription = { + new TaskDescription( + taskId = 1L, + attemptNumber = 0, + executorId = "1", + name = "localhost", + index = 0, + addedFiles = mutable.Map.empty[String, Long], + addedJars = mutable.Map.empty[String, Long], + properties = new Properties(), + task = task) + } + + when(taskScheduler.resourceOffers(any[IndexedSeq[WorkerOffer]])).thenReturn(Seq(Seq( + createTaskDescription(1, taskSet1.tasks.head), + createTaskDescription(2, taskSet2.tasks.head)))) + taskIdToTaskSetManager(1L) = new TaskSetManager(taskScheduler, taskSet1, 1) + taskIdToTaskSetManager(2L) = new TaskSetManager(taskScheduler, taskSet2, 1) + + val backend = new CoarseGrainedSchedulerBackend(taskScheduler, rpcEnv) + backend.start() + backend.driverEndpoint.askSync[Boolean](message) + backend.reviveOffers() + // Make sure that the ReviveOffers message has been processed. + // backend.driverEndpoint is thread safe. However, If you modify it, + // please modify the code here + backend.driverEndpoint.askSync[Any](RetrieveSparkAppConfig) + assert(taskIdToTaskSetManager(1L).isZombie === true) + assert(taskIdToTaskSetManager(2L).isZombie === false) + backend.stop() + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8eaf9dfcf49b1..ae98ee7cdbc12 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{IOException, NotSerializableException, ObjectInputStream, ObjectOutputStream} import java.util.Properties import java.util.concurrent.atomic.AtomicBoolean @@ -517,6 +518,32 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou assertDataStructuresEmpty() } + test("unserializable partitioner") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new Partitioner { + override def numPartitions = 1 + + override def getPartition(key: Any) = 1 + + @throws(classOf[IOException]) + private def writeObject(out: ObjectOutputStream): Unit = { + throw new NotSerializableException() + } + + @throws(classOf[IOException]) + private def readObject(in: ObjectInputStream): Unit = {} + }) + + // Submit a map stage by itself + submitMapStage(shuffleDep) + assert(failure.getMessage.startsWith( + "Job aborted due to stage failure: Task not serializable")) + sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) + assertDataStructuresEmpty() + } + test("trivial job failure") { submit(new MyRDD(sc, 1, Nil), Array(0)) failed(taskSets(0), "some failure") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 9f1fe0515732e..d8007dde18c65 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -37,7 +37,6 @@ class TaskDescriptionSuite extends SparkFunSuite { originalProperties.put("property1", "18") originalProperties.put("property2", "test value") - // Create a dummy byte buffer for the task. val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) val originalTaskDescription = new TaskDescription( @@ -49,11 +48,15 @@ class TaskDescriptionSuite extends SparkFunSuite { originalFiles, originalJars, originalProperties, - taskBuffer - ) + // Pass in null for the task, because we override the serialize method below anyway (which + // is the only time task is used). + task = null + ) { + override def serializeTask() = taskBuffer + } val serializedTaskDescription = TaskDescription.encode(originalTaskDescription) - val decodedTaskDescription = TaskDescription.decode(serializedTaskDescription) + val (decodedTaskDescription, serializedTask) = TaskDescription.decode(serializedTaskDescription) // Make sure that all of the fields in the decoded task description match the original. assert(decodedTaskDescription.taskId === originalTaskDescription.taskId) @@ -64,6 +67,6 @@ class TaskDescriptionSuite extends SparkFunSuite { assert(decodedTaskDescription.addedFiles.equals(originalFiles)) assert(decodedTaskDescription.addedJars.equals(originalJars)) assert(decodedTaskDescription.properties.equals(originalTaskDescription.properties)) - assert(decodedTaskDescription.serializedTask.equals(taskBuffer)) + assert(serializedTask.equals(taskBuffer)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ae0bcd9b8860..5edd1992002c1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -178,29 +178,6 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(!failedTaskSet) } - test("Scheduler does not crash when tasks are not serializable") { - val taskCpus = 2 - val taskScheduler = setupScheduler("spark.task.cpus" -> taskCpus.toString) - val numFreeCores = 1 - val taskSet = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) - val multiCoreWorkerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", taskCpus), - new WorkerOffer("executor1", "host1", numFreeCores)) - taskScheduler.submitTasks(taskSet) - var taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten - assert(0 === taskDescriptions.length) - assert(failedTaskSet) - assert(failedTaskSetReason.contains("Failed to serialize task")) - - // Now check that we can still submit tasks - // Even if one of the task sets has not-serializable tasks, the other task set should - // still be processed without error - taskScheduler.submitTasks(FakeTask.createTaskSet(1)) - taskScheduler.submitTasks(taskSet) - taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten - assert(taskDescriptions.map(_.executorId) === Seq("executor0")) - } - test("refuse to schedule concurrent attempts for the same stage (SPARK-8103)") { val taskScheduler = setupScheduler() val attempt1 = FakeTask.createTaskSet(1, 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index d03a0c990a02b..41485277aea70 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -611,34 +611,6 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(!manager.emittedTaskSizeWarning) } - test("emit warning when serialized task is large") { - sc = new SparkContext("local", "test") - sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - - val taskSet = new TaskSet(Array(new LargeTask(0)), 0, 0, 0, null) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) - - assert(!manager.emittedTaskSizeWarning) - - assert(manager.resourceOffer("exec1", "host1", ANY).get.index === 0) - - assert(manager.emittedTaskSizeWarning) - } - - test("Not serializable exception thrown if the task cannot be serialized") { - sc = new SparkContext("local", "test") - sched = new FakeTaskScheduler(sc, ("exec1", "host1")) - - val taskSet = new TaskSet( - Array(new NotSerializableFakeTask(1, 0), new NotSerializableFakeTask(0, 1)), 0, 0, 0, null) - val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) - - intercept[TaskNotSerializableException] { - manager.resourceOffer("exec1", "host1", ANY) - } - assert(manager.isZombie) - } - test("abort the job if total size of results is too large") { val conf = new SparkConf().set("spark.driver.maxResultSize", "2m") sc = new SparkContext("local", "test", conf) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index b252539782580..29711a17097ca 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -85,12 +85,12 @@ private[spark] class MesosExecutorBackend } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { - val taskDescription = TaskDescription.decode(taskInfo.getData.asReadOnlyByteBuffer()) + val (taskDesc, serializedTask) = TaskDescription.decode(taskInfo.getData.asReadOnlyByteBuffer()) if (executor == null) { logError("Received launchTask but executor was null") } else { SparkHadoopUtil.get.runAsSparkUser { () => - executor.launchTask(this, taskDescription) + executor.launchTask(this, taskDesc, serializedTask) } } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 7e561916a71e2..e0a720b05ebec 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File +import java.nio.ByteBuffer import java.util.{ArrayList => JArrayList, Collections, List => JList} import scala.collection.JavaConverters._ @@ -29,8 +30,9 @@ import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.prepareSerializedTask import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.util.Utils +import org.apache.spark.util.{RpcUtils, Utils} /** * A SchedulerBackend for running fine-grained tasks on Mesos. Each Spark task is mapped to a @@ -67,6 +69,8 @@ private[spark] class MesosFineGrainedSchedulerBackend( private val rejectOfferDurationForUnmetConstraints = getRejectOfferDurationForUnmetConstraints(sc) + private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(sc.conf) + @volatile var appId: String = _ override def start() { @@ -291,24 +295,26 @@ private[spark] class MesosFineGrainedSchedulerBackend( val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] val slavesIdsOfAcceptedOffers = HashSet[String]() - + val abortTaskSet = new HashSet[TaskSetManager]() // Call into the TaskSchedulerImpl - val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) - acceptedOffers - .foreach { offer => - offer.foreach { taskDesc => - val slaveId = taskDesc.executorId - slavesIdsOfAcceptedOffers += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - val (mesosTask, remainingResources) = createMesosTask( - taskDesc, - slaveIdToResources(slaveId), - slaveId) - mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) - .add(mesosTask) - slaveIdToResources(slaveId) = remainingResources - } + val acceptedOffers = scheduler.resourceOffers(workerOffers).flatten + for (task <- acceptedOffers) { + val serializedTask = prepareSerializedTask(scheduler, task, + abortTaskSet, maxRpcMessageSize) + if (serializedTask != null) { + val slaveId = task.executorId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(task.taskId) = slaveId + val (mesosTask, remainingResources) = createMesosTask( + task, + serializedTask, + slaveIdToResources(slaveId), + slaveId) + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(mesosTask) + slaveIdToResources(slaveId) = remainingResources } + } // Reply to the offers val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? @@ -334,6 +340,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( /** Turn a Spark TaskDescription into a Mesos task and also resources unused by the task */ def createMesosTask( task: TaskDescription, + serializedTask: ByteBuffer, resources: JList[Resource], slaveId: String): (MesosTaskInfo, JList[Resource]) = { val taskId = TaskID.newBuilder().setValue(task.taskId.toString).build() @@ -351,7 +358,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( .setExecutor(executorInfo) .setName(task.name) .addAllResources(cpuResources.asJava) - .setData(ByteString.copyFrom(TaskDescription.encode(task))) + .setData(ByteString.copyFrom(serializedTask)) .build() (taskInfo, finalResources.asJava) } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala index 4ee85b91830a9..03f0cf7ecb3b1 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackendSuite.scala @@ -255,8 +255,7 @@ class MesosFineGrainedSchedulerBackendSuite index = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], - properties = new Properties(), - ByteBuffer.wrap(new Array[Byte](0))) + properties = new Properties()) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(2) @@ -363,8 +362,7 @@ class MesosFineGrainedSchedulerBackendSuite index = 0, addedFiles = mutable.Map.empty[String, Long], addedJars = mutable.Map.empty[String, Long], - properties = new Properties(), - ByteBuffer.wrap(new Array[Byte](0))) + properties = new Properties()) when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc))) when(taskScheduler.CPUS_PER_TASK).thenReturn(1)