Skip to content

Commit 86f7823

Browse files
author
Andrew Or
committed
Implement transitive cleaning + add missing documentation
See in-code comments for more detail on what this means.
1 parent 4c722d7 commit 86f7823

File tree

1 file changed

+208
-41
lines changed

1 file changed

+208
-41
lines changed

core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala

Lines changed: 208 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@ package org.apache.spark.util
1919

2020
import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
2121

22-
import scala.collection.mutable.Map
23-
import scala.collection.mutable.Set
22+
import scala.collection.mutable.{Map, Set}
2423

2524
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
2625
import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
2726

2827
import org.apache.spark.{Logging, SparkEnv, SparkException}
2928

29+
/**
30+
* A cleaner that renders closures serializable if they can be done so safely.
31+
*/
3032
private[spark] object ClosureCleaner extends Logging {
33+
3134
// Get an ASM class reader for a given class from the JAR that loaded it
32-
private def getClassReader(cls: Class[_]): ClassReader = {
35+
def getClassReader(cls: Class[_]): ClassReader = {
3336
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
3437
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
3538
val resourceStream = cls.getResourceAsStream(className)
@@ -77,6 +80,9 @@ private[spark] object ClosureCleaner extends Logging {
7780
Nil
7881
}
7982

83+
/**
84+
* Return a list of classes that represent closures enclosed in the given closure object.
85+
*/
8086
private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
8187
val seen = Set[Class[_]](obj.getClass)
8288
var stack = List[Class[_]](obj.getClass)
@@ -101,21 +107,110 @@ private[spark] object ClosureCleaner extends Logging {
101107
}
102108
}
103109

104-
def clean(func: AnyRef, checkSerializable: Boolean = true) {
110+
/**
111+
* Clean the given closure in place.
112+
*
113+
* More specifically, this renders the given closure serializable as long as it does not
114+
* explicitly reference unserializable objects.
115+
*
116+
* @param closure the closure to clean
117+
* @param checkSerializable whether to verify that the closure is serializable after cleaning
118+
* @param cleanTransitively whether to clean enclosing closures transitively
119+
*/
120+
def clean(
121+
closure: AnyRef,
122+
checkSerializable: Boolean = true,
123+
cleanTransitively: Boolean = true): Unit = {
124+
clean(closure, checkSerializable, cleanTransitively, Map.empty)
125+
}
126+
127+
/**
128+
* Helper method to clean the given closure in place.
129+
*
130+
* The mechanism is to traverse the hierarchy of enclosing closures and null out any
131+
* references along the way that are not actually used by the starting closure, but are
132+
* nevertheless included in the compiled anonymous classes. Note that it is unsafe to
133+
* simply mutate the enclosing closures, as other code paths may depend on them. Instead,
134+
* we clone each enclosing closure and set the parent pointers accordingly.
135+
*
136+
* By default, closures are cleaned transitively. This means we detect whether enclosing
137+
* objects are actually referenced by the starting one, either directly or transitively,
138+
* and, if not, sever these closures from the hierarchy. In other words, in addition to
139+
* nulling out unused field references, we also null out any parent pointers that refer
140+
* to enclosing objects not actually needed by the starting closure.
141+
*
142+
* For instance, transitive cleaning is necessary in the following scenario:
143+
*
144+
* class SomethingNotSerializable {
145+
* def someValue = 1
146+
* def someMethod(): Unit = scope("one") {
147+
* def x = someValue
148+
* def y = 2
149+
* scope("two") { println(y + 1) }
150+
* }
151+
* def scope(name: String)(body: => Unit) = body
152+
* }
153+
*
154+
* In this example, scope "two" is not serializable because it references scope "one", which
155+
* references SomethingNotSerializable. Note that, however, scope "two" does not actually
156+
* depend on SomethingNotSerializable. This means we can null out the parent pointer of
157+
* a cloned scope "one" and set it the parent of scope "two", such that scope "two" no longer
158+
* references SomethingNotSerializable transitively.
159+
*
160+
* @param func the starting closure to clean
161+
* @param checkSerializable whether to verify that the closure is serializable after cleaning
162+
* @param cleanTransitively whether to clean enclosing closures transitively
163+
* @param accessedFields a map from a class to a set of its fields that are accessed by
164+
* the starting closure
165+
*/
166+
private def clean(
167+
func: AnyRef,
168+
checkSerializable: Boolean,
169+
cleanTransitively: Boolean,
170+
accessedFields: Map[Class[_], Set[String]]) {
171+
172+
// TODO: clean all inner closures first. This requires us to find the inner objects.
105173
// TODO: cache outerClasses / innerClasses / accessedFields
106-
val outerClasses = getOuterClasses(func)
174+
175+
logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++")
176+
177+
// A list of classes that represents closures enclosed in the given one
107178
val innerClasses = getInnerClasses(func)
179+
180+
// A list of enclosing objects and their respective classes, from innermost to outermost
181+
// An outer object at a given index is of type outer class at the same index
182+
val outerClasses = getOuterClasses(func)
108183
val outerObjects = getOuterObjects(func)
109184

110-
val accessedFields = Map[Class[_], Set[String]]()
111-
185+
logDebug(s" + inner classes: " + innerClasses.size)
186+
innerClasses.foreach { c => logDebug(" " + c.getName) }
187+
logDebug(s" + outer classes: " + outerClasses.size)
188+
outerClasses.foreach { c => logDebug(" " + c.getName) }
189+
logDebug(s" + outer objects: " + outerObjects.size)
190+
outerObjects.foreach { o => logDebug(" " + o) }
191+
192+
// Fail fast if we detect return statements in closures
112193
getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
113-
114-
for (cls <- outerClasses)
115-
accessedFields(cls) = Set[String]()
116-
for (cls <- func.getClass :: innerClasses)
117-
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
118-
// logInfo("accessedFields: " + accessedFields)
194+
195+
// If accessed fields is not populated yet, we assume that
196+
// the closure we are trying to clean is the starting one
197+
if (accessedFields.isEmpty) {
198+
logDebug(s" + populating accessed fields because this is the starting closure")
199+
// Initialize accessed fields with the outer classes first
200+
// This step is needed to associate the fields to the correct classes later
201+
for (cls <- outerClasses) {
202+
accessedFields(cls) = Set[String]()
203+
}
204+
// Populate accessed fields by visiting all fields and methods accessed by this and
205+
// all of its inner closures. If transitive cleaning is enabled, this may recursively
206+
// visits methods that belong to other classes in search of transitively referenced fields.
207+
for (cls <- func.getClass :: innerClasses) {
208+
getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
209+
}
210+
}
211+
212+
logDebug(s" + fields accessed by starting closure: " + accessedFields.size)
213+
accessedFields.foreach { f => logDebug(" " + f) }
119214

120215
val inInterpreter = {
121216
try {
@@ -126,34 +221,66 @@ private[spark] object ClosureCleaner extends Logging {
126221
}
127222
}
128223

224+
// List of outer (class, object) pairs, ordered from outermost to innermost
225+
// Note that all outer objects but the outermost one (first one in this list) must be closures
129226
var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse
130-
var outer: AnyRef = null
227+
var parent: AnyRef = null
131228
if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
132229
// The closure is ultimately nested inside a class; keep the object of that
133230
// class without cloning it since we don't want to clone the user's objects.
134-
outer = outerPairs.head._2
231+
// Note that we still need to keep around the outermost object itself because
232+
// we need it to clone its child closure later (see below).
233+
logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}")
234+
parent = outerPairs.head._2 // e.g. SparkContext
135235
outerPairs = outerPairs.tail
236+
} else if (outerPairs.size > 0) {
237+
logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}")
238+
} else {
239+
logDebug(" + there are no enclosing objects!")
136240
}
241+
137242
// Clone the closure objects themselves, nulling out any fields that are not
138243
// used in the closure we're working on or any of its inner closures.
139244
for ((cls, obj) <- outerPairs) {
140-
outer = instantiateClass(cls, outer, inInterpreter)
245+
logDebug(s" + cloning the object $obj of class ${cls.getName}")
246+
// We null out these unused references by cloning each object and then filling in all
247+
// required fields from the original object. We need the parent here because the Java
248+
// language specification requires the first constructor parameter of any closure to be
249+
// its enclosing object.
250+
val clone = instantiateClass(cls, parent, inInterpreter)
141251
for (fieldName <- accessedFields(cls)) {
142252
val field = cls.getDeclaredField(fieldName)
143253
field.setAccessible(true)
144254
val value = field.get(obj)
145-
// logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
146-
field.set(outer, value)
255+
field.set(clone, value)
147256
}
257+
// If transitive cleaning is enabled, we recursively clean any enclosing closure using
258+
// the already populated accessed fields map of the starting closure
259+
if (cleanTransitively && isClosure(clone.getClass)) {
260+
logDebug(s" + cleaning cloned closure $clone recursively (${cls.getName})")
261+
clean(clone, checkSerializable, cleanTransitively, accessedFields)
262+
}
263+
parent = clone
148264
}
149265

150-
if (outer != null) {
151-
// logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
266+
// Update the parent pointer ($outer) of this closure
267+
if (parent != null) {
152268
val field = func.getClass.getDeclaredField("$outer")
153269
field.setAccessible(true)
154-
field.set(func, outer)
270+
// If the starting closure doesn't actually need our enclosing object, then just null it out
271+
if (accessedFields.contains(func.getClass) &&
272+
!accessedFields(func.getClass).contains("$outer")) {
273+
logDebug(s" + the starting closure doesn't actually need $parent, so we null it out")
274+
field.set(func, null)
275+
} else {
276+
// Update this closure's parent pointer to point to our enclosing object,
277+
// which could either be a cloned closure or the original user object
278+
field.set(func, parent)
279+
}
155280
}
156-
281+
282+
logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned +++")
283+
157284
if (checkSerializable) {
158285
ensureSerializable(func)
159286
}
@@ -167,15 +294,17 @@ private[spark] object ClosureCleaner extends Logging {
167294
}
168295
}
169296

170-
private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = {
171-
// logInfo("Creating a " + cls + " with outer = " + outer)
297+
private def instantiateClass(
298+
cls: Class[_],
299+
enclosingObject: AnyRef,
300+
inInterpreter: Boolean): AnyRef = {
172301
if (!inInterpreter) {
173302
// This is a bona fide closure class, whose constructor has no effects
174303
// other than to set its fields, so use its constructor
175304
val cons = cls.getConstructors()(0)
176305
val params = cons.getParameterTypes.map(createNullValue).toArray
177-
if (outer != null) {
178-
params(0) = outer // First param is always outer object
306+
if (enclosingObject!= null) {
307+
params(0) = enclosingObject // First param is always enclosing object
179308
}
180309
return cons.newInstance(params: _*).asInstanceOf[AnyRef]
181310
} else {
@@ -184,11 +313,10 @@ private[spark] object ClosureCleaner extends Logging {
184313
val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
185314
val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
186315
val obj = newCtor.newInstance().asInstanceOf[AnyRef]
187-
if (outer != null) {
188-
// logInfo("3: Setting $outer on " + cls + " to " + outer);
316+
if (enclosingObject != null) {
189317
val field = cls.getDeclaredField("$outer")
190318
field.setAccessible(true)
191-
field.set(obj, outer)
319+
field.set(obj, enclosingObject)
192320
}
193321
obj
194322
}
@@ -213,29 +341,68 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) {
213341
}
214342
}
215343

344+
/**
345+
* Find the fields accessed by a given class.
346+
*
347+
* The fields are stored in the mutable map passed in by the class that contains them.
348+
* This map is assumed to have its keys already populated by the classes of interest.
349+
*
350+
* @param fields the mutable map that stores the fields to return
351+
* @param specificMethodNames if not empty, only visit methods whose names are in this set
352+
* @param findTransitively if true, find fields indirectly referenced in other classes
353+
*/
216354
private[spark]
217-
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
218-
override def visitMethod(access: Int, name: String, desc: String,
219-
sig: String, exceptions: Array[String]): MethodVisitor = {
355+
class FieldAccessFinder(
356+
fields: Map[Class[_], Set[String]],
357+
specificMethodNames: Set[String] = Set.empty,
358+
findTransitively: Boolean = true)
359+
extends ClassVisitor(ASM4) {
360+
361+
override def visitMethod(
362+
access: Int,
363+
name: String,
364+
desc: String,
365+
sig: String,
366+
exceptions: Array[String]): MethodVisitor = {
367+
368+
// Ignore this method if we don't want to visit it
369+
if (specificMethodNames.nonEmpty && !specificMethodNames.contains(name)) {
370+
return new MethodVisitor(ASM4) { }
371+
}
372+
220373
new MethodVisitor(ASM4) {
221374
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
222375
if (op == GETFIELD) {
223-
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
224-
output(cl) += name
376+
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
377+
fields(cl) += name
225378
}
226379
}
227380
}
228381

229-
override def visitMethodInsn(op: Int, owner: String, name: String,
230-
desc: String) {
231-
// Check for calls a getter method for a variable in an interpreter wrapper object.
232-
// This means that the corresponding field will be accessed, so we should save it.
233-
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
234-
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
235-
output(cl) += name
382+
override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) {
383+
if (isInvoke(op)) {
384+
for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
385+
// Check for calls a getter method for a variable in an interpreter wrapper object.
386+
// This means that the corresponding field will be accessed, so we should save it.
387+
if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && !name.endsWith("$outer")) {
388+
fields(cl) += name
389+
}
390+
// Visit other methods to find fields that are transitively referenced
391+
if (findTransitively) {
392+
ClosureCleaner.getClassReader(cl)
393+
.accept(new FieldAccessFinder(fields, Set(name), findTransitively), 0)
394+
}
236395
}
237396
}
238397
}
398+
399+
private def isInvoke(op: Int): Boolean = {
400+
op == INVOKEVIRTUAL ||
401+
op == INVOKESPECIAL ||
402+
op == INVOKEDYNAMIC ||
403+
op == INVOKEINTERFACE ||
404+
op == INVOKESTATIC
405+
}
239406
}
240407
}
241408
}

0 commit comments

Comments
 (0)