Skip to content

Commit 58f36d0

Browse files
committed
Merge in a sketch of a unit test for the new sorter (now failing).
1 parent 2bd8c9a commit 58f36d0

File tree

6 files changed

+348
-3
lines changed

6 files changed

+348
-3
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,4 @@ public UnsafeSorterIterator getSortedIterator() throws IOException {
232232
spillMerger.addSpill(sorter.getSortedIterator());
233233
return spillMerger.getSortedIterator();
234234
}
235-
}
235+
}

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ final class UnsafeSorterSpillMerger {
2626
private final PriorityQueue<UnsafeSorterIterator> priorityQueue;
2727

2828
public UnsafeSorterSpillMerger(
29-
final RecordComparator recordComparator,
30-
final PrefixComparator prefixComparator) {
29+
final RecordComparator recordComparator,
30+
final PrefixComparator prefixComparator) {
3131
final Comparator<UnsafeSorterIterator> comparator = new Comparator<UnsafeSorterIterator>() {
3232

3333
@Override
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.joins
19+
20+
import java.util.NoSuchElementException
21+
22+
import org.apache.spark.annotation.DeveloperApi
23+
import org.apache.spark.rdd.RDD
24+
import org.apache.spark.sql.Row
25+
import org.apache.spark.sql.catalyst.expressions._
26+
import org.apache.spark.sql.catalyst.plans._
27+
import org.apache.spark.sql.catalyst.plans.physical._
28+
import org.apache.spark.sql.execution.{UnsafeExternalSort, BinaryNode, SparkPlan}
29+
import org.apache.spark.util.collection.CompactBuffer
30+
31+
/**
32+
* :: DeveloperApi ::
33+
* Performs an sort merge join of two child relations.
34+
* TODO(josh): Document
35+
*/
36+
@DeveloperApi
37+
case class UnsafeSortMergeJoin(
38+
leftKeys: Seq[Expression],
39+
rightKeys: Seq[Expression],
40+
left: SparkPlan,
41+
right: SparkPlan) extends BinaryNode {
42+
43+
override def output: Seq[Attribute] = left.output ++ right.output
44+
45+
override def outputPartitioning: Partitioning = left.outputPartitioning
46+
47+
override def requiredChildDistribution: Seq[Distribution] =
48+
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
49+
50+
// this is to manually construct an ordering that can be used to compare keys from both sides
51+
private val keyOrdering: RowOrdering = RowOrdering.forSchema(leftKeys.map(_.dataType))
52+
53+
override def outputOrdering: Seq[SortOrder] = requiredOrders(leftKeys)
54+
55+
@transient protected lazy val leftKeyGenerator = newProjection(leftKeys, left.output)
56+
@transient protected lazy val rightKeyGenerator = newProjection(rightKeys, right.output)
57+
58+
private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] =
59+
keys.map(SortOrder(_, Ascending))
60+
61+
protected override def doExecute(): RDD[Row] = {
62+
// Note that we purposely do not require out input to be sorted. Instead, we'll sort it
63+
// ourselves using UnsafeExternalSorter. Not requiring the input to be sorted will prevent the
64+
// Exchange from pushing the sort into the shuffle, which will allow the shuffle to benefit from
65+
// Project Tungsten's shuffle optimizations which currently cannot be applied to shuffles that
66+
// specify a key ordering.
67+
68+
// Only sort if necessary:
69+
val leftOrder = requiredOrders(leftKeys)
70+
val leftResults = {
71+
if (left.outputOrdering == leftOrder) {
72+
left.execute().map(_.copy())
73+
} else {
74+
new UnsafeExternalSort(leftOrder, global = false, left).execute()
75+
}
76+
}
77+
val rightOrder = requiredOrders(rightKeys)
78+
val rightResults = {
79+
if (right.outputOrdering == rightOrder) {
80+
right.execute().map(_.copy())
81+
} else {
82+
new UnsafeExternalSort(rightOrder, global = false, right).execute()
83+
}
84+
}
85+
86+
leftResults.zipPartitions(rightResults) { (leftIter, rightIter) =>
87+
new Iterator[Row] {
88+
// Mutable per row objects.
89+
private[this] val joinRow = new JoinedRow5
90+
private[this] var leftElement: Row = _
91+
private[this] var rightElement: Row = _
92+
private[this] var leftKey: Row = _
93+
private[this] var rightKey: Row = _
94+
private[this] var rightMatches: CompactBuffer[Row] = _
95+
private[this] var rightPosition: Int = -1
96+
private[this] var stop: Boolean = false
97+
private[this] var matchKey: Row = _
98+
99+
// initialize iterator
100+
initialize()
101+
102+
override final def hasNext: Boolean = nextMatchingPair()
103+
104+
override final def next(): Row = {
105+
if (hasNext) {
106+
// we are using the buffered right rows and run down left iterator
107+
val joinedRow = joinRow(leftElement, rightMatches(rightPosition))
108+
rightPosition += 1
109+
if (rightPosition >= rightMatches.size) {
110+
rightPosition = 0
111+
fetchLeft()
112+
if (leftElement == null || keyOrdering.compare(leftKey, matchKey) != 0) {
113+
stop = false
114+
rightMatches = null
115+
}
116+
}
117+
joinedRow
118+
} else {
119+
// no more result
120+
throw new NoSuchElementException
121+
}
122+
}
123+
124+
private def fetchLeft() = {
125+
if (leftIter.hasNext) {
126+
leftElement = leftIter.next()
127+
println(leftElement)
128+
leftKey = leftKeyGenerator(leftElement)
129+
} else {
130+
leftElement = null
131+
}
132+
}
133+
134+
private def fetchRight() = {
135+
if (rightIter.hasNext) {
136+
rightElement = rightIter.next()
137+
println(right)
138+
rightKey = rightKeyGenerator(rightElement)
139+
} else {
140+
rightElement = null
141+
}
142+
}
143+
144+
private def initialize() = {
145+
fetchLeft()
146+
fetchRight()
147+
}
148+
149+
/**
150+
* Searches the right iterator for the next rows that have matches in left side, and store
151+
* them in a buffer.
152+
*
153+
* @return true if the search is successful, and false if the right iterator runs out of
154+
* tuples.
155+
*/
156+
private def nextMatchingPair(): Boolean = {
157+
if (!stop && rightElement != null) {
158+
// run both side to get the first match pair
159+
while (!stop && leftElement != null && rightElement != null) {
160+
val comparing = keyOrdering.compare(leftKey, rightKey)
161+
// for inner join, we need to filter those null keys
162+
stop = comparing == 0 && !leftKey.anyNull
163+
if (comparing > 0 || rightKey.anyNull) {
164+
fetchRight()
165+
} else if (comparing < 0 || leftKey.anyNull) {
166+
fetchLeft()
167+
}
168+
}
169+
rightMatches = new CompactBuffer[Row]()
170+
if (stop) {
171+
stop = false
172+
// iterate the right side to buffer all rows that matches
173+
// as the records should be ordered, exit when we meet the first that not match
174+
while (!stop && rightElement != null) {
175+
rightMatches += rightElement
176+
fetchRight()
177+
stop = keyOrdering.compare(leftKey, rightKey) != 0
178+
}
179+
if (rightMatches.size > 0) {
180+
rightPosition = 0
181+
matchKey = leftKey
182+
}
183+
}
184+
}
185+
rightMatches != null && rightMatches.size > 0
186+
}
187+
}
188+
}
189+
}
190+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.TestData._
21+
import org.apache.spark.sql.execution.UnsafeExternalSort
22+
import org.apache.spark.sql.execution.joins._
23+
import org.apache.spark.sql.test.TestSQLContext._
24+
import org.apache.spark.sql.test.TestSQLContext.implicits._
25+
import org.scalatest.BeforeAndAfterEach
26+
27+
class UnsafeSortMergeJoinSuite extends QueryTest with BeforeAndAfterEach {
28+
// Ensures tables are loaded.
29+
TestData
30+
31+
conf.setConf(SQLConf.SORTMERGE_JOIN, "true")
32+
conf.setConf(SQLConf.CODEGEN_ENABLED, "true")
33+
conf.setConf(SQLConf.UNSAFE_ENABLED, "true")
34+
conf.setConf(SQLConf.EXTERNAL_SORT, "true")
35+
conf.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, "-1")
36+
37+
test("basic sort merge join test") {
38+
val df = upperCaseData.join(lowerCaseData, $"n" === $"N")
39+
print(df.queryExecution.optimizedPlan)
40+
assert(df.queryExecution.sparkPlan.collect {
41+
case smj: UnsafeSortMergeJoin => smj
42+
}.nonEmpty)
43+
checkAnswer(
44+
df,
45+
Seq(
46+
Row(1, "A", 1, "a"),
47+
Row(2, "B", 2, "b"),
48+
Row(3, "C", 3, "c"),
49+
Row(4, "D", 4, "d")
50+
))
51+
}
52+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution
19+
20+
import org.apache.spark.sql.catalyst.CatalystTypeConverters
21+
import org.scalatest.{FunSuite, Matchers}
22+
23+
import org.apache.spark.sql.{SQLConf, SQLContext, Row}
24+
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.types._
26+
import org.apache.spark.sql.test.TestSQLContext
27+
import org.apache.spark.sql.test.TestSQLContext.implicits._
28+
29+
class UnsafeExternalSortSuite extends FunSuite with Matchers {
30+
31+
private def createRow(values: Any*): Row = {
32+
new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray)
33+
}
34+
35+
test("basic sorting") {
36+
val sc = TestSQLContext.sparkContext
37+
val sqlContext = new SQLContext(sc)
38+
sqlContext.conf.setConf(SQLConf.CODEGEN_ENABLED, "true")
39+
40+
val schema: StructType = StructType(
41+
StructField("word", StringType, nullable = false) ::
42+
StructField("number", IntegerType, nullable = false) :: Nil)
43+
val sortOrder: Seq[SortOrder] = Seq(
44+
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
45+
SortOrder(BoundReference(1, IntegerType, nullable = false), Descending))
46+
val rowsToSort: Seq[Row] = Seq(
47+
createRow("Hello", 9),
48+
createRow("World", 4),
49+
createRow("Hello", 7),
50+
createRow("Skinny", 0),
51+
createRow("Constantinople", 9))
52+
SparkPlan.currentContext.set(sqlContext)
53+
val input =
54+
new PhysicalRDD(schema.toAttributes.map(_.toAttribute), sc.parallelize(rowsToSort, 1))
55+
// Treat the existing sort operators as the source-of-truth for this test
56+
val defaultSorted = new Sort(sortOrder, global = false, input).executeCollect()
57+
val externalSorted = new ExternalSort(sortOrder, global = false, input).executeCollect()
58+
val unsafeSorted = new UnsafeExternalSort(sortOrder, global = false, input).executeCollect()
59+
assert (defaultSorted === externalSorted)
60+
assert (unsafeSorted === externalSorted)
61+
}
62+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution
19+
20+
import org.apache.spark.sql.SQLConf
21+
import org.apache.spark.sql.hive.test.TestHive
22+
23+
/**
24+
* Runs the test cases that are included in the hive distribution with sort merge join and
25+
* unsafe external sort enabled.
26+
*/
27+
class UnsafeSortMergeCompatibiltySuite extends SortMergeCompatibilitySuite {
28+
override def beforeAll() {
29+
super.beforeAll()
30+
TestHive.setConf(SQLConf.CODEGEN_ENABLED, "true")
31+
TestHive.setConf(SQLConf.UNSAFE_ENABLED, "true")
32+
TestHive.setConf(SQLConf.EXTERNAL_SORT, "true")
33+
}
34+
35+
override def afterAll() {
36+
TestHive.setConf(SQLConf.CODEGEN_ENABLED, "false")
37+
TestHive.setConf(SQLConf.UNSAFE_ENABLED, "false")
38+
TestHive.setConf(SQLConf.EXTERNAL_SORT, "false")
39+
super.afterAll()
40+
}
41+
}

0 commit comments

Comments
 (0)