Skip to content

Commit c35237b

Browse files
committed
Added test for checkpoint serialization in StreamingContext.start()
1 parent 3c4c1f9 commit c35237b

File tree

4 files changed

+86
-33
lines changed

4 files changed

+86
-33
lines changed

core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
2727

2828
import org.apache.spark.Logging
2929

30-
private[serializer] object SerializationDebugger extends Logging {
30+
private[spark] object SerializationDebugger extends Logging {
3131

3232
/**
3333
* Improve the given NotSerializableException with the serialization path leading from the given

streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,44 @@ object Checkpoint extends Logging {
102102
Seq.empty
103103
}
104104
}
105+
106+
/** Serialize the checkpoint, or throw any exception that occurs */
107+
def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = {
108+
val compressionCodec = CompressionCodec.createCodec(conf)
109+
val bos = new ByteArrayOutputStream()
110+
val zos = compressionCodec.compressedOutputStream(bos)
111+
val oos = new ObjectOutputStream(zos)
112+
Utils.tryWithSafeFinally {
113+
oos.writeObject(checkpoint)
114+
} {
115+
oos.close()
116+
}
117+
bos.toByteArray
118+
}
119+
120+
/** Deserialize a checkpoint from the input stream, or throw any exception that occurs */
121+
def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = {
122+
val compressionCodec = CompressionCodec.createCodec(conf)
123+
var ois: ObjectInputStreamWithLoader = null
124+
Utils.tryWithSafeFinally {
125+
126+
// ObjectInputStream uses the last defined user-defined class loader in the stack
127+
// to find classes, which maybe the wrong class loader. Hence, a inherited version
128+
// of ObjectInputStream is used to explicitly use the current thread's default class
129+
// loader to find and load classes. This is a well know Java issue and has popped up
130+
// in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
131+
val zis = compressionCodec.compressedInputStream(inputStream)
132+
ois = new ObjectInputStreamWithLoader(zis,
133+
Thread.currentThread().getContextClassLoader)
134+
val cp = ois.readObject.asInstanceOf[Checkpoint]
135+
cp.validate()
136+
cp
137+
} {
138+
if (ois != null) {
139+
ois.close()
140+
}
141+
}
142+
}
105143
}
106144

107145

@@ -189,17 +227,10 @@ class CheckpointWriter(
189227
}
190228

191229
def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) {
192-
val bos = new ByteArrayOutputStream()
193-
val zos = compressionCodec.compressedOutputStream(bos)
194-
val oos = new ObjectOutputStream(zos)
195-
Utils.tryWithSafeFinally {
196-
oos.writeObject(checkpoint)
197-
} {
198-
oos.close()
199-
}
200230
try {
231+
val bytes = Checkpoint.serialize(checkpoint, conf)
201232
executor.execute(new CheckpointWriteHandler(
202-
checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater))
233+
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
203234
logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
204235
} catch {
205236
case rej: RejectedExecutionException =>
@@ -264,25 +295,8 @@ object CheckpointReader extends Logging {
264295
checkpointFiles.foreach(file => {
265296
logInfo("Attempting to load checkpoint from file " + file)
266297
try {
267-
var ois: ObjectInputStreamWithLoader = null
268-
var cp: Checkpoint = null
269-
Utils.tryWithSafeFinally {
270-
val fis = fs.open(file)
271-
// ObjectInputStream uses the last defined user-defined class loader in the stack
272-
// to find classes, which maybe the wrong class loader. Hence, a inherited version
273-
// of ObjectInputStream is used to explicitly use the current thread's default class
274-
// loader to find and load classes. This is a well know Java issue and has popped up
275-
// in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
276-
val zis = compressionCodec.compressedInputStream(fis)
277-
ois = new ObjectInputStreamWithLoader(zis,
278-
Thread.currentThread().getContextClassLoader)
279-
cp = ois.readObject.asInstanceOf[Checkpoint]
280-
} {
281-
if (ois != null) {
282-
ois.close()
283-
}
284-
}
285-
cp.validate()
298+
val fis = fs.open(file)
299+
val cp = Checkpoint.deserialize(fis, conf)
286300
logInfo("Checkpoint successfully loaded from file " + file)
287301
logInfo("Checkpoint was generated at time " + cp.checkpointTime)
288302
return Some(cp)

streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.streaming
1919

20-
import java.io.InputStream
20+
import java.io.{NotSerializableException, InputStream}
2121
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
2222

2323
import scala.collection.Map
2424
import scala.collection.mutable.Queue
2525
import scala.reflect.ClassTag
26+
import scala.util.control.NonFatal
2627

2728
import akka.actor.{Props, SupervisorStrategy}
2829
import org.apache.hadoop.conf.Configuration
@@ -35,13 +36,14 @@ import org.apache.spark._
3536
import org.apache.spark.annotation.{DeveloperApi, Experimental}
3637
import org.apache.spark.input.FixedLengthBinaryInputFormat
3738
import org.apache.spark.rdd.{RDD, RDDOperationScope}
39+
import org.apache.spark.serializer.SerializationDebugger
3840
import org.apache.spark.storage.StorageLevel
3941
import org.apache.spark.streaming.StreamingContextState._
4042
import org.apache.spark.streaming.dstream._
4143
import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver}
4244
import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener}
4345
import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab}
44-
import org.apache.spark.util.CallSite
46+
import org.apache.spark.util.{Utils, CallSite}
4547

4648
/**
4749
* Main entry point for Spark Streaming functionality. It provides methods used to create
@@ -235,6 +237,10 @@ class StreamingContext private[streaming] (
235237
}
236238
}
237239

240+
private[streaming] def isCheckpointingEnabled: Boolean = {
241+
checkpointDir != null
242+
}
243+
238244
private[streaming] def initialCheckpoint: Checkpoint = {
239245
if (isCheckpointPresent) cp_ else null
240246
}
@@ -523,11 +529,26 @@ class StreamingContext private[streaming] (
523529
assert(graph != null, "Graph is null")
524530
graph.validate()
525531

526-
assert(
527-
checkpointDir == null || checkpointDuration != null,
532+
require(
533+
!isCheckpointingEnabled || checkpointDuration != null,
528534
"Checkpoint directory has been set, but the graph checkpointing interval has " +
529535
"not been set. Please use StreamingContext.checkpoint() to set the interval."
530536
)
537+
538+
// Verify whether the DStream checkpoint is serializable
539+
if (isCheckpointingEnabled) {
540+
val checkpoint = new Checkpoint(this, Time.apply(0))
541+
try {
542+
Checkpoint.serialize(checkpoint, conf)
543+
} catch {
544+
case e: NotSerializableException =>
545+
throw new IllegalArgumentException(
546+
"DStream checkpointing has been enabled but the DStreams with their functions " +
547+
"are not serializable\nSerialization stack:\n" +
548+
SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n")
549+
)
550+
}
551+
}
531552
}
532553

533554
/**

streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,24 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
132132
}
133133
}
134134

135+
test("start with non-seriazable DStream checkpoints") {
136+
val checkpointDir = Utils.createTempDir()
137+
ssc = new StreamingContext(conf, batchDuration)
138+
ssc.checkpoint(checkpointDir.getAbsolutePath)
139+
addInputStream(ssc).foreachRDD { rdd =>
140+
// Refer to this.appName from inside closure so that this closure refers to
141+
// the instance of StreamingContextSuite, and is therefore not serializable
142+
rdd.count() + appName
143+
}
144+
145+
// Test whether start() fails early when checkpointing is enabled
146+
intercept[IllegalArgumentException] {
147+
ssc.start()
148+
}
149+
assert(ssc.getState() !== StreamingContextState.ACTIVE)
150+
assert(StreamingContext.getActive().isEmpty)
151+
}
152+
135153
test("start multiple times") {
136154
ssc = new StreamingContext(master, appName, batchDuration)
137155
addInputStream(ssc).register()

0 commit comments

Comments
 (0)