Skip to content

Commit 68ff3d3

Browse files
restore the SparkEnv after SparkContext.stop()
1 parent b8dd7eb commit 68ff3d3

File tree

1 file changed

+39
-24
lines changed

1 file changed

+39
-24
lines changed

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStr
2222
import org.apache.spark.executor.ShuffleWriteMetrics
2323
import org.apache.spark.storage.ShuffleBlockId
2424
import org.apache.spark.util.collection.ExternalSorter
25+
import org.apache.spark.util.Utils
2526
import org.apache.spark.sql.Row
2627
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
2728
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
@@ -43,9 +44,15 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
4344
class UnsafeRowSerializerSuite extends SparkFunSuite {
4445

4546
private def toUnsafeRow(row: Row, schema: Array[DataType]): UnsafeRow = {
46-
val internalRow = CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow]
47+
val converter = unsafeRowConverter(schema)
48+
converter(row)
49+
}
50+
51+
private def unsafeRowConverter(schema: Array[DataType]): Row => UnsafeRow = {
4752
val converter = UnsafeProjection.create(schema)
48-
converter.apply(internalRow)
53+
(row: Row) => {
54+
converter(CatalystTypeConverters.convertToCatalyst(row).asInstanceOf[InternalRow])
55+
}
4956
}
5057

5158
test("toUnsafeRow() test helper method") {
@@ -92,37 +99,45 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
9299
}
93100

94101
test("SPARK-10466: external sorter spilling with unsafe row serializer") {
95-
val conf = new SparkConf()
96-
.set("spark.shuffle.spill.initialMemoryThreshold", "1024")
97-
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
98-
.set("spark.shuffle.memoryFraction", "0.0001")
99102
var sc: SparkContext = null
100103
var outputFile: File = null
101-
try {
102-
sc = new SparkContext("local", "test", conf)
103-
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
104-
val data = (1 to 1000).iterator.map { i =>
105-
(i, toUnsafeRow(Row(i), Array(IntegerType)))
106-
}
107-
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
108-
partitioner = Some(new HashPartitioner(10)),
109-
serializer = Some(new UnsafeRowSerializer(numFields = 1)))
104+
val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
105+
Utils.tryWithSafeFinally {
106+
val conf = new SparkConf()
107+
.set("spark.shuffle.spill.initialMemoryThreshold", "1024")
108+
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
109+
.set("spark.shuffle.memoryFraction", "0.0001")
110110

111-
// Ensure we spilled something and have to merge them later
112-
assert(sorter.numSpills === 0)
113-
sorter.insertAll(data)
114-
assert(sorter.numSpills > 0)
111+
sc = new SparkContext("local", "test", conf)
112+
outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
113+
// prepare data
114+
val converter = unsafeRowConverter(Array(IntegerType))
115+
val data = (1 to 1000).iterator.map { i =>
116+
(i, converter(Row(i)))
117+
}
118+
val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
119+
partitioner = Some(new HashPartitioner(10)),
120+
serializer = Some(new UnsafeRowSerializer(numFields = 1)))
115121

116-
// Merging spilled files should not throw assertion error
117-
val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc))
118-
taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics)
119-
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile)
122+
// Ensure we spilled something and have to merge them later
123+
assert(sorter.numSpills === 0)
124+
sorter.insertAll(data)
125+
assert(sorter.numSpills > 0)
120126

121-
} finally {
127+
// Merging spilled files should not throw assertion error
128+
val taskContext =
129+
new TaskContextImpl(0, 0, 0, 0, null, null, InternalAccumulator.create(sc))
130+
taskContext.taskMetrics.shuffleWriteMetrics = Some(new ShuffleWriteMetrics)
131+
sorter.writePartitionedFile(ShuffleBlockId(0, 0, 0), taskContext, outputFile)
132+
} {
122133
// Clean up
123134
if (sc != null) {
124135
sc.stop()
125136
}
137+
138+
// restore the spark env
139+
SparkEnv.set(oldEnv)
140+
126141
if (outputFile != null) {
127142
outputFile.delete()
128143
}

0 commit comments

Comments
 (0)