Skip to content

Commit bc53ce8

Browse files
committed
fix NaiveBayes
1 parent 8e5604b commit bc53ce8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

python/pyspark/mllib/classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class NaiveBayesModel(object):
8484
- pi: vector of logs of class priors (dimension C)
8585
- theta: matrix of logs of class conditional probabilities (CxD)
8686
87-
>>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
87+
>>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 1.0]).reshape(3,3)
8888
>>> model = NaiveBayes.train(sc.parallelize(data))
8989
>>> model.predict(array([0.0, 1.0]))
9090
0
@@ -98,7 +98,7 @@ def __init__(self, pi, theta):
9898

9999
def predict(self, x):
100100
"""Return the most likely class for a data vector x"""
101-
return numpy.argmax(self.pi + dot(x, self.theta))
101+
return numpy.argmax(self.pi + dot(x, self.theta.transpose()))
102102

103103
class NaiveBayes(object):
104104
@classmethod

0 commit comments

Comments
 (0)