From c0f85bb676443cd0d6d73f8beb4139fb062fef8c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 10 Dec 2015 21:44:23 -0800 Subject: [PATCH 1/3] fix passing null into ScalaUDF --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 5deb2f81d1738..ae34a46ecc4d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1032,8 +1032,8 @@ case class ScalaUDF( val funcArguments = converterTerms.zipWithIndex.map { case (converter, i) => val eval = evals(i) - val dt = children(i).dataType - s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)}) ${eval.value})" + val boxedType = ctx.boxedType(children(i).dataType) + s"$converter.apply(${eval.isNull} ? ($boxedType) null : ($boxedType) ${eval.value})" }.mkString(",") val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + From c96b512a8a30575b24e8e9dbba24e0ac7ae16121 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 10 Dec 2015 22:21:32 -0800 Subject: [PATCH 2/3] add test --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index ae34a46ecc4d7..371cf722a3063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1032,8 +1032,8 @@ case class ScalaUDF( val funcArguments = converterTerms.zipWithIndex.map { case (converter, i) => val eval = evals(i) - val boxedType = ctx.boxedType(children(i).dataType) - s"$converter.apply(${eval.isNull} ? ($boxedType) null : ($boxedType) ${eval.value})" + val dt = children(i).dataType + s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)})(${eval.value}))" }.mkString(",") val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8887dc68a50e7..5b74973b739c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1148,6 +1148,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { } checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) + sqlContext.udf.register("boxedUDF", (i: java.lang.Integer) => if (i == null) -10 else i * 2) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, -2) :: Nil) + val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) } From 2125a1bac7edd75c1602806089cee6eff2e11660 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 11 Dec 2015 00:44:03 -0800 Subject: [PATCH 3/3] fix bug in handling result (null) --- .../sql/catalyst/expressions/ScalaUDF.scala | 31 ++++++++++--------- .../org/apache/spark/sql/DataFrameSuite.scala | 9 +++--- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 371cf722a3063..85faa19bbf5ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -1029,24 +1029,27 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val funcArguments = converterTerms.zipWithIndex.map { - case (converter, i) => - val eval = evals(i) - val dt = children(i).dataType - s"$converter.apply(${eval.isNull} ? null : (${ctx.boxedType(dt)})(${eval.value}))" - }.mkString(",") - val callFunc = s"${ctx.boxedType(ctx.javaType(dataType))} $resultTerm = " + - s"(${ctx.boxedType(ctx.javaType(dataType))})${catalystConverterTerm}" + - s".apply($funcTerm.apply($funcArguments));" + val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" + (convert, argTerm) + }.unzip - evalCode + s""" - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - Boolean ${ev.isNull}; + val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " + + s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" + s""" + $evalCode + ${converters.mkString("\n")} $callFunc - ${ev.value} = $resultTerm; - ${ev.isNull} = $resultTerm == null; + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $resultTerm; + } """ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5b74973b739c1..5353fefaf4b84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1144,12 +1144,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { // passing null into the UDF that could handle it val boxedUDF = udf[java.lang.Integer, java.lang.Integer] { - (i: java.lang.Integer) => if (i == null) -10 else i * 2 + (i: java.lang.Integer) => if (i == null) -10 else null } - checkAnswer(df.select(boxedUDF($"age")), Row(44) :: Row(-10) :: Nil) + checkAnswer(df.select(boxedUDF($"age")), Row(null) :: Row(-10) :: Nil) - sqlContext.udf.register("boxedUDF", (i: java.lang.Integer) => if (i == null) -10 else i * 2) - checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, -2) :: Nil) + sqlContext.udf.register("boxedUDF", + (i: java.lang.Integer) => (if (i == null) -10 else null): java.lang.Integer) + checkAnswer(sql("select boxedUDF(null), boxedUDF(-1)"), Row(-10, null) :: Nil) val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil)