Skip to content

Commit 4a37b94

Browse files
mengxrmarkhamstra
authored andcommitted
[SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark
SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang Author: Xiangrui Meng <[email protected]> Closes apache#5619 from mengxr/SPARK-7036 and squashes the following commits: dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark (cherry picked from commit 686dd74) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent c1e2a19 commit 4a37b94

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

python/pyspark/mllib/recommendation.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pyspark.rdd import RDD
2222
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
2323
from pyspark.mllib.util import JavaLoader, JavaSaveable
24+
from pyspark.sql import DataFrame
2425

2526
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
2627

@@ -77,18 +78,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
7778
True
7879
7980
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
80-
>>> model.predict(2,2)
81+
>>> model.predict(2, 2)
82+
3.8...
83+
84+
>>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
85+
>>> model = ALS.train(df, 1, nonnegative=True, seed=10)
86+
>>> model.predict(2, 2)
8187
3.8...
8288
8389
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
84-
>>> model.predict(2,2)
90+
>>> model.predict(2, 2)
8591
0.4...
8692
8793
>>> import os, tempfile
8894
>>> path = tempfile.mkdtemp()
8995
>>> model.save(sc, path)
9096
>>> sameModel = MatrixFactorizationModel.load(sc, path)
91-
>>> sameModel.predict(2,2)
97+
>>> sameModel.predict(2, 2)
9298
0.4...
9399
>>> sameModel.predictAll(testset).collect()
94100
[Rating(...
@@ -124,13 +130,20 @@ class ALS(object):
124130

125131
@classmethod
126132
def _prepare(cls, ratings):
127-
assert isinstance(ratings, RDD), "ratings should be RDD"
133+
if isinstance(ratings, RDD):
134+
pass
135+
elif isinstance(ratings, DataFrame):
136+
ratings = ratings.rdd
137+
else:
138+
raise TypeError("Ratings should be represented by either an RDD or a DataFrame, "
139+
"but got %s." % type(ratings))
128140
first = ratings.first()
129-
if not isinstance(first, Rating):
130-
if isinstance(first, (tuple, list)):
131-
ratings = ratings.map(lambda x: Rating(*x))
132-
else:
133-
raise ValueError("rating should be RDD of Rating or tuple/list")
141+
if isinstance(first, Rating):
142+
pass
143+
elif isinstance(first, (tuple, list)):
144+
ratings = ratings.map(lambda x: Rating(*x))
145+
else:
146+
raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first))
134147
return ratings
135148

136149
@classmethod
@@ -151,8 +164,11 @@ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alp
151164
def _test():
152165
import doctest
153166
import pyspark.mllib.recommendation
167+
from pyspark.sql import SQLContext
154168
globs = pyspark.mllib.recommendation.__dict__.copy()
155-
globs['sc'] = SparkContext('local[4]', 'PythonTest')
169+
sc = SparkContext('local[4]', 'PythonTest')
170+
globs['sc'] = sc
171+
globs['sqlContext'] = SQLContext(sc)
156172
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
157173
globs['sc'].stop()
158174
if failure_count:

0 commit comments

Comments
 (0)