Skip to content

spark-729: predictable closure capture #1322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
51 changes: 43 additions & 8 deletions core/src/main/scala/org/apache/spark/ContextCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark
import java.lang.ref.{ReferenceQueue, WeakReference}

import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.reflect.ClassTag
import scala.util.DynamicVariable

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -63,6 +65,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
with SynchronizedBuffer[CleanerListener]

private val cleaningThread = new Thread() { override def run() { keepCleaning() }}

private var broadcastRefCounts = Map(0L -> 0L)

/**
* Whether the cleaning thread will block on cleanup tasks.
Expand Down Expand Up @@ -102,9 +106,25 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Register a Broadcast for cleanup when it is garbage collected. */
def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) {
incBroadcastRefCount(broadcast.id)
registerForCleanup(broadcast, CleanBroadcast(broadcast.id))
}

private def incBroadcastRefCount[T](bid: Long) {
val newRefCount: Long = this.broadcastRefCounts.getOrElse(bid, 0L) + 1
this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, newRefCount)
}

private def decBroadcastRefCount[T](bid: Long) = {
this.broadcastRefCounts.get(bid) match {
case Some(rc:Long) if rc > 0 => {
this.broadcastRefCounts = this.broadcastRefCounts + Pair(bid, rc - 1)
rc - 1
}
case _ => 0
}
}

/** Register an object for cleanup. */
private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) {
referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue)
Expand Down Expand Up @@ -161,14 +181,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {

/** Perform broadcast cleanup. */
def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) {
try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, true, blocking)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
decBroadcastRefCount(broadcastId) match {
case x if x > 0 => {}
case _ => try {
logDebug("Cleaning broadcast " + broadcastId)
broadcastManager.unbroadcast(broadcastId, true, blocking)
listeners.foreach(_.broadcastCleaned(broadcastId))
logInfo("Cleaned broadcast " + broadcastId)
} catch {
case e: Exception => logError("Error cleaning broadcast " + broadcastId, e)
}
}

}

private def blockManagerMaster = sc.env.blockManager.master
Expand All @@ -179,8 +203,19 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
// to ensure that more reliable testing.
}

private object ContextCleaner {
private[spark] object ContextCleaner {
private val REF_QUEUE_POLL_TIMEOUT = 100
val currentCleaner = new DynamicVariable[Option[ContextCleaner]](None)

/**
* Runs the given thunk with a dynamically-scoped binding for the current ContextCleaner.
* This is necessary for blocks of code that serialize and deserialize broadcast variable
* objects, since all clones of a Broadcast object <tt>b</tt> need to be re-registered with the
* context cleaner that is tracking <tt>b</tt>.
*/
def withCurrentCleaner[T <: Any : ClassTag](cc: Option[ContextCleaner])(thunk: => T) = {
currentCleaner.withValue(cc)(thunk)
}
}

/**
Expand Down
9 changes: 5 additions & 4 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,9 @@ class SparkContext(config: SparkConf) extends Logging {
throw new SparkException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
// There's no need to check this function for serializability,
// since it will be run right away.
val cleanedFunc = clean(func, false)
logInfo("Starting job: " + callSite.short)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
Expand Down Expand Up @@ -1212,9 +1214,8 @@ class SparkContext(config: SparkConf) extends Logging {
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
* serializable
*/
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
ClosureCleaner.clean(f, checkSerializable)
f
private[spark] def clean[F <: AnyRef : ClassTag](f: F, checkSerializable: Boolean = true): F = {
ClosureCleaner.clean(f, checkSerializable, this)
}

/**
Expand Down
12 changes: 10 additions & 2 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

package org.apache.spark.broadcast

import java.io.Serializable
import java.io.{ObjectInputStream, Serializable}

import org.apache.spark.SparkException
import org.apache.spark.{ContextCleaner, SparkException}

import scala.reflect.ClassTag

Expand Down Expand Up @@ -129,4 +129,12 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
}

override def toString = "Broadcast(" + id + ")"

private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
ContextCleaner.currentCleaner.value match {
case None => {}
case Some(cc: ContextCleaner) => cc.registerBroadcastForCleanup(this)
}
}
}
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -733,14 +733,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
sc.runJob(this, (iter: Iterator[T]) => f(iter))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}

/**
Expand Down
21 changes: 14 additions & 7 deletions core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set

import scala.reflect.ClassTag

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.{Logging, SparkEnv, SparkException, SparkContext, ContextCleaner}

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -100,8 +102,8 @@ private[spark] object ClosureCleaner extends Logging {
null
}
}

def clean(func: AnyRef, checkSerializable: Boolean = true) {
def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true, sc: SparkContext): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
Expand Down Expand Up @@ -154,14 +156,19 @@ private[spark] object ClosureCleaner extends Logging {
field.set(func, outer)
}

if (checkSerializable) {
ensureSerializable(func)
if (captureNow) {
ContextCleaner.withCurrentCleaner(sc.cleaner){
cloneViaSerializing(func)
}
} else {
func
}
}

private def ensureSerializable(func: AnyRef) {
private def cloneViaSerializing[T: ClassTag](func: T): T = {
try {
SparkEnv.get.closureSerializer.newInstance().serialize(func)
val serializer = SparkEnv.get.closureSerializer.newInstance()
serializer.deserialize[T](serializer.serialize[T](func))
} catch {
case ex: Exception => throw new SparkException("Task not serializable", ex)
}
Expand Down
37 changes: 36 additions & 1 deletion core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.lang.ref.WeakReference
import scala.collection.mutable.{HashSet, SynchronizedSet}
import scala.language.existentials
import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.Random

import org.scalatest.{BeforeAndAfter, FunSuite}
Expand Down Expand Up @@ -141,6 +142,33 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
postGCTester.assertCleanup()
}

test("automatically cleanup broadcast only after all extant clones become unreachable") {
var broadcast = newBroadcast

// clone this broadcast variable
var broadcastClone = cloneBySerializing(broadcast)

val id = broadcast.id

// eliminate all strong references to the original broadcast; keep the clone
broadcast = null

// Test that GC does not cause broadcast cleanup since a strong reference to a
// clone of the broadcast with the given id still exist
val preGCTester = new CleanerTester(sc, broadcastIds = Seq(id))
runGC()
intercept[Exception] {
preGCTester.assertCleanup()(timeout(1000 millis))
}

// Test that GC causes broadcast cleanup after dereferencing the clone
val postGCTester = new CleanerTester(sc, broadcastIds = Seq(id))
broadcastClone = null
runGC()
postGCTester.assertCleanup()
}


test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
Expand Down Expand Up @@ -242,7 +270,14 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
Thread.sleep(200)
}
}


def cloneBySerializing[T <: Any : ClassTag](ref: T): T = {
val serializer = SparkEnv.get.closureSerializer.newInstance()
ContextCleaner.withCurrentCleaner[T](sc.cleaner){
serializer.deserialize(serializer.serialize(ref))
}
}

def cleaner = sc.cleaner.get
}

Expand Down
17 changes: 16 additions & 1 deletion core/src/test/scala/org/apache/spark/FailureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}

test("failure because task closure is not serializable") {
test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

Expand All @@ -122,6 +122,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getMessage.contains("NotSerializableException") ||
thrown.getCause.getClass === classOf[NotSerializableException])

FailureSuiteState.clear()
}

test("failure because closure in early-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
Expand All @@ -130,6 +137,13 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown1.getMessage.contains("NotSerializableException") ||
thrown1.getCause.getClass === classOf[NotSerializableException])

FailureSuiteState.clear()
}

test("failure because closure in foreach task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
Expand All @@ -141,5 +155,6 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}


// TODO: Need to add tests with shuffle fetch failures.
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
// transformation on a given RDD, creating one test case for each

for (transformation <-
Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _,
"mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _,
"mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
"mapPartitionsWithContext" -> xmapPartitionsWithContext _,
"filterWith" -> xfilterWith _)) {
Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _,
"mapWith" -> mapWith _, "mapPartitions" -> mapPartitions _,
"mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
"mapPartitionsWithContext" -> mapPartitionsWithContext _,
"filterWith" -> filterWith _)) {
val (name, xf) = transformation

test(s"$name transformations throw proactive serialization exceptions") {
Expand All @@ -70,21 +70,28 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex
}
}

private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y=>uc.op(y))
private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y => uc.op(y))

def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y) => x + uc.op(y))

def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))
private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =

def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))
private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y=>uc.op(y)))
private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))

def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y) => uc.pred(y))

def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y => uc.op(y)))

def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))

}
Loading