Skip to content

Commit c712fbf

Browse files
committed
Converts types of values based on defined schema.
1 parent 4ceeb66 commit c712fbf

File tree

5 files changed

+134
-96
lines changed

5 files changed

+134
-96
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ private[spark] object PythonRDD extends Logging {
544544
}
545545

546546
/**
547-
* Convert an RDD of serialized Python dictionaries to Scala Maps
547+
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
548+
* It is only used by pyspark.sql.
548549
* TODO: Support more Python types.
549550
*/
550551
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {

python/pyspark/sql.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,7 @@ def __repr__(self):
111111
class FloatType(object):
112112
"""Spark SQL FloatType
113113
114-
For now, please use L{DoubleType} instead of using L{FloatType}.
115-
Because query evaluation is done in Scala, java.lang.Double will be be used
116-
for Python float numbers. Because the underlying JVM type of FloatType is
117-
java.lang.Float (in Java) and Float (in scala), and we are trying to cast the type,
118-
there will be a java.lang.ClassCastException
119-
if FloatType (Python) is used.
114+
The data type representing single precision floating-point values.
120115
121116
"""
122117
__metaclass__ = PrimitiveTypeSingleton
@@ -128,12 +123,7 @@ def __repr__(self):
128123
class ByteType(object):
129124
"""Spark SQL ByteType
130125
131-
For now, please use L{IntegerType} instead of using L{ByteType}.
132-
Because query evaluation is done in Scala, java.lang.Integer will be be used
133-
for Python int numbers. Because the underlying JVM type of ByteType is
134-
java.lang.Byte (in Java) and Byte (in scala), and we are trying to cast the type,
135-
there will be a java.lang.ClassCastException
136-
if ByteType (Python) is used.
126+
The data type representing int values with 1 singed byte.
137127
138128
"""
139129
__metaclass__ = PrimitiveTypeSingleton
@@ -170,12 +160,7 @@ def __repr__(self):
170160
class ShortType(object):
171161
"""Spark SQL ShortType
172162
173-
For now, please use L{IntegerType} instead of using L{ShortType}.
174-
Because query evaluation is done in Scala, java.lang.Integer will be be used
175-
for Python int numbers. Because the underlying JVM type of ShortType is
176-
java.lang.Short (in Java) and Short (in scala), and we are trying to cast the type,
177-
there will be a java.lang.ClassCastException
178-
if ShortType (Python) is used.
163+
The data type representing int values with 2 signed bytes.
179164
180165
"""
181166
__metaclass__ = PrimitiveTypeSingleton
@@ -198,7 +183,6 @@ def __init__(self, elementType, containsNull=False):
198183
199184
:param elementType: the data type of elements.
200185
:param containsNull: indicates whether the list contains None values.
201-
:return:
202186
203187
>>> ArrayType(StringType) == ArrayType(StringType, False)
204188
True
@@ -238,7 +222,6 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
238222
:param keyType: the data type of keys.
239223
:param valueType: the data type of values.
240224
:param valueContainsNull: indicates whether values contains null values.
241-
:return:
242225
243226
>>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True)
244227
True
@@ -279,7 +262,6 @@ def __init__(self, name, dataType, nullable):
279262
:param name: the name of this field.
280263
:param dataType: the data type of this field.
281264
:param nullable: indicates whether values of this field can be null.
282-
:return:
283265
284266
>>> StructField("f1", StringType, True) == StructField("f1", StringType, True)
285267
True
@@ -314,8 +296,6 @@ class StructType(object):
314296
"""
315297
def __init__(self, fields):
316298
"""Creates a StructType
317-
:param fields:
318-
:return:
319299
320300
>>> struct1 = StructType([StructField("f1", StringType, True)])
321301
>>> struct2 = StructType([StructField("f1", StringType, True)])
@@ -342,11 +322,7 @@ def __ne__(self, other):
342322

343323

344324
def _parse_datatype_list(datatype_list_string):
345-
"""Parses a list of comma separated data types.
346-
347-
:param datatype_list_string:
348-
:return:
349-
"""
325+
"""Parses a list of comma separated data types."""
350326
index = 0
351327
datatype_list = []
352328
start = 0
@@ -372,9 +348,6 @@ def _parse_datatype_list(datatype_list_string):
372348
def _parse_datatype_string(datatype_string):
373349
"""Parses the given data type string.
374350
375-
:param datatype_string:
376-
:return:
377-
378351
>>> def check_datatype(datatype):
379352
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__())
380353
... python_datatype = _parse_datatype_string(scala_datatype.toString())
@@ -582,9 +555,6 @@ def inferSchema(self, rdd):
582555

583556
def applySchema(self, rdd, schema):
584557
"""Applies the given schema to the given RDD of L{dict}s.
585-
:param rdd:
586-
:param schema:
587-
:return:
588558
589559
>>> schema = StructType([StructField("field1", IntegerType(), False),
590560
... StructField("field2", StringType(), False)])
@@ -594,9 +564,27 @@ def applySchema(self, rdd, schema):
594564
>>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
595565
... {"field1" : 3, "field2": "row3"}]
596566
True
567+
>>> from datetime import datetime
568+
>>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0,
569+
... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2},
570+
... "list": [1, 2, 3]}])
571+
>>> schema = StructType([
572+
... StructField("byte", ByteType(), False),
573+
... StructField("short", ShortType(), False),
574+
... StructField("float", FloatType(), False),
575+
... StructField("time", TimestampType(), False),
576+
... StructField("map", MapType(StringType(), IntegerType(), False), False),
577+
... StructField("struct", StructType([StructField("b", ShortType(), False)]), False),
578+
... StructField("list", ArrayType(ByteType(), False), False),
579+
... StructField("null", DoubleType(), True)])
580+
>>> srdd = sqlCtx.applySchema(rdd, schema).map(
581+
... lambda x: (
582+
... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null))
583+
>>> srdd.collect()[0]
584+
(127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
597585
"""
598586
jrdd = self._pythonToJavaMap(rdd._jrdd)
599-
srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema.__repr__())
587+
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__())
600588
return SchemaRDD(srdd, self)
601589

602590
def registerRDDAsTable(self, rdd, tableName):

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,13 +313,13 @@ case class StructType(fields: Seq[StructField]) extends DataType {
313313
*/
314314
lazy val fieldNames: Seq[String] = fields.map(_.name)
315315
private lazy val fieldNamesSet: Set[String] = fieldNames.toSet
316-
316+
private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap
317317
/**
318318
* Extracts a [[StructField]] of the given name. If the [[StructType]] object does not
319319
* have a name matching the given name, `null` will be returned.
320320
*/
321321
def apply(name: String): StructField = {
322-
fields.find(f => f.name == name).getOrElse(
322+
nameToField.get(name).getOrElse(
323323
throw new IllegalArgumentException(s"Field ${name} does not exist."))
324324
}
325325

@@ -333,6 +333,7 @@ case class StructType(fields: Seq[StructField]) extends DataType {
333333
throw new IllegalArgumentException(
334334
s"Field ${nonExistFields.mkString(",")} does not exist.")
335335
}
336+
// Preserve the original order of fields.
336337
StructType(fields.filter(f => names.contains(f.name)))
337338
}
338339

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 103 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -125,29 +125,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
125125
new SchemaRDD(this, logicalPlan)
126126
}
127127

128-
/**
129-
* Parses the data type in our internal string representation. The data type string should
130-
* have the same format as the one generated by `toString` in scala.
131-
* It is only used by PySpark.
132-
*/
133-
private[sql] def parseDataType(dataTypeString: String): DataType = {
134-
val parser = org.apache.spark.sql.catalyst.types.DataType
135-
parser(dataTypeString)
136-
}
137-
138-
/**
139-
* Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
140-
*/
141-
private[sql] def applySchema(rdd: RDD[Map[String, _]], schemaString: String): SchemaRDD = {
142-
val schema = parseDataType(schemaString).asInstanceOf[StructType]
143-
val rowRdd = rdd.mapPartitions { iter =>
144-
iter.map { map =>
145-
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
146-
}
147-
}
148-
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd)))
149-
}
150-
151128
/**
152129
* Loads a Parquet file, returning the result as a [[SchemaRDD]].
153130
*
@@ -438,6 +415,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
438415
*/
439416
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
440417
import scala.collection.JavaConversions._
418+
441419
def typeOfComplexValue: PartialFunction[Any, DataType] = {
442420
case c: java.util.Calendar => TimestampType
443421
case c: java.util.List[_] =>
@@ -453,48 +431,116 @@ class SQLContext(@transient val sparkContext: SparkContext)
453431
def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue
454432

455433
val firstRow = rdd.first()
456-
val schema = StructType(
457-
firstRow.map { case (fieldName, obj) =>
458-
StructField(fieldName, typeOfObject(obj), true)
459-
}.toSeq)
460-
461-
def needTransform(obj: Any): Boolean = obj match {
462-
case c: java.util.List[_] => true
463-
case c: java.util.Map[_, _] => true
464-
case c if c.getClass.isArray => true
465-
case c: java.util.Calendar => true
466-
case c => false
434+
val fields = firstRow.map {
435+
case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true)
436+
}.toSeq
437+
438+
applySchemaToPythonRDD(rdd, StructType(fields))
439+
}
440+
441+
/**
442+
* Parses the data type in our internal string representation. The data type string should
443+
* have the same format as the one generated by `toString` in scala.
444+
* It is only used by PySpark.
445+
*/
446+
private[sql] def parseDataType(dataTypeString: String): DataType = {
447+
val parser = org.apache.spark.sql.catalyst.types.DataType
448+
parser(dataTypeString)
449+
}
450+
451+
/**
452+
* Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
453+
*/
454+
private[sql] def applySchemaToPythonRDD(
455+
rdd: RDD[Map[String, _]],
456+
schemaString: String): SchemaRDD = {
457+
val schema = parseDataType(schemaString).asInstanceOf[StructType]
458+
applySchemaToPythonRDD(rdd, schema)
459+
}
460+
461+
/**
462+
* Apply a schema defined by the schema to an RDD. It is only used by PySpark.
463+
*/
464+
private[sql] def applySchemaToPythonRDD(
465+
rdd: RDD[Map[String, _]],
466+
schema: StructType): SchemaRDD = {
467+
import scala.collection.JavaConversions._
468+
import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
469+
470+
def needsConversion(dataType: DataType): Boolean = dataType match {
471+
case ByteType => true
472+
case ShortType => true
473+
case FloatType => true
474+
case TimestampType => true
475+
case ArrayType(_, _) => true
476+
case MapType(_, _, _) => true
477+
case StructType(_) => true
478+
case other => false
467479
}
468480

469-
// convert JList, JArray into Seq, convert JMap into Map
470-
// convert Calendar into Timestamp
471-
def transform(obj: Any): Any = obj match {
472-
case c: java.util.List[_] => c.map(transform).toSeq
473-
case c: java.util.Map[_, _] => c.map {
474-
case (key, value) => (key, transform(value))
475-
}.toMap
476-
case c if c.getClass.isArray =>
477-
c.asInstanceOf[Array[_]].map(transform).toSeq
478-
case c: java.util.Calendar =>
479-
new java.sql.Timestamp(c.getTime().getTime())
480-
case c => c
481+
// Converts value to the type specified by the data type.
482+
// Because Python does not have data types for TimestampType, FloatType, ShortType, and
483+
// ByteType, we need to explicitly convert values in columns of these data types to the desired
484+
// JVM data types.
485+
def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match {
486+
// TODO: We should check nullable
487+
case (null, _) => null
488+
489+
case (c: java.util.List[_], ArrayType(elementType, _)) =>
490+
val converted = c.map { e => convert(e, elementType)}
491+
JListWrapper(converted)
492+
493+
case (c: java.util.Map[_, _], struct: StructType) =>
494+
val row = new GenericMutableRow(struct.fields.length)
495+
struct.fields.zipWithIndex.foreach {
496+
case (field, i) =>
497+
val value = convert(c.get(field.name), field.dataType)
498+
row.update(i, value)
499+
}
500+
row
501+
502+
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) =>
503+
val converted = c.map {
504+
case (key, value) =>
505+
(convert(key, keyType), convert(value, valueType))
506+
}
507+
JMapWrapper(converted)
508+
509+
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
510+
val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType))
511+
converted: Seq[Any]
512+
513+
case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime())
514+
case (c: Int, ByteType) => c.toByte
515+
case (c: Int, ShortType) => c.toShort
516+
case (c: Double, FloatType) => c.toFloat
517+
518+
case (c, _) => c
519+
}
520+
521+
val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
522+
rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) })
523+
} else {
524+
rdd
481525
}
482526

483-
val need = firstRow.exists { case (key, value) => needTransform(value) }
484-
val transformed = if (need) {
485-
rdd.mapPartitions { iter =>
486-
iter.map {
487-
m => m.map {case (key, value) => (key, transform(value))}
527+
val rowRdd = convertedRdd.mapPartitions { iter =>
528+
val row = new GenericMutableRow(schema.fields.length)
529+
val fieldsWithIndex = schema.fields.zipWithIndex
530+
iter.map { m =>
531+
// We cannot use m.values because the order of values returned by m.values may not
532+
// match fields order.
533+
fieldsWithIndex.foreach {
534+
case (field, i) =>
535+
val value =
536+
m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull
537+
row.update(i, value)
488538
}
489-
}
490-
} else rdd
491539

492-
val rowRdd = transformed.mapPartitions { iter =>
493-
iter.map { map =>
494-
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
540+
row: Row
495541
}
496542
}
543+
497544
new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd)))
498545
}
499-
500546
}

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20-
import java.util.{Map => JMap, List => JList, Set => JSet}
20+
import java.util.{Map => JMap, List => JList}
2121

2222
import scala.collection.JavaConversions._
2323
import scala.collection.JavaConverters._
@@ -380,6 +380,8 @@ class SchemaRDD(
380380
* Converts a JavaRDD to a PythonRDD. It is used by pyspark.
381381
*/
382382
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
383+
import scala.collection.Map
384+
383385
def toJava(obj: Any, dataType: DataType): Any = dataType match {
384386
case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct)
385387
case array: ArrayType => obj match {

0 commit comments

Comments
 (0)