diff --git a/.rat-excludes b/.rat-excludes index eaefef1b0aa2e..22b38a335fa69 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -57,3 +57,4 @@ dist/* .*iws logs .*scalastyle-output.xml +ssl.conf.template diff --git a/conf/ssl.conf.template b/conf/ssl.conf.template new file mode 100644 index 0000000000000..403c18c00a2a2 --- /dev/null +++ b/conf/ssl.conf.template @@ -0,0 +1,10 @@ +# Spark SSL settings + +# ssl.enabled true +# ssl.keyStore /path/to/your/keyStore +# ssl.keyStorePassword password +# ssl.keyPassword password +# ssl.trustStore /path/to/your/trustStore +# ssl.trustStorePassword password +# ssl.enabledAlgorithms [TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA] +# ssl.protocol TLSv1.2 diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index edc3889c9ae51..13e79531ed627 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -25,7 +25,8 @@ import org.apache.spark.util.Utils private[spark] class HttpFileServer( securityManager: SecurityManager, - requestedPort: Int = 0) + requestedPort: Int = 0, + conf: SparkConf) extends Logging { var baseDir : File = null @@ -41,7 +42,7 @@ private[spark] class HttpFileServer( fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server") + httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server", conf) httpServer.start() serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 912558d0cab7d..459628bd8cb53 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.File +import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} @@ -45,7 +46,8 @@ private[spark] class HttpServer( resourceBase: File, securityManager: SecurityManager, requestedPort: Int = 0, - serverName: String = "HTTP server") + serverName: String = "HTTP server", + conf: SparkConf) extends Logging { private var server: Server = null @@ -71,7 +73,10 @@ private[spark] class HttpServer( */ private def doStart(startPort: Int): (Server, Int) = { val server = new Server() - val connector = new SocketConnector + + val connector = securityManager.sslOptions.createJettySslContextFactory() + .map(new SslSocketConnector(_)).getOrElse(new SocketConnector) + connector.setMaxIdleTime(60 * 1000) connector.setSoLingerTime(-1) connector.setPort(startPort) @@ -148,13 +153,14 @@ private[spark] class HttpServer( } /** - * Get the URI of this HTTP server (http://host:port) + * Get the URI of this HTTP server (http://host:port or https://host:port) */ def uri: String = { if (server == null) { throw new ServerStateException("Server is not started") } else { - "http://" + Utils.localIpAddress + ":" + port + val scheme = if (securityManager.sslOptions.enabled) "https" else "http" + s"$scheme://${Utils.localIpAddress}:$port" } } } diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala new file mode 100644 index 0000000000000..1804359766ab2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{FileReader, File} +import java.util.Properties + +import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} +import org.eclipse.jetty.util.ssl.SslContextFactory + +import scala.util.Try + +case class SSLOptions(enabled: Boolean = false, + keyStore: Option[File] = None, + keyStorePassword: Option[String] = None, + keyPassword: Option[String] = None, + trustStore: Option[File] = None, + trustStorePassword: Option[String] = None, + protocol: Option[String] = None, + enabledAlgorithms: Set[String] = Set.empty) { + + /** + * Creates a Jetty SSL context factory according to the SSL settings represented by this object. + */ + def createJettySslContextFactory(): Option[SslContextFactory] = { + if (enabled) { + val sslContextFactory = new SslContextFactory() + + keyStore.foreach(file => sslContextFactory.setKeyStorePath(file.getAbsolutePath)) + trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath)) + keyStorePassword.foreach(sslContextFactory.setKeyStorePassword) + trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) + keyPassword.foreach(sslContextFactory.setKeyManagerPassword) + protocol.foreach(sslContextFactory.setProtocol) + sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + + Some(sslContextFactory) + } else { + None + } + } + + /** + * Creates an Akka configuration object which contains all the SSL settings represented by this + * object. It can be used then to compose the ultimate Akka configuration. + */ + def createAkkaConfig: Option[Config] = { + import scala.collection.JavaConversions._ + if (enabled) { + Some(ConfigFactory.empty() + .withValue("akka.remote.netty.tcp.security.key-store", + ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.key-store-password", + ConfigValueFactory.fromAnyRef(keyStorePassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.trust-store", + ConfigValueFactory.fromAnyRef(trustStore.map(_.getAbsolutePath).getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.trust-store-password", + ConfigValueFactory.fromAnyRef(trustStorePassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.key-password", + ConfigValueFactory.fromAnyRef(keyPassword.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.random-number-generator", + ConfigValueFactory.fromAnyRef("")) + .withValue("akka.remote.netty.tcp.security.protocol", + ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) + .withValue("akka.remote.netty.tcp.security.enabled-algorithms", + ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + .withValue("akka.remote.netty.tcp.enable-ssl", + ConfigValueFactory.fromAnyRef(true))) + } else { + None + } + } + +} + +object SSLOptions extends Logging { + + /** + * Resolves the SSL configuration file location by checking: + * - SPARK_SSL_CONFIG_FILE env variable + * - SPARK_CONF_DIR/ssl.conf + * - SPARK_HOME/conf/ssl.conf + */ + def defaultConfigFile: Option[File] = { + val specifiedFile = Option(System.getenv("SPARK_SSL_CONFIG_FILE")).map(new File(_)) + val sparkConfDir = Option(System.getenv("SPARK_CONF_DIR")).map(new File(_)) + val sparkHomeConfDir = Option(System.getenv("SPARK_HOME")) + .map(new File(_, "conf")) + val defaultFile = (sparkConfDir orElse sparkHomeConfDir).map(new File(_, "ssl.conf")) + + specifiedFile orElse defaultFile + } + + /** + * Loads the given properties file with failover to empty Properties object. + */ + def load(configFile: File): Properties = { + logInfo(s"Loading SSL configuration from $configFile") + try { + val props = new Properties() + val reader = new FileReader(configFile) + try { + props.load(reader) + props.put("sslConfigurationFileLocation", configFile.getAbsolutePath) + props + } finally { + reader.close() + } + } catch { + case ex: Throwable => + logWarning(s"The SSL configuration file ${configFile.getAbsolutePath} " + + s"could not be loaded. The underlying exception was: ${ex.getMessage}") + new Properties + } + } + + /** + * Resolves SSLOptions settings from a given Spark configuration object at a given namespace. + * If SSL settings were loaded from the configuration file, ``sslConfigurationFileLocation`` + * property is present in the Spark configuration. The parent directory of that location is used + * as a base directory to resolve relative paths to keystore and truststore. + */ + def parse(conf: SparkConf, ns: String): SSLOptions = { + val parentDir = conf.getOption("sslConfigurationFileLocation").map(new File(_).getParentFile) + .getOrElse(new File(".")).toPath + + val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = false) + val keyStore = Try(conf.get(s"$ns.keyStore")).toOption.map(parentDir.resolve(_).toFile) + val keyStorePassword = Try(conf.get(s"$ns.keyStorePassword")).toOption + val keyPassword = Try(conf.get(s"$ns.keyPassword")).toOption + val trustStore = Try(conf.get(s"$ns.trustStore")).toOption.map(parentDir.resolve(_).toFile) + val trustStorePassword = Try(conf.get(s"$ns.trustStorePassword")).toOption + val protocol = Try(conf.get(s"$ns.protocol")).toOption + val enabledAlgorithms = Try(conf.get(s"$ns.enabledAlgorithms")).toOption + .map(_.trim.dropWhile(_ == '[') + .takeWhile(_ != ']')).map(_.split(",").map(_.trim).toSet) + .getOrElse(Set.empty) + + new SSLOptions(enabled, keyStore, keyStorePassword, keyPassword, trustStore, trustStorePassword, + protocol, enabledAlgorithms) + } + + /** + * Loads the SSL configuration file. If ``spark.ssl.configFile`` property is in the system + * properties, it is assumed it contains the SSL configuration file location to be used. + * Otherwise, it uses the location returned by [[SSLOptions.defaultConfigFile]]. + */ + def load(): Properties = { + val file = Option(System.getProperty("spark.ssl.configFile")) + .map(new File(_)) orElse defaultConfigFile + + file.fold { + logWarning("SSL configuration file not found. SSL will be disabled.") + new Properties() + } { file => + logInfo(s"Loading SSL configuration from ${file.getAbsolutePath}") + load(file) + } + } + +} + diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 48c4e515885ea..8c4fe1f7b2d55 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -18,7 +18,11 @@ package org.apache.spark import java.net.{Authenticator, PasswordAuthentication} +import java.security.KeyStore +import java.security.cert.X509Certificate +import javax.net.ssl._ +import com.google.common.io.Files import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil @@ -192,6 +196,43 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { ) } + private[spark] val sslOptions = SSLOptions.parse(sparkConf, "ssl") + + private[spark] val (sslSocketFactory, hostnameVerifier) = if (sslOptions.enabled) { + val trustStoreManagers = + for (trustStore <- sslOptions.trustStore) yield { + val ks = KeyStore.getInstance(KeyStore.getDefaultType) + ks.load(Files.asByteSource(sslOptions.trustStore.get).openStream(), + sslOptions.trustStorePassword.get.toCharArray) + + val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm) + tmf.init(ks) + tmf.getTrustManagers + } + + lazy val credulousTrustStoreManagers = Array({ + logWarning("Using 'accept-all' trust manager for SSL connections.") + new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + + override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {} + + override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {} + }: TrustManager + }) + + val sslContext = SSLContext.getInstance(sslOptions.protocol.getOrElse("Default")) + sslContext.init(null, trustStoreManagers getOrElse credulousTrustStoreManagers, null) + + val hostVerifier = new HostnameVerifier { + override def verify(s: String, sslSession: SSLSession): Boolean = true + } + + (Some(sslContext.getSocketFactory), Some(hostVerifier)) + } else { + (None, None) + } + /** * Split a comma separated String, filter out any empty items, and return a Set of strings */ diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 605df0e929faa..15099636031d0 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -48,6 +48,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { private[spark] val settings = new HashMap[String, String]() if (loadDefaults) { + // Load SSL settings from SSL configuration file + for ((k, v) <- SSLOptions.load().asScala) { + settings(k) = v + } + // Load any spark.* system properties for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) { settings(k) = v diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72716567ca99b..f86759dc93263 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -239,7 +239,7 @@ object SparkEnv extends Logging { val httpFileServer = if (isDriver) { val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(securityManager, fileServerPort) + val server = new HttpFileServer(securityManager, fileServerPort, conf) server.initialize() conf.set("spark.fileserver.uri", server.serverUri) server diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 4cd4f4f96fd16..9ffa85953892b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -153,7 +153,8 @@ private[broadcast] object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") + server = new HttpServer(broadcastDir, securityManager, + broadcastPort, "HTTP broadcast server", conf) server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) @@ -196,6 +197,7 @@ private[broadcast] object HttpBroadcast extends Logging { logDebug("broadcast not using security") uc = new URL(url).openConnection() } + Utils.setupSecureURLConnection(uc, securityManager) val in = { uc.setReadTimeout(httpReadTimeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..3b1eb880525c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -39,7 +39,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) val timeout = AkkaUtils.askTimeout(conf) override def preStart() = { - masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) + masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master, conf)) context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..5e438369f280d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -77,7 +77,7 @@ private[spark] class AppClient( def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) + val actor = context.actorSelection(Master.toAkkaUrl(masterUrl, conf)) actor ! RegisterApplication(appDescription) } } @@ -104,17 +104,17 @@ private[spark] class AppClient( def changeMaster(url: String) { activeMasterUrl = url - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl, conf)) masterAddress = activeMasterUrl match { case Master.sparkUrlRegex(host, port) => - Address("akka.tcp", Master.systemName, host, port.toInt) + Address(AkkaUtils.protocol(conf), Master.systemName, host, port.toInt) case x => throw new SparkException("Invalid spark URL: " + x) } } private def isPossibleMaster(remoteUrl: Address) = { - masterUrls.map(s => Master.toAkkaUrl(s)) + masterUrls.map(s => Master.toAkkaUrl(s, conf)) .map(u => AddressFromURIString(u).hostPort) .contains(remoteUrl.hostPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 8d99ed442604f..b4d181a7d42e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -798,10 +798,10 @@ private[spark] object Master extends Logging { } /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ - def toAkkaUrl(sparkUrl: String): String = { + def toAkkaUrl(sparkUrl: String, conf: SparkConf): String = { sparkUrl match { case sparkUrlRegex(host, port) => - "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + AkkaUtils.address(systemName, host, port, actorName, conf) case _ => throw new SparkException("Invalid master URL: " + sparkUrl) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index acc83d52ce98f..f801978d7a2a8 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -77,7 +77,7 @@ private[spark] class Worker( var masterAddress: Address = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" - val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) + val akkaUrl = AkkaUtils.address(actorSystemName, host, port, actorName, conf) @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() @@ -148,10 +148,10 @@ private[spark] class Worker( masterLock.synchronized { activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl, conf)) masterAddress = activeMasterUrl match { case Master.sparkUrlRegex(_host, _port) => - Address("akka.tcp", Master.systemName, _host, _port.toInt) + Address(AkkaUtils.protocol(conf), Master.systemName, _host, _port.toInt) case x => throw new SparkException("Invalid spark URL: " + x) } @@ -162,7 +162,7 @@ private[spark] class Worker( def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) + val actor = context.actorSelection(Master.toAkkaUrl(masterUrl, conf)) actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 04046e2e5d11d..0e1a247179a46 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -27,7 +27,7 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} +import org.apache.spark._ import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index c1b0da4b99cf2..b3c6e0fc9a54f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -38,7 +38,7 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = "%s://%s@%s:%s/user/%s".format( SparkEnv.driverActorSystemName, sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port"), diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index c1d5ce0a36075..5eb01c6b7be4e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -21,7 +21,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl} -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, @@ -42,11 +42,12 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME, + conf) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 037fea5854ca3..f321da1339839 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -21,6 +21,8 @@ import java.io.File import java.util.{List => JList} import java.util.Collections +import org.apache.spark.util.AkkaUtils + import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} @@ -138,11 +140,12 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( SparkEnv.driverActorSystemName, conf.get("spark.driver.host"), conf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME, + conf) val uri = conf.get("spark.executor.uri", null) if (uri == null) { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index e2d32c859bbda..71bf3b16d288d 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -27,7 +27,7 @@ import akka.pattern.ask import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, SSLOptions} /** * Various utility classes for working with Akka. @@ -91,8 +91,10 @@ private[spark] object AkkaUtils extends Logging { val secureCookie = if (isAuthOn) secretKey else "" logDebug("In createActorSystem, requireCookie is: " + requireCookie) - val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( - ConfigFactory.parseString( + val akkaSslConfig = securityManager.sslOptions.createAkkaConfig.getOrElse(ConfigFactory.empty()) + + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]) + .withFallback(akkaSslConfig).withFallback(ConfigFactory.parseString( s""" |akka.daemonic = on |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] @@ -196,9 +198,22 @@ private[spark] object AkkaUtils extends Logging { val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - val url = s"akka.tcp://$driverActorSystemName@$driverHost:$driverPort/user/$name" + val url = address(driverActorSystemName, driverHost, driverPort, name, conf) val timeout = AkkaUtils.lookupTimeout(conf) logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def protocol(conf: SparkConf): String = { + if (conf.getBoolean("ssl.enabled", defaultValue = false)) { + "akka.ssl.tcp" + } else { + "akka.tcp" + } + } + + def address(systemName: String, host: String, port: Any, actorName: String, + conf: SparkConf): String = { + s"${protocol(conf)}://$systemName@$host:$port/user/$actorName" + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 3f0a80b95649c..5cb73fcbabd1e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,6 +22,7 @@ import java.net._ import java.nio.ByteBuffer import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} +import javax.net.ssl.HttpsURLConnection import org.apache.log4j.PropertyConfigurator @@ -376,6 +377,7 @@ private[spark] object Utils extends Logging { logDebug("fetchFile not using security") uc = new URL(url).openConnection() } + Utils.setupSecureURLConnection(uc, securityMgr) val timeout = conf.getInt("spark.files.fetchTimeout", 60) * 1000 uc.setConnectTimeout(timeout) @@ -1516,6 +1518,20 @@ private[spark] object Utils extends Logging { PropertyConfigurator.configure(pro) } + /** + * If the given URL connection is HttpsURLConnection, it sets the SSL socket factory and + * the host verifier from the given security manager. + */ + def setupSecureURLConnection(urlConnection: URLConnection, sm: SecurityManager): URLConnection = { + urlConnection match { + case https: HttpsURLConnection => + sm.sslSocketFactory.foreach(https.setSSLSocketFactory) + sm.hostnameVerifier.foreach(https.setHostnameVerifier) + https + case connection => connection + } + } + } /** diff --git a/core/src/test/resources/bad-ssl.conf b/core/src/test/resources/bad-ssl.conf new file mode 100644 index 0000000000000..968f974632a8b --- /dev/null +++ b/core/src/test/resources/bad-ssl.conf @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Spark SSL settings + +ssl.enabled true +ssl.keyStore untrusted-keystore +ssl.keyStorePassword password +ssl.keyPassword password +ssl.trustStore truststore +ssl.trustStorePassword password +ssl.enabledAlgorithms [TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA] +ssl.protocol TLSv1.2 diff --git a/core/src/test/resources/good-ssl.conf b/core/src/test/resources/good-ssl.conf new file mode 100644 index 0000000000000..2e5057620d80b --- /dev/null +++ b/core/src/test/resources/good-ssl.conf @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Spark SSL settings + +ssl.enabled true +ssl.keyStore keystore +ssl.keyStorePassword password +ssl.keyPassword password +ssl.trustStore truststore +ssl.trustStorePassword password +ssl.enabledAlgorithms [TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA] +ssl.protocol TLSv1.2 diff --git a/core/src/test/resources/keystore b/core/src/test/resources/keystore new file mode 100644 index 0000000000000..f8310e39ba1e0 Binary files /dev/null and b/core/src/test/resources/keystore differ diff --git a/core/src/test/resources/truststore b/core/src/test/resources/truststore new file mode 100644 index 0000000000000..a6b1d46e1f391 Binary files /dev/null and b/core/src/test/resources/truststore differ diff --git a/core/src/test/resources/untrusted-keystore b/core/src/test/resources/untrusted-keystore new file mode 100644 index 0000000000000..6015b02caa128 Binary files /dev/null and b/core/src/test/resources/untrusted-keystore differ diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 7e18f45de7b5b..6753cb934f00e 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -18,9 +18,13 @@ package org.apache.spark import java.io._ +import java.net.URI import java.util.jar.{JarEntry, JarOutputStream} +import javax.net.ssl.SSLHandshakeException import com.google.common.io.Files +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.commons.lang3.RandomUtils import org.scalatest.FunSuite import org.apache.spark.SparkContext._ @@ -175,4 +179,100 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } } + test ("HttpFileServer should work with SSL") { + val sparkConf = sparkSSLConfig() + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sm, 0, sparkConf) + try { + server.initialize() + + fileTransferTest(server, sm) + } finally { + server.stop() + } + } + + test ("HttpFileServer should work with SSL and good credentials") { + val sparkConf = sparkSSLConfig() + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sm, 0, sparkConf) + try { + server.initialize() + + fileTransferTest(server, sm) + } finally { + server.stop() + } + } + + test ("HttpFileServer should not work with valid SSL and bad credentials") { + val sparkConf = sparkSSLConfig() + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "bad") + + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sm, 0, sparkConf) + try { + server.initialize() + + intercept[IOException] { + fileTransferTest(server) + } + } finally { + server.stop() + } + } + + test ("HttpFileServer should not work with SSL when the server is untrusted") { + val sparkConf = sparkSSLConfigUntrusted() + val sm = new SecurityManager(sparkConf) + val server = new HttpFileServer(sm, 0, sparkConf) + try { + server.initialize() + + intercept[SSLHandshakeException] { + fileTransferTest(server) + } + } finally { + server.stop() + } + } + + def sparkSSLConfig() = { + System.setProperty("spark.ssl.configFile", getClass.getResource("/good-ssl.conf").getPath) + val conf = new SparkConf + conf + } + + def sparkSSLConfigUntrusted() = { + System.setProperty("spark.ssl.configFile", getClass.getResource("/bad-ssl.conf").getPath) + val conf = new SparkConf + conf + } + + def fileTransferTest(server: HttpFileServer, sm: SecurityManager = null): Unit = { + val randomContent = RandomUtils.nextBytes(100) + val file = File.createTempFile("FileServerSuite", "sslTests", tmpDir) + FileUtils.writeByteArrayToFile(file, randomContent) + server.addFile(file) + + val uri = new URI(server.serverUri + "/files/" + file.getName) + + val connection = if (sm != null && sm.isAuthenticationEnabled()) { + Utils.constructURIForAuthentication(uri, sm).toURL.openConnection() + } else { + uri.toURL.openConnection() + } + + if (sm != null) { + Utils.setupSecureURLConnection(connection, sm) + } + + val buf = IOUtils.toByteArray(connection.getInputStream) + assert(buf === randomContent) + } + } diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 53e367a61715b..6f78a7b7bcbba 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -38,6 +38,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self } def resetSparkContext() = { + System.clearProperty("spark.ssl.configFile") LocalSparkContext.stop(sc) sc = null } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 9702838085627..efcc60ff235c2 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -133,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala new file mode 100644 index 0000000000000..c180b60b92c5a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{FileInputStream, FileOutputStream, PrintWriter, File} +import java.util.jar.{JarEntry, JarOutputStream} + +import com.google.common.io.Files +import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.commons.lang3.RandomUtils +import org.apache.spark.util.Utils +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import scala.util.Random + +class SSLOptionsSuite extends FunSuite with BeforeAndAfterAll { + + @transient var tmpDir: File = _ + + override def beforeAll() { + super.beforeAll() + + tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + } + + override def afterAll() { + super.afterAll() + Utils.deleteRecursively(tmpDir) + } + + test("test loading existing property file") { + val file = File.createTempFile("SSLOptionsSuite", "conf", tmpDir) + FileUtils.write(file, + """ + |ssl.some.property someValue + """.stripMargin) + + val props = SSLOptions.load(file) + assert(props.get("ssl.some.property") === "someValue") + assert(props.get("sslConfigurationFileLocation") === file.getAbsolutePath) + } + + test("test loading not existing property file") { + val file = File.createTempFile("SSLOptionsSuite", "conf", tmpDir) + FileUtils.write(file, + """ + |ssl.some.property someValue + """.stripMargin) + + val props = SSLOptions.load(new File(file.getParentFile, Random.nextString(10))) + assert(props.get("ssl.some.property") === null) + assert(props.get("sslConfigurationFileLocation") === null) + } + + test("test loading existing property file by sprecifying it in system properties") { + val file = File.createTempFile("SSLOptionsSuite", "conf", tmpDir) + FileUtils.write(file, + """ + |ssl.some.property someValue + """.stripMargin) + + System.setProperty("spark.ssl.configFile", file.getAbsolutePath) + val props = SSLOptions.load(file) + assert(props.get("ssl.some.property") === "someValue") + assert(props.get("sslConfigurationFileLocation") === file.getAbsolutePath) + } + + test("test resolving property file as spark conf ") { + val file = File.createTempFile("SSLOptionsSuite", "conf", tmpDir) + FileUtils.write(file, + """ + |ssl.enabled true + |ssl.keyStore keystore + |ssl.keyStorePassword password + |ssl.keyPassword password + |ssl.trustStore truststore + |ssl.trustStorePassword password + |ssl.enabledAlgorithms [TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA] + |ssl.protocol SSLv3 + """.stripMargin) + + System.setProperty("spark.ssl.configFile", file.getAbsolutePath) + val opts = SSLOptions.parse(new SparkConf(), "ssl") + + assert(opts.enabled === true) + assert(opts.trustStore.isDefined === true) + assert(opts.trustStore.get.getName === "truststore") + assert(opts.trustStore.get.getParent === file.getParent) + assert(opts.keyStore.isDefined === true) + assert(opts.keyStore.get.getName === "keystore") + assert(opts.keyStore.get.getParent === file.getParent) + assert(opts.trustStorePassword === Some("password")) + assert(opts.keyStorePassword === Some("password")) + assert(opts.keyPassword === Some("password")) + assert(opts.protocol === Some("SSLv3")) + assert(opts.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + +} diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index fcca0867b8072..a7f5020078451 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -17,8 +17,9 @@ package org.apache.spark -import scala.collection.mutable.ArrayBuffer +import java.io.File +import org.apache.commons.io.FileUtils import org.scalatest.FunSuite class SecurityManagerSuite extends FunSuite { @@ -125,6 +126,54 @@ class SecurityManagerSuite extends FunSuite { } + test("ssl on setup") { + val file = File.createTempFile("SSLOptionsSuite", "conf") + file.deleteOnExit() + FileUtils.write(file, + s""" + |ssl.enabled true + |ssl.keyStore ${getClass.getResource("/keystore").getPath} + |ssl.keyStorePassword password + |ssl.keyPassword password + |ssl.trustStore ${getClass.getResource("/truststore").getPath} + |ssl.trustStorePassword password + |ssl.enabledAlgorithms [TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA] + |ssl.protocol SSLv3 + """.stripMargin) + + System.setProperty("spark.ssl.configFile", file.getAbsolutePath) + val conf = new SparkConf() + + val securityManager = new SecurityManager(conf) + + assert(securityManager.sslOptions.enabled === true) + assert(securityManager.sslSocketFactory.isDefined === true) + assert(securityManager.hostnameVerifier.isDefined === true) + + assert(securityManager.sslOptions.trustStore.isDefined === true) + assert(securityManager.sslOptions.trustStore.get.getName === "truststore") + assert(securityManager.sslOptions.keyStore.isDefined === true) + assert(securityManager.sslOptions.keyStore.get.getName === "keystore") + assert(securityManager.sslOptions.trustStorePassword === Some("password")) + assert(securityManager.sslOptions.keyStorePassword === Some("password")) + assert(securityManager.sslOptions.keyPassword === Some("password")) + assert(securityManager.sslOptions.protocol === Some("SSLv3")) + assert(securityManager.sslOptions.enabledAlgorithms === Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + } + + test("ssl off setup") { + val file = File.createTempFile("SSLOptionsSuite", "conf") + file.deleteOnExit() + + System.setProperty("spark.ssl.configFile", file.getAbsolutePath) + val conf = new SparkConf() + + val securityManager = new SecurityManager(conf) + + assert(securityManager.sslOptions.enabled === false) + assert(securityManager.sslSocketFactory.isDefined === false) + assert(securityManager.hostnameVerifier.isDefined === false) + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala index 4161aede1d1d0..27d6f04496ac7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.Matchers class ClientSuite extends FunSuite with Matchers { test("correctly validates driver jar URL's") { ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true) + ClientArguments.isValidJarUrl("https://someHost:8080/foo.jar") should be (true) ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true) ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo.jar") should be (true) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index c4765e53de17b..9a4ca137b39e9 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import java.util.concurrent.TimeoutException + import akka.actor._ import org.apache.spark._ import org.apache.spark.scheduler.MapStatus @@ -44,7 +46,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "true") @@ -57,7 +59,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -82,7 +84,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "false") @@ -93,7 +95,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -112,7 +114,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { // this should succeed since security off assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) actorSystem.shutdown() slaveSystem.shutdown() @@ -133,7 +135,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") @@ -146,7 +148,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -163,7 +165,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { // this should succeed since security on and passwords match assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) actorSystem.shutdown() slaveSystem.shutdown() @@ -185,7 +187,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") val badconf = new SparkConf badconf.set("spark.authenticate", "false") @@ -198,7 +200,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) val timeout = AkkaUtils.lookupTimeout(conf) intercept[akka.actor.ActorNotFound] { slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) @@ -208,4 +210,185 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { slaveSystem.shutdown() } + test("remote fetch ssl on") { + val conf = sparkSSLConfig() + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + + test("remote fetch ssl on and security enabled") { + val conf = sparkSSLConfig() + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + slaveConf.set("spark.authenticate", "true") + slaveConf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + + test("remote fetch ssl on and security enabled - bad credentials") { + val conf = sparkSSLConfig() + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + slaveConf.set("spark.authenticate", "true") + slaveConf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + + test("remote fetch ssl on - untrusted server") { + val conf = sparkSSLConfigUntrusted() + val securityManager = new SecurityManager(conf) + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + + val slaveConf = sparkSSLConfig() + val securityManagerBad = new SecurityManager(slaveConf) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = slaveConf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTrackerWorker(conf) + val selection = slaveSystem.actorSelection( + AkkaUtils.address("spark", "localhost", boundPort, "MapOutputTracker", conf)) + val timeout = AkkaUtils.lookupTimeout(conf) + intercept[TimeoutException] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + def sparkSSLConfig() = { + System.setProperty("spark.ssl.configFile", getClass.getResource("/good-ssl.conf").getPath) + val conf = new SparkConf + conf + } + + def sparkSSLConfigUntrusted() = { + System.setProperty("spark.ssl.configFile", getClass.getResource("/bad-ssl.conf").getPath) + val conf = new SparkConf + conf + } + } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index b433082dce1a2..23c0161536919 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -162,8 +162,8 @@ object ActorWordCount { */ val lines = ssc.actorStream[String]( - Props(new SampleActorReceiver[String]("akka.tcp://test@%s:%s/user/FeederActor".format( - host, port.toInt))), "SampleReceiver") + Props(new SampleActorReceiver[String]( + AkkaUtils.address("test", host, port, "FeederActor", sparkConf))), "SampleReceiver") // compute wordcount lines.flatMap(_.split("\\s+")).map(x => (x, 1)).reduceByKey(_ + _).print() diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 687e85ca94d3c..8dbe5c6627e16 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -45,7 +45,7 @@ class ExecutorClassLoader(classUri: String, parent: ClassLoader, // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { - if (uri.getScheme() == "http") { + if (Set("http", "https", "ftp").contains(uri.getScheme)) { null } else { FileSystem.get(uri, new Configuration()) @@ -78,13 +78,16 @@ class ExecutorClassLoader(classUri: String, parent: ClassLoader, if (fileSystem != null) { fileSystem.open(new Path(directory, pathInDirectory)) } else { - if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val url = if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) - newuri.toURL().openStream() + newuri.toURL } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + new URL(classUri + "/" + urlEncode(pathInDirectory)) } + + Utils.setupSecureURLConnection(url.openConnection(), SparkEnv.get.securityManager) + .getInputStream } } val bytes = readAndTransformClass(name, inputStream) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 84b57cd2dc1af..7fba4e5538d56 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -103,7 +103,7 @@ import org.apache.spark.util.Utils val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles /** Jetty server that will serve our classes to worker nodes */ val classServerPort = conf.getInt("spark.replClassServer.port", 0) - val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server", conf) private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 53a3e6200e340..51601ed6880a5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -55,8 +55,8 @@ private[streaming] class ReceiverSupervisorImpl( private val trackerActor = { val ip = env.conf.get("spark.driver.host", "localhost") val port = env.conf.getInt("spark.driver.port", 7077) - val url = "akka.tcp://%s@%s:%s/user/ReceiverTracker".format( - SparkEnv.driverActorSystemName, ip, port) + val url = AkkaUtils.address( + SparkEnv.driverActorSystemName, ip, port, "ReceiverTracker", env.conf) env.actorSystem.actorSelection(url) } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index 155dd88aa2b81..35ed341c13fdc 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -210,11 +210,12 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( SparkEnv.driverActorSystemName, driverHost, driverPort.toString, - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME, + sparkConf) actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 568a6ef932bbd..6e5483e73d177 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -29,7 +29,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.AMRMProtocol @@ -245,11 +245,12 @@ private[yarn] class YarnAllocationHandler( // Deallocate + allocate can result in reusing id's wrongly - so use a different counter // (executorIdCounter) val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME, + sparkConf) logInfo("launching container on " + containerId + " host " + executorHostname) // Just to be safe, simply remove it from pendingReleaseContainers. diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index e093fe4ae6ff8..5d769ccc86490 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -174,11 +174,12 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( SparkEnv.driverActorSystemName, driverHost, driverPort.toString, - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME, + sparkConf) actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 0a461749c819d..ffa77d1783074 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -29,7 +29,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.ApplicationMasterProtocol @@ -262,11 +262,12 @@ private[yarn] class YarnAllocationHandler( numExecutorsRunning.decrementAndGet() } else { val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + val driverUrl = AkkaUtils.address( SparkEnv.driverActorSystemName, sparkConf.get("spark.driver.host"), sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) + CoarseGrainedSchedulerBackend.ACTOR_NAME, + sparkConf) logInfo("Launching container %s for on host %s".format(containerId, executorHostname))