Skip to content

Commit 238c83c

Browse files
committed
fixed several scala-style issues, changed structure of binaryFiles, removed excessive classes added new tests. The caching tests still have a serialization issue, but that should be easily fixed as well.
1 parent 932a206 commit 238c83c

File tree

5 files changed

+146
-78
lines changed

5 files changed

+146
-78
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ import org.apache.mesos.MesosNativeLibrary
4040
import org.apache.spark.annotation.{DeveloperApi, Experimental}
4141
import org.apache.spark.broadcast.Broadcast
4242
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
43-
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, ByteInputFormat, FixedLengthBinaryInputFormat}
43+
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
4444
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
4545
import org.apache.spark.rdd._
4646
import org.apache.spark.scheduler._
@@ -510,27 +510,6 @@ class SparkContext(config: SparkConf) extends Logging {
510510
minPartitions).setName(path)
511511
}
512512

513-
/**
514-
* Get an RDD for a Hadoop-readable dataset as byte-streams for each file
515-
* (useful for binary data)
516-
*
517-
* @param minPartitions A suggestion value of the minimal splitting number for input data.
518-
*
519-
* @note Small files are preferred, large file is also allowable, but may cause bad performance.
520-
*/
521-
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
522-
RDD[(String, Array[Byte])] = {
523-
val job = new NewHadoopJob(hadoopConfiguration)
524-
NewFileInputFormat.addInputPath(job, new Path(path))
525-
val updateConf = job.getConfiguration
526-
new BinaryFileRDD(
527-
this,
528-
classOf[ByteInputFormat],
529-
classOf[String],
530-
classOf[Array[Byte]],
531-
updateConf,
532-
minPartitions).setName(path)
533-
}
534513

535514
/**
536515
* Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
@@ -543,7 +522,7 @@ class SparkContext(config: SparkConf) extends Logging {
543522
* @note Small files are preferred, large file is also allowable, but may cause bad performance.
544523
*/
545524
@DeveloperApi
546-
def dataStreamFiles(path: String, minPartitions: Int = defaultMinPartitions):
525+
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
547526
RDD[(String, PortableDataStream)] = {
548527
val job = new NewHadoopJob(hadoopConfiguration)
549528
NewFileInputFormat.addInputPath(job, new Path(path))
@@ -563,10 +542,17 @@ class SparkContext(config: SparkConf) extends Logging {
563542
* bytes per record is constant (see FixedLengthBinaryInputFormat)
564543
*
565544
* @param path Directory to the input data files
545+
* @param recordLength The length at which to split the records
566546
* @return An RDD of data with values, RDD[(Array[Byte])]
567547
*/
568-
def binaryRecords(path: String): RDD[Array[Byte]] = {
569-
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path)
548+
def binaryRecords(path: String, recordLength: Int,
549+
conf: Configuration = hadoopConfiguration): RDD[Array[Byte]] = {
550+
conf.setInt("recordLength",recordLength)
551+
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
552+
classOf[FixedLengthBinaryInputFormat],
553+
classOf[LongWritable],
554+
classOf[BytesWritable],
555+
conf=conf)
570556
val data = br.map{ case (k, v) => v.getBytes}
571557
data
572558
}

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf}
3636
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
3737

3838
import org.apache.spark._
39-
import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
39+
import org.apache.spark.SparkContext._
4040
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
4141
import org.apache.spark.broadcast.Broadcast
4242
import org.apache.spark.rdd.{EmptyRDD, RDD}
@@ -256,8 +256,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
256256
*
257257
* @param minPartitions A suggestion value of the minimal splitting number for input data.
258258
*/
259-
def dataStreamFiles(path: String, minPartitions: Int = defaultMinPartitions):
260-
JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.dataStreamFiles(path,minPartitions))
259+
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
260+
JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))
261261

262262
/**
263263
* Read a directory of files as DataInputStream from HDFS,
@@ -288,8 +288,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
288288
*
289289
* @param minPartitions A suggestion value of the minimal splitting number for input data.
290290
*/
291-
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
292-
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))
291+
def binaryArrays(path: String, minPartitions: Int = defaultMinPartitions):
292+
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions).mapValues(_.toArray()))
293293

294294
/**
295295
* Load data from a flat binary file, assuming each record is a set of numbers
@@ -299,8 +299,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
299299
* @param path Directory to the input data files
300300
* @return An RDD of data with values, JavaRDD[(Array[Byte])]
301301
*/
302-
def binaryRecords(path: String): JavaRDD[Array[Byte]] = {
303-
new JavaRDD(sc.binaryRecords(path))
302+
def binaryRecords(path: String,recordLength: Int): JavaRDD[Array[Byte]] = {
303+
new JavaRDD(sc.binaryRecords(path,recordLength))
304304
}
305305

306306
/** Get an RDD for a Hadoop SequenceFile with given key and value types.

core/src/main/scala/org/apache/spark/input/RawFileInput.scala

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ abstract class StreamFileInputFormat[T]
4646
if (file.isDir) 0L else file.getLen
4747
}.sum
4848

49-
val maxSplitSize = Math.ceil(totalLen*1.0/files.length).toLong
49+
val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
5050
super.setMaxSplitSize(maxSplitSize)
5151
}
5252

@@ -61,8 +61,10 @@ abstract class StreamFileInputFormat[T]
6161
*/
6262
class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, index: Integer)
6363
extends Serializable {
64-
64+
// transient forces file to be reopened after being moved (serialization)
65+
@transient
6566
private var fileIn: FSDataInputStream = null.asInstanceOf[FSDataInputStream]
67+
@transient
6668
private var isOpen = false
6769
/**
6870
* Calculate the path name independently of opening the file
@@ -76,13 +78,25 @@ class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, i
7678
* create a new DataInputStream from the split and context
7779
*/
7880
def open(): FSDataInputStream = {
79-
val pathp = split.getPath(index)
80-
val fs = pathp.getFileSystem(context.getConfiguration)
81-
fileIn = fs.open(pathp)
82-
isOpen=true
81+
if (!isOpen) {
82+
val pathp = split.getPath(index)
83+
val fs = pathp.getFileSystem(context.getConfiguration)
84+
fileIn = fs.open(pathp)
85+
isOpen=true
86+
}
8387
fileIn
8488
}
8589

90+
/**
91+
* Read the file as a byte array
92+
*/
93+
def toArray(): Array[Byte] = {
94+
open()
95+
val innerBuffer = ByteStreams.toByteArray(fileIn)
96+
close()
97+
innerBuffer
98+
}
99+
86100
/**
87101
* close the file (if it is already open)
88102
*/
@@ -131,7 +145,7 @@ abstract class StreamBasedRecordReader[T](
131145

132146
override def nextKeyValue = {
133147
if (!processed) {
134-
val fileIn = new PortableDataStream(split,context,index)
148+
val fileIn = new PortableDataStream(split, context, index)
135149
value = parseStream(fileIn)
136150
fileIn.close() // if it has not been open yet, close does nothing
137151
key = fileIn.getPath
@@ -157,7 +171,7 @@ private[spark] class StreamRecordReader(
157171
split: CombineFileSplit,
158172
context: TaskAttemptContext,
159173
index: Integer)
160-
extends StreamBasedRecordReader[PortableDataStream](split,context,index) {
174+
extends StreamBasedRecordReader[PortableDataStream](split, context, index) {
161175

162176
def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
163177
}
@@ -170,7 +184,7 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat
170184
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
171185
{
172186
new CombineFileRecordReader[String,PortableDataStream](
173-
split.asInstanceOf[CombineFileSplit],taContext,classOf[StreamRecordReader]
187+
split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]
174188
)
175189
}
176190
}
@@ -193,29 +207,4 @@ abstract class BinaryRecordReader[T](
193207
parseByteArray(innerBuffer)
194208
}
195209
def parseByteArray(inArray: Array[Byte]): T
196-
}
197-
198-
199-
200-
private[spark] class ByteRecordReader(
201-
split: CombineFileSplit,
202-
context: TaskAttemptContext,
203-
index: Integer)
204-
extends BinaryRecordReader[Array[Byte]](split,context,index) {
205-
206-
override def parseByteArray(inArray: Array[Byte]) = inArray
207-
}
208-
209-
/**
210-
* A class for reading the file using the BinaryRecordReader (as Byte array)
211-
*/
212-
private[spark] class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] {
213-
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
214-
{
215-
new CombineFileRecordReader[String,Array[Byte]](
216-
split.asInstanceOf[CombineFileSplit],taContext,classOf[ByteRecordReader]
217-
)
218-
}
219-
}
220-
221-
210+
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.net.URI;
2424
import java.util.*;
2525

26+
import org.apache.spark.input.PortableDataStream;
2627
import scala.Tuple2;
2728
import scala.Tuple3;
2829
import scala.Tuple4;
@@ -852,12 +853,68 @@ public void binaryFiles() throws Exception {
852853
FileChannel channel1 = fos1.getChannel();
853854
ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
854855
channel1.write(bbuf);
856+
channel1.close();
857+
JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName,3);
858+
List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
859+
for (Tuple2<String, PortableDataStream> res : result) {
860+
Assert.assertArrayEquals(content1, res._2().toArray());
861+
}
862+
}
863+
864+
@Test
865+
public void binaryFilesCaching() throws Exception {
866+
// Reusing the wholeText files example
867+
byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
868+
869+
870+
String tempDirName = tempDir.getAbsolutePath();
871+
File file1 = new File(tempDirName + "/part-00000");
872+
873+
FileOutputStream fos1 = new FileOutputStream(file1);
874+
875+
FileChannel channel1 = fos1.getChannel();
876+
ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
877+
channel1.write(bbuf);
878+
channel1.close();
879+
880+
JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName,3).cache();
881+
readRDD.foreach(new VoidFunction<Tuple2<String,PortableDataStream>>() {
882+
@Override
883+
public void call(Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2) throws Exception {
884+
stringPortableDataStreamTuple2._2().getPath();
885+
stringPortableDataStreamTuple2._2().toArray(); // force the file to read
886+
}
887+
});
888+
889+
List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
890+
for (Tuple2<String, PortableDataStream> res : result) {
891+
Assert.assertArrayEquals(content1, res._2().toArray());
892+
}
893+
}
855894

895+
@Test
896+
public void binaryRecords() throws Exception {
897+
// Reusing the wholeText files example
898+
byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8");
899+
int numOfCopies = 10;
900+
String tempDirName = tempDir.getAbsolutePath();
901+
File file1 = new File(tempDirName + "/part-00000");
902+
903+
FileOutputStream fos1 = new FileOutputStream(file1);
904+
905+
FileChannel channel1 = fos1.getChannel();
906+
907+
for (int i=0;i<numOfCopies;i++) {
908+
ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
909+
channel1.write(bbuf);
910+
}
911+
channel1.close();
856912

857-
JavaPairRDD<String, byte[]> readRDD = sc.binaryFiles(tempDirName,3);
858-
List<Tuple2<String, byte[]>> result = readRDD.collect();
859-
for (Tuple2<String, byte[]> res : result) {
860-
Assert.assertArrayEquals(content1, res._2());
913+
JavaRDD<byte[]> readRDD = sc.binaryRecords(tempDirName,content1.length);
914+
Assert.assertEquals(numOfCopies,readRDD.count());
915+
List<byte[]> result = readRDD.collect();
916+
for (byte[] res : result) {
917+
Assert.assertArrayEquals(content1, res);
861918
}
862919
}
863920

core/src/test/scala/org/apache/spark/FileSuite.scala

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark
1919

2020
import java.io.{File, FileWriter}
2121

22+
import org.apache.spark.input.PortableDataStream
23+
2224
import scala.io.Source
2325

2426
import com.google.common.io.Files
@@ -240,35 +242,69 @@ class FileSuite extends FunSuite with LocalSparkContext {
240242
file.close()
241243

242244
val inRdd = sc.binaryFiles(outFileName)
243-
val (infile: String, indata: Array[Byte]) = inRdd.first
245+
val (infile: String, indata: PortableDataStream) = inRdd.first
244246

245247
// Try reading the output back as an object file
246248
assert(infile === outFileName)
247-
assert(indata === testOutput)
249+
assert(indata.toArray === testOutput)
250+
}
251+
252+
test("portabledatastream caching tests") {
253+
sc = new SparkContext("local", "test")
254+
val outFile = new File(tempDir, "record-bytestream-00000.bin")
255+
val outFileName = outFile.getAbsolutePath()
256+
257+
// create file
258+
val testOutput = Array[Byte](1,2,3,4,5,6)
259+
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
260+
// write data to file
261+
val file = new java.io.FileOutputStream(outFile)
262+
val channel = file.getChannel
263+
channel.write(bbuf)
264+
channel.close()
265+
file.close()
266+
267+
val inRdd = sc.binaryFiles(outFileName).cache()
268+
inRdd.foreach{
269+
curData: (String, PortableDataStream) =>
270+
curData._2.toArray() // force the file to read
271+
}
272+
val mappedRdd = inRdd.map{
273+
curData: (String, PortableDataStream) =>
274+
(curData._2.getPath(),curData._2)
275+
}
276+
val (infile: String, indata: PortableDataStream) = mappedRdd.first
277+
278+
// Try reading the output back as an object file
279+
280+
assert(indata.toArray === testOutput)
248281
}
249282

250283
test("fixed record length binary file as byte array") {
251284
// a fixed length of 6 bytes
252285

253286
sc = new SparkContext("local", "test")
287+
254288
val outFile = new File(tempDir, "record-bytestream-00000.bin")
255289
val outFileName = outFile.getAbsolutePath()
256290

257291
// create file
258292
val testOutput = Array[Byte](1,2,3,4,5,6)
259293
val testOutputCopies = 10
260-
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
294+
261295
// write data to file
262296
val file = new java.io.FileOutputStream(outFile)
263297
val channel = file.getChannel
264-
for(i <- 1 to testOutputCopies) channel.write(bbuf)
298+
for(i <- 1 to testOutputCopies) {
299+
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
300+
channel.write(bbuf)
301+
}
265302
channel.close()
266303
file.close()
267-
sc.hadoopConfiguration.setInt("recordLength",testOutput.length)
268304

269-
val inRdd = sc.binaryRecords(outFileName)
305+
val inRdd = sc.binaryRecords(outFileName, testOutput.length)
270306
// make sure there are enough elements
271-
assert(inRdd.count== testOutputCopies)
307+
assert(inRdd.count == testOutputCopies)
272308

273309
// now just compare the first one
274310
val indata: Array[Byte] = inRdd.first

0 commit comments

Comments
 (0)