Skip to content

Commit 35d6a99

Browse files
jacek-lewandowskiJoshRosen
authored andcommitted
[SPARK-7436] Fixed instantiation of custom recovery mode factory and added tests
Author: Jacek Lewandowski <[email protected]> Closes #5977 from jacek-lewandowski/SPARK-7436 and squashes the following commits: ff0a3c2 [Jacek Lewandowski] SPARK-7436: Fixed instantiation of custom recovery mode factory and added tests
1 parent 008a60d commit 35d6a99

File tree

3 files changed

+208
-4
lines changed

3 files changed

+208
-4
lines changed

core/src/main/scala/org/apache/spark/deploy/master/Master.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ private[master] class Master(
165165
(fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
166166
case "CUSTOM" =>
167167
val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
168-
val factory = clazz.getConstructor(conf.getClass, Serialization.getClass)
168+
val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization])
169169
.newInstance(conf, SerializationExtension(context.system))
170170
.asInstanceOf[StandaloneRecoveryModeFactory]
171171
(factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
// This file is placed in different package to make sure all of these components work well
19+
// when they are outside of org.apache.spark.
20+
package other.supplier
21+
22+
import scala.collection.mutable
23+
import scala.reflect.ClassTag
24+
25+
import akka.serialization.Serialization
26+
27+
import org.apache.spark.SparkConf
28+
import org.apache.spark.deploy.master._
29+
30+
class CustomRecoveryModeFactory(
31+
conf: SparkConf,
32+
serialization: Serialization
33+
) extends StandaloneRecoveryModeFactory(conf, serialization) {
34+
35+
CustomRecoveryModeFactory.instantiationAttempts += 1
36+
37+
/**
38+
* PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
39+
* is handled for recovery.
40+
*
41+
*/
42+
override def createPersistenceEngine(): PersistenceEngine =
43+
new CustomPersistenceEngine(serialization)
44+
45+
/**
46+
* Create an instance of LeaderAgent that decides who gets elected as master.
47+
*/
48+
override def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent =
49+
new CustomLeaderElectionAgent(master)
50+
}
51+
52+
object CustomRecoveryModeFactory {
53+
@volatile var instantiationAttempts = 0
54+
}
55+
56+
class CustomPersistenceEngine(serialization: Serialization) extends PersistenceEngine {
57+
val data = mutable.HashMap[String, Array[Byte]]()
58+
59+
CustomPersistenceEngine.lastInstance = Some(this)
60+
61+
/**
62+
* Defines how the object is serialized and persisted. Implementation will
63+
* depend on the store used.
64+
*/
65+
override def persist(name: String, obj: Object): Unit = {
66+
CustomPersistenceEngine.persistAttempts += 1
67+
serialization.serialize(obj) match {
68+
case util.Success(bytes) => data += name -> bytes
69+
case util.Failure(cause) => throw new RuntimeException(cause)
70+
}
71+
}
72+
73+
/**
74+
* Defines how the object referred by its name is removed from the store.
75+
*/
76+
override def unpersist(name: String): Unit = {
77+
CustomPersistenceEngine.unpersistAttempts += 1
78+
data -= name
79+
}
80+
81+
/**
82+
* Gives all objects, matching a prefix. This defines how objects are
83+
* read/deserialized back.
84+
*/
85+
override def read[T: ClassTag](prefix: String): Seq[T] = {
86+
CustomPersistenceEngine.readAttempts += 1
87+
val clazz = implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]
88+
val results = for ((name, bytes) <- data; if name.startsWith(prefix))
89+
yield serialization.deserialize(bytes, clazz)
90+
91+
results.find(_.isFailure).foreach {
92+
case util.Failure(cause) => throw new RuntimeException(cause)
93+
}
94+
95+
results.flatMap(_.toOption).toSeq
96+
}
97+
}
98+
99+
object CustomPersistenceEngine {
100+
@volatile var persistAttempts = 0
101+
@volatile var unpersistAttempts = 0
102+
@volatile var readAttempts = 0
103+
104+
@volatile var lastInstance: Option[CustomPersistenceEngine] = None
105+
}
106+
107+
class CustomLeaderElectionAgent(val masterActor: LeaderElectable) extends LeaderElectionAgent {
108+
masterActor.electedLeader()
109+
}
110+

core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala

Lines changed: 97 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,20 @@
1717

1818
package org.apache.spark.deploy.master
1919

20+
import java.util.Date
21+
22+
import scala.concurrent.Await
23+
import scala.concurrent.duration._
24+
import scala.language.postfixOps
25+
2026
import akka.actor.Address
21-
import org.scalatest.FunSuite
27+
import org.scalatest.{FunSuite, Matchers}
28+
import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
2229

23-
import org.apache.spark.{SSLOptions, SparkConf, SparkException}
30+
import org.apache.spark.deploy._
31+
import org.apache.spark.{SparkConf, SparkException}
2432

25-
class MasterSuite extends FunSuite {
33+
class MasterSuite extends FunSuite with Matchers {
2634

2735
test("toAkkaUrl") {
2836
val conf = new SparkConf(loadDefaults = false)
@@ -63,4 +71,90 @@ class MasterSuite extends FunSuite {
6371
}
6472
assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage)
6573
}
74+
75+
test("can use a custom recovery mode factory") {
76+
val conf = new SparkConf(loadDefaults = false)
77+
conf.set("spark.deploy.recoveryMode", "CUSTOM")
78+
conf.set("spark.deploy.recoveryMode.factory",
79+
classOf[CustomRecoveryModeFactory].getCanonicalName)
80+
81+
val instantiationAttempts = CustomRecoveryModeFactory.instantiationAttempts
82+
83+
val commandToPersist = new Command(
84+
mainClass = "",
85+
arguments = Nil,
86+
environment = Map.empty,
87+
classPathEntries = Nil,
88+
libraryPathEntries = Nil,
89+
javaOpts = Nil
90+
)
91+
92+
val appToPersist = new ApplicationInfo(
93+
startTime = 0,
94+
id = "test_app",
95+
desc = new ApplicationDescription(
96+
name = "",
97+
maxCores = None,
98+
memoryPerExecutorMB = 0,
99+
command = commandToPersist,
100+
appUiUrl = "",
101+
eventLogDir = None,
102+
eventLogCodec = None,
103+
coresPerExecutor = None),
104+
submitDate = new Date(),
105+
driver = null,
106+
defaultCores = 0
107+
)
108+
109+
val driverToPersist = new DriverInfo(
110+
startTime = 0,
111+
id = "test_driver",
112+
desc = new DriverDescription(
113+
jarUrl = "",
114+
mem = 0,
115+
cores = 0,
116+
supervise = false,
117+
command = commandToPersist
118+
),
119+
submitDate = new Date()
120+
)
121+
122+
val workerToPersist = new WorkerInfo(
123+
id = "test_worker",
124+
host = "127.0.0.1",
125+
port = 10000,
126+
cores = 0,
127+
memory = 0,
128+
actor = null,
129+
webUiPort = 0,
130+
publicAddress = ""
131+
)
132+
133+
val (actorSystem, port, uiPort, restPort) =
134+
Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf)
135+
136+
try {
137+
Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds)
138+
139+
CustomPersistenceEngine.lastInstance.isDefined shouldBe true
140+
val persistenceEngine = CustomPersistenceEngine.lastInstance.get
141+
142+
persistenceEngine.addApplication(appToPersist)
143+
persistenceEngine.addDriver(driverToPersist)
144+
persistenceEngine.addWorker(workerToPersist)
145+
146+
val (apps, drivers, workers) = persistenceEngine.readPersistedData()
147+
148+
apps.map(_.id) should contain(appToPersist.id)
149+
drivers.map(_.id) should contain(driverToPersist.id)
150+
workers.map(_.id) should contain(workerToPersist.id)
151+
152+
} finally {
153+
actorSystem.shutdown()
154+
actorSystem.awaitTermination()
155+
}
156+
157+
CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts
158+
}
159+
66160
}

0 commit comments

Comments
 (0)