Skip to content

Commit 96db384

Browse files
committed
support datetime type for SchemaRDD
1 parent a2715cc commit 96db384

File tree

4 files changed

+50
-8
lines changed

4 files changed

+50
-8
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,11 @@ private[spark] object PythonRDD extends Logging {
550550
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
551551
pyRDD.rdd.mapPartitions { iter =>
552552
val unpickle = new Unpickler
553-
// TODO: Figure out why flatMap is necessay for pyspark
554553
iter.flatMap { row =>
555554
unpickle.loads(row) match {
555+
// in case of objects are pickled in batch mode
556556
case objs: java.util.ArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
557-
// Incase the partition doesn't have a collection
557+
// not in batch mode
558558
case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
559559
}
560560
}

python/pyspark/sql.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@ def __init__(self, sparkContext, sqlContext=None):
4747
...
4848
ValueError:...
4949
50-
>>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L,
51-
... "boolean" : True}])
50+
>>> from datetime import datetime
51+
>>> allTypes = sc.parallelize([{"int": 1, "string": "string", "double": 1.0, "long": 1L,
52+
... "boolean": True, "time": datetime(2010, 1, 1, 1, 1, 1)}])
5253
>>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long,
53-
... x.boolean))
54+
... x.boolean, x.time))
5455
>>> srdd.collect()[0]
55-
(1, u'string', 1.0, 1, True)
56+
(1, u'string', 1.0, 1, True, datetime.datetime(2010, 1, 1, 1, 1, 1))
5657
"""
5758
self._sc = sparkContext
5859
self._jsc = self._sc._jsc

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,16 +357,52 @@ class SQLContext(@transient val sparkContext: SparkContext)
357357
case c: java.util.Map[_, _] =>
358358
val (key, value) = c.head
359359
MapType(typeFor(key), typeFor(value))
360+
case c: java.util.Calendar => TimestampType
360361
case c if c.getClass.isArray =>
361362
val elem = c.asInstanceOf[Array[_]].head
362363
ArrayType(typeFor(elem))
363364
case c => throw new Exception(s"Object of type $c cannot be used")
364365
}
365-
val schema = rdd.first().map { case (fieldName, obj) =>
366+
val firstRow = rdd.first()
367+
val schema = firstRow.map { case (fieldName, obj) =>
366368
AttributeReference(fieldName, typeFor(obj), true)()
367369
}.toSeq
368370

369-
val rowRdd = rdd.mapPartitions { iter =>
371+
def needTransform(obj: Any): Boolean = obj match {
372+
case c: java.util.List[_] => c.exists(needTransform)
373+
case c: java.util.Set[_] => c.exists(needTransform)
374+
case c: java.util.Map[_, _] => c.exists {
375+
case (key, value) => needTransform(key) || needTransform(value)
376+
}
377+
case c if c.getClass.isArray =>
378+
c.asInstanceOf[Array[_]].exists(needTransform)
379+
case c: java.util.Calendar => true
380+
case c => false
381+
}
382+
383+
def transform(obj: Any): Any = obj match {
384+
case c: java.util.List[_] => c.map(transform)
385+
case c: java.util.Set[_] => c.map(transform)
386+
case c: java.util.Map[_, _] => c.map {
387+
case (key, value) => (transform(key), transform(value))
388+
}
389+
case c if c.getClass.isArray =>
390+
c.asInstanceOf[Array[_]].map(transform)
391+
case c: java.util.Calendar =>
392+
new java.sql.Timestamp(c.getTime().getTime())
393+
case c => c
394+
}
395+
396+
val need = firstRow.exists {case (key, value) => needTransform(value)}
397+
val transformed = if (need) {
398+
rdd.mapPartitions { iter =>
399+
iter.map {
400+
m => m.map {case (key, value) => (key, transform(value))}
401+
}
402+
}
403+
} else rdd
404+
405+
val rowRdd = transformed.mapPartitions { iter =>
370406
iter.map { map =>
371407
new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row
372408
}

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,11 @@ class SchemaRDD(
395395
arr.asInstanceOf[Array[Any]].map {
396396
element => rowToMap(element.asInstanceOf[Row], struct)
397397
}
398+
case t: java.sql.Timestamp => {
399+
val c = java.util.Calendar.getInstance()
400+
c.setTimeInMillis(t.getTime())
401+
c
402+
}
398403
case other => other
399404
}
400405
map.put(attrName, arrayValues)

0 commit comments

Comments
 (0)