@@ -791,6 +791,10 @@ def _restore_object(dataType, obj):
791
791
792
792
def _create_object (cls , v ):
793
793
""" Create an customized object with class `cls`. """
794
+ # datetime.date would be deserialized as datetime.datetime
795
+ # from java type, so we need to set it back.
796
+ if cls is datetime .date and isinstance (v , datetime .datetime ):
797
+ return v .date ()
794
798
return cls (v ) if v is not None else v
795
799
796
800
@@ -804,14 +808,16 @@ def getter(self):
804
808
return getter
805
809
806
810
807
- def _has_struct (dt ):
808
- """Return whether `dt` is or has StructType in it"""
811
+ def _has_struct_or_date (dt ):
812
+ """Return whether `dt` is or has StructType/DateType in it"""
809
813
if isinstance (dt , StructType ):
810
814
return True
811
815
elif isinstance (dt , ArrayType ):
812
- return _has_struct (dt .elementType )
816
+ return _has_struct_or_date (dt .elementType )
813
817
elif isinstance (dt , MapType ):
814
- return _has_struct (dt .valueType )
818
+ return _has_struct_or_date (dt .valueType )
819
+ elif isinstance (dt , DateType ):
820
+ return True
815
821
return False
816
822
817
823
@@ -824,7 +830,7 @@ def _create_properties(fields):
824
830
or keyword .iskeyword (name )):
825
831
warnings .warn ("field name %s can not be accessed in Python,"
826
832
"use position to access it instead" % name )
827
- if _has_struct (f .dataType ):
833
+ if _has_struct_or_date (f .dataType ):
828
834
# delay creating object until accessing it
829
835
getter = _create_getter (f .dataType , i )
830
836
else :
@@ -879,6 +885,9 @@ def Dict(d):
879
885
880
886
return Dict
881
887
888
+ elif isinstance (dataType , DateType ):
889
+ return datetime .date
890
+
882
891
elif not isinstance (dataType , StructType ):
883
892
raise Exception ("unexpected data type: %s" % dataType )
884
893
@@ -1098,7 +1107,7 @@ def applySchema(self, rdd, schema):
1098
1107
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
1099
1108
... x.time, x.map["a"], x.struct.b, x.list, x.null))
1100
1109
>>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
1101
- (127, -128, -32768, 32767, 2147483647, 1.0, datetime.datetime (2010, 1, 1, 0, 0 ),
1110
+ (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date (2010, 1, 1),
1102
1111
datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
1103
1112
1104
1113
>>> srdd.registerTempTable("table2")
0 commit comments