Skip to content

Commit 81172aa

Browse files
author
Davies Liu
committed
fix predict
1 parent 84324fb commit 81172aa

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.mllib.tree.model
1919

20+
import org.apache.spark.api.java.JavaRDD
2021
import org.apache.spark.annotation.Experimental
2122
import org.apache.spark.mllib.tree.configuration.Algo._
2223
import org.apache.spark.rdd.RDD
@@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
5253
features.map(x => predict(x))
5354
}
5455

56+
57+
/**
58+
* Predict values for the given data set using the model trained.
59+
*
60+
* @param features JavaRDD representing data points to be predicted
61+
* @return JavaRDD of predictions for each of the given data points
62+
*/
63+
def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
64+
predict(features.rdd)
65+
}
66+
5567
/**
5668
* Get number of nodes in tree, including leaf nodes.
5769
*/

python/pyspark/mllib/tree.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,13 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
124124
Predict: 0.0
125125
Else (feature 0 > 0.0)
126126
Predict: 1.0
127-
>>> model.predict(array([1.0])) > 0
128-
True
129-
>>> model.predict(array([0.0])) == 0
130-
True
127+
>>> model.predict(array([1.0]))
128+
1.0
129+
>>> model.predict(array([0.0]))
130+
0.0
131+
>>> rdd = sc.parallelize([[1.0], [0.0]])
132+
>>> model.predict(rdd).collect()
133+
[1.0, 0.0]
131134
"""
132135
return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
133136
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
@@ -170,14 +173,13 @@ def trainRegressor(data, categoricalFeaturesInfo,
170173
... ]
171174
>>>
172175
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
173-
>>> model.predict(array([0.0, 1.0])) == 1
174-
True
175-
>>> model.predict(array([0.0, 0.0])) == 0
176-
True
177-
>>> model.predict(SparseVector(2, {1: 1.0})) == 1
178-
True
179-
>>> model.predict(SparseVector(2, {1: 0.0})) == 0
180-
True
176+
>>> model.predict(SparseVector(2, {1: 1.0}))
177+
1.0
178+
>>> model.predict(SparseVector(2, {1: 0.0}))
179+
0.0
180+
>>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
181+
>>> model.predict(rdd).collect()
182+
[1.0, 0.0]
181183
"""
182184
return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
183185
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)

0 commit comments

Comments
 (0)