Skip to content

Commit e626ac5

Browse files
zsxwingAndrew Or
authored andcommitted
[SPARK-9992] [SPARK-9994] [SPARK-9998] [SQL] Implement the local TopK, sample and intersect operators
This PR is in conflict with #8535. I will update this one when #8535 gets merged. Author: zsxwing <[email protected]> Closes #8573 from zsxwing/more-local-operators.
1 parent 1eede3b commit e626ac5

File tree

8 files changed

+353
-1
lines changed

8 files changed

+353
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
138138
* will be ub - lb.
139139
* @param withReplacement Whether to sample with replacement.
140140
* @param seed the random seed
141-
* @param child the QueryPlan
141+
* @param child the SparkPlan
142142
*/
143143
@DeveloperApi
144144
case class Sample(
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import scala.collection.mutable
21+
22+
import org.apache.spark.sql.SQLConf
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.expressions.Attribute
25+
26+
case class IntersectNode(conf: SQLConf, left: LocalNode, right: LocalNode)
27+
extends BinaryLocalNode(conf) {
28+
29+
override def output: Seq[Attribute] = left.output
30+
31+
private[this] var leftRows: mutable.HashSet[InternalRow] = _
32+
33+
private[this] var currentRow: InternalRow = _
34+
35+
override def open(): Unit = {
36+
left.open()
37+
leftRows = mutable.HashSet[InternalRow]()
38+
while (left.next()) {
39+
leftRows += left.fetch().copy()
40+
}
41+
left.close()
42+
right.open()
43+
}
44+
45+
override def next(): Boolean = {
46+
currentRow = null
47+
while (currentRow == null && right.next()) {
48+
currentRow = right.fetch()
49+
if (!leftRows.contains(currentRow)) {
50+
currentRow = null
51+
}
52+
}
53+
currentRow != null
54+
}
55+
56+
override def fetch(): InternalRow = currentRow
57+
58+
override def close(): Unit = {
59+
left.close()
60+
right.close()
61+
}
62+
63+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
6969
*/
7070
def close(): Unit
7171

72+
/**
73+
* Returns the content through the [[Iterator]] interface.
74+
*/
75+
final def asIterator: Iterator[InternalRow] = new LocalNodeIterator(this)
76+
7277
/**
7378
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
7479
*/
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import java.util.Random
21+
22+
import org.apache.spark.sql.SQLConf
23+
import org.apache.spark.sql.catalyst.InternalRow
24+
import org.apache.spark.sql.catalyst.expressions.Attribute
25+
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
26+
27+
/**
28+
* Sample the dataset.
29+
*
30+
* @param conf the SQLConf
31+
* @param lowerBound Lower-bound of the sampling probability (usually 0.0)
32+
* @param upperBound Upper-bound of the sampling probability. The expected fraction sampled
33+
* will be ub - lb.
34+
* @param withReplacement Whether to sample with replacement.
35+
* @param seed the random seed
36+
* @param child the LocalNode
37+
*/
38+
case class SampleNode(
39+
conf: SQLConf,
40+
lowerBound: Double,
41+
upperBound: Double,
42+
withReplacement: Boolean,
43+
seed: Long,
44+
child: LocalNode) extends UnaryLocalNode(conf) {
45+
46+
override def output: Seq[Attribute] = child.output
47+
48+
private[this] var iterator: Iterator[InternalRow] = _
49+
50+
private[this] var currentRow: InternalRow = _
51+
52+
override def open(): Unit = {
53+
child.open()
54+
val (sampler, _seed) = if (withReplacement) {
55+
val random = new Random(seed)
56+
// Disable gap sampling since the gap sampling method buffers two rows internally,
57+
// requiring us to copy the row, which is more expensive than the random number generator.
58+
(new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
59+
// Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
60+
// of DataFrame
61+
random.nextLong())
62+
} else {
63+
(new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
64+
}
65+
sampler.setSeed(_seed)
66+
iterator = sampler.sample(child.asIterator)
67+
}
68+
69+
override def next(): Boolean = {
70+
if (iterator.hasNext) {
71+
currentRow = iterator.next()
72+
true
73+
} else {
74+
false
75+
}
76+
}
77+
78+
override def fetch(): InternalRow = currentRow
79+
80+
override def close(): Unit = child.close()
81+
82+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import org.apache.spark.sql.SQLConf
21+
import org.apache.spark.sql.catalyst.InternalRow
22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.util.BoundedPriorityQueue
24+
25+
case class TakeOrderedAndProjectNode(
26+
conf: SQLConf,
27+
limit: Int,
28+
sortOrder: Seq[SortOrder],
29+
projectList: Option[Seq[NamedExpression]],
30+
child: LocalNode) extends UnaryLocalNode(conf) {
31+
32+
private[this] var projection: Option[Projection] = _
33+
private[this] var ord: InterpretedOrdering = _
34+
private[this] var iterator: Iterator[InternalRow] = _
35+
private[this] var currentRow: InternalRow = _
36+
37+
override def output: Seq[Attribute] = {
38+
val projectOutput = projectList.map(_.map(_.toAttribute))
39+
projectOutput.getOrElse(child.output)
40+
}
41+
42+
override def open(): Unit = {
43+
child.open()
44+
projection = projectList.map(new InterpretedProjection(_, child.output))
45+
ord = new InterpretedOrdering(sortOrder, child.output)
46+
// Priority keeps the largest elements, so let's reverse the ordering.
47+
val queue = new BoundedPriorityQueue[InternalRow](limit)(ord.reverse)
48+
while (child.next()) {
49+
queue += child.fetch()
50+
}
51+
// Close it eagerly since we don't need it.
52+
child.close()
53+
iterator = queue.iterator
54+
}
55+
56+
override def next(): Boolean = {
57+
if (iterator.hasNext) {
58+
val _currentRow = iterator.next()
59+
currentRow = projection match {
60+
case Some(p) => p(_currentRow)
61+
case None => _currentRow
62+
}
63+
true
64+
} else {
65+
false
66+
}
67+
}
68+
69+
override def fetch(): InternalRow = currentRow
70+
71+
override def close(): Unit = child.close()
72+
73+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
class IntersectNodeSuite extends LocalNodeTest {
21+
22+
import testImplicits._
23+
24+
test("basic") {
25+
val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
26+
val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")
27+
28+
checkAnswer2(
29+
input1,
30+
input2,
31+
(node1, node2) => IntersectNode(conf, node1, node2),
32+
input1.intersect(input2).collect()
33+
)
34+
}
35+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
class SampleNodeSuite extends LocalNodeTest {
21+
22+
import testImplicits._
23+
24+
private def testSample(withReplacement: Boolean): Unit = {
25+
test(s"withReplacement: $withReplacement") {
26+
val seed = 0L
27+
val input = sqlContext.sparkContext.
28+
parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
29+
toDF("key", "value")
30+
checkAnswer(
31+
input,
32+
node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
33+
input.sample(withReplacement, 0.3, seed).collect()
34+
)
35+
}
36+
}
37+
38+
testSample(withReplacement = true)
39+
testSample(withReplacement = false)
40+
}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.local
19+
20+
import org.apache.spark.sql.Column
21+
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
22+
23+
class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
24+
25+
import testImplicits._
26+
27+
private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
28+
val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
29+
col.expr match {
30+
case expr: SortOrder =>
31+
expr
32+
case expr: Expression =>
33+
SortOrder(expr, Ascending)
34+
}
35+
}
36+
sortOrder
37+
}
38+
39+
private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
40+
val testCaseName = if (desc) "desc" else "asc"
41+
test(testCaseName) {
42+
val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
43+
val sortColumn = if (desc) input.col("key").desc else input.col("key")
44+
checkAnswer(
45+
input,
46+
node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
47+
input.sort(sortColumn).limit(5).collect()
48+
)
49+
}
50+
}
51+
52+
testTakeOrderedAndProjectNode(desc = false)
53+
testTakeOrderedAndProjectNode(desc = true)
54+
}

0 commit comments

Comments
 (0)