Skip to content

Spark-1230: Enable SparkContext.addJars() to load jars absent from CLASSPATH #351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark

import java.io._
import java.net.URI
import java.net.{URI, URL}
import java.util.concurrent.atomic.AtomicInteger
import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
Expand All @@ -45,7 +45,8 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
import org.apache.spark.util.{Utils, TimeStampedHashMap, MetadataCleaner, MetadataCleanerType,
ClosureCleaner, SparkURLClassLoader}

/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
Expand Down Expand Up @@ -141,6 +142,18 @@ class SparkContext(
// An asynchronous listener bus for Spark events
private[spark] val listenerBus = new LiveListenerBus

// Create a classLoader for use by the driver so that jars added via addJar are available to the
// driver. Do this before all other initialization so that any thread pools created for this
// SparkContext uses the class loader.
// In the future it might make sense to expose this to users so they can assign it as the
// context class loader for other threads.
// Note that this is config-enabled as classloaders can introduce subtle side effects
private[spark] val classLoader = if (conf.getBoolean("spark.driver.loadAddedJars", false)) {
val loader = new SparkURLClassLoader(Array.empty[URL], this.getClass.getClassLoader)
Thread.currentThread.setContextClassLoader(loader)
Some(loader)
} else None

// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.create(
conf,
Expand Down Expand Up @@ -800,6 +813,8 @@ class SparkContext(
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
* NOTE: If you enable spark.driver.loadAddedJars, then the JAR will also be made available
* to this SparkContext and chld threads. local: JARs must be available on the driver node.
*/
def addJar(path: String) {
if (path == null) {
Expand Down Expand Up @@ -841,6 +856,20 @@ class SparkContext(
case _ =>
path
}

// Add jar to driver class loader so it is available for driver,
// even if it is not on the classpath
uri.getScheme match {
case null | "file" | "local" =>
// Assume file exists on current (driver) node as well. Unlike executors, driver
// doesn't need to download the jar since it's local.
addUrlToDriverLoader(new URL("file:" + uri.getPath))
case "http" | "https" | "ftp" =>
// Should be handled by the URLClassLoader, pass along entire URL
addUrlToDriverLoader(new URL(path))
case other =>
logWarning(s"This URI scheme for URI $path is not supported by the driver class loader")
}
}
if (key != null) {
addedJars(key) = System.currentTimeMillis
Expand All @@ -850,6 +879,15 @@ class SparkContext(
postEnvironmentUpdate()
}

private def addUrlToDriverLoader(url: URL) {
classLoader.foreach { loader =>
if (!loader.getURLs.contains(url)) {
logInfo("Adding JAR " + url + " to driver class loader")
loader.addURL(url)
}
}
}

/**
* Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
* any new nodes.
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark

import java.net.URL
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.concurrent.Await
Expand Down Expand Up @@ -132,6 +133,7 @@ object SparkEnv extends Logging {
}

val securityManager = new SecurityManager(conf)
val classLoader = Thread.currentThread.getContextClassLoader

val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf,
securityManager = securityManager)
Expand All @@ -142,8 +144,6 @@ object SparkEnv extends Logging {
conf.set("spark.driver.port", boundPort.toString)
}

val classLoader = Thread.currentThread.getContextClassLoader

// Create an instance of the class named by the given Java system property, or by
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.deploy
import java.io.{PrintStream, File}
import java.net.URL

import org.apache.spark.executor.ExecutorURLClassLoader
import org.apache.spark.util.SparkURLClassLoader

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
Expand Down Expand Up @@ -198,7 +198,7 @@ object SparkSubmit {
System.err.println("\n")
}

val loader = new ExecutorURLClassLoader(new Array[URL](0),
val loader = new SparkURLClassLoader(new Array[URL](0),
Thread.currentThread.getContextClassLoader)
Thread.currentThread.setContextClassLoader(loader)

Expand All @@ -215,7 +215,7 @@ object SparkSubmit {
mainMethod.invoke(null, childArgs.toArray)
}

private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) {
private def addJarToClasspath(localJar: String, loader: SparkURLClassLoader) {
val localJarFile = new File(localJar)
if (!localJarFile.exists()) {
printWarning(s"Jar $localJar does not exist, skipping.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private[spark] object CoarseGrainedExecutorBackend {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
indestructible = true, conf = conf, new SecurityManager(conf))
indestructible = true, conf = conf, securityManager = new SecurityManager(conf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
Expand Down
8 changes: 4 additions & 4 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.{AkkaUtils, Utils}
import org.apache.spark.util.{AkkaUtils, Utils, SparkURLClassLoader}

/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
Expand Down Expand Up @@ -291,15 +291,15 @@ private[spark] class Executor(
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
val loader = Thread.currentThread().getContextClassLoader
private def createClassLoader(): SparkURLClassLoader = {
val loader = this.getClass.getClassLoader

// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
new ExecutorURLClassLoader(urls, loader)
new SparkURLClassLoader(urls, loader)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
taskSetManager.abort(s"ClassNotFound [${cnf.getMessage}] with classloader: " + loader)
case ex: Throwable =>
taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
}
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ private[spark] object AkkaUtils extends Logging {
* of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false,
conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = {
conf: SparkConf, securityManager: SecurityManager,
classLoader: ClassLoader = Thread.currentThread.getContextClassLoader)
: (ActorSystem, Int) = {

val akkaThreads = conf.getInt("spark.akka.threads", 4)
val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15)
Expand Down Expand Up @@ -102,9 +104,9 @@ private[spark] object AkkaUtils extends Logging {
""".stripMargin))

val actorSystem = if (indestructible) {
IndestructibleActorSystem(name, akkaConf)
IndestructibleActorSystem(name, akkaConf, classLoader)
} else {
ActorSystem(name, akkaConf)
ActorSystem(name, akkaConf, classLoader)
}

val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
* limitations under the License.
*/

package org.apache.spark.executor
package org.apache.spark.util

import java.net.{URLClassLoader, URL}

/**
* The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
*/
private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
private[spark] class SparkURLClassLoader(urls: Array[URL], parent: ClassLoader)
extends URLClassLoader(urls, parent) {

override def addURL(url: URL) {
Expand Down
24 changes: 4 additions & 20 deletions core/src/test/scala/org/apache/spark/FileServerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
override def beforeEach() {
super.beforeEach()
resetSparkContext()
System.setProperty("spark.authenticate", "false")
}

override def beforeAll() {
super.beforeAll()
System.setProperty("spark.authenticate", "false")

val tmpDir = new File(Files.createTempDir(), "test")
tmpDir.mkdir()

Expand All @@ -47,27 +48,10 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
pw.close()

val jarFile = new File(tmpDir, "test.jar")
val jarStream = new FileOutputStream(jarFile)
val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest())
System.setProperty("spark.authenticate", "false")

val jarEntry = new JarEntry(textFile.getName)
jar.putNextEntry(jarEntry)

val in = new FileInputStream(textFile)
val buffer = new Array[Byte](10240)
var nRead = 0
while (nRead <= 0) {
nRead = in.read(buffer, 0, buffer.length)
jar.write(buffer, 0, nRead)
}

in.close()
jar.close()
jarStream.close()
val jarUrl = TestUtils.createJar(Seq(textFile), jarFile)

tmpFile = textFile
tmpJarUrl = jarFile.toURI.toURL.toString
tmpJarUrl = jarUrl.toString
}

test("Distributing files locally") {
Expand Down
60 changes: 59 additions & 1 deletion core/src/test/scala/org/apache/spark/FileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark

import java.io.{File, FileWriter}
import java.util.concurrent.Semaphore

import scala.io.Source
import scala.util.Try

import com.google.common.io.Files
import org.apache.hadoop.io._
Expand All @@ -32,6 +34,62 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkContext._

class FileSuite extends FunSuite with LocalSparkContext {
val loader = Thread.currentThread.getContextClassLoader
override def afterEach() {
super.afterEach()
Thread.currentThread.setContextClassLoader(loader)
}

test("adding jars to classpath at the driver") {
val tmpDir = Files.createTempDir()
val classFile = TestUtils.createCompiledClass("HelloSpark", tmpDir)
val jarFile = new File(tmpDir, "test.jar")
TestUtils.createJar(Seq(classFile), jarFile)

def canLoadClass(clazz: String) =
Try(Class.forName(clazz, true, Thread.currentThread().getContextClassLoader)).isSuccess

val loadedBefore = canLoadClass("HelloSpark")

val conf = new SparkConf().setMaster("local-cluster[1,1,512]").setAppName("test")
.set("spark.driver.loadAddedJars", "true")

var driverLoadedAfter = false
var childLoadedAfter = false

val sem = new Semaphore(1)
sem.acquire()

new Thread() {
override def run() {
val sc = new SparkContext(conf)
sc.addJar(jarFile.getAbsolutePath)
driverLoadedAfter = canLoadClass("HelloSpark")

// Test visibility in a child thread
val childSem = new Semaphore(1)
childSem.acquire()
new Thread() {
override def run() {
childLoadedAfter = canLoadClass("HelloSpark")
childSem.release()
}
}.start()

childSem.acquire()
sem.release()
}
}.start()
sem.acquire()

// Test visibility in a parent thread
val parentLoadedAfter = canLoadClass("HelloSpark")

assert(false === loadedBefore, "Class visible before being added")
assert(true === driverLoadedAfter, "Class was not visible after being added")
assert(true === childLoadedAfter, "Class was not visible to child thread after being added")
assert(false === parentLoadedAfter, "Class was visible to parent thread after being added")
}

test("text files") {
sc = new SparkContext("local", "test")
Expand Down Expand Up @@ -106,7 +164,7 @@ class FileSuite extends FunSuite with LocalSparkContext {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x))
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x))
nums.saveAsSequenceFile(outputDir)
// Try reading the output back as a SequenceFile
val output = sc.sequenceFile[IntWritable, Text](outputDir)
Expand Down
3 changes: 2 additions & 1 deletion docs/cluster-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ and `addFile`.
- **hdfs:**, **http:**, **https:**, **ftp:** - these pull down files and JARs from the URI as expected
- **local:** - a URI starting with local:/ is expected to exist as a local file on each worker node. This
means that no network IO will be incurred, and works well for large files/JARs that are pushed to each worker,
or shared via NFS, GlusterFS, etc.
or shared via NFS, GlusterFS, etc. Note that if `spark.driver.loadAddedJars` is set,
then the file must be visible to the node running the SparkContext as well.

Note that JARs and files are copied to the working directory for each SparkContext on the executor nodes.
Over time this can use up a significant amount of space and will need to be cleaned up.
Expand Down
10 changes: 10 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,16 @@ Apart from these, the following properties are also available, and may be useful
Port for the driver to listen on.
</td>
</tr>
<tr>
<td>spark.driver.loadAddedJars</td>
<td>false</td>
<td>
If true, the SparkContext uses a class loader to make jars added via `addJar` available to
the SparkContext. The default behavior is that jars added via `addJar` must already be on
the classpath. Jar contents will be visible to the thread that created the SparkContext
and all of its child threads.
</td>
</tr>
<tr>
<td>spark.cleaner.ttl</td>
<td>(infinite)</td>
Expand Down