Skip to content

Commit d468a88

Browse files
committed
Update for InternalRow refactoring
1 parent 269cf86 commit d468a88

File tree

5 files changed

+39
-66
lines changed

5 files changed

+39
-66
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import org.apache.spark.SparkEnv;
2828
import org.apache.spark.TaskContext;
2929
import org.apache.spark.sql.AbstractScalaRowIterator;
30-
import org.apache.spark.sql.Row;
30+
import org.apache.spark.sql.catalyst.InternalRow;
3131
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
3232
import org.apache.spark.sql.catalyst.expressions.UnsafeRowConverter;
3333
import org.apache.spark.sql.types.StructType;
@@ -43,22 +43,22 @@ final class UnsafeExternalRowSorter {
4343
private final UnsafeRowConverter rowConverter;
4444
private final RowComparator rowComparator;
4545
private final PrefixComparator prefixComparator;
46-
private final Function1<Row, Long> prefixComputer;
46+
private final Function1<InternalRow, Long> prefixComputer;
4747

4848
public UnsafeExternalRowSorter(
4949
StructType schema,
50-
Ordering<Row> ordering,
50+
Ordering<InternalRow> ordering,
5151
PrefixComparator prefixComparator,
5252
// TODO: if possible, avoid this boxing of the return value
53-
Function1<Row, Long> prefixComputer) {
53+
Function1<InternalRow, Long> prefixComputer) {
5454
this.schema = schema;
5555
this.rowConverter = new UnsafeRowConverter(schema);
5656
this.rowComparator = new RowComparator(ordering, schema);
5757
this.prefixComparator = prefixComparator;
5858
this.prefixComputer = prefixComputer;
5959
}
6060

61-
public Iterator<Row> sort(Iterator<Row> inputIterator) throws IOException {
61+
public Iterator<InternalRow> sort(Iterator<InternalRow> inputIterator) throws IOException {
6262
final SparkEnv sparkEnv = SparkEnv.get();
6363
final TaskContext taskContext = TaskContext.get();
6464
byte[] rowConversionBuffer = new byte[1024 * 8];
@@ -74,7 +74,7 @@ public Iterator<Row> sort(Iterator<Row> inputIterator) throws IOException {
7474
);
7575
try {
7676
while (inputIterator.hasNext()) {
77-
final Row row = inputIterator.next();
77+
final InternalRow row = inputIterator.next();
7878
final int sizeRequirement = rowConverter.getSizeRequirement(row);
7979
if (sizeRequirement > rowConversionBuffer.length) {
8080
rowConversionBuffer = new byte[sizeRequirement];
@@ -108,7 +108,7 @@ public boolean hasNext() {
108108
}
109109

110110
@Override
111-
public Row next() {
111+
public InternalRow next() {
112112
try {
113113
sortedIterator.loadNext();
114114
if (hasNext()) {
@@ -150,12 +150,12 @@ public Row next() {
150150

151151
private static final class RowComparator extends RecordComparator {
152152
private final StructType schema;
153-
private final Ordering<Row> ordering;
153+
private final Ordering<InternalRow> ordering;
154154
private final int numFields;
155155
private final UnsafeRow row1 = new UnsafeRow();
156156
private final UnsafeRow row2 = new UnsafeRow();
157157

158-
public RowComparator(Ordering<Row> ordering, StructType schema) {
158+
public RowComparator(Ordering<InternalRow> ordering, StructType schema) {
159159
this.schema = schema;
160160
this.numFields = schema.length();
161161
this.ordering = ordering;

sql/catalyst/src/main/scala/org/apache/spark/sql/AbstractScalaRowIterator.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.InternalRow
21+
2022
/**
2123
* Shim to allow us to implement [[scala.Iterator]] in Java. Scala 2.11+ has an AbstractIterator
2224
* class for this, but that class is `private[scala]` in 2.10. We need to explicitly fix this to
2325
* `Row` in order to work around a spurious IntelliJ compiler error.
2426
*/
25-
private[spark] abstract class AbstractScalaRowIterator extends Iterator[Row]
27+
private[spark] abstract class AbstractScalaRowIterator extends Iterator[InternalRow]

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution
1919

2020
import scala.util.control.NonFatal
2121

22-
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
2322
import org.apache.spark.annotation.DeveloperApi
2423
import org.apache.spark.rdd.{RDD, ShuffledRDD}
2524
import org.apache.spark.serializer.Serializer
@@ -35,16 +34,6 @@ import org.apache.spark.sql.types.DataType
3534
import org.apache.spark.util.MutablePair
3635
import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
3736

38-
object Exchange {
39-
/**
40-
* Returns true when the ordering expressions are a subset of the key.
41-
* if true, ShuffledRDD can use `setKeyOrdering(orderingKey)` to sort within [[Exchange]].
42-
*/
43-
def canSortWithShuffle(partitioning: Partitioning, desiredOrdering: Seq[SortOrder]): Boolean = {
44-
desiredOrdering.map(_.child).toSet.subsetOf(partitioning.keyExpressions.toSet)
45-
}
46-
}
47-
4837
/**
4938
* :: DeveloperApi ::
5039
* Performs a shuffle that will result in the desired `newPartitioning`. Optionally sorts each
@@ -194,9 +183,6 @@ case class Exchange(
194183
}
195184
}
196185
val shuffled = new ShuffledRDD[InternalRow, InternalRow, InternalRow](rdd, part)
197-
if (newOrdering.nonEmpty) {
198-
shuffled.setKeyOrdering(keyOrdering)
199-
}
200186
shuffled.setSerializer(serializer)
201187
shuffled.map(_._2)
202188

@@ -317,23 +303,20 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
317303
child
318304
}
319305

320-
val withSort = if (needSort) {
321-
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
322-
// supports the given schema.
323-
val supportsUnsafeRowConversion: Boolean = try {
324-
new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray)
325-
true
326-
} catch {
327-
case NonFatal(e) =>
328-
false
329-
}
330-
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
331-
UnsafeExternalSort(rowOrdering, global = false, withShuffle)
332-
} else if (sqlContext.conf.externalSortEnabled) {
333-
ExternalSort(rowOrdering, global = false, withShuffle)
334-
} else {
335-
Sort(rowOrdering, global = false, withShuffle)
336-
}
306+
val withSort = if (needSort) {
307+
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
308+
// supports the given schema.
309+
val supportsUnsafeRowConversion: Boolean = try {
310+
new UnsafeRowConverter(withShuffle.schema.map(_.dataType).toArray)
311+
true
312+
} catch {
313+
case NonFatal(e) =>
314+
false
315+
}
316+
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
317+
UnsafeExternalSort(rowOrdering, global = false, withShuffle)
318+
} else if (sqlContext.conf.externalSortEnabled) {
319+
ExternalSort(rowOrdering, global = false, withShuffle)
337320
} else {
338321
Sort(rowOrdering, global = false, withShuffle)
339322
}
@@ -364,18 +347,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
364347
case (UnspecifiedDistribution, Seq(), child) =>
365348
child
366349
case (UnspecifiedDistribution, rowOrdering, child) =>
367-
// TODO(josh): this is a hack. Need a better way to determine whether UnsafeRow
368-
// supports the given schema.
369-
val supportsUnsafeRowConversion: Boolean = try {
370-
new UnsafeRowConverter(child.schema.map(_.dataType).toArray)
371-
true
372-
} catch {
373-
case NonFatal(e) =>
374-
false
375-
}
376-
if (sqlContext.conf.unsafeEnabled && supportsUnsafeRowConversion) {
377-
UnsafeExternalSort(rowOrdering, global = false, child)
378-
} else if (sqlContext.conf.externalSortEnabled) {
350+
if (sqlContext.conf.externalSortEnabled) {
379351
ExternalSort(rowOrdering, global = false, child)
380352
} else {
381353
Sort(rowOrdering, global = false, child)

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,15 @@ case class UnsafeExternalSort(
268268
override def requiredChildDistribution: Seq[Distribution] =
269269
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
270270

271-
protected override def doExecute(): RDD[Row] = attachTree(this, "sort") {
271+
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
272272
assert (codegenEnabled)
273-
def doSort(iterator: Iterator[Row]): Iterator[Row] = {
273+
def doSort(iterator: Iterator[InternalRow]): Iterator[InternalRow] = {
274274
val ordering = newOrdering(sortOrder, child.output)
275275
val prefixComparator = new PrefixComparator {
276276
override def compare(prefix1: Long, prefix2: Long): Int = 0
277277
}
278278
// TODO: do real prefix comparsion. For dev/testing purposes, this is a dummy implementation.
279-
def prefixComputer(row: Row): Long = 0
279+
def prefixComputer(row: InternalRow): Long = 0
280280
new UnsafeExternalRowSorter(schema, ordering, prefixComparator, prefixComputer).sort(iterator)
281281
}
282282
child.execute().mapPartitions(doSort, preservesPartitioning = true)

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.util.NoSuchElementException
2222
import org.apache.spark.annotation.DeveloperApi
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.catalyst.plans._
2625
import org.apache.spark.sql.catalyst.plans.physical._
2726
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
2827
import org.apache.spark.util.collection.CompactBuffer
@@ -64,24 +63,24 @@ case class SortMergeJoin(
6463
val rightResults = right.execute().map(_.copy())
6564

6665
leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
67-
new Iterator[Row] {
66+
new Iterator[InternalRow] {
6867
// Mutable per row objects.
6968
private[this] val joinRow = new JoinedRow5
70-
private[this] var leftElement: Row = _
71-
private[this] var rightElement: Row = _
72-
private[this] var leftKey: Row = _
73-
private[this] var rightKey: Row = _
74-
private[this] var rightMatches: CompactBuffer[Row] = _
69+
private[this] var leftElement: InternalRow = _
70+
private[this] var rightElement: InternalRow = _
71+
private[this] var leftKey: InternalRow = _
72+
private[this] var rightKey: InternalRow = _
73+
private[this] var rightMatches: CompactBuffer[InternalRow] = _
7574
private[this] var rightPosition: Int = -1
7675
private[this] var stop: Boolean = false
77-
private[this] var matchKey: Row = _
76+
private[this] var matchKey: InternalRow = _
7877

7978
// initialize iterator
8079
initialize()
8180

8281
override final def hasNext: Boolean = nextMatchingPair()
8382

84-
override final def next(): Row = {
83+
override final def next(): InternalRow = {
8584
if (hasNext) {
8685
// we are using the buffered right rows and run down left iterator
8786
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
@@ -144,7 +143,7 @@ case class SortMergeJoin(
144143
fetchLeft()
145144
}
146145
}
147-
rightMatches = new CompactBuffer[Row]()
146+
rightMatches = new CompactBuffer[InternalRow]()
148147
if (stop) {
149148
stop = false
150149
// iterate the right side to buffer all rows that matches

0 commit comments

Comments
 (0)