Skip to content

Commit a3270b0

Browse files
committed
Address Andrew's comments
1 parent 4090902 commit a3270b0

File tree

5 files changed

+43
-63
lines changed

5 files changed

+43
-63
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
6969
*/
7070
def close(): Unit
7171

72+
/**
73+
* Returns the content through the [[Iterator]] interface.
74+
*/
75+
final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
76+
7277
/**
7378
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
7479
*/
@@ -108,7 +113,6 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
108113
}
109114
}
110115

111-
def toIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
112116
}
113117

114118

sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ case class SampleNode(
6363
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
6464
}
6565
sampler.setSeed(_seed)
66-
iterator = sampler.sample(child.toIterator)
66+
iterator = sampler.sample(child.asIterator)
6767
}
6868

6969
override def next(): Boolean = {

sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,23 +29,21 @@ case class TakeOrderedAndProjectNode(
2929
projectList: Option[Seq[NamedExpression]],
3030
child: LocalNode) extends UnaryLocalNode(conf) {
3131

32-
override def output: Seq[Attribute] = {
33-
val projectOutput = projectList.map(_.map(_.toAttribute))
34-
projectOutput.getOrElse(child.output)
35-
}
36-
3732
private[this] var projection: Option[Projection] = _
38-
3933
private[this] var ord: InterpretedOrdering = _
40-
4134
private[this] var iterator: Iterator[InternalRow] = _
42-
4335
private[this] var currentRow: InternalRow = _
4436

37+
override def output: Seq[Attribute] = {
38+
val projectOutput = projectList.map(_.map(_.toAttribute))
39+
projectOutput.getOrElse(child.output)
40+
}
41+
4542
override def open(): Unit = {
4643
child.open()
4744
projection = projectList.map(new InterpretedProjection(_, child.output))
4845
ord = new InterpretedOrdering(sortOrder, child.output)
46+
// Priority keeps the largest elements, so let's reverse the ordering.
4947
val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse)
5048
while (child.next()) {
5149
queue += child.fetch()
@@ -58,7 +56,10 @@ case class TakeOrderedAndProjectNode(
5856
override def next(): Boolean = {
5957
if (iterator.hasNext) {
6058
val _currentRow = iterator.next()
61-
currentRow = projection.map(p => p(_currentRow)).getOrElse(_currentRow)
59+
currentRow = projection match {
60+
case Some(p) => p(_currentRow)
61+
case None => _currentRow
62+
}
6263
true
6364
} else {
6465
false

sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,24 @@
1717

1818
package org.apache.spark.sql.execution.local
1919

20-
import org.apache.spark.sql.Column
21-
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
22-
2320
class SampleNodeSuite extends LocalNodeTest {
2421

2522
import testImplicits._
2623

27-
def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
28-
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
29-
col.expr match {
30-
case expr: SortOrder =>
31-
expr
32-
case expr: Expression =>
33-
SortOrder(expr, Ascending)
34-
}
24+
private def testSample(withReplacement: Boolean): Unit = {
25+
test(s"withReplacement: $withReplacement") {
26+
val seed = 0L
27+
val input = sqlContext.sparkContext.
28+
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
29+
toDF("key", "value")
30+
checkAnswer(
31+
input,
32+
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
33+
input.sample(withReplacement, 0.3, seed).collect()
34+
)
3535
}
36-
sortOrder
37-
}
38-
39-
test("withReplacement: true") {
40-
val seed = 0L
41-
val input = sqlContext.sparkContext.
42-
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
43-
toDF("key", "value")
44-
checkAnswer(
45-
input,
46-
node => SampleNode(conf, 0.0, 0.3, true, seed, node),
47-
input.sample(true, 0.3, seed).collect()
48-
)
4936
}
5037

51-
test("withReplacement: false") {
52-
val seed = 0L
53-
val input = sqlContext.sparkContext.
54-
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
55-
toDF("key", "value")
56-
checkAnswer(
57-
input,
58-
node => SampleNode(conf, 0.0, 0.3, false, seed, node),
59-
input.sample(false, 0.3, seed).collect()
60-
)
61-
}
38+
testSample(withReplacement = true)
39+
testSample(withReplacement = false)
6240
}

sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
2424

2525
import testImplicits._
2626

27-
def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
27+
private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
2828
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
2929
col.expr match {
3030
case expr: SortOrder =>
@@ -36,22 +36,19 @@ class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
3636
sortOrder
3737
}
3838

39-
test("asc") {
40-
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
41-
checkAnswer(
42-
input,
43-
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key")), None, node),
44-
input.sort(input.col("key")).limit(5).collect()
45-
)
39+
private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
40+
val testCaseName = if (desc) "desc" else "asc"
41+
test(testCaseName) {
42+
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
43+
val sortColumn = if (desc) input.col("key").desc else input.col("key")
44+
checkAnswer(
45+
input,
46+
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
47+
input.sort(sortColumn).limit(5).collect()
48+
)
49+
}
4650
}
4751

48-
test("desc") {
49-
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
50-
checkAnswer(
51-
input,
52-
node =>
53-
TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(input.col("key").desc), None, node),
54-
input.sort(input.col("key").desc).limit(5).collect()
55-
)
56-
}
52+
testTakeOrderedAndProjectNode(desc = false)
53+
testTakeOrderedAndProjectNode(desc = true)
5754
}

0 commit comments

Comments
 (0)