Skip to content

Commit a484030

Browse files
willbrxin
authored andcommitted
SPARK-897: preemptively serialize closures
These commits cause `ClosureCleaner.clean` to attempt to serialize the cleaned closure with the default closure serializer and throw a `SparkException` if doing so fails. This behavior is enabled by default but can be disabled at individual callsites of `SparkContext.clean`. Commit 98e01ae fixes some no-op assertions in `GraphSuite` that this work exposed; I'm happy to put that in a separate PR if that would be more appropriate. Author: William Benton <[email protected]> Closes #143 from willb/spark-897 and squashes the following commits: bceab8a [William Benton] Commented DStream corner cases for serializability checking. 64d04d2 [William Benton] FailureSuite now checks both messages and causes. 3b3f74a [William Benton] Stylistic and doc cleanups. b215dea [William Benton] Fixed spurious failures in ImplicitOrderingSuite be1ecd6 [William Benton] Don't check serializability of DStream transforms. abe816b [William Benton] Make proactive serializability checking optional. 5bfff24 [William Benton] Adds proactive closure-serializablilty checking ed2ccf0 [William Benton] Test cases for SPARK-897.
1 parent 66135a3 commit a484030

File tree

6 files changed

+196
-36
lines changed

6 files changed

+196
-36
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,9 +1203,17 @@ class SparkContext(config: SparkConf) extends Logging {
12031203
/**
12041204
* Clean a closure to make it ready to serialized and send to tasks
12051205
* (removes unreferenced variables in $outer's, updates REPL variables)
1206+
* If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
1207+
* check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
1208+
* if not.
1209+
*
1210+
* @param f the closure to clean
1211+
* @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
1212+
* @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not
1213+
* serializable
12061214
*/
1207-
private[spark] def clean[F <: AnyRef](f: F): F = {
1208-
ClosureCleaner.clean(f)
1215+
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = {
1216+
ClosureCleaner.clean(f, checkSerializable)
12091217
f
12101218
}
12111219

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import scala.collection.mutable.Set
2525
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
2626
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
2727

28-
import org.apache.spark.{Logging, SparkException}
28+
import org.apache.spark.{Logging, SparkEnv, SparkException}
2929

3030
private[spark] object ClosureCleaner extends Logging {
3131
// Get an ASM class reader for a given class from the JAR that loaded it
@@ -101,7 +101,7 @@ private[spark] object ClosureCleaner extends Logging {
101101
}
102102
}
103103

104-
def clean(func: AnyRef) {
104+
def clean(func: AnyRef, checkSerializable: Boolean = true) {
105105
// TODO: cache outerClasses / innerClasses / accessedFields
106106
val outerClasses = getOuterClasses(func)
107107
val innerClasses = getInnerClasses(func)
@@ -153,6 +153,18 @@ private[spark] object ClosureCleaner extends Logging {
153153
field.setAccessible(true)
154154
field.set(func, outer)
155155
}
156+
157+
if (checkSerializable) {
158+
ensureSerializable(func)
159+
}
160+
}
161+
162+
private def ensureSerializable(func: AnyRef) {
163+
try {
164+
SparkEnv.get.closureSerializer.newInstance().serialize(func)
165+
} catch {
166+
case ex: Exception => throw new SparkException("Task not serializable", ex)
167+
}
156168
}
157169

158170
private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {

core/src/test/scala/org/apache/spark/FailureSuite.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import org.scalatest.FunSuite
2222
import org.apache.spark.SparkContext._
2323
import org.apache.spark.util.NonSerializable
2424

25+
import java.io.NotSerializableException
26+
2527
// Common state shared by FailureSuite-launched tasks. We use a global object
2628
// for this because any local variables used in the task closures will rightfully
2729
// be copied for each task, so there's no other way for them to share state.
@@ -102,7 +104,8 @@ class FailureSuite extends FunSuite with LocalSparkContext {
102104
results.collect()
103105
}
104106
assert(thrown.getClass === classOf[SparkException])
105-
assert(thrown.getMessage.contains("NotSerializableException"))
107+
assert(thrown.getMessage.contains("NotSerializableException") ||
108+
thrown.getCause.getClass === classOf[NotSerializableException])
106109

107110
FailureSuiteState.clear()
108111
}
@@ -116,21 +119,24 @@ class FailureSuite extends FunSuite with LocalSparkContext {
116119
sc.parallelize(1 to 10, 2).map(x => a).count()
117120
}
118121
assert(thrown.getClass === classOf[SparkException])
119-
assert(thrown.getMessage.contains("NotSerializableException"))
122+
assert(thrown.getMessage.contains("NotSerializableException") ||
123+
thrown.getCause.getClass === classOf[NotSerializableException])
120124

121125
// Non-serializable closure in an earlier stage
122126
val thrown1 = intercept[SparkException] {
123127
sc.parallelize(1 to 10, 2).map(x => (x, a)).partitionBy(new HashPartitioner(3)).count()
124128
}
125129
assert(thrown1.getClass === classOf[SparkException])
126-
assert(thrown1.getMessage.contains("NotSerializableException"))
130+
assert(thrown1.getMessage.contains("NotSerializableException") ||
131+
thrown1.getCause.getClass === classOf[NotSerializableException])
127132

128133
// Non-serializable closure in foreach function
129134
val thrown2 = intercept[SparkException] {
130135
sc.parallelize(1 to 10, 2).foreach(x => println(a))
131136
}
132137
assert(thrown2.getClass === classOf[SparkException])
133-
assert(thrown2.getMessage.contains("NotSerializableException"))
138+
assert(thrown2.getMessage.contains("NotSerializableException") ||
139+
thrown2.getCause.getClass === classOf[NotSerializableException])
134140

135141
FailureSuiteState.clear()
136142
}

core/src/test/scala/org/apache/spark/ImplicitOrderingSuite.scala

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,29 @@ package org.apache.spark
1919

2020
import org.scalatest.FunSuite
2121

22+
import org.apache.spark.rdd.RDD
2223
import org.apache.spark.SparkContext._
2324

2425
class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
26+
// Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
27+
test("basic inference of Orderings"){
28+
sc = new SparkContext("local", "test")
29+
val rdd = sc.parallelize(1 to 10)
30+
31+
// These RDD methods are in the companion object so that the unserializable ScalaTest Engine
32+
// won't be reachable from the closure object
33+
34+
// Infer orderings after basic maps to particular types
35+
val basicMapExpectations = ImplicitOrderingSuite.basicMapExpectations(rdd)
36+
basicMapExpectations.map({case (met, explain) => assert(met, explain)})
37+
38+
// Infer orderings for other RDD methods
39+
val otherRDDMethodExpectations = ImplicitOrderingSuite.otherRDDMethodExpectations(rdd)
40+
otherRDDMethodExpectations.map({case (met, explain) => assert(met, explain)})
41+
}
42+
}
43+
44+
private object ImplicitOrderingSuite {
2545
class NonOrderedClass {}
2646

2747
class ComparableClass extends Comparable[ComparableClass] {
@@ -31,27 +51,36 @@ class ImplicitOrderingSuite extends FunSuite with LocalSparkContext {
3151
class OrderedClass extends Ordered[OrderedClass] {
3252
override def compare(o: OrderedClass): Int = ???
3353
}
34-
35-
// Tests that PairRDDFunctions grabs an implicit Ordering in various cases where it should.
36-
test("basic inference of Orderings"){
37-
sc = new SparkContext("local", "test")
38-
val rdd = sc.parallelize(1 to 10)
39-
40-
// Infer orderings after basic maps to particular types
41-
assert(rdd.map(x => (x, x)).keyOrdering.isDefined)
42-
assert(rdd.map(x => (1, x)).keyOrdering.isDefined)
43-
assert(rdd.map(x => (x.toString, x)).keyOrdering.isDefined)
44-
assert(rdd.map(x => (null, x)).keyOrdering.isDefined)
45-
assert(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty)
46-
assert(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined)
47-
assert(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined)
48-
49-
// Infer orderings for other RDD methods
50-
assert(rdd.groupBy(x => x).keyOrdering.isDefined)
51-
assert(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty)
52-
assert(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined)
53-
assert(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined)
54-
assert(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined)
55-
assert(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined)
54+
55+
def basicMapExpectations(rdd: RDD[Int]) = {
56+
List((rdd.map(x => (x, x)).keyOrdering.isDefined,
57+
"rdd.map(x => (x, x)).keyOrdering.isDefined"),
58+
(rdd.map(x => (1, x)).keyOrdering.isDefined,
59+
"rdd.map(x => (1, x)).keyOrdering.isDefined"),
60+
(rdd.map(x => (x.toString, x)).keyOrdering.isDefined,
61+
"rdd.map(x => (x.toString, x)).keyOrdering.isDefined"),
62+
(rdd.map(x => (null, x)).keyOrdering.isDefined,
63+
"rdd.map(x => (null, x)).keyOrdering.isDefined"),
64+
(rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty,
65+
"rdd.map(x => (new NonOrderedClass, x)).keyOrdering.isEmpty"),
66+
(rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined,
67+
"rdd.map(x => (new ComparableClass, x)).keyOrdering.isDefined"),
68+
(rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined,
69+
"rdd.map(x => (new OrderedClass, x)).keyOrdering.isDefined"))
5670
}
57-
}
71+
72+
def otherRDDMethodExpectations(rdd: RDD[Int]) = {
73+
List((rdd.groupBy(x => x).keyOrdering.isDefined,
74+
"rdd.groupBy(x => x).keyOrdering.isDefined"),
75+
(rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty,
76+
"rdd.groupBy(x => new NonOrderedClass).keyOrdering.isEmpty"),
77+
(rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined,
78+
"rdd.groupBy(x => new ComparableClass).keyOrdering.isDefined"),
79+
(rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined,
80+
"rdd.groupBy(x => new OrderedClass).keyOrdering.isDefined"),
81+
(rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined,
82+
"rdd.groupBy((x: Int) => x, 5).keyOrdering.isDefined"),
83+
(rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined,
84+
"rdd.groupBy((x: Int) => x, new HashPartitioner(5)).keyOrdering.isDefined"))
85+
}
86+
}
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.serializer;
19+
20+
import java.io.NotSerializableException
21+
22+
import org.scalatest.FunSuite
23+
24+
import org.apache.spark.rdd.RDD
25+
import org.apache.spark.SparkException
26+
import org.apache.spark.SharedSparkContext
27+
28+
/* A trivial (but unserializable) container for trivial functions */
29+
class UnserializableClass {
30+
def op[T](x: T) = x.toString
31+
32+
def pred[T](x: T) = x.toString.length % 2 == 0
33+
}
34+
35+
class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContext {
36+
37+
def fixture = (sc.parallelize(0 until 1000).map(_.toString), new UnserializableClass)
38+
39+
test("throws expected serialization exceptions on actions") {
40+
val (data, uc) = fixture
41+
42+
val ex = intercept[SparkException] {
43+
data.map(uc.op(_)).count
44+
}
45+
46+
assert(ex.getMessage.contains("Task not serializable"))
47+
}
48+
49+
// There is probably a cleaner way to eliminate boilerplate here, but we're
50+
// iterating over a map from transformation names to functions that perform that
51+
// transformation on a given RDD, creating one test case for each
52+
53+
for (transformation <-
54+
Map("map" -> xmap _, "flatMap" -> xflatMap _, "filter" -> xfilter _,
55+
"mapWith" -> xmapWith _, "mapPartitions" -> xmapPartitions _,
56+
"mapPartitionsWithIndex" -> xmapPartitionsWithIndex _,
57+
"mapPartitionsWithContext" -> xmapPartitionsWithContext _,
58+
"filterWith" -> xfilterWith _)) {
59+
val (name, xf) = transformation
60+
61+
test(s"$name transformations throw proactive serialization exceptions") {
62+
val (data, uc) = fixture
63+
64+
val ex = intercept[SparkException] {
65+
xf(data, uc)
66+
}
67+
68+
assert(ex.getMessage.contains("Task not serializable"),
69+
s"RDD.$name doesn't proactively throw NotSerializableException")
70+
}
71+
}
72+
73+
private def xmap(x: RDD[String], uc: UnserializableClass): RDD[String] =
74+
x.map(y=>uc.op(y))
75+
private def xmapWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
76+
x.mapWith(x => x.toString)((x,y)=>x + uc.op(y))
77+
private def xflatMap(x: RDD[String], uc: UnserializableClass): RDD[String] =
78+
x.flatMap(y=>Seq(uc.op(y)))
79+
private def xfilter(x: RDD[String], uc: UnserializableClass): RDD[String] =
80+
x.filter(y=>uc.pred(y))
81+
private def xfilterWith(x: RDD[String], uc: UnserializableClass): RDD[String] =
82+
x.filterWith(x => x.toString)((x,y)=>uc.pred(y))
83+
private def xmapPartitions(x: RDD[String], uc: UnserializableClass): RDD[String] =
84+
x.mapPartitions(_.map(y=>uc.op(y)))
85+
private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] =
86+
x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y)))
87+
private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] =
88+
x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y)))
89+
90+
}

streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -532,23 +532,32 @@ abstract class DStream[T: ClassTag] (
532532
* 'this' DStream will be registered as an output stream and therefore materialized.
533533
*/
534534
def foreachRDD(foreachFunc: (RDD[T], Time) => Unit) {
535-
new ForEachDStream(this, context.sparkContext.clean(foreachFunc)).register()
535+
// because the DStream is reachable from the outer object here, and because
536+
// DStreams can't be serialized with closures, we can't proactively check
537+
// it for serializability and so we pass the optional false to SparkContext.clean
538+
new ForEachDStream(this, context.sparkContext.clean(foreachFunc, false)).register()
536539
}
537540

538541
/**
539542
* Return a new DStream in which each RDD is generated by applying a function
540543
* on each RDD of 'this' DStream.
541544
*/
542545
def transform[U: ClassTag](transformFunc: RDD[T] => RDD[U]): DStream[U] = {
543-
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r)))
546+
// because the DStream is reachable from the outer object here, and because
547+
// DStreams can't be serialized with closures, we can't proactively check
548+
// it for serializability and so we pass the optional false to SparkContext.clean
549+
transform((r: RDD[T], t: Time) => context.sparkContext.clean(transformFunc(r), false))
544550
}
545551

546552
/**
547553
* Return a new DStream in which each RDD is generated by applying a function
548554
* on each RDD of 'this' DStream.
549555
*/
550556
def transform[U: ClassTag](transformFunc: (RDD[T], Time) => RDD[U]): DStream[U] = {
551-
val cleanedF = context.sparkContext.clean(transformFunc)
557+
// because the DStream is reachable from the outer object here, and because
558+
// DStreams can't be serialized with closures, we can't proactively check
559+
// it for serializability and so we pass the optional false to SparkContext.clean
560+
val cleanedF = context.sparkContext.clean(transformFunc, false)
552561
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
553562
assert(rdds.length == 1)
554563
cleanedF(rdds.head.asInstanceOf[RDD[T]], time)
@@ -563,7 +572,10 @@ abstract class DStream[T: ClassTag] (
563572
def transformWith[U: ClassTag, V: ClassTag](
564573
other: DStream[U], transformFunc: (RDD[T], RDD[U]) => RDD[V]
565574
): DStream[V] = {
566-
val cleanedF = ssc.sparkContext.clean(transformFunc)
575+
// because the DStream is reachable from the outer object here, and because
576+
// DStreams can't be serialized with closures, we can't proactively check
577+
// it for serializability and so we pass the optional false to SparkContext.clean
578+
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
567579
transformWith(other, (rdd1: RDD[T], rdd2: RDD[U], time: Time) => cleanedF(rdd1, rdd2))
568580
}
569581

@@ -574,7 +586,10 @@ abstract class DStream[T: ClassTag] (
574586
def transformWith[U: ClassTag, V: ClassTag](
575587
other: DStream[U], transformFunc: (RDD[T], RDD[U], Time) => RDD[V]
576588
): DStream[V] = {
577-
val cleanedF = ssc.sparkContext.clean(transformFunc)
589+
// because the DStream is reachable from the outer object here, and because
590+
// DStreams can't be serialized with closures, we can't proactively check
591+
// it for serializability and so we pass the optional false to SparkContext.clean
592+
val cleanedF = ssc.sparkContext.clean(transformFunc, false)
578593
val realTransformFunc = (rdds: Seq[RDD[_]], time: Time) => {
579594
assert(rdds.length == 2)
580595
val rdd1 = rdds(0).asInstanceOf[RDD[T]]

0 commit comments

Comments
 (0)