@@ -19,42 +19,32 @@ package org.apache.spark.sql.execution.joins
19
19
20
20
import java .util .{HashMap => JavaHashMap }
21
21
22
- import org .apache .spark .rdd .RDD
23
-
24
- import scala .collection .JavaConversions ._
25
-
26
22
import org .apache .spark .annotation .DeveloperApi
27
23
import org .apache .spark .sql .catalyst .expressions ._
28
- import org .apache .spark .sql .catalyst .plans .physical .{ClusteredDistribution , Partitioning , UnknownPartitioning }
24
+ import org .apache .spark .sql .catalyst .plans .physical .{Partitioning , UnknownPartitioning }
29
25
import org .apache .spark .sql .catalyst .plans .{FullOuter , JoinType , LeftOuter , RightOuter }
30
- import org .apache .spark .sql .execution .{ BinaryNode , SparkPlan }
26
+ import org .apache .spark .sql .execution .SparkPlan
31
27
import org .apache .spark .util .collection .CompactBuffer
32
28
33
- /**
34
- * :: DeveloperApi ::
35
- * Performs a hash based outer join for two child relations by shuffling the data using
36
- * the join keys. This operator requires loading the associated partition in both side into memory.
37
- */
38
29
@ DeveloperApi
39
- case class HashOuterJoin (
40
- leftKeys : Seq [Expression ],
41
- rightKeys : Seq [Expression ],
42
- joinType : JoinType ,
43
- condition : Option [Expression ],
44
- left : SparkPlan ,
45
- right : SparkPlan ) extends BinaryNode {
46
-
47
- override def outputPartitioning : Partitioning = joinType match {
30
+ trait HashOuterJoin {
31
+ self : SparkPlan =>
32
+
33
+ val leftKeys : Seq [Expression ]
34
+ val rightKeys : Seq [Expression ]
35
+ val joinType : JoinType
36
+ val condition : Option [Expression ]
37
+ val left : SparkPlan
38
+ val right : SparkPlan
39
+
40
+ override def outputPartitioning : Partitioning = joinType match {
48
41
case LeftOuter => left.outputPartitioning
49
42
case RightOuter => right.outputPartitioning
50
43
case FullOuter => UnknownPartitioning (left.outputPartitioning.numPartitions)
51
44
case x =>
52
45
throw new IllegalArgumentException (s " HashOuterJoin should not take $x as the JoinType " )
53
46
}
54
47
55
- override def requiredChildDistribution : Seq [ClusteredDistribution ] =
56
- ClusteredDistribution (leftKeys) :: ClusteredDistribution (rightKeys) :: Nil
57
-
58
48
override def output : Seq [Attribute ] = {
59
49
joinType match {
60
50
case LeftOuter =>
@@ -68,8 +58,8 @@ case class HashOuterJoin(
68
58
}
69
59
}
70
60
71
- @ transient private [this ] lazy val DUMMY_LIST = Seq [InternalRow ](null )
72
- @ transient private [this ] lazy val EMPTY_LIST = Seq .empty [InternalRow ]
61
+ @ transient private [this ] lazy val DUMMY_LIST = CompactBuffer [InternalRow ](null )
62
+ @ transient protected [this ] lazy val EMPTY_LIST = CompactBuffer [InternalRow ]()
73
63
74
64
@ transient private [this ] lazy val leftNullRow = new GenericInternalRow (left.output.length)
75
65
@ transient private [this ] lazy val rightNullRow = new GenericInternalRow (right.output.length)
@@ -80,7 +70,7 @@ case class HashOuterJoin(
80
70
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
81
71
// iterator for performance purpose.
82
72
83
- private [this ] def leftOuterIterator (
73
+ protected [this ] def leftOuterIterator (
84
74
key : InternalRow ,
85
75
joinedRow : JoinedRow ,
86
76
rightIter : Iterable [InternalRow ]): Iterator [InternalRow ] = {
@@ -89,7 +79,7 @@ case class HashOuterJoin(
89
79
val temp = rightIter.collect {
90
80
case r if boundCondition(joinedRow.withRight(r)) => joinedRow.copy()
91
81
}
92
- if (temp.size == 0 ) {
82
+ if (temp.isEmpty ) {
93
83
joinedRow.withRight(rightNullRow).copy :: Nil
94
84
} else {
95
85
temp
@@ -101,18 +91,17 @@ case class HashOuterJoin(
101
91
ret.iterator
102
92
}
103
93
104
- private [this ] def rightOuterIterator (
94
+ protected [this ] def rightOuterIterator (
105
95
key : InternalRow ,
106
96
leftIter : Iterable [InternalRow ],
107
97
joinedRow : JoinedRow ): Iterator [InternalRow ] = {
108
-
109
98
val ret : Iterable [InternalRow ] = {
110
99
if (! key.anyNull) {
111
100
val temp = leftIter.collect {
112
101
case l if boundCondition(joinedRow.withLeft(l)) =>
113
- joinedRow.copy
102
+ joinedRow.copy()
114
103
}
115
- if (temp.size == 0 ) {
104
+ if (temp.isEmpty ) {
116
105
joinedRow.withLeft(leftNullRow).copy :: Nil
117
106
} else {
118
107
temp
@@ -124,10 +113,9 @@ case class HashOuterJoin(
124
113
ret.iterator
125
114
}
126
115
127
- private [this ] def fullOuterIterator (
116
+ protected [this ] def fullOuterIterator (
128
117
key : InternalRow , leftIter : Iterable [InternalRow ], rightIter : Iterable [InternalRow ],
129
118
joinedRow : JoinedRow ): Iterator [InternalRow ] = {
130
-
131
119
if (! key.anyNull) {
132
120
// Store the positions of records in right, if one of its associated row satisfy
133
121
// the join condition.
@@ -171,7 +159,7 @@ case class HashOuterJoin(
171
159
}
172
160
}
173
161
174
- private [this ] def buildHashTable (
162
+ protected [this ] def buildHashTable (
175
163
iter : Iterator [InternalRow ],
176
164
keyGenerator : Projection ): JavaHashMap [InternalRow , CompactBuffer [InternalRow ]] = {
177
165
val hashTable = new JavaHashMap [InternalRow , CompactBuffer [InternalRow ]]()
@@ -190,43 +178,4 @@ case class HashOuterJoin(
190
178
191
179
hashTable
192
180
}
193
-
194
- protected override def doExecute (): RDD [InternalRow ] = {
195
- val joinedRow = new JoinedRow ()
196
- left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
197
- // TODO this probably can be replaced by external sort (sort merged join?)
198
-
199
- joinType match {
200
- case LeftOuter =>
201
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
202
- val keyGenerator = newProjection(leftKeys, left.output)
203
- leftIter.flatMap( currentRow => {
204
- val rowKey = keyGenerator(currentRow)
205
- joinedRow.withLeft(currentRow)
206
- leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST ))
207
- })
208
-
209
- case RightOuter =>
210
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
211
- val keyGenerator = newProjection(rightKeys, right.output)
212
- rightIter.flatMap ( currentRow => {
213
- val rowKey = keyGenerator(currentRow)
214
- joinedRow.withRight(currentRow)
215
- rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST ), joinedRow)
216
- })
217
-
218
- case FullOuter =>
219
- val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output))
220
- val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
221
- (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
222
- fullOuterIterator(key,
223
- leftHashTable.getOrElse(key, EMPTY_LIST ),
224
- rightHashTable.getOrElse(key, EMPTY_LIST ), joinedRow)
225
- }
226
-
227
- case x =>
228
- throw new IllegalArgumentException (s " HashOuterJoin should not take $x as the JoinType " )
229
- }
230
- }
231
- }
232
181
}
0 commit comments