Skip to content

Commit c51a24d

Browse files
committed
convert datetime to date
1 parent 5670626 commit c51a24d

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

python/pyspark/sql.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,10 @@ def _restore_object(dataType, obj):
791791

792792
def _create_object(cls, v):
793793
""" Create an customized object with class `cls`. """
794+
# datetime.date would be deserialized as datetime.datetime
795+
# from java type, so we need to set it back.
796+
if cls is datetime.date and isinstance(v, datetime.datetime):
797+
return v.date()
794798
return cls(v) if v is not None else v
795799

796800

@@ -804,14 +808,16 @@ def getter(self):
804808
return getter
805809

806810

807-
def _has_struct(dt):
808-
"""Return whether `dt` is or has StructType in it"""
811+
def _has_struct_or_date(dt):
812+
"""Return whether `dt` is or has StructType/DateType in it"""
809813
if isinstance(dt, StructType):
810814
return True
811815
elif isinstance(dt, ArrayType):
812-
return _has_struct(dt.elementType)
816+
return _has_struct_or_date(dt.elementType)
813817
elif isinstance(dt, MapType):
814-
return _has_struct(dt.valueType)
818+
return _has_struct_or_date(dt.valueType)
819+
elif isinstance(dt, DateType):
820+
return True
815821
return False
816822

817823

@@ -824,7 +830,7 @@ def _create_properties(fields):
824830
or keyword.iskeyword(name)):
825831
warnings.warn("field name %s can not be accessed in Python,"
826832
"use position to access it instead" % name)
827-
if _has_struct(f.dataType):
833+
if _has_struct_or_date(f.dataType):
828834
# delay creating object until accessing it
829835
getter = _create_getter(f.dataType, i)
830836
else:
@@ -879,6 +885,9 @@ def Dict(d):
879885

880886
return Dict
881887

888+
elif isinstance(dataType, DateType):
889+
return datetime.date
890+
882891
elif not isinstance(dataType, StructType):
883892
raise Exception("unexpected data type: %s" % dataType)
884893

@@ -1098,7 +1107,7 @@ def applySchema(self, rdd, schema):
10981107
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
10991108
... x.time, x.map["a"], x.struct.b, x.list, x.null))
11001109
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
1101-
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.datetime(2010, 1, 1, 0, 0),
1110+
(127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
11021111
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
11031112
11041113
>>> srdd.registerTempTable("table2")

0 commit comments

Comments
 (0)