Skip to content

Commit 1cb35fe

Browse files
committed
Add "valueContainsNull" to MapType.
1 parent 3edb3ae commit 1cb35fe

File tree

13 files changed

+76
-118
lines changed

13 files changed

+76
-118
lines changed

python/pyspark/sql.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,28 +191,32 @@ class MapType(object):
191191
The data type representing dict values.
192192
193193
"""
194-
def __init__(self, keyType, valueType):
194+
def __init__(self, keyType, valueType, valueContainsNull=True):
195195
"""Creates a MapType
196196
:param keyType: the data type of keys.
197197
:param valueType: the data type of values.
198+
:param valueContainsNull: indicates whether values contains null values.
198199
:return:
199200
200-
>>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType)
201+
>>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True)
201202
True
202-
>>> MapType(StringType, IntegerType) == MapType(StringType, FloatType)
203+
>>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType)
203204
False
204205
"""
205206
self.keyType = keyType
206207
self.valueType = valueType
208+
self.valueContainsNull = valueContainsNull
207209

208210
def _get_scala_type_string(self):
209211
return "MapType(" + self.keyType._get_scala_type_string() + "," + \
210-
self.valueType._get_scala_type_string() + ")"
212+
self.valueType._get_scala_type_string() + "," + \
213+
str(self.valueContainsNull).lower() + ")"
211214

212215
def __eq__(self, other):
213216
return (isinstance(other, self.__class__) and \
214217
self.keyType == other.keyType and \
215-
self.valueType == other.valueType)
218+
self.valueType == other.valueType and \
219+
self.valueContainsNull == other.valueContainsNull)
216220

217221
def __ne__(self, other):
218222
return not self.__eq__(other)
@@ -369,7 +373,7 @@ def _parse_datatype_string(datatype_string):
369373
>>> check_datatype(complex_arraytype)
370374
True
371375
>>> # Complex MapType.
372-
>>> complex_maptype = MapType(complex_structtype, complex_arraytype)
376+
>>> complex_maptype = MapType(complex_structtype, complex_arraytype, False)
373377
>>> check_datatype(complex_maptype)
374378
True
375379
"""
@@ -409,8 +413,12 @@ def _parse_datatype_string(datatype_string):
409413
elementType = _parse_datatype_string(rest_part[:last_comma_index].strip())
410414
return ArrayType(elementType, containsNull)
411415
elif type_or_field == "MapType":
412-
keyType, valueType = _parse_datatype_list(rest_part.strip())
413-
return MapType(keyType, valueType)
416+
last_comma_index = rest_part.rfind(",")
417+
valueContainsNull = True
418+
if rest_part[last_comma_index+1:].strip().lower() == "false":
419+
valueContainsNull = False
420+
keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip())
421+
return MapType(keyType, valueType, valueContainsNull)
414422
elif type_or_field == "StructField":
415423
first_comma_index = rest_part.find(",")
416424
name = rest_part[:first_comma_index].strip()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression {
3232
override def references = children.flatMap(_.references).toSet
3333
def dataType = child.dataType match {
3434
case ArrayType(dt, _) => dt
35-
case MapType(_, vt) => vt
35+
case MapType(_, vt, _) => vt
3636
}
3737
override lazy val resolved =
3838
childrenResolved &&

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
8585

8686
private lazy val elementTypes = child.dataType match {
8787
case ArrayType(et, _) => et :: Nil
88-
case MapType(kt,vt) => kt :: vt :: Nil
88+
case MapType(kt,vt, _) => kt :: vt :: Nil
8989
}
9090

9191
// TODO: Move this pattern into Generator.
@@ -105,7 +105,7 @@ case class Explode(attributeNames: Seq[String], child: Expression)
105105
case ArrayType(_, _) =>
106106
val inputArray = child.eval(input).asInstanceOf[Seq[Any]]
107107
if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v)))
108-
case MapType(_, _) =>
108+
case MapType(_, _, _) =>
109109
val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]]
110110
if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) }
111111
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ object DataType extends RegexParsers {
5353
}
5454

5555
protected lazy val mapType: Parser[DataType] =
56-
"MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ {
57-
case t1 ~ _ ~ t2 => MapType(t1, t2)
56+
"MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ {
57+
case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull)
5858
}
5959

6060
protected lazy val structField: Parser[StructField] =
@@ -344,7 +344,16 @@ case class StructType(fields: Seq[StructField]) extends DataType {
344344
def simpleString: String = "struct"
345345
}
346346

347-
case class MapType(keyType: DataType, valueType: DataType) extends DataType {
347+
object MapType {
348+
/**
349+
* Construct a [[MapType]] object with the given key type and value type.
350+
* The `valueContainsNull` is true.
351+
*/
352+
def apply(keyType: DataType, valueType: DataType): MapType =
353+
MapType(keyType: DataType, valueType: DataType, true)
354+
}
355+
356+
case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType {
348357
private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
349358
builder.append(s"${prefix}-- key: ${keyType.simpleString}\n")
350359
builder.append(s"${prefix}-- value: ${valueType.simpleString}\n")

sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,30 @@ public static MapType createMapType(DataType keyType, DataType valueType) {
111111
throw new IllegalArgumentException("valueType should not be null.");
112112
}
113113

114-
return new MapType(keyType, valueType);
114+
return new MapType(keyType, valueType, true);
115+
}
116+
117+
/**
118+
* Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of
119+
* values ({@code keyType}), and whether values contain any null value
120+
* ({@code valueContainsNull}).
121+
* @param keyType
122+
* @param valueType
123+
* @param valueContainsNull
124+
* @return
125+
*/
126+
public static MapType createMapType(
127+
DataType keyType,
128+
DataType valueType,
129+
boolean valueContainsNull) {
130+
if (keyType == null) {
131+
throw new IllegalArgumentException("keyType should not be null.");
132+
}
133+
if (valueType == null) {
134+
throw new IllegalArgumentException("valueType should not be null.");
135+
}
136+
137+
return new MapType(keyType, valueType, valueContainsNull);
115138
}
116139

117140
/**

sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@
2626
public class MapType extends DataType {
2727
private DataType keyType;
2828
private DataType valueType;
29+
private boolean valueContainsNull;
2930

30-
protected MapType(DataType keyType, DataType valueType) {
31+
protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) {
3132
this.keyType = keyType;
3233
this.valueType = valueType;
34+
this.valueContainsNull = valueContainsNull;
3335
}
3436

3537
public DataType getKeyType() {
@@ -40,13 +42,18 @@ public DataType getValueType() {
4042
return valueType;
4143
}
4244

45+
public boolean isValueContainsNull() {
46+
return valueContainsNull;
47+
}
48+
4349
@Override
4450
public boolean equals(Object o) {
4551
if (this == o) return true;
4652
if (o == null || getClass() != o.getClass()) return false;
4753

4854
MapType mapType = (MapType) o;
4955

56+
if (valueContainsNull != mapType.valueContainsNull) return false;
5057
if (!keyType.equals(mapType.keyType)) return false;
5158
if (!valueType.equals(mapType.valueType)) return false;
5259

@@ -57,6 +64,7 @@ public boolean equals(Object o) {
5764
public int hashCode() {
5865
int result = keyType.hashCode();
5966
result = 31 * result + valueType.hashCode();
67+
result = 31 * result + (valueContainsNull ? 1 : 0);
6068
return result;
6169
}
6270
}

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -443,98 +443,4 @@ class SQLContext(@transient val sparkContext: SparkContext)
443443
}
444444
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd)))
445445
}
446-
447-
/**
448-
* Returns the equivalent StructField in Scala for the given StructField in Java.
449-
*/
450-
protected def asJavaStructField(scalaStructField: StructField): JStructField = {
451-
org.apache.spark.sql.api.java.types.DataType.createStructField(
452-
scalaStructField.name,
453-
asJavaDataType(scalaStructField.dataType),
454-
scalaStructField.nullable)
455-
}
456-
457-
/**
458-
* Returns the equivalent DataType in Java for the given DataType in Scala.
459-
*/
460-
protected[sql] def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match {
461-
case StringType =>
462-
org.apache.spark.sql.api.java.types.DataType.StringType
463-
case BinaryType =>
464-
org.apache.spark.sql.api.java.types.DataType.BinaryType
465-
case BooleanType =>
466-
org.apache.spark.sql.api.java.types.DataType.BooleanType
467-
case TimestampType =>
468-
org.apache.spark.sql.api.java.types.DataType.TimestampType
469-
case DecimalType =>
470-
org.apache.spark.sql.api.java.types.DataType.DecimalType
471-
case DoubleType =>
472-
org.apache.spark.sql.api.java.types.DataType.DoubleType
473-
case FloatType =>
474-
org.apache.spark.sql.api.java.types.DataType.FloatType
475-
case ByteType =>
476-
org.apache.spark.sql.api.java.types.DataType.ByteType
477-
case IntegerType =>
478-
org.apache.spark.sql.api.java.types.DataType.IntegerType
479-
case LongType =>
480-
org.apache.spark.sql.api.java.types.DataType.LongType
481-
case ShortType =>
482-
org.apache.spark.sql.api.java.types.DataType.ShortType
483-
484-
case arrayType: ArrayType =>
485-
org.apache.spark.sql.api.java.types.DataType.createArrayType(
486-
asJavaDataType(arrayType.elementType), arrayType.containsNull)
487-
case mapType: MapType =>
488-
org.apache.spark.sql.api.java.types.DataType.createMapType(
489-
asJavaDataType(mapType.keyType), asJavaDataType(mapType.valueType))
490-
case structType: StructType =>
491-
org.apache.spark.sql.api.java.types.DataType.createStructType(
492-
structType.fields.map(asJavaStructField).asJava)
493-
}
494-
495-
/**
496-
* Returns the equivalent StructField in Scala for the given StructField in Java.
497-
*/
498-
protected def asScalaStructField(javaStructField: JStructField): StructField = {
499-
StructField(
500-
javaStructField.getName,
501-
asScalaDataType(javaStructField.getDataType),
502-
javaStructField.isNullable)
503-
}
504-
505-
/**
506-
* Returns the equivalent DataType in Scala for the given DataType in Java.
507-
*/
508-
protected[sql] def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match {
509-
case stringType: org.apache.spark.sql.api.java.types.StringType =>
510-
StringType
511-
case binaryType: org.apache.spark.sql.api.java.types.BinaryType =>
512-
BinaryType
513-
case booleanType: org.apache.spark.sql.api.java.types.BooleanType =>
514-
BooleanType
515-
case timestampType: org.apache.spark.sql.api.java.types.TimestampType =>
516-
TimestampType
517-
case decimalType: org.apache.spark.sql.api.java.types.DecimalType =>
518-
DecimalType
519-
case doubleType: org.apache.spark.sql.api.java.types.DoubleType =>
520-
DoubleType
521-
case floatType: org.apache.spark.sql.api.java.types.FloatType =>
522-
FloatType
523-
case byteType: org.apache.spark.sql.api.java.types.ByteType =>
524-
ByteType
525-
case integerType: org.apache.spark.sql.api.java.types.IntegerType =>
526-
IntegerType
527-
case longType: org.apache.spark.sql.api.java.types.LongType =>
528-
LongType
529-
case shortType: org.apache.spark.sql.api.java.types.ShortType =>
530-
ShortType
531-
532-
case arrayType: org.apache.spark.sql.api.java.types.ArrayType =>
533-
ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull)
534-
case mapType: org.apache.spark.sql.api.java.types.MapType =>
535-
MapType(asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType))
536-
case structType: org.apache.spark.sql.api.java.types.StructType =>
537-
StructType(structType.getFields.map(asScalaStructField))
538-
}
539-
540446
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ private[sql] object CatalystConverter {
8585
case StructType(fields: Seq[StructField]) => {
8686
new CatalystStructConverter(fields.toArray, fieldIndex, parent)
8787
}
88-
case MapType(keyType: DataType, valueType: DataType) => {
88+
case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => {
8989
new CatalystMapConverter(
9090
Array(
9191
new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false),
92-
new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)),
92+
new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)),
9393
fieldIndex,
9494
parent)
9595
}

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
175175
case t @ ArrayType(_, false) => writeArray(
176176
t,
177177
value.asInstanceOf[CatalystConverter.ArrayScalaType[_]])
178-
case t @ MapType(_, _) => writeMap(
178+
case t @ MapType(_, _, _) => writeMap(
179179
t,
180180
value.asInstanceOf[CatalystConverter.MapScalaType[_, _]])
181181
case t @ StructType(_) => writeStruct(

sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ private[parquet] object ParquetTypesConverter extends Logging {
130130
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
131131
val valueType = toDataType(keyValueGroup.getFields.apply(1))
132132
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
133+
// TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
134+
// at here.
133135
MapType(keyType, valueType)
134136
}
135137
case _ => {
@@ -140,6 +142,8 @@ private[parquet] object ParquetTypesConverter extends Logging {
140142
assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED)
141143
val valueType = toDataType(keyValueGroup.getFields.apply(1))
142144
assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED)
145+
// TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true
146+
// at here.
143147
MapType(keyType, valueType)
144148
} else if (correspondsToArray(groupType)) { // ArrayType
145149
val elementType = toDataType(groupType.getFields.apply(0))
@@ -248,7 +252,7 @@ private[parquet] object ParquetTypesConverter extends Logging {
248252
}
249253
new ParquetGroupType(repetition, name, fields)
250254
}
251-
case MapType(keyType, valueType) => {
255+
case MapType(keyType, valueType, _) => {
252256
val parquetKeyType =
253257
fromDataType(
254258
keyType,

0 commit comments

Comments
 (0)