Skip to content

Commit 0cb7ea2

Browse files
committed
marmbrus's comments.
1 parent 3cec464 commit 0cb7ea2

File tree

8 files changed

+66
-29
lines changed

8 files changed

+66
-29
lines changed

mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ private[mllib] object Loader {
110110
assert(loadedFields.contains(field.name), s"Unable to parse model data." +
111111
s" Expected field with name ${field.name} was missing in loaded schema:" +
112112
s" ${loadedFields.mkString(", ")}")
113-
assert(DataType.equalsIgnoreNullability(loadedFields(field.name), field.dataType),
113+
assert(loadedFields(field.name).sameType(field.dataType),
114114
s"Unable to parse model data. Expected field $field but found field" +
115115
s" with different type: ${loadedFields(field.name)}")
116116
}

sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ object DataType {
181181
/**
182182
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
183183
*/
184-
private[spark] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
184+
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
185185
(left, right) match {
186186
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
187187
equalsIgnoreNullability(leftElementType, rightElementType)
@@ -213,7 +213,7 @@ object DataType {
213213
* if and only if for all every pair of fields, `to.nullable` is true, or both
214214
* of `fromField.nullable` and `toField.nullable` are false.
215215
*/
216-
private[spark] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
216+
private[sql] def equalsIgnoreCompatibleNullability(from: DataType, to: DataType): Boolean = {
217217
(from, to) match {
218218
case (ArrayType(fromElement, fn), ArrayType(toElement, tn)) =>
219219
(tn || !fn) && equalsIgnoreCompatibleNullability(fromElement, toElement)
@@ -235,20 +235,6 @@ object DataType {
235235
case (fromDataType, toDataType) => fromDataType == toDataType
236236
}
237237
}
238-
239-
/** Sets all nullable/containsNull/valueContainsNull to true. */
240-
private[spark] def alwaysNullable(dataType: DataType): DataType = dataType match {
241-
case ArrayType(elementType, _) =>
242-
ArrayType(alwaysNullable(elementType), containsNull = true)
243-
case MapType(keyType, valueType, _) =>
244-
MapType(alwaysNullable(keyType), alwaysNullable(valueType), valueContainsNull = true)
245-
case StructType(fields) =>
246-
val newFields = fields.map { field =>
247-
StructField(field.name, alwaysNullable(field.dataType), nullable = true)
248-
}
249-
StructType(newFields)
250-
case other => other
251-
}
252238
}
253239

254240

@@ -281,6 +267,16 @@ abstract class DataType {
281267
def prettyJson: String = pretty(render(jsonValue))
282268

283269
def simpleString: String = typeName
270+
271+
/** Check if `this` and `other` are the same data type when ignoring nullability
272+
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
273+
*/
274+
def sameType(other: DataType): Boolean = DataType.equalsIgnoreNullability(this, other)
275+
276+
/** Returns the same data type but set all nullability fields are true
277+
* (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`).
278+
*/
279+
def asNullable: DataType
284280
}
285281

286282
/**
@@ -296,6 +292,8 @@ class NullType private() extends DataType {
296292
// this type. Otherwise, the companion object would be of type "NullType$" in byte code.
297293
// Defined with a private constructor so the companion object is the only possible instantiation.
298294
override def defaultSize: Int = 1
295+
296+
override def asNullable: NullType = this
299297
}
300298

301299
case object NullType extends NullType
@@ -361,6 +359,8 @@ class StringType private() extends NativeType with PrimitiveType {
361359
* The default size of a value of the StringType is 4096 bytes.
362360
*/
363361
override def defaultSize: Int = 4096
362+
363+
override def asNullable: StringType = this
364364
}
365365

366366
case object StringType extends StringType
@@ -395,6 +395,8 @@ class BinaryType private() extends NativeType with PrimitiveType {
395395
* The default size of a value of the BinaryType is 4096 bytes.
396396
*/
397397
override def defaultSize: Int = 4096
398+
399+
override def asNullable: BinaryType = this
398400
}
399401

400402
case object BinaryType extends BinaryType
@@ -420,6 +422,8 @@ class BooleanType private() extends NativeType with PrimitiveType {
420422
* The default size of a value of the BooleanType is 1 byte.
421423
*/
422424
override def defaultSize: Int = 1
425+
426+
override def asNullable: BooleanType = this
423427
}
424428

425429
case object BooleanType extends BooleanType
@@ -450,6 +454,8 @@ class TimestampType private() extends NativeType {
450454
* The default size of a value of the TimestampType is 12 bytes.
451455
*/
452456
override def defaultSize: Int = 12
457+
458+
override def asNullable: TimestampType = this
453459
}
454460

455461
case object TimestampType extends TimestampType
@@ -478,6 +484,8 @@ class DateType private() extends NativeType {
478484
* The default size of a value of the DateType is 4 bytes.
479485
*/
480486
override def defaultSize: Int = 4
487+
488+
override def asNullable: DateType = this
481489
}
482490

483491
case object DateType extends DateType
@@ -536,6 +544,8 @@ class LongType private() extends IntegralType {
536544
override def defaultSize: Int = 8
537545

538546
override def simpleString = "bigint"
547+
548+
override def asNullable: LongType = this
539549
}
540550

541551
case object LongType extends LongType
@@ -565,6 +575,8 @@ class IntegerType private() extends IntegralType {
565575
override def defaultSize: Int = 4
566576

567577
override def simpleString = "int"
578+
579+
override def asNullable: IntegerType = this
568580
}
569581

570582
case object IntegerType extends IntegerType
@@ -594,6 +606,8 @@ class ShortType private() extends IntegralType {
594606
override def defaultSize: Int = 2
595607

596608
override def simpleString = "smallint"
609+
610+
override def asNullable: ShortType = this
597611
}
598612

599613
case object ShortType extends ShortType
@@ -623,6 +637,8 @@ class ByteType private() extends IntegralType {
623637
override def defaultSize: Int = 1
624638

625639
override def simpleString = "tinyint"
640+
641+
override def asNullable: ByteType = this
626642
}
627643

628644
case object ByteType extends ByteType
@@ -689,6 +705,8 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
689705
case Some(PrecisionInfo(precision, scale)) => s"decimal($precision,$scale)"
690706
case None => "decimal(10,0)"
691707
}
708+
709+
override def asNullable: DecimalType = this
692710
}
693711

694712

@@ -747,6 +765,8 @@ class DoubleType private() extends FractionalType {
747765
* The default size of a value of the DoubleType is 8 bytes.
748766
*/
749767
override def defaultSize: Int = 8
768+
769+
override def asNullable: DoubleType = this
750770
}
751771

752772
case object DoubleType extends DoubleType
@@ -775,6 +795,8 @@ class FloatType private() extends FractionalType {
775795
* The default size of a value of the FloatType is 4 bytes.
776796
*/
777797
override def defaultSize: Int = 4
798+
799+
override def asNullable: FloatType = this
778800
}
779801

780802
case object FloatType extends FloatType
@@ -823,6 +845,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
823845
override def defaultSize: Int = 100 * elementType.defaultSize
824846

825847
override def simpleString = s"array<${elementType.simpleString}>"
848+
849+
override def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true)
826850
}
827851

828852

@@ -1068,6 +1092,15 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
10681092
*/
10691093
private[sql] def merge(that: StructType): StructType =
10701094
StructType.merge(this, that).asInstanceOf[StructType]
1095+
1096+
override def asNullable: StructType = {
1097+
val newFields = fields.map {
1098+
case StructField(name, dataType, nullable, metadata) =>
1099+
StructField(name, dataType.asNullable, nullable = true, metadata)
1100+
}
1101+
1102+
StructType(newFields)
1103+
}
10711104
}
10721105

10731106

@@ -1120,6 +1153,9 @@ case class MapType(
11201153
override def defaultSize: Int = 100 * (keyType.defaultSize + valueType.defaultSize)
11211154

11221155
override def simpleString = s"map<${keyType.simpleString},${valueType.simpleString}>"
1156+
1157+
override def asNullable: MapType =
1158+
MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true)
11231159
}
11241160

11251161

@@ -1173,4 +1209,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
11731209
* The default size of a value of the UserDefinedType is 4096 bytes.
11741210
*/
11751211
override def defaultSize: Int = 4096
1212+
1213+
override def sameType(other: DataType): Boolean = ???
1214+
1215+
override def asNullable: DataType = ???
11761216
}

sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ private[sql] case class JSONRelation(
131131

132132
override def equals(other: Any): Boolean = other match {
133133
case that: JSONRelation =>
134-
(this.path == that.path) && (DataType.equalsIgnoreNullability(this.schema, that.schema))
134+
(this.path == that.path) && this.schema.sameType(that.schema)
135135
case _ => false
136136
}
137137
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,7 @@ private[sql] object ParquetRelation {
175175
ParquetRelation.enableLogForwarding()
176176
// This is a hack. We always set nullable/containsNull/valueContainsNull to true
177177
// for the schema of a parquet data.
178-
val schema =
179-
DataType.alwaysNullable(StructType.fromAttributes(attributes)).asInstanceOf[StructType]
178+
val schema = StructType.fromAttributes(attributes).asNullable
180179
val newAttributes = schema.toAttributes
181180
ParquetTypesConverter.writeMetaData(newAttributes, path, conf)
182181
new ParquetRelation(path.toString, Some(conf), sqlContext) {

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,7 @@ private[sql] case class InsertIntoParquetTable(
280280
val conf = ContextUtil.getConfiguration(job)
281281
// This is a hack. We always set nullable/containsNull/valueContainsNull to true
282282
// for the schema of a parquet data.
283-
val schema =
284-
DataType.alwaysNullable(StructType.fromAttributes(relation.output)).asInstanceOf[StructType]
283+
val schema = StructType.fromAttributes(relation.output).asNullable
285284
RowWriteSupport.setSchema(schema.toAttributes, conf)
286285

287286
val fspath = new Path(relation.path)

sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ private[sql] class DefaultSource
120120
val df =
121121
sqlContext.createDataFrame(
122122
data.queryExecution.toRdd,
123-
DataType.alwaysNullable(data.schema).asInstanceOf[StructType])
123+
data.schema.asNullable)
124124
val createdRelation =
125125
createRelation(sqlContext, parameters, df.schema).asInstanceOf[ParquetRelation2]
126126
createdRelation.insert(df, overwrite = mode == SaveMode.Overwrite)
@@ -185,11 +185,11 @@ private[sql] case class ParquetRelation2(
185185
val schemaEquality = if (shouldMergeSchemas) {
186186
shouldMergeSchemas == relation.shouldMergeSchemas
187187
} else {
188-
DataType.equalsIgnoreNullability(schema, relation.schema)
188+
schema.sameType(relation.schema)
189189
}
190190

191191
val maybeSchemaEquality = (maybeMetastoreSchema, relation.maybeMetastoreSchema) match {
192-
case (Some(left), Some(right)) => DataType.equalsIgnoreNullability(left, right)
192+
case (Some(left), Some(right)) => left.sameType(right)
193193
case (left, right) => left == right
194194
}
195195

sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
5656
child: LogicalPlan) = {
5757
val newChildOutput = expectedOutput.zip(child.output).map {
5858
case (expected, actual) =>
59-
val needCast = !DataType.equalsIgnoreNullability(expected.dataType, actual.dataType)
59+
val needCast = !expected.dataType.sameType(actual.dataType)
6060
// We want to make sure the filed names in the data to be inserted exactly match
6161
// names in the schema.
6262
val needRename = expected.name != actual.name

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
634634
p
635635
} else if (childOutputDataTypes.size == tableOutputDataTypes.size &&
636636
childOutputDataTypes.zip(tableOutputDataTypes)
637-
.forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) {
637+
.forall { case (left, right) => left.sameType(right) }) {
638638
// If both types ignoring nullability of ArrayType, MapType, StructType are the same,
639639
// use InsertIntoHiveTable instead of InsertIntoTable.
640640
InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite)
@@ -682,8 +682,7 @@ private[hive] case class InsertIntoHiveTable(
682682
override def output = child.output
683683

684684
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
685-
case (childAttr, tableAttr) =>
686-
DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType)
685+
case (childAttr, tableAttr) => childAttr.dataType.sameType(tableAttr.dataType)
687686
}
688687
}
689688

0 commit comments

Comments
 (0)