Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 97fedf1

Browse files
committed
[SPARK-7432] [MLLIB] fix flaky CrossValidator doctest
The new test uses CV to compare `maxIter=0` and `maxIter=1`, and validate on the evaluation result. jkbradley Author: Xiangrui Meng <[email protected]> Closes apache#6572 from mengxr/SPARK-7432 and squashes the following commits: c236bb8 [Xiangrui Meng] fix flacky cv doctest (cherry picked from commit bd97840) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 92a6778 commit 97fedf1

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

python/pyspark/ml/tuning.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -91,20 +91,19 @@ class CrossValidator(Estimator):
9191
>>> from pyspark.ml.evaluation import BinaryClassificationEvaluator
9292
>>> from pyspark.mllib.linalg import Vectors
9393
>>> dataset = sqlContext.createDataFrame(
94-
... [(Vectors.dense([0.0, 1.0]), 0.0),
95-
... (Vectors.dense([1.0, 2.0]), 1.0),
96-
... (Vectors.dense([0.55, 3.0]), 0.0),
97-
... (Vectors.dense([0.45, 4.0]), 1.0),
98-
... (Vectors.dense([0.51, 5.0]), 1.0)] * 10,
94+
... [(Vectors.dense([0.0]), 0.0),
95+
... (Vectors.dense([0.4]), 1.0),
96+
... (Vectors.dense([0.5]), 0.0),
97+
... (Vectors.dense([0.6]), 1.0),
98+
... (Vectors.dense([1.0]), 1.0)] * 10,
9999
... ["features", "label"])
100100
>>> lr = LogisticRegression()
101-
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1, 5]).build()
101+
>>> grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
102102
>>> evaluator = BinaryClassificationEvaluator()
103103
>>> cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
104-
>>> # SPARK-7432: The following test is flaky.
105-
>>> # cvModel = cv.fit(dataset)
106-
>>> # expected = lr.fit(dataset, {lr.maxIter: 5}).transform(dataset)
107-
>>> # cvModel.transform(dataset).collect() == expected.collect()
104+
>>> cvModel = cv.fit(dataset)
105+
>>> evaluator.evaluate(cvModel.transform(dataset))
106+
0.8333...
108107
"""
109108

110109
# a placeholder to make it appear in the generated doc

0 commit comments

Comments
 (0)