Skip to content

Commit b349b77

Browse files
committed
[SPARK-5307] SerializationDebugger to help debug NotSerializableException - take 2
This patch adds a SerializationDebugger that is used to add serialization path to a NotSerializableException. When a NotSerializableException is encountered, the debugger visits the object graph to find the path towards the object that cannot be serialized, and constructs information to help user to find the object. Compared with an earlier attempt, this one provides extra information including field names, array offsets, writeExternal calls, etc.
1 parent 1955645 commit b349b77

File tree

3 files changed

+437
-1
lines changed

3 files changed

+437
-1
lines changed

core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
3939
* the stream 'resets' object class descriptions have to be re-written)
4040
*/
4141
def writeObject[T: ClassTag](t: T): SerializationStream = {
42-
objOut.writeObject(t)
42+
try {
43+
objOut.writeObject(t)
44+
} catch {
45+
case e: NotSerializableException =>
46+
throw SerializationDebugger.improveException(t, e)
47+
}
4348
counter += 1
4449
if (counterReset > 0 && counter >= counterReset) {
4550
objOut.reset()
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.serializer
19+
20+
import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField}
21+
import java.lang.reflect.{Field, Method}
22+
import java.security.AccessController
23+
24+
import scala.annotation.tailrec
25+
import scala.collection.mutable
26+
27+
28+
private[serializer] object SerializationDebugger {
29+
30+
/**
31+
* Improve the given NotSerializableException with the serialization path leading from the given
32+
* object to the problematic object.
33+
*/
34+
def improveException(obj: Any, e: NotSerializableException): NotSerializableException = {
35+
if (enableDebugging) {
36+
new NotSerializableException(
37+
e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t-" + _).mkString("\n"))
38+
} else {
39+
e
40+
}
41+
}
42+
43+
/**
44+
* Find the path leading to a not serializable object. This method is modeled after OpenJDK's
45+
* serialization mechanism, and handles the following cases:
46+
* - primitives
47+
* - arrays of primitives
48+
* - arrays of non-primitive objects
49+
* - Serializable objects
50+
* - Externalizable objects
51+
* - writeReplace
52+
*
53+
* It does not yet handle writeObject override, but that shouldn't be too hard to do either.
54+
*/
55+
def find(obj: Any): List[String] = {
56+
new SerializationDebugger().visit(obj, List.empty)
57+
}
58+
59+
private[serializer] var enableDebugging: Boolean = {
60+
!AccessController.doPrivileged(new sun.security.action.GetBooleanAction(
61+
"sun.io.serialization.extendedDebugInfo")).booleanValue()
62+
}
63+
64+
private class SerializationDebugger {
65+
66+
/** A set to track the list of objects we have visited, to avoid cycles in the graph. */
67+
private val visited = new mutable.HashSet[Any]
68+
69+
/**
70+
* Visit the object and its fields and stop when we find an object that is not serializable.
71+
* Return the path as a list. If everything can be serialized, return an empty list.
72+
*/
73+
def visit(o: Any, stack: List[String]): List[String] = {
74+
if (o == null) {
75+
List.empty
76+
} else if (visited.contains(o)) {
77+
List.empty
78+
} else {
79+
visited += o
80+
o match {
81+
// Primitive value, string, and primitive arrays are always serializable
82+
case _ if o.getClass.isPrimitive => List.empty
83+
case _: String => List.empty
84+
case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty
85+
86+
// Traverse non primitive array.
87+
case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive =>
88+
val elem = s"array (class ${a.getClass.getName}, size ${a.length})"
89+
visitArray(o.asInstanceOf[Array[_]], elem :: stack)
90+
91+
case e: java.io.Externalizable =>
92+
val elem = s"externalizable object (class ${e.getClass.getName}, $e)"
93+
visitExternalizable(e, elem :: stack)
94+
95+
case s: Object with java.io.Serializable =>
96+
val elem = s"object (class ${s.getClass.getName}, $s)"
97+
visitSerializable(s, elem :: stack)
98+
99+
case _ =>
100+
// Found an object that is not serializable!
101+
s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack
102+
}
103+
}
104+
}
105+
106+
private def visitArray(o: Array[_], stack: List[String]): List[String] = {
107+
var i = 0
108+
while (i < o.length) {
109+
val childStack = visit(o(i), s"element of array (index: $i)" :: stack)
110+
if (childStack.nonEmpty) {
111+
return childStack
112+
}
113+
i += 1
114+
}
115+
return List.empty
116+
}
117+
118+
private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
119+
{
120+
val fieldList = new ListObjectOutput
121+
o.writeExternal(fieldList)
122+
val childObjects = fieldList.outputArray
123+
var i = 0
124+
while (i < childObjects.length) {
125+
val childStack = visit(childObjects(i), "writeExternal data" :: stack)
126+
if (childStack.nonEmpty) {
127+
return childStack
128+
}
129+
i += 1
130+
}
131+
return List.empty
132+
}
133+
134+
private def visitSerializable(o: Object, stack: List[String]): List[String] = {
135+
// An object contains multiple slots in serialization.
136+
// Get the slots and visit fields in all of them.
137+
val (finalObj, desc) = findObjectAndDescriptor(o)
138+
val slotDescs = desc.getSlotDescs
139+
var i = 0
140+
while (i < slotDescs.length) {
141+
val slotDesc = slotDescs(i)
142+
if (slotDesc.hasWriteObjectMethod) {
143+
// TODO: Handle classes that specify writeObject method.
144+
} else {
145+
val fields: Array[ObjectStreamField] = slotDesc.getFields
146+
val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
147+
val numPrims = fields.length - objFieldValues.length
148+
desc.getObjFieldValues(finalObj, objFieldValues)
149+
150+
var j = 0
151+
while (j < objFieldValues.length) {
152+
val fieldDesc = fields(numPrims + j)
153+
val elem = s"field (class: ${slotDesc.getName}" +
154+
s", name: ${fieldDesc.getName}" +
155+
s", type: ${fieldDesc.getType})"
156+
val childStack = visit(objFieldValues(j), elem :: stack)
157+
if (childStack.nonEmpty) {
158+
return childStack
159+
}
160+
j += 1
161+
}
162+
163+
}
164+
i += 1
165+
}
166+
return List.empty
167+
}
168+
}
169+
170+
/**
171+
* Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
172+
* writeReplace in Serializable. It starts with the object itself, and keeps calling the
173+
* writeReplace method until there is no more
174+
*/
175+
@tailrec
176+
private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
177+
val cl = o.getClass
178+
val desc = ObjectStreamClass.lookupAny(cl)
179+
if (!desc.hasWriteReplaceMethod) {
180+
(o, desc)
181+
} else {
182+
// write place
183+
findObjectAndDescriptor(desc.invokeWriteReplace(o))
184+
}
185+
}
186+
187+
/**
188+
* A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal
189+
* call, and returns them through `outputArray`.
190+
*/
191+
private class ListObjectOutput extends ObjectOutput {
192+
private val output = new mutable.ArrayBuffer[Any]
193+
def outputArray: Array[Any] = output.toArray
194+
override def writeObject(o: Any): Unit = output += o
195+
override def flush(): Unit = {}
196+
override def write(i: Int): Unit = {}
197+
override def write(bytes: Array[Byte]): Unit = {}
198+
override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {}
199+
override def close(): Unit = {}
200+
override def writeFloat(v: Float): Unit = {}
201+
override def writeChars(s: String): Unit = {}
202+
override def writeDouble(v: Double): Unit = {}
203+
override def writeUTF(s: String): Unit = {}
204+
override def writeShort(i: Int): Unit = {}
205+
override def writeInt(i: Int): Unit = {}
206+
override def writeBoolean(b: Boolean): Unit = {}
207+
override def writeBytes(s: String): Unit = {}
208+
override def writeChar(i: Int): Unit = {}
209+
override def writeLong(l: Long): Unit = {}
210+
override def writeByte(i: Int): Unit = {}
211+
}
212+
213+
/** An implicit class that allows us to call private methods of ObjectStreamClass. */
214+
implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
215+
def getSlotDescs: Array[ObjectStreamClass] = {
216+
objectStreamClassGetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map {
217+
classDataSlot => classDataSlotDesc.get(classDataSlot).asInstanceOf[ObjectStreamClass]
218+
}
219+
}
220+
221+
def hasWriteObjectMethod: Boolean = {
222+
objectStreamClassHasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean]
223+
}
224+
225+
def hasWriteReplaceMethod: Boolean = {
226+
objectStreamClassHasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean]
227+
}
228+
229+
def invokeWriteReplace(obj: Object): Object = {
230+
objectStreamClassInvokeWriteReplace.invoke(desc, obj)
231+
}
232+
233+
def getNumObjFields: Int = {
234+
objectStreamClassGetNumObjFields.invoke(desc).asInstanceOf[Int]
235+
}
236+
237+
def getObjFieldValues(obj: Object, out: Array[Object]): Unit = {
238+
objectStreamClassGetObjFieldValues.invoke(desc, obj, out)
239+
}
240+
}
241+
242+
/** ObjectStreamClass.getClassDataLayout */
243+
private val objectStreamClassGetClassDataLayout: Method = {
244+
val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout")
245+
f.setAccessible(true)
246+
f
247+
}
248+
249+
/** ObjectStreamClass.hasWriteObjectMethod */
250+
private val objectStreamClassHasWriteObjectMethod: Method = {
251+
val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod")
252+
f.setAccessible(true)
253+
f
254+
}
255+
256+
/** ObjectStreamClass.hasWriteReplaceMethod */
257+
private val objectStreamClassHasWriteReplaceMethod: Method = {
258+
val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod")
259+
f.setAccessible(true)
260+
f
261+
}
262+
263+
/** ObjectStreamClass.invokeWriteReplace */
264+
private val objectStreamClassInvokeWriteReplace: Method = {
265+
val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object])
266+
f.setAccessible(true)
267+
f
268+
}
269+
270+
/** ObjectStreamClass.getNumObjFields */
271+
private val objectStreamClassGetNumObjFields: Method = {
272+
val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields")
273+
f.setAccessible(true)
274+
f
275+
}
276+
277+
/** ObjectStreamClass.getObjFieldValues */
278+
private val objectStreamClassGetObjFieldValues: Method = {
279+
val f = classOf[ObjectStreamClass].getDeclaredMethod(
280+
"getObjFieldValues", classOf[Object], classOf[Array[Object]])
281+
f.setAccessible(true)
282+
f
283+
}
284+
285+
/** ObjectStreamClass$ClassDataSlot.desc field */
286+
private val classDataSlotDesc: Field = {
287+
val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
288+
f.setAccessible(true)
289+
f
290+
}
291+
}

0 commit comments

Comments
 (0)