17
17
18
18
package org .apache .spark .rdd
19
19
20
- import scala .collection .mutable .ArrayBuffer
21
- import scala .collection .mutable .HashSet
20
+ import org .apache .hadoop .fs .FileSystem
21
+ import org .apache .hadoop .mapred ._
22
+ import org .apache .hadoop .util .Progressable
23
+
24
+ import scala .collection .mutable .{ArrayBuffer , HashSet }
22
25
import scala .util .Random
23
26
24
- import org .scalatest .FunSuite
25
27
import com .google .common .io .Files
26
- import org .apache .hadoop .mapreduce . _
27
- import org .apache .hadoop .conf .{ Configuration , Configurable }
28
-
29
- import org . apache . spark . SparkContext . _
28
+ import org .apache .hadoop .conf .{ Configurable , Configuration }
29
+ import org .apache .hadoop .mapreduce .{ JobContext => NewJobContext , OutputCommitter => NewOutputCommitter ,
30
+ OutputFormat => NewOutputFormat , RecordWriter => NewRecordWriter ,
31
+ TaskAttemptContext => NewTaskAttempContext }
30
32
import org .apache .spark .{Partitioner , SharedSparkContext }
33
+ import org .apache .spark .SparkContext ._
34
+ import org .scalatest .FunSuite
31
35
32
36
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
33
37
test(" aggregateByKey" ) {
@@ -467,7 +471,7 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
467
471
val pairs = sc.parallelize(Array ((new Integer (1 ), new Integer (1 ))))
468
472
469
473
// No error, non-configurable formats still work
470
- pairs.saveAsNewAPIHadoopFile[FakeFormat ](" ignored" )
474
+ pairs.saveAsNewAPIHadoopFile[NewFakeFormat ](" ignored" )
471
475
472
476
/*
473
477
Check that configurable formats get configured:
@@ -478,6 +482,17 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
478
482
pairs.saveAsNewAPIHadoopFile[ConfigTestFormat ](" ignored" )
479
483
}
480
484
485
+ test(" saveAsHadoopFile should respect configured output committers" ) {
486
+ val pairs = sc.parallelize(Array ((new Integer (1 ), new Integer (1 ))))
487
+ val conf = new JobConf ()
488
+ conf.setOutputCommitter(classOf [FakeOutputCommitter ])
489
+
490
+ FakeOutputCommitter .ran = false
491
+ pairs.saveAsHadoopFile(" ignored" , pairs.keyClass, pairs.valueClass, classOf [FakeOutputFormat ], conf)
492
+
493
+ assert(FakeOutputCommitter .ran, " OutputCommitter was never called" )
494
+ }
495
+
481
496
test(" lookup" ) {
482
497
val pairs = sc.parallelize(Array ((1 ,2 ), (3 ,4 ), (5 ,6 ), (5 ,7 )))
483
498
@@ -621,40 +636,86 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
621
636
and the test will therefore throw InstantiationException when saveAsNewAPIHadoopFile
622
637
tries to instantiate them with Class.newInstance.
623
638
*/
639
+
640
+ /*
641
+ * Original Hadoop API
642
+ */
624
643
class FakeWriter extends RecordWriter [Integer , Integer ] {
644
+ override def write (key : Integer , value : Integer ): Unit = ()
625
645
626
- def close (p1 : TaskAttemptContext ) = ()
646
+ override def close (reporter : Reporter ): Unit = ()
647
+ }
648
+
649
+ class FakeOutputCommitter () extends OutputCommitter () {
650
+ override def setupJob (jobContext : JobContext ): Unit = ()
651
+
652
+ override def needsTaskCommit (taskContext : TaskAttemptContext ): Boolean = true
653
+
654
+ override def setupTask (taskContext : TaskAttemptContext ): Unit = ()
655
+
656
+ override def commitTask (taskContext : TaskAttemptContext ): Unit = {
657
+ FakeOutputCommitter .ran = true
658
+ ()
659
+ }
660
+
661
+ override def abortTask (taskContext : TaskAttemptContext ): Unit = ()
662
+ }
663
+
664
+ /*
665
+ * Used to communicate state between the test harness and the OutputCommitter.
666
+ */
667
+ object FakeOutputCommitter {
668
+ var ran = false
669
+ }
670
+
671
+ class FakeOutputFormat () extends OutputFormat [Integer , Integer ]() {
672
+ override def getRecordWriter (
673
+ ignored : FileSystem ,
674
+ job : JobConf , name : String ,
675
+ progress : Progressable ): RecordWriter [Integer , Integer ] = {
676
+ new FakeWriter ()
677
+ }
678
+
679
+ override def checkOutputSpecs (ignored : FileSystem , job : JobConf ): Unit = ()
680
+ }
681
+
682
+ /*
683
+ * New-style Hadoop API
684
+ */
685
+ class NewFakeWriter extends NewRecordWriter [Integer , Integer ] {
686
+
687
+ def close (p1 : NewTaskAttempContext ) = ()
627
688
628
689
def write (p1 : Integer , p2 : Integer ) = ()
629
690
630
691
}
631
692
632
- class FakeCommitter extends OutputCommitter {
633
- def setupJob (p1 : JobContext ) = ()
693
+ class NewFakeCommitter extends NewOutputCommitter {
694
+ def setupJob (p1 : NewJobContext ) = ()
634
695
635
- def needsTaskCommit (p1 : TaskAttemptContext ): Boolean = false
696
+ def needsTaskCommit (p1 : NewTaskAttempContext ): Boolean = false
636
697
637
- def setupTask (p1 : TaskAttemptContext ) = ()
698
+ def setupTask (p1 : NewTaskAttempContext ) = ()
638
699
639
- def commitTask (p1 : TaskAttemptContext ) = ()
700
+ def commitTask (p1 : NewTaskAttempContext ) = ()
640
701
641
- def abortTask (p1 : TaskAttemptContext ) = ()
702
+ def abortTask (p1 : NewTaskAttempContext ) = ()
642
703
}
643
704
644
- class FakeFormat () extends OutputFormat [Integer , Integer ]() {
705
+ class NewFakeFormat () extends NewOutputFormat [Integer , Integer ]() {
645
706
646
- def checkOutputSpecs (p1 : JobContext ) = ()
707
+ def checkOutputSpecs (p1 : NewJobContext ) = ()
647
708
648
- def getRecordWriter (p1 : TaskAttemptContext ): RecordWriter [Integer , Integer ] = {
649
- new FakeWriter ()
709
+ def getRecordWriter (p1 : NewTaskAttempContext ): NewRecordWriter [Integer , Integer ] = {
710
+ new NewFakeWriter ()
650
711
}
651
712
652
- def getOutputCommitter (p1 : TaskAttemptContext ): OutputCommitter = {
653
- new FakeCommitter ()
713
+ def getOutputCommitter (p1 : NewTaskAttempContext ): NewOutputCommitter = {
714
+ new NewFakeCommitter ()
654
715
}
655
716
}
656
717
657
- class ConfigTestFormat () extends FakeFormat () with Configurable {
718
+ class ConfigTestFormat () extends NewFakeFormat () with Configurable {
658
719
659
720
var setConfCalled = false
660
721
def setConf (p1 : Configuration ) = {
@@ -664,7 +725,7 @@ class ConfigTestFormat() extends FakeFormat() with Configurable {
664
725
665
726
def getConf : Configuration = null
666
727
667
- override def getRecordWriter (p1 : TaskAttemptContext ): RecordWriter [Integer , Integer ] = {
728
+ override def getRecordWriter (p1 : NewTaskAttempContext ): NewRecordWriter [Integer , Integer ] = {
668
729
assert(setConfCalled, " setConf was never called" )
669
730
super .getRecordWriter(p1)
670
731
}
0 commit comments