Skip to content

Commit b8dd7eb

Browse files
simplify the unit test
1 parent 871764c commit b8dd7eb

File tree

3 files changed

+50
-3
lines changed

3 files changed

+50
-3
lines changed

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,12 @@ private[spark] class ExternalSorter[K, V, C](
188188

189189
private val spills = new ArrayBuffer[SpilledFile]
190190

191+
/**
192+
* Number of files this sorter has spilled so far.
193+
* Exposed for testing.
194+
*/
195+
private[spark] def numSpills: Int = spills.size
196+
191197
override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
192198
// TODO: stop combining if we find that the reduction factor isn't high
193199
val shouldCombine = aggregator.isDefined

sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
7272
override def writeKey[T: ClassTag](key: T): SerializationStream = {
7373
// The key is only needed on the map side when computing partition ids. It does not need to
7474
// be shuffled.
75-
assert(key.isInstanceOf[Int])
75+
assert(null == key || key.isInstanceOf[Int])
7676
this
7777
}
7878

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717

1818
package org.apache.spark.sql.execution
1919

20-
import java.io.{DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
20+
import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
2121

22-
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.executor.ShuffleWriteMetrics
23+
import org.apache.spark.storage.ShuffleBlockId
24+
import org.apache.spark.util.collection.ExternalSorter
2325
import org.apache.spark.sql.Row
2426
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2527
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
2628
import org.apache.spark.sql.types._
29+
import org.apache.spark._
2730

2831

2932
/**
@@ -87,4 +90,42 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
8790
assert(!deserializerIter.hasNext)
8891
assert(input.closed)
8992
}
93+
94+
test("SPARK-10466: external sorter spilling with unsafe row serializer") {
95+
val conf = new SparkConf()
96+
.set("spark.shuffle.spill.initialMemoryThreshold", "1024")
97+
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
98+
.set("spark.shuffle.memoryFraction", "0.0001")
99+
var sc: SparkContext = null
100+
var outputFile: File = null
101+
try {
102+
sc = new SparkContext("local", "test", conf)
103+
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
104+
val data = (1 to 1000).iterator.map { i =>
105+
(i, toUnsafeRow(Row(i), Array(IntegerType)))
106+
}
107+
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
108+
partitioner = Some(new HashPartitioner(10)),
109+
serializer = Some(new UnsafeRowSerializer(numFields = 1)))
110+
111+
// Ensure we spilled something and have to merge them later
112+
assert(sorter.numSpills === 0)
113+
sorter.insertAll(data)
114+
assert(sorter.numSpills > 0)
115+
116+
// Merging spilled files should not throw assertion error
117+
val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc))
118+
taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics)
119+
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile)
120+
121+
} finally {
122+
// Clean up
123+
if (sc != null) {
124+
sc.stop()
125+
}
126+
if (outputFile != null) {
127+
outputFile.delete()
128+
}
129+
}
130+
}
90131
}

0 commit comments

Comments
 (0)