Skip to content

Commit af01858

Browse files
Andrew Ormarkhamstra
authored andcommitted
[SPARK-11078] Ensure spilling tests actually spill
Author: Andrew Or <[email protected]> Closes apache#9124 from andrewor14/spilling-tests.
1 parent d061dda commit af01858

File tree

8 files changed

+535
-582
lines changed

8 files changed

+535
-582
lines changed

core/src/main/scala/org/apache/spark/TestUtils.scala

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,16 @@ import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
2121
import java.net.{URI, URL}
2222
import java.util.jar.{JarEntry, JarOutputStream}
2323

24-
import scala.collection.JavaConversions._
24+
import scala.collection.JavaConverters._
25+
import scala.collection.mutable
26+
import scala.collection.mutable.ArrayBuffer
2527

2628
import com.google.common.base.Charsets.UTF_8
2729
import com.google.common.io.{ByteStreams, Files}
2830
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
2931

32+
import org.apache.spark.executor.TaskMetrics
33+
import org.apache.spark.scheduler._
3034
import org.apache.spark.util.Utils
3135

3236
/**
@@ -153,4 +157,51 @@ private[spark] object TestUtils {
153157
" @Override public String toString() { return \"" + toStringValue + "\"; }}")
154158
createCompiledClass(className, destDir, sourceFile, classpathUrls)
155159
}
160+
161+
/**
162+
* Run some code involving jobs submitted to the given context and assert that the jobs spilled.
163+
*/
164+
def assertSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
165+
val spillListener = new SpillListener
166+
sc.addSparkListener(spillListener)
167+
body
168+
assert(spillListener.numSpilledStages > 0, s"expected $identifier to spill, but did not")
169+
}
170+
171+
/**
172+
* Run some code involving jobs submitted to the given context and assert that the jobs
173+
* did not spill.
174+
*/
175+
def assertNotSpilled[T](sc: SparkContext, identifier: String)(body: => T): Unit = {
176+
val spillListener = new SpillListener
177+
sc.addSparkListener(spillListener)
178+
body
179+
assert(spillListener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
180+
}
181+
182+
}
183+
184+
185+
/**
186+
* A [[SparkListener]] that detects whether spills have occurred in Spark jobs.
187+
*/
188+
private class SpillListener extends SparkListener {
189+
private val stageIdToTaskMetrics = new mutable.HashMap[Int, ArrayBuffer[TaskMetrics]]
190+
private val spilledStageIds = new mutable.HashSet[Int]
191+
192+
def numSpilledStages: Int = spilledStageIds.size
193+
194+
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
195+
stageIdToTaskMetrics.getOrElseUpdate(
196+
taskEnd.stageId, new ArrayBuffer[TaskMetrics]) += taskEnd.taskMetrics
197+
}
198+
199+
override def onStageCompleted(stageComplete: SparkListenerStageCompleted): Unit = {
200+
val stageId = stageComplete.stageInfo.stageId
201+
val metrics = stageIdToTaskMetrics.remove(stageId).toSeq.flatten
202+
val spilled = metrics.map(_.memoryBytesSpilled).sum > 0
203+
if (spilled) {
204+
spilledStageIds += stageId
205+
}
206+
}
156207
}

core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@ class ShuffleMemoryManager protected (
139139
throw new SparkException(
140140
s"Internal error: release called on $numBytes bytes but task only has $curMem")
141141
}
142-
taskMemory(taskAttemptId) -= numBytes
143-
memoryManager.releaseExecutionMemory(numBytes)
142+
if (taskMemory.contains(taskAttemptId)) {
143+
taskMemory(taskAttemptId) -= numBytes
144+
memoryManager.releaseExecutionMemory(numBytes)
145+
}
144146
memoryManager.notifyAll() // Notify waiters in tryToAcquire that memory has been freed
145147
}
146148

core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,12 @@ class ExternalAppendOnlyMap[K, V, C](
9595
private val keyComparator = new HashComparator[K]
9696
private val ser = serializer.newInstance()
9797

98+
/**
99+
* Number of files this map has spilled so far.
100+
* Exposed for testing.
101+
*/
102+
private[collection] def numSpills: Int = spilledMaps.size
103+
98104
/**
99105
* Insert the given key and value into the map.
100106
*/

core/src/main/scala/org/apache/spark/util/collection/Spillable.scala

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,15 @@ private[spark] trait Spillable[C] extends Logging {
4343
private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager
4444

4545
// Initial threshold for the size of a collection before we start tracking its memory usage
46-
// Exposed for testing
46+
// For testing only
4747
private[this] val initialMemoryThreshold: Long =
4848
SparkEnv.get.conf.getLong("spark.shuffle.spill.initialMemoryThreshold", 5 * 1024 * 1024)
4949

50+
// Force this collection to spill when there are this many elements in memory
51+
// For testing only
52+
private[this] val numElementsForceSpillThreshold: Long =
53+
SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MaxValue)
54+
5055
// Threshold for this collection's size in bytes before we start tracking its memory usage
5156
// To avoid a large number of small spills, initialize this to a value orders of magnitude > 0
5257
private[this] var myMemoryThreshold = initialMemoryThreshold
@@ -69,27 +74,27 @@ private[spark] trait Spillable[C] extends Logging {
6974
* @return true if `collection` was spilled to disk; false otherwise
7075
*/
7176
protected def maybeSpill(collection: C, currentMemory: Long): Boolean = {
77+
var shouldSpill = false
7278
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
7379
// Claim up to double our current memory from the shuffle memory pool
7480
val amountToRequest = 2 * currentMemory - myMemoryThreshold
7581
val granted = shuffleMemoryManager.tryToAcquire(amountToRequest)
7682
myMemoryThreshold += granted
77-
if (myMemoryThreshold <= currentMemory) {
78-
// We were granted too little memory to grow further (either tryToAcquire returned 0,
79-
// or we already had more memory than myMemoryThreshold); spill the current collection
80-
_spillCount += 1
81-
logSpillage(currentMemory)
82-
83-
spill(collection)
84-
85-
_elementsRead = 0
86-
// Keep track of spills, and release memory
87-
_memoryBytesSpilled += currentMemory
88-
releaseMemoryForThisThread()
89-
return true
90-
}
83+
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
84+
// or we already had more memory than myMemoryThreshold), spill the current collection
85+
shouldSpill = currentMemory >= myMemoryThreshold
86+
}
87+
shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold
88+
// Actually spill
89+
if (shouldSpill) {
90+
_spillCount += 1
91+
logSpillage(currentMemory)
92+
spill(collection)
93+
_elementsRead = 0
94+
_memoryBytesSpilled += currentMemory
95+
releaseMemoryForThisThread()
9196
}
92-
false
97+
shouldSpill
9398
}
9499

95100
/**

core/src/test/scala/org/apache/spark/DistributedSuite.scala

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -203,22 +203,35 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
203203
}
204204

205205
test("compute without caching when no partitions fit in memory") {
206-
sc = new SparkContext(clusterUrl, "test")
207-
// data will be 4 million * 4 bytes = 16 MB in size, but our memoryFraction set the cache
208-
// to only 50 KB (0.0001 of 512 MB), so no partitions should fit in memory
209-
val data = sc.parallelize(1 to 4000000, 2).persist(StorageLevel.MEMORY_ONLY_SER)
210-
assert(data.count() === 4000000)
211-
assert(data.count() === 4000000)
212-
assert(data.count() === 4000000)
206+
val size = 10000
207+
val conf = new SparkConf()
208+
.set("spark.storage.unrollMemoryThreshold", "1024")
209+
.set("spark.testing.memory", (size / 2).toString)
210+
sc = new SparkContext(clusterUrl, "test", conf)
211+
val data = sc.parallelize(1 to size, 2).persist(StorageLevel.MEMORY_ONLY)
212+
assert(data.count() === size)
213+
assert(data.count() === size)
214+
assert(data.count() === size)
215+
// ensure only a subset of partitions were cached
216+
val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true)
217+
assert(rddBlocks.size === 0, s"expected no RDD blocks, found ${rddBlocks.size}")
213218
}
214219

215220
test("compute when only some partitions fit in memory") {
216-
sc = new SparkContext(clusterUrl, "test", new SparkConf)
217-
// TODO: verify that only a subset of partitions fit in memory (SPARK-11078)
218-
val data = sc.parallelize(1 to 4000000, 20).persist(StorageLevel.MEMORY_ONLY_SER)
219-
assert(data.count() === 4000000)
220-
assert(data.count() === 4000000)
221-
assert(data.count() === 4000000)
221+
val size = 10000
222+
val numPartitions = 10
223+
val conf = new SparkConf()
224+
.set("spark.storage.unrollMemoryThreshold", "1024")
225+
.set("spark.testing.memory", (size * numPartitions).toString)
226+
sc = new SparkContext(clusterUrl, "test", conf)
227+
val data = sc.parallelize(1 to size, numPartitions).persist(StorageLevel.MEMORY_ONLY)
228+
assert(data.count() === size)
229+
assert(data.count() === size)
230+
assert(data.count() === size)
231+
// ensure only a subset of partitions were cached
232+
val rddBlocks = sc.env.blockManager.master.getMatchingBlockIds(_.isRDD, askSlaves = true)
233+
assert(rddBlocks.size > 0, "no RDD blocks found")
234+
assert(rddBlocks.size < numPartitions, s"too many RDD blocks found, expected <$numPartitions")
222235
}
223236

224237
test("passing environment variables to cluster") {

0 commit comments

Comments
 (0)