@@ -181,7 +181,7 @@ object DataType {
181
181
/**
182
182
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
183
183
*/
184
- private [spark ] def equalsIgnoreNullability (left : DataType , right : DataType ): Boolean = {
184
+ private [types ] def equalsIgnoreNullability (left : DataType , right : DataType ): Boolean = {
185
185
(left, right) match {
186
186
case (ArrayType (leftElementType, _), ArrayType (rightElementType, _)) =>
187
187
equalsIgnoreNullability(leftElementType, rightElementType)
@@ -213,7 +213,7 @@ object DataType {
213
213
* if and only if for all every pair of fields, `to.nullable` is true, or both
214
214
* of `fromField.nullable` and `toField.nullable` are false.
215
215
*/
216
- private [spark ] def equalsIgnoreCompatibleNullability (from : DataType , to : DataType ): Boolean = {
216
+ private [sql ] def equalsIgnoreCompatibleNullability (from : DataType , to : DataType ): Boolean = {
217
217
(from, to) match {
218
218
case (ArrayType (fromElement, fn), ArrayType (toElement, tn)) =>
219
219
(tn || ! fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
@@ -235,20 +235,6 @@ object DataType {
235
235
case (fromDataType, toDataType) => fromDataType == toDataType
236
236
}
237
237
}
238
-
239
- /** Sets all nullable/containsNull/valueContainsNull to true. */
240
- private [spark] def alwaysNullable (dataType : DataType ): DataType = dataType match {
241
- case ArrayType (elementType, _) =>
242
- ArrayType (alwaysNullable(elementType), containsNull = true )
243
- case MapType (keyType, valueType, _) =>
244
- MapType (alwaysNullable(keyType), alwaysNullable(valueType), valueContainsNull = true )
245
- case StructType (fields) =>
246
- val newFields = fields.map { field =>
247
- StructField (field.name, alwaysNullable(field.dataType), nullable = true )
248
- }
249
- StructType (newFields)
250
- case other => other
251
- }
252
238
}
253
239
254
240
@@ -281,6 +267,16 @@ abstract class DataType {
281
267
def prettyJson : String = pretty(render(jsonValue))
282
268
283
269
def simpleString : String = typeName
270
+
271
+ /** Check if `this` and `other` are the same data type when ignoring nullability
272
+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
273
+ */
274
+ def sameType (other : DataType ): Boolean = DataType .equalsIgnoreNullability(this , other)
275
+
276
+ /** Returns the same data type but set all nullability fields are true
277
+ * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
278
+ */
279
+ def asNullable : DataType
284
280
}
285
281
286
282
/**
@@ -296,6 +292,8 @@ class NullType private() extends DataType {
296
292
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
297
293
// Defined with a private constructor so the companion object is the only possible instantiation.
298
294
override def defaultSize : Int = 1
295
+
296
+ override def asNullable : NullType = this
299
297
}
300
298
301
299
case object NullType extends NullType
@@ -361,6 +359,8 @@ class StringType private() extends NativeType with PrimitiveType {
361
359
* The default size of a value of the StringType is 4096 bytes.
362
360
*/
363
361
override def defaultSize : Int = 4096
362
+
363
+ override def asNullable : StringType = this
364
364
}
365
365
366
366
case object StringType extends StringType
@@ -395,6 +395,8 @@ class BinaryType private() extends NativeType with PrimitiveType {
395
395
* The default size of a value of the BinaryType is 4096 bytes.
396
396
*/
397
397
override def defaultSize : Int = 4096
398
+
399
+ override def asNullable : BinaryType = this
398
400
}
399
401
400
402
case object BinaryType extends BinaryType
@@ -420,6 +422,8 @@ class BooleanType private() extends NativeType with PrimitiveType {
420
422
* The default size of a value of the BooleanType is 1 byte.
421
423
*/
422
424
override def defaultSize : Int = 1
425
+
426
+ override def asNullable : BooleanType = this
423
427
}
424
428
425
429
case object BooleanType extends BooleanType
@@ -450,6 +454,8 @@ class TimestampType private() extends NativeType {
450
454
* The default size of a value of the TimestampType is 12 bytes.
451
455
*/
452
456
override def defaultSize : Int = 12
457
+
458
+ override def asNullable : TimestampType = this
453
459
}
454
460
455
461
case object TimestampType extends TimestampType
@@ -478,6 +484,8 @@ class DateType private() extends NativeType {
478
484
* The default size of a value of the DateType is 4 bytes.
479
485
*/
480
486
override def defaultSize : Int = 4
487
+
488
+ override def asNullable : DateType = this
481
489
}
482
490
483
491
case object DateType extends DateType
@@ -536,6 +544,8 @@ class LongType private() extends IntegralType {
536
544
override def defaultSize : Int = 8
537
545
538
546
override def simpleString = " bigint"
547
+
548
+ override def asNullable : LongType = this
539
549
}
540
550
541
551
case object LongType extends LongType
@@ -565,6 +575,8 @@ class IntegerType private() extends IntegralType {
565
575
override def defaultSize : Int = 4
566
576
567
577
override def simpleString = " int"
578
+
579
+ override def asNullable : IntegerType = this
568
580
}
569
581
570
582
case object IntegerType extends IntegerType
@@ -594,6 +606,8 @@ class ShortType private() extends IntegralType {
594
606
override def defaultSize : Int = 2
595
607
596
608
override def simpleString = " smallint"
609
+
610
+ override def asNullable : ShortType = this
597
611
}
598
612
599
613
case object ShortType extends ShortType
@@ -623,6 +637,8 @@ class ByteType private() extends IntegralType {
623
637
override def defaultSize : Int = 1
624
638
625
639
override def simpleString = " tinyint"
640
+
641
+ override def asNullable : ByteType = this
626
642
}
627
643
628
644
case object ByteType extends ByteType
@@ -689,6 +705,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
689
705
case Some (PrecisionInfo (precision, scale)) => s " decimal( $precision, $scale) "
690
706
case None => " decimal(10,0)"
691
707
}
708
+
709
+ override def asNullable : DecimalType = this
692
710
}
693
711
694
712
@@ -747,6 +765,8 @@ class DoubleType private() extends FractionalType {
747
765
* The default size of a value of the DoubleType is 8 bytes.
748
766
*/
749
767
override def defaultSize : Int = 8
768
+
769
+ override def asNullable : DoubleType = this
750
770
}
751
771
752
772
case object DoubleType extends DoubleType
@@ -775,6 +795,8 @@ class FloatType private() extends FractionalType {
775
795
* The default size of a value of the FloatType is 4 bytes.
776
796
*/
777
797
override def defaultSize : Int = 4
798
+
799
+ override def asNullable : FloatType = this
778
800
}
779
801
780
802
case object FloatType extends FloatType
@@ -823,6 +845,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
823
845
override def defaultSize : Int = 100 * elementType.defaultSize
824
846
825
847
override def simpleString = s " array< ${elementType.simpleString}> "
848
+
849
+ override def asNullable : ArrayType = ArrayType (elementType.asNullable, containsNull = true )
826
850
}
827
851
828
852
@@ -1068,6 +1092,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
1068
1092
*/
1069
1093
private [sql] def merge (that : StructType ): StructType =
1070
1094
StructType .merge(this , that).asInstanceOf [StructType ]
1095
+
1096
+ override def asNullable : StructType = {
1097
+ val newFields = fields.map {
1098
+ case StructField (name, dataType, nullable, metadata) =>
1099
+ StructField (name, dataType.asNullable, nullable = true , metadata)
1100
+ }
1101
+
1102
+ StructType (newFields)
1103
+ }
1071
1104
}
1072
1105
1073
1106
@@ -1120,6 +1153,9 @@ case class MapType(
1120
1153
override def defaultSize : Int = 100 * (keyType.defaultSize + valueType.defaultSize)
1121
1154
1122
1155
override def simpleString = s " map< ${keyType.simpleString}, ${valueType.simpleString}> "
1156
+
1157
+ override def asNullable : MapType =
1158
+ MapType (keyType.asNullable, valueType.asNullable, valueContainsNull = true )
1123
1159
}
1124
1160
1125
1161
@@ -1173,4 +1209,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
1173
1209
* The default size of a value of the UserDefinedType is 4096 bytes.
1174
1210
*/
1175
1211
override def defaultSize : Int = 4096
1212
+
1213
+ override def sameType (other : DataType ): Boolean = ???
1214
+
1215
+ override def asNullable : DataType = ???
1176
1216
}
0 commit comments