Skip to content

Commit 9a313d5

Browse files
committed
making classes that needn't be public private, adding automatic file closure, adding new tests
1 parent edf5829 commit 9a313d5

File tree

8 files changed

+107
-19
lines changed

8 files changed

+107
-19
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ class SparkContext(config: SparkConf) extends Logging {
523523
val job = new NewHadoopJob(hadoopConfiguration)
524524
NewFileInputFormat.addInputPath(job, new Path(path))
525525
val updateConf = job.getConfiguration
526-
new RawFileRDD(
526+
new BinaryFileRDD(
527527
this,
528528
classOf[ByteInputFormat],
529529
classOf[String],
@@ -548,7 +548,7 @@ class SparkContext(config: SparkConf) extends Logging {
548548
val job = new NewHadoopJob(hadoopConfiguration)
549549
NewFileInputFormat.addInputPath(job, new Path(path))
550550
val updateConf = job.getConfiguration
551-
new RawFileRDD(
551+
new BinaryFileRDD(
552552
this,
553553
classOf[StreamInputFormat],
554554
classOf[String],
@@ -565,9 +565,9 @@ class SparkContext(config: SparkConf) extends Logging {
565565
* @param path Directory to the input data files
566566
* @return An RDD of data with values, RDD[(Array[Byte])]
567567
*/
568-
def fixedLengthBinaryFiles(path: String): RDD[Array[Byte]] = {
569-
val lines = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path)
570-
val data = lines.map{ case (k, v) => v.getBytes}
568+
def binaryRecords(path: String): RDD[Array[Byte]] = {
569+
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path)
570+
val data = br.map{ case (k, v) => v.getBytes}
571571
data
572572
}
573573

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
289289
* @param minPartitions A suggestion value of the minimal splitting number for input data.
290290
*/
291291
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
292-
JavaPairRDD[String,Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))
292+
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))
293293

294294
/**
295295
* Load data from a flat binary file, assuming each record is a set of numbers
@@ -299,8 +299,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
299299
* @param path Directory to the input data files
300300
* @return An RDD of data with values, JavaRDD[(Array[Byte])]
301301
*/
302-
def fixedLengthBinaryFiles(path: String): JavaRDD[Array[Byte]] = {
303-
new JavaRDD(sc.fixedLengthBinaryFiles(path))
302+
def binaryRecords(path: String): JavaRDD[Array[Byte]] = {
303+
new JavaRDD(sc.binaryRecords(path))
304304
}
305305

306306
/** Get an RDD for a Hadoop SequenceFile with given key and value types.

core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAt
2828
* a parameter recordLength in the Hadoop configuration.
2929
*/
3030

31-
object FixedLengthBinaryInputFormat {
31+
private[spark] object FixedLengthBinaryInputFormat {
3232

3333
/**
3434
* This function retrieves the recordLength by checking the configuration parameter
@@ -42,7 +42,7 @@ object FixedLengthBinaryInputFormat {
4242

4343
}
4444

45-
class FixedLengthBinaryInputFormat extends FileInputFormat[LongWritable, BytesWritable] {
45+
private[spark] class FixedLengthBinaryInputFormat extends FileInputFormat[LongWritable, BytesWritable] {
4646

4747

4848
/**

core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileSplit
3737
* VALUE = the record itself (BytesWritable)
3838
*
3939
*/
40-
class FixedLengthBinaryRecordReader extends RecordReader[LongWritable, BytesWritable] {
40+
private[spark] class FixedLengthBinaryRecordReader extends RecordReader[LongWritable, BytesWritable] {
4141

4242
override def close() {
4343
if (fileInputStream != null) {

core/src/main/scala/org/apache/spark/input/RawFileInput.scala

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,29 @@ abstract class StreamBasedRecordReader[T](
7373

7474
private val key = path.toString
7575
private var value: T = null.asInstanceOf[T]
76+
// the file to be read when nextkeyvalue is called
77+
private lazy val fileIn: FSDataInputStream = fs.open(path)
78+
7679
override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
77-
override def close() = {}
80+
override def close() = {
81+
// make sure the file is closed
82+
try {
83+
fileIn.close()
84+
} catch {
85+
case ioe: java.io.IOException => // do nothing
86+
}
87+
}
7888

7989
override def getProgress = if (processed) 1.0f else 0.0f
8090

8191
override def getCurrentKey = key
8292

8393
override def getCurrentValue = value
8494

95+
8596
override def nextKeyValue = {
8697
if (!processed) {
87-
val fileIn: FSDataInputStream = fs.open(path)
98+
8899
value = parseStream(fileIn)
89100
processed = true
90101
true
@@ -104,7 +115,7 @@ abstract class StreamBasedRecordReader[T](
104115
/**
105116
* Reads the record in directly as a stream for other objects to manipulate and handle
106117
*/
107-
class StreamRecordReader(
118+
private[spark] class StreamRecordReader(
108119
split: CombineFileSplit,
109120
context: TaskAttemptContext,
110121
index: Integer)
@@ -117,7 +128,7 @@ class StreamRecordReader(
117128
* A class for extracting the information from the file using the
118129
* BinaryRecordReader (as Byte array)
119130
*/
120-
class StreamInputFormat extends StreamFileInputFormat[DataInputStream] {
131+
private[spark] class StreamInputFormat extends StreamFileInputFormat[DataInputStream] {
121132
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
122133
{
123134
new CombineFileRecordReader[String,DataInputStream](
@@ -146,7 +157,7 @@ abstract class BinaryRecordReader[T](
146157
}
147158

148159

149-
class ByteRecordReader(
160+
private[spark] class ByteRecordReader(
150161
split: CombineFileSplit,
151162
context: TaskAttemptContext,
152163
index: Integer)
@@ -158,7 +169,7 @@ class ByteRecordReader(
158169
/**
159170
* A class for reading the file using the BinaryRecordReader (as Byte array)
160171
*/
161-
class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] {
172+
private[spark] class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] {
162173
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
163174
{
164175
new CombineFileRecordReader[String,Array[Byte]](

core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ package org.apache.spark.rdd
2323
import org.apache.hadoop.conf.{Configurable, Configuration}
2424
import org.apache.hadoop.io.Writable
2525
import org.apache.hadoop.mapreduce._
26-
import org.apache.spark.{Partition, SparkContext}
26+
import org.apache.spark.{InterruptibleIterator, TaskContext, Partition, SparkContext}
2727
import org.apache.spark.input.StreamFileInputFormat
2828

29-
private[spark] class RawFileRDD[T](
29+
private[spark] class BinaryFileRDD[T](
3030
sc : SparkContext,
3131
inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
3232
keyClass: Class[String],
@@ -35,6 +35,7 @@ private[spark] class RawFileRDD[T](
3535
minPartitions: Int)
3636
extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {
3737

38+
3839
override def getPartitions: Array[Partition] = {
3940
val inputFormat = inputFormatClass.newInstance
4041
inputFormat match {

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,28 @@ public Tuple2<Integer, String> call(Tuple2<IntWritable, Text> pair) {
836836
Assert.assertEquals(pairs, readRDD.collect());
837837
}
838838

839+
@Test
840+
public void binaryFiles() throws Exception {
841+
// Reusing the wholeText files example
842+
byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
843+
byte[] content2 = "spark is also easy to use.\n".getBytes("utf-8");
844+
845+
String tempDirName = tempDir.getAbsolutePath();
846+
File file1 = new File(tempDirName + "/part-00000");
847+
Files.write(content1, file1);
848+
File file2 = new File(tempDirName + "/part-00001");
849+
Files.write(content2, file2);
850+
851+
JavaPairRDD<String, byte[]> readRDD = sc.binaryFiles(tempDirName,3);
852+
List<Tuple2<String, byte[]>> result = readRDD.collect();
853+
for (Tuple2<String, byte[]> res : result) {
854+
if (res._1()==file1.toString())
855+
Assert.assertArrayEquals(content1,res._2());
856+
else
857+
Assert.assertArrayEquals(content2,res._2());
858+
}
859+
}
860+
839861
@SuppressWarnings("unchecked")
840862
@Test
841863
public void writeWithNewAPIHadoopFile() {

core/src/test/scala/org/apache/spark/FileSuite.scala

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,60 @@ class FileSuite extends FunSuite with LocalSparkContext {
224224
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
225225
}
226226

227+
test("byte stream input") {
228+
sc = new SparkContext("local", "test")
229+
val outputDir = new File(tempDir, "output").getAbsolutePath
230+
val outFile = new File(outputDir, "part-00000.bin")
231+
val outFileName = outFile.toPath().toString()
232+
233+
// create file
234+
val testOutput = Array[Byte](1,2,3,4,5,6)
235+
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
236+
// write data to file
237+
val file = new java.io.FileOutputStream(outFile)
238+
val channel = file.getChannel
239+
channel.write(bbuf)
240+
channel.close()
241+
file.close()
242+
243+
val inRdd = sc.binaryFiles(outFileName)
244+
val (infile: String, indata: Array[Byte]) = inRdd.first
245+
246+
// Try reading the output back as an object file
247+
assert(infile === outFileName)
248+
assert(indata === testOutput)
249+
}
250+
251+
test("fixed length byte stream input") {
252+
// a fixed length of 6 bytes
253+
254+
sc = new SparkContext("local", "test")
255+
256+
val outputDir = new File(tempDir, "output").getAbsolutePath
257+
val outFile = new File(outputDir, "part-00000.bin")
258+
val outFileName = outFile.toPath().toString()
259+
260+
// create file
261+
val testOutput = Array[Byte](1,2,3,4,5,6)
262+
val testOutputCopies = 10
263+
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
264+
// write data to file
265+
val file = new java.io.FileOutputStream(outFile)
266+
val channel = file.getChannel
267+
for(i <- 1 to testOutputCopies) channel.write(bbuf)
268+
channel.close()
269+
file.close()
270+
sc.hadoopConfiguration.setInt("recordLength",testOutput.length)
271+
272+
val inRdd = sc.binaryRecords(outFileName)
273+
// make sure there are enough elements
274+
assert(inRdd.count== testOutputCopies)
275+
276+
// now just compare the first one
277+
val indata: Array[Byte] = inRdd.first
278+
assert(indata === testOutput)
279+
}
280+
227281
test("file caching") {
228282
sc = new SparkContext("local", "test")
229283
val out = new FileWriter(tempDir + "/input")

0 commit comments

Comments
 (0)