Skip to content

Commit 303b6da

Browse files
committed
fix several errors
1 parent 95db7ad commit 303b6da

File tree

3 files changed

+49
-43
lines changed

3 files changed

+49
-43
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
8686
}
8787
}
8888
val sortingExpressions = expressions.map(s => new SortOrder(s, Ascending))
89-
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
89+
val ordering = new RowOrdering(sortingExpressions, child.output)
9090
val part = new HashPartitioner(numPartitions)
9191
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part).setKeyOrdering(ordering)
9292
shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,21 @@ case class SortMergeJoin(
6767
private[this] var currentlPosition: Int = -1
6868
private[this] var currentrPosition: Int = -1
6969

70-
override final def hasNext: Boolean =
71-
(currentlPosition != -1 && currentlPosition < currentlMatches.size) ||
72-
nextMatchingPair
70+
override final def hasNext: Boolean = currentlPosition != -1 || nextMatchingPair
7371

7472
override final def next(): Row = {
73+
if (!hasNext) {
74+
return null
75+
}
7576
val joinedRow =
7677
joinRow(currentlMatches(currentlPosition), currentrMatches(currentrPosition))
7778
currentrPosition += 1
7879
if (currentrPosition >= currentrMatches.size) {
7980
currentlPosition += 1
8081
currentrPosition = 0
82+
if (currentlPosition >= currentlMatches.size) {
83+
currentlPosition = -1
84+
}
8185
}
8286
joinedRow
8387
}
@@ -100,13 +104,13 @@ case class SortMergeJoin(
100104
}
101105
}
102106

103-
// initialize iterator
104-
private def initialize() = {
107+
private def fetchFirst() = {
105108
fetchLeft()
106109
fetchRight()
110+
currentrPosition = 0
107111
}
108-
109-
initialize()
112+
// initialize iterator
113+
fetchFirst()
110114

111115
/**
112116
* Searches the left/right iterator for the next rows that matches.
@@ -115,49 +119,49 @@ case class SortMergeJoin(
115119
* of tuples.
116120
*/
117121
private def nextMatchingPair(): Boolean = {
118-
currentlPosition = -1
119-
currentlMatches = null
120-
var stop: Boolean = false
121-
while (!stop && leftElement != null && rightElement != null) {
122-
if (ordering.compare(leftKey, rightKey) > 0) {
123-
fetchRight()
124-
} else if (ordering.compare(leftKey, rightKey) < 0) {
125-
fetchLeft()
126-
} else {
127-
stop = true
122+
if (currentlPosition > -1) {
123+
true
124+
} else {
125+
currentlPosition = -1
126+
currentlMatches = null
127+
var stop: Boolean = false
128+
while (!stop && leftElement != null && rightElement != null) {
129+
if (ordering.compare(leftKey, rightKey) == 0 && !leftKey.anyNull) {
130+
stop = true
131+
} else if (ordering.compare(leftKey, rightKey) > 0 || rightKey.anyNull) {
132+
fetchRight()
133+
} else { //if (ordering.compare(leftKey, rightKey) < 0 || leftKey.anyNull)
134+
fetchLeft()
135+
}
128136
}
129-
}
130-
currentrMatches = new CompactBuffer[Row]()
131-
while (stop && rightElement != null) {
132-
if (!rightKey.anyNull) {
137+
currentrMatches = new CompactBuffer[Row]()
138+
while (stop && rightElement != null) {
133139
currentrMatches += rightElement
140+
fetchRight()
141+
if (ordering.compare(leftKey, rightKey) != 0) {
142+
stop = false
143+
}
134144
}
135-
fetchRight()
136-
if (ordering.compare(leftKey, rightKey) != 0) {
145+
if (currentrMatches.size > 0) {
137146
stop = false
138-
}
139-
}
140-
if (currentrMatches.size > 0) {
141-
stop = false
142-
currentlMatches = new CompactBuffer[Row]()
143-
val leftMatch = leftKey.copy()
144-
while (!stop && leftElement != null) {
145-
if (!leftKey.anyNull) {
147+
currentlMatches = new CompactBuffer[Row]()
148+
val leftMatch = leftKey.copy()
149+
while (!stop && leftElement != null) {
146150
currentlMatches += leftElement
147-
}
148-
fetchLeft()
149-
if (ordering.compare(leftKey, leftMatch) != 0) {
150-
stop = true
151+
fetchLeft()
152+
if (ordering.compare(leftKey, leftMatch) != 0) {
153+
stop = true
154+
}
151155
}
152156
}
153-
}
154157

155-
if (currentlMatches == null) {
156-
false
157-
} else {
158-
currentlPosition = 0
159-
currentrPosition = 0
160-
true
158+
if (currentlMatches == null) {
159+
false
160+
} else {
161+
currentlPosition = 0
162+
currentrPosition = 0
163+
true
164+
}
161165
}
162166
}
163167
}

sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
144144
expectedAnswer: Seq[Row],
145145
ct: ClassTag[_]) = {
146146
before()
147+
conf.setConf("spark.sql.autoSortMergeJoin", "false")
147148

148149
var df = sql(query)
149150

@@ -178,6 +179,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
178179
sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=$tmp""")
179180
}
180181

182+
conf.setConf("spark.sql.autoSortMergeJoin", "true")
181183
after()
182184
}
183185

0 commit comments

Comments
 (0)