Skip to content

SPARK-729: Closures not always serialized at capture time #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
16 changes: 11 additions & 5 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,9 @@ class SparkContext(
require(p >= 0 && p < rdd.partitions.size, s"Invalid partition requested: $p")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
// There's no need to check this function for serializability,
// since it will be run right away.
val cleanedFunc = clean(func, false)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO you might as well clone it here too, because it could be modified in other threads while the job is running.

logInfo("Starting job: " + callSite)
val start = System.nanoTime
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
Expand Down Expand Up @@ -1026,14 +1028,18 @@ class SparkContext(
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}

/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*
* @param f closure to be cleaned and optionally serialized
* @param captureNow whether or not to serialize this closure and capture any free
* variables immediately; defaults to true. If this is set and f is not serializable,
* it will raise an exception.
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
f
private[spark] def clean[F <: AnyRef : ClassTag](f: F, captureNow: Boolean = true): F = {
ClosureCleaner.clean(f, captureNow)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding two methods, just add a default argument to the first one (captureNow: Boolean = true).


/**
Expand Down
6 changes: 4 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -640,14 +640,16 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}

/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
sc.runJob(this, (iter: Iterator[T]) => f(iter))
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}

/**
Expand Down
21 changes: 20 additions & 1 deletion core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.Map
import scala.collection.mutable.Set

import scala.reflect.ClassTag

import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._

import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import org.apache.spark.SparkException

private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
Expand Down Expand Up @@ -101,7 +105,7 @@ private[spark] object ClosureCleaner extends Logging {
}
}

def clean(func: AnyRef) {
def clean[F <: AnyRef : ClassTag](func: F, captureNow: Boolean = true): F = {
// TODO: cache outerClasses / innerClasses / accessedFields
val outerClasses = getOuterClasses(func)
val innerClasses = getInnerClasses(func)
Expand Down Expand Up @@ -150,6 +154,21 @@ private[spark] object ClosureCleaner extends Logging {
field.setAccessible(true)
field.set(func, outer)
}

if (captureNow) {
cloneViaSerializing(func)
} else {
func
}
}

private def cloneViaSerializing[T: ClassTag](func: T): T = {
try {
val serializer = SparkEnv.get.closureSerializer.newInstance()
serializer.deserialize[T](serializer.serialize[T](func))
} catch {
case ex: Exception => throw new SparkException("Task not serializable: " + ex.toString)
}
}

private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
Expand Down
17 changes: 16 additions & 1 deletion core/src/test/scala/org/apache/spark/FailureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}

test("failure because task closure is not serializable") {
test("failure because closure in final-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

Expand All @@ -118,13 +118,27 @@ class FailureSuite extends FunSuite with LocalSparkContext {
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.contains("NotSerializableException"))

FailureSuiteState.clear()
}

test("failure because closure in early-stage task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in an earlier stage
val thrown1 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
}
assert(thrown1.getClass === classOf[SparkException])
assert(thrown1.getMessage.contains("NotSerializableException"))

FailureSuiteState.clear()
}

test("failure because closure in foreach task is not serializable") {
sc = new SparkContext("local[1,1]", "test")
val a = new NonSerializable

// Non-serializable closure in foreach function
val thrown2 = intercept[SparkException] {
sc.parallelize(1 to 10, 2).foreach(x => println(a))
Expand All @@ -135,6 +149,7 @@ class FailureSuite extends FunSuite with LocalSparkContext {
FailureSuiteState.clear()
}


// TODO: Need to add tests with shuffle fetch failures.
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.serializer;

import java.io.NotSerializableException

import org.scalatest.FunSuite

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkException
import org.apache.spark.SharedSparkContext

/* A trivial (but unserializable) container for trivial functions */
class UnserializableClass {
def op[T](x: T) = x.toString

def pred[T](x: T) = x.toString.length % 2 == 0
}

class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {

def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)

test("throws expected serialization exceptions on actions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
data.map(uc.op(_)).count
}

assert(ex.getMessage.matches(".*Task not serializable.*"))
}

// There is probably a cleaner way to eliminate boilerplate here, but we're
// iterating over a map from transformation names to functions that perform that
// transformation on a given RDD, creating one test case for each

for (transformation <-
Map("map" -> map _, "flatMap" -> flatMap _, "filter" -> filter _, "mapWith" -> mapWith _,
"mapPartitions" -> mapPartitions _, "mapPartitionsWithIndex" -> mapPartitionsWithIndex _,
"mapPartitionsWithContext" -> mapPartitionsWithContext _, "filterWith" -> filterWith _)) {
val (name, xf) = transformation

test(s"$name transformations throw proactive serialization exceptions") {
val (data, uc) = fixture

val ex = intercept[SparkException] {
xf(data, uc)
}

assert(ex.getMessage.matches(".*Task not serializable.*"), s"RDD.$name doesn't proactively throw NotSerializableException")
}
}

def map(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.map(y => uc.op(y))

def mapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapWith(x => x.toString)((x,y) => x + uc.op(y))

def flatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.flatMap(y=>Seq(uc.op(y)))

def filter(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filter(y=>uc.pred(y))

def filterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.filterWith(x => x.toString)((x,y) => uc.pred(y))

def mapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitions(_.map(y => uc.op(y)))

def mapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithIndex((_, it) => it.map(y => uc.op(y)))

def mapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
x.mapPartitionsWithContext((_, it) => it.map(y => uc.op(y)))

}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,27 @@ class ClosureCleanerSuite extends FunSuite {
val obj = new TestClassWithNesting(1)
assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
}

test("capturing free variables in closures at RDD definition") {
val obj = new TestCaptureVarClass()
val (ones, onesPlusZeroes) = obj.run()

assert(ones === onesPlusZeroes)
}

test("capturing free variable fields in closures at RDD definition") {
val obj = new TestCaptureFieldClass()
val (ones, onesPlusZeroes) = obj.run()

assert(ones === onesPlusZeroes)
}

test("capturing arrays in closures at RDD definition") {
val obj = new TestCaptureArrayEltClass()
val (observed, expected) = obj.run()

assert(observed === expected)
}
}

// A non-serializable class we create in closures to make sure that we aren't
Expand Down Expand Up @@ -143,3 +164,50 @@ class TestClassWithNesting(val y: Int) extends Serializable {
}
}
}

class TestCaptureFieldClass extends Serializable {
class ZeroBox extends Serializable {
var zero = 0
}

def run(): (Int, Int) = {
val zb = new ZeroBox

withSpark(new SparkContext("local", "test")) {sc =>
val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
val onesPlusZeroes = ones.map(_ + zb.zero)

zb.zero = 5

(ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
}
}
}

class TestCaptureArrayEltClass extends Serializable {
def run(): (Int, Int) = {
withSpark(new SparkContext("local", "test")) {sc =>
val rdd = sc.parallelize(1 to 10)
val data = Array(1, 2, 3)
val expected = data(0)
val mapped = rdd.map(x => data(0))
data(0) = 4
(mapped.first, expected)
}
}
}

class TestCaptureVarClass extends Serializable {
def run(): (Int, Int) = {
var zero = 0

withSpark(new SparkContext("local", "test")) {sc =>
val ones = sc.parallelize(Array(1, 1, 1, 1, 1))
val onesPlusZeroes = ones.map(_ + zero)

zero = 5

(ones.reduce(_ + _), onesPlusZeroes.reduce(_ + _))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class GraphSuite extends FunSuite with LocalSparkContext {
assert( graph.edges.count() === rawEdges.size )
// Vertices not explicitly provided but referenced by edges should be created automatically
assert( graph.vertices.count() === 100)
graph.triplets.map { et =>
graph.triplets.collect.map { et =>
assert((et.srcId < 10 && et.srcAttr) || (et.srcId >= 10 && !et.srcAttr))
assert((et.dstId < 10 && et.dstAttr) || (et.dstId >= 10 && !et.dstAttr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,15 +533,15 @@ abstract class DStream[T: ClassTag] (
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
}

/**
* Return a new DStream in which each RDD is generated by applying a function
* on each RDD of 'this' DStream.
*/
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
val cleanedF = context.sparkContext.clean(transformFunc)
val cleanedF = context.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 1)
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
Expand All @@ -556,7 +556,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
}

Expand All @@ -567,7 +567,7 @@ abstract class DStream[T: ClassTag] (
def transformWith[U: ClassTag, V: ClassTag](
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
): DStream[V] = {
val cleanedF = ssc.sparkContext.clean(transformFunc)
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
assert(rdds.length == 2)
val rdd1 = rdds(0).asInstanceOf[RDD[T]]
Expand Down