Skip to content

Commit a4f4dbf

Browse files
committed
add unit test for LR
1 parent 7521d1c commit a4f4dbf

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

python/pyspark/ml/classification.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,22 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
2626
HasRegParam):
2727
"""
2828
Logistic regression.
29+
30+
>>> from pyspark.sql import Row
31+
>>> from pyspark.mllib.linalg import Vectors
32+
>>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
33+
Row(label=1.0, features=Vectors.dense(1.0)), \
34+
Row(label=0.0, features=Vectors.sparse(1, [], []))]))
35+
>>> lr = LogisticRegression() \
36+
.setMaxIter(5) \
37+
.setRegParam(0.01)
38+
>>> model = lr.fit(dataset)
39+
>>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
40+
>>> print model.transform(test0).first().prediction
41+
0.0
42+
>>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
43+
>>> print model.transform(test1).first().prediction
44+
1.0
2945
"""
3046

3147
def __init__(self):
@@ -52,3 +68,21 @@ def __init__(self, java_model):
5268
@property
5369
def _java_class(self):
5470
return "org.apache.spark.ml.classification.LogisticRegressionModel"
71+
72+
73+
if __name__ == "__main__":
74+
import doctest
75+
from pyspark.context import SparkContext
76+
from pyspark.sql import SQLContext
77+
globs = globals().copy()
78+
# The small batch size here ensures that we see multiple batches,
79+
# even in these small test examples:
80+
sc = SparkContext("local[2]", "ml.feature tests")
81+
sqlCtx = SQLContext(sc)
82+
globs['sc'] = sc
83+
globs['sqlCtx'] = sqlCtx
84+
(failure_count, test_count) = doctest.testmod(
85+
globs=globs, optionflags=doctest.ELLIPSIS)
86+
sc.stop()
87+
if failure_count:
88+
exit(-1)

python/pyspark/ml/feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _java_class(self):
7373
if __name__ == "__main__":
7474
import doctest
7575
from pyspark.context import SparkContext
76-
from pyspark.sql import Row, SQLContext
76+
from pyspark.sql import SQLContext
7777
globs = globals().copy()
7878
# The small batch size here ensures that we see multiple batches,
7979
# even in these small test examples:

0 commit comments

Comments
 (0)