@@ -19,17 +19,20 @@ package org.apache.spark.util
19
19
20
20
import java .io .{ByteArrayInputStream , ByteArrayOutputStream }
21
21
22
- import scala .collection .mutable .Map
23
- import scala .collection .mutable .Set
22
+ import scala .collection .mutable .{Map , Set }
24
23
25
24
import com .esotericsoftware .reflectasm .shaded .org .objectweb .asm .{ClassReader , ClassVisitor , MethodVisitor , Type }
26
25
import com .esotericsoftware .reflectasm .shaded .org .objectweb .asm .Opcodes ._
27
26
28
27
import org .apache .spark .{Logging , SparkEnv , SparkException }
29
28
29
+ /**
30
+ * A cleaner that renders closures serializable if they can be done so safely.
31
+ */
30
32
private [spark] object ClosureCleaner extends Logging {
33
+
31
34
// Get an ASM class reader for a given class from the JAR that loaded it
32
- private def getClassReader (cls : Class [_]): ClassReader = {
35
+ def getClassReader (cls : Class [_]): ClassReader = {
33
36
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
34
37
val className = cls.getName.replaceFirst(" ^.*\\ ." , " " ) + " .class"
35
38
val resourceStream = cls.getResourceAsStream(className)
@@ -77,6 +80,9 @@ private[spark] object ClosureCleaner extends Logging {
77
80
Nil
78
81
}
79
82
83
+ /**
84
+ * Return a list of classes that represent closures enclosed in the given closure object.
85
+ */
80
86
private def getInnerClasses (obj : AnyRef ): List [Class [_]] = {
81
87
val seen = Set [Class [_]](obj.getClass)
82
88
var stack = List [Class [_]](obj.getClass)
@@ -101,21 +107,110 @@ private[spark] object ClosureCleaner extends Logging {
101
107
}
102
108
}
103
109
104
- def clean (func : AnyRef , checkSerializable : Boolean = true ) {
110
+ /**
111
+ * Clean the given closure in place.
112
+ *
113
+ * More specifically, this renders the given closure serializable as long as it does not
114
+ * explicitly reference unserializable objects.
115
+ *
116
+ * @param closure the closure to clean
117
+ * @param checkSerializable whether to verify that the closure is serializable after cleaning
118
+ * @param cleanTransitively whether to clean enclosing closures transitively
119
+ */
120
+ def clean (
121
+ closure : AnyRef ,
122
+ checkSerializable : Boolean = true ,
123
+ cleanTransitively : Boolean = true ): Unit = {
124
+ clean(closure, checkSerializable, cleanTransitively, Map .empty)
125
+ }
126
+
127
+ /**
128
+ * Helper method to clean the given closure in place.
129
+ *
130
+ * The mechanism is to traverse the hierarchy of enclosing closures and null out any
131
+ * references along the way that are not actually used by the starting closure, but are
132
+ * nevertheless included in the compiled anonymous classes. Note that it is unsafe to
133
+ * simply mutate the enclosing closures, as other code paths may depend on them. Instead,
134
+ * we clone each enclosing closure and set the parent pointers accordingly.
135
+ *
136
+ * By default, closures are cleaned transitively. This means we detect whether enclosing
137
+ * objects are actually referenced by the starting one, either directly or transitively,
138
+ * and, if not, sever these closures from the hierarchy. In other words, in addition to
139
+ * nulling out unused field references, we also null out any parent pointers that refer
140
+ * to enclosing objects not actually needed by the starting closure.
141
+ *
142
+ * For instance, transitive cleaning is necessary in the following scenario:
143
+ *
144
+ * class SomethingNotSerializable {
145
+ * def someValue = 1
146
+ * def someMethod(): Unit = scope("one") {
147
+ * def x = someValue
148
+ * def y = 2
149
+ * scope("two") { println(y + 1) }
150
+ * }
151
+ * def scope(name: String)(body: => Unit) = body
152
+ * }
153
+ *
154
+ * In this example, scope "two" is not serializable because it references scope "one", which
155
+ * references SomethingNotSerializable. Note that, however, scope "two" does not actually
156
+ * depend on SomethingNotSerializable. This means we can null out the parent pointer of
157
+ * a cloned scope "one" and set it the parent of scope "two", such that scope "two" no longer
158
+ * references SomethingNotSerializable transitively.
159
+ *
160
+ * @param func the starting closure to clean
161
+ * @param checkSerializable whether to verify that the closure is serializable after cleaning
162
+ * @param cleanTransitively whether to clean enclosing closures transitively
163
+ * @param accessedFields a map from a class to a set of its fields that are accessed by
164
+ * the starting closure
165
+ */
166
+ private def clean (
167
+ func : AnyRef ,
168
+ checkSerializable : Boolean ,
169
+ cleanTransitively : Boolean ,
170
+ accessedFields : Map [Class [_], Set [String ]]) {
171
+
172
+ // TODO: clean all inner closures first. This requires us to find the inner objects.
105
173
// TODO: cache outerClasses / innerClasses / accessedFields
106
- val outerClasses = getOuterClasses(func)
174
+
175
+ logDebug(s " +++ Cleaning closure $func ( ${func.getClass.getName}}) +++ " )
176
+
177
+ // A list of classes that represents closures enclosed in the given one
107
178
val innerClasses = getInnerClasses(func)
179
+
180
+ // A list of enclosing objects and their respective classes, from innermost to outermost
181
+ // An outer object at a given index is of type outer class at the same index
182
+ val outerClasses = getOuterClasses(func)
108
183
val outerObjects = getOuterObjects(func)
109
184
110
- val accessedFields = Map [Class [_], Set [String ]]()
111
-
185
+ logDebug(s " + inner classes: " + innerClasses.size)
186
+ innerClasses.foreach { c => logDebug(" " + c.getName) }
187
+ logDebug(s " + outer classes: " + outerClasses.size)
188
+ outerClasses.foreach { c => logDebug(" " + c.getName) }
189
+ logDebug(s " + outer objects: " + outerObjects.size)
190
+ outerObjects.foreach { o => logDebug(" " + o) }
191
+
192
+ // Fail fast if we detect return statements in closures
112
193
getClassReader(func.getClass).accept(new ReturnStatementFinder (), 0 )
113
-
114
- for (cls <- outerClasses)
115
- accessedFields(cls) = Set [String ]()
116
- for (cls <- func.getClass :: innerClasses)
117
- getClassReader(cls).accept(new FieldAccessFinder (accessedFields), 0 )
118
- // logInfo("accessedFields: " + accessedFields)
194
+
195
+ // If accessed fields is not populated yet, we assume that
196
+ // the closure we are trying to clean is the starting one
197
+ if (accessedFields.isEmpty) {
198
+ logDebug(s " + populating accessed fields because this is the starting closure " )
199
+ // Initialize accessed fields with the outer classes first
200
+ // This step is needed to associate the fields to the correct classes later
201
+ for (cls <- outerClasses) {
202
+ accessedFields(cls) = Set [String ]()
203
+ }
204
+ // Populate accessed fields by visiting all fields and methods accessed by this and
205
+ // all of its inner closures. If transitive cleaning is enabled, this may recursively
206
+ // visits methods that belong to other classes in search of transitively referenced fields.
207
+ for (cls <- func.getClass :: innerClasses) {
208
+ getClassReader(cls).accept(new FieldAccessFinder (accessedFields), 0 )
209
+ }
210
+ }
211
+
212
+ logDebug(s " + fields accessed by starting closure: " + accessedFields.size)
213
+ accessedFields.foreach { f => logDebug(" " + f) }
119
214
120
215
val inInterpreter = {
121
216
try {
@@ -126,34 +221,66 @@ private[spark] object ClosureCleaner extends Logging {
126
221
}
127
222
}
128
223
224
+ // List of outer (class, object) pairs, ordered from outermost to innermost
225
+ // Note that all outer objects but the outermost one (first one in this list) must be closures
129
226
var outerPairs : List [(Class [_], AnyRef )] = (outerClasses zip outerObjects).reverse
130
- var outer : AnyRef = null
227
+ var parent : AnyRef = null
131
228
if (outerPairs.size > 0 && ! isClosure(outerPairs.head._1)) {
132
229
// The closure is ultimately nested inside a class; keep the object of that
133
230
// class without cloning it since we don't want to clone the user's objects.
134
- outer = outerPairs.head._2
231
+ // Note that we still need to keep around the outermost object itself because
232
+ // we need it to clone its child closure later (see below).
233
+ logDebug(s " + outermost object is not a closure, so do not clone it: ${outerPairs.head}" )
234
+ parent = outerPairs.head._2 // e.g. SparkContext
135
235
outerPairs = outerPairs.tail
236
+ } else if (outerPairs.size > 0 ) {
237
+ logDebug(s " + outermost object is a closure, so we just keep it: ${outerPairs.head}" )
238
+ } else {
239
+ logDebug(" + there are no enclosing objects!" )
136
240
}
241
+
137
242
// Clone the closure objects themselves, nulling out any fields that are not
138
243
// used in the closure we're working on or any of its inner closures.
139
244
for ((cls, obj) <- outerPairs) {
140
- outer = instantiateClass(cls, outer, inInterpreter)
245
+ logDebug(s " + cloning the object $obj of class ${cls.getName}" )
246
+ // We null out these unused references by cloning each object and then filling in all
247
+ // required fields from the original object. We need the parent here because the Java
248
+ // language specification requires the first constructor parameter of any closure to be
249
+ // its enclosing object.
250
+ val clone = instantiateClass(cls, parent, inInterpreter)
141
251
for (fieldName <- accessedFields(cls)) {
142
252
val field = cls.getDeclaredField(fieldName)
143
253
field.setAccessible(true )
144
254
val value = field.get(obj)
145
- // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
146
- field.set(outer, value)
255
+ field.set(clone, value)
147
256
}
257
+ // If transitive cleaning is enabled, we recursively clean any enclosing closure using
258
+ // the already populated accessed fields map of the starting closure
259
+ if (cleanTransitively && isClosure(clone.getClass)) {
260
+ logDebug(s " + cleaning cloned closure $clone recursively ( ${cls.getName}) " )
261
+ clean(clone, checkSerializable, cleanTransitively, accessedFields)
262
+ }
263
+ parent = clone
148
264
}
149
265
150
- if ( outer != null ) {
151
- // logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
266
+ // Update the parent pointer ($ outer) of this closure
267
+ if (parent != null ) {
152
268
val field = func.getClass.getDeclaredField(" $outer" )
153
269
field.setAccessible(true )
154
- field.set(func, outer)
270
+ // If the starting closure doesn't actually need our enclosing object, then just null it out
271
+ if (accessedFields.contains(func.getClass) &&
272
+ ! accessedFields(func.getClass).contains(" $outer" )) {
273
+ logDebug(s " + the starting closure doesn't actually need $parent, so we null it out " )
274
+ field.set(func, null )
275
+ } else {
276
+ // Update this closure's parent pointer to point to our enclosing object,
277
+ // which could either be a cloned closure or the original user object
278
+ field.set(func, parent)
279
+ }
155
280
}
156
-
281
+
282
+ logDebug(s " +++ closure $func ( ${func.getClass.getName}) is now cleaned +++ " )
283
+
157
284
if (checkSerializable) {
158
285
ensureSerializable(func)
159
286
}
@@ -167,15 +294,17 @@ private[spark] object ClosureCleaner extends Logging {
167
294
}
168
295
}
169
296
170
- private def instantiateClass (cls : Class [_], outer : AnyRef , inInterpreter : Boolean ): AnyRef = {
171
- // logInfo("Creating a " + cls + " with outer = " + outer)
297
+ private def instantiateClass (
298
+ cls : Class [_],
299
+ enclosingObject : AnyRef ,
300
+ inInterpreter : Boolean ): AnyRef = {
172
301
if (! inInterpreter) {
173
302
// This is a bona fide closure class, whose constructor has no effects
174
303
// other than to set its fields, so use its constructor
175
304
val cons = cls.getConstructors()(0 )
176
305
val params = cons.getParameterTypes.map(createNullValue).toArray
177
- if (outer != null ) {
178
- params(0 ) = outer // First param is always outer object
306
+ if (enclosingObject != null ) {
307
+ params(0 ) = enclosingObject // First param is always enclosing object
179
308
}
180
309
return cons.newInstance(params : _* ).asInstanceOf [AnyRef ]
181
310
} else {
@@ -184,11 +313,10 @@ private[spark] object ClosureCleaner extends Logging {
184
313
val parentCtor = classOf [java.lang.Object ].getDeclaredConstructor()
185
314
val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
186
315
val obj = newCtor.newInstance().asInstanceOf [AnyRef ]
187
- if (outer != null ) {
188
- // logInfo("3: Setting $outer on " + cls + " to " + outer);
316
+ if (enclosingObject != null ) {
189
317
val field = cls.getDeclaredField(" $outer" )
190
318
field.setAccessible(true )
191
- field.set(obj, outer )
319
+ field.set(obj, enclosingObject )
192
320
}
193
321
obj
194
322
}
@@ -213,29 +341,68 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) {
213
341
}
214
342
}
215
343
344
+ /**
345
+ * Find the fields accessed by a given class.
346
+ *
347
+ * The fields are stored in the mutable map passed in by the class that contains them.
348
+ * This map is assumed to have its keys already populated by the classes of interest.
349
+ *
350
+ * @param fields the mutable map that stores the fields to return
351
+ * @param specificMethodNames if not empty, only visit methods whose names are in this set
352
+ * @param findTransitively if true, find fields indirectly referenced in other classes
353
+ */
216
354
private [spark]
217
- class FieldAccessFinder (output : Map [Class [_], Set [String ]]) extends ClassVisitor (ASM4 ) {
218
- override def visitMethod (access : Int , name : String , desc : String ,
219
- sig : String , exceptions : Array [String ]): MethodVisitor = {
355
+ class FieldAccessFinder (
356
+ fields : Map [Class [_], Set [String ]],
357
+ specificMethodNames : Set [String ] = Set .empty,
358
+ findTransitively : Boolean = true )
359
+ extends ClassVisitor (ASM4 ) {
360
+
361
+ override def visitMethod (
362
+ access : Int ,
363
+ name : String ,
364
+ desc : String ,
365
+ sig : String ,
366
+ exceptions : Array [String ]): MethodVisitor = {
367
+
368
+ // Ignore this method if we don't want to visit it
369
+ if (specificMethodNames.nonEmpty && ! specificMethodNames.contains(name)) {
370
+ return new MethodVisitor (ASM4 ) { }
371
+ }
372
+
220
373
new MethodVisitor (ASM4 ) {
221
374
override def visitFieldInsn (op : Int , owner : String , name : String , desc : String ) {
222
375
if (op == GETFIELD ) {
223
- for (cl <- output .keys if cl.getName == owner.replace('/' , '.' )) {
224
- output (cl) += name
376
+ for (cl <- fields .keys if cl.getName == owner.replace('/' , '.' )) {
377
+ fields (cl) += name
225
378
}
226
379
}
227
380
}
228
381
229
- override def visitMethodInsn (op : Int , owner : String , name : String ,
230
- desc : String ) {
231
- // Check for calls a getter method for a variable in an interpreter wrapper object.
232
- // This means that the corresponding field will be accessed, so we should save it.
233
- if (op == INVOKEVIRTUAL && owner.endsWith(" $iwC" ) && ! name.endsWith(" $outer" )) {
234
- for (cl <- output.keys if cl.getName == owner.replace('/' , '.' )) {
235
- output(cl) += name
382
+ override def visitMethodInsn (op : Int , owner : String , name : String , desc : String ) {
383
+ if (isInvoke(op)) {
384
+ for (cl <- fields.keys if cl.getName == owner.replace('/' , '.' )) {
385
+ // Check for calls a getter method for a variable in an interpreter wrapper object.
386
+ // This means that the corresponding field will be accessed, so we should save it.
387
+ if (op == INVOKEVIRTUAL && owner.endsWith(" $iwC" ) && ! name.endsWith(" $outer" )) {
388
+ fields(cl) += name
389
+ }
390
+ // Visit other methods to find fields that are transitively referenced
391
+ if (findTransitively) {
392
+ ClosureCleaner .getClassReader(cl)
393
+ .accept(new FieldAccessFinder (fields, Set (name), findTransitively), 0 )
394
+ }
236
395
}
237
396
}
238
397
}
398
+
399
+ private def isInvoke (op : Int ): Boolean = {
400
+ op == INVOKEVIRTUAL ||
401
+ op == INVOKESPECIAL ||
402
+ op == INVOKEDYNAMIC ||
403
+ op == INVOKEINTERFACE ||
404
+ op == INVOKESTATIC
405
+ }
239
406
}
240
407
}
241
408
}
0 commit comments