18
18
package org .apache .spark .rdd
19
19
20
20
import java .io .{FileNotFoundException , IOException }
21
+ import java .security .PrivilegedExceptionAction
21
22
import java .text .SimpleDateFormat
22
23
import java .util .{Date , Locale }
23
24
@@ -29,6 +30,7 @@ import org.apache.hadoop.mapred._
29
30
import org .apache .hadoop .mapred .lib .CombineFileSplit
30
31
import org .apache .hadoop .mapreduce .TaskType
31
32
import org .apache .hadoop .mapreduce .lib .input .FileInputFormat
33
+ import org .apache .hadoop .security .UserGroupInformation
32
34
import org .apache .hadoop .util .ReflectionUtils
33
35
34
36
import org .apache .spark ._
@@ -124,6 +126,8 @@ class HadoopRDD[K, V](
124
126
minPartitions)
125
127
}
126
128
129
+ private val doAsUserName = UserGroupInformation .getCurrentUser.getUserName
130
+
127
131
protected val jobConfCacheKey : String = " rdd_%d_job_conf" .format(id)
128
132
129
133
protected val inputFormatCacheKey : String = " rdd_%d_input_format" .format(id)
@@ -220,7 +224,7 @@ class HadoopRDD[K, V](
220
224
}
221
225
}
222
226
223
- override def compute (theSplit : Partition , context : TaskContext ): InterruptibleIterator [(K , V )] = {
227
+ def doCompute (theSplit : Partition , context : TaskContext ): InterruptibleIterator [(K , V )] = {
224
228
val iter = new NextIterator [(K , V )] {
225
229
226
230
private val split = theSplit.asInstanceOf [HadoopPartition ]
@@ -326,7 +330,7 @@ class HadoopRDD[K, V](
326
330
if (getBytesReadCallback.isDefined) {
327
331
updateBytesRead()
328
332
} else if (split.inputSplit.value.isInstanceOf [FileSplit ] ||
329
- split.inputSplit.value.isInstanceOf [CombineFileSplit ]) {
333
+ split.inputSplit.value.isInstanceOf [CombineFileSplit ]) {
330
334
// If we can't get the bytes read from the FS stats, fall back to the split size,
331
335
// which may be inaccurate.
332
336
try {
@@ -342,6 +346,29 @@ class HadoopRDD[K, V](
342
346
new InterruptibleIterator [(K , V )](context, iter)
343
347
}
344
348
349
+ override def compute (theSplit : Partition , context : TaskContext ): InterruptibleIterator [(K , V )] = {
350
+ val ugi = UserGroupInformation .getCurrentUser
351
+
352
+ if (ugi.getUserName == doAsUserName) {
353
+ doCompute(theSplit : Partition , context : TaskContext )
354
+ } else {
355
+ val doAsAction = new PrivilegedExceptionAction [InterruptibleIterator [(K , V )]]() {
356
+ override def run (): InterruptibleIterator [(K , V )] = {
357
+ try {
358
+ doCompute(theSplit : Partition , context : TaskContext )
359
+ } catch {
360
+ case e : Exception =>
361
+ log.error(" Error when HadoopRDD computing: " , e)
362
+ throw e
363
+ }
364
+ }
365
+ }
366
+
367
+ val proxyUgi = UserGroupInformation .createProxyUser(doAsUserName, ugi)
368
+ proxyUgi.doAs(doAsAction)
369
+ }
370
+ }
371
+
345
372
/** Maps over a partition, providing the InputSplit that was used as the base of the partition. */
346
373
@ DeveloperApi
347
374
def mapPartitionsWithInputSplit [U : ClassTag ](
0 commit comments