Skip to content

Commit 5b0a42c

Browse files
Davies Liumarmbrus
authored andcommitted
[SPARK-5898] [SPARK-5896] [SQL] [PySpark] create DataFrame from pandas and tuple/list
Fix createDataFrame() from pandas DataFrame (not tested by jenkins, depends on SPARK-5693). It also support to create DataFrame from plain tuple/list without column names, `_1`, `_2` will be used as column names. Author: Davies Liu <[email protected]> Closes #4679 from davies/pandas and squashes the following commits: c0cbe0b [Davies Liu] fix tests 8466d1d [Davies Liu] fix create DataFrame from pandas
1 parent 4a17eed commit 5b0a42c

File tree

3 files changed

+20
-20
lines changed

3 files changed

+20
-20
lines changed

python/pyspark/sql/context.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
351351
:return: a DataFrame
352352
353353
>>> l = [('Alice', 1)]
354+
>>> sqlCtx.createDataFrame(l).collect()
355+
[Row(_1=u'Alice', _2=1)]
354356
>>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
355357
[Row(name=u'Alice', age=1)]
356358
@@ -359,6 +361,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
359361
[Row(age=1, name=u'Alice')]
360362
361363
>>> rdd = sc.parallelize(l)
364+
>>> sqlCtx.createDataFrame(rdd).collect()
365+
[Row(_1=u'Alice', _2=1)]
362366
>>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
363367
>>> df.collect()
364368
[Row(name=u'Alice', age=1)]
@@ -377,14 +381,17 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
377381
>>> df3 = sqlCtx.createDataFrame(rdd, schema)
378382
>>> df3.collect()
379383
[Row(name=u'Alice', age=1)]
384+
385+
>>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
386+
[Row(name=u'Alice', age=1)]
380387
"""
381388
if isinstance(data, DataFrame):
382389
raise TypeError("data is already a DataFrame")
383390

384391
if has_pandas and isinstance(data, pandas.DataFrame):
385-
data = self._sc.parallelize(data.to_records(index=False))
386392
if schema is None:
387393
schema = list(data.columns)
394+
data = [r.tolist() for r in data.to_records(index=False)]
388395

389396
if not isinstance(data, RDD):
390397
try:
@@ -399,7 +406,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None):
399406
if isinstance(schema, (list, tuple)):
400407
first = data.first()
401408
if not isinstance(first, (list, tuple)):
402-
raise ValueError("each row in `rdd` should be list or tuple")
409+
raise ValueError("each row in `rdd` should be list or tuple, "
410+
"but got %r" % type(first))
403411
row_cls = Row(*schema)
404412
schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
405413

python/pyspark/sql/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_serialize_nested_array_and_map(self):
186186
self.assertEqual("2", row.d)
187187

188188
def test_infer_schema(self):
189-
d = [Row(l=[], d={}),
189+
d = [Row(l=[], d={}, s=None),
190190
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
191191
rdd = self.sc.parallelize(d)
192192
df = self.sqlCtx.createDataFrame(rdd)

python/pyspark/sql/types.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ def _infer_type(obj):
604604
ExamplePointUDT
605605
"""
606606
if obj is None:
607-
raise ValueError("Can not infer type for None")
607+
return NullType()
608608

609609
if hasattr(obj, '__UDT__'):
610610
return obj.__UDT__
@@ -637,15 +637,14 @@ def _infer_schema(row):
637637
if isinstance(row, dict):
638638
items = sorted(row.items())
639639

640-
elif isinstance(row, tuple):
640+
elif isinstance(row, (tuple, list)):
641641
if hasattr(row, "_fields"): # namedtuple
642642
items = zip(row._fields, tuple(row))
643643
elif hasattr(row, "__FIELDS__"): # Row
644644
items = zip(row.__FIELDS__, tuple(row))
645-
elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
646-
items = row
647645
else:
648-
raise ValueError("Can't infer schema from tuple")
646+
names = ['_%d' % i for i in range(1, len(row) + 1)]
647+
items = zip(names, row)
649648

650649
elif hasattr(row, "__dict__"): # object
651650
items = sorted(row.__dict__.items())
@@ -812,17 +811,10 @@ def convert_struct(obj):
812811
if obj is None:
813812
return
814813

815-
if isinstance(obj, tuple):
816-
if hasattr(obj, "_fields"):
817-
d = dict(zip(obj._fields, obj))
818-
elif hasattr(obj, "__FIELDS__"):
819-
d = dict(zip(obj.__FIELDS__, obj))
820-
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
821-
d = dict(obj)
822-
else:
823-
raise ValueError("unexpected tuple: %s" % str(obj))
814+
if isinstance(obj, (tuple, list)):
815+
return tuple(conv(v) for v, conv in zip(obj, converters))
824816

825-
elif isinstance(obj, dict):
817+
if isinstance(obj, dict):
826818
d = obj
827819
elif hasattr(obj, "__dict__"): # object
828820
d = obj.__dict__
@@ -1022,7 +1014,7 @@ def _verify_type(obj, dataType):
10221014
return
10231015

10241016
_type = type(dataType)
1025-
assert _type in _acceptable_types, "unkown datatype: %s" % dataType
1017+
assert _type in _acceptable_types, "unknown datatype: %s" % dataType
10261018

10271019
# subclass of them can not be deserialized in JVM
10281020
if type(obj) not in _acceptable_types[_type]:
@@ -1040,7 +1032,7 @@ def _verify_type(obj, dataType):
10401032

10411033
elif isinstance(dataType, StructType):
10421034
if len(obj) != len(dataType.fields):
1043-
raise ValueError("Length of object (%d) does not match with"
1035+
raise ValueError("Length of object (%d) does not match with "
10441036
"length of fields (%d)" % (len(obj), len(dataType.fields)))
10451037
for v, f in zip(obj, dataType.fields):
10461038
_verify_type(v, f.dataType)

0 commit comments

Comments
 (0)