Skip to content

Commit 5bed68c

Browse files
committed
Merge remote-tracking branch 'upstream/master' into SPARK-32201
2 parents 3cd411f + 8c7d6f9 commit 5bed68c

File tree

188 files changed

+6260
-2814
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

188 files changed

+6260
-2814
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ BSD 3-Clause
229229
------------
230230

231231
python/lib/py4j-*-src.zip
232-
python/pyspark/cloudpickle.py
232+
python/pyspark/cloudpickle/*.py
233233
python/pyspark/join.py
234234
core/src/main/resources/org/apache/spark/ui/static/d3.min.js
235235

core/src/main/scala/org/apache/spark/MapOutputTracker.scala

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ import org.apache.spark.util._
4949
*
5050
* All public methods of this class are thread-safe.
5151
*/
52-
private class ShuffleStatus(numPartitions: Int) {
52+
private class ShuffleStatus(numPartitions: Int) extends Logging {
5353

5454
private val (readLock, writeLock) = {
5555
val lock = new ReentrantReadWriteLock()
@@ -121,12 +121,28 @@ private class ShuffleStatus(numPartitions: Int) {
121121
mapStatuses(mapIndex) = status
122122
}
123123

124+
/**
125+
* Update the map output location (e.g. during migration).
126+
*/
127+
def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock {
128+
val mapStatusOpt = mapStatuses.find(_.mapId == mapId)
129+
mapStatusOpt match {
130+
case Some(mapStatus) =>
131+
logInfo(s"Updating map output for ${mapId} to ${bmAddress}")
132+
mapStatus.updateLocation(bmAddress)
133+
invalidateSerializedMapOutputStatusCache()
134+
case None =>
135+
logError(s"Asked to update map output ${mapId} for untracked map status.")
136+
}
137+
}
138+
124139
/**
125140
* Remove the map output which was served by the specified block manager.
126141
* This is a no-op if there is no registered map output or if the registered output is from a
127142
* different block manager.
128143
*/
129144
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
145+
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
130146
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
131147
_numAvailableOutputs -= 1
132148
mapStatuses(mapIndex) = null
@@ -139,6 +155,7 @@ private class ShuffleStatus(numPartitions: Int) {
139155
* outputs which are served by an external shuffle server (if one exists).
140156
*/
141157
def removeOutputsOnHost(host: String): Unit = withWriteLock {
158+
logDebug(s"Removing outputs for host ${host}")
142159
removeOutputsByFilter(x => x.host == host)
143160
}
144161

@@ -148,6 +165,7 @@ private class ShuffleStatus(numPartitions: Int) {
148165
* still registered with that execId.
149166
*/
150167
def removeOutputsOnExecutor(execId: String): Unit = withWriteLock {
168+
logDebug(s"Removing outputs for execId ${execId}")
151169
removeOutputsByFilter(x => x.executorId == execId)
152170
}
153171

@@ -265,7 +283,7 @@ private[spark] class MapOutputTrackerMasterEndpoint(
265283
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
266284
case GetMapOutputStatuses(shuffleId: Int) =>
267285
val hostPort = context.senderAddress.hostPort
268-
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)
286+
logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}")
269287
tracker.post(new GetMapOutputMessage(shuffleId, context))
270288

271289
case StopMapOutputTracker =>
@@ -465,6 +483,15 @@ private[spark] class MapOutputTrackerMaster(
465483
}
466484
}
467485

486+
def updateMapOutput(shuffleId: Int, mapId: Long, bmAddress: BlockManagerId): Unit = {
487+
shuffleStatuses.get(shuffleId) match {
488+
case Some(shuffleStatus) =>
489+
shuffleStatus.updateMapOutput(mapId, bmAddress)
490+
case None =>
491+
logError(s"Asked to update map output for unknown shuffle ${shuffleId}")
492+
}
493+
}
494+
468495
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = {
469496
shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
470497
}
@@ -745,7 +772,12 @@ private[spark] class MapOutputTrackerMaster(
745772
override def stop(): Unit = {
746773
mapOutputRequests.offer(PoisonPill)
747774
threadpool.shutdown()
748-
sendTracker(StopMapOutputTracker)
775+
try {
776+
sendTracker(StopMapOutputTracker)
777+
} catch {
778+
case e: SparkException =>
779+
logError("Could not tell tracker we are stopping.", e)
780+
}
749781
trackerEndpoint = null
750782
shuffleStatuses.clear()
751783
}

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ object SparkEnv extends Logging {
367367
externalShuffleClient
368368
} else {
369369
None
370-
}, blockManagerInfo)),
370+
}, blockManagerInfo,
371+
mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])),
371372
registerOrLookupEndpoint(
372373
BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
373374
new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)),

core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,13 @@ private[deploy] object DeployMessages {
108108

109109
case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage
110110

111+
/**
112+
* Used by the MasterWebUI to request the master to decommission all workers that are active on
113+
* any of the given hostnames.
114+
* @param hostnames: A list of hostnames without the ports. Like "localhost", "foo.bar.com" etc
115+
*/
116+
case class DecommissionWorkersOnHosts(hostnames: Seq[String])
117+
111118
// Master to Worker
112119

113120
sealed trait RegisterWorkerResponse

core/src/main/scala/org/apache/spark/deploy/master/Master.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ import java.util.{Date, Locale}
2222
import java.util.concurrent.{ScheduledFuture, TimeUnit}
2323

2424
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
25+
import scala.collection.mutable
2526
import scala.util.Random
27+
import scala.util.control.NonFatal
2628

2729
import org.apache.spark.{SecurityManager, SparkConf, SparkException}
2830
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil}
@@ -525,6 +527,13 @@ private[deploy] class Master(
525527
case KillExecutors(appId, executorIds) =>
526528
val formattedExecutorIds = formatExecutorIds(executorIds)
527529
context.reply(handleKillExecutors(appId, formattedExecutorIds))
530+
531+
case DecommissionWorkersOnHosts(hostnames) =>
532+
if (state != RecoveryState.STANDBY) {
533+
context.reply(decommissionWorkersOnHosts(hostnames))
534+
} else {
535+
context.reply(0)
536+
}
528537
}
529538

530539
override def onDisconnected(address: RpcAddress): Unit = {
@@ -863,6 +872,34 @@ private[deploy] class Master(
863872
true
864873
}
865874

875+
/**
876+
* Decommission all workers that are active on any of the given hostnames. The decommissioning is
877+
* asynchronously done by enqueueing WorkerDecommission messages to self. No checks are done about
878+
* the prior state of the worker. So an already decommissioned worker will match as well.
879+
*
880+
* @param hostnames: A list of hostnames without the ports. Like "localhost", "foo.bar.com" etc
881+
*
882+
* Returns the number of workers that matched the hostnames.
883+
*/
884+
private def decommissionWorkersOnHosts(hostnames: Seq[String]): Integer = {
885+
val hostnamesSet = hostnames.map(_.toLowerCase(Locale.ROOT)).toSet
886+
val workersToRemove = addressToWorker
887+
.filterKeys(addr => hostnamesSet.contains(addr.host.toLowerCase(Locale.ROOT)))
888+
.values
889+
890+
val workersToRemoveHostPorts = workersToRemove.map(_.hostPort)
891+
logInfo(s"Decommissioning the workers with host:ports ${workersToRemoveHostPorts}")
892+
893+
// The workers are removed async to avoid blocking the receive loop for the entire batch
894+
workersToRemove.foreach(wi => {
895+
logInfo(s"Sending the worker decommission to ${wi.id} and ${wi.endpoint}")
896+
self.send(WorkerDecommission(wi.id, wi.endpoint))
897+
})
898+
899+
// Return the count of workers actually removed
900+
workersToRemove.size
901+
}
902+
866903
private def decommissionWorker(worker: WorkerInfo): Unit = {
867904
if (worker.state != WorkerState.DECOMMISSIONED) {
868905
logInfo("Decommissioning worker %s on %s:%d".format(worker.id, worker.host, worker.port))

core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,14 @@
1717

1818
package org.apache.spark.deploy.master.ui
1919

20-
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
20+
import java.net.{InetAddress, NetworkInterface, SocketException}
21+
import java.util.Locale
22+
import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
23+
24+
import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, MasterStateResponse, RequestMasterState}
2125
import org.apache.spark.deploy.master.Master
2226
import org.apache.spark.internal.Logging
27+
import org.apache.spark.internal.config.UI.MASTER_UI_DECOMMISSION_ALLOW_MODE
2328
import org.apache.spark.internal.config.UI.UI_KILL_ENABLED
2429
import org.apache.spark.ui.{SparkUI, WebUI}
2530
import org.apache.spark.ui.JettyUtils._
@@ -36,6 +41,7 @@ class MasterWebUI(
3641

3742
val masterEndpointRef = master.self
3843
val killEnabled = master.conf.get(UI_KILL_ENABLED)
44+
val decommissionAllowMode = master.conf.get(MASTER_UI_DECOMMISSION_ALLOW_MODE)
3945

4046
initialize()
4147

@@ -49,6 +55,27 @@ class MasterWebUI(
4955
"/app/kill", "/", masterPage.handleAppKillRequest, httpMethods = Set("POST")))
5056
attachHandler(createRedirectHandler(
5157
"/driver/kill", "/", masterPage.handleDriverKillRequest, httpMethods = Set("POST")))
58+
attachHandler(createServletHandler("/workers/kill", new HttpServlet {
59+
override def doPost(req: HttpServletRequest, resp: HttpServletResponse): Unit = {
60+
val hostnames: Seq[String] = Option(req.getParameterValues("host"))
61+
.getOrElse(Array[String]()).toSeq
62+
if (!isDecommissioningRequestAllowed(req)) {
63+
resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
64+
} else {
65+
val removedWorkers = masterEndpointRef.askSync[Integer](
66+
DecommissionWorkersOnHosts(hostnames))
67+
logInfo(s"Decommissioning of hosts $hostnames decommissioned $removedWorkers workers")
68+
if (removedWorkers > 0) {
69+
resp.setStatus(HttpServletResponse.SC_OK)
70+
} else if (removedWorkers == 0) {
71+
resp.sendError(HttpServletResponse.SC_NOT_FOUND)
72+
} else {
73+
// We shouldn't even see this case.
74+
resp.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
75+
}
76+
}
77+
}
78+
}, ""))
5279
}
5380

5481
def addProxy(): Unit = {
@@ -64,6 +91,25 @@ class MasterWebUI(
6491
maybeWorkerUiAddress.orElse(maybeAppUiAddress)
6592
}
6693

94+
private def isLocal(address: InetAddress): Boolean = {
95+
if (address.isAnyLocalAddress || address.isLoopbackAddress) {
96+
return true
97+
}
98+
try {
99+
NetworkInterface.getByInetAddress(address) != null
100+
} catch {
101+
case _: SocketException => false
102+
}
103+
}
104+
105+
private def isDecommissioningRequestAllowed(req: HttpServletRequest): Boolean = {
106+
decommissionAllowMode match {
107+
case "ALLOW" => true
108+
case "LOCAL" => isLocal(InetAddress.getByName(req.getRemoteAddr))
109+
case _ => false
110+
}
111+
}
112+
67113
}
68114

69115
private[master] object MasterWebUI {

core/src/main/scala/org/apache/spark/internal/config/UI.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.internal.config
1919

20+
import java.util.Locale
2021
import java.util.concurrent.TimeUnit
2122

2223
import org.apache.spark.network.util.ByteUnit
@@ -191,4 +192,15 @@ private[spark] object UI {
191192
.version("3.0.0")
192193
.stringConf
193194
.createOptional
195+
196+
val MASTER_UI_DECOMMISSION_ALLOW_MODE = ConfigBuilder("spark.master.ui.decommission.allow.mode")
197+
.doc("Specifies the behavior of the Master Web UI's /workers/kill endpoint. Possible choices" +
198+
" are: `LOCAL` means allow this endpoint from IP's that are local to the machine running" +
199+
" the Master, `DENY` means to completely disable this endpoint, `ALLOW` means to allow" +
200+
" calling this endpoint from any IP.")
201+
.internal()
202+
.version("3.1.0")
203+
.stringConf
204+
.transform(_.toUpperCase(Locale.ROOT))
205+
.createWithDefault("LOCAL")
194206
}

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,29 @@ package object config {
420420
.booleanConf
421421
.createWithDefault(false)
422422

423+
private[spark] val STORAGE_DECOMMISSION_SHUFFLE_BLOCKS_ENABLED =
424+
ConfigBuilder("spark.storage.decommission.shuffleBlocks.enabled")
425+
.doc("Whether to transfer shuffle blocks during block manager decommissioning. Requires " +
426+
"a migratable shuffle resolver (like sort based shuffe)")
427+
.version("3.1.0")
428+
.booleanConf
429+
.createWithDefault(false)
430+
431+
private[spark] val STORAGE_DECOMMISSION_SHUFFLE_MAX_THREADS =
432+
ConfigBuilder("spark.storage.decommission.shuffleBlocks.maxThreads")
433+
.doc("Maximum number of threads to use in migrating shuffle files.")
434+
.version("3.1.0")
435+
.intConf
436+
.checkValue(_ > 0, "The maximum number of threads should be positive")
437+
.createWithDefault(8)
438+
439+
private[spark] val STORAGE_DECOMMISSION_RDD_BLOCKS_ENABLED =
440+
ConfigBuilder("spark.storage.decommission.rddBlocks.enabled")
441+
.doc("Whether to transfer RDD blocks during block manager decommissioning.")
442+
.version("3.1.0")
443+
.booleanConf
444+
.createWithDefault(false)
445+
423446
private[spark] val STORAGE_DECOMMISSION_MAX_REPLICATION_FAILURE_PER_BLOCK =
424447
ConfigBuilder("spark.storage.decommission.maxReplicationFailuresPerBlock")
425448
.internal()
@@ -1843,6 +1866,17 @@ package object config {
18431866
.timeConf(TimeUnit.MILLISECONDS)
18441867
.createOptional
18451868

1869+
private[spark] val EXECUTOR_DECOMMISSION_KILL_INTERVAL =
1870+
ConfigBuilder("spark.executor.decommission.killInterval")
1871+
.doc("Duration after which a decommissioned executor will be killed forcefully." +
1872+
"This config is useful for cloud environments where we know in advance when " +
1873+
"an executor is going to go down after decommissioning signal i.e. around 2 mins " +
1874+
"in aws spot nodes, 1/2 hrs in spot block nodes etc. This config is currently " +
1875+
"used to decide what tasks running on decommission executors to speculate.")
1876+
.version("3.1.0")
1877+
.timeConf(TimeUnit.SECONDS)
1878+
.createOptional
1879+
18461880
private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir")
18471881
.doc("Staging directory used while submitting applications.")
18481882
.version("2.0.0")

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,10 @@ private[spark] class NettyBlockTransferService(
168168
// Everything else is encoded using our binary protocol.
169169
val metadata = JavaUtils.bufferToArray(serializer.newInstance().serialize((level, classTag)))
170170

171-
val asStream = blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM)
171+
// We always transfer shuffle blocks as a stream for simplicity with the receiving code since
172+
// they are always written to disk. Otherwise we check the block size.
173+
val asStream = (blockData.size() > conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) ||
174+
blockId.isShuffle)
172175
val callback = new RpcResponseCallback {
173176
override def onSuccess(response: ByteBuffer): Unit = {
174177
logTrace(s"Successfully uploaded block $blockId${if (asStream) " as stream" else ""}")

core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,15 @@ import org.apache.spark.util.Utils
3030

3131
/**
3232
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
33-
* task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks.
33+
* task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing
34+
* on to the reduce tasks.
3435
*/
3536
private[spark] sealed trait MapStatus {
36-
/** Location where this task was run. */
37+
/** Location where this task output is. */
3738
def location: BlockManagerId
3839

40+
def updateLocation(newLoc: BlockManagerId): Unit
41+
3942
/**
4043
* Estimated size for the reduce block, in bytes.
4144
*
@@ -126,6 +129,10 @@ private[spark] class CompressedMapStatus(
126129

127130
override def location: BlockManagerId = loc
128131

132+
override def updateLocation(newLoc: BlockManagerId): Unit = {
133+
loc = newLoc
134+
}
135+
129136
override def getSizeForBlock(reduceId: Int): Long = {
130137
MapStatus.decompressSize(compressedSizes(reduceId))
131138
}
@@ -178,6 +185,10 @@ private[spark] class HighlyCompressedMapStatus private (
178185

179186
override def location: BlockManagerId = loc
180187

188+
override def updateLocation(newLoc: BlockManagerId): Unit = {
189+
loc = newLoc
190+
}
191+
181192
override def getSizeForBlock(reduceId: Int): Long = {
182193
assert(hugeBlockSizes != null)
183194
if (emptyBlocks.contains(reduceId)) {

0 commit comments

Comments
 (0)