Skip to content

Commit dda6752

Browse files
committed
Commit some missing code from an old git stash.
1 parent 58f36d0 commit dda6752

File tree

7 files changed

+231
-18
lines changed

7 files changed

+231
-18
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public void spill() throws IOException {
137137
openSorter();
138138
}
139139

140-
private long freeMemory() {
140+
public long freeMemory() {
141141
long memoryFreed = 0;
142142
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
143143
while (iter.hasNext()) {
@@ -223,13 +223,20 @@ public void insertRecord(
223223
}
224224

225225
public UnsafeSorterIterator getSortedIterator() throws IOException {
226-
final UnsafeSorterSpillMerger spillMerger =
227-
new UnsafeSorterSpillMerger(recordComparator, prefixComparator);
228-
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
229-
spillMerger.addSpill(spillWriter.getReader(blockManager));
226+
final UnsafeSorterIterator inMemoryIterator = sorter.getSortedIterator();
227+
if (!spillWriters.isEmpty()) {
228+
final UnsafeSorterSpillMerger spillMerger =
229+
new UnsafeSorterSpillMerger(recordComparator, prefixComparator);
230+
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
231+
spillMerger.addSpill(spillWriter.getReader(blockManager));
232+
}
233+
spillWriters.clear();
234+
if (inMemoryIterator.hasNext()) {
235+
spillMerger.addSpill(inMemoryIterator);
236+
}
237+
return spillMerger.getSortedIterator();
238+
} else {
239+
return inMemoryIterator;
230240
}
231-
spillWriters.clear();
232-
spillMerger.addSpill(sorter.getSortedIterator());
233-
return spillMerger.getSortedIterator();
234241
}
235242
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.util.Comparator;
2121

22+
import org.apache.spark.unsafe.PlatformDependent;
2223
import org.apache.spark.util.collection.Sorter;
2324
import org.apache.spark.unsafe.memory.TaskMemoryManager;
2425

@@ -50,10 +51,10 @@ private static final class SortComparator implements Comparator<RecordPointerAnd
5051
public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) {
5152
final int prefixComparisonResult = prefixComparator.compare(r1.keyPrefix, r2.keyPrefix);
5253
if (prefixComparisonResult == 0) {
53-
final Object baseObject1 = memoryManager.getPage(r2.recordPointer);
54-
final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer);
54+
final Object baseObject1 = memoryManager.getPage(r1.recordPointer);
55+
final long baseOffset1 = memoryManager.getOffsetInPage(r1.recordPointer) + 4; // skip length
5556
final Object baseObject2 = memoryManager.getPage(r2.recordPointer);
56-
final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer);
57+
final long baseOffset2 = memoryManager.getOffsetInPage(r2.recordPointer) + 4; // skip length
5758
return recordComparator.compare(baseObject1, baseOffset1, baseObject2, baseOffset2);
5859
} else {
5960
return prefixComparisonResult;
@@ -146,7 +147,8 @@ public boolean hasNext() {
146147
public void loadNext() {
147148
final long recordPointer = sortBuffer[position];
148149
baseObject = memoryManager.getPage(recordPointer);
149-
baseOffset = memoryManager.getOffsetInPage(recordPointer);
150+
baseOffset = memoryManager.getOffsetInPage(recordPointer) + 4; // Skip over record length
151+
recordLength = PlatformDependent.UNSAFE.getInt(baseObject, baseOffset - 4);
150152
keyPrefix = sortBuffer[position + 1];
151153
position += 2;
152154
}

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,37 @@ public Object get(int i) {
279279
}
280280
}
281281

282+
/**
283+
* Generic `get()`, for use in toString(). This method is for debugging only and is probably very
284+
* slow to call due to having to reflect on the schema.
285+
*/
286+
private Object genericGet(int i) {
287+
assertIndexIsValid(i);
288+
assert (schema != null) : "Schema must be defined when calling genericGet()";
289+
final DataType dataType = schema.fields()[i].dataType();
290+
if (isNullAt(i) || dataType == NullType) {
291+
return null;
292+
} else if (dataType == StringType) {
293+
return getUTF8String(i);
294+
} else if (dataType == BooleanType) {
295+
return getBoolean(i);
296+
} else if (dataType == ByteType) {
297+
return getByte(i);
298+
} else if (dataType == ShortType) {
299+
return getShort(i);
300+
} else if (dataType == IntegerType) {
301+
return getInt(i);
302+
} else if (dataType == LongType) {
303+
return getLong(i);
304+
} else if (dataType == FloatType) {
305+
return getFloat(i);
306+
} else if (dataType == DoubleType) {
307+
return getDouble(i);
308+
} else {
309+
throw new UnsupportedOperationException();
310+
}
311+
}
312+
282313
@Override
283314
public boolean isNullAt(int i) {
284315
assertIndexIsValid(i);

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,18 @@ import org.apache.spark.sql.types.DataType
3232
import org.apache.spark.util.MutablePair
3333
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
3434

35+
import scala.util.control.NonFatal
36+
37+
object Exchange {
38+
/**
39+
* Returns true when the ordering expressions are a subset of the key.
40+
* if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
41+
*/
42+
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
43+
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
44+
}
45+
}
46+
3547
/**
3648
* :: DeveloperApi ::
3749
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
@@ -181,6 +193,10 @@ case class Exchange(
181193
}
182194
}
183195
val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
196+
if (newOrdering.nonEmpty) {
197+
println("Shuffling with a key ordering")
198+
shuffled.setKeyOrdering(keyOrdering)
199+
}
184200
shuffled.setSerializer(serializer)
185201
shuffled.map(_._2)
186202

@@ -292,6 +308,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
292308
partitioning: Partitioning,
293309
rowOrdering: Seq[SortOrder],
294310
child: SparkPlan): SparkPlan = {
311+
logInfo("In addOperatorsIfNecessary")
295312
val needSort = rowOrdering.nonEmpty && child.outputOrdering != rowOrdering
296313
val needsShuffle = child.outputPartitioning != partitioning
297314

@@ -301,9 +318,27 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
301318
child
302319
}
303320

304-
val withSort = if (needSort) {
305-
if (sqlContext.conf.externalSortEnabled) {
306-
ExternalSort(rowOrdering, global = false, withShuffle)
321+
val withSort = if (needSort) {
322+
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
323+
// supports the given schema.
324+
val supportsUnsafeRowConversion: Boolean = try {
325+
new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray)
326+
true
327+
} catch {
328+
case NonFatal(e) =>
329+
false
330+
}
331+
logInfo(s"For row with data types ${withShuffle.schema.map(_.dataType)}, " +
332+
s"supportsUnsafeRowConversion = $supportsUnsafeRowConversion")
333+
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
334+
logInfo("Using unsafe external sort!")
335+
UnsafeExternalSort(rowOrdering, global = false, withShuffle)
336+
} else if (sqlContext.conf.externalSortEnabled) {
337+
logInfo("Not using unsafe sort")
338+
ExternalSort(rowOrdering, global = false, withShuffle)
339+
} else {
340+
Sort(rowOrdering, global = false, withShuffle)
341+
}
307342
} else {
308343
Sort(rowOrdering, global = false, withShuffle)
309344
}
@@ -317,6 +352,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
317352
if (meetsRequirements && compatible && !needsAnySort) {
318353
operator
319354
} else {
355+
logInfo("Looking through Exchange")
320356
// At least one child does not satisfies its required data distribution or
321357
// at least one child's outputPartitioning is not compatible with another child's
322358
// outputPartitioning. In this case, we need to add Exchange operators.
@@ -334,7 +370,21 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
334370
case (UnspecifiedDistribution, Seq(), child) =>
335371
child
336372
case (UnspecifiedDistribution, rowOrdering, child) =>
337-
if (sqlContext.conf.externalSortEnabled) {
373+
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
374+
// supports the given schema.
375+
val supportsUnsafeRowConversion: Boolean = try {
376+
new UnsafeRowConverter(child.schema.map(_.dataType).toArray)
377+
true
378+
} catch {
379+
case NonFatal(e) =>
380+
false
381+
}
382+
logInfo(s"For row with data types ${child.schema.map(_.dataType)}, " +
383+
s"supportsUnsafeRowConversion = $supportsUnsafeRowConversion")
384+
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
385+
logInfo("Using unsafe external sort!")
386+
UnsafeExternalSort(rowOrdering, global = false, child)
387+
} else if (sqlContext.conf.externalSortEnabled) {
338388
ExternalSort(rowOrdering, global = false, child)
339389
} else {
340390
Sort(rowOrdering, global = false, child)

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
7474
* [[org.apache.spark.sql.SQLConf.AUTO_BROADCASTJOIN_THRESHOLD]], the planner would mark it as the
7575
* ''build'' relation and mark the other relation as the ''stream'' side. The build table will be
7676
* ''broadcasted'' to all of the executors involved in the join, as a
77-
* [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
77+
* [[org.apache.spark.broadcast.Broadcast]] object. If both estimates exceed the threshold, they
7878
* will instead be used to decide the build side in a [[joins.ShuffledHashJoin]].
7979
*/
8080
object HashJoin extends Strategy with PredicateHelper {
@@ -102,8 +102,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
102102
// for now let's support inner join first, then add outer join
103103
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right)
104104
if sqlContext.conf.sortMergeJoinEnabled =>
105-
val mergeJoin =
105+
val mergeJoin = if (sqlContext.conf.unsafeEnabled) {
106+
joins.UnsafeSortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
107+
} else {
106108
joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right))
109+
}
107110
condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil
108111

109112
case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import java.util.Arrays
21+
22+
import org.apache.spark.sql.types.StructType
23+
import org.apache.spark.unsafe.PlatformDependent
24+
import org.apache.spark.util.collection.unsafe.sort.{RecordComparator, PrefixComparator, UnsafeExternalSorter}
25+
import org.apache.spark.{TaskContext, SparkEnv, HashPartitioner, SparkConf}
2026
import org.apache.spark.annotation.DeveloperApi
2127
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2228
import org.apache.spark.shuffle.sort.SortShuffleManager
@@ -245,6 +251,119 @@ case class ExternalSort(
245251
override def outputOrdering: Seq[SortOrder] = sortOrder
246252
}
247253

254+
/**
255+
* :: DeveloperApi ::
256+
* TODO(josh): document
257+
* Performs a sort, spilling to disk as needed.
258+
* @param global when true performs a global sort of all partitions by shuffling the data first
259+
* if necessary.
260+
*/
261+
@DeveloperApi
262+
case class UnsafeExternalSort(
263+
sortOrder: Seq[SortOrder],
264+
global: Boolean,
265+
child: SparkPlan)
266+
extends UnaryNode {
267+
268+
private[this] val numFields: Int = child.schema.size
269+
private[this] val schema: StructType = child.schema
270+
271+
override def requiredChildDistribution: Seq[Distribution] =
272+
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
273+
274+
protected override def doExecute(): RDD[Row] = attachTree(this, "sort") {
275+
// TODO(josh): This code is unreadably messy; this should be split into a separate file
276+
// and written in Java.
277+
assert (codegenEnabled)
278+
def doSort(iterator: Iterator[Row]): Iterator[Row] = {
279+
val ordering = newOrdering(sortOrder, child.output)
280+
val rowConverter = new UnsafeRowConverter(schema.map(_.dataType).toArray)
281+
var rowConversionScratchSpace = new Array[Long](1024)
282+
val prefixComparator = new PrefixComparator {
283+
override def compare(prefix1: Long, prefix2: Long): Int = 0
284+
}
285+
val recordComparator = new RecordComparator {
286+
private[this] val row1 = new UnsafeRow
287+
private[this] val row2 = new UnsafeRow
288+
override def compare(
289+
baseObj1: scala.Any, baseOff1: Long, baseObj2: scala.Any, baseOff2: Long): Int = {
290+
row1.pointTo(baseObj1, baseOff1, numFields, schema)
291+
row2.pointTo(baseObj2, baseOff2, numFields, schema)
292+
ordering.compare(row1, row2)
293+
}
294+
}
295+
val sorter = new UnsafeExternalSorter(
296+
TaskContext.get.taskMemoryManager(),
297+
SparkEnv.get.shuffleMemoryManager,
298+
SparkEnv.get.blockManager,
299+
TaskContext.get,
300+
recordComparator,
301+
prefixComparator,
302+
4096,
303+
SparkEnv.get.conf
304+
)
305+
while (iterator.hasNext) {
306+
val row: Row = iterator.next()
307+
val sizeRequirement = rowConverter.getSizeRequirement(row)
308+
if (sizeRequirement / 8 > rowConversionScratchSpace.length) {
309+
rowConversionScratchSpace = new Array[Long](sizeRequirement / 8)
310+
} else {
311+
// Zero out the buffer that's used to hold the current row. This is necessary in order
312+
// to ensure that rows hash properly, since garbage data from the previous row could
313+
// otherwise end up as padding in this row. As a performance optimization, we only zero
314+
// out the portion of the buffer that we'll actually write to.
315+
Arrays.fill(rowConversionScratchSpace, 0, sizeRequirement / 8, 0)
316+
}
317+
val bytesWritten =
318+
rowConverter.writeRow(row, rowConversionScratchSpace, PlatformDependent.LONG_ARRAY_OFFSET)
319+
assert (bytesWritten == sizeRequirement)
320+
val prefix: Long = 0 // dummy prefix until we implement prefix calculation
321+
sorter.insertRecord(
322+
rowConversionScratchSpace,
323+
PlatformDependent.LONG_ARRAY_OFFSET,
324+
sizeRequirement,
325+
prefix
326+
)
327+
}
328+
val sortedIterator = sorter.getSortedIterator
329+
// TODO: need to avoid memory leaks on exceptions, etc. by wrapping in resource cleanup blocks
330+
// TODO: need to clean up spill files after success or failure.
331+
new Iterator[Row] {
332+
private[this] val row = new UnsafeRow()
333+
override def hasNext: Boolean = sortedIterator.hasNext
334+
335+
override def next(): Row = {
336+
sortedIterator.loadNext()
337+
if (hasNext) {
338+
row.pointTo(
339+
sortedIterator.getBaseObject, sortedIterator.getBaseOffset, numFields, schema)
340+
println("Returned row " + row)
341+
row
342+
} else {
343+
val rowDataCopy = new Array[Byte](sortedIterator.getRecordLength)
344+
PlatformDependent.copyMemory(
345+
sortedIterator.getBaseObject,
346+
sortedIterator.getBaseOffset,
347+
rowDataCopy,
348+
PlatformDependent.BYTE_ARRAY_OFFSET,
349+
sortedIterator.getRecordLength
350+
)
351+
row.pointTo(rowDataCopy, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, schema)
352+
sorter.freeMemory()
353+
row
354+
}
355+
}
356+
}
357+
}
358+
child.execute().mapPartitions(doSort, preservesPartitioning = true)
359+
}
360+
361+
override def output: Seq[Attribute] = child.output
362+
363+
override def outputOrdering: Seq[SortOrder] = sortOrder
364+
}
365+
366+
248367
/**
249368
* :: DeveloperApi ::
250369
* Return a new RDD that has exactly `numPartitions` partitions.

sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/SortMergeCompatibilitySuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution
1919

2020
import org.apache.spark.sql.SQLConf
2121
import org.apache.spark.sql.hive.test.TestHive
22+
import org.apache.spark.sql.test.TestSQLContext._
2223

2324
/**
2425
* Runs the test cases that are included in the hive distribution with sort merge join is true.

0 commit comments

Comments
 (0)