diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bf3c3a6ceb5ef..26872b197d232 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -20,6 +20,8 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.reflect.ClassTag +import scala.util.DynamicVariable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -63,6 +65,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { with SynchronizedBuffer[CleanerListener] private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + private var broadcastRefCounts = Map(0L -> 0L) /** * Whether the cleaning thread will block on cleanup tasks. @@ -102,9 +106,25 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Register a Broadcast for cleanup when it is garbage collected. */ def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + incBroadcastRefCount(broadcast.id) registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } + private def incBroadcastRefCount[T](bid: Long) { + val newRefCount: Long = this.broadcastRefCounts.getOrElse(bid, 0L) + 1 + this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, newRefCount) + } + + private def decBroadcastRefCount[T](bid: Long) = { + this.broadcastRefCounts.get(bid) match { + case Some(rc:Long) if rc > 0 => { + this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, rc - 1) + rc - 1 + } + case _ => 0 + } + } + /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) @@ -161,14 +181,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { /** Perform broadcast cleanup. */ def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { - try { - logDebug("Cleaning broadcast " + broadcastId) - broadcastManager.unbroadcast(broadcastId, true, blocking) - listeners.foreach(_.broadcastCleaned(broadcastId)) - logInfo("Cleaned broadcast " + broadcastId) - } catch { - case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) + decBroadcastRefCount(broadcastId) match { + case x if x > 0 => {} + case _ => try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, true, blocking) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case e: Exception => logError("Error cleaning broadcast " + broadcastId, e) + } } + } private def blockManagerMaster = sc.env.blockManager.master @@ -179,8 +203,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { // to ensure that more reliable testing. } -private object ContextCleaner { +private[spark] object ContextCleaner { private val REF_QUEUE_POLL_TIMEOUT = 100 + val currentCleaner = new DynamicVariable[Option[ContextCleaner]](None) + + /** + * Runs the given thunk with a dynamically-scoped binding for the current ContextCleaner. + * This is necessary for blocks of code that serialize and deserialize broadcast variable + * objects, since all clones of a Broadcast object b need to be re-registered with the + * context cleaner that is tracking b. + */ + def withCurrentCleaner[T <: Any : ClassTag](cc: Option[ContextCleaner])(thunk: => T) = { + currentCleaner.withValue(cc)(thunk) + } } /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8819e73d17fb2..9c93be0b6518a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1058,7 +1058,9 @@ class SparkContext(config: SparkConf) extends Logging { throw new SparkException("SparkContext has been shutdown") } val callSite = getCallSite - val cleanedFunc = clean(func) + // There's no need to check this function for serializability, + // since it will be run right away. + val cleanedFunc = clean(func, false) logInfo("Starting job: " + callSite.short) val start = System.nanoTime dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, @@ -1212,9 +1214,8 @@ class SparkContext(config: SparkConf) extends Logging { * @throws SparkException if checkSerializable is set but f is not * serializable */ - private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { - ClosureCleaner.clean(f, checkSerializable) - f + private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean = true): F = { + ClosureCleaner.clean(f, checkSerializable, this) } /** diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 76956f6a345d1..ee0936d83ab8f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -17,9 +17,9 @@ package org.apache.spark.broadcast -import java.io.Serializable +import java.io.{ObjectInputStream, Serializable} -import org.apache.spark.SparkException +import org.apache.spark.{ContextCleaner, SparkException} import scala.reflect.ClassTag @@ -129,4 +129,12 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { } override def toString = "Broadcast(" + id + ")" + + private def readObject(in: ObjectInputStream) { + in.defaultReadObject() + ContextCleaner.currentCleaner.value match { + case None => {} + case Some(cc: ContextCleaner) => cc.registerBroadcastForCleanup(this) + } + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4e841bc992bff..a025934f46435 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -733,14 +733,16 @@ abstract class RDD[T: ClassTag]( * Applies a function f to all elements of this RDD. */ def foreach(f: T => Unit) { - sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } /** * Applies a function f to each partition of this RDD. */ def foreachPartition(f: Iterator[T] => Unit) { - sc.runJob(this, (iter: Iterator[T]) => f(iter)) + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => cleanF(iter)) } /** diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index e3f52f6ff1e63..3be9e418a9132 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -22,10 +22,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import scala.collection.mutable.Map import scala.collection.mutable.Set +import scala.reflect.ClassTag + import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type} import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ -import org.apache.spark.{Logging, SparkEnv, SparkException} +import org.apache.spark.{Logging, SparkEnv, SparkException, SparkContext, ContextCleaner} private[spark] object ClosureCleaner extends Logging { // Get an ASM class reader for a given class from the JAR that loaded it @@ -100,8 +102,8 @@ private[spark] object ClosureCleaner extends Logging { null } } - - def clean(func: AnyRef, checkSerializable: Boolean = true) { + + def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true, sc: SparkContext): F = { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) @@ -154,14 +156,19 @@ private[spark] object ClosureCleaner extends Logging { field.set(func, outer) } - if (checkSerializable) { - ensureSerializable(func) + if (captureNow) { + ContextCleaner.withCurrentCleaner(sc.cleaner){ + cloneViaSerializing(func) + } + } else { + func } } - private def ensureSerializable(func: AnyRef) { + private def cloneViaSerializing[T: ClassTag](func: T): T = { try { - SparkEnv.get.closureSerializer.newInstance().serialize(func) + val serializer = SparkEnv.get.closureSerializer.newInstance() + serializer.deserialize[T](serializer.serialize[T](func)) } catch { case ex: Exception => throw new SparkException("Task not serializable", ex) } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 13b415cccb647..0d6c0377be75a 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -22,6 +22,7 @@ import java.lang.ref.WeakReference import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.language.postfixOps +import scala.reflect.ClassTag import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} @@ -141,6 +142,33 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo postGCTester.assertCleanup() } + test("automatically cleanup broadcast only after all extant clones become unreachable") { + var broadcast = newBroadcast + + // clone this broadcast variable + var broadcastClone = cloneBySerializing(broadcast) + + val id = broadcast.id + + // eliminate all strong references to the original broadcast; keep the clone + broadcast = null + + // Test that GC does not cause broadcast cleanup since a strong reference to a + // clone of the broadcast with the given id still exist + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes broadcast cleanup after dereferencing the clone + val postGCTester = new CleanerTester(sc, broadcastIds = Seq(id)) + broadcastClone = null + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly @@ -242,7 +270,14 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo Thread.sleep(200) } } - + + def cloneBySerializing[T <: Any : ClassTag](ref: T): T = { + val serializer = SparkEnv.get.closureSerializer.newInstance() + ContextCleaner.withCurrentCleaner[T](sc.cleaner){ + serializer.deserialize(serializer.serialize(ref)) + } + } + def cleaner = sc.cleaner.get } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index e755d2e309398..67b1d84ac15b0 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -110,7 +110,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } - test("failure because task closure is not serializable") { + test("failure because closure in final-stage task is not serializable") { sc = new SparkContext("local[1,1]", "test") val a = new NonSerializable @@ -122,6 +122,13 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown.getMessage.contains("NotSerializableException") || thrown.getCause.getClass === classOf[NotSerializableException]) + FailureSuiteState.clear() + } + + test("failure because closure in early-stage task is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + // Non-serializable closure in an earlier stage val thrown1 = intercept[SparkException] { sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count() @@ -130,6 +137,13 @@ class FailureSuite extends FunSuite with LocalSparkContext { assert(thrown1.getMessage.contains("NotSerializableException") || thrown1.getCause.getClass === classOf[NotSerializableException]) + FailureSuiteState.clear() + } + + test("failure because closure in foreach task is not serializable") { + sc = new SparkContext("local[1,1]", "test") + val a = new NonSerializable + // Non-serializable closure in foreach function val thrown2 = intercept[SparkException] { sc.parallelize(1 to 10, 2).foreach(x => println(a)) @@ -141,5 +155,6 @@ class FailureSuite extends FunSuite with LocalSparkContext { FailureSuiteState.clear() } + // TODO: Need to add tests with shuffle fetch failures. } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index 5d15a68ac7e4f..304f60c013d92 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -51,11 +51,11 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex // transformation on a given RDD, creating one test case for each for (transformation <- - Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _, - "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _, - "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _, - "filterWith" -> xfilterWith _)) { + Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, + "mapWith" -> mapWith _, "mapPartitions" -> mapPartitions _, + "mapPartitionsWithIndex" -> mapPartitionsWithIndex _, + "mapPartitionsWithContext" -> mapPartitionsWithContext _, + "filterWith" -> filterWith _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { @@ -70,21 +70,28 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex } } - private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.map(y=>uc.op(y)) - private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapWith(x => x.toString)((x,y)=>x + uc.op(y)) - private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = + def map(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.map(y => uc.op(y)) + + def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapWith(x => x.toString)((x,y) => x + uc.op(y)) + + def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] = x.flatMap(y=>Seq(uc.op(y))) - private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] = + + def filter(x: RDD[String], uc: UnserializableClass): RDD[String] = x.filter(y=>uc.pred(y)) - private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.filterWith(x => x.toString)((x,y)=>uc.pred(y)) - private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitions(_.map(y=>uc.op(y))) - private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) - private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) + + def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.filterWith(x => x.toString)((x,y) => uc.pred(y)) + + def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitions(_.map(y => uc.op(y))) + + def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y))) + + def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = + x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y))) } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 054ef54e746a5..781c562d93f05 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -63,6 +63,20 @@ class ClosureCleanerSuite extends FunSuite { val result = TestObjectWithNestedReturns.run() assert(result == 1) } + + test("capturing free variables in closures at RDD definition") { + val obj = new TestCaptureVarClass() + val (ones, onesPlusZeroes) = obj.run() + + assert(ones === onesPlusZeroes) + } + + test("capturing free variable fields in closures at RDD definition") { + val obj = new TestCaptureFieldClass() + val (ones, onesPlusZeroes) = obj.run() + + assert(ones === onesPlusZeroes) + } } // A non-serializable class we create in closures to make sure that we aren't @@ -180,3 +194,37 @@ class TestClassWithNesting(val y: Int) extends Serializable { } } } + +class TestCaptureFieldClass extends Serializable { + class ZeroBox extends Serializable { + var zero = 0 + } + + def run(): (Int, Int) = { + val zb = new ZeroBox + + withSpark(new SparkContext("local", "test")) {sc => + val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) + val onesPlusZeroes = ones.map(_ + zb.zero) + + zb.zero = 5 + + (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) + } + } +} + +class TestCaptureVarClass extends Serializable { + def run(): (Int, Int) = { + var zero = 0 + + withSpark(new SparkContext("local", "test")) {sc => + val ones = sc.parallelize(Array(1, 1, 1, 1, 1)) + val onesPlusZeroes = ones.map(_ + zero) + + zero = 5 + + (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _)) + } + } +}