@@ -22,6 +22,7 @@ import java.io.{File, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStr
22
22
import org .apache .spark .executor .ShuffleWriteMetrics
23
23
import org .apache .spark .storage .ShuffleBlockId
24
24
import org .apache .spark .util .collection .ExternalSorter
25
+ import org .apache .spark .util .Utils
25
26
import org .apache .spark .sql .Row
26
27
import org .apache .spark .sql .catalyst .{CatalystTypeConverters , InternalRow }
27
28
import org .apache .spark .sql .catalyst .expressions .{UnsafeProjection , UnsafeRow }
@@ -43,9 +44,15 @@ class ClosableByteArrayInputStream(buf: Array[Byte]) extends ByteArrayInputStrea
43
44
class UnsafeRowSerializerSuite extends SparkFunSuite {
44
45
45
46
private def toUnsafeRow (row : Row , schema : Array [DataType ]): UnsafeRow = {
46
- val internalRow = CatalystTypeConverters .convertToCatalyst(row).asInstanceOf [InternalRow ]
47
+ val converter = unsafeRowConverter(schema)
48
+ converter(row)
49
+ }
50
+
51
+ private def unsafeRowConverter (schema : Array [DataType ]): Row => UnsafeRow = {
47
52
val converter = UnsafeProjection .create(schema)
48
- converter.apply(internalRow)
53
+ (row : Row ) => {
54
+ converter(CatalystTypeConverters .convertToCatalyst(row).asInstanceOf [InternalRow ])
55
+ }
49
56
}
50
57
51
58
test(" toUnsafeRow() test helper method" ) {
@@ -92,37 +99,45 @@ class UnsafeRowSerializerSuite extends SparkFunSuite {
92
99
}
93
100
94
101
test(" SPARK-10466: external sorter spilling with unsafe row serializer" ) {
95
- val conf = new SparkConf ()
96
- .set(" spark.shuffle.spill.initialMemoryThreshold" , " 1024" )
97
- .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" )
98
- .set(" spark.shuffle.memoryFraction" , " 0.0001" )
99
102
var sc : SparkContext = null
100
103
var outputFile : File = null
101
- try {
102
- sc = new SparkContext (" local" , " test" , conf)
103
- outputFile = File .createTempFile(" test-unsafe-row-serializer-spill" , " " )
104
- val data = (1 to 1000 ).iterator.map { i =>
105
- (i, toUnsafeRow(Row (i), Array (IntegerType )))
106
- }
107
- val sorter = new ExternalSorter [Int , UnsafeRow , UnsafeRow ](
108
- partitioner = Some (new HashPartitioner (10 )),
109
- serializer = Some (new UnsafeRowSerializer (numFields = 1 )))
104
+ val oldEnv = SparkEnv .get // save the old SparkEnv, as it will be overwritten
105
+ Utils .tryWithSafeFinally {
106
+ val conf = new SparkConf ()
107
+ .set(" spark.shuffle.spill.initialMemoryThreshold" , " 1024" )
108
+ .set(" spark.shuffle.sort.bypassMergeThreshold" , " 0" )
109
+ .set(" spark.shuffle.memoryFraction" , " 0.0001" )
110
110
111
- // Ensure we spilled something and have to merge them later
112
- assert(sorter.numSpills === 0 )
113
- sorter.insertAll(data)
114
- assert(sorter.numSpills > 0 )
111
+ sc = new SparkContext (" local" , " test" , conf)
112
+ outputFile = File .createTempFile(" test-unsafe-row-serializer-spill" , " " )
113
+ // prepare data
114
+ val converter = unsafeRowConverter(Array (IntegerType ))
115
+ val data = (1 to 1000 ).iterator.map { i =>
116
+ (i, converter(Row (i)))
117
+ }
118
+ val sorter = new ExternalSorter [Int , UnsafeRow , UnsafeRow ](
119
+ partitioner = Some (new HashPartitioner (10 )),
120
+ serializer = Some (new UnsafeRowSerializer (numFields = 1 )))
115
121
116
- // Merging spilled files should not throw assertion error
117
- val taskContext = new TaskContextImpl ( 0 , 0 , 0 , 0 , null , null , InternalAccumulator .create(sc) )
118
- taskContext.taskMetrics.shuffleWriteMetrics = Some ( new ShuffleWriteMetrics )
119
- sorter.writePartitionedFile( ShuffleBlockId ( 0 , 0 , 0 ), taskContext, outputFile )
122
+ // Ensure we spilled something and have to merge them later
123
+ assert(sorter.numSpills === 0 )
124
+ sorter.insertAll(data )
125
+ assert( sorter.numSpills > 0 )
120
126
121
- } finally {
127
+ // Merging spilled files should not throw assertion error
128
+ val taskContext =
129
+ new TaskContextImpl (0 , 0 , 0 , 0 , null , null , InternalAccumulator .create(sc))
130
+ taskContext.taskMetrics.shuffleWriteMetrics = Some (new ShuffleWriteMetrics )
131
+ sorter.writePartitionedFile(ShuffleBlockId (0 , 0 , 0 ), taskContext, outputFile)
132
+ } {
122
133
// Clean up
123
134
if (sc != null ) {
124
135
sc.stop()
125
136
}
137
+
138
+ // restore the spark env
139
+ SparkEnv .set(oldEnv)
140
+
126
141
if (outputFile != null ) {
127
142
outputFile.delete()
128
143
}
0 commit comments