Skip to content

Commit d5d6473

Browse files
cloud-fanyhuai
authored andcommitted
[SPARK-10442] [SQL] fix string to boolean cast
When we cast string to boolean in hive, it returns `true` if the length of string is > 0, and spark SQL follows this behavior. However, this behavior is very different from other SQL systems: 1. [presto](https://github.com/facebook/presto/blob/master/presto-main/src/main/java/com/facebook/presto/type/VarcharOperators.java#L89-L118) will return `true` for 't' 'true' '1', `false` for 'f' 'false' '0', throw exception for others. 2. [redshift](http://docs.aws.amazon.com/redshift/latest/dg/r_Boolean_type.html) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 3. [postgresql](http://www.postgresql.org/docs/devel/static/datatype-boolean.html) will return `true` for 't' 'true' 'y' 'yes' 'on' '1', `false` for 'f' 'false' 'n' 'no' 'off' '0', throw exception for others. 4. [vertica](https://my.vertica.com/docs/5.0/HTML/Master/2983.htm) will return `true` for 't' 'true' 'y' 'yes' '1', `false` for 'f' 'false' 'n' 'no' '0', null for others. 5. [impala](http://www.cloudera.com/content/cloudera/en/documentation/cloudera-impala/latest/topics/impala_boolean.html) throw exception when try to cast string to boolean. 6. mysql, oracle, sqlserver don't have boolean type Whether we should change the cast behavior according to other SQL system or not is not decided yet, this PR is a test to see if we changed, how many compatibility tests will fail. Author: Wenchen Fan <[email protected]> Closes #8698 from cloud-fan/string2boolean.
1 parent c373866 commit d5d6473

File tree

4 files changed

+82
-24
lines changed

4 files changed

+82
-24
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import java.math.{BigDecimal => JavaBigDecimal}
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2424
import org.apache.spark.sql.catalyst.expressions.codegen._
25-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
25+
import org.apache.spark.sql.catalyst.util.{StringUtils, DateTimeUtils}
2626
import org.apache.spark.sql.types._
2727
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2828

@@ -140,7 +140,15 @@ case class Cast(child: Expression, dataType: DataType)
140140
// UDFToBoolean
141141
private[this] def castToBoolean(from: DataType): Any => Any = from match {
142142
case StringType =>
143-
buildCast[UTF8String](_, _.numBytes() != 0)
143+
buildCast[UTF8String](_, s => {
144+
if (StringUtils.isTrueString(s)) {
145+
true
146+
} else if (StringUtils.isFalseString(s)) {
147+
false
148+
} else {
149+
null
150+
}
151+
})
144152
case TimestampType =>
145153
buildCast[Long](_, t => t != 0)
146154
case DateType =>
@@ -646,7 +654,17 @@ case class Cast(child: Expression, dataType: DataType)
646654

647655
private[this] def castToBooleanCode(from: DataType): CastFunction = from match {
648656
case StringType =>
649-
(c, evPrim, evNull) => s"$evPrim = $c.numBytes() != 0;"
657+
val stringUtils = StringUtils.getClass.getName.stripSuffix("$")
658+
(c, evPrim, evNull) =>
659+
s"""
660+
if ($stringUtils.isTrueString($c)) {
661+
$evPrim = true;
662+
} else if ($stringUtils.isFalseString($c)) {
663+
$evPrim = false;
664+
} else {
665+
$evNull = true;
666+
}
667+
"""
650668
case TimestampType =>
651669
(c, evPrim, evNull) => s"$evPrim = $c != 0;"
652670
case DateType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import java.util.regex.Pattern
2121

22+
import org.apache.spark.unsafe.types.UTF8String
23+
2224
object StringUtils {
2325

2426
// replace the _ with .{1} exactly match 1 time of any character
@@ -44,4 +46,10 @@ object StringUtils {
4446
v
4547
}
4648
}
49+
50+
private[this] val trueStrings = Set("t", "true", "y", "yes", "1").map(UTF8String.fromString)
51+
private[this] val falseStrings = Set("f", "false", "n", "no", "0").map(UTF8String.fromString)
52+
53+
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
54+
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
4755
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
503503
}
504504

505505
test("cast from array") {
506-
val array = Literal.create(Seq("123", "abc", "", null),
506+
val array = Literal.create(Seq("123", "true", "f", null),
507507
ArrayType(StringType, containsNull = true))
508-
val array_notNull = Literal.create(Seq("123", "abc", ""),
508+
val array_notNull = Literal.create(Seq("123", "true", "f"),
509509
ArrayType(StringType, containsNull = false))
510510

511511
checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
@@ -522,7 +522,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
522522
{
523523
val ret = cast(array, ArrayType(BooleanType, containsNull = true))
524524
assert(ret.resolved === true)
525-
checkEvaluation(ret, Seq(true, true, false, null))
525+
checkEvaluation(ret, Seq(null, true, false, null))
526526
}
527527
{
528528
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
@@ -541,12 +541,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
541541
{
542542
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
543543
assert(ret.resolved === true)
544-
checkEvaluation(ret, Seq(true, true, false))
544+
checkEvaluation(ret, Seq(null, true, false))
545545
}
546546
{
547547
val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
548548
assert(ret.resolved === true)
549-
checkEvaluation(ret, Seq(true, true, false))
549+
checkEvaluation(ret, Seq(null, true, false))
550550
}
551551

552552
{
@@ -557,10 +557,10 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
557557

558558
test("cast from map") {
559559
val map = Literal.create(
560-
Map("a" -> "123", "b" -> "abc", "c" -> "", "d" -> null),
560+
Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
561561
MapType(StringType, StringType, valueContainsNull = true))
562562
val map_notNull = Literal.create(
563-
Map("a" -> "123", "b" -> "abc", "c" -> ""),
563+
Map("a" -> "123", "b" -> "true", "c" -> "f"),
564564
MapType(StringType, StringType, valueContainsNull = false))
565565

566566
checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
@@ -577,7 +577,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
577577
{
578578
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
579579
assert(ret.resolved === true)
580-
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false, "d" -> null))
580+
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
581581
}
582582
{
583583
val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
@@ -600,12 +600,12 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
600600
{
601601
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
602602
assert(ret.resolved === true)
603-
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
603+
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
604604
}
605605
{
606606
val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
607607
assert(ret.resolved === true)
608-
checkEvaluation(ret, Map("a" -> true, "b" -> true, "c" -> false))
608+
checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
609609
}
610610
{
611611
val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
@@ -630,8 +630,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
630630
val struct = Literal.create(
631631
InternalRow(
632632
UTF8String.fromString("123"),
633-
UTF8String.fromString("abc"),
634-
UTF8String.fromString(""),
633+
UTF8String.fromString("true"),
634+
UTF8String.fromString("f"),
635635
null),
636636
StructType(Seq(
637637
StructField("a", StringType, nullable = true),
@@ -641,8 +641,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
641641
val struct_notNull = Literal.create(
642642
InternalRow(
643643
UTF8String.fromString("123"),
644-
UTF8String.fromString("abc"),
645-
UTF8String.fromString("")),
644+
UTF8String.fromString("true"),
645+
UTF8String.fromString("f")),
646646
StructType(Seq(
647647
StructField("a", StringType, nullable = false),
648648
StructField("b", StringType, nullable = false),
@@ -672,7 +672,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
672672
StructField("c", BooleanType, nullable = true),
673673
StructField("d", BooleanType, nullable = true))))
674674
assert(ret.resolved === true)
675-
checkEvaluation(ret, InternalRow(true, true, false, null))
675+
checkEvaluation(ret, InternalRow(null, true, false, null))
676676
}
677677
{
678678
val ret = cast(struct, StructType(Seq(
@@ -704,15 +704,15 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
704704
StructField("b", BooleanType, nullable = true),
705705
StructField("c", BooleanType, nullable = true))))
706706
assert(ret.resolved === true)
707-
checkEvaluation(ret, InternalRow(true, true, false))
707+
checkEvaluation(ret, InternalRow(null, true, false))
708708
}
709709
{
710710
val ret = cast(struct_notNull, StructType(Seq(
711711
StructField("a", BooleanType, nullable = true),
712712
StructField("b", BooleanType, nullable = true),
713713
StructField("c", BooleanType, nullable = false))))
714714
assert(ret.resolved === true)
715-
checkEvaluation(ret, InternalRow(true, true, false))
715+
checkEvaluation(ret, InternalRow(null, true, false))
716716
}
717717

718718
{
@@ -731,8 +731,8 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
731731
test("complex casting") {
732732
val complex = Literal.create(
733733
Row(
734-
Seq("123", "abc", ""),
735-
Map("a" ->"123", "b" -> "abc", "c" -> ""),
734+
Seq("123", "true", "f"),
735+
Map("a" ->"123", "b" -> "true", "c" -> "f"),
736736
Row(0)),
737737
StructType(Seq(
738738
StructField("a",
@@ -755,11 +755,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
755755
assert(ret.resolved === true)
756756
checkEvaluation(ret, Row(
757757
Seq(123, null, null),
758-
Map("a" -> true, "b" -> true, "c" -> false),
758+
Map("a" -> null, "b" -> true, "c" -> false),
759759
Row(0L)))
760760
}
761761

762-
test("case between string and interval") {
762+
test("cast between string and interval") {
763763
import org.apache.spark.unsafe.types.CalendarInterval
764764

765765
checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
@@ -769,4 +769,23 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
769769
StringType),
770770
"interval 1 years 3 months -3 days")
771771
}
772+
773+
test("cast string to boolean") {
774+
checkCast("t", true)
775+
checkCast("true", true)
776+
checkCast("tRUe", true)
777+
checkCast("y", true)
778+
checkCast("yes", true)
779+
checkCast("1", true)
780+
781+
checkCast("f", false)
782+
checkCast("false", false)
783+
checkCast("FAlsE", false)
784+
checkCast("n", false)
785+
checkCast("no", false)
786+
checkCast("0", false)
787+
788+
checkEvaluation(cast("abc", BooleanType), null)
789+
checkEvaluation(cast("", BooleanType), null)
790+
}
772791
}

sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,19 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
375375
}
376376
}
377377

378+
test("saveAsTable()/load() - partitioned table - boolean type") {
379+
sqlContext.range(2)
380+
.select('id, ('id % 2 === 0).as("b"))
381+
.write.partitionBy("b").saveAsTable("t")
382+
383+
withTable("t") {
384+
checkAnswer(
385+
sqlContext.table("t").sort('id),
386+
Row(0, true) :: Row(1, false) :: Nil
387+
)
388+
}
389+
}
390+
378391
test("saveAsTable()/load() - partitioned table - Overwrite") {
379392
partitionedTestDF.write
380393
.format(dataSourceName)

0 commit comments

Comments
 (0)