Skip to content

[SPARK-7767] [STREAMING] Added test for checkpoint serialization in StreamingContext.start() #6292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.util.control.NonFatal

import org.apache.spark.Logging

private[serializer] object SerializationDebugger extends Logging {
private[spark] object SerializationDebugger extends Logging {

/**
* Improve the given NotSerializableException with the serialization path leading from the given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,44 @@ object Checkpoint extends Logging {
Seq.empty
}
}

/** Serialize the checkpoint, or throw any exception that occurs */
def serialize(checkpoint: Checkpoint, conf: SparkConf): Array[Byte] = {
val compressionCodec = CompressionCodec.createCodec(conf)
val bos = new ByteArrayOutputStream()
val zos = compressionCodec.compressedOutputStream(bos)
val oos = new ObjectOutputStream(zos)
Utils.tryWithSafeFinally {
oos.writeObject(checkpoint)
} {
oos.close()
}
bos.toByteArray
}

/** Deserialize a checkpoint from the input stream, or throw any exception that occurs */
def deserialize(inputStream: InputStream, conf: SparkConf): Checkpoint = {
val compressionCodec = CompressionCodec.createCodec(conf)
var ois: ObjectInputStreamWithLoader = null
Utils.tryWithSafeFinally {

// ObjectInputStream uses the last defined user-defined class loader in the stack
// to find classes, which maybe the wrong class loader. Hence, a inherited version
// of ObjectInputStream is used to explicitly use the current thread's default class
// loader to find and load classes. This is a well know Java issue and has popped up
// in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
val zis = compressionCodec.compressedInputStream(inputStream)
ois = new ObjectInputStreamWithLoader(zis,
Thread.currentThread().getContextClassLoader)
val cp = ois.readObject.asInstanceOf[Checkpoint]
cp.validate()
cp
} {
if (ois != null) {
ois.close()
}
}
}
}


Expand Down Expand Up @@ -189,17 +227,10 @@ class CheckpointWriter(
}

def write(checkpoint: Checkpoint, clearCheckpointDataLater: Boolean) {
val bos = new ByteArrayOutputStream()
val zos = compressionCodec.compressedOutputStream(bos)
val oos = new ObjectOutputStream(zos)
Utils.tryWithSafeFinally {
oos.writeObject(checkpoint)
} {
oos.close()
}
try {
val bytes = Checkpoint.serialize(checkpoint, conf)
executor.execute(new CheckpointWriteHandler(
checkpoint.checkpointTime, bos.toByteArray, clearCheckpointDataLater))
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
} catch {
case rej: RejectedExecutionException =>
Expand Down Expand Up @@ -264,25 +295,8 @@ object CheckpointReader extends Logging {
checkpointFiles.foreach(file => {
logInfo("Attempting to load checkpoint from file " + file)
try {
var ois: ObjectInputStreamWithLoader = null
var cp: Checkpoint = null
Utils.tryWithSafeFinally {
val fis = fs.open(file)
// ObjectInputStream uses the last defined user-defined class loader in the stack
// to find classes, which maybe the wrong class loader. Hence, a inherited version
// of ObjectInputStream is used to explicitly use the current thread's default class
// loader to find and load classes. This is a well know Java issue and has popped up
// in other places (e.g., http://jira.codehaus.org/browse/GROOVY-1627)
val zis = compressionCodec.compressedInputStream(fis)
ois = new ObjectInputStreamWithLoader(zis,
Thread.currentThread().getContextClassLoader)
cp = ois.readObject.asInstanceOf[Checkpoint]
} {
if (ois != null) {
ois.close()
}
}
cp.validate()
val fis = fs.open(file)
val cp = Checkpoint.deserialize(fis, conf)
logInfo("Checkpoint successfully loaded from file " + file)
logInfo("Checkpoint was generated at time " + cp.checkpointTime)
return Some(cp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.streaming

import java.io.InputStream
import java.io.{InputStream, NotSerializableException}
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}

import scala.collection.Map
Expand All @@ -35,6 +35,7 @@ import org.apache.spark._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.input.FixedLengthBinaryInputFormat
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.serializer.SerializationDebugger
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContextState._
import org.apache.spark.streaming.dstream._
Expand Down Expand Up @@ -235,6 +236,10 @@ class StreamingContext private[streaming] (
}
}

private[streaming] def isCheckpointingEnabled: Boolean = {
checkpointDir != null
}

private[streaming] def initialCheckpoint: Checkpoint = {
if (isCheckpointPresent) cp_ else null
}
Expand Down Expand Up @@ -523,11 +528,26 @@ class StreamingContext private[streaming] (
assert(graph != null, "Graph is null")
graph.validate()

assert(
checkpointDir == null || checkpointDuration != null,
require(
!isCheckpointingEnabled || checkpointDuration != null,
"Checkpoint directory has been set, but the graph checkpointing interval has " +
"not been set. Please use StreamingContext.checkpoint() to set the interval."
)

// Verify whether the DStream checkpoint is serializable
if (isCheckpointingEnabled) {
val checkpoint = new Checkpoint(this, Time.apply(0))
try {
Checkpoint.serialize(checkpoint, conf)
} catch {
case e: NotSerializableException =>
throw new NotSerializableException(
"DStream checkpointing has been enabled but the DStreams with their functions " +
"are not serializable\nSerialization stack:\n" +
SerializationDebugger.find(checkpoint).map("\t- " + _).mkString("\n")
)
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@

package org.apache.spark.streaming

import java.io.File
import java.io.{File, NotSerializableException}
import java.util.concurrent.atomic.AtomicInteger

import org.apache.commons.io.FileUtils
import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts
import org.scalatest.exceptions.TestFailedDueToTimeoutException
import org.scalatest.time.SpanSugar._
import org.scalatest.{Assertions, BeforeAndAfter, FunSuite}

import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}


class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging {
Expand Down Expand Up @@ -132,6 +132,25 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
}
}

test("start with non-seriazable DStream checkpoints") {
val checkpointDir = Utils.createTempDir()
ssc = new StreamingContext(conf, batchDuration)
ssc.checkpoint(checkpointDir.getAbsolutePath)
addInputStream(ssc).foreachRDD { rdd =>
// Refer to this.appName from inside closure so that this closure refers to
// the instance of StreamingContextSuite, and is therefore not serializable
rdd.count() + appName
}

// Test whether start() fails early when checkpointing is enabled
val exception = intercept[NotSerializableException] {
ssc.start()
}
assert(exception.getMessage().contains("DStreams with their functions are not serializable"))
assert(ssc.getState() !== StreamingContextState.ACTIVE)
assert(StreamingContext.getActive().isEmpty)
}

test("start multiple times") {
ssc = new StreamingContext(master, appName, batchDuration)
addInputStream(ssc).register()
Expand Down