Skip to content

Commit 010d076

Browse files
committed
modify NaiveBayesModel and GLM to use broadcast
1 parent cb09e93 commit 010d076

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] (
5454
}
5555
}
5656

57-
override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
57+
override def predict(testData: RDD[Vector]): RDD[Double] = {
58+
val bcModel = testData.context.broadcast(this)
59+
testData.mapPartitions { iter =>
60+
val model = bcModel.value
61+
iter.map(model.predict)
62+
}
63+
}
5864

5965
override def predict(testData: Vector): Double = {
6066
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
5656
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
5757
// and intercept is needed.
5858
val localWeights = weights
59+
val bcWeights = testData.context.broadcast(localWeights)
5960
val localIntercept = intercept
60-
61-
testData.map(v => predictPoint(v, localWeights, localIntercept))
61+
testData.mapPartitions { iter =>
62+
val w = bcWeights.value
63+
iter.map(v => predictPoint(v, w, localIntercept))
64+
}
6265
}
6366

6467
/**

0 commit comments

Comments
 (0)