Skip to content

[SPARK-4277] Support external shuffle service on Standalone Worker #3142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,15 +343,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with
*/
def getSecretKey(): String = secretKey

override def getSaslUser(appId: String): String = {
val myAppId = sparkConf.getAppId
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
getSaslUser()
}

override def getSecretKey(appId: String): String = {
val myAppId = sparkConf.getAppId
require(appId == myAppId, s"SASL appId $appId did not match my appId ${myAppId}")
getSecretKey()
}
// Default SecurityManager only has a single secret key, so ignore appId.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized that these checks were not useful (since the secret key is already doing the security checks). Additionally, it doesn't make sense on the Worker.

override def getSaslUser(appId: String): String = getSaslUser()
override def getSecretKey(appId: String): String = getSecretKey()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.deploy.worker

import org.apache.spark.{Logging, SparkConf, SecurityManager}
import org.apache.spark.network.TransportContext
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.sasl.SaslRpcHandler
import org.apache.spark.network.server.TransportServer
import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler

/**
* Provides a server from which Executors can read shuffle files (rather than reading directly from
* each other), to provide uninterrupted access to the files in the face of executors being turned
* off or killed.
*
* Optionally requires SASL authentication in order to read. See [[SecurityManager]].
*/
private[worker]
class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
extends Logging {

private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
private val useSasl: Boolean = securityManager.isAuthenticationEnabled()

private val transportConf = SparkTransportConf.fromSparkConf(sparkConf)
private val blockHandler = new ExternalShuffleBlockHandler()
private val transportContext: TransportContext = {
val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
new TransportContext(transportConf, handler)
}

private var server: TransportServer = _

/** Starts the external shuffle service if the user has configured us to. */
def startIfEnabled() {
if (enabled) {
require(server == null, "Shuffle server already started")
logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
server = transportContext.createServer(port)
}
}

def stop() {
if (enabled && server != null) {
server.close()
server = null
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ private[spark] class Worker(
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]

// The shuffle service is not actually started unless configured.
val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)

val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else host
Expand Down Expand Up @@ -154,6 +157,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
Expand Down Expand Up @@ -419,6 +423,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
Expand All @@ -441,7 +446,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
workDir: String,
workerNumber: Option[Int] = None): (ActorSystem, Int) = {

// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
private[this] var currentResult: FetchResult = null
@volatile private[this] var currentResult: FetchResult = null
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was supposed to go into #3128 as we check this nullness of this parameter in multiple threads.


/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,6 @@ class NettyBlockTransferSecuritySuite extends FunSuite with MockitoSugar with Sh
}
}

test("security mismatch app ids") {
val conf0 = new SparkConf()
.set("spark.authenticate", "true")
.set("spark.authenticate.secret", "good")
.set("spark.app.id", "app-id")
val conf1 = conf0.clone.set("spark.app.id", "other-id")
testConnection(conf0, conf1) match {
case Success(_) => fail("Should have failed")
case Failure(t) => t.getMessage should include ("SASL appId app-id did not match")
}
}

/**
* Creates two servers with different configurations and sees if they can talk.
* Returns Success() if they can transfer a block, and Failure() if the block transfer was failed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public void encode(ByteBuf buf) {

public static SaslMessage decode(ByteBuf buf) {
if (buf.readByte() != TAG_BYTE) {
throw new IllegalStateException("Expected SaslMessage, received something else");
throw new IllegalStateException("Expected SaslMessage, received something else"
+ " (maybe your client does not have SASL enabled?)");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improved error message requested by @andrewor14

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good

}

int idLength = buf.readInt();
Expand Down