Skip to content

Commit ba33096

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-9068][SQL] refactor the implicit type cast code
based on apache#7348 Author: Wenchen Fan <[email protected]> Closes apache#7420 from cloud-fan/type-check and squashes the following commits: 7633fa9 [Wenchen Fan] revert fe169b0 [Wenchen Fan] improve test 03b70da [Wenchen Fan] enhance implicit type cast
1 parent 42dea3a commit ba33096

File tree

13 files changed

+81
-126
lines changed

13 files changed

+81
-126
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -675,10 +675,10 @@ object HiveTypeCoercion {
675675
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
676676
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
677677
if (b.inputType.acceptsType(commonType)) {
678-
// If the expression accepts the tighest common type, cast to that.
678+
// If the expression accepts the tightest common type, cast to that.
679679
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
680680
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
681-
b.makeCopy(Array(newLeft, newRight))
681+
b.withNewChildren(Seq(newLeft, newRight))
682682
} else {
683683
// Otherwise, don't do anything with the expression.
684684
b
@@ -697,7 +697,7 @@ object HiveTypeCoercion {
697697
// general implicit casting.
698698
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
699699
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
700-
Cast(in, expected.defaultConcreteType)
700+
Literal.create(null, expected.defaultConcreteType)
701701
} else {
702702
in
703703
}
@@ -719,27 +719,22 @@ object HiveTypeCoercion {
719719
@Nullable val ret: Expression = (inType, expectedType) match {
720720

721721
// If the expected type is already a parent of the input type, no need to cast.
722-
case _ if expectedType.isSameType(inType) => e
722+
case _ if expectedType.acceptsType(inType) => e
723723

724724
// Cast null type (usually from null literals) into target types
725725
case (NullType, target) => Cast(e, target.defaultConcreteType)
726726

727-
// If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
728-
// already a number, leave it as is.
729-
case (_: NumericType, NumericType) => e
730-
731727
// If the function accepts any numeric type and the input is a string, we follow the hive
732728
// convention and cast that input into a double
733729
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
734730

735-
// Implicit cast among numeric types
731+
// Implicit cast among numeric types. When we reach here, input type is not acceptable.
732+
736733
// If input is a numeric type but not decimal, and we expect a decimal type,
737734
// cast the input to unlimited precision decimal.
738-
case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
739-
Cast(e, DecimalType.Unlimited)
735+
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
740736
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
741-
case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
742-
case (_: NumericType, target: NumericType) => e
737+
case (_: NumericType, target: NumericType) => Cast(e, target)
743738

744739
// Implicit cast between date time types
745740
case (DateType, TimestampType) => Cast(e, TimestampType)
@@ -753,15 +748,9 @@ object HiveTypeCoercion {
753748
case (StringType, BinaryType) => Cast(e, BinaryType)
754749
case (any, StringType) if any != StringType => Cast(e, StringType)
755750

756-
// Type collection.
757-
// First see if we can find our input type in the type collection. If we can, then just
758-
// use the current expression; otherwise, find the first one we can implicitly cast.
759-
case (_, TypeCollection(types)) =>
760-
if (types.exists(_.isSameType(inType))) {
761-
e
762-
} else {
763-
types.flatMap(implicitCast(e, _)).headOption.orNull
764-
}
751+
// When we reach here, input type is not acceptable for any types in this type collection,
752+
// try to find the first one we can implicitly cast.
753+
case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
765754

766755
// Else, just return the same input expression
767756
case _ => null

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -386,17 +386,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
386386
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
387387

388388
override def checkInputDataTypes(): TypeCheckResult = {
389-
// First call the checker for ExpectsInputTypes, and then check whether left and right have
390-
// the same type.
391-
super.checkInputDataTypes() match {
392-
case TypeCheckResult.TypeCheckSuccess =>
393-
if (left.dataType != right.dataType) {
394-
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
395-
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
396-
} else {
397-
TypeCheckResult.TypeCheckSuccess
398-
}
399-
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
389+
// First check whether left and right have the same type, then check if the type is acceptable.
390+
if (left.dataType != right.dataType) {
391+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
392+
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
393+
} else if (!inputType.acceptsType(left.dataType)) {
394+
TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
395+
s" not ${left.dataType.simpleString}")
396+
} else {
397+
TypeCheckResult.TypeCheckSuccess
400398
}
401399
}
402400
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
320320
}
321321

322322
override def symbol: String = "max"
323-
override def prettyName: String = symbol
324323
}
325324

326325
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -375,7 +374,6 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
375374
}
376375

377376
override def symbol: String = "min"
378-
override def prettyName: String = symbol
379377
}
380378

381379
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
2828
*/
2929
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
3030

31-
override def inputType: AbstractDataType = TypeCollection.Bitwise
31+
override def inputType: AbstractDataType = IntegralType
3232

3333
override def symbol: String = "&"
3434

@@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
5353
*/
5454
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
5555

56-
override def inputType: AbstractDataType = TypeCollection.Bitwise
56+
override def inputType: AbstractDataType = IntegralType
5757

5858
override def symbol: String = "|"
5959

@@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
7878
*/
7979
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
8080

81-
override def inputType: AbstractDataType = TypeCollection.Bitwise
81+
override def inputType: AbstractDataType = IntegralType
8282

8383
override def symbol: String = "^"
8484

@@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
101101
*/
102102
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {
103103

104-
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
104+
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)
105105

106106
override def dataType: DataType = child.dataType
107107

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
3535
TypeCheckResult.TypeCheckFailure(
3636
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
3737
} else if (trueValue.dataType != falseValue.dataType) {
38-
TypeCheckResult.TypeCheckFailure(
39-
s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
38+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
39+
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
4040
} else {
4141
TypeCheckResult.TypeCheckSuccess
4242
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
3434
private[sql] def defaultConcreteType: DataType
3535

3636
/**
37-
* Returns true if this data type is the same type as `other`. This is different that equality
38-
* as equality will also consider data type parametrization, such as decimal precision.
37+
* Returns true if `other` is an acceptable input type for a function that expects this,
38+
* possibly abstract DataType.
3939
*
4040
* {{{
4141
* // this should return true
42-
* DecimalType.isSameType(DecimalType(10, 2))
43-
*
44-
* // this should return false
45-
* NumericType.isSameType(DecimalType(10, 2))
46-
* }}}
47-
*/
48-
private[sql] def isSameType(other: DataType): Boolean
49-
50-
/**
51-
* Returns true if `other` is an acceptable input type for a function that expectes this,
52-
* possibly abstract, DataType.
53-
*
54-
* {{{
55-
* // this should return true
56-
* DecimalType.isSameType(DecimalType(10, 2))
42+
* DecimalType.acceptsType(DecimalType(10, 2))
5743
*
5844
* // this should return true as well
5945
* NumericType.acceptsType(DecimalType(10, 2))
6046
* }}}
6147
*/
62-
private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
48+
private[sql] def acceptsType(other: DataType): Boolean
6349

6450
/** Readable string representation for the type. */
6551
private[sql] def simpleString: String
@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
8369

8470
override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
8571

86-
override private[sql] def isSameType(other: DataType): Boolean = false
87-
8872
override private[sql] def acceptsType(other: DataType): Boolean =
89-
types.exists(_.isSameType(other))
73+
types.exists(_.acceptsType(other))
9074

9175
override private[sql] def simpleString: String = {
9276
types.map(_.simpleString).mkString("(", " or ", ")")
@@ -107,13 +91,6 @@ private[sql] object TypeCollection {
10791
TimestampType, DateType,
10892
StringType, BinaryType)
10993

110-
/**
111-
* Types that can be used in bitwise operations.
112-
*/
113-
val Bitwise = TypeCollection(
114-
BooleanType,
115-
ByteType, ShortType, IntegerType, LongType)
116-
11794
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
11895

11996
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
@@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {
134111

135112
override private[sql] def simpleString: String = "any"
136113

137-
override private[sql] def isSameType(other: DataType): Boolean = false
138-
139114
override private[sql] def acceptsType(other: DataType): Boolean = true
140115
}
141116

@@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {
183158

184159
override private[sql] def simpleString: String = "numeric"
185160

186-
override private[sql] def isSameType(other: DataType): Boolean = false
187-
188161
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
189162
}
190163

191164

192-
private[sql] object IntegralType {
165+
private[sql] object IntegralType extends AbstractDataType {
193166
/**
194167
* Enables matching against IntegralType for expressions:
195168
* {{{
@@ -198,6 +171,12 @@ private[sql] object IntegralType {
198171
* }}}
199172
*/
200173
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]
174+
175+
override private[sql] def defaultConcreteType: DataType = IntegerType
176+
177+
override private[sql] def simpleString: String = "integral"
178+
179+
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
201180
}
202181

203182

sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {
2828

2929
override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
3030

31-
override private[sql] def isSameType(other: DataType): Boolean = {
31+
override private[sql] def acceptsType(other: DataType): Boolean = {
3232
other.isInstanceOf[ArrayType]
3333
}
3434

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {
7979

8080
override private[sql] def defaultConcreteType: DataType = this
8181

82-
override private[sql] def isSameType(other: DataType): Boolean = this == other
82+
override private[sql] def acceptsType(other: DataType): Boolean = this == other
8383
}
8484

8585

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {
8686

8787
override private[sql] def defaultConcreteType: DataType = Unlimited
8888

89-
override private[sql] def isSameType(other: DataType): Boolean = {
89+
override private[sql] def acceptsType(other: DataType): Boolean = {
9090
other.isInstanceOf[DecimalType]
9191
}
9292

sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ object MapType extends AbstractDataType {
7171

7272
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
7373

74-
override private[sql] def isSameType(other: DataType): Boolean = {
74+
override private[sql] def acceptsType(other: DataType): Boolean = {
7575
other.isInstanceOf[MapType]
7676
}
7777

0 commit comments

Comments
 (0)