Skip to content

Commit 1850d90

Browse files
sarutakmarmbrus
authored andcommitted
[SPARK-4536][SQL] Add sqrt and abs to Spark SQL DSL
Spark SQL has embeded sqrt and abs but DSL doesn't support those functions. Author: Kousuke Saruta <[email protected]> Closes #3401 from sarutak/dsl-missing-operator and squashes the following commits: 07700cf [Kousuke Saruta] Modified Literal(null, NullType) to Literal(null) in DslQuerySuite 8f366f8 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into dsl-missing-operator 1b88e2e [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into dsl-missing-operator 0396f89 [Kousuke Saruta] Added sqrt and abs to Spark SQL DSL (cherry picked from commit e75e04f) Signed-off-by: Michael Armbrust <[email protected]>
1 parent b97c27f commit 1850d90

File tree

4 files changed

+74
-1
lines changed

4 files changed

+74
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ package object dsl {
147147
def max(e: Expression) = Max(e)
148148
def upper(e: Expression) = Upper(e)
149149
def lower(e: Expression) = Lower(e)
150+
def sqrt(e: Expression) = Sqrt(e)
151+
def abs(e: Expression) = Abs(e)
150152

151153
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
152154
// TODO more implicit class for literal?

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2121
import org.apache.spark.sql.catalyst.types._
22-
import scala.math.pow
2322

2423
case class UnaryMinus(child: Expression) extends UnaryExpression {
2524
type EvaluatedType = Any

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,4 +282,72 @@ class DslQuerySuite extends QueryTest {
282282
(1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
283283
)
284284
}
285+
286+
test("sqrt") {
287+
checkAnswer(
288+
testData.select(sqrt('key)).orderBy('key asc),
289+
(1 to 100).map(n => Seq(math.sqrt(n)))
290+
)
291+
292+
checkAnswer(
293+
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
294+
(1 to 100).map(n => Seq(math.sqrt(n), n))
295+
)
296+
297+
checkAnswer(
298+
testData.select(sqrt(Literal(null))),
299+
(1 to 100).map(_ => Seq(null))
300+
)
301+
}
302+
303+
test("abs") {
304+
checkAnswer(
305+
testData.select(abs('key)).orderBy('key asc),
306+
(1 to 100).map(n => Seq(n))
307+
)
308+
309+
checkAnswer(
310+
negativeData.select(abs('key)).orderBy('key desc),
311+
(1 to 100).map(n => Seq(n))
312+
)
313+
314+
checkAnswer(
315+
testData.select(abs(Literal(null))),
316+
(1 to 100).map(_ => Seq(null))
317+
)
318+
}
319+
320+
test("upper") {
321+
checkAnswer(
322+
lowerCaseData.select(upper('l)),
323+
('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
324+
)
325+
326+
checkAnswer(
327+
testData.select(upper('value), 'key),
328+
(1 to 100).map(n => Seq(n.toString, n))
329+
)
330+
331+
checkAnswer(
332+
testData.select(upper(Literal(null))),
333+
(1 to 100).map(n => Seq(null))
334+
)
335+
}
336+
337+
test("lower") {
338+
checkAnswer(
339+
upperCaseData.select(lower('L)),
340+
('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
341+
)
342+
343+
checkAnswer(
344+
testData.select(lower('value), 'key),
345+
(1 to 100).map(n => Seq(n.toString, n))
346+
)
347+
348+
checkAnswer(
349+
testData.select(lower(Literal(null))),
350+
(1 to 100).map(n => Seq(null))
351+
)
352+
}
285353
}

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ object TestData {
3232
(1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
3333
testData.registerTempTable("testData")
3434

35+
val negativeData = TestSQLContext.sparkContext.parallelize(
36+
(1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
37+
negativeData.registerTempTable("negativeData")
38+
3539
case class LargeAndSmallInts(a: Int, b: Int)
3640
val largeAndSmallInts =
3741
TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)