diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index bf3c3a6ceb5ef..26872b197d232 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -20,6 +20,8 @@ package org.apache.spark
import java.lang.ref.{ReferenceQueue, WeakReference}
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
+import scala.reflect.ClassTag
+import scala.util.DynamicVariable
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
@@ -63,6 +65,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
with SynchronizedBuffer[CleanerListener]
private val cleaningThread = new Thread() { override def run() { keepCleaning() }}
+
+ private var broadcastRefCounts = Map(0L -> 0L)
/**
* Whether the cleaning thread will block on cleanup tasks.
@@ -102,9 +106,25 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
+ incBroadcastRefCount(broadcast.id)
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}
+ private def incBroadcastRefCount[T](bid: Long) {
+ val newRefCount: Long = this.broadcastRefCounts.getOrElse(bid, 0L) + 1
+ this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, newRefCount)
+ }
+
+ private def decBroadcastRefCount[T](bid: Long) = {
+ this.broadcastRefCounts.get(bid) match {
+ case Some(rc:Long) if rc > 0 => {
+ this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, rc - 1)
+ rc - 1
+ }
+ case _ => 0
+ }
+ }
+
/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
@@ -161,14 +181,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
/** Perform broadcast cleanup. */
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
- try {
- logDebug("Cleaning broadcast " + broadcastId)
- broadcastManager.unbroadcast(broadcastId, true, blocking)
- listeners.foreach(_.broadcastCleaned(broadcastId))
- logInfo("Cleaned broadcast " + broadcastId)
- } catch {
- case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
+ decBroadcastRefCount(broadcastId) match {
+ case x if x > 0 => {}
+ case _ => try {
+ logDebug("Cleaning broadcast " + broadcastId)
+ broadcastManager.unbroadcast(broadcastId, true, blocking)
+ listeners.foreach(_.broadcastCleaned(broadcastId))
+ logInfo("Cleaned broadcast " + broadcastId)
+ } catch {
+ case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
+ }
}
+
}
private def blockManagerMaster = sc.env.blockManager.master
@@ -179,8 +203,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
// to ensure that more reliable testing.
}
-private object ContextCleaner {
+private[spark] object ContextCleaner {
private val REF_QUEUE_POLL_TIMEOUT = 100
+ val currentCleaner = new DynamicVariable[Option[ContextCleaner]](None)
+
+ /**
+ * Runs the given thunk with a dynamically-scoped binding for the current ContextCleaner.
+ * This is necessary for blocks of code that serialize and deserialize broadcast variable
+ * objects, since all clones of a Broadcast object b need to be re-registered with the
+ * context cleaner that is tracking b.
+ */
+ def withCurrentCleaner[T <: Any : ClassTag](cc: Option[ContextCleaner])(thunk: => T) = {
+ currentCleaner.withValue(cc)(thunk)
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8819e73d17fb2..9c93be0b6518a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1058,7 +1058,9 @@ class SparkContext(config: SparkConf) extends Logging {
throw new SparkException("SparkContext has been shutdown")
}
val callSite = getCallSite
- val cleanedFunc = clean(func)
+ // There's no need to check this function for serializability,
+ // since it will be run right away.
+ val cleanedFunc = clean(func, false)
logInfo("Starting job: " + callSite.short)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
@@ -1212,9 +1214,8 @@ class SparkContext(config: SparkConf) extends Logging {
* @throws SparkException if checkSerializable is set but f is not
* serializable
*/
- private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
- ClosureCleaner.clean(f, checkSerializable)
- f
+ private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean = true): F = {
+ ClosureCleaner.clean(f, checkSerializable, this)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 76956f6a345d1..ee0936d83ab8f 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -17,9 +17,9 @@
package org.apache.spark.broadcast
-import java.io.Serializable
+import java.io.{ObjectInputStream, Serializable}
-import org.apache.spark.SparkException
+import org.apache.spark.{ContextCleaner, SparkException}
import scala.reflect.ClassTag
@@ -129,4 +129,12 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
}
override def toString = "Broadcast(" + id + ")"
+
+ private def readObject(in: ObjectInputStream) {
+ in.defaultReadObject()
+ ContextCleaner.currentCleaner.value match {
+ case None => {}
+ case Some(cc: ContextCleaner) => cc.registerBroadcastForCleanup(this)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 4e841bc992bff..a025934f46435 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -733,14 +733,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
- sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
- sc.runJob(this, (iter: Iterator[T]) => f(iter))
+ val cleanF = sc.clean(f)
+ sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index e3f52f6ff1e63..3be9e418a9132 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -22,10 +22,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set
+import scala.reflect.ClassTag
+
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
-import org.apache.spark.{Logging, SparkEnv, SparkException}
+import org.apache.spark.{Logging, SparkEnv, SparkException, SparkContext, ContextCleaner}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -100,8 +102,8 @@ private[spark] object ClosureCleaner extends Logging {
null
}
}
-
- def clean(func: AnyRef, checkSerializable: Boolean = true) {
+
+ def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true, sc: SparkContext): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
@@ -154,14 +156,19 @@ private[spark] object ClosureCleaner extends Logging {
field.set(func, outer)
}
- if (checkSerializable) {
- ensureSerializable(func)
+ if (captureNow) {
+ ContextCleaner.withCurrentCleaner(sc.cleaner){
+ cloneViaSerializing(func)
+ }
+ } else {
+ func
}
}
- private def ensureSerializable(func: AnyRef) {
+ private def cloneViaSerializing[T: ClassTag](func: T): T = {
try {
- SparkEnv.get.closureSerializer.newInstance().serialize(func)
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ serializer.deserialize[T](serializer.serialize[T](func))
} catch {
case ex: Exception => throw new SparkException("Task not serializable", ex)
}
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 13b415cccb647..0d6c0377be75a 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -22,6 +22,7 @@ import java.lang.ref.WeakReference
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
import scala.language.postfixOps
+import scala.reflect.ClassTag
import scala.util.Random
import org.scalatest.{BeforeAndAfter, FunSuite}
@@ -141,6 +142,33 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
postGCTester.assertCleanup()
}
+ test("automatically cleanup broadcast only after all extant clones become unreachable") {
+ var broadcast = newBroadcast
+
+ // clone this broadcast variable
+ var broadcastClone = cloneBySerializing(broadcast)
+
+ val id = broadcast.id
+
+ // eliminate all strong references to the original broadcast; keep the clone
+ broadcast = null
+
+ // Test that GC does not cause broadcast cleanup since a strong reference to a
+ // clone of the broadcast with the given id still exist
+ val preGCTester = new CleanerTester(sc, broadcastIds = Seq(id))
+ runGC()
+ intercept[Exception] {
+ preGCTester.assertCleanup()(timeout(1000 millis))
+ }
+
+ // Test that GC causes broadcast cleanup after dereferencing the clone
+ val postGCTester = new CleanerTester(sc, broadcastIds = Seq(id))
+ broadcastClone = null
+ runGC()
+ postGCTester.assertCleanup()
+ }
+
+
test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
@@ -242,7 +270,14 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
Thread.sleep(200)
}
}
-
+
+ def cloneBySerializing[T <: Any : ClassTag](ref: T): T = {
+ val serializer = SparkEnv.get.closureSerializer.newInstance()
+ ContextCleaner.withCurrentCleaner[T](sc.cleaner){
+ serializer.deserialize(serializer.serialize(ref))
+ }
+ }
+
def cleaner = sc.cleaner.get
}
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index e755d2e309398..67b1d84ac15b0 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -110,7 +110,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
- test("failure because task closure is not serializable") {
+ test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable
@@ -122,6 +122,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getMessage.contains("NotSerializableException") ||
thrown.getCause.getClass === classOf[NotSerializableException])
+ FailureSuiteState.clear()
+ }
+
+ test("failure because closure in early-stage task is not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val a = new NonSerializable
+
// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
@@ -130,6 +137,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown1.getMessage.contains("NotSerializableException") ||
thrown1.getCause.getClass === classOf[NotSerializableException])
+ FailureSuiteState.clear()
+ }
+
+ test("failure because closure in foreach task is not serializable") {
+ sc = new SparkContext("local[1,1]", "test")
+ val a = new NonSerializable
+
// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
@@ -141,5 +155,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}
+
// TODO: Need to add tests with shuffle fetch failures.
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
index 5d15a68ac7e4f..304f60c013d92 100644
--- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala
@@ -51,11 +51,11 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
// transformation on a given RDD, creating one test case for each
for (transformation <-
- Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _,
- "mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _,
- "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
- "mapPartitionsWithContext" -> xmapPartitionsWithContext _,
- "filterWith" -> xfilterWith _)) {
+ Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _,
+ "mapWith" -> mapWith _, "mapPartitions" -> mapPartitions _,
+ "mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
+ "mapPartitionsWithContext" -> mapPartitionsWithContext _,
+ "filterWith" -> filterWith _)) {
val (name, xf) = transformation
test(s"$name transformations throw proactive serialization exceptions") {
@@ -70,21 +70,28 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
}
}
- private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.map(y=>uc.op(y))
- private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
- private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.map(y => uc.op(y))
+
+ def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapWith(x => x.toString)((x,y) => x + uc.op(y))
+
+ def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))
- private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =
+
+ def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))
- private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
- private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitions(_.map(y=>uc.op(y)))
- private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
- private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
- x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))
+
+ def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.filterWith(x => x.toString)((x,y) => uc.pred(y))
+
+ def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitions(_.map(y => uc.op(y)))
+
+ def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))
+
+ def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
+ x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))
}
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 054ef54e746a5..781c562d93f05 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -63,6 +63,20 @@ class ClosureCleanerSuite extends FunSuite {
val result = TestObjectWithNestedReturns.run()
assert(result == 1)
}
+
+ test("capturing free variables in closures at RDD definition") {
+ val obj = new TestCaptureVarClass()
+ val (ones, onesPlusZeroes) = obj.run()
+
+ assert(ones === onesPlusZeroes)
+ }
+
+ test("capturing free variable fields in closures at RDD definition") {
+ val obj = new TestCaptureFieldClass()
+ val (ones, onesPlusZeroes) = obj.run()
+
+ assert(ones === onesPlusZeroes)
+ }
}
// A non-serializable class we create in closures to make sure that we aren't
@@ -180,3 +194,37 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}
+
+class TestCaptureFieldClass extends Serializable {
+ class ZeroBox extends Serializable {
+ var zero = 0
+ }
+
+ def run(): (Int, Int) = {
+ val zb = new ZeroBox
+
+ withSpark(new SparkContext("local", "test")) {sc =>
+ val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
+ val onesPlusZeroes = ones.map(_ + zb.zero)
+
+ zb.zero = 5
+
+ (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
+ }
+ }
+}
+
+class TestCaptureVarClass extends Serializable {
+ def run(): (Int, Int) = {
+ var zero = 0
+
+ withSpark(new SparkContext("local", "test")) {sc =>
+ val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
+ val onesPlusZeroes = ones.map(_ + zero)
+
+ zero = 5
+
+ (ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
+ }
+ }
+}