@@ -26,6 +26,22 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
26
26
HasRegParam ):
27
27
"""
28
28
Logistic regression.
29
+
30
+ >>> from pyspark.sql import Row
31
+ >>> from pyspark.mllib.linalg import Vectors
32
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
33
+ Row(label=1.0, features=Vectors.dense(1.0)), \
34
+ Row(label=0.0, features=Vectors.sparse(1, [], []))]))
35
+ >>> lr = LogisticRegression() \
36
+ .setMaxIter(5) \
37
+ .setRegParam(0.01)
38
+ >>> model = lr.fit(dataset)
39
+ >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
40
+ >>> print model.transform(test0).first().prediction
41
+ 0.0
42
+ >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
43
+ >>> print model.transform(test1).first().prediction
44
+ 1.0
29
45
"""
30
46
31
47
def __init__ (self ):
@@ -52,3 +68,21 @@ def __init__(self, java_model):
52
68
@property
53
69
def _java_class (self ):
54
70
return "org.apache.spark.ml.classification.LogisticRegressionModel"
71
+
72
+
73
+ if __name__ == "__main__" :
74
+ import doctest
75
+ from pyspark .context import SparkContext
76
+ from pyspark .sql import SQLContext
77
+ globs = globals ().copy ()
78
+ # The small batch size here ensures that we see multiple batches,
79
+ # even in these small test examples:
80
+ sc = SparkContext ("local[2]" , "ml.feature tests" )
81
+ sqlCtx = SQLContext (sc )
82
+ globs ['sc' ] = sc
83
+ globs ['sqlCtx' ] = sqlCtx
84
+ (failure_count , test_count ) = doctest .testmod (
85
+ globs = globs , optionflags = doctest .ELLIPSIS )
86
+ sc .stop ()
87
+ if failure_count :
88
+ exit (- 1 )
0 commit comments