Skip to content

Commit cc36487

Browse files
committed
[SPARK-3046] use executor's class loader as the default serializer classloader
The serializer is not always used in an executor thread (e.g. connection manager, broadcast), in which case the classloader might not have the user jar set, leading to corruption in deserialization. https://issues.apache.org/jira/browse/SPARK-3046 https://issues.apache.org/jira/browse/SPARK-2878 Author: Reynold Xin <[email protected]> Closes #1972 from rxin/kryoBug and squashes the following commits: c1c7bf0 [Reynold Xin] Made change to JavaSerializer. 7204c33 [Reynold Xin] Added imports back. d879e67 [Reynold Xin] [SPARK-3046] use executor's class loader as the default serializer class loader.
1 parent c703229 commit cc36487

File tree

6 files changed

+128
-4
lines changed

6 files changed

+128
-4
lines changed

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ private[spark] class Executor(
9999
private val urlClassLoader = createClassLoader()
100100
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
101101

102+
// Set the classloader for serializer
103+
env.serializer.setDefaultClassLoader(urlClassLoader)
104+
102105
// Akka's message frame size. If task result is bigger than this, we use the block manager
103106
// to send the result back.
104107
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ extends DeserializationStream {
6363
def close() { objIn.close() }
6464
}
6565

66-
private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
66+
private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
67+
extends SerializerInstance {
68+
6769
def serialize[T: ClassTag](t: T): ByteBuffer = {
6870
val bos = new ByteArrayOutputStream()
6971
val out = serializeStream(bos)
@@ -109,7 +111,10 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
109111
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
110112
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
111113

112-
def newInstance(): SerializerInstance = new JavaSerializerInstance(counterReset)
114+
override def newInstance(): SerializerInstance = {
115+
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
116+
new JavaSerializerInstance(counterReset, classLoader)
117+
}
113118

114119
override def writeExternal(out: ObjectOutput) {
115120
out.writeInt(counterReset)

core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ class KryoSerializer(conf: SparkConf)
6161
val instantiator = new EmptyScalaKryoInstantiator
6262
val kryo = instantiator.newKryo()
6363
kryo.setRegistrationRequired(registrationRequired)
64-
val classLoader = Thread.currentThread.getContextClassLoader
64+
65+
val oldClassLoader = Thread.currentThread.getContextClassLoader
66+
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
6567

6668
// Allow disabling Kryo reference tracking if user knows their object graphs don't have loops.
6769
// Do this before we invoke the user registrator so the user registrator can override this.
@@ -84,10 +86,15 @@ class KryoSerializer(conf: SparkConf)
8486
try {
8587
val reg = Class.forName(regCls, true, classLoader).newInstance()
8688
.asInstanceOf[KryoRegistrator]
89+
90+
// Use the default classloader when calling the user registrator.
91+
Thread.currentThread.setContextClassLoader(classLoader)
8792
reg.registerClasses(kryo)
8893
} catch {
8994
case e: Exception =>
9095
throw new SparkException(s"Failed to invoke $regCls", e)
96+
} finally {
97+
Thread.currentThread.setContextClassLoader(oldClassLoader)
9198
}
9299
}
93100

core/src/main/scala/org/apache/spark/serializer/Serializer.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
4444
*/
4545
@DeveloperApi
4646
trait Serializer {
47+
48+
/**
49+
* Default ClassLoader to use in deserialization. Implementations of [[Serializer]] should
50+
* make sure it is using this when set.
51+
*/
52+
@volatile protected var defaultClassLoader: Option[ClassLoader] = None
53+
54+
/**
55+
* Sets a class loader for the serializer to use in deserialization.
56+
*
57+
* @return this Serializer object
58+
*/
59+
def setDefaultClassLoader(classLoader: ClassLoader): Serializer = {
60+
defaultClassLoader = Some(classLoader)
61+
this
62+
}
63+
4764
def newInstance(): SerializerInstance
4865
}
4966

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.serializer
19+
20+
import org.apache.spark.util.Utils
21+
22+
import com.esotericsoftware.kryo.Kryo
23+
import org.scalatest.FunSuite
24+
25+
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, TestUtils}
26+
import org.apache.spark.SparkContext._
27+
import org.apache.spark.serializer.KryoDistributedTest._
28+
29+
class KryoSerializerDistributedSuite extends FunSuite {
30+
31+
test("kryo objects are serialised consistently in different processes") {
32+
val conf = new SparkConf(false)
33+
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
34+
conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName)
35+
conf.set("spark.task.maxFailures", "1")
36+
37+
val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName))
38+
conf.setJars(List(jar.getPath))
39+
40+
val sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
41+
val original = Thread.currentThread.getContextClassLoader
42+
val loader = new java.net.URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader)
43+
SparkEnv.get.serializer.setDefaultClassLoader(loader)
44+
45+
val cachedRDD = sc.parallelize((0 until 10).map((_, new MyCustomClass)), 3).cache()
46+
47+
// Randomly mix the keys so that the join below will require a shuffle with each partition
48+
// sending data to multiple other partitions.
49+
val shuffledRDD = cachedRDD.map { case (i, o) => (i * i * i - 10 * i * i, o)}
50+
51+
// Join the two RDDs, and force evaluation
52+
assert(shuffledRDD.join(cachedRDD).collect().size == 1)
53+
54+
LocalSparkContext.stop(sc)
55+
}
56+
}
57+
58+
object KryoDistributedTest {
59+
class MyCustomClass
60+
61+
class AppJarRegistrator extends KryoRegistrator {
62+
override def registerClasses(k: Kryo) {
63+
val classLoader = Thread.currentThread.getContextClassLoader
64+
k.register(Class.forName(AppJarRegistrator.customClassName, true, classLoader))
65+
}
66+
}
67+
68+
object AppJarRegistrator {
69+
val customClassName = "KryoSerializerDistributedSuiteCustomClass"
70+
}
71+
}

core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
2323
import com.esotericsoftware.kryo.Kryo
2424
import org.scalatest.FunSuite
2525

26-
import org.apache.spark.SharedSparkContext
26+
import org.apache.spark.{SparkConf, SharedSparkContext}
2727
import org.apache.spark.serializer.KryoTest._
2828

2929
class KryoSerializerSuite extends FunSuite with SharedSparkContext {
@@ -217,8 +217,29 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
217217
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
218218
assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
219219
}
220+
221+
test("default class loader can be set by a different thread") {
222+
val ser = new KryoSerializer(new SparkConf)
223+
224+
// First serialize the object
225+
val serInstance = ser.newInstance()
226+
val bytes = serInstance.serialize(new ClassLoaderTestingObject)
227+
228+
// Deserialize the object to make sure normal deserialization works
229+
serInstance.deserialize[ClassLoaderTestingObject](bytes)
230+
231+
// Set a special, broken ClassLoader and make sure we get an exception on deserialization
232+
ser.setDefaultClassLoader(new ClassLoader() {
233+
override def loadClass(name: String) = throw new UnsupportedOperationException
234+
})
235+
intercept[UnsupportedOperationException] {
236+
ser.newInstance().deserialize[ClassLoaderTestingObject](bytes)
237+
}
238+
}
220239
}
221240

241+
class ClassLoaderTestingObject
242+
222243
class KryoSerializerResizableOutputSuite extends FunSuite {
223244
import org.apache.spark.SparkConf
224245
import org.apache.spark.SparkContext

0 commit comments

Comments
 (0)