Skip to content

Commit e682724

Browse files
committed
Graph should support the checkpoint operation
1 parent 6eb1b6f commit e682724

File tree

7 files changed

+59
-8
lines changed

7 files changed

+59
-8
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,21 +1279,20 @@ abstract class RDD[T: ClassTag](
12791279
}
12801280

12811281
// Avoid handling doCheckpoint multiple times to prevent excessive recursion
1282-
@transient private var doCheckpointCalled = false
1282+
@transient private var doCheckpointCalled = 0
12831283

12841284
/**
12851285
* Performs the checkpointing of this RDD by saving this. It is called after a job using this RDD
12861286
* has completed (therefore the RDD has been materialized and potentially stored in memory).
12871287
* doCheckpoint() is called recursively on the parent RDDs.
12881288
*/
12891289
private[spark] def doCheckpoint() {
1290-
if (!doCheckpointCalled) {
1291-
doCheckpointCalled = true
1292-
if (checkpointData.isDefined) {
1293-
checkpointData.get.doCheckpoint()
1294-
} else {
1295-
dependencies.foreach(_.rdd.doCheckpoint())
1296-
}
1290+
if (checkpointData == None && doCheckpointCalled == 0) {
1291+
dependencies.foreach(_.rdd.doCheckpoint())
1292+
doCheckpointCalled = 1
1293+
} else if (checkpointData.isDefined && doCheckpointCalled < 2) {
1294+
checkpointData.get.doCheckpoint()
1295+
doCheckpointCalled = 2
12971296
}
12981297
}
12991298

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
5555
assert(flatMappedRDD.collect() === result)
5656
}
5757

58+
test("After call count method, checkpoint should also work") {
59+
val parCollection = sc.makeRDD(1 to 4)
60+
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
61+
flatMappedRDD.count
62+
flatMappedRDD.checkpoint()
63+
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
64+
val result = flatMappedRDD.collect()
65+
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
66+
assert(flatMappedRDD.collect() === result)
67+
}
68+
5869
test("RDDs with one-to-one dependencies") {
5970
testRDD(_.map(x => x.toString))
6071
testRDD(_.flatMap(x => 1 to x))

graphx/src/main/scala/org/apache/spark/graphx/Graph.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab
9696
*/
9797
def cache(): Graph[VD, ED]
9898

99+
def checkpoint():Unit
100+
99101
/**
100102
* Uncaches only the vertices of this graph, leaving the edges alone. This is useful in iterative
101103
* algorithms that modify the vertex attributes but reuse the edges. This method can be used to

graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
6464
this
6565
}
6666

67+
override def checkpoint(): Unit = {
68+
partitionsRDD.checkpoint()
69+
}
70+
6771
/** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
6872
override def cache(): this.type = {
6973
partitionsRDD.persist(targetStorageLevel)

graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected (
6565
this
6666
}
6767

68+
override def checkpoint(): Unit = {
69+
vertices.checkpoint()
70+
replicatedVertexView.edges.checkpoint()
71+
}
72+
6873
override def unpersistVertices(blocking: Boolean = true): Graph[VD, ED] = {
6974
vertices.unpersist(blocking)
7075
// TODO: unpersist the replicated vertices in `replicatedVertexView` but leave the edges alone

graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ class VertexRDDImpl[VD] private[graphx] (
6565
this
6666
}
6767

68+
override def checkpoint(): Unit = {
69+
partitionsRDD.checkpoint()
70+
}
71+
6872
/** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */
6973
override def cache(): this.type = {
7074
partitionsRDD.persist(targetStorageLevel)

graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.graphx
1919

2020
import org.scalatest.FunSuite
2121

22+
import com.google.common.io.Files
23+
2224
import org.apache.spark.SparkContext
2325
import org.apache.spark.graphx.Graph._
2426
import org.apache.spark.graphx.PartitionStrategy._
@@ -365,4 +367,28 @@ class GraphSuite extends FunSuite with LocalSparkContext {
365367
}
366368
}
367369

370+
test("checkpoint") {
371+
val checkpointDir = Files.createTempDir()
372+
checkpointDir.deleteOnExit()
373+
withSpark { sc =>
374+
sc.setCheckpointDir(checkpointDir.getAbsolutePath)
375+
val ring = (0L to 100L).zip((1L to 99L) :+ 0L).map { case (a, b) => Edge(a, b, 1)}
376+
val rdd = sc.parallelize(ring)
377+
val graph = Graph.fromEdges(rdd, 1.0F)
378+
graph.checkpoint()
379+
val edgesDependencies = graph.edges.partitionsRDD.dependencies
380+
val verticesDependencies = graph.vertices.partitionsRDD.dependencies
381+
val edges = graph.edges.collect().map(_.attr)
382+
val vertices = graph.vertices.collect().map(_._2)
383+
384+
graph.vertices.count()
385+
graph.edges.count()
386+
387+
assert(graph.edges.partitionsRDD.dependencies != edgesDependencies)
388+
assert(graph.vertices.partitionsRDD.dependencies != verticesDependencies)
389+
assert(graph.vertices.collect().map(_._2) === vertices)
390+
assert(graph.edges.collect().map(_.attr) === edges)
391+
}
392+
}
393+
368394
}

0 commit comments

Comments
 (0)