From f35176352b77d8a295601479879a4666501eabe8 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sun, 19 Jan 2014 13:35:32 -0600 Subject: [PATCH 01/14] Add Security to Spark - Akka, Http, ConnectionManager, UI to use servlets --- core/pom.xml | 16 ++ .../org/apache/spark/SparkSaslClient.java | 182 +++++++++++++ .../org/apache/spark/SparkSaslServer.java | 189 +++++++++++++ .../org/apache/spark/HttpFileServer.scala | 5 +- .../scala/org/apache/spark/HttpServer.scala | 58 +++- .../org/apache/spark/SecurityManager.scala | 112 ++++++++ .../scala/org/apache/spark/SparkContext.scala | 2 + .../scala/org/apache/spark/SparkEnv.scala | 19 +- .../apache/spark/broadcast/Broadcast.scala | 4 +- .../spark/broadcast/BroadcastFactory.scala | 3 +- .../spark/broadcast/HttpBroadcast.scala | 56 +++- .../spark/broadcast/TorrentBroadcast.scala | 4 +- .../org/apache/spark/deploy/Client.scala | 4 +- .../apache/spark/deploy/SparkHadoopUtil.scala | 5 + .../spark/deploy/client/TestClient.scala | 4 +- .../apache/spark/deploy/master/Master.scala | 5 +- .../spark/deploy/master/ui/MasterWebUI.scala | 17 +- .../spark/deploy/worker/DriverWrapper.scala | 6 +- .../apache/spark/deploy/worker/Worker.scala | 4 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 17 +- .../CoarseGrainedExecutorBackend.scala | 4 +- .../org/apache/spark/executor/Executor.scala | 11 +- .../spark/metrics/sink/MetricsServlet.scala | 8 +- .../apache/spark/network/BufferMessage.scala | 8 +- .../org/apache/spark/network/Connection.scala | 50 +++- .../apache/spark/network/ConnectionId.scala | 30 +++ .../spark/network/ConnectionManager.scala | 249 +++++++++++++++++- .../org/apache/spark/network/Message.scala | 1 + .../spark/network/MessageChunkHeader.scala | 10 +- .../apache/spark/network/ReceiverTest.scala | 4 +- .../spark/network/SecurityMessage.scala | 110 ++++++++ .../org/apache/spark/network/SenderTest.scala | 4 +- .../apache/spark/storage/BlockManager.scala | 12 +- .../apache/spark/storage/ThreadingTest.scala | 5 +- .../org/apache/spark/ui/JettyUtils.scala | 119 ++++++--- .../scala/org/apache/spark/ui/SparkUI.scala | 11 +- .../apache/spark/ui/env/EnvironmentUI.scala | 6 +- .../apache/spark/ui/exec/ExecutorsUI.scala | 6 +- .../apache/spark/ui/jobs/JobProgressUI.scala | 10 +- .../spark/ui/storage/BlockManagerUI.scala | 8 +- .../org/apache/spark/util/AkkaUtils.scala | 18 +- .../scala/org/apache/spark/util/Utils.scala | 44 +++- .../org/apache/spark/AkkaUtilsSuite.scala | 230 ++++++++++++++++ .../org/apache/spark/BroadcastSuite.scala | 3 + .../scala/org/apache/spark/DriverSuite.scala | 2 + .../org/apache/spark/FileServerSuite.scala | 24 ++ .../apache/spark/MapOutputTrackerSuite.scala | 6 +- .../spark/storage/BlockManagerSuite.scala | 66 ++--- docs/configuration.md | 37 +++ .../streaming/examples/ActorWordCount.scala | 5 +- pom.xml | 20 ++ project/SparkBuild.scala | 4 + .../spark/repl/ExecutorClassLoader.scala | 32 +++ .../org/apache/spark/repl/SparkILoop.scala | 22 +- .../org/apache/spark/repl/SparkIMain.scala | 4 +- .../org/apache/spark/repl/ReplSuite.scala | 2 + .../spark/deploy/yarn/ApplicationMaster.scala | 37 ++- .../spark/deploy/yarn/WorkerLauncher.scala | 6 +- .../spark/deploy/yarn/ClientArguments.scala | 2 +- .../deploy/yarn/YarnSparkHadoopUtil.scala | 9 + .../spark/deploy/yarn/ApplicationMaster.scala | 23 +- .../spark/deploy/yarn/WorkerLauncher.scala | 6 +- 62 files changed, 1746 insertions(+), 234 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/SparkSaslClient.java create mode 100644 core/src/main/java/org/apache/spark/SparkSaslServer.java create mode 100644 core/src/main/scala/org/apache/spark/SecurityManager.scala create mode 100644 core/src/main/scala/org/apache/spark/network/ConnectionId.scala create mode 100644 core/src/main/scala/org/apache/spark/network/SecurityMessage.scala create mode 100644 core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala diff --git a/core/pom.xml b/core/pom.xml index 9e5a450d57a47..8656da65df429 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -52,6 +52,18 @@ org.apache.zookeeper zookeeper + + org.eclipse.jetty + jetty-plus + + + org.eclipse.jetty + jetty-security + + + org.eclipse.jetty + jetty-util + org.eclipse.jetty jetty-server @@ -90,6 +102,10 @@ chill-java 0.3.1 + + commons-net + commons-net + ${akka.group} akka-remote_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/SparkSaslClient.java b/core/src/main/java/org/apache/spark/SparkSaslClient.java new file mode 100644 index 0000000000000..5fab593270992 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkSaslClient.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslClient; + +/** + * Implements SASL Client logic for Spark + * Some of the code borrowed from Giraph and Hadoop + */ +public class SparkSaslClient { + /** Class logger */ + private static Logger LOG = LoggerFactory.getLogger(SparkSaslClient.class); + + /** + * Used to respond to server's counterpart, SaslServer with SASL tokens + * represented as byte arrays. + */ + private SaslClient saslClient; + + /** + * Create a SaslClient for authentication with BSP servers. + */ + public SparkSaslClient(SecurityManager securityMgr) { + try { + saslClient = Sasl.createSaslClient(new String[] { SparkSaslServer.DIGEST }, + null, null, SparkSaslServer.SASL_DEFAULT_REALM, + SparkSaslServer.SASL_PROPS, new SparkSaslClientCallbackHandler(securityMgr)); + } catch (IOException e) { + LOG.error("SaslClient: Could not create SaslClient"); + saslClient = null; + } + } + + /** + * Used to initiate SASL handshake with server. + * @return response to challenge if needed + * @throws IOException + */ + public byte[] firstToken() throws SaslException { + byte[] saslToken = new byte[0]; + if (saslClient.hasInitialResponse()) { + LOG.debug("has initial response"); + saslToken = saslClient.evaluateChallenge(saslToken); + } + return saslToken; + } + + /** + * Determines whether the authentication exchange has completed. + */ + public boolean isComplete() { + return saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param saslTokenMessage contains server's SASL token + * @return client's response SASL token + */ + public byte[] saslResponse(byte[] saslTokenMessage) throws SaslException { + try { + byte[] retval = saslClient.evaluateChallenge(saslTokenMessage); + return retval; + } catch (SaslException e) { + LOG.error("saslResponse: Failed to respond to SASL server's token:", e); + throw e; + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public void dispose() throws SaslException { + if (saslClient != null) { + try { + saslClient.dispose(); + saslClient = null; + } catch (SaslException ignored) { + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private static class SparkSaslClientCallbackHandler implements CallbackHandler { + private final String userName; + private final char[] userPassword; + + /** + * Constructor + */ + public SparkSaslClientCallbackHandler(SecurityManager securityMgr) { + this.userName = SparkSaslServer. + encodeIdentifier(securityMgr.getSaslUser().getBytes()); + String secretKey = securityMgr.getSecretKey() ; + String passwd = (secretKey != null) ? secretKey : ""; + this.userPassword = SparkSaslServer.encodePassword(passwd.getBytes()); + } + + /** + * Implementation used to respond to SASL tokens from server. + * + * @param callbacks objects that indicate what credential information the + * server's SaslServer requires from the client. + * @throws UnsupportedCallbackException + */ + public void handle(Callback[] callbacks) + throws UnsupportedCallbackException { + NameCallback nc = null; + PasswordCallback pc = null; + RealmCallback rc = null; + for (Callback callback : callbacks) { + if (callback instanceof RealmChoiceCallback) { + continue; + } else if (callback instanceof NameCallback) { + nc = (NameCallback) callback; + } else if (callback instanceof PasswordCallback) { + pc = (PasswordCallback) callback; + } else if (callback instanceof RealmCallback) { + rc = (RealmCallback) callback; + } else { + throw new UnsupportedCallbackException(callback, + "handle: Unrecognized SASL client callback"); + } + } + if (nc != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("handle: SASL client callback: setting username: " + + userName); + } + nc.setName(userName); + } + if (pc != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("handle: SASL client callback: setting userPassword"); + } + pc.setPassword(userPassword); + } + if (rc != null) { + if (LOG.isDebugEnabled()) { + LOG.debug("handle: SASL client callback: setting realm: " + + rc.getDefaultText()); + } + rc.setText(rc.getDefaultText()); + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/SparkSaslServer.java b/core/src/main/java/org/apache/spark/SparkSaslServer.java new file mode 100644 index 0000000000000..5c2fcf9afca15 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkSaslServer.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark; + +import org.apache.commons.net.util.Base64; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.TreeMap; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; + +/** + * Encapsulates SASL server logic for Server + */ +public class SparkSaslServer { + /** Logger */ + private static Logger LOG = LoggerFactory.getLogger("SparkSaslServer.class"); + + /** + * Actual SASL work done by this object from javax.security.sasl. + * Initialized below in constructor. + */ + private SaslServer saslServer; + + public static final String SASL_DEFAULT_REALM = "default"; + public static final String DIGEST = "DIGEST-MD5"; + public static final Map SASL_PROPS = + new TreeMap(); + + /** + * Constructor + */ + public SparkSaslServer(SecurityManager securityMgr) { + try { + SASL_PROPS.put(Sasl.QOP, "auth"); + SASL_PROPS.put(Sasl.SERVER_AUTH, "true"); + saslServer = Sasl.createSaslServer(DIGEST, null, SASL_DEFAULT_REALM, SASL_PROPS, + new SaslDigestCallbackHandler(securityMgr)); + } catch (SaslException e) { + LOG.error("SparkSaslServer: Could not create SaslServer: " + e); + saslServer = null; + } + } + + /** + * Determines whether the authentication exchange has completed. + */ + public boolean isComplete() { + return saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * + * @param token Server's SASL token + * @return response to send back to the server. + */ + public byte[] response(byte[] token) throws SaslException { + try { + byte[] retval = saslServer.evaluateResponse(token); + if (LOG.isDebugEnabled()) { + LOG.debug("response: Response token length: " + retval.length); + } + return retval; + } catch (SaslException e) { + LOG.error("Response: Failed to evaluate client token of length: " + + token.length + " : " + e); + throw e; + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public void dispose() throws SaslException { + if (saslServer != null) { + try { + saslServer.dispose(); + saslServer = null; + } catch (SaslException ignored) { + } + } + } + + /** + * Encode a byte[] identifier as a Base64-encoded string. + * + * @param identifier identifier to encode + * @return Base64-encoded string + */ + static String encodeIdentifier(byte[] identifier) { + return new String(Base64.encodeBase64(identifier)); + } + + /** + * Encode a password as a base64-encoded char[] array. + * @param password as a byte array. + * @return password as a char array. + */ + static char[] encodePassword(byte[] password) { + return new String(Base64.encodeBase64(password)).toCharArray(); + } + + /** CallbackHandler for SASL DIGEST-MD5 mechanism */ + public static class SaslDigestCallbackHandler implements CallbackHandler { + + private SecurityManager securityManager; + + /** + * Constructor + */ + public SaslDigestCallbackHandler(SecurityManager securityMgr) { + this.securityManager = securityMgr; + } + + @Override + public void handle(Callback[] callbacks) throws IOException, + UnsupportedCallbackException { + NameCallback nc = null; + PasswordCallback pc = null; + AuthorizeCallback ac = null; + LOG.debug("in the sasl server callback handler"); + for (Callback callback : callbacks) { + if (callback instanceof AuthorizeCallback) { + ac = (AuthorizeCallback) callback; + } else if (callback instanceof NameCallback) { + nc = (NameCallback) callback; + } else if (callback instanceof PasswordCallback) { + pc = (PasswordCallback) callback; + } else if (callback instanceof RealmCallback) { + continue; // realm is ignored + } else { + throw new UnsupportedCallbackException(callback, + "handle: Unrecognized SASL DIGEST-MD5 Callback"); + } + } + if (pc != null) { + char[] password = + encodePassword(securityManager.getSecretKey().getBytes()); + pc.setPassword(password); + } + + if (ac != null) { + String authid = ac.getAuthenticationID(); + String authzid = ac.getAuthorizationID(); + if (authid.equals(authzid)) { + LOG.debug("set auth to true"); + ac.setAuthorized(true); + } else { + LOG.debug("set auth to false"); + ac.setAuthorized(false); + } + if (ac.isAuthorized()) { + LOG.debug("sasl server is authorized"); + ac.setAuthorizedID(authzid); + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index a885898ad48d4..4e791813f94f1 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -21,7 +21,7 @@ import java.io.{File} import com.google.common.io.Files import org.apache.spark.util.Utils -private[spark] class HttpFileServer extends Logging { +private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { var baseDir : File = null var fileDir : File = null @@ -36,9 +36,10 @@ private[spark] class HttpFileServer extends Logging { fileDir.mkdir() jarDir.mkdir() logInfo("HTTP File server directory is " + baseDir) - httpServer = new HttpServer(baseDir) + httpServer = new HttpServer(baseDir, securityManager) httpServer.start() serverUri = httpServer.uri + logDebug("HTTP file server started at: " + serverUri) } def stop() { diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 69a738dc4446a..2bc91ef5318eb 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -20,14 +20,17 @@ package org.apache.spark import java.io.File import java.net.InetAddress +import org.eclipse.jetty.util.security.{Constraint, Password} +import org.eclipse.jetty.security.authentication.DigestAuthenticator +import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService, SecurityHandler} + import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.handler.DefaultHandler -import org.eclipse.jetty.server.handler.HandlerList -import org.eclipse.jetty.server.handler.ResourceHandler +import org.eclipse.jetty.server.handler.{DefaultHandler, HandlerList, ResourceHandler} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils + /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ @@ -38,7 +41,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File) extends Logging { +private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) extends Logging { private var server: Server = null private var port: Int = -1 @@ -59,9 +62,50 @@ private[spark] class HttpServer(resourceBase: File) extends Logging { server.setThreadPool(threadPool) val resHandler = new ResourceHandler resHandler.setResourceBase(resourceBase.getAbsolutePath) - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) - server.setHandler(handlerList) + + if (securityManager.isAuthenticationEnabled()) { + logDebug("server is using security") + val constraint = new Constraint() + constraint.setName(Constraint.__DIGEST_AUTH) + constraint.setRoles(Array("user")) + constraint.setAuthenticate(true) + constraint.setDataConstraint(Constraint.DC_NONE) + + val cm = new ConstraintMapping() + cm.setConstraint(constraint) + cm.setPathSpec("/*") + + val sh = new ConstraintSecurityHandler() + + // the hashLoginService lets us do a simply user and + // secret right now. This could be changed to use the + // JAASLoginService for other options. + val hashLogin = new HashLoginService() + + val userCred = new Password(securityManager.getSecretKey()) + if (userCred == null) { + throw new Exception("secret key is null with authentication on") + } + hashLogin.putUser(securityManager.getHttpUser(), userCred, Array("user")) + + logDebug("hashlogin loading user: " + hashLogin.getUsers()) + + sh.setLoginService(hashLogin) + sh.setAuthenticator(new DigestAuthenticator()); + sh.setConstraintMappings(Array(cm)) + + // make sure we go through security handler to get resources + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + sh.setHandler(handlerList) + server.setHandler(sh) + } else { + logDebug("server is not using security") + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + server.setHandler(handlerList) + } + server.start() port = server.getConnectors()(0).getLocalPort() } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala new file mode 100644 index 0000000000000..7aaceb1bd3767 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.hadoop.io.Text +import org.apache.hadoop.security.Credentials +import org.apache.hadoop.security.UserGroupInformation + +import org.apache.spark.deploy.SparkHadoopUtil + +/** + * Spark class responsible for security. + */ +private[spark] class SecurityManager extends Logging { + + private val isAuthOn = System.getProperty("spark.authenticate", "false").toBoolean + private val isUIAuthOn = System.getProperty("spark.authenticate.ui", "false").toBoolean + private val viewAcls = System.getProperty("spark.ui.view.acls", "").split(',').map(_.trim()).toSet + private val secretKey = generateSecretKey() + logDebug("is auth enabled = " + isAuthOn + " is uiAuth enabled = " + isUIAuthOn) + + /** + * In Yarn mode its uses Hadoop UGI to pass the secret as that + * will keep it protected. For a standalone SPARK cluster + * use a environment variable SPARK_SECRET to specify the secret. + * This probably isn't ideal but only the user who starts the process + * should have access to view the variable (atleast on Linux). + * Since we can't set the environment variable we set the + * java system property SPARK_SECRET so it will automatically + * generate a secret is not specified. This definitely is not + * ideal since users can see it. We should switch to put it in + * a config. + */ + private def generateSecretKey(): String = { + + if (!isAuthenticationEnabled) return null + // first check to see if secret already set, else generate it + if (SparkHadoopUtil.get.isYarnMode) { + val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() + if (credentials != null) { + val secretKey = credentials.getSecretKey(new Text("akkaCookie")) + if (secretKey != null) { + logDebug("in yarn mode, getting secret from credentials") + return new Text(secretKey).toString + } else { + logDebug("getSecretKey: yarn mode, secret key from credentials is null") + } + } else { + logDebug("getSecretKey: yarn mode, credentials are null") + } + } + val secret = System.getProperty("SPARK_SECRET", System.getenv("SPARK_SECRET")) + if (secret != null && !secret.isEmpty()) return secret + // generate one + val sCookie = akka.util.Crypt.generateSecureCookie + + // if we generate we must be the first so lets set it so its used by everyone else + if (SparkHadoopUtil.get.isYarnMode) { + val creds = new Credentials() + creds.addSecretKey(new Text("akkaCookie"), sCookie.getBytes()) + SparkHadoopUtil.get.addCurrentUserCredentials(creds) + logDebug("adding secret to credentials yarn mode") + } else { + System.setProperty("SPARK_SECRET", sCookie) + logDebug("adding secret to java property") + } + return sCookie + } + + def isUIAuthenticationEnabled(): Boolean = return isUIAuthOn + + // allow anyone in the acl list and the application owner + def checkUIViewPermissions(user: String): Boolean = { + if (isUIAuthenticationEnabled() && (user != null)) { + if ((!viewAcls.contains(user)) && (user != System.getProperty("user.name"))) { + return false + } + } + return true + } + + def isAuthenticationEnabled(): Boolean = return isAuthOn + + // user for HTTP connections + def getHttpUser(): String = "sparkHttpUser" + + // user to use with SASL connections + def getSaslUser(): String = "sparkSaslUser" + + /** + * Gets the secret key if security is enabled, else returns null. + */ + def getSecretKey(): String = { + return secretKey + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ddd7d60d96bd5..aa657e2f89617 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -135,6 +135,8 @@ class SparkContext( val isLocal = (master == "local" || master.startsWith("local[")) + if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + // Create the Spark execution environment (cache, map output tracker, etc) private[spark] val env = SparkEnv.create( conf, diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ed788560e79f1..64c1dc9ca07f0 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -54,7 +54,8 @@ class SparkEnv private[spark] ( val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, - val conf: SparkConf) extends Logging { + val conf: SparkConf, + val securityManager: SecurityManager) extends Logging { // A mapping of thread ID to amount of memory used for shuffle in bytes // All accesses should be manually synchronized @@ -123,8 +124,9 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean): SparkEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, - conf = conf) + val securityManager = new SecurityManager() + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, + securityManager = securityManager) // Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port), // figure out which port number Akka actually bound to and set spark.driver.port to it. @@ -140,7 +142,6 @@ object SparkEnv extends Logging { val name = conf.get(propertyName, defaultClassName) Class.forName(name, true, classLoader).newInstance().asInstanceOf[T] } - val serializerManager = new SerializerManager val serializer = serializerManager.setDefault( @@ -168,11 +169,12 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf)), conf) - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf) + val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, + conf, securityManager) val connectionManager = blockManager.connectionManager - val broadcastManager = new BroadcastManager(isDriver, conf) + val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val cacheManager = new CacheManager(blockManager) @@ -190,7 +192,7 @@ object SparkEnv extends Logging { val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") - val httpFileServer = new HttpFileServer() + val httpFileServer = new HttpFileServer(securityManager) httpFileServer.initialize() conf.set("spark.fileserver.uri", httpFileServer.serverUri) @@ -231,6 +233,7 @@ object SparkEnv extends Logging { httpFileServer, sparkFilesDir, metricsSystem, - conf) + conf, + securityManager) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index d113d4040594d..da937e7377f3d 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -60,7 +60,7 @@ abstract class Broadcast[T](val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null @@ -78,7 +78,7 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf) extends Logging Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf) + broadcastFactory.initialize(isDriver, conf, securityManager) initialized = true } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 940e5ab805100..6beecaeced5be 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.broadcast +import org.apache.spark.SecurityManager import org.apache.spark.SparkConf @@ -26,7 +27,7 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 39ee0dbb92841..395fd2cf2bebb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -18,13 +18,13 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.URL +import java.net.{Authenticator, PasswordAuthentication, URL, URLConnection, URI} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream -import org.apache.spark.{SparkConf, HttpServer, Logging, SparkEnv} +import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -67,7 +67,9 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. */ class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { HttpBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new HttpBroadcast[T](value_, isLocal, id) @@ -83,6 +85,7 @@ private object HttpBroadcast extends Logging { private var bufferSize: Int = 65536 private var serverUri: String = null private var server: HttpServer = null + private var securityManager: SecurityManager = null // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] @@ -92,11 +95,12 @@ private object HttpBroadcast extends Logging { private var compressionCodec: CompressionCodec = null - def initialize(isDriver: Boolean, conf: SparkConf) { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { if (!initialized) { bufferSize = conf.getInt("spark.buffer.size", 65536) compress = conf.getBoolean("spark.broadcast.compress", true) + securityManager = securityMgr if (isDriver) { createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) @@ -126,7 +130,7 @@ private object HttpBroadcast extends Logging { private def createServer(conf: SparkConf) { broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) - server = new HttpServer(broadcastDir) + server = new HttpServer(broadcastDir, securityManager) server.start() serverUri = server.uri logInfo("Broadcast server started at " + serverUri) @@ -149,11 +153,47 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { + logDebug("broadcast read server: " + serverUri + " id: broadcast-"+id) val url = serverUri + "/" + BroadcastBlockId(id).name + + var uc: URLConnection = null + if (securityManager.isAuthenticationEnabled()) { + val uri = new URI(url) + val userCred = securityManager.getSecretKey() + if (userCred == null) { + // if auth is on force the user to specify a password + throw new Exception("secret key is null with authentication on") + } + val userInfo = securityManager.getHttpUser() + ":" + userCred + val newuri = new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), + uri.getQuery(), uri.getFragment()) + + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + logDebug("broadcast security enabled") + + // set our own authenticator to properly negotiate user/password + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ); + } else { + logDebug("broadcast not using security") + uc = new URL(url).openConnection() + } + val in = { - val httpConnection = new URL(url).openConnection() - httpConnection.setReadTimeout(httpReadTimeout) - val inputStream = httpConnection.getInputStream + uc.setReadTimeout(httpReadTimeout) + val inputStream = uc.getInputStream(); if (compress) { compressionCodec.compressedInputStream(inputStream) } else { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index d351dfc1f56a2..c23a878712f52 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -241,7 +241,9 @@ private[spark] case class TorrentInfo( */ class TorrentBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = new TorrentBroadcast[T](value_, isLocal, id) diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 9987e2300ceb7..76f73271b2b97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -25,7 +25,7 @@ import akka.actor._ import akka.pattern.ask import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.util.{AkkaUtils, Utils} @@ -141,7 +141,7 @@ object Client { // TODO: See if we can initialize akka so return messages are sent back using the same TCP // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, false, conf) + "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index b479225b45ee9..5af01f59afbc8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -21,6 +21,7 @@ import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkContext, SparkException} @@ -63,6 +64,10 @@ class SparkHadoopUtil { def addCredentials(conf: JobConf) {} def isYarnMode(): Boolean = { false } + + def getCurrentUserCredentials(): Credentials = { null } + + def addCurrentUserCredentials(creds: Credentials) {} } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index ffa909c26b64a..5285cb7a6ac98 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.client import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext, Logging} import org.apache.spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { @@ -46,7 +46,7 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = new SparkConf) + conf = new SparkConf, securityManager = new SecurityManager()) val desc = new ApplicationDescription( "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d9ea96afcf52a..8d97c8e7ebaf5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -31,7 +31,7 @@ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{SparkConf, Logging, SparkException} +import org.apache.spark.{SecurityManager, SparkConf, Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ @@ -704,7 +704,8 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) : (ActorSystem, Int, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, + securityManager = new SecurityManager) val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index ead35662fc75a..787b94e35f99e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -18,7 +18,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.Logging import org.apache.spark.deploy.master.Master @@ -59,12 +60,12 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { val metricsHandlers = master.masterMetricsSystem.getServletHandlers ++ master.applicationMetricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR)), - ("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), - ("/app", (request: HttpServletRequest) => applicationPage.render(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"), + createServletHandler("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), + createServletHandler("/app", (request: HttpServletRequest) => applicationPage.render(request)), + createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), + createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request)) ) def stop() { @@ -73,5 +74,5 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { } private[spark] object MasterWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 6f6c101547c3c..0c91c89714009 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.worker import akka.actor._ -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{AkkaUtils, Utils} /** @@ -30,7 +30,7 @@ object DriverWrapper { args.toList match { case workerUrl :: mainClass :: extraArgs => val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", - Utils.localHostName(), 0, false, new SparkConf()) + Utils.localHostName(), 0, false, new SparkConf(), new SecurityManager()) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") // Delegate to supplied main class @@ -45,4 +45,4 @@ object DriverWrapper { System.exit(-1) } } -} \ No newline at end of file +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 312560d7063a4..17ce27b47627f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -27,7 +27,7 @@ import scala.concurrent.duration._ import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} @@ -337,7 +337,7 @@ private[spark] object Worker { val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf) + conf = conf, securityManager = new SecurityManager) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, masterUrls, systemName, actorName, workDir, conf), name = actorName) (actorSystem, boundPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 8daa47b2b2435..ff249e92d6484 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,7 +20,8 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker @@ -46,12 +47,12 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val metricsHandlers = worker.metricsSystem.getServletHandlers - val handlers = metricsHandlers ++ Array[(String, Handler)]( - ("/static", createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR)), - ("/log", (request: HttpServletRequest) => log(request)), - ("/logPage", (request: HttpServletRequest) => logPage(request)), - ("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - ("*", (request: HttpServletRequest) => indexPage.render(request)) + val handlers = metricsHandlers ++ Seq[ServletContextHandler]( + createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR, "/static/*"), + createServletHandler("/log", (request: HttpServletRequest) => log(request)), + createServletHandler("/logPage", (request: HttpServletRequest) => logPage(request)), + createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), + createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request)) ) def start() { @@ -198,6 +199,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I } private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" val DEFAULT_PORT="8081" } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 45b43b403dd8c..ad7f2b97a06f3 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor._ import akka.remote._ -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext, Logging} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -100,7 +100,7 @@ private[spark] object CoarseGrainedExecutorBackend { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true, conf = new SparkConf) + indestructible = true, conf = new SparkConf, new SecurityManager) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index c1b57f74d7e9a..8ca602b9d39e1 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -69,11 +69,6 @@ private[spark] class Executor( conf.set("spark.local.dir", getYarnLocalDirs()) } - // Create our ClassLoader and set it on this thread - private val urlClassLoader = createClassLoader() - private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) - Thread.currentThread.setContextClassLoader(replClassLoader) - if (!isLocal) { // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire @@ -117,6 +112,12 @@ private[spark] class Executor( } } + // Create our ClassLoader and set it on this thread + // do this after SparkEnv creation so can access the SecurityManager + private val urlClassLoader = createClassLoader() + private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader) + Thread.currentThread.setContextClassLoader(replClassLoader) + // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 99357fede6d06..c59cbbedf64ab 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -24,9 +24,10 @@ import com.fasterxml.jackson.databind.ObjectMapper import java.util.Properties import java.util.concurrent.TimeUnit + import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.ui.JettyUtils @@ -44,8 +45,9 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext val mapper = new ObjectMapper().registerModule( new MetricsModule(TimeUnit.SECONDS, TimeUnit.MILLISECONDS, servletShowSample)) - def getHandlers = Array[(String, Handler)]( - (servletPath, JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) + def getHandlers = Array[ServletContextHandler]( + JettyUtils.createServletHandler(servletPath, + JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) ) def getMetricsSnapshot(request: HttpServletRequest): String = { diff --git a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala index fb4c65909a9e2..231aacd1524d2 100644 --- a/core/src/main/scala/org/apache/spark/network/BufferMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/BufferMessage.scala @@ -46,9 +46,10 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Max chunk size is " + maxChunkSize) } + val security = if (isSecurityNeg) 1 else 0 if (size == 0 && !gotChunkForSendingOnce) { val newChunk = new MessageChunk( - new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null) + new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null) gotChunkForSendingOnce = true return Some(newChunk) } @@ -66,7 +67,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: } buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) gotChunkForSendingOnce = true return Some(newChunk) } @@ -80,6 +81,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: throw new Exception("Attempting to get chunk from message with multiple data buffers") } val buffer = buffers(0) + val security = if (isSecurityNeg) 1 else 0 if (buffer.remaining > 0) { if (buffer.remaining < chunkSize) { throw new Exception("Not enough space in data buffer for receiving chunk") @@ -87,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer] buffer.position(buffer.position + newBuffer.remaining) val newChunk = new MessageChunk(new MessageChunkHeader( - typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer) + typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer) return Some(newChunk) } None diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index cba8477ed5723..e20569d297dc5 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -18,6 +18,8 @@ package org.apache.spark.network import org.apache.spark._ +import org.apache.spark.SparkSaslServer +import org.apache.spark.SparkSaslServer.SaslDigestCallbackHandler import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} @@ -30,13 +32,16 @@ import java.net._ private[spark] abstract class Connection(val channel: SocketChannel, val selector: Selector, - val socketRemoteConnectionManagerId: ConnectionManagerId) + val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { - def this(channel_ : SocketChannel, selector_ : Selector) = { + var sparkSaslServer : SparkSaslServer = null + var sparkSaslClient : SparkSaslClient = null + + def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, ConnectionManagerId.fromSocketAddress( - channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress])) + channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_) } channel.configureBlocking(false) @@ -52,6 +57,16 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val remoteAddress = getRemoteAddress() + /** + * Used to synchronize client requests: client's work-related requests must + * wait until SASL authentication completes. + */ + private val authenticated = new Object() + + def getAuthenticated(): Object = authenticated + + def isSaslComplete(): Boolean + def resetForceReregister(): Boolean // Read channels typically do not register for write and write does not for read @@ -72,6 +87,15 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, // Will be true for ReceivingConnection, false for SendingConnection. def changeInterestForRead(): Boolean + private def disposeSasl() { + if (sparkSaslServer != null) { + sparkSaslServer.dispose(); + } + if (sparkSaslClient != null) { + sparkSaslClient.dispose() + } + } + // On receiving a write event, should we change the interest for this channel or not ? // Will be false for ReceivingConnection, true for SendingConnection. // Actually, for now, should not get triggered for ReceivingConnection @@ -104,6 +128,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, k.cancel() } channel.close() + disposeSasl() callOnCloseCallback() } @@ -171,8 +196,12 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector, - remoteId_ : ConnectionManagerId) - extends Connection(SocketChannel.open, selector_, remoteId_) { + remoteId_ : ConnectionManagerId, id_ : ConnectionId) + extends Connection(SocketChannel.open, selector_, remoteId_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslClient != null) sparkSaslClient.isComplete() else false + } private class Outbox(fair: Int = 0) { val messages = new Queue[Message]() @@ -262,6 +291,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, data as detailed in https://github.com/mesos/spark/pull/791 */ private var needForceReregister = false + val currentBuffers = new ArrayBuffer[ByteBuffer]() /*channel.socket.setSendBufferSize(256 * 1024)*/ @@ -352,6 +382,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // If we have 'seen' pending messages, then reset flag - since we handle that as normal // registering of event (below) if (needForceReregister && buffers.exists(_.remaining() > 0)) resetForceReregister() + currentBuffers ++= buffers } case None => { @@ -419,8 +450,12 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector) - extends Connection(channel_, selector_) { +private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) + extends Connection(channel_, selector_, id_) { + + def isSaslComplete(): Boolean = { + if (sparkSaslServer != null) sparkSaslServer.isComplete() else false + } class Inbox() { val messages = new HashMap[Int, BufferMessage]() @@ -431,6 +466,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis + newMessage.isSecurityNeg = if (header.securityNeg == 1) true else false logDebug( "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala new file mode 100644 index 0000000000000..b26b7ee34534a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +private[spark] case class ConnectionId(id : String) {} + +private[spark] object ConnectionId { + + def createConnectionId(connectionManagerId : ConnectionManagerId, secureMsgId : Int) : ConnectionId = { + val connIdStr = connectionManagerId.host + "_" + connectionManagerId.port + "_" + secureMsgId + val connId = new ConnectionId(connIdStr) + return connId + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e6e01783c8895..a67b17596d19e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -23,6 +23,7 @@ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.net._ +import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor} import scala.collection.mutable.HashSet @@ -37,7 +38,7 @@ import scala.concurrent.duration._ import org.apache.spark.util.Utils -private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Logging { +private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManager: SecurityManager) extends Logging { class MessageStatus( val message: Message, @@ -53,6 +54,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private val selector = SelectorProvider.provider.openSelector() + // TODO -update to use spark conf + private val numAuthRetries = System.getProperty("spark.core.connection.num.auth.retries","10").toInt + private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), conf.getInt("spark.core.connection.handler.threads.max", 60), @@ -73,6 +77,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] private val messageStatuses = new HashMap[Int, MessageStatus] @@ -84,6 +89,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null + private val authEnabled = securityManager.isAuthenticationEnabled() + serverChannel.configureBlocking(false) serverChannel.socket.setReuseAddress(true) serverChannel.socket.setReceiveBufferSize(256 * 1024) @@ -94,6 +101,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort) logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id) + // used in combination with the ConnectionManagerId to create unique Connection ids + // to be able to track asynchronous messages + private val idCount: AtomicInteger = new AtomicInteger(1) + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } @@ -125,7 +136,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } finally { writeRunnableStarted.synchronized { writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() + val needReregister = register || conn.resetForceReregister() if (needReregister && conn.changeInterestForWrite()) { conn.registerInterest() } @@ -367,7 +378,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi // accept them all in a tight loop. non blocking accept with no processing, should be fine while (newChannel != null) { try { - val newConnection = new ReceivingConnection(newChannel, selector) + val newConnectionId = ConnectionId.createConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) newConnection.onReceive(receiveMessage) addListeners(newConnection) addConnection(newConnection) @@ -401,6 +413,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi logInfo("Removing SendingConnection to " + sendingConnectionManagerId) connectionsById -= sendingConnectionManagerId + connectionsAwaitingSasl -= connection.connectionId messageStatuses.synchronized { messageStatuses @@ -475,7 +488,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi val creationTime = System.currentTimeMillis def run() { logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message) + handleMessage(connectionManagerId, message, connection) logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") } } @@ -483,10 +496,131 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi /*handleMessage(connection, message)*/ } - private def handleMessage(connectionManagerId: ConnectionManagerId, message: Message) { + private def handleClientAuthNeg( + waitingConn: SendingConnection, + securityMsg: SecurityMessage, + connectionId : ConnectionId) { + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll(); + } + return + } else { + var replyToken : Array[Byte] = null + try { + replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken); + if (waitingConn.isSaslComplete()) { + logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) + connectionsAwaitingSasl -= waitingConn.connectionId + waitingConn.getAuthenticated().synchronized { + waitingConn.getAuthenticated().notifyAll() + } + return + } + var securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) + } catch { + case e: Exception => { + logError("Error doing sasl client: " + e) + waitingConn.close() + throw new Exception("error evaluating sasl response: " + e) + } + } + } + } + + private def handleServerAuthNeg( + connection: Connection, + securityMsg: SecurityMessage, + connectionId: ConnectionId) { + if (!connection.isSaslComplete()) { + logDebug("saslContext not established") + var replyToken : Array[Byte] = null + try { + connection.synchronized { + if (connection.sparkSaslServer == null) { + logDebug("Creating sasl Server") + connection.sparkSaslServer = new SparkSaslServer(securityManager) + } + } + replyToken = connection.sparkSaslServer.response(securityMsg.getToken) + if (connection.isSaslComplete()) { + logDebug("Server sasl completed: " + connection.connectionId) + } else { + logDebug("Server sasl not completed: " + connection.connectionId) + } + if (replyToken != null) { + var securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId) + var message = securityMsgResp.toBufferMessage + if (message == null) throw new Exception("Error creating security Message") + sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) + } + } catch { + case e: Exception => { + logError("Error in server auth negotiation: " + e) + // It would probably be better to send an error message telling other side auth failed + // but for now just close + connection.close() + } + } + } else { + logDebug("connection already established for this connection id: " + connection.connectionId) + } + } + + + private def handleAuthentication(conn: Connection, bufferMessage: BufferMessage): Boolean = { + if (bufferMessage.isSecurityNeg) { + logDebug("This is security neg message") + + // parse as SecurityMessage + val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) + val connectionId = new ConnectionId(securityMsg.getConnectionId) + + connectionsAwaitingSasl.get(connectionId) match { + case Some(waitingConn) => { + // Client - this must be in response to us doing Send + logDebug("Client handleAuth for id: " + waitingConn.connectionId) + handleClientAuthNeg(waitingConn, securityMsg, connectionId) + } + case None => { + // Server - someone sent us something and we haven't authenticated yet + logDebug("Server handleAuth for id: " + connectionId) + handleServerAuthNeg(conn, securityMsg, connectionId) + } + } + return true + } else { + if (!conn.isSaslComplete()) { + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. + logError("message sent that is not security negotiation message on connection " + + "not authenticated yet, ignoring it!!") + return true + } + } + return false + } + + private def handleMessage( + connectionManagerId: ConnectionManagerId, + message: Message, + connection: Connection) { logDebug("Handling [" + message + "] from [" + connectionManagerId + "]") message match { case bufferMessage: BufferMessage => { + if (authEnabled) { + val res = handleAuthentication(connection, bufferMessage) + if (res == true) { + // message was security negotiation so skip the rest + logDebug("After handleAuth result was true, returning"); + return + } + } if (bufferMessage.hasAckId) { val sentMessageStatus = messageStatuses.synchronized { messageStatuses.get(bufferMessage.ackId) match { @@ -533,10 +667,65 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi } } + private def checkSendAuthFirst(connManagerId: ConnectionManagerId, conn: SendingConnection) { + // see if we need to do sasl before writing + // this should only be the first negotiation as the Client!!! + if (!conn.isSaslComplete()) { + conn.synchronized { + if (conn.sparkSaslClient == null) { + conn.sparkSaslClient = new SparkSaslClient(securityManager) + var firstResponse: Array[Byte] = null + try { + firstResponse = conn.sparkSaslClient.firstToken() + var securityMsg = SecurityMessage.fromResponse(firstResponse, conn.connectionId.id) + var message = securityMsg.toBufferMessage + if (message == null) throw new Exception("Error creating security message") + sendSecurityMessage(connManagerId, message) + logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) + connectionsAwaitingSasl += ((conn.connectionId, conn)) + } catch { + case e: Exception => { + logError("Error getting first response from the SaslClient") + conn.close() + throw new Exception("Error getting first response from the SaslClient") + } + } + } + } + } else { + logDebug("Sasl already established ") + } + } + + // allow us to add messages to the inbox for doing sasl negotiating + private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { + def startNewConnection(): SendingConnection = { + val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) + val newConnectionId = ConnectionId.createConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, newConnectionId) + logInfo("creating new sending connection for security! " + newConnectionId ) + registerRequests.enqueue(newConnection) + + newConnection + } + // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // If we do re-add it, we should consistently use it everywhere I guess ? + message.senderAddress = id.toSocketAddress() + logDebug("Sending Security [" + message + "] to [" + connManagerId + "]") + val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) + + //send security message until going connection has been authenticated + connection.send(message) + + wakeupSelector() + } + private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + val newConnectionId = ConnectionId.createConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId) + logInfo("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection @@ -544,7 +733,53 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... // If we do re-add it, we should consistently use it everywhere I guess ? val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection()) + if (authEnabled) { + checkSendAuthFirst(connectionManagerId, connection) + } message.senderAddress = id.toSocketAddress() + logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " connectionid: " + connection.connectionId) + + if (authEnabled) { + // if we aren't authenticated yet lets block the senders until authentication completes + try { + connection.getAuthenticated().synchronized { + var totalWaitTimes = 0 + while (!connection.isSaslComplete()) { + // should we specify timeout as fallback? + logDebug("getAuthenticated wait connectionid: " + connection.connectionId) + // have timeout in case remote side never responds + totalWaitTimes += 1 + connection.getAuthenticated().wait(500) + if (totalWaitTimes >= numAuthRetries) { + // took to long to auth connection something probably went wrong + throw new Exception("Took to long for authentication to " + connectionManagerId + + ", waited " + 500 * numAuthRetries + "ms, failing.") + } + } + } + } catch { + case e: Exception => logError("Exception while waiting for authentication. " + e) + + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(message.id) + s match { + case Some(msgStatus) => { + messageStatuses -= message.id + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.synchronized { + msgStatus.attempted = true + msgStatus.acked = false + msgStatus.markDone() + } + } + case None => { + logError("no messageStatus for failed message id: " + message.id) + } + } + } + } + } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) @@ -594,7 +829,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val manager = new ConnectionManager(9999, new SparkConf, new SecurityManager) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/Message.scala b/core/src/main/scala/org/apache/spark/network/Message.scala index 2612884bdbe15..2deb8b09adaf7 100644 --- a/core/src/main/scala/org/apache/spark/network/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/Message.scala @@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) { var started = false var startTime = -1L var finishTime = -1L + var isSecurityNeg = false def size: Int diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index 235fbc39b3bd2..30666891633bb 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -28,6 +28,7 @@ private[spark] class MessageChunkHeader( val totalSize: Int, val chunkSize: Int, val other: Int, + val securityNeg: Int, val address: InetSocketAddress) { lazy val buffer = { // No need to change this, at 'use' time, we do a reverse lookup of the hostname. @@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader( putInt(totalSize). putInt(chunkSize). putInt(other). + putInt(securityNeg). putInt(ip.size). put(ip). putInt(port). @@ -49,12 +51,13 @@ private[spark] class MessageChunkHeader( } override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ + - " and sizes " + totalSize + " / " + chunkSize + " bytes" + " and sizes " + totalSize + " / " + chunkSize + " bytes, securityNeg: " + securityNeg + } private[spark] object MessageChunkHeader { - val HEADER_SIZE = 40 + val HEADER_SIZE = 44 def create(buffer: ByteBuffer): MessageChunkHeader = { if (buffer.remaining != HEADER_SIZE) { @@ -65,11 +68,12 @@ private[spark] object MessageChunkHeader { val totalSize = buffer.getInt() val chunkSize = buffer.getInt() val other = buffer.getInt() + val securityNeg = buffer.getInt() val ipSize = buffer.getInt() val ipBytes = new Array[Byte](ipSize) buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port)) + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 1c9d6030d68d7..ba915dc74cdcf 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -19,11 +19,11 @@ package org.apache.spark.network import java.nio.ByteBuffer import java.net.InetAddress -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf) + val manager = new ConnectionManager(9999, new SparkConf, new SecurityManager) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala new file mode 100644 index 0000000000000..dd5519f6b39a7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network + +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.StringBuilder + +import org.apache.spark._ +import org.apache.spark.network._ + +private[spark] class SecurityMessage() extends Logging { + + private var connectionId: String = null + private var token: Array[Byte] = null + + def set(byteArr: Array[Byte], newconnectionId: String) { + if (byteArr == null) { + token = new Array[Byte](0) + } else { + token = byteArr + } + connectionId = newconnectionId + } + + def set(buffer: ByteBuffer) { + val idLength = buffer.getInt() + val idBuilder = new StringBuilder(idLength) + for (i <- 1 to idLength) { + idBuilder += buffer.getChar() + } + connectionId = idBuilder.toString() + + val tokenLength = buffer.getInt() + token = new Array[Byte](tokenLength) + if (tokenLength > 0) { + buffer.get(token, 0, tokenLength) + } + } + + def set(bufferMsg: BufferMessage) { + val buffer = bufferMsg.buffers.apply(0) + buffer.clear() + set(buffer) + } + + def getConnectionId: String = { + return connectionId + } + + def getToken: Array[Byte] = { + return token + } + + def toBufferMessage: BufferMessage = { + val startTime = System.currentTimeMillis + val buffers = new ArrayBuffer[ByteBuffer]() + + var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) + buffer.putInt(connectionId.length()) + connectionId.foreach((x: Char) => buffer.putChar(x)) + buffer.putInt(token.length) + + if (token.length > 0) { + buffer.put(token) + } + buffer.flip() + buffers += buffer + + var message = Message.createBufferMessage(buffers) + logDebug("message total size is : " + message.size) + message.isSecurityNeg = true + return message + } + + override def toString: String = { + "SecurityMessage [connId= " + connectionId + ", Token = " + token + "]" + } +} + +private[spark] object SecurityMessage { + + def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(bufferMessage) + newSecurityMessage + } + + def fromResponse(response : Array[Byte], newConnectionId : String) : SecurityMessage = { + val newSecurityMessage = new SecurityMessage() + newSecurityMessage.set(response, newConnectionId) + newSecurityMessage + } +} diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index dcbd183c88d09..195560ca856dc 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.network import java.nio.ByteBuffer import java.net.InetAddress -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { def main(args: Array[String]) { @@ -33,7 +33,7 @@ private[spark] object SenderTest { val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val manager = new ConnectionManager(0, new SparkConf) + val manager = new ConnectionManager(0, new SparkConf, new SecurityManager) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ed53558566edf..a39d151d78932 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -30,7 +30,7 @@ import scala.concurrent.duration._ import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} -import org.apache.spark.{SparkConf, Logging, SparkEnv, SparkException} +import org.apache.spark.{SecurityManager, SparkConf, Logging, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -44,7 +44,8 @@ private[spark] class BlockManager( val master: BlockManagerMaster, val defaultSerializer: Serializer, maxMemory: Long, - val conf: SparkConf) + val conf: SparkConf, + securityManager: SecurityManager) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -63,7 +64,7 @@ private[spark] class BlockManager( if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } - val connectionManager = new ConnectionManager(0, conf) + val connectionManager = new ConnectionManager(0, conf, securityManager) implicit val futureExecContext = connectionManager.futureExecContext val blockManagerId = BlockManagerId( @@ -119,8 +120,9 @@ private[spark] class BlockManager( * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, - serializer: Serializer, conf: SparkConf) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf) + serializer: Serializer, conf: SparkConf, securityManager: SecurityManager) = { + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf, + securityManager) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 729ba2c550a20..53b56e3ab544c 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,7 +22,7 @@ import akka.actor._ import java.util.concurrent.ArrayBlockingQueue import util.Random import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -97,7 +97,8 @@ private[spark] object ThreadingTest { val blockManagerMaster = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) val blockManager = new BlockManager( - "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf) + "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, + new SecurityManager()) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 7211dbc7c6681..fd83257993622 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui +import java.net.URL +import javax.servlet.http.HttpServlet import javax.servlet.http.{HttpServletResponse, HttpServletRequest} import scala.annotation.tailrec @@ -25,11 +27,14 @@ import scala.xml.Node import net.liftweb.json.{JValue, pretty, render} -import org.eclipse.jetty.server.{Server, Request, Handler} +import org.eclipse.jetty.server.{DispatcherType, Server} import org.eclipse.jetty.server.handler.{ResourceHandler, HandlerList, ContextHandler, AbstractHandler} +import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.Logging +import org.apache.spark.SparkEnv +import org.apache.spark.SecurityManager /** Utilities for launching a web server using Jetty's HTTP Server class */ @@ -39,56 +44,103 @@ private[spark] object JettyUtils extends Logging { type Responder[T] = HttpServletRequest => T // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): Handler = + implicit def jsonResponderToHandler(responder: Responder[JValue]): HttpServlet = createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): Handler = + implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): HttpServlet = createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - implicit def textResponderToHandler(responder: Responder[String]): Handler = + implicit def textResponderToHandler(responder: Responder[String]): HttpServlet = createHandler(responder, "text/plain") - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, + extractFn: T => String = (in: Any) => in.toString): HttpServlet = { + new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setContentType("%s;charset=utf-8".format(contentType)) - response.setStatus(HttpServletResponse.SC_OK) - baseRequest.setHandled(true) - val result = responder(request) - response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) + // First try to get the security Manager from the SparkEnv. If that doesn't exist, create + // a new one and rely on the configs being set + val sparkEnv = SparkEnv.get + val securityMgr = if (sparkEnv != null) sparkEnv.securityManager else new SecurityManager() + if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) { + response.setContentType("%s;charset=utf-8".format(contentType)) + response.setStatus(HttpServletResponse.SC_OK) + val result = responder(request) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.getWriter().println(extractFn(result)) + } else { + response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) + response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") + response.sendError(HttpServletResponse.SC_UNAUTHORIZED, + "User is not authorized to access this page."); + } } } } + def createServletHandler(path: String, servlet: HttpServlet): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler + } + /** Creates a handler that always redirects the user to a given path */ - def createRedirectHandler(newPath: String): Handler = { - new AbstractHandler { - def handle(target: String, - baseRequest: Request, - request: HttpServletRequest, + def createRedirectHandler(newPath: String, path: String): ServletContextHandler = { + val servlet = new HttpServlet { + override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - response.setStatus(302) - response.setHeader("Location", baseRequest.getRootURL + newPath) - baseRequest.setHandled(true) + // make sure we don't end up with // in the middle + val newUri = new URL(new URL(request.getRequestURL.toString), newPath).toURI + response.sendRedirect(newUri.toString) } } + val contextHandler = new ServletContextHandler() + val holder = new ServletHolder(servlet) + contextHandler.setContextPath(path) + contextHandler.addServlet(holder, "/") + contextHandler } /** Creates a handler for serving files from a static directory */ - def createStaticHandler(resourceBase: String): ResourceHandler = { - val staticHandler = new ResourceHandler + def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = { + val contextHandler = new ServletContextHandler() + val staticHandler = new DefaultServlet + val holder = new ServletHolder(staticHandler) Option(getClass.getClassLoader.getResource(resourceBase)) match { case Some(res) => - staticHandler.setResourceBase(res.toString) + holder.setInitParameter("resourceBase", res.toString) case None => throw new Exception("Could not find resource path for Web UI: " + resourceBase) } - staticHandler + contextHandler.addServlet(holder, path) + contextHandler + } + + private def addFilters(handlers: Seq[ServletContextHandler]) { + val filters : Array[String] = System.getProperty("spark.ui.filters", "").split(',').map(_.trim()) + filters.foreach { + case filter : String => + if (!filter.isEmpty) { + logInfo("Adding filter: " + filter) + val holder : FilterHolder = new FilterHolder() + holder.setClassName(filter) + // get any parameters for each filter + val paramName = filter + ".params" + val params = System.getProperty(paramName, "").split(',').map(_.trim()).toSet + params.foreach { + case param : String => + if (!param.isEmpty) { + val parts = param.split("=") + if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) + } + } + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, + DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) + handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } + } + } } /** @@ -97,15 +149,10 @@ private[spark] object JettyUtils extends Logging { * If the desired port number is contented, continues incrementing ports until a free port is * found. Returns the chosen port and the jetty Server object. */ - def startJettyServer(ip: String, port: Int, handlers: Seq[(String, Handler)]): (Server, Int) = { - val handlersToRegister = handlers.map { case(path, handler) => - val contextHandler = new ContextHandler(path) - contextHandler.setHandler(handler) - contextHandler.asInstanceOf[org.eclipse.jetty.server.Handler] - } - + def startJettyServer(ip: String, port: Int, handlers: Seq[ServletContextHandler]): (Server, Int) = { + addFilters(handlers) val handlerList = new HandlerList - handlerList.setHandlers(handlersToRegister.toArray) + handlerList.setHandlers(handlers.toArray) @tailrec def connect(currentPort: Int): (Server, Int) = { diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 50dfdbdf5ae9b..1d5c91c608794 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -19,7 +19,8 @@ package org.apache.spark.ui import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.{Handler, Server} +import org.eclipse.jetty.server.Server +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext, SparkEnv} import org.apache.spark.ui.env.EnvironmentUI @@ -36,9 +37,9 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { var boundPort: Option[Int] = None var server: Option[Server] = None - val handlers = Seq[(String, Handler)]( - ("/static", createStaticHandler(SparkUI.STATIC_RESOURCE_DIR)), - ("/", createRedirectHandler("/stages")) + val handlers = Seq[ServletContextHandler] ( + createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static/*"), + createRedirectHandler("/stages", "/") ) val storage = new BlockManagerUI(sc) val jobs = new JobProgressUI(sc) @@ -85,5 +86,5 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { private[spark] object SparkUI { val DEFAULT_PORT = "4040" - val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + val STATIC_RESOURCE_DIR = "org/apache/spark/ui" } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index 88f41be8d3dd2..37fb87fc37a85 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import scala.util.Properties import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.UIUtils @@ -33,8 +33,8 @@ import org.apache.spark.SparkContext private[spark] class EnvironmentUI(sc: SparkContext) { - def getHandlers = Seq[(String, Handler)]( - ("/environment", (request: HttpServletRequest) => envDetails(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/environment", (request: HttpServletRequest) => envDetails(request)) ) def envDetails(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index a31a7e1d58374..82a487110d69e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{HashMap, HashSet} import scala.xml.Node -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{ExceptionFailure, Logging, SparkContext} import org.apache.spark.executor.TaskMetrics @@ -44,8 +44,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { sc.addSparkListener(listener) } - def getHandlers = Seq[(String, Handler)]( - ("/executors", (request: HttpServletRequest) => render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/executors", (request: HttpServletRequest) => render(request)) ) def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index c1ee2f3d00d66..3516ed57a02e3 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -23,7 +23,7 @@ import java.text.SimpleDateFormat import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import scala.Seq import scala.collection.mutable.{HashSet, ListBuffer, HashMap, ArrayBuffer} @@ -53,9 +53,9 @@ private[spark] class JobProgressUI(val sc: SparkContext) { def formatDuration(ms: Long) = Utils.msDurationToString(ms) - def getHandlers = Seq[(String, Handler)]( - ("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - ("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - ("/stages", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), + createServletHandler("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), + createServletHandler("/stages", (request: HttpServletRequest) => indexPage.render(request)) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 39f422dd6b90f..8de6a16772dd1 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -21,7 +21,7 @@ import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.Handler +import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SparkContext} import org.apache.spark.ui.JettyUtils._ @@ -31,8 +31,8 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { val indexPage = new IndexPage(this) val rddPage = new RDDPage(this) - def getHandlers = Seq[(String, Handler)]( - ("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - ("/storage", (request: HttpServletRequest) => indexPage.render(request)) + def getHandlers = Seq[ServletContextHandler]( + createServletHandler("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), + createServletHandler("/storage", (request: HttpServletRequest) => indexPage.render(request)) ) } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 761d378c7fd8b..69fb9e9834408 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -25,11 +25,12 @@ import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SecurityManager} /** * Various utility classes for working with Akka. */ -private[spark] object AkkaUtils { +private[spark] object AkkaUtils extends Logging { /** * Creates an ActorSystem ready for remoting, with various Spark features. Returns both the @@ -42,7 +43,7 @@ private[spark] object AkkaUtils { * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, - conf: SparkConf): (ActorSystem, Int) = { + conf: SparkConf, securityManager: SecurityManager): (ActorSystem, Int) = { val akkaThreads = conf.getInt("spark.akka.threads", 4) val akkaBatchSize = conf.getInt("spark.akka.batchSize", 15) @@ -65,6 +66,15 @@ private[spark] object AkkaUtils { conf.getDouble("spark.akka.failure-detector.threshold", 300.0) val akkaHeartBeatInterval = conf.getInt("spark.akka.heartbeat.interval", 1000) + val secretKey = securityManager.getSecretKey() + val isAuthOn = securityManager.isAuthenticationEnabled() + if (isAuthOn && secretKey == null) { + throw new Exception("Secret key is null with authentication on") + } + val requireCookie = if (isAuthOn) "on" else "off" + val secureCookie = if (isAuthOn) secretKey else "" + logDebug("In createActorSystem, requireCookie is: " + requireCookie) + val akkaConf = ConfigFactory.parseMap(conf.getAkkaConf.toMap[String, String]).withFallback( ConfigFactory.parseString( s""" @@ -72,6 +82,8 @@ private[spark] object AkkaUtils { |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] |akka.stdout-loglevel = "ERROR" |akka.jvm-exit-on-fatal-error = off + |akka.remote.require-cookie = "$requireCookie" + |akka.remote.secure-cookie = "$secureCookie" |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector @@ -88,6 +100,8 @@ private[spark] object AkkaUtils { |akka.remote.log-remote-lifecycle-events = $lifecycleEvents |akka.log-dead-letters = $lifecycleEvents |akka.log-dead-letters-during-shutdown = $lifecycleEvents + |akka.remote.netty.require-cookie = "$requireCookie" + |akka.remote.netty.secure-cookie = "$secureCookie" """.stripMargin)) val actorSystem = if (indestructible) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index caa9bf4c9280e..31392dfe3492b 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,8 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address} +import java.net.{Authenticator, PasswordAuthentication} +import java.net.{InetAddress, URL, URLConnection, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} @@ -38,7 +39,7 @@ import org.apache.hadoop.io._ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkConf, SparkException, Logging} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, Logging} /** @@ -271,7 +272,44 @@ private[spark] object Utils extends Logging { uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) - val in = new URL(url).openStream() + + var uc: URLConnection = null + // First try to get the security Manager from the SparkEnv. If that doesn't exist, create + // a new one and rely on the configs being set + val sparkEnv = SparkEnv.get + val securityMgr = if (sparkEnv != null) sparkEnv.securityManager else new SecurityManager() + if (securityMgr.isAuthenticationEnabled()) { + val userCred = securityMgr.getSecretKey() + if (userCred == null) { + throw new Exception("secret key is null with authentication on") + } + val userInfo = securityMgr.getHttpUser() + ":" + userCred + val newuri = new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), + uri.getPath(), uri.getQuery(), uri.getFragment()) + uc = newuri.toURL().openConnection() + uc.setAllowUserInteraction(false) + logDebug("in security enabled") + + // set our own authenticator to properly negotiate user/password + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ); + } else { + logDebug("fetchFile not using security") + uc = new URL(url).openConnection() + } + + val in = uc.getInputStream(); val out = new FileOutputStream(tempFile) Utils.copyStream(in, out, true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala new file mode 100644 index 0000000000000..a5bd9bd938327 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import akka.actor._ +import org.apache.spark.scheduler.MapStatus +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.AkkaUtils +import scala.concurrent.Await + +/** + * Test the AkkaUtils with various security settings. + */ +class AkkaUtilsSuite extends FunSuite with LocalSparkContext { + private val conf = new SparkConf + + test("remote fetch security bad password") { + System.setProperty("spark.authenticate", "true") + System.setProperty("SPARK_SECRET", "good") + + val securityManager = new SecurityManager(); + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + System.setProperty("spark.authenticate", "true") + System.setProperty("SPARK_SECRET", "bad") + val securityManagerBad= new SecurityManager(); + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManagerBad.isAuthenticationEnabled() === true) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should fail since password wrong + intercept[SparkException] { slaveTracker.getServerStatuses(10, 0) } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off") { + System.setProperty("spark.authenticate", "false") + System.setProperty("SPARK_SECRET", "bad") + val securityManager = new SecurityManager(); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === false) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + System.setProperty("spark.authenticate", "false") + System.setProperty("SPARK_SECRET", "good") + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManager) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + assert(securityManager.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security off + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security pass") { + System.setProperty("spark.authenticate", "true") + System.setProperty("SPARK_SECRET", "good") + val securityManager = new SecurityManager(); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + System.setProperty("spark.authenticate", "true") + System.setProperty("SPARK_SECRET", "good") + + assert(securityManager.isAuthenticationEnabled() === true) + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManager) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should succeed since security on and passwords match + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000, 0), size1000))) + + actorSystem.shutdown() + slaveSystem.shutdown() + } + + test("remote fetch security off client") { + System.setProperty("spark.authenticate", "true") + System.setProperty("SPARK_SECRET", "good") + val securityManager = new SecurityManager(); + + val hostname = "localhost" + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + conf = conf, securityManager = securityManager) + System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext + System.setProperty("spark.hostPort", hostname + ":" + boundPort) + + assert(securityManager.isAuthenticationEnabled() === true) + + val masterTracker = new MapOutputTrackerMaster(conf) + masterTracker.trackerActor = actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + + System.setProperty("spark.authenticate", "false") + System.setProperty("SPARK_SECRET", "bad") + val securityManagerBad = new SecurityManager(); + + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + conf = conf, securityManager = securityManagerBad) + val slaveTracker = new MapOutputTracker(conf) + val selection = slaveSystem.actorSelection( + s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") + val timeout = AkkaUtils.lookupTimeout(conf) + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + + + assert(securityManagerBad.isAuthenticationEnabled() === false) + + masterTracker.registerShuffle(10, 1) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) + masterTracker.incrementEpoch() + slaveTracker.updateEpoch(masterTracker.getEpoch) + + // this should fail since security on in server and off in client + intercept[SparkException] { slaveTracker.getServerStatuses(10, 0) } + + actorSystem.shutdown() + slaveSystem.shutdown() + } + +} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index e022accee6d08..3bfd7d94d26ab 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,6 +21,9 @@ import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { + System.setProperty("spark.authenticate", "false") + + override def afterEach() { super.afterEach() System.clearProperty("spark.broadcast.factory") diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index fb89537258542..0c6b5b8488878 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { + test("driver should exit after finishing") { val sparkHome = sys.env.get("SPARK_HOME").orElse(sys.props.get("spark.home")).get // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" @@ -50,6 +51,7 @@ class DriverSuite extends FunSuite with Timeouts { */ object DriverWithoutCleanup { def main(args: Array[String]) { + System.setProperty("spark.authenticate", "false") Logger.getRootLogger().setLevel(Level.WARN) val sc = new SparkContext(args(0), "DriverWithoutCleanup") sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index a2eb9a4e84696..06e597d90186a 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -29,6 +29,12 @@ class FileServerSuite extends FunSuite with LocalSparkContext { @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ + override def beforeEach() { + super.beforeEach() + resetSparkContext() + System.setProperty("spark.authenticate", "false") + } + override def beforeAll() { super.beforeAll() val tmpDir = new File(Files.createTempDir(), "test") @@ -42,6 +48,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarFile = new File(tmpDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) + System.setProperty("spark.authenticate", "false") val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) @@ -76,6 +83,23 @@ class FileServerSuite extends FunSuite with LocalSparkContext { assert(result.toSet === Set((1,200), (2,300), (3,500))) } + test("Distributing files locally security On") { + System.setProperty("spark.authenticate", "true") + System.setProperty("SPARK_SECRET", "good") + + sc = new SparkContext("local[4]", "test") + sc.addFile(tmpFile.toString) + val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) + val result = sc.parallelize(testData).reduceByKey { + val path = SparkFiles.get("FileServerSuite.txt") + val in = new BufferedReader(new FileReader(path)) + val fileVal = in.readLine().toInt + in.close() + _ * fileVal + _ * fileVal + }.collect() + assert(result.toSet === Set((1,200), (2,300), (3,500))) + } + test("Distributing files locally using URL as input") { // addFile("file:///....") sc = new SparkContext("local[4]", "test") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 930c2523caf8c..e9319a9063776 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -97,14 +97,16 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, + securityManager = new SecurityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext val masterTracker = new MapOutputTrackerMaster(conf) masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, + securityManager = new SecurityManager) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 18aa587662d24..ee9ccfeaf2c9e 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -29,6 +29,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.SecurityManager import org.apache.spark.util.{SizeEstimator, Utils, AkkaUtils, ByteBufferInputStream} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.{SparkConf, SparkContext} @@ -40,6 +41,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var actorSystem: ActorSystem = null var master: BlockManagerMaster = null var oldArch: String = null + System.setProperty("spark.authenticate", "false") + val securityMgr = new SecurityManager() // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -50,7 +53,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId) before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("test", "localhost", 0, conf = conf, + securityManager = securityMgr) this.actorSystem = actorSystem conf.set("spark.driver.port", boundPort.toString) @@ -126,7 +130,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -156,8 +160,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, + securityMgr) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -172,7 +177,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -220,7 +225,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -254,7 +259,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -270,7 +275,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -289,7 +294,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -326,7 +331,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -345,7 +350,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -364,7 +369,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -383,7 +388,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -406,7 +411,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -419,7 +424,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -434,7 +439,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -449,7 +454,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -464,7 +469,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -479,7 +484,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -504,7 +509,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -528,7 +533,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -574,7 +579,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf) + store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -585,7 +590,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -593,7 +598,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -601,7 +606,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -609,28 +614,28 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -644,7 +649,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. - store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf) + store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, + securityMgr) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/docs/configuration.md b/docs/configuration.md index 00864906b3c7b..ef14fdd1c8b8b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -460,6 +460,41 @@ Apart from these, the following properties are also available, and may be useful Note: this setting needs to be configured in the standalone cluster master, not in individual applications; you can set it through SPARK_JAVA_OPTS in spark-env.sh. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark web ui. The filter should be a + standard javax servlet Filter. Parameters to each filter can also be specified by setting a + java system property of .params='param1=value1,param2=value2' + (e.g.-Dspark.ui.filters=com.test.filter1 -Dcom.test.filter1.params='param1=foo,param2=testing') + + + + spark.authenticate.ui + false + + Whether spark web ui authentication should be on. If enabled this checks the user access + permissions to view the web ui. See spark.ui.view.acls for more details. + Also note this requires the user to be known, if the user comes across as null no checks + are done. Filters can be used to authenticate and set the user. + + + + spark.ui.view.acls + Empty + + Comma separated list of users that have view access to the spark web ui. By default only the + user that started the Spark job has view access. + + + + spark.authenticate + false + + Whether spark authenticates its internal connections. See SPARK_SECRET if not + running on Yarn. + @@ -491,6 +526,8 @@ The following variables can be set in `spark-env.sh`: * `SPARK_JAVA_OPTS`, to add JVM options. This includes Java options like garbage collector settings and any system properties that you'd like to pass with `-D`. One use case is to set some Spark properties differently on this machine, e.g., `-Dspark.local.dir=/disk1,/disk2`. +* `SPARK_SECRET`, Set the secret key used for Spark to authenticate between components. This needs to be set if + not running on Yarn and authentication is enabled. * Options for the Spark [standalone cluster scripts](spark-standalone.html#cluster-launch-scripts), such as number of cores to use on each machine and maximum memory. diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index 57e1b1f806e82..93c34e43775ce 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -23,7 +23,7 @@ import scala.util.Random import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SecurityManager} import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions import org.apache.spark.streaming.receivers.Receiver @@ -113,7 +113,8 @@ object FeederActor { val Seq(host, port) = args.toSeq - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = new SparkConf)._1 + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = new SparkConf, + securityManager = new SecurityManager)._1 val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") println("Feeder started as:" + feeder) diff --git a/pom.xml b/pom.xml index 54072b053cb5e..9385b5dd7042a 100644 --- a/pom.xml +++ b/pom.xml @@ -155,6 +155,21 @@ + + org.eclipse.jetty + jetty-util + 7.6.8.v20121106 + + + org.eclipse.jetty + jetty-security + 7.6.8.v20121106 + + + org.eclipse.jetty + jetty-plus + 7.6.8.v20121106 + org.eclipse.jetty jetty-server @@ -284,6 +299,11 @@ mesos ${mesos.version} + + commons-net + commons-net + 2.2 + io.netty netty-all diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 151b1e7c799c9..936582e722d4b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -208,6 +208,9 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "io.netty" % "netty-all" % "4.0.13.Final", "org.eclipse.jetty" % "jetty-server" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-util" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-plus" % "7.6.8.v20121106", + "org.eclipse.jetty" % "jetty-security" % "7.6.8.v20121106", /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ "org.eclipse.jetty.orbit" % "javax.servlet" % "2.5.0.v201103041518" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", @@ -264,6 +267,7 @@ object SparkBuild extends Build { "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", "org.apache.mesos" % "mesos" % "0.13.0", + "commons-net" % "commons-net" % "2.2", "net.java.dev.jets3t" % "jets3t" % "0.7.1", "org.apache.derby" % "derby" % "10.4.2.0" % "test", "org.apache.hadoop" % "hadoop-client" % hadoopVersion excludeAll(excludeJackson, excludeNetty, excludeAsm, excludeCglib), diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 3e171849e3494..158779b621746 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -19,11 +19,15 @@ package org.apache.spark.repl import java.io.{ByteArrayOutputStream, InputStream} import java.net.{URI, URL, URLClassLoader, URLEncoder} +import java.net.Authenticator +import java.net.PasswordAuthentication import java.util.concurrent.{Executors, ExecutorService} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkEnv + import org.objectweb.asm._ import org.objectweb.asm.Opcodes._ @@ -52,7 +56,35 @@ extends ClassLoader(parent) { if (fileSystem != null) fileSystem.open(new Path(directory, pathInDirectory)) else + if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { + val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) + val userCred = SparkEnv.get.securityManager.getSecretKey() + if (userCred == null) { + throw new Exception("secret key is null with authentication on") + } + val userInfo = SparkEnv.get.securityManager.getHttpUser() + ":" + userCred + val newuri = new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), + uri.getPath(), uri.getQuery(), uri.getFragment()) + + // set our own authenticator to properly negotiate user/password + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ); + + newuri.toURL().openStream() + } else { new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } } val bytes = readAndTransformClass(name, inputStream) inputStream.close() diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 87d94d51be199..0ab47f3da600b 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -874,6 +874,8 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, }) def process(settings: Settings): Boolean = savingContextLoader { + if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") + this.settings = settings createInterpreter() @@ -932,16 +934,9 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, def createSparkContext(): SparkContext = { val execUri = System.getenv("SPARK_EXECUTOR_URI") - val master = this.master match { - case Some(m) => m - case None => { - val prop = System.getenv("MASTER") - if (prop != null) prop else "local" - } - } val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath) val conf = new SparkConf() - .setMaster(master) + .setMaster(getMaster()) .setAppName("Spark shell") .setJars(jars) .set("spark.repl.class.uri", intp.classServer.uri) @@ -956,6 +951,17 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, sparkContext } + private def getMaster(): String = { + val master = this.master match { + case Some(m) => m + case None => { + val prop = System.getenv("MASTER") + if (prop != null) prop else "local" + } + } + master + } + /** process command-line arguments and do as they request */ def process(args: Array[String]): Boolean = { val command = new SparkCommandLine(args.toList, msg => echo(msg)) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 59fdb0b37a766..927bec19db118 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -34,7 +34,7 @@ import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable import util.stackTraceString -import org.apache.spark.{HttpServer, SparkConf, Logging} +import org.apache.spark.{HttpServer, SparkConf, Logging, SecurityManager} import org.apache.spark.util.Utils // /** directory to save .class files to */ @@ -97,7 +97,7 @@ import org.apache.spark.util.Utils } val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - val classServer = new HttpServer(outputDir) /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir, new SecurityManager()) /** Jetty server that will serve our classes to worker nodes */ private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 8203b8f6122e1..799f717eb30da 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.SparkContext class ReplSuite extends FunSuite { + System.setProperty("spark.authenticate", "false") + def runInterpreter(master: String, input: String): String = { val in = new BufferedReader(new StringReader(input + "\n")) val out = new StringWriter() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 2e46d750c4a38..77c1bda495861 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ @@ -36,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging, SecurityManager} import org.apache.spark.util.Utils class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, @@ -81,25 +80,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts resourceManager = registerWithResourceManager() - // Workaround until hadoop moves to something which has - // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // ignore result. - // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times - // Hence args.workerCores = numCore disabled above. Any better option? - - // Compute number of threads for akka - //val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - //if (minimumMemory > 0) { - // val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD - // val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - // if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore - // } - //} - // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + // setup AmIpFilter for the SparkUI - do this before we start the UI + addAmIpFilter() ApplicationMaster.register(this) // Start the user's JAR @@ -121,6 +103,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.exit(0) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + System.setProperty("spark.ui.filters", amFilter) + val proxy = YarnConfiguration.getProxyHostAndPort(conf) + val parts : Array[String] = proxy.split(":") + val uriBase = "http://" + proxy + + System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + System.setProperty("org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) + } + /** Get the Yarn approved local directories. */ private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 9fe4d64a0fca0..8dc571b3b3126 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging, SecurityManager} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo @@ -50,8 +50,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var yarnAllocator: YarnAllocationHandler = _ private var driverClosed:Boolean = false + val securityManager = new SecurityManager() val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf)._1 + conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ // This actor just working as a monitor to watch on Driver Actor. @@ -110,6 +111,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 1419f215c78e5..2bd44d930b552 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -134,7 +134,7 @@ class ClientArguments(val args: Array[String], val sparkConf: SparkConf) { " --args ARGS Arguments to be passed to your application's main class.\n" + " Mutliple invocations are possible, each will be passed in order.\n" + " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1).\n" + " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 2ba2366ead171..89f12b602c32d 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -19,6 +19,7 @@ package org.apache.spark.deploy.yarn import org.apache.spark.deploy.SparkHadoopUtil import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.conf.Configuration @@ -40,4 +41,12 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { val jobCreds = conf.getCredentials() jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) } + + override def getCurrentUserCredentials(): Credentials = { + UserGroupInformation.getCurrentUser().getCredentials() + } + + override def addCurrentUserCredentials(creds: Credentials) { + UserGroupInformation.getCurrentUser().addCredentials(creds) + } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4b777d5fa7a28..0fcad80ae1ce2 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -27,7 +27,6 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils -import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.protocolrecords._ @@ -37,8 +36,9 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.apache.hadoop.yarn.webapp.util.WebAppUtils; -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging, SecurityManager} import org.apache.spark.util.Utils @@ -85,9 +85,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, amClient.init(yarnConf) amClient.start() - // Workaround until hadoop moves to something which has - // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) + // setup AmIpFilter for the SparkUI - do this before we start the UI + addAmIpFilter() ApplicationMaster.register(this) @@ -110,6 +109,19 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.exit(0) } + // add the yarn amIpFilter that Yarn requires for properly securing the UI + private def addAmIpFilter() { + val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + System.setProperty("spark.ui.filters", amFilter) + val proxy = WebAppUtils.getProxyHostAndPort(conf) + val parts : Array[String] = proxy.split(":") + val uriBase = "http://" + proxy + + System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) + + val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase + System.setProperty("org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) + } + /** Get the Yarn approved local directories. */ private def getLocalDirs(): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the @@ -249,7 +261,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, val schedulerInterval = sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000) - // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 78353224fa4b8..516e435e79d08 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging, SecurityManager} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo @@ -52,8 +52,9 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var amClient: AMRMClient[ContainerRequest] = _ + val securityManager = new SecurityManager() val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, - conf = sparkConf)._1 + conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ // This actor just working as a monitor to watch on Driver Actor. @@ -105,6 +106,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar val interval = math.min(timeoutInterval / 2, schedulerInterval) reporterThread = launchReporterThread(interval) + // Wait for the reporter thread to Finish. reporterThread.join() From 5721c5ac83b62afb8e8201730e4fc6bc76556e5b Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 20 Jan 2014 09:17:56 -0600 Subject: [PATCH 02/14] update AkkaUtilsSuite test for the actorSelection changes, fix typos based on comments, and remove extra lines I missed in rebase from AkkaUtils --- .../org/apache/spark/SecurityManager.scala | 4 +- .../org/apache/spark/util/AkkaUtils.scala | 2 - .../org/apache/spark/AkkaUtilsSuite.scala | 45 +++++-------------- 3 files changed, 12 insertions(+), 39 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 7aaceb1bd3767..37ac0138a3dbb 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -36,11 +36,11 @@ private[spark] class SecurityManager extends Logging { logDebug("is auth enabled = " + isAuthOn + " is uiAuth enabled = " + isUIAuthOn) /** - * In Yarn mode its uses Hadoop UGI to pass the secret as that + * In Yarn mode it uses Hadoop UGI to pass the secret as that * will keep it protected. For a standalone SPARK cluster * use a environment variable SPARK_SECRET to specify the secret. * This probably isn't ideal but only the user who starts the process - * should have access to view the variable (atleast on Linux). + * should have access to view the variable (at least on Linux). * Since we can't set the environment variable we set the * java system property SPARK_SECRET so it will automatically * generate a secret is not specified. This definitely is not diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 69fb9e9834408..043241d2fed90 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -100,8 +100,6 @@ private[spark] object AkkaUtils extends Logging { |akka.remote.log-remote-lifecycle-events = $lifecycleEvents |akka.log-dead-letters = $lifecycleEvents |akka.log-dead-letters-during-shutdown = $lifecycleEvents - |akka.remote.netty.require-cookie = "$requireCookie" - |akka.remote.netty.secure-cookie = "$secureCookie" """.stripMargin)) val actorSystem = if (indestructible) { diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index a5bd9bd938327..1d33dd5db6a32 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -51,29 +51,17 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { System.setProperty("SPARK_SECRET", "bad") val securityManagerBad= new SecurityManager(); + assert(securityManagerBad.isAuthenticationEnabled() === true) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should fail since password wrong - intercept[SparkException] { slaveTracker.getServerStatuses(10, 0) } + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } actorSystem.shutdown() slaveSystem.shutdown() @@ -198,30 +186,17 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { System.setProperty("SPARK_SECRET", "bad") val securityManagerBad = new SecurityManager(); + assert(securityManagerBad.isAuthenticationEnabled() === false) + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) - - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000, 0), Array(compressedSize1000))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should fail since security on in server and off in client - intercept[SparkException] { slaveTracker.getServerStatuses(10, 0) } + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + } actorSystem.shutdown() slaveSystem.shutdown() From 6f7ddf38d3b3f3c367df4d0b9a6be3a0bc644e1d Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sat, 22 Feb 2014 10:37:03 -0600 Subject: [PATCH 03/14] Convert SaslClient and SaslServer to scala, change spark.authenticate.ui to spark.ui.acls.enable, and fix up various other things from review comments --- .../org/apache/spark/SparkSaslClient.java | 182 ------------- .../org/apache/spark/SparkSaslServer.java | 189 -------------- .../scala/org/apache/spark/HttpServer.scala | 75 +++--- .../org/apache/spark/SecurityManager.scala | 246 ++++++++++++++---- .../org/apache/spark/SparkSaslClient.scala | 139 ++++++++++ .../org/apache/spark/SparkSaslServer.scala | 168 ++++++++++++ .../spark/broadcast/HttpBroadcast.scala | 30 +-- .../apache/spark/deploy/SparkHadoopUtil.scala | 5 + .../org/apache/spark/network/Connection.scala | 5 +- .../apache/spark/network/ConnectionId.scala | 16 +- .../spark/network/ConnectionManager.scala | 54 ++-- .../scala/org/apache/spark/util/Utils.scala | 42 ++- docs/configuration.md | 64 +++-- docs/index.md | 1 + .../spark/repl/ExecutorClassLoader.scala | 36 +-- .../deploy/yarn/YarnSparkHadoopUtil.scala | 15 +- 16 files changed, 666 insertions(+), 601 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/SparkSaslClient.java delete mode 100644 core/src/main/java/org/apache/spark/SparkSaslServer.java create mode 100644 core/src/main/scala/org/apache/spark/SparkSaslClient.scala create mode 100644 core/src/main/scala/org/apache/spark/SparkSaslServer.scala diff --git a/core/src/main/java/org/apache/spark/SparkSaslClient.java b/core/src/main/java/org/apache/spark/SparkSaslClient.java deleted file mode 100644 index 5fab593270992..0000000000000 --- a/core/src/main/java/org/apache/spark/SparkSaslClient.java +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; - -import javax.security.auth.callback.Callback; -import javax.security.auth.callback.CallbackHandler; -import javax.security.auth.callback.NameCallback; -import javax.security.auth.callback.PasswordCallback; -import javax.security.auth.callback.UnsupportedCallbackException; -import javax.security.sasl.RealmCallback; -import javax.security.sasl.RealmChoiceCallback; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslException; -import javax.security.sasl.SaslClient; - -/** - * Implements SASL Client logic for Spark - * Some of the code borrowed from Giraph and Hadoop - */ -public class SparkSaslClient { - /** Class logger */ - private static Logger LOG = LoggerFactory.getLogger(SparkSaslClient.class); - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - */ - private SaslClient saslClient; - - /** - * Create a SaslClient for authentication with BSP servers. - */ - public SparkSaslClient(SecurityManager securityMgr) { - try { - saslClient = Sasl.createSaslClient(new String[] { SparkSaslServer.DIGEST }, - null, null, SparkSaslServer.SASL_DEFAULT_REALM, - SparkSaslServer.SASL_PROPS, new SparkSaslClientCallbackHandler(securityMgr)); - } catch (IOException e) { - LOG.error("SaslClient: Could not create SaslClient"); - saslClient = null; - } - } - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - * @throws IOException - */ - public byte[] firstToken() throws SaslException { - byte[] saslToken = new byte[0]; - if (saslClient.hasInitialResponse()) { - LOG.debug("has initial response"); - saslToken = saslClient.evaluateChallenge(saslToken); - } - return saslToken; - } - - /** - * Determines whether the authentication exchange has completed. - */ - public boolean isComplete() { - return saslClient.isComplete(); - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - public byte[] saslResponse(byte[] saslTokenMessage) throws SaslException { - try { - byte[] retval = saslClient.evaluateChallenge(saslTokenMessage); - return retval; - } catch (SaslException e) { - LOG.error("saslResponse: Failed to respond to SASL server's token:", e); - throw e; - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - public void dispose() throws SaslException { - if (saslClient != null) { - try { - saslClient.dispose(); - saslClient = null; - } catch (SaslException ignored) { - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private static class SparkSaslClientCallbackHandler implements CallbackHandler { - private final String userName; - private final char[] userPassword; - - /** - * Constructor - */ - public SparkSaslClientCallbackHandler(SecurityManager securityMgr) { - this.userName = SparkSaslServer. - encodeIdentifier(securityMgr.getSaslUser().getBytes()); - String secretKey = securityMgr.getSecretKey() ; - String passwd = (secretKey != null) ? secretKey : ""; - this.userPassword = SparkSaslServer.encodePassword(passwd.getBytes()); - } - - /** - * Implementation used to respond to SASL tokens from server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - * @throws UnsupportedCallbackException - */ - public void handle(Callback[] callbacks) - throws UnsupportedCallbackException { - NameCallback nc = null; - PasswordCallback pc = null; - RealmCallback rc = null; - for (Callback callback : callbacks) { - if (callback instanceof RealmChoiceCallback) { - continue; - } else if (callback instanceof NameCallback) { - nc = (NameCallback) callback; - } else if (callback instanceof PasswordCallback) { - pc = (PasswordCallback) callback; - } else if (callback instanceof RealmCallback) { - rc = (RealmCallback) callback; - } else { - throw new UnsupportedCallbackException(callback, - "handle: Unrecognized SASL client callback"); - } - } - if (nc != null) { - if (LOG.isDebugEnabled()) { - LOG.debug("handle: SASL client callback: setting username: " + - userName); - } - nc.setName(userName); - } - if (pc != null) { - if (LOG.isDebugEnabled()) { - LOG.debug("handle: SASL client callback: setting userPassword"); - } - pc.setPassword(userPassword); - } - if (rc != null) { - if (LOG.isDebugEnabled()) { - LOG.debug("handle: SASL client callback: setting realm: " + - rc.getDefaultText()); - } - rc.setText(rc.getDefaultText()); - } - } - } -} diff --git a/core/src/main/java/org/apache/spark/SparkSaslServer.java b/core/src/main/java/org/apache/spark/SparkSaslServer.java deleted file mode 100644 index 5c2fcf9afca15..0000000000000 --- a/core/src/main/java/org/apache/spark/SparkSaslServer.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark; - -import org.apache.commons.net.util.Base64; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; -import java.util.TreeMap; - -import javax.security.auth.callback.Callback; -import javax.security.auth.callback.CallbackHandler; -import javax.security.auth.callback.NameCallback; -import javax.security.auth.callback.PasswordCallback; -import javax.security.auth.callback.UnsupportedCallbackException; -import javax.security.sasl.AuthorizeCallback; -import javax.security.sasl.RealmCallback; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslException; -import javax.security.sasl.SaslServer; -import java.io.IOException; - -/** - * Encapsulates SASL server logic for Server - */ -public class SparkSaslServer { - /** Logger */ - private static Logger LOG = LoggerFactory.getLogger("SparkSaslServer.class"); - - /** - * Actual SASL work done by this object from javax.security.sasl. - * Initialized below in constructor. - */ - private SaslServer saslServer; - - public static final String SASL_DEFAULT_REALM = "default"; - public static final String DIGEST = "DIGEST-MD5"; - public static final Map SASL_PROPS = - new TreeMap(); - - /** - * Constructor - */ - public SparkSaslServer(SecurityManager securityMgr) { - try { - SASL_PROPS.put(Sasl.QOP, "auth"); - SASL_PROPS.put(Sasl.SERVER_AUTH, "true"); - saslServer = Sasl.createSaslServer(DIGEST, null, SASL_DEFAULT_REALM, SASL_PROPS, - new SaslDigestCallbackHandler(securityMgr)); - } catch (SaslException e) { - LOG.error("SparkSaslServer: Could not create SaslServer: " + e); - saslServer = null; - } - } - - /** - * Determines whether the authentication exchange has completed. - */ - public boolean isComplete() { - return saslServer.isComplete(); - } - - /** - * Used to respond to server SASL tokens. - * - * @param token Server's SASL token - * @return response to send back to the server. - */ - public byte[] response(byte[] token) throws SaslException { - try { - byte[] retval = saslServer.evaluateResponse(token); - if (LOG.isDebugEnabled()) { - LOG.debug("response: Response token length: " + retval.length); - } - return retval; - } catch (SaslException e) { - LOG.error("Response: Failed to evaluate client token of length: " + - token.length + " : " + e); - throw e; - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - public void dispose() throws SaslException { - if (saslServer != null) { - try { - saslServer.dispose(); - saslServer = null; - } catch (SaslException ignored) { - } - } - } - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - static String encodeIdentifier(byte[] identifier) { - return new String(Base64.encodeBase64(identifier)); - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - static char[] encodePassword(byte[] password) { - return new String(Base64.encodeBase64(password)).toCharArray(); - } - - /** CallbackHandler for SASL DIGEST-MD5 mechanism */ - public static class SaslDigestCallbackHandler implements CallbackHandler { - - private SecurityManager securityManager; - - /** - * Constructor - */ - public SaslDigestCallbackHandler(SecurityManager securityMgr) { - this.securityManager = securityMgr; - } - - @Override - public void handle(Callback[] callbacks) throws IOException, - UnsupportedCallbackException { - NameCallback nc = null; - PasswordCallback pc = null; - AuthorizeCallback ac = null; - LOG.debug("in the sasl server callback handler"); - for (Callback callback : callbacks) { - if (callback instanceof AuthorizeCallback) { - ac = (AuthorizeCallback) callback; - } else if (callback instanceof NameCallback) { - nc = (NameCallback) callback; - } else if (callback instanceof PasswordCallback) { - pc = (PasswordCallback) callback; - } else if (callback instanceof RealmCallback) { - continue; // realm is ignored - } else { - throw new UnsupportedCallbackException(callback, - "handle: Unrecognized SASL DIGEST-MD5 Callback"); - } - } - if (pc != null) { - char[] password = - encodePassword(securityManager.getSecretKey().getBytes()); - pc.setPassword(password); - } - - if (ac != null) { - String authid = ac.getAuthenticationID(); - String authzid = ac.getAuthorizationID(); - if (authid.equals(authzid)) { - LOG.debug("set auth to true"); - ac.setAuthorized(true); - } else { - LOG.debug("set auth to false"); - ac.setAuthorized(false); - } - if (ac.isAuthorized()) { - LOG.debug("sasl server is authorized"); - ac.setAuthorizedID(authzid); - } - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 2bc91ef5318eb..d1dc51d0ef8ed 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -63,46 +63,17 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan val resHandler = new ResourceHandler resHandler.setResourceBase(resourceBase.getAbsolutePath) - if (securityManager.isAuthenticationEnabled()) { - logDebug("server is using security") - val constraint = new Constraint() - constraint.setName(Constraint.__DIGEST_AUTH) - constraint.setRoles(Array("user")) - constraint.setAuthenticate(true) - constraint.setDataConstraint(Constraint.DC_NONE) - - val cm = new ConstraintMapping() - cm.setConstraint(constraint) - cm.setPathSpec("/*") - - val sh = new ConstraintSecurityHandler() - - // the hashLoginService lets us do a simply user and - // secret right now. This could be changed to use the - // JAASLoginService for other options. - val hashLogin = new HashLoginService() - - val userCred = new Password(securityManager.getSecretKey()) - if (userCred == null) { - throw new Exception("secret key is null with authentication on") - } - hashLogin.putUser(securityManager.getHttpUser(), userCred, Array("user")) - - logDebug("hashlogin loading user: " + hashLogin.getUsers()) - - sh.setLoginService(hashLogin) - sh.setAuthenticator(new DigestAuthenticator()); - sh.setConstraintMappings(Array(cm)) + val handlerList = new HandlerList + handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + if (securityManager.isAuthenticationEnabled()) { + logDebug("HttpServer is using security") + val sh = setupSecurityHandler(securityManager) // make sure we go through security handler to get resources - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) sh.setHandler(handlerList) server.setHandler(sh) } else { - logDebug("server is not using security") - val handlerList = new HandlerList - handlerList.setHandlers(Array(resHandler, new DefaultHandler)) + logDebug("HttpServer is not using security") server.setHandler(handlerList) } @@ -111,6 +82,40 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan } } + /** + * Setup Jetty to the HashLoginService using a single user with our + * shared secret. Configure it to use DIGEST-MD5 authentication so that the password + * isn't passed in plaintext. + */ + private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { + val constraint = new Constraint() + // use DIGEST-MD5 as the authentication mechanism + constraint.setName(Constraint.__DIGEST_AUTH) + constraint.setRoles(Array("user")) + constraint.setAuthenticate(true) + constraint.setDataConstraint(Constraint.DC_NONE) + + val cm = new ConstraintMapping() + cm.setConstraint(constraint) + cm.setPathSpec("/*") + val sh = new ConstraintSecurityHandler() + + // the hashLoginService lets us do a single user and + // secret right now. This could be changed to use the + // JAASLoginService for other options. + val hashLogin = new HashLoginService() + + val userCred = new Password(securityMgr.getSecretKey()) + if (userCred == null) { + throw new Exception("Error: secret key is null with authentication on") + } + hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user")) + sh.setLoginService(hashLogin) + sh.setAuthenticator(new DigestAuthenticator()); + sh.setConstraintMappings(Array(cm)) + sh + } + def stop() { if (server == null) { throw new ServerStateException("Server is already stopped") diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 37ac0138a3dbb..f21dff51caaac 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -18,51 +18,185 @@ package org.apache.spark +import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation - import org.apache.spark.deploy.SparkHadoopUtil +import scala.collection.mutable.ArrayBuffer + /** - * Spark class responsible for security. + * Spark class responsible for security. + * + * In general this class should be instantiated by the SparkEnv and most components + * should access it from that. There are some cases where the SparkEnv hasn't been + * initialized yet and this class must be instantiated directly. + * + * Spark currently supports authentication via a shared secret. + * Authentication can be configured to be on via the 'spark.authenticate' configuration + * parameter. This parameter controls whether the Spark communication protocols do + * authentication using the shared secret. This authentication is a basic handshake to + * make sure both sides have the same shared secret and are allowed to communicate. + * If the shared secret is not identical they will not be allowed to communicate. + * + * The Spark UI can also be secured by using javax servlet filters. A user may want to + * secure the UI if it has data that other users should not be allowed to see. The javax + * servlet filter specified by the user can authenticate the user and then once the user + * is logged in, Spark can compare that user versus the view acls to make sure they are + * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * control the behavior of the acls. Note that the person who started the application + * always has view access to the UI. + * + * Spark does not currently support encryption after authentication. + * + * At this point spark has multiple communication protocols that need to be secured and + * different underlying mechisms are used depending on the protocol: + * + * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. + * Akka remoting allows you to specify a secure cookie that will be exchanged + * and ensured to be identical in the connection handshake between the client + * and the server. If they are not identical then the client will be refused + * to connect to the server. There is no control of the underlying + * authentication mechanism so its not clear if the password is passed in + * plaintext or uses DIGEST-MD5 or some other mechanism. + * Akka also has an option to turn on SSL, this option is not currently supported + * but we could add a configuration option in the future. + * + * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty + * for the HttpServer. Jetty supports multiple authentication mechanisms - + * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService + * to authenticate using DIGEST-MD5 via a single user and the shared secret. + * Since we are using DIGEST-MD5, the shared secret is not passed on the wire + * in plaintext. + * We currently do not support SSL (https), but Jetty can be configured to use it + * so we could add a configuration option for this in the future. + * + * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. + * Any clients must specify the user and password. There is a default + * Authenticator installed in the SecurityManager to how it does the authentication + * and in this case gets the user name and password from the request. + * + * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * exchange messages. For this we use the Java SASL + * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 + * as the authentication mechanism. This means the shared secret is not passed + * over the wire in plaintext. + * Note that SASL is pluggable as to what mechanism it uses. We currently use + * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. + * Spark currently supports "auth" for the quality of protection, which means + * the connection is not supporting integrity or privacy protection (encryption) + * after authentication. SASL also supports "auth-int" and "auth-conf" which + * SPARK could be support in the future to allow the user to specify the quality + * of protection they want. If we support those, the messages will also have to + * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. + * + * Since the connectionManager does asynchronous messages passing, the SASL + * authentication is a bit more complex. A ConnectionManager can be both a client + * and a Server, so for a particular connection is has to determine what to do. + * A ConnectionId was added to be able to track connections and is used to + * match up incoming messages with connections waiting for authentication. + * If its acting as a client and trying to send a message to another ConnectionManager, + * it blocks the thread calling sendMessage until the SASL negotiation has occurred. + * The ConnectionManager tracks all the sendingConnections using the ConnectionId + * and waits for the response from the server and does the handshake. + * + * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters + * can be used. Yarn requires a specific AmIpFilter be installed for security to work + * properly. For non-Yarn deployments, users can write a filter to go through a companies + * normal login service. If an authentication filter is in place then the SparkUI + * can be configured to check the logged in user against the list of users who have + * view acls to see if that user is authorized. + * The filters can also be used for many different purposes. For instance filters + * could be used for logging, encypryption, or compression. + * + * The exact mechanisms used to generate/distributed the shared secret is deployment specific. + * + * For Yarn deployments, the secret is automatically generated using the Akka remote + * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed + * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels + * of protection. See the Hadoop documentation for more details. Each Spark application on Yarn + * gets a different shared secret. On Yarn, the Spark UI gets configured to use the Hadoop Yarn + * AmIpFilter which requires the user to go through the ResourceManager Proxy. That Proxy is there + * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use + * filters to do authentication. That authentication then happens via the ResourceManager Proxy + * and Spark will use that to do authorization against the view acls. + * + * For other Spark deployments, the shared secret should be specified via the SPARK_SECRET + * environment variable. This isn't ideal but it means only the user who starts the process + * has access to view that variable. Note that Spark does try to generate a secret for + * you if the SPARK_SECRET environment variable is not set, but it gets put into the java + * system property which can be viewed by other users, so setting the SPARK_SECRET environment + * variable is recommended. + * All the nodes (Master and Workers) need to have the same shared secret + * and all the applications running need to have that same shared secret. This again + * is not ideal as one user could potentially affect another users application. + * This should be enhanced in the future to provide better protection. + * If the UI needs to be secured the user needs to install a javax servlet filter to do the + * authentication. Spark will then use that user to compare against the view acls to do + * authorization. If not filter is in place the user is generally null and no authorization + * can take place. */ private[spark] class SecurityManager extends Logging { - private val isAuthOn = System.getProperty("spark.authenticate", "false").toBoolean - private val isUIAuthOn = System.getProperty("spark.authenticate.ui", "false").toBoolean - private val viewAcls = System.getProperty("spark.ui.view.acls", "").split(',').map(_.trim()).toSet + // key used to store the spark secret in the Hadoop UGI + private val sparkSecretLookupKey = "sparkCookie" + + private val authOn = System.getProperty("spark.authenticate", "false").toBoolean + private val uiAclsOn = System.getProperty("spark.ui.acls.enable", "false").toBoolean + + // always add the current user and SPARK_USER to the viewAcls + private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""), + Option(System.getenv("SPARK_USER")).getOrElse("")) + aclUsers ++= System.getProperty("spark.ui.view.acls", "").split(',') + private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet + private val secretKey = generateSecretKey() - logDebug("is auth enabled = " + isAuthOn + " is uiAuth enabled = " + isUIAuthOn) - + logDebug("is auth enabled = " + authOn + " is uiAcls enabled = " + uiAclsOn) + + // Set our own authenticator to properly negotiate user/password for HTTP connections. + // This is needed by the HTTP client fetching from the HttpServer. Put here so its + // only set once. + if (authOn) { + Authenticator.setDefault( + new Authenticator() { + override def getPasswordAuthentication(): PasswordAuthentication = { + var passAuth: PasswordAuthentication = null + val userInfo = getRequestingURL().getUserInfo() + if (userInfo != null) { + val parts = userInfo.split(":", 2) + passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) + } + return passAuth + } + } + ); + } + /** - * In Yarn mode it uses Hadoop UGI to pass the secret as that - * will keep it protected. For a standalone SPARK cluster - * use a environment variable SPARK_SECRET to specify the secret. - * This probably isn't ideal but only the user who starts the process - * should have access to view the variable (at least on Linux). - * Since we can't set the environment variable we set the - * java system property SPARK_SECRET so it will automatically - * generate a secret is not specified. This definitely is not - * ideal since users can see it. We should switch to put it in - * a config. + * Generates or looks up the secret key. + * + * The way the key is stored depends on the Spark deployment mode. Yarn + * uses the Hadoop UGI. + * + * For non-Yarn deployments, If the environment variable is not set already + * we generate a secret and since we can't set an environment variable dynamically + * we set the java system property SPARK_SECRET. This will allow it to automatically + * work in certain situations. Others this still will not work and this definitely is + * not ideal since other users can see it. We should switch to put it in + * a config once Spark supports configs. */ private def generateSecretKey(): String = { - if (!isAuthenticationEnabled) return null - // first check to see if secret already set, else generate it + // first check to see if the secret is already set, else generate a new one if (SparkHadoopUtil.get.isYarnMode) { - val credentials = SparkHadoopUtil.get.getCurrentUserCredentials() - if (credentials != null) { - val secretKey = credentials.getSecretKey(new Text("akkaCookie")) - if (secretKey != null) { - logDebug("in yarn mode, getting secret from credentials") - return new Text(secretKey).toString - } else { - logDebug("getSecretKey: yarn mode, secret key from credentials is null") - } + val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) + if (secretKey != null) { + logDebug("in yarn mode, getting secret from credentials") + return new Text(secretKey).toString } else { - logDebug("getSecretKey: yarn mode, credentials are null") + logDebug("getSecretKey: yarn mode, secret key from credentials is null") } } val secret = System.getProperty("SPARK_SECRET", System.getenv("SPARK_SECRET")) @@ -70,43 +204,57 @@ private[spark] class SecurityManager extends Logging { // generate one val sCookie = akka.util.Crypt.generateSecureCookie - // if we generate we must be the first so lets set it so its used by everyone else + // if we generated the secret then we must be the first so lets set it so t + // gets used by everyone else if (SparkHadoopUtil.get.isYarnMode) { - val creds = new Credentials() - creds.addSecretKey(new Text("akkaCookie"), sCookie.getBytes()) - SparkHadoopUtil.get.addCurrentUserCredentials(creds) + SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, sCookie) logDebug("adding secret to credentials yarn mode") } else { System.setProperty("SPARK_SECRET", sCookie) logDebug("adding secret to java property") } - return sCookie + sCookie } - def isUIAuthenticationEnabled(): Boolean = return isUIAuthOn + /** + * Check to see if Acls for the UI are enabled + * @return true if UI authentication is enabled, otherwise false + */ + def uiAclsEnabled(): Boolean = uiAclsOn - // allow anyone in the acl list and the application owner + /** + * Checks the given user against the view acl list to see if they have + * authorization to view the UI. + * @param user to see if is authorized + * @return true is the user has permission, otherwise false + */ def checkUIViewPermissions(user: String): Boolean = { - if (isUIAuthenticationEnabled() && (user != null)) { - if ((!viewAcls.contains(user)) && (user != System.getProperty("user.name"))) { - return false - } - } - return true + if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true } - def isAuthenticationEnabled(): Boolean = return isAuthOn + /** + * Check to see if authentication for the Spark communication protocols is enabled + * @return true if authentication is enabled, otherwise false + */ + def isAuthenticationEnabled(): Boolean = authOn - // user for HTTP connections + /** + * Gets the user used for authenticating HTTP connections. + * For now use a single hardcoded user. + * @return the HTTP user as a String + */ def getHttpUser(): String = "sparkHttpUser" - // user to use with SASL connections + /** + * Gets the user used for authenticating SASL connections. + * For now use a single hardcoded user. + * @return the SASL user as a String + */ def getSaslUser(): String = "sparkSaslUser" /** - * Gets the secret key if security is enabled, else returns null. + * Gets the secret key. + * @return the secret key as a String if authentication is enabled, otherwise returns null */ - def getSecretKey(): String = { - return secretKey - } + def getSecretKey(): String = secretKey } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala new file mode 100644 index 0000000000000..2737a82b85fef --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License") you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.IOException +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.RealmCallback +import javax.security.sasl.RealmChoiceCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslClient +import javax.security.sasl.SaslException + +import scala.collection.JavaConversions.mapAsJavaMap + +/** + * Implements SASL Client logic for Spark + */ +private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { + + /** + * Used to respond to server's counterpart, SaslServer with SASL tokens + * represented as byte arrays. + * + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), + null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslClientCallbackHandler(securityMgr)) + + /** + * Used to initiate SASL handshake with server. + * @return response to challenge if needed + */ + def firstToken(): Array[Byte] = { + val saslToken: Array[Byte] = + if (saslClient.hasInitialResponse()) { + logDebug("has initial response") + saslClient.evaluateChallenge(new Array[Byte](0)) + } else { + new Array[Byte](0) + } + saslToken + } + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + saslClient.isComplete() + } + + /** + * Respond to server's SASL token. + * @param saslTokenMessage contains server's SASL token + * @return client's response SASL token + */ + def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { + saslClient.evaluateChallenge(saslTokenMessage) + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + def dispose() { + if (saslClient != null) { + try { + saslClient.dispose() + } catch { + case e: SaslException => // ignored + } finally { + saslClient = null + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends + CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + private val secretKey = securityMgr.getSecretKey() + private val userPassword: Array[Char] = + SparkSaslServer.encodePassword(if (secretKey != null) secretKey.getBytes() else "".getBytes()) + + /** + * Implementation used to respond to SASL request from the server. + * + * @param callbacks objects that indicate what credential information the + * server's SaslServer requires from the client. + */ + override def handle(callbacks: Array[Callback]) { + logDebug("in the sasl client callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL client callback: setting username: " + userName) + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL client callback: setting userPassword") + pc.setPassword(userPassword) + } + case rc: RealmCallback => { + logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case cb: RealmChoiceCallback => {} + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala new file mode 100644 index 0000000000000..633f83b46ed31 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License") you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +import javax.security.auth.callback.Callback +import javax.security.auth.callback.CallbackHandler +import javax.security.auth.callback.NameCallback +import javax.security.auth.callback.PasswordCallback +import javax.security.auth.callback.UnsupportedCallbackException +import javax.security.sasl.AuthorizeCallback +import javax.security.sasl.RealmCallback +import javax.security.sasl.Sasl +import javax.security.sasl.SaslException +import javax.security.sasl.SaslServer +import scala.collection.JavaConversions.mapAsJavaMap +import org.apache.commons.net.util.Base64 + +/** + * Encapsulates SASL server logic + */ +private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { + + /** + * Actual SASL work done by this object from javax.security.sasl. + */ + private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, + SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + new SparkSaslDigestCallbackHandler(securityMgr)) + + /** + * Determines whether the authentication exchange has completed. + * @return true is complete, otherwise false + */ + def isComplete(): Boolean = { + saslServer.isComplete() + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + def response(token: Array[Byte]): Array[Byte] = { + saslServer.evaluateResponse(token) + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + def dispose() { + if (saslServer != null) { + try { + saslServer.dispose() + } catch { + case e: SaslException => // ignore + } finally { + saslServer = null + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * for SASL DIGEST-MD5 mechanism + */ + private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) + extends CallbackHandler { + + private val userName: String = + SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) + + override def handle(callbacks: Array[Callback]) { + logDebug("In the sasl server callback handler") + callbacks foreach { + case nc: NameCallback => { + logDebug("handle: SASL server callback: setting username") + nc.setName(userName) + } + case pc: PasswordCallback => { + logDebug("handle: SASL server callback: setting userPassword") + val password: Array[Char] = + SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes()) + pc.setPassword(password) + } + case rc: RealmCallback => { + logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) + rc.setText(rc.getDefaultText()) + } + case ac: AuthorizeCallback => { + val authid = ac.getAuthenticationID() + val authzid = ac.getAuthorizationID() + if (authid.equals(authzid)) { + logDebug("set auth to true") + ac.setAuthorized(true) + } else { + logDebug("set auth to false") + ac.setAuthorized(false) + } + if (ac.isAuthorized()) { + logDebug("sasl server is authorized") + ac.setAuthorizedID(authzid) + } + } + case cb: Callback => throw + new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") + } + } + } +} + +private[spark] object SparkSaslServer { + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + val SASL_DEFAULT_REALM = "default" + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + val DIGEST = "DIGEST-MD5" + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") + + /** + * Encode a byte[] identifier as a Base64-encoded string. + * + * @param identifier identifier to encode + * @return Base64-encoded string + */ + def encodeIdentifier(identifier: Array[Byte]): String = { + new String(Base64.encodeBase64(identifier)) + } + + /** + * Encode a password as a base64-encoded char[] array. + * @param password as a byte array. + * @return password as a char array. + */ + def encodePassword(password: Array[Byte]): Array[Char] = { + new String(Base64.encodeBase64(password)).toCharArray() + } +} + diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 395fd2cf2bebb..bcde2e8bb79ee 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -18,7 +18,7 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{Authenticator, PasswordAuthentication, URL, URLConnection, URI} +import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream @@ -158,34 +158,10 @@ private object HttpBroadcast extends Logging { var uc: URLConnection = null if (securityManager.isAuthenticationEnabled()) { - val uri = new URI(url) - val userCred = securityManager.getSecretKey() - if (userCred == null) { - // if auth is on force the user to specify a password - throw new Exception("secret key is null with authentication on") - } - val userInfo = securityManager.getHttpUser() + ":" + userCred - val newuri = new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), - uri.getQuery(), uri.getFragment()) - + logDebug("broadcast security enabled") + val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) uc = newuri.toURL().openConnection() uc.setAllowUserInteraction(false) - logDebug("broadcast security enabled") - - // set our own authenticator to properly negotiate user/password - Authenticator.setDefault( - new Authenticator() { - override def getPasswordAuthentication(): PasswordAuthentication = { - var passAuth: PasswordAuthentication = null - val userInfo = getRequestingURL().getUserInfo() - if (userInfo != null) { - val parts = userInfo.split(":", 2) - passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) - } - return passAuth - } - } - ); } else { logDebug("broadcast not using security") uc = new URL(url).openConnection() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 5af01f59afbc8..b2dc59aac27f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -68,6 +68,11 @@ class SparkHadoopUtil { def getCurrentUserCredentials(): Credentials = { null } def addCurrentUserCredentials(creds: Credentials) {} + + def addSecretKeyToUserCredentials(key: String, secret: String) {} + + def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index e192a8fa9c511..b88f6d12893c9 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -19,14 +19,11 @@ package org.apache.spark.network import org.apache.spark._ import org.apache.spark.SparkSaslServer -import org.apache.spark.SparkSaslServer.SaslDigestCallbackHandler import scala.collection.mutable.{HashMap, Queue, ArrayBuffer} -import java.io._ import java.nio._ import java.nio.channels._ -import java.nio.channels.spi._ import java.net._ @@ -466,7 +463,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val newMessage = Message.create(header).asInstanceOf[BufferMessage] newMessage.started = true newMessage.startTime = System.currentTimeMillis - newMessage.isSecurityNeg = if (header.securityNeg == 1) true else false + newMessage.isSecurityNeg = header.securityNeg == 1 logDebug( "Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]") messages += ((newMessage.id, newMessage)) diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala index b26b7ee34534a..a174dfe4d2a00 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -17,14 +17,18 @@ package org.apache.spark.network -private[spark] case class ConnectionId(id : String) {} +private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId +} private[spark] object ConnectionId { - def createConnectionId(connectionManagerId : ConnectionManagerId, secureMsgId : Int) : ConnectionId = { - val connIdStr = connectionManagerId.host + "_" + connectionManagerId.port + "_" + secureMsgId - val connId = new ConnectionId(connIdStr) - return connId + def createConnectionIdFromString(connectionIdString: String) : ConnectionId = { + val res = connectionIdString.split("_").map(_.trim()) + if (res.size != 3) { + throw new Exception("Error converting ConnectionId string: " + connectionIdString + + " to a ConnectionId Object") + } + new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) } } - diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index ce0e44873ee1d..e404e28e9a54f 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -33,10 +33,9 @@ import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, Promise, ExecutionContext, Future} -import scala.concurrent.duration.Duration import scala.concurrent.duration._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{SystemClock, Utils} private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManager: SecurityManager) extends Logging { @@ -54,8 +53,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag private val selector = SelectorProvider.provider.openSelector() - // TODO -update to use spark conf - private val numAuthRetries = System.getProperty("spark.core.connection.num.auth.retries","10").toInt + // default to 30 second timeout waiting for authentication + private val authTimeout= System.getProperty("spark.core.connection.auth.wait.timeout","30000").toInt private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -387,7 +386,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag // accept them all in a tight loop. non blocking accept with no processing, should be fine while (newChannel != null) { try { - val newConnectionId = ConnectionId.createConnectionId(id, idCount.getAndIncrement.intValue) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId) newConnection.onReceive(receiveMessage) addListeners(newConnection) @@ -506,7 +505,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag /*handleMessage(connection, message)*/ } - private def handleClientAuthNeg( + private def handleClientAuthentication( waitingConn: SendingConnection, securityMsg: SecurityMessage, connectionId : ConnectionId) { @@ -529,21 +528,22 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag } return } - var securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId) + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId.toString()) var message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message) } catch { case e: Exception => { - logError("Error doing sasl client: " + e) + logError("Error handling sasl client authentication", e) waitingConn.close() - throw new Exception("error evaluating sasl response: " + e) + throw new Exception("Error evaluating sasl response: " + e) } } } } - private def handleServerAuthNeg( + private def handleServerAuthentication( connection: Connection, securityMsg: SecurityMessage, connectionId: ConnectionId) { @@ -589,18 +589,18 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag // parse as SecurityMessage val securityMsg = SecurityMessage.fromBufferMessage(bufferMessage) - val connectionId = new ConnectionId(securityMsg.getConnectionId) + val connectionId = ConnectionId.createConnectionIdFromString(securityMsg.getConnectionId) connectionsAwaitingSasl.get(connectionId) match { case Some(waitingConn) => { // Client - this must be in response to us doing Send logDebug("Client handleAuth for id: " + waitingConn.connectionId) - handleClientAuthNeg(waitingConn, securityMsg, connectionId) + handleClientAuthentication(waitingConn, securityMsg, connectionId) } case None => { // Server - someone sent us something and we haven't authenticated yet logDebug("Server handleAuth for id: " + connectionId) - handleServerAuthNeg(conn, securityMsg, connectionId) + handleServerAuthentication(conn, securityMsg, connectionId) } } return true @@ -627,7 +627,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag val res = handleAuthentication(connection, bufferMessage) if (res == true) { // message was security negotiation so skip the rest - logDebug("After handleAuth result was true, returning"); + logDebug("After handleAuth result was true, returning") return } } @@ -689,15 +689,15 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() - var securityMsg = SecurityMessage.fromResponse(firstResponse, conn.connectionId.id) + var securityMsg = SecurityMessage.fromResponse(firstResponse, conn.connectionId.toString()) var message = securityMsg.toBufferMessage if (message == null) throw new Exception("Error creating security message") + connectionsAwaitingSasl += ((conn.connectionId, conn)) sendSecurityMessage(connManagerId, message) logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId) - connectionsAwaitingSasl += ((conn.connectionId, conn)) } catch { case e: Exception => { - logError("Error getting first response from the SaslClient") + logError("Error getting first response from the SaslClient.", e) conn.close() throw new Exception("Error getting first response from the SaslClient") } @@ -713,7 +713,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) - val newConnectionId = ConnectionId.createConnectionId(id, idCount.getAndIncrement.intValue) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, newConnectionId) logInfo("creating new sending connection for security! " + newConnectionId ) registerRequests.enqueue(newConnection) @@ -736,8 +736,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) - val newConnectionId = ConnectionId.createConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId) + val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId) logDebug("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -757,22 +757,22 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag // if we aren't authenticated yet lets block the senders until authentication completes try { connection.getAuthenticated().synchronized { - var totalWaitTimes = 0 + val clock = SystemClock + val startTime = clock.getTime() + while (!connection.isSaslComplete()) { - // should we specify timeout as fallback? logDebug("getAuthenticated wait connectionid: " + connection.connectionId) // have timeout in case remote side never responds - totalWaitTimes += 1 connection.getAuthenticated().wait(500) - if (totalWaitTimes >= numAuthRetries) { - // took to long to auth connection something probably went wrong + if (((clock.getTime() - startTime) >= authTimeout) && (!connection.isSaslComplete())) { + // took to long to authenticate the connection, something probably went wrong throw new Exception("Took to long for authentication to " + connectionManagerId + - ", waited " + 500 * numAuthRetries + "ms, failing.") + ", waited " + authTimeout + "ms, failing.") } } } } catch { - case e: Exception => logError("Exception while waiting for authentication. " + e) + case e: Exception => logError("Exception while waiting for authentication.", e) // need to tell sender it failed messageStatuses.synchronized { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 386d607b21c6a..553d33cedd656 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,6 @@ package org.apache.spark.util import java.io._ -import java.net.{Authenticator, PasswordAuthentication} import java.net.{InetAddress, URL, URLConnection, URI, NetworkInterface, Inet4Address, ServerSocket} import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} @@ -237,6 +236,22 @@ private[spark] object Utils extends Logging { } } + /** + * Construct a URI container information used for authentication. + * This also sets the default authenticator to properly negotiation the + * user/password based on the URI. + * + * Note this relies on the Authenticator.setDefault being set properly to decode + * the user name and password. This is currently set in the SecurityManager. + */ + def constructURIForAuthentication(uri: URI, securityMgr: SecurityManager): URI = { + val userCred = securityMgr.getSecretKey() + if (userCred == null) throw new Exception("Secret key is null with authentication on") + val userInfo = securityMgr.getHttpUser() + ":" + userCred + new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), uri.getPath(), + uri.getQuery(), uri.getFragment()) + } + /** * Download a file requested by the executor. Supports fetching the file in a variety of ways, * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter. @@ -261,31 +276,10 @@ private[spark] object Utils extends Logging { val sparkEnv = SparkEnv.get val securityMgr = if (sparkEnv != null) sparkEnv.securityManager else new SecurityManager() if (securityMgr.isAuthenticationEnabled()) { - val userCred = securityMgr.getSecretKey() - if (userCred == null) { - throw new Exception("secret key is null with authentication on") - } - val userInfo = securityMgr.getHttpUser() + ":" + userCred - val newuri = new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), - uri.getPath(), uri.getQuery(), uri.getFragment()) + logDebug("fetchFile with security enabled") + val newuri = constructURIForAuthentication(uri, securityMgr) uc = newuri.toURL().openConnection() uc.setAllowUserInteraction(false) - logDebug("in security enabled") - - // set our own authenticator to properly negotiate user/password - Authenticator.setDefault( - new Authenticator() { - override def getPasswordAuthentication(): PasswordAuthentication = { - var passAuth: PasswordAuthentication = null - val userInfo = getRequestingURL().getUserInfo() - if (userInfo != null) { - val parts = userInfo.split(":", 2) - passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) - } - return passAuth - } - } - ); } else { logDebug("fetchFile not using security") uc = new URL(url).openConnection() diff --git a/docs/configuration.md b/docs/configuration.md index 99558cb541bcd..7b7cfeb4d8277 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -147,6 +147,34 @@ Apart from these, the following properties are also available, and may be useful How many stages the Spark UI remembers before garbage collecting. + + spark.ui.filters + None + + Comma separated list of filter class names to apply to the Spark web ui. The filter should be a + standard javax servlet Filter. Parameters to each filter can also be specified by setting a + java system property of .params='param1=value1,param2=value2' + (e.g.-Dspark.ui.filters=com.test.filter1 -Dcom.test.filter1.params='param1=foo,param2=testing') + + + + spark.ui.acls.enable + false + + Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has + access permissions to view the web ui. See spark.ui.view.acls for more details. + Also note this requires the user to be known, if the user comes across as null no checks + are done. Filters can be used to authenticate and set the user. + + + + spark.ui.view.acls + Empty + + Comma separated list of users that have view access to the spark web ui. By default only the + user that started the Spark job has view access. + + spark.shuffle.compress true @@ -477,34 +505,6 @@ Apart from these, the following properties are also available, and may be useful Whether to overwrite files added through SparkContext.addFile() when the target file exists and its contents do not match those of the source. - - spark.ui.filters - None - - Comma separated list of filter class names to apply to the Spark web ui. The filter should be a - standard javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of .params='param1=value1,param2=value2' - (e.g.-Dspark.ui.filters=com.test.filter1 -Dcom.test.filter1.params='param1=foo,param2=testing') - - - - spark.authenticate.ui - false - - Whether spark web ui authentication should be on. If enabled this checks the user access - permissions to view the web ui. See spark.ui.view.acls for more details. - Also note this requires the user to be known, if the user comes across as null no checks - are done. Filters can be used to authenticate and set the user. - - - - spark.ui.view.acls - Empty - - Comma separated list of users that have view access to the spark web ui. By default only the - user that started the Spark job has view access. - - spark.authenticate false @@ -513,6 +513,14 @@ Apart from these, the following properties are also available, and may be useful running on Yarn. + + spark.core.connection.auth.wait.timeout + 30000 + + Number of milliseconds for the connection to wait for authentication to occur before timing + out and giving up. + + ## Viewing Spark Properties diff --git a/docs/index.md b/docs/index.md index 7fea73024a8a0..9aa85f289190d 100644 --- a/docs/index.md +++ b/docs/index.md @@ -103,6 +103,7 @@ For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to bui * [Configuration](configuration.html): customize Spark via its configuration system * [Tuning Guide](tuning.html): best practices to optimize performance and memory use +* [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * [Job Scheduling](job-scheduling.html): scheduling resources across and within Spark applications * [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 4cc2b66fe6e27..1aa94079fd0ae 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -18,15 +18,14 @@ package org.apache.spark.repl import java.io.{ByteArrayOutputStream, InputStream} -import java.net.{URI, URL, URLClassLoader, URLEncoder} -import java.net.Authenticator -import java.net.PasswordAuthentication +import java.net.{URI, URL, URLEncoder} import java.util.concurrent.{Executors, ExecutorService} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkEnv +import org.apache.spark.util.Utils import org.objectweb.asm._ import org.objectweb.asm.Opcodes._ @@ -56,35 +55,14 @@ extends ClassLoader(parent) { val inputStream = { if (fileSystem != null) { fileSystem.open(new Path(directory, pathInDirectory)) - else + } else { if (SparkEnv.get.securityManager.isAuthenticationEnabled()) { val uri = new URI(classUri + "/" + urlEncode(pathInDirectory)) - val userCred = SparkEnv.get.securityManager.getSecretKey() - if (userCred == null) { - throw new Exception("secret key is null with authentication on") - } - val userInfo = SparkEnv.get.securityManager.getHttpUser() + ":" + userCred - val newuri = new URI(uri.getScheme(), userInfo, uri.getHost(), uri.getPort(), - uri.getPath(), uri.getQuery(), uri.getFragment()) - - // set our own authenticator to properly negotiate user/password - Authenticator.setDefault( - new Authenticator() { - override def getPasswordAuthentication(): PasswordAuthentication = { - var passAuth: PasswordAuthentication = null - val userInfo = getRequestingURL().getUserInfo() - if (userInfo != null) { - val parts = userInfo.split(":", 2) - passAuth = new PasswordAuthentication(parts(0), parts(1).toCharArray()) - } - return passAuth - } - } - ); - + val newuri = Utils.constructURIForAuthentication(uri, SparkEnv.get.securityManager) newuri.toURL().openStream() - } else { - new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } else { + new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream() + } } } val bytes = readAndTransformClass(name, inputStream) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 89f12b602c32d..c07f3749981d4 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,12 +17,13 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.conf.Configuration +import org.apache.spark.deploy.SparkHadoopUtil /** * Contains util methods to interact with Hadoop from spark. @@ -49,4 +50,16 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { override def addCurrentUserCredentials(creds: Credentials) { UserGroupInformation.getCurrentUser().addCredentials(creds) } + + override def addSecretKeyToUserCredentials(key: String, secret: String) { + val creds = new Credentials() + creds.addSecretKey(new Text(key), secret.getBytes()) + addCurrentUserCredentials(creds) + } + + override def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { + val credentials = getCurrentUserCredentials() + if (credentials != null) credentials.getSecretKey(new Text(key)) else null + } + } From ed3d1c16cf9a0af6530d2c37e62fb9cdc92ddfcb Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sat, 22 Feb 2014 10:50:30 -0600 Subject: [PATCH 04/14] Add security.md --- docs/security.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 docs/security.md diff --git a/docs/security.md b/docs/security.md new file mode 100644 index 0000000000000..3fd511d6c509e --- /dev/null +++ b/docs/security.md @@ -0,0 +1,14 @@ +--- +layout: global +title: Spark Security +--- + +Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the 'spark.authenticate' configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. + +The Spark UI can also be secured by using javax servlet filters. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view acls to make sure they are authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' control the behavior of the acls. Note that the person who started the application always has view access to the UI. + +For Spark on Yarn deployments, configuring `spark.authenticate` to true will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. The Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +For other types of Spark deployments, the environment variable `SPARK_SECRET` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +See [Spark Configuration](configuration.html) for more details on the security configs. From b514becd7a0173ebeb209c0436e3c2c9f2f40a64 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sat, 22 Feb 2014 10:52:56 -0600 Subject: [PATCH 05/14] Fix reference to config --- docs/security.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/security.md b/docs/security.md index 3fd511d6c509e..c621c5a2eadde 100644 --- a/docs/security.md +++ b/docs/security.md @@ -3,7 +3,7 @@ layout: global title: Spark Security --- -Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the 'spark.authenticate' configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. +Spark currently supports authentication via a shared secret. Authentication can be configured to be on via the `spark.authenticate` configuration parameter. This parameter controls whether the Spark communication protocols do authentication using the shared secret. This authentication is a basic handshake to make sure both sides have the same shared secret and are allowed to communicate. If the shared secret is not identical they will not be allowed to communicate. The Spark UI can also be secured by using javax servlet filters. A user may want to secure the UI if it has data that other users should not be allowed to see. The javax servlet filter specified by the user can authenticate the user and then once the user is logged in, Spark can compare that user versus the view acls to make sure they are authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' control the behavior of the acls. Note that the person who started the application always has view access to the UI. From ecbfb65860e4fea722537802cf036f0b505d7da9 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sat, 22 Feb 2014 11:37:41 -0600 Subject: [PATCH 06/14] Fix spacing and formatting --- .../scala/org/apache/spark/HttpServer.scala | 3 ++- .../org/apache/spark/SecurityManager.scala | 24 +++++++++---------- .../org/apache/spark/SparkSaslClient.scala | 15 ++++++------ .../org/apache/spark/SparkSaslServer.scala | 16 ++++++------- .../apache/spark/broadcast/Broadcast.scala | 3 ++- .../spark/broadcast/HttpBroadcast.scala | 2 +- .../spark/deploy/master/ui/MasterWebUI.scala | 3 ++- .../org/apache/spark/network/Connection.scala | 11 +++++---- .../spark/network/ConnectionManager.scala | 24 ++++++++++++------- .../spark/network/MessageChunkHeader.scala | 3 ++- .../org/apache/spark/ui/JettyUtils.scala | 2 +- .../apache/spark/ui/jobs/JobProgressUI.scala | 9 ++++--- 12 files changed, 66 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index d1dc51d0ef8ed..e2aa8c2314e6f 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -41,7 +41,8 @@ private[spark] class ServerStateException(message: String) extends Exception(mes * as well as classes created by the interpreter when the user types in code. This is just a wrapper * around a Jetty server. */ -private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) extends Logging { +private[spark] class HttpServer(resourceBase: File, securityManager: SecurityManager) + extends Logging { private var server: Server = null private var port: Int = -1 diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index f21dff51caaac..f025fd24a6788 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -1,13 +1,13 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -104,10 +104,10 @@ import scala.collection.mutable.ArrayBuffer * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work - * properly. For non-Yarn deployments, users can write a filter to go through a companies - * normal login service. If an authentication filter is in place then the SparkUI - * can be configured to check the logged in user against the list of users who have - * view acls to see if that user is authorized. + * properly. For non-Yarn deployments, users can write a filter to go through a + * companies normal login service. If an authentication filter is in place then the + * SparkUI can be configured to check the logged in user against the list of users who + * have view acls to see if that user is authorized. * The filters can also be used for many different purposes. For instance filters * could be used for logging, encypryption, or compression. * diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala index 2737a82b85fef..9af0440a357ca 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License") you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala index 633f83b46ed31..4a8213d7c296c 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -1,13 +1,12 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License") you may not use this file except in compliance - * with the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -15,6 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + package org.apache.spark import javax.security.auth.callback.Callback diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index da937e7377f3d..e3c3a12d16f2a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -60,7 +60,8 @@ abstract class Broadcast[T](val id: Long) extends Serializable { } private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) extends Logging with Serializable { +class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) + extends Logging with Serializable { private var initialized = false private var broadcastFactory: BroadcastFactory = null diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index bcde2e8bb79ee..e8eb04bb10469 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -153,7 +153,7 @@ private object HttpBroadcast extends Logging { } def read[T](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-"+id) + logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) val url = serverUri + "/" + BroadcastBlockId(id).name var uc: URLConnection = null diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index b0b56c96e0907..728fa5f9e0626 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -62,7 +62,8 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { val handlers = metricsHandlers ++ Seq[ServletContextHandler]( createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"), - createServletHandler("/app/json", (request: HttpServletRequest) => applicationPage.renderJson(request)), + createServletHandler("/app/json", + (request: HttpServletRequest) => applicationPage.renderJson(request)), createServletHandler("/app", (request: HttpServletRequest) => applicationPage.render(request)), createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request)) diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index b88f6d12893c9..d47931aaebc70 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -447,8 +447,11 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, // Must be created within selector loop - else deadlock -private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) - extends Connection(channel_, selector_, id_) { +private[spark] class ReceivingConnection( + channel_ : SocketChannel, + selector_ : Selector, + id_ : ConnectionId) + extends Connection(channel_, selector_, id_) { def isSaslComplete(): Boolean = { if (sparkSaslServer != null) sparkSaslServer.isComplete() else false @@ -509,7 +512,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S val inbox = new Inbox() val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE) - var onReceiveCallback: (Connection , Message) => Unit = null + var onReceiveCallback: (Connection, Message) => Unit = null var currentChunk: MessageChunk = null channel.register(selector, SelectionKey.OP_READ) @@ -584,7 +587,7 @@ private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : S } } } catch { - case e: Exception => { + case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) callOnExceptionCallback(e) close() diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index e404e28e9a54f..dceb6f632649c 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -37,7 +37,8 @@ import scala.concurrent.duration._ import org.apache.spark.util.{SystemClock, Utils} -private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManager: SecurityManager) extends Logging { +private[spark] class ConnectionManager(port: Int, conf: SparkConf, + securityManager: SecurityManager) extends Logging { class MessageStatus( val message: Message, @@ -54,7 +55,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag private val selector = SelectorProvider.provider.openSelector() // default to 30 second timeout waiting for authentication - private val authTimeout= System.getProperty("spark.core.connection.auth.wait.timeout","30000").toInt + private val authTimeout = System.getProperty("spark.core.connection.auth.wait.timeout", + "30000").toInt private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -564,7 +566,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag logDebug("Server sasl not completed: " + connection.connectionId) } if (replyToken != null) { - var securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId) + var securityMsgResp = SecurityMessage.fromResponse(replyToken, + securityMsg.getConnectionId) var message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security Message") sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) @@ -689,7 +692,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() - var securityMsg = SecurityMessage.fromResponse(firstResponse, conn.connectionId.toString()) + var securityMsg = SecurityMessage.fromResponse(firstResponse, + conn.connectionId.toString()) var message = securityMsg.toBufferMessage if (message == null) throw new Exception("Error creating security message") connectionsAwaitingSasl += ((conn.connectionId, conn)) @@ -714,13 +718,15 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, newConnectionId) + val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId, + newConnectionId) logInfo("creating new sending connection for security! " + newConnectionId ) registerRequests.enqueue(newConnection) newConnection } - // I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ... + // I removed the lookupKey stuff as part of merge ... should I re-add it ? + // We did not find it useful in our test-env ... // If we do re-add it, we should consistently use it everywhere I guess ? message.senderAddress = id.toSocketAddress() logDebug("Sending Security [" + message + "] to [" + connManagerId + "]") @@ -737,7 +743,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port) val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) - val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId) + val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, + newConnectionId) logDebug("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -751,7 +758,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, securityManag checkSendAuthFirst(connectionManagerId, connection) } message.senderAddress = id.toSocketAddress() - logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " connectionid: " + connection.connectionId) + logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " + + "connectionid: " + connection.connectionId) if (authEnabled) { // if we aren't authenticated yet lets block the senders until authentication completes diff --git a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala index 30666891633bb..e850d8c366686 100644 --- a/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala @@ -74,6 +74,7 @@ private[spark] object MessageChunkHeader { buffer.get(ipBytes) val ip = InetAddress.getByAddress(ipBytes) val port = buffer.getInt() - new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, new InetSocketAddress(ip, port)) + new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg, + new InetSocketAddress(ip, port)) } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 2dfad0326bd3e..63bddea9b2c30 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -121,7 +121,7 @@ private[spark] object JettyUtils extends Logging { } private def addFilters(handlers: Seq[ServletContextHandler]) { - val filters : Array[String] = System.getProperty("spark.ui.filters", "").split(',').map(_.trim()) + val filters: Array[String] = System.getProperty("spark.ui.filters", "").split(',').map(_.trim()) filters.foreach { case filter : String => if (!filter.isEmpty) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index 3516ed57a02e3..cc69aba0b5651 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -54,8 +54,11 @@ private[spark] class JobProgressUI(val sc: SparkContext) { def formatDuration(ms: Long) = Utils.msDurationToString(ms) def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/stages/stage", (request: HttpServletRequest) => stagePage.render(request)), - createServletHandler("/stages/pool", (request: HttpServletRequest) => poolPage.render(request)), - createServletHandler("/stages", (request: HttpServletRequest) => indexPage.render(request)) + createServletHandler("/stages/stage", + (request: HttpServletRequest) => stagePage.render(request)), + createServletHandler("/stages/pool", + (request: HttpServletRequest) => poolPage.render(request)), + createServletHandler("/stages", + (request: HttpServletRequest) => indexPage.render(request)) ) } From 50dd9f2438356117e749d4cbd8d0ea8c25746166 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Sat, 22 Feb 2014 12:35:34 -0600 Subject: [PATCH 07/14] fix header in SecurityManager --- core/src/main/scala/org/apache/spark/SecurityManager.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index f025fd24a6788..33570f6272350 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -1,4 +1,3 @@ -/* /* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with From 2f7714722854b42f27479955202cde2d2d2fb281 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 3 Mar 2014 12:31:33 -0600 Subject: [PATCH 08/14] Rework from comments --- .../org/apache/spark/SecurityManager.scala | 53 ++++++++--------- .../spark/deploy/worker/ui/WorkerWebUI.scala | 4 +- .../org/apache/spark/network/Connection.scala | 4 +- .../apache/spark/network/ConnectionId.scala | 2 +- .../spark/network/ConnectionManager.scala | 11 ++-- .../spark/network/SecurityMessage.scala | 57 ++++++++++++++++++- docs/configuration.md | 4 +- docs/security.md | 2 + 8 files changed, 94 insertions(+), 43 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 33570f6272350..0111f8186ebcf 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -50,7 +50,7 @@ import scala.collection.mutable.ArrayBuffer * Spark does not currently support encryption after authentication. * * At this point spark has multiple communication protocols that need to be secured and - * different underlying mechisms are used depending on the protocol: + * different underlying mechanisms are used depending on the protocol: * * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. * Akka remoting allows you to specify a secure cookie that will be exchanged @@ -108,7 +108,7 @@ import scala.collection.mutable.ArrayBuffer * SparkUI can be configured to check the logged in user against the list of users who * have view acls to see if that user is authorized. * The filters can also be used for many different purposes. For instance filters - * could be used for logging, encypryption, or compression. + * could be used for logging, encryption, or compression. * * The exact mechanisms used to generate/distributed the shared secret is deployment specific. * @@ -122,15 +122,11 @@ import scala.collection.mutable.ArrayBuffer * filters to do authentication. That authentication then happens via the ResourceManager Proxy * and Spark will use that to do authorization against the view acls. * - * For other Spark deployments, the shared secret should be specified via the SPARK_SECRET + * For other Spark deployments, the shared secret must be specified via the SPARK_SECRET * environment variable. This isn't ideal but it means only the user who starts the process - * has access to view that variable. Note that Spark does try to generate a secret for - * you if the SPARK_SECRET environment variable is not set, but it gets put into the java - * system property which can be viewed by other users, so setting the SPARK_SECRET environment - * variable is recommended. - * All the nodes (Master and Workers) need to have the same shared secret - * and all the applications running need to have that same shared secret. This again - * is not ideal as one user could potentially affect another users application. + * has access to view that variable. + * All the nodes (Master and Workers) and the applications need to have the same shared secret. + * This again is not ideal as one user could potentially affect another users application. * This should be enhanced in the future to provide better protection. * If the UI needs to be secured the user needs to install a javax servlet filter to do the * authentication. Spark will then use that user to compare against the view acls to do @@ -152,7 +148,8 @@ private[spark] class SecurityManager extends Logging { private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet private val secretKey = generateSecretKey() - logDebug("is auth enabled = " + authOn + " is uiAcls enabled = " + uiAclsOn) + logInfo("SecurityManager, is authentication enabled: " + authOn + + " are ui acls enabled: " + uiAclsOn) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -170,7 +167,7 @@ private[spark] class SecurityManager extends Logging { return passAuth } } - ); + ) } /** @@ -179,16 +176,12 @@ private[spark] class SecurityManager extends Logging { * The way the key is stored depends on the Spark deployment mode. Yarn * uses the Hadoop UGI. * - * For non-Yarn deployments, If the environment variable is not set already - * we generate a secret and since we can't set an environment variable dynamically - * we set the java system property SPARK_SECRET. This will allow it to automatically - * work in certain situations. Others this still will not work and this definitely is - * not ideal since other users can see it. We should switch to put it in - * a config once Spark supports configs. + * For non-Yarn deployments, If the environment variable is not set + * we throw an exception. */ private def generateSecretKey(): String = { if (!isAuthenticationEnabled) return null - // first check to see if the secret is already set, else generate a new one + // first check to see if the secret is already set, else generate a new one if on yarn if (SparkHadoopUtil.get.isYarnMode) { val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) if (secretKey != null) { @@ -200,17 +193,17 @@ private[spark] class SecurityManager extends Logging { } val secret = System.getProperty("SPARK_SECRET", System.getenv("SPARK_SECRET")) if (secret != null && !secret.isEmpty()) return secret - // generate one - val sCookie = akka.util.Crypt.generateSecureCookie - - // if we generated the secret then we must be the first so lets set it so t - // gets used by everyone else + val sCookie = if (SparkHadoopUtil.get.isYarnMode) { + // generate one + akka.util.Crypt.generateSecureCookie + } else { + throw new Exception("Error: a secret key must be specified via SPARK_SECRET env variable") + } if (SparkHadoopUtil.get.isYarnMode) { + // if we generated the secret then we must be the first so lets set it so t + // gets used by everyone else SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, sCookie) - logDebug("adding secret to credentials yarn mode") - } else { - System.setProperty("SPARK_SECRET", sCookie) - logDebug("adding secret to java property") + logInfo("adding secret to credentials in yarn mode") } sCookie } @@ -223,7 +216,9 @@ private[spark] class SecurityManager extends Logging { /** * Checks the given user against the view acl list to see if they have - * authorization to view the UI. + * authorization to view the UI. If the UI acls must are disabled + * via spark.ui.acls.enable, all users have view access. + * * @param user to see if is authorized * @return true is the user has permission, otherwise false */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 2e058930497d7..dcdcce4ae31b3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -48,7 +48,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val metricsHandlers = worker.metricsSystem.getServletHandlers val handlers = metricsHandlers ++ Seq[ServletContextHandler]( - createStaticHandler(WorkerWebUI.STATIC_RESOURCE_DIR, "/static/*"), + createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"), createServletHandler("/log", (request: HttpServletRequest) => log(request)), createServletHandler("/logPage", (request: HttpServletRequest) => logPage(request)), createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), @@ -199,6 +199,6 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I } private[spark] object WorkerWebUI { - val STATIC_RESOURCE_DIR = "org/apache/spark/ui" + val STATIC_RESOURCE_BASE = "org/apache/spark/ui" val DEFAULT_PORT="8081" } diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index d47931aaebc70..4eadf66d98698 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -32,8 +32,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId) extends Logging { - var sparkSaslServer : SparkSaslServer = null - var sparkSaslClient : SparkSaslClient = null + var sparkSaslServer: SparkSaslServer = null + var sparkSaslClient: SparkSaslClient = null def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = { this(channel_, selector_, diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala index a174dfe4d2a00..ffaab677d411a 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -23,7 +23,7 @@ private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, private[spark] object ConnectionId { - def createConnectionIdFromString(connectionIdString: String) : ConnectionId = { + def createConnectionIdFromString(connectionIdString: String): ConnectionId = { val res = connectionIdString.split("_").map(_.trim()) if (res.size != 3) { throw new Exception("Error converting ConnectionId string: " + connectionIdString + diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index dceb6f632649c..db74a182af358 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -56,7 +56,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, // default to 30 second timeout waiting for authentication private val authTimeout = System.getProperty("spark.core.connection.auth.wait.timeout", - "30000").toInt + "30").toInt private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -79,6 +79,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, new LinkedBlockingDeque[Runnable]()) private val serverChannel = ServerSocketChannel.open() + // used to track the SendingConnections waiting to do SASL negotiation private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = @@ -729,7 +730,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, // We did not find it useful in our test-env ... // If we do re-add it, we should consistently use it everywhere I guess ? message.senderAddress = id.toSocketAddress() - logDebug("Sending Security [" + message + "] to [" + connManagerId + "]") + logTrace("Sending Security [" + message + "] to [" + connManagerId + "]") val connection = connectionsById.getOrElseUpdate(connManagerId, startNewConnection()) //send security message until going connection has been authenticated @@ -745,7 +746,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId) - logDebug("creating new sending connection: " + newConnectionId) + logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) newConnection @@ -772,10 +773,10 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logDebug("getAuthenticated wait connectionid: " + connection.connectionId) // have timeout in case remote side never responds connection.getAuthenticated().wait(500) - if (((clock.getTime() - startTime) >= authTimeout) && (!connection.isSaslComplete())) { + if (((clock.getTime() - startTime) >= (authTimeout * 1000)) && (!connection.isSaslComplete())) { // took to long to authenticate the connection, something probably went wrong throw new Exception("Took to long for authentication to " + connectionManagerId + - ", waited " + authTimeout + "ms, failing.") + ", waited " + authTimeout + "seconds, failing.") } } } diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala index dd5519f6b39a7..0d9f743b3624b 100644 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -25,6 +25,35 @@ import scala.collection.mutable.StringBuilder import org.apache.spark._ import org.apache.spark.network._ +/** + * SecurityMessage is class that contains the connectionId and sasl token + * used in SASL negotiation. SecurityMessage has routines for converting + * it to and from a BufferMessage so that it can be sent by the ConnectionManager + * and easily consumed by users when received. + * The api was modeled after BlockMessage. + * + * The connectionId is the connectionId of the client side. Since + * message passing is asynchronous and its possible for the server side (receiving) + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * + * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side + * is acting as a client and connecting to node_1. SASL negotiation has to occur + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This + * is where the connectionId field is used. node_0 can lookup the connectionId to see if + * it is in response to it being a client or if its in response to someone sending other data. + * + * The format of a SecurityMessage as its sent is: + * - Length of the ConnectionId + * - ConnectionId + * - Length of the token + * - Token + */ private[spark] class SecurityMessage() extends Logging { private var connectionId: String = null @@ -39,6 +68,9 @@ private[spark] class SecurityMessage() extends Logging { connectionId = newconnectionId } + /** + * Read the given buffer and set the members of this class. + */ def set(buffer: ByteBuffer) { val idLength = buffer.getInt() val idBuilder = new StringBuilder(idLength) @@ -68,10 +100,19 @@ private[spark] class SecurityMessage() extends Logging { return token } + /** + * Create a BufferMessage that can be sent by the ConnectionManager containing + * the security information from this class. + * @return BufferMessage + */ def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() + // 4 bytes for the length of the connectionId + // connectionId is of type char so multiple the length by 2 to get number of bytes + // 4 bytes for the length of token + // token is a byte buffer so just take the length var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) buffer.putInt(connectionId.length()) connectionId.foreach((x: Char) => buffer.putChar(x)) @@ -96,15 +137,27 @@ private[spark] class SecurityMessage() extends Logging { private[spark] object SecurityMessage { + /** + * Convert the given BufferMessage to a SecurityMessage by parsing the contents + * of the BufferMessage and populating the SecurityMessage fields. + * @param bufferMessage is a BufferMessage that was received + * @return new SecurityMessage + */ def fromBufferMessage(bufferMessage: BufferMessage): SecurityMessage = { val newSecurityMessage = new SecurityMessage() newSecurityMessage.set(bufferMessage) newSecurityMessage } - def fromResponse(response : Array[Byte], newConnectionId : String) : SecurityMessage = { + /** + * Create a SecurityMessage to send from a given saslResponse. + * @param response is the response to a challenge from the SaslClient or Saslserver + * @param connectionId the client connectionId we are negotiation authentication for + * @return a new SecurityMessage + */ + def fromResponse(response : Array[Byte], connectionId : String) : SecurityMessage = { val newSecurityMessage = new SecurityMessage() - newSecurityMessage.set(response, newConnectionId) + newSecurityMessage.set(response, connectionId) newSecurityMessage } } diff --git a/docs/configuration.md b/docs/configuration.md index 7b7cfeb4d8277..c94ef26a12739 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -515,9 +515,9 @@ Apart from these, the following properties are also available, and may be useful spark.core.connection.auth.wait.timeout - 30000 + 30 - Number of milliseconds for the connection to wait for authentication to occur before timing + Number of seconds for the connection to wait for authentication to occur before timing out and giving up. diff --git a/docs/security.md b/docs/security.md index c621c5a2eadde..aa61dfc354c19 100644 --- a/docs/security.md +++ b/docs/security.md @@ -12,3 +12,5 @@ For Spark on Yarn deployments, configuring `spark.authenticate` to true will aut For other types of Spark deployments, the environment variable `SPARK_SECRET` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. See [Spark Configuration](configuration.html) for more details on the security configs. + +See org.apache.spark.SecurityManager for implementation details about security. From 4a57acc4f4b96067a19538ed7c06504de31d9025 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Mon, 3 Mar 2014 12:44:09 -0600 Subject: [PATCH 09/14] Change UI createHandler routines to createServlet since they now return servlets --- .../apache/spark/metrics/sink/MetricsServlet.scala | 2 +- .../scala/org/apache/spark/ui/JettyUtils.scala | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index c59cbbedf64ab..1729bcfc92e41 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -47,7 +47,7 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext def getHandlers = Array[ServletContextHandler]( JettyUtils.createServletHandler(servletPath, - JettyUtils.createHandler(request => getMetricsSnapshot(request), "text/json")) + JettyUtils.createServlet(request => getMetricsSnapshot(request), "text/json")) ) def getMetricsSnapshot(request: HttpServletRequest): String = { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 63bddea9b2c30..aa57f2687028e 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -46,16 +46,16 @@ private[spark] object JettyUtils extends Logging { type Responder[T] = HttpServletRequest => T // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToHandler(responder: Responder[JValue]): HttpServlet = - createHandler(responder, "text/json", (in: JValue) => pretty(render(in))) + implicit def jsonResponderToServlet(responder: Responder[JValue]): HttpServlet = + createServlet(responder, "text/json", (in: JValue) => pretty(render(in))) - implicit def htmlResponderToHandler(responder: Responder[Seq[Node]]): HttpServlet = - createHandler(responder, "text/html", (in: Seq[Node]) => "" + in.toString) + implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): HttpServlet = + createServlet(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - implicit def textResponderToHandler(responder: Responder[String]): HttpServlet = - createHandler(responder, "text/plain") + implicit def textResponderToServlet(responder: Responder[String]): HttpServlet = + createServlet(responder, "text/plain") - def createHandler[T <% AnyRef](responder: Responder[T], contentType: String, + def createServlet[T <% AnyRef](responder: Responder[T], contentType: String, extractFn: T => String = (in: Any) => in.toString): HttpServlet = { new HttpServlet { override def doGet(request: HttpServletRequest, From 13733e1532cbd3fcd0bef59d4078771bf58892d2 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 5 Mar 2014 20:09:52 -0600 Subject: [PATCH 10/14] Pass securityManager and SparkConf around where we can. Switch to use sparkConf for reading config whereever possible. Added ConnectionManagerSuite unit tests. --- .../org/apache/spark/SecurityManager.scala | 45 ++-- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 6 +- .../org/apache/spark/SparkSaslClient.scala | 42 ++-- .../org/apache/spark/SparkSaslServer.scala | 24 +- .../org/apache/spark/deploy/Client.scala | 2 +- .../spark/deploy/client/TestClient.scala | 3 +- .../apache/spark/deploy/master/Master.scala | 14 +- .../spark/deploy/master/ui/MasterWebUI.scala | 14 +- .../spark/deploy/worker/DriverWrapper.scala | 3 +- .../apache/spark/deploy/worker/Worker.scala | 10 +- .../spark/deploy/worker/ui/WorkerWebUI.scala | 16 +- .../CoarseGrainedExecutorBackend.scala | 3 +- .../org/apache/spark/executor/Executor.scala | 4 +- .../apache/spark/metrics/MetricsSystem.scala | 13 +- .../spark/metrics/sink/ConsoleSink.scala | 4 +- .../apache/spark/metrics/sink/CsvSink.scala | 4 +- .../spark/metrics/sink/GangliaSink.scala | 4 +- .../spark/metrics/sink/GraphiteSink.scala | 4 +- .../apache/spark/metrics/sink/JmxSink.scala | 5 +- .../spark/metrics/sink/MetricsServlet.scala | 8 +- .../org/apache/spark/network/Connection.scala | 1 + .../spark/network/ConnectionManager.scala | 9 +- .../apache/spark/network/ReceiverTest.scala | 3 +- .../org/apache/spark/network/SenderTest.scala | 4 +- .../apache/spark/storage/ThreadingTest.scala | 2 +- .../org/apache/spark/ui/JettyUtils.scala | 54 ++-- .../scala/org/apache/spark/ui/SparkUI.scala | 2 +- .../apache/spark/ui/env/EnvironmentUI.scala | 3 +- .../apache/spark/ui/exec/ExecutorsUI.scala | 3 +- .../apache/spark/ui/jobs/JobProgressUI.scala | 9 +- .../spark/ui/storage/BlockManagerUI.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 8 +- .../org/apache/spark/AkkaUtilsSuite.scala | 66 ++--- .../org/apache/spark/BroadcastSuite.scala | 2 - .../apache/spark/ConnectionManagerSuite.scala | 230 ++++++++++++++++++ .../scala/org/apache/spark/DriverSuite.scala | 1 - .../org/apache/spark/FileServerSuite.scala | 8 +- .../apache/spark/MapOutputTrackerSuite.scala | 4 +- .../spark/metrics/MetricsSystemSuite.scala | 8 +- .../spark/storage/BlockManagerSuite.scala | 4 +- .../scala/org/apache/spark/ui/UISuite.scala | 9 +- docs/configuration.md | 16 +- docs/security.md | 4 +- .../streaming/examples/ActorWordCount.scala | 6 +- .../org/apache/spark/repl/SparkIMain.scala | 11 +- .../org/apache/spark/repl/ReplSuite.scala | 2 - .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../spark/deploy/yarn/WorkerLauncher.scala | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 2 +- .../spark/deploy/yarn/WorkerLauncher.scala | 2 +- 51 files changed, 513 insertions(+), 203 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0111f8186ebcf..591978c1d3630 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -122,9 +122,8 @@ import scala.collection.mutable.ArrayBuffer * filters to do authentication. That authentication then happens via the ResourceManager Proxy * and Spark will use that to do authorization against the view acls. * - * For other Spark deployments, the shared secret must be specified via the SPARK_SECRET - * environment variable. This isn't ideal but it means only the user who starts the process - * has access to view that variable. + * For other Spark deployments, the shared secret must be specified via the + * spark.authenticate.secret config. * All the nodes (Master and Workers) and the applications need to have the same shared secret. * This again is not ideal as one user could potentially affect another users application. * This should be enhanced in the future to provide better protection. @@ -133,23 +132,24 @@ import scala.collection.mutable.ArrayBuffer * authorization. If not filter is in place the user is generally null and no authorization * can take place. */ -private[spark] class SecurityManager extends Logging { + +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" - private val authOn = System.getProperty("spark.authenticate", "false").toBoolean - private val uiAclsOn = System.getProperty("spark.ui.acls.enable", "false").toBoolean + private val authOn = sparkConf.getBoolean("spark.authenticate", false) + private val uiAclsOn = sparkConf.getBoolean("spark.ui.acls.enable", false) // always add the current user and SPARK_USER to the viewAcls private val aclUsers = ArrayBuffer[String](System.getProperty("user.name", ""), Option(System.getenv("SPARK_USER")).getOrElse("")) - aclUsers ++= System.getProperty("spark.ui.view.acls", "").split(',') + aclUsers ++= sparkConf.get("spark.ui.view.acls", "").split(',') private val viewAcls = aclUsers.map(_.trim()).filter(!_.isEmpty).toSet private val secretKey = generateSecretKey() logInfo("SecurityManager, is authentication enabled: " + authOn + - " are ui acls enabled: " + uiAclsOn) + " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. // This is needed by the HTTP client fetching from the HttpServer. Put here so its @@ -176,13 +176,13 @@ private[spark] class SecurityManager extends Logging { * The way the key is stored depends on the Spark deployment mode. Yarn * uses the Hadoop UGI. * - * For non-Yarn deployments, If the environment variable is not set - * we throw an exception. + * For non-Yarn deployments, If the config variable is not set + * we throw an exception. */ private def generateSecretKey(): String = { if (!isAuthenticationEnabled) return null // first check to see if the secret is already set, else generate a new one if on yarn - if (SparkHadoopUtil.get.isYarnMode) { + val sCookie = if (SparkHadoopUtil.get.isYarnMode) { val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(sparkSecretLookupKey) if (secretKey != null) { logDebug("in yarn mode, getting secret from credentials") @@ -190,20 +190,19 @@ private[spark] class SecurityManager extends Logging { } else { logDebug("getSecretKey: yarn mode, secret key from credentials is null") } - } - val secret = System.getProperty("SPARK_SECRET", System.getenv("SPARK_SECRET")) - if (secret != null && !secret.isEmpty()) return secret - val sCookie = if (SparkHadoopUtil.get.isYarnMode) { - // generate one - akka.util.Crypt.generateSecureCookie - } else { - throw new Exception("Error: a secret key must be specified via SPARK_SECRET env variable") - } - if (SparkHadoopUtil.get.isYarnMode) { - // if we generated the secret then we must be the first so lets set it so t + val cookie = akka.util.Crypt.generateSecureCookie + // if we generated the secret then we must be the first so lets set it so t // gets used by everyone else - SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, sCookie) + SparkHadoopUtil.get.addSecretKeyToUserCredentials(sparkSecretLookupKey, cookie) logInfo("adding secret to credentials in yarn mode") + cookie + } else { + // user must have set spark.authenticate.secret config + sparkConf.getOption("spark.authenticate.secret") match { + case Some(value) => value + case None => throw new Exception("Error: a secret key must be specified via the " + + "spark.authenticate.secret config") + } } sCookie } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c716d2b69ca58..e7080bd832418 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -640,7 +640,7 @@ class SparkContext( addedFiles(key) = System.currentTimeMillis // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf, env.securityManager) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6f5b8cec90e1a..27920c6b7df74 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -124,7 +124,7 @@ object SparkEnv extends Logging { isDriver: Boolean, isLocal: Boolean): SparkEnv = { - val securityManager = new SecurityManager() + val securityManager = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port, conf = conf, securityManager = securityManager) @@ -197,9 +197,9 @@ object SparkEnv extends Logging { conf.set("spark.fileserver.uri", httpFileServer.serverUri) val metricsSystem = if (isDriver) { - MetricsSystem.createMetricsSystem("driver", conf) + MetricsSystem.createMetricsSystem("driver", conf, securityManager) } else { - MetricsSystem.createMetricsSystem("executor", conf) + MetricsSystem.createMetricsSystem("executor", conf, securityManager) } metricsSystem.start() diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala index 9af0440a357ca..a2a871cbd3c31 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -52,14 +52,16 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg * @return response to challenge if needed */ def firstToken(): Array[Byte] = { - val saslToken: Array[Byte] = - if (saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken + synchronized { + val saslToken: Array[Byte] = + if (saslClient != null && saslClient.hasInitialResponse()) { + logDebug("has initial response") + saslClient.evaluateChallenge(new Array[Byte](0)) + } else { + new Array[Byte](0) + } + saslToken + } } /** @@ -67,7 +69,9 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg * @return true is complete, otherwise false */ def isComplete(): Boolean = { - saslClient.isComplete() + synchronized { + if (saslClient != null) saslClient.isComplete() else false + } } /** @@ -76,7 +80,9 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg * @return client's response SASL token */ def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - saslClient.evaluateChallenge(saslTokenMessage) + synchronized { + if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) + } } /** @@ -84,13 +90,15 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg * SaslClient might be using. */ def dispose() { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null + synchronized { + if (saslClient != null) { + try { + saslClient.dispose() + } catch { + case e: SaslException => // ignored + } finally { + saslClient = null + } } } } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala index 4a8213d7c296c..11fcb2ae3a5c5 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -47,7 +47,9 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi * @return true is complete, otherwise false */ def isComplete(): Boolean = { - saslServer.isComplete() + synchronized { + if (saslServer != null) saslServer.isComplete() else false + } } /** @@ -56,7 +58,9 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi * @return response to send back to the server. */ def response(token: Array[Byte]): Array[Byte] = { - saslServer.evaluateResponse(token) + synchronized { + if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) + } } /** @@ -64,13 +68,15 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi * SaslServer might be using. */ def dispose() { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null + synchronized { + if (saslServer != null) { + try { + saslServer.dispose() + } catch { + case e: SaslException => // ignore + } finally { + saslServer = null + } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 76f73271b2b97..c5adf2f1f541a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -141,7 +141,7 @@ object Client { // TODO: See if we can initialize akka so return messages are sent back using the same TCP // flow. Else, this (sadly) requires the DriverClient be routable from the Master. val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager) + "driverClient", Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index dd44866d27903..58e532776b086 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -45,8 +45,9 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) + val conf = new SparkConf val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = new SparkConf, securityManager = new SecurityManager()) + conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription( "TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), Some("dummy-spark-home"), "ignored") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1adf7d29437da..acf66b36a06c5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -40,7 +40,8 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} import org.apache.spark.deploy.master.DriverState.DriverState -private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { +private[spark] class Master(host: String, port: Int, webUiPort: Int, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() val conf = new SparkConf @@ -71,8 +72,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act Utils.checkHost(host, "Expected hostname") - val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf) - val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf) + val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) + val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, + securityMgr) val masterSource = new MasterSource(this) val webUi = new MasterWebUI(this, webUiPort) @@ -712,9 +714,11 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) : (ActorSystem, Int, Int) = { + val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = new SecurityManager) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) + securityManager = securityMgr) + val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, + securityMgr), actorName) val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 728fa5f9e0626..a7bd01e284c8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -46,7 +46,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, master.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Master web UI at http://%s:%d".format(host, boundPort.get)) @@ -63,10 +63,14 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { val handlers = metricsHandlers ++ Seq[ServletContextHandler]( createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static/*"), createServletHandler("/app/json", - (request: HttpServletRequest) => applicationPage.renderJson(request)), - createServletHandler("/app", (request: HttpServletRequest) => applicationPage.render(request)), - createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request)) + createServlet((request: HttpServletRequest) => applicationPage.renderJson(request), + master.securityMgr)), + createServletHandler("/app", createServlet((request: HttpServletRequest) => applicationPage + .render(request), master.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), master.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), master.securityMgr)) ) def stop() { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 0c91c89714009..be15138f62406 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -29,8 +29,9 @@ object DriverWrapper { def main(args: Array[String]) { args.toList match { case workerUrl :: mainClass :: extraArgs => + val conf = new SparkConf() val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", - Utils.localHostName(), 0, false, new SparkConf(), new SecurityManager()) + Utils.localHostName(), 0, false, conf, new SecurityManager(conf)) actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") // Delegate to supplied main class diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 7411479b36909..c5cffc76b2f88 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -49,7 +49,8 @@ private[spark] class Worker( actorSystemName: String, actorName: String, workDirPath: String = null, - val conf: SparkConf) + val conf: SparkConf, + val securityMgr: SecurityManager) extends Actor with Logging { import context.dispatcher @@ -92,7 +93,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 - val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) def coresFree: Int = cores - coresUsed @@ -348,10 +349,11 @@ private[spark] object Worker { val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" + val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = new SecurityManager) + conf = conf, securityManager = securityMgr) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, systemName, actorName, workDir, conf), name = actorName) + masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index dcdcce4ae31b3..f5c1e6163c2a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -34,7 +34,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} */ private[spark] class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { + extends Logging { val timeout = AkkaUtils.askTimeout(worker.conf) val host = Utils.localHostName() val port = requestedPort.getOrElse( @@ -49,15 +49,19 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I val handlers = metricsHandlers ++ Seq[ServletContextHandler]( createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static/*"), - createServletHandler("/log", (request: HttpServletRequest) => log(request)), - createServletHandler("/logPage", (request: HttpServletRequest) => logPage(request)), - createServletHandler("/json", (request: HttpServletRequest) => indexPage.renderJson(request)), - createServletHandler("*", (request: HttpServletRequest) => indexPage.render(request)) + createServletHandler("/log", createServlet((request: HttpServletRequest) => log(request), + worker.securityMgr)), + createServletHandler("/logPage", createServlet((request: HttpServletRequest) => logPage + (request), worker.securityMgr)), + createServletHandler("/json", createServlet((request: HttpServletRequest) => indexPage + .renderJson(request), worker.securityMgr)), + createServletHandler("*", createServlet((request: HttpServletRequest) => indexPage.render + (request), worker.securityMgr)) ) def start() { try { - val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers) + val (srv, bPort) = JettyUtils.startJettyServer(host, port, handlers, worker.conf) server = Some(srv) boundPort = Some(bPort) logInfo("Started Worker web UI at http://%s:%d".format(host, bPort)) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index ad7f2b97a06f3..a4627a0b94b21 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -97,10 +97,11 @@ private[spark] object CoarseGrainedExecutorBackend { // Debug code Utils.checkHost(hostname) + val conf = new SparkConf // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true, conf = new SparkConf, new SecurityManager) + indestructible = true, conf = conf, new SecurityManager(conf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 77e6ad2ceafe0..e69f6f72d3275 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -339,12 +339,12 @@ private[spark] class Executor( // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index de233e416a9dc..906c7933377b6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit import scala.collection.mutable -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source @@ -64,7 +64,7 @@ import org.apache.spark.metrics.source.Source * [options] is the specific property of this source or sink. */ private[spark] class MetricsSystem private (val instance: String, - conf: SparkConf) extends Logging { + conf: SparkConf, securityMgr: SecurityManager) extends Logging { val confFile = conf.get("spark.metrics.conf", null) val metricsConfig = new MetricsConfig(Option(confFile)) @@ -131,8 +131,8 @@ private[spark] class MetricsSystem private (val instance: String, val classPath = kv._2.getProperty("class") try { val sink = Class.forName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry]) - .newInstance(kv._2, registry) + .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) + .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) } else { @@ -160,6 +160,7 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf): MetricsSystem = - new MetricsSystem(instance, conf) + def createMetricsSystem(instance: String, conf: SparkConf, + securityMgr: SecurityManager): MetricsSystem = + new MetricsSystem(instance, conf, securityMgr) } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index bce257d6e6f47..2fe2f5cb2d219 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -22,9 +22,11 @@ import com.codahale.metrics.{ConsoleReporter, MetricRegistry} import java.util.Properties import java.util.concurrent.TimeUnit +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class ConsoleSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class ConsoleSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CONSOLE_DEFAULT_PERIOD = 10 val CONSOLE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 3d1a06a395a72..c8f112c5d1f43 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -23,9 +23,11 @@ import java.io.File import java.util.{Locale, Properties} import java.util.concurrent.TimeUnit +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class CsvSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class CsvSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val CSV_KEY_PERIOD = "period" val CSV_KEY_UNIT = "unit" val CSV_KEY_DIR = "directory" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala index b924907070eb9..28a215f14f0e0 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GangliaSink.scala @@ -24,9 +24,11 @@ import com.codahale.metrics.ganglia.GangliaReporter import com.codahale.metrics.MetricRegistry import info.ganglia.gmetric4j.gmetric.GMetric +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GangliaSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GangliaSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GANGLIA_KEY_PERIOD = "period" val GANGLIA_DEFAULT_PERIOD = 10 diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index cdcfec8ca785b..b17e78b370b44 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -24,9 +24,11 @@ import java.net.InetSocketAddress import com.codahale.metrics.MetricRegistry import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} +import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { +class GraphiteSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val GRAPHITE_DEFAULT_PERIOD = 10 val GRAPHITE_DEFAULT_UNIT = "SECONDS" val GRAPHITE_DEFAULT_PREFIX = "" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 621d086d415cc..c108072a5ef44 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -21,7 +21,10 @@ import com.codahale.metrics.{JmxReporter, MetricRegistry} import java.util.Properties -class JmxSink(val property: Properties, val registry: MetricRegistry) extends Sink { +import org.apache.spark.SecurityManager + +class JmxSink(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() override def start() { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 1729bcfc92e41..28247d39b2e38 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -29,9 +29,11 @@ import javax.servlet.http.HttpServletRequest import org.eclipse.jetty.servlet.ServletContextHandler +import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils -class MetricsServlet(val property: Properties, val registry: MetricRegistry) extends Sink { +class MetricsServlet(val property: Properties, val registry: MetricRegistry, + securityMgr: SecurityManager) extends Sink { val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" @@ -47,7 +49,9 @@ class MetricsServlet(val property: Properties, val registry: MetricRegistry) ext def getHandlers = Array[ServletContextHandler]( JettyUtils.createServletHandler(servletPath, - JettyUtils.createServlet(request => getMetricsSnapshot(request), "text/json")) + JettyUtils.createServlet( + new JettyUtils.ServletParams(request => getMetricsSnapshot(request), "text/json"), + securityMgr) ) ) def getMetricsSnapshot(request: HttpServletRequest): String = { diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 4eadf66d98698..d683567eac366 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -88,6 +88,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, if (sparkSaslServer != null) { sparkSaslServer.dispose(); } + if (sparkSaslClient != null) { sparkSaslClient.dispose() } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index db74a182af358..22e01fbe8c2f7 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -55,8 +55,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private val selector = SelectorProvider.provider.openSelector() // default to 30 second timeout waiting for authentication - private val authTimeout = System.getProperty("spark.core.connection.auth.wait.timeout", - "30").toInt + private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30) private val handleMessageExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.handler.threads.min", 20), @@ -773,7 +772,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, logDebug("getAuthenticated wait connectionid: " + connection.connectionId) // have timeout in case remote side never responds connection.getAuthenticated().wait(500) - if (((clock.getTime() - startTime) >= (authTimeout * 1000)) && (!connection.isSaslComplete())) { + if (((clock.getTime() - startTime) >= (authTimeout * 1000)) + && (!connection.isSaslComplete())) { // took to long to authenticate the connection, something probably went wrong throw new Exception("Took to long for authentication to " + connectionManagerId + ", waited " + authTimeout + "seconds, failing.") @@ -854,7 +854,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private[spark] object ConnectionManager { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf, new SecurityManager) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index ba915dc74cdcf..9208250468d16 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -23,7 +23,8 @@ import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { - val manager = new ConnectionManager(9999, new SparkConf, new SecurityManager) + val conf = new SparkConf + val manager = new ConnectionManager(9999, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 234b729690fe8..1b7c838ed440a 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -32,8 +32,8 @@ private[spark] object SenderTest { val targetHost = args(0) val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - - val manager = new ConnectionManager(0, new SparkConf, new SecurityManager) + val conf = new SparkConf + val manager = new ConnectionManager(0, conf, new SecurityManager(conf)) println("Started connection manager with id = " + manager.id) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 53b56e3ab544c..adccef757b0c4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -98,7 +98,7 @@ private[spark] object ThreadingTest { actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf))), conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager()) + new SecurityManager(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index aa57f2687028e..b45b1aa241271 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -29,13 +29,13 @@ import scala.xml.Node import net.liftweb.json.{JValue, pretty, render} import org.eclipse.jetty.server.{DispatcherType, Server} -import org.eclipse.jetty.server.handler.{ResourceHandler, HandlerList, ContextHandler, AbstractHandler} +import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.Logging -import org.apache.spark.SparkEnv import org.apache.spark.SecurityManager +import org.apache.spark.SparkConf /** Utilities for launching a web server using Jetty's HTTP Server class */ @@ -45,31 +45,31 @@ private[spark] object JettyUtils extends Logging { type Responder[T] = HttpServletRequest => T - // Conversions from various types of Responder's to jetty Handlers - implicit def jsonResponderToServlet(responder: Responder[JValue]): HttpServlet = - createServlet(responder, "text/json", (in: JValue) => pretty(render(in))) + class ServletParams[T <% AnyRef](val responder: Responder[T], + val contentType: String, + val extractFn: T => String = (in: Any) => in.toString) {} - implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): HttpServlet = - createServlet(responder, "text/html", (in: Seq[Node]) => "" + in.toString) + // Conversions from various types of Responder's to appropriate servlet parameters + implicit def jsonResponderToServlet(responder: Responder[JValue]): ServletParams[JValue] = + new ServletParams(responder, "text/json", (in: JValue) => pretty(render(in))) - implicit def textResponderToServlet(responder: Responder[String]): HttpServlet = - createServlet(responder, "text/plain") + implicit def htmlResponderToServlet(responder: Responder[Seq[Node]]): ServletParams[Seq[Node]] = + new ServletParams(responder, "text/html", (in: Seq[Node]) => "" + in.toString) - def createServlet[T <% AnyRef](responder: Responder[T], contentType: String, - extractFn: T => String = (in: Any) => in.toString): HttpServlet = { + implicit def textResponderToServlet(responder: Responder[String]): ServletParams[String] = + new ServletParams(responder, "text/plain") + + def createServlet[T <% AnyRef](servletParams: ServletParams[T], + securityMgr: SecurityManager): HttpServlet = { new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { - // First try to get the security Manager from the SparkEnv. If that doesn't exist, create - // a new one and rely on the configs being set - val sparkEnv = SparkEnv.get - val securityMgr = if (sparkEnv != null) sparkEnv.securityManager else new SecurityManager() if (securityMgr.checkUIViewPermissions(request.getRemoteUser())) { - response.setContentType("%s;charset=utf-8".format(contentType)) + response.setContentType("%s;charset=utf-8".format(servletParams.contentType)) response.setStatus(HttpServletResponse.SC_OK) - val result = responder(request) + val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - response.getWriter().println(extractFn(result)) + response.getWriter().println(servletParams.extractFn(result)) } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") @@ -120,8 +120,8 @@ private[spark] object JettyUtils extends Logging { contextHandler } - private def addFilters(handlers: Seq[ServletContextHandler]) { - val filters: Array[String] = System.getProperty("spark.ui.filters", "").split(',').map(_.trim()) + private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { + val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) filters.foreach { case filter : String => if (!filter.isEmpty) { @@ -129,8 +129,8 @@ private[spark] object JettyUtils extends Logging { val holder : FilterHolder = new FilterHolder() holder.setClassName(filter) // get any parameters for each filter - val paramName = filter + ".params" - val params = System.getProperty(paramName, "").split(',').map(_.trim()).toSet + val paramName = "spark." + filter + ".params" + val params = conf.get(paramName, "").split(',').map(_.trim()).toSet params.foreach { case param : String => if (!param.isEmpty) { @@ -152,10 +152,10 @@ private[spark] object JettyUtils extends Logging { * If the desired port number is contented, continues incrementing ports until a free port is * found. Returns the chosen port and the jetty Server object. */ - def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler]): - (Server, Int) = { + def startJettyServer(hostName: String, port: Int, handlers: Seq[ServletContextHandler], + conf: SparkConf): (Server, Int) = { - addFilters(handlers) + addFilters(handlers, conf) val handlerList = new HandlerList handlerList.setHandlers(handlers.toArray) @@ -167,7 +167,9 @@ private[spark] object JettyUtils extends Logging { server.setThreadPool(pool) server.setHandler(handlerList) - Try { server.start() } match { + Try { + server.start() + } match { case s: Success[_] => (server, server.getConnectors.head.getLocalPort) case f: Failure[_] => diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index b8e2d4e91bdbb..a60546e41be7c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -55,7 +55,7 @@ private[spark] class SparkUI(sc: SparkContext) extends Logging { /** Bind the HTTP server which backs this web interface */ def bind() { try { - val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers) + val (srv, usedPort) = JettyUtils.startJettyServer(host, port, allHandlers, sc.conf) logInfo("Started Spark Web UI at http://%s:%d".format(host, usedPort)) server = Some(srv) boundPort = Some(usedPort) diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index 37fb87fc37a85..5728f0244478c 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -34,7 +34,8 @@ import org.apache.spark.SparkContext private[spark] class EnvironmentUI(sc: SparkContext) { def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/environment", (request: HttpServletRequest) => envDetails(request)) + createServletHandler("/environment", + createServlet((request: HttpServletRequest) => envDetails(request), sc.env.securityManager)) ) def envDetails(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 36d86279959d5..7a4b8cabdb769 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -45,7 +45,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { } def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/executors", (request: HttpServletRequest) => render(request)) + createServletHandler("/executors", createServlet((request: HttpServletRequest) => render + (request), sc.env.securityManager)) ) def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index cc69aba0b5651..3a9a85076af0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -55,10 +55,13 @@ private[spark] class JobProgressUI(val sc: SparkContext) { def getHandlers = Seq[ServletContextHandler]( createServletHandler("/stages/stage", - (request: HttpServletRequest) => stagePage.render(request)), + createServlet((request: HttpServletRequest) => stagePage.render(request), + sc.env.securityManager)), createServletHandler("/stages/pool", - (request: HttpServletRequest) => poolPage.render(request)), + createServlet((request: HttpServletRequest) => poolPage.render(request), + sc.env.securityManager)), createServletHandler("/stages", - (request: HttpServletRequest) => indexPage.render(request)) + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 8de6a16772dd1..1b3811b16d6fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -32,7 +32,11 @@ private[spark] class BlockManagerUI(val sc: SparkContext) extends Logging { val rddPage = new RDDPage(this) def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/storage/rdd", (request: HttpServletRequest) => rddPage.render(request)), - createServletHandler("/storage", (request: HttpServletRequest) => indexPage.render(request)) + createServletHandler("/storage/rdd", + createServlet((request: HttpServletRequest) => rddPage.render(request), + sc.env.securityManager)), + createServletHandler("/storage", + createServlet((request: HttpServletRequest) => indexPage.render(request), + sc.env.securityManager)) ) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 553d33cedd656..d25f1d7e218b3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.io._ import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, Logging} +import org.apache.spark.{SecurityManager, SparkConf, SparkException, Logging} /** @@ -259,7 +259,7 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File, conf: SparkConf) { + def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager) { val filename = url.split("/").last val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) @@ -271,10 +271,6 @@ private[spark] object Utils extends Logging { logInfo("Fetching " + url + " to " + tempFile) var uc: URLConnection = null - // First try to get the security Manager from the SparkEnv. If that doesn't exist, create - // a new one and rely on the configs being set - val sparkEnv = SparkEnv.get - val securityMgr = if (sparkEnv != null) sparkEnv.securityManager else new SecurityManager() if (securityMgr.isAuthenticationEnabled()) { logDebug("fetchFile with security enabled") val newuri = constructURIForAuthentication(uri, securityMgr) diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index 1d33dd5db6a32..cd054c1f684ab 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -29,13 +29,13 @@ import scala.concurrent.Await * Test the AkkaUtils with various security settings. */ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { - private val conf = new SparkConf test("remote fetch security bad password") { - System.setProperty("spark.authenticate", "true") - System.setProperty("SPARK_SECRET", "good") + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(); + val securityManager = new SecurityManager(conf); val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) @@ -47,9 +47,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - System.setProperty("spark.authenticate", "true") - System.setProperty("SPARK_SECRET", "bad") - val securityManagerBad= new SecurityManager(); + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); assert(securityManagerBad.isAuthenticationEnabled() === true) @@ -68,9 +69,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { } test("remote fetch security off") { - System.setProperty("spark.authenticate", "false") - System.setProperty("SPARK_SECRET", "bad") - val securityManager = new SecurityManager(); + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "bad") + val securityManager = new SecurityManager(conf); val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, @@ -84,18 +86,20 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - System.setProperty("spark.authenticate", "false") - System.setProperty("SPARK_SECRET", "good") + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "good") + val securityManagerBad = new SecurityManager(badconf); val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManager) + conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) - assert(securityManager.isAuthenticationEnabled() === false) + assert(securityManagerBad.isAuthenticationEnabled() === false) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -117,9 +121,10 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { } test("remote fetch security pass") { - System.setProperty("spark.authenticate", "true") - System.setProperty("SPARK_SECRET", "good") - val securityManager = new SecurityManager(); + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf); val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, @@ -133,13 +138,15 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - System.setProperty("spark.authenticate", "true") - System.setProperty("SPARK_SECRET", "good") + val goodconf = new SparkConf + goodconf.set("spark.authenticate", "true") + goodconf.set("spark.authenticate.secret", "good") + val securityManagerGood = new SecurityManager(goodconf); - assert(securityManager.isAuthenticationEnabled() === true) + assert(securityManagerGood.isAuthenticationEnabled() === true) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManager) + conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") @@ -166,9 +173,11 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { } test("remote fetch security off client") { - System.setProperty("spark.authenticate", "true") - System.setProperty("SPARK_SECRET", "good") - val securityManager = new SecurityManager(); + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + + val securityManager = new SecurityManager(conf); val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, @@ -182,14 +191,15 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.trackerActor = actorSystem.actorOf( Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") - System.setProperty("spark.authenticate", "false") - System.setProperty("SPARK_SECRET", "bad") - val securityManagerBad = new SecurityManager(); + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + badconf.set("spark.authenticate.secret", "bad") + val securityManagerBad = new SecurityManager(badconf); assert(securityManagerBad.isAuthenticationEnabled() === false) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, - conf = conf, securityManager = securityManagerBad) + conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 3bfd7d94d26ab..96ba3929c1685 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -21,8 +21,6 @@ import org.scalatest.FunSuite class BroadcastSuite extends FunSuite with LocalSparkContext { - System.setProperty("spark.authenticate", "false") - override def afterEach() { super.afterEach() diff --git a/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala new file mode 100644 index 0000000000000..80f7ec00c74b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ConnectionManagerSuite.scala @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.FunSuite + +import java.nio._ + +import org.apache.spark.network.{ConnectionManager, Message, ConnectionManagerId} +import scala.concurrent.Await +import scala.concurrent.TimeoutException +import scala.concurrent.duration._ + + +/** + * Test the ConnectionManager with various security settings. + */ +class ConnectionManagerSuite extends FunSuite { + + test("security default off") { + val conf = new SparkConf + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var receivedMessage = false + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + receivedMessage = true + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(manager.id, bufferMessage) + + assert(receivedMessage == true) + + manager.stop() + } + + test("security on same password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + val managerServer = new ConnectionManager(0, conf, securityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val count = 10 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + + (0 until count).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + }) + + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch password") { + val conf = new SparkConf + conf.set("spark.authenticate", "true") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "bad") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliablySync(managerServer.id, bufferMessage) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + test("security mismatch auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + conf.set("spark.authenticate.secret", "good") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "true") + badconf.set("spark.authenticate.secret", "good") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 1).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + assert(false) + } catch { + case e: TimeoutException => { + // we should timeout here since the client can't do the negotiation + assert(true) + } + } + }) + + assert(numReceivedServerMessages == 0) + assert(numReceivedMessages == 0) + manager.stop() + managerServer.stop() + } + + test("security auth off") { + val conf = new SparkConf + conf.set("spark.authenticate", "false") + val securityManager = new SecurityManager(conf) + val manager = new ConnectionManager(0, conf, securityManager) + var numReceivedMessages = 0 + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedMessages += 1 + None + }) + + val badconf = new SparkConf + badconf.set("spark.authenticate", "false") + val badsecurityManager = new SecurityManager(badconf) + val managerServer = new ConnectionManager(0, badconf, badsecurityManager) + var numReceivedServerMessages = 0 + + managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + numReceivedServerMessages += 1 + None + }) + + val size = 10 * 1024 * 1024 + val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) + buffer.flip + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + (0 until 10).map(i => { + val bufferMessage = Message.createBufferMessage(buffer.duplicate) + manager.sendMessageReliably(managerServer.id, bufferMessage) + }).foreach(f => { + try { + val g = Await.result(f, 1 second) + if (!g.isDefined) assert(false) else assert(true) + } catch { + case e: Exception => { + assert(false) + } + } + }) + assert(numReceivedServerMessages == 10) + assert(numReceivedMessages == 0) + + manager.stop() + managerServer.stop() + } + + + +} + diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 0c6b5b8488878..510c35a8b5686 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -51,7 +51,6 @@ class DriverSuite extends FunSuite with Timeouts { */ object DriverWithoutCleanup { def main(args: Array[String]) { - System.setProperty("spark.authenticate", "false") Logger.getRootLogger().setLevel(Level.WARN) val sc = new SparkContext(args(0), "DriverWithoutCleanup") sc.parallelize(1 to 100, 4).count() diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 06e597d90186a..c3bfaa2855371 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -84,11 +84,13 @@ class FileServerSuite extends FunSuite with LocalSparkContext { } test("Distributing files locally security On") { - System.setProperty("spark.authenticate", "true") - System.setProperty("SPARK_SECRET", "good") + val sparkConf = new SparkConf(false) + sparkConf.set("spark.authenticate", "true") + sparkConf.set("spark.authenticate.secret", "good") + sc = new SparkContext("local[4]", "test", sparkConf) - sc = new SparkContext("local[4]", "test") sc.addFile(tmpFile.toString) + assert(sc.env.securityManager.isAuthenticationEnabled() === true) val testData = Array((1,1), (1,1), (2,1), (3,5), (2,2), (3,0)) val result = sc.parallelize(testData).reduceByKey { val path = SparkFiles.get("FileServerSuite.txt") diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index e9319a9063776..8ce5834cc12e2 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -98,7 +98,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("remote fetch") { val hostname = "localhost" val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, - securityManager = new SecurityManager) + securityManager = new SecurityManager(conf)) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext val masterTracker = new MapOutputTrackerMaster(conf) @@ -106,7 +106,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, - securityManager = new SecurityManager) + securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTracker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 71a2c6c498eef..755962cb298e8 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -19,19 +19,21 @@ package org.apache.spark.metrics import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.deploy.master.MasterSource -import org.apache.spark.SparkConf +import org.apache.spark.{SecurityManager, SparkConf} class MetricsSystemSuite extends FunSuite with BeforeAndAfter { var filePath: String = _ var conf: SparkConf = null + var securityMgr: SecurityManager = null before { filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() conf = new SparkConf(false).set("spark.metrics.conf", filePath) + securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { - val metricsSystem = MetricsSystem.createMetricsSystem("default", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks @@ -41,7 +43,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter { } test("MetricsSystem with sources add") { - val metricsSystem = MetricsSystem.createMetricsSystem("test", conf) + val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) val sources = metricsSystem.sources val sinks = metricsSystem.sinks diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9b2454026b4ae..180dc044edbb3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -41,8 +41,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var actorSystem: ActorSystem = null var master: BlockManagerMaster = null var oldArch: String = null - System.setProperty("spark.authenticate", "false") - val securityMgr = new SecurityManager() + conf.set("spark.authenticate", "false") + val securityMgr = new SecurityManager(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index c17bbfe7d35ba..f97f68ae137e5 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -21,6 +21,7 @@ import scala.util.{Failure, Success, Try} import java.net.ServerSocket import org.scalatest.FunSuite import org.eclipse.jetty.server.Server +import org.apache.spark.SparkConf class UISuite extends FunSuite { test("jetty port increases under contention") { @@ -32,15 +33,17 @@ class UISuite extends FunSuite { case Failure(e) => // Either case server port is busy hence setup for test complete } - val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) - val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq()) + val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) + val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("0.0.0.0", startPort, Seq(), + new SparkConf) // Allow some wiggle room in case ports on the machine are under contention assert(boundPort1 > startPort && boundPort1 < startPort + 10) assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) } test("jetty binds to port 0 correctly") { - val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq()) + val (jettyServer, boundPort) = JettyUtils.startJettyServer("0.0.0.0", 0, Seq(), new SparkConf) assert(jettyServer.getState === "STARTED") assert(boundPort != 0) Try {new ServerSocket(boundPort)} match { diff --git a/docs/configuration.md b/docs/configuration.md index c94ef26a12739..ef990d7c63896 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -153,8 +153,8 @@ Apart from these, the following properties are also available, and may be useful Comma separated list of filter class names to apply to the Spark web ui. The filter should be a standard javax servlet Filter. Parameters to each filter can also be specified by setting a - java system property of .params='param1=value1,param2=value2' - (e.g.-Dspark.ui.filters=com.test.filter1 -Dcom.test.filter1.params='param1=foo,param2=testing') + java system property of spark..params='param1=value1,param2=value2' + (e.g.-Dspark.ui.filters=com.test.filter1 -Dspark.com.test.filter1.params='param1=foo,param2=testing') @@ -509,10 +509,18 @@ Apart from these, the following properties are also available, and may be useful spark.authenticate false - Whether spark authenticates its internal connections. See SPARK_SECRET if not + Whether spark authenticates its internal connections. See spark.authenticate.secret if not running on Yarn. + + spark.authenticate.secret + None + + Set the secret key used for Spark to authenticate between components. This needs to be set if + not running on Yarn and authentication is enabled. + + spark.core.connection.auth.wait.timeout 30 @@ -551,8 +559,6 @@ The following variables can be set in `spark-env.sh`: * `SPARK_JAVA_OPTS`, to add JVM options. This includes Java options like garbage collector settings and any system properties that you'd like to pass with `-D`. One use case is to set some Spark properties differently on this machine, e.g., `-Dspark.local.dir=/disk1,/disk2`. -* `SPARK_SECRET`, Set the secret key used for Spark to authenticate between components. This needs to be set if - not running on Yarn and authentication is enabled. * Options for the Spark [standalone cluster scripts](spark-standalone.html#cluster-launch-scripts), such as number of cores to use on each machine and maximum memory. diff --git a/docs/security.md b/docs/security.md index aa61dfc354c19..9e4218fbcfe7d 100644 --- a/docs/security.md +++ b/docs/security.md @@ -9,7 +9,9 @@ The Spark UI can also be secured by using javax servlet filters. A user may want For Spark on Yarn deployments, configuring `spark.authenticate` to true will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. The Spark UI uses the standard YARN web application proxy mechanism and will authenticate via any installed Hadoop filters. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. -For other types of Spark deployments, the environment variable `SPARK_SECRET` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. +For other types of Spark deployments, the spark config `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. The UI can be secured using a javax servlet filter installed via `spark.ui.filters`. If an authentication filter is enabled, the acls controls can be used by control which users can via the Spark UI. + +IMPORTANT NOTE: The NettyBlockFetcherIterator is not secured so do not use netty for the shuffle is running with authentication on. See [Spark Configuration](configuration.html) for more details on the security configs. diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index a4e900ce9d0bd..62d3a52615584 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -112,9 +112,9 @@ object FeederActor { } val Seq(host, port) = args.toSeq - - val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = new SparkConf, - securityManager = new SecurityManager)._1 + val conf = new SparkConf + val actorSystem = AkkaUtils.createActorSystem("test", host, port.toInt, conf = conf, + securityManager = new SecurityManager(conf))._1 val feeder = actorSystem.actorOf(Props[FeederActor], "FeederActor") println("Feeder started as:" + feeder) diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index d5b20d471867a..ae04a69a145df 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -83,15 +83,17 @@ import org.apache.spark.util.Utils * @author Moez A. Abdel-Gawad * @author Lex Spoon */ - class SparkIMain(initialSettings: Settings, val out: JPrintWriter) extends SparkImports with Logging { + class SparkIMain(initialSettings: Settings, val out: JPrintWriter) + extends SparkImports with Logging { imain => - val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + val conf = new SparkConf() + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") /** Local directory to save .class files too */ val outputDir = { val tmp = System.getProperty("java.io.tmpdir") - val rootDir = new SparkConf().get("spark.repl.classdir", tmp) + val rootDir = conf.get("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) } if (SPARK_DEBUG_REPL) { @@ -99,7 +101,8 @@ import org.apache.spark.util.Utils } val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles - val classServer = new HttpServer(outputDir, new SecurityManager()) /** Jetty server that will serve our classes to worker nodes */ + val classServer = new HttpServer(outputDir, + new SecurityManager(conf)) /** Jetty server that will serve our classes to worker nodes */ private var currentSettings: Settings = initialSettings var printResults = true // whether to print result lines var totalSilence = false // whether to print anything diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 799f717eb30da..8203b8f6122e1 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -29,8 +29,6 @@ import org.apache.spark.SparkContext class ReplSuite extends FunSuite { - System.setProperty("spark.authenticate", "false") - def runInterpreter(master: String, input: String): String = { val in = new BufferedReader(new StringReader(input + "\n")) val out = new StringWriter() diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 77c1bda495861..a68100f4b6ab5 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -113,7 +113,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - System.setProperty("org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) + System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", + params) } /** Get the Yarn approved local directories. */ diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 27715722196a2..cc18899ba303d 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -50,7 +50,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var yarnAllocator: YarnAllocationHandler = _ private var driverClosed:Boolean = false - val securityManager = new SecurityManager() + val securityManager = new SecurityManager(sparkConf) val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 0fcad80ae1ce2..8ef9d4630d767 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -119,7 +119,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, System.getenv(ApplicationConstants.APPLICATION_WEB_PROXY_BASE_ENV) val params = "PROXY_HOST=" + parts(0) + "," + "PROXY_URI_BASE=" + uriBase - System.setProperty("org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) + System.setProperty("spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params) } /** Get the Yarn approved local directories. */ diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 9490b1f4eab74..3eb89ef5e6801 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -52,7 +52,7 @@ class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration, spar private var amClient: AMRMClient[ContainerRequest] = _ - val securityManager = new SecurityManager() + val securityManager = new SecurityManager(sparkConf) val actorSystem: ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityManager)._1 var actor: ActorRef = _ From 05ff5e092bd144f1d046de5bc8298d46ca5e69d4 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 5 Mar 2014 21:00:15 -0600 Subject: [PATCH 11/14] Fix up imports after upmerging to master --- core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 3 ++- core/src/main/scala/org/apache/spark/util/AkkaUtils.scala | 3 +-- repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala | 2 +- .../scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala | 2 +- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- .../scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 6ffe99587eeb5..7212e1ba1ada4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -25,7 +25,8 @@ import scala.annotation.tailrec import scala.util.{Failure, Success, Try} import scala.xml.Node -import net.liftweb.json.{JValue, pretty, render} +import org.json4s.JValue +import org.json4s.jackson.JsonMethods.{pretty, render} import org.eclipse.jetty.server.{DispatcherType, Server} import org.eclipse.jetty.server.handler.HandlerList diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 7e42f2ce93873..a6c9a9aaba8eb 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -24,8 +24,7 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.SparkConf -import org.apache.spark.{Logging, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index ae04a69a145df..90a96ad38381e 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -36,7 +36,7 @@ import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable import util.stackTraceString -import org.apache.spark.{HttpServer, SparkConf, Logging, SecurityManager} +import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} import org.apache.spark.util.Utils // /** directory to save .class files to */ diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index cc18899ba303d..b735d01df8097 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 7e417e76eede3..8c49a90ca4d1a 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.hadoop.yarn.webapp.util.WebAppUtils; -import org.apache.spark.{Loggin, SecurityManager, SparkConf, SparkContext} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala index 3eb89ef5e6801..f1c1fea0b5895 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import akka.actor._ import akka.remote._ import akka.actor.Terminated -import org.apache.spark.{SparkConf, SparkContext, Logging, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.util.{Utils, AkkaUtils} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.SplitInfo From d1040ecfca0458c88457d0d6e2bc02d1d68c4e47 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 5 Mar 2014 21:12:57 -0600 Subject: [PATCH 12/14] Fix up various imports --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 2 +- .../main/scala/org/apache/spark/metrics/MetricsSystem.scala | 2 +- core/src/main/scala/org/apache/spark/network/SenderTest.scala | 1 - core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 4 +--- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 26d42d533a8c6..2d6d0c33fac7e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,7 +30,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{SecurityManager, SparkConf, Logging, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState.DriverState diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 80a4985ec54e8..c5bda2078fc14 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 1b7c838ed440a..aac2c24a46faa 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -18,7 +18,6 @@ package org.apache.spark.network import java.nio.ByteBuffer -import java.net.InetAddress import org.apache.spark.{SecurityManager, SparkConf} private[spark] object SenderTest { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 7212e1ba1ada4..7c35cd165ad7c 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -33,9 +33,7 @@ import org.eclipse.jetty.server.handler.HandlerList import org.eclipse.jetty.servlet.{DefaultServlet, FilterHolder, ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool -import org.apache.spark.Logging -import org.apache.spark.SecurityManager -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** Utilities for launching a web server using Jetty's HTTP Server class */ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 956250cfcde45..0eb2f78b730f6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,7 +18,7 @@ package org.apache.spark.util import java.io._ -import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL} +import java.net.{InetAddress, Inet4Address, NetworkInterface, URI, URL, URLConnection} import java.nio.ByteBuffer import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} From 05eebedc7a69bc90b7bee4f179738167548ca2f9 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 5 Mar 2014 22:39:51 -0600 Subject: [PATCH 13/14] Fix dependency lost in upmerge --- project/SparkBuild.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index e86b83059a951..138aad7561043 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -288,6 +288,7 @@ object SparkBuild extends Build { "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", "org.apache.mesos" % "mesos" % "0.13.0", + "commons-net" % "commons-net" % "2.2", "net.java.dev.jets3t" % "jets3t" % "0.7.1" excludeAll(excludeCommonsLogging), "org.apache.derby" % "derby" % "10.4.2.0" % "test", "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J), From dfe3918265677a6208696a8cb82710203ee07c19 Mon Sep 17 00:00:00 2001 From: Thomas Graves Date: Wed, 5 Mar 2014 23:34:19 -0600 Subject: [PATCH 14/14] Fix merge conflict since startUserClass now using runAsUser --- .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 6 ++++++ .../org/apache/spark/deploy/yarn/ApplicationMaster.scala | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 094b8d586ae5a..bb574f415293a 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -90,6 +90,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, addAmIpFilter() ApplicationMaster.register(this) + + // Call this to force generation of secret so it gets populated into the + // hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the worker containers. + val securityMgr = new SecurityManager(sparkConf) + // Start the user's JAR userThread = startUserClass() diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8c49a90ca4d1a..b48a2d50db5ef 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -96,6 +96,11 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, ApplicationMaster.register(this) + // Call this to force generation of secret so it gets populated into the + // hadoop UGI. This has to happen before the startUserClass which does a + // doAs in order for the credentials to be passed on to the worker containers. + val securityMgr = new SecurityManager(sparkConf) + // Start the user's JAR userThread = startUserClass()