Skip to content

Commit 6e248dd

Browse files
committed
PEP8 fixes, added test cases.
1 parent 14f36c3 commit 6e248dd

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1762,14 +1762,15 @@ def toPandas(self):
17621762
else:
17631763
dtype = {}
17641764
columns_with_null_int = set()
1765+
17651766
def null_handler(rows, columns_with_null_int):
17661767
for row in rows:
17671768
row = row.asDict()
17681769
for column in columns_with_null_int:
17691770
val = row[column]
17701771
dt = dtype[column]
17711772
if val is not None:
1772-
if abs(val) > 16777216: # Max value before np.float32 loses precision.
1773+
if abs(val) > 16777216: # Max value before np.float32 loses precision.
17731774
val = np.float64(val)
17741775
dt = np.float64
17751776
dtype[column] = np.float64
@@ -1778,7 +1779,7 @@ def null_handler(rows, columns_with_null_int):
17781779
row[column] = val
17791780
row = Row(**row)
17801781
yield row
1781-
row_handler = lambda x,y: x
1782+
row_handler = lambda x, y: x
17821783
for field in self.schema:
17831784
pandas_type = _to_corrected_pandas_type(field.dataType)
17841785
if pandas_type in (np.int8, np.int16, np.int32) and field.nullable:
@@ -1787,8 +1788,8 @@ def null_handler(rows, columns_with_null_int):
17871788
pandas_type = np.float32
17881789
if pandas_type is not None:
17891790
dtype[field.name] = pandas_type
1790-
1791-
pdf = pd.DataFrame.from_records(row_handler(self.collect(), columns_with_null_int), columns=self.columns)
1791+
collected_rows = row_handler(self.collect(), columns_with_null_int)
1792+
pdf = pd.DataFrame.from_records(collected_rows, columns=self.columns)
17921793

17931794
for f, t in dtype.items():
17941795
pdf[f] = pdf[f].astype(t, copy=False)

python/pyspark/sql/tests.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2495,17 +2495,22 @@ def count_bucketed_cols(names, table="pyspark_bucket"):
24952495
def test_to_pandas(self):
24962496
import numpy as np
24972497
schema = StructType().add("a", IntegerType()).add("b", StringType())\
2498-
.add("c", BooleanType()).add("d", FloatType())
2498+
.add("c", BooleanType()).add("d", FloatType())\
2499+
.add("e", IntegerType()).add("f", IntegerType())\
2500+
.add("g", IntegerType())
24992501
data = [
2500-
(1, "foo", True, 3.0), (2, "foo", True, 5.0),
2501-
(3, "bar", False, -1.0), (4, "bar", False, 6.0),
2502+
(1, "foo", True, 3.0, 1, 16777218, None), (2, "foo", True, 5.0, 2, 16777220, None),
2503+
(3, "bar", False, -1.0, 3, 1, None), (4, "bar", False, 6.0, None, None, None),
25022504
]
25032505
df = self.spark.createDataFrame(data, schema)
25042506
types = df.toPandas().dtypes
25052507
self.assertEquals(types[0], np.int32)
25062508
self.assertEquals(types[1], np.object)
25072509
self.assertEquals(types[2], np.bool)
25082510
self.assertEquals(types[3], np.float32)
2511+
self.assertEquals(types[4], np.float32)
2512+
self.assertEquals(types[5], np.float64)
2513+
self.assertEquals(types[6], np.float32)
25092514

25102515
def test_create_dataframe_from_array_of_long(self):
25112516
import array

0 commit comments

Comments
 (0)