Skip to content

Commit 3b95ff6

Browse files
云峤nemccarthy
authored andcommitted
[SPARK-7294][SQL] ADD BETWEEN
Author: 云峤 <[email protected]> Author: kaka1992 <[email protected]> Closes apache#5839 from kaka1992/master and squashes the following commits: b15360d [kaka1992] Fix python unit test in sql/test. =_= I forget to commit this file last time. f928816 [kaka1992] Fix python style in sql/test. d2e7f72 [kaka1992] Fix python style in sql/test. c54d904 [kaka1992] Fix empty map bug. 7e64d1e [云峤] Update 7b9b858 [云峤] undo f080f8d [云峤] update pep8 76f0c51 [云峤] Merge remote-tracking branch 'remotes/upstream/master' 7d62368 [云峤] [SPARK-7294] ADD BETWEEN baf839b [云峤] [SPARK-7294] ADD BETWEEN d11d5b9 [云峤] [SPARK-7294] ADD BETWEEN
1 parent d3d39a0 commit 3b95ff6

File tree

4 files changed

+38
-0
lines changed

4 files changed

+38
-0
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,6 +1405,13 @@ def cast(self, dataType):
14051405
raise TypeError("unexpected type: %s" % type(dataType))
14061406
return Column(jc)
14071407

1408+
@ignore_unicode_prefix
1409+
def between(self, lowerBound, upperBound):
1410+
""" A boolean expression that is evaluated to true if the value of this
1411+
expression is between the given columns.
1412+
"""
1413+
return (self >= lowerBound) & (self <= upperBound)
1414+
14081415
def __repr__(self):
14091416
return 'Column<%s>' % self._jc.toString().encode('utf8')
14101417

python/pyspark/sql/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,14 @@ def test_rand_functions(self):
453453
for row in rndn:
454454
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
455455

456+
def test_between_function(self):
457+
df = self.sc.parallelize([
458+
Row(a=1, b=2, c=3),
459+
Row(a=2, b=1, c=3),
460+
Row(a=4, b=1, c=4)]).toDF()
461+
self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)],
462+
df.filter(df.a.between(df.b, df.c)).collect())
463+
456464
def test_save_and_load(self):
457465
df = self.df
458466
tmpPath = tempfile.mkdtemp()

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,15 @@ class Column(protected[sql] val expr: Expression) extends Logging {
295295
*/
296296
def eqNullSafe(other: Any): Column = this <=> other
297297

298+
/**
299+
* True if the current column is between the lower bound and upper bound, inclusive.
300+
*
301+
* @group java_expr_ops
302+
*/
303+
def between(lowerBound: Any, upperBound: Any): Column = {
304+
(this >= lowerBound) && (this <= upperBound)
305+
}
306+
298307
/**
299308
* True if the current expression is null.
300309
*

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,20 @@ class ColumnExpressionSuite extends QueryTest {
208208
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
209209
}
210210

211+
test("between") {
212+
val testData = TestSQLContext.sparkContext.parallelize(
213+
(0, 1, 2) ::
214+
(1, 2, 3) ::
215+
(2, 1, 0) ::
216+
(2, 2, 4) ::
217+
(3, 1, 6) ::
218+
(3, 2, 0) :: Nil).toDF("a", "b", "c")
219+
val expectAnswer = testData.collect().toSeq.
220+
filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2))
221+
222+
checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer)
223+
}
224+
211225
val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
212226
Row(false, false) ::
213227
Row(false, true) ::

0 commit comments

Comments
 (0)