Skip to content

Commit aaf2b73

Browse files
mengxrrxin
authored andcommitted
[SPARK-2361][MLLIB] Use broadcast instead of serializing data directly into task closure
We saw task serialization problems with large feature dimension, which could be avoid if we don't serialize data directly into task but use broadcast variables. This PR uses broadcast in both training and prediction and adds tests to make sure the task size is small. Author: Xiangrui Meng <[email protected]> Closes apache#1427 from mengxr/broadcast-new and squashes the following commits: b9a1228 [Xiangrui Meng] style update b97c184 [Xiangrui Meng] minimal change to LBFGS 9ebadcc [Xiangrui Meng] add task size test to RowMatrix 9427bf0 [Xiangrui Meng] add task size tests to linear methods e0a5cf2 [Xiangrui Meng] add task size test to GD 28a8411 [Xiangrui Meng] add test for NaiveBayes 380778c [Xiangrui Meng] update KMeans test bccab92 [Xiangrui Meng] add task size test to LBFGS 02103ba [Xiangrui Meng] remove print e73d68e [Xiangrui Meng] update tests for k-means 174cb15 [Xiangrui Meng] use local-cluster for test with a small akka.frameSize 1928a5a [Xiangrui Meng] add test for KMeans task size e00c2da [Xiangrui Meng] use broadcast in GD, KMeans 010d076 [Xiangrui Meng] modify NaiveBayesModel and GLM to use broadcast
1 parent b547f69 commit aaf2b73

File tree

19 files changed

+330
-70
lines changed

19 files changed

+330
-70
lines changed

mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,13 @@ class NaiveBayesModel private[mllib] (
5454
}
5555
}
5656

57-
override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict)
57+
override def predict(testData: RDD[Vector]): RDD[Double] = {
58+
val bcModel = testData.context.broadcast(this)
59+
testData.mapPartitions { iter =>
60+
val model = bcModel.value
61+
iter.map(model.predict)
62+
}
63+
}
5864

5965
override def predict(testData: Vector): Double = {
6066
labels(brzArgmax(brzPi + brzTheta * testData.toBreeze))

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,21 @@ class KMeans private (
165165
val activeCenters = activeRuns.map(r => centers(r)).toArray
166166
val costAccums = activeRuns.map(_ => sc.accumulator(0.0))
167167

168+
val bcActiveCenters = sc.broadcast(activeCenters)
169+
168170
// Find the sum and count of points mapping to each center
169171
val totalContribs = data.mapPartitions { points =>
170-
val runs = activeCenters.length
171-
val k = activeCenters(0).length
172-
val dims = activeCenters(0)(0).vector.length
172+
val thisActiveCenters = bcActiveCenters.value
173+
val runs = thisActiveCenters.length
174+
val k = thisActiveCenters(0).length
175+
val dims = thisActiveCenters(0)(0).vector.length
173176

174177
val sums = Array.fill(runs, k)(BDV.zeros[Double](dims).asInstanceOf[BV[Double]])
175178
val counts = Array.fill(runs, k)(0L)
176179

177180
points.foreach { point =>
178181
(0 until runs).foreach { i =>
179-
val (bestCenter, cost) = KMeans.findClosest(activeCenters(i), point)
182+
val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)
180183
costAccums(i) += cost
181184
sums(i)(bestCenter) += point.vector
182185
counts(i)(bestCenter) += 1
@@ -264,16 +267,17 @@ class KMeans private (
264267
// to their squared distance from that run's current centers
265268
var step = 0
266269
while (step < initializationSteps) {
270+
val bcCenters = data.context.broadcast(centers)
267271
val sumCosts = data.flatMap { point =>
268272
(0 until runs).map { r =>
269-
(r, KMeans.pointCost(centers(r), point))
273+
(r, KMeans.pointCost(bcCenters.value(r), point))
270274
}
271275
}.reduceByKey(_ + _).collectAsMap()
272276
val chosen = data.mapPartitionsWithIndex { (index, points) =>
273277
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
274278
points.flatMap { p =>
275279
(0 until runs).filter { r =>
276-
rand.nextDouble() < 2.0 * KMeans.pointCost(centers(r), p) * k / sumCosts(r)
280+
rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
277281
}.map((_, p))
278282
}
279283
}.collect()
@@ -286,9 +290,10 @@ class KMeans private (
286290
// Finally, we might have a set of more than k candidate centers for each run; weigh each
287291
// candidate by the number of points in the dataset mapping to it and run a local k-means++
288292
// on the weighted centers to pick just k of them
293+
val bcCenters = data.context.broadcast(centers)
289294
val weightMap = data.flatMap { p =>
290295
(0 until runs).map { r =>
291-
((r, KMeans.findClosest(centers(r), p)._1), 1.0)
296+
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
292297
}
293298
}.reduceByKey(_ + _).collectAsMap()
294299
val finalCenters = (0 until runs).map { r =>

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
3838
/** Maps given points to their cluster indices. */
3939
def predict(points: RDD[Vector]): RDD[Int] = {
4040
val centersWithNorm = clusterCentersWithNorm
41-
points.map(p => KMeans.findClosest(centersWithNorm, new BreezeVectorWithNorm(p))._1)
41+
val bcCentersWithNorm = points.context.broadcast(centersWithNorm)
42+
points.map(p => KMeans.findClosest(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))._1)
4243
}
4344

4445
/** Maps given points to their cluster indices. */
@@ -51,7 +52,8 @@ class KMeansModel private[mllib] (val clusterCenters: Array[Vector]) extends Ser
5152
*/
5253
def computeCost(data: RDD[Vector]): Double = {
5354
val centersWithNorm = clusterCentersWithNorm
54-
data.map(p => KMeans.pointCost(centersWithNorm, new BreezeVectorWithNorm(p))).sum()
55+
val bcCentersWithNorm = data.context.broadcast(centersWithNorm)
56+
data.map(p => KMeans.pointCost(bcCentersWithNorm.value, new BreezeVectorWithNorm(p))).sum()
5557
}
5658

5759
private def clusterCentersWithNorm: Iterable[BreezeVectorWithNorm] =

mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ object GradientDescent extends Logging {
163163

164164
// Initialize weights as a column vector
165165
var weights = Vectors.dense(initialWeights.toArray)
166+
val n = weights.size
166167

167168
/**
168169
* For the first iteration, the regVal will be initialized as sum of weight squares
@@ -172,12 +173,13 @@ object GradientDescent extends Logging {
172173
weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2
173174

174175
for (i <- 1 to numIterations) {
176+
val bcWeights = data.context.broadcast(weights)
175177
// Sample a subset (fraction miniBatchFraction) of the total data
176178
// compute and sum up the subgradients on this subset (this is one map-reduce)
177179
val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i)
178-
.aggregate((BDV.zeros[Double](weights.size), 0.0))(
180+
.aggregate((BDV.zeros[Double](n), 0.0))(
179181
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
180-
val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad))
182+
val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad))
181183
(grad, loss + l)
182184
},
183185
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,14 @@ object LBFGS extends Logging {
195195

196196
override def calculate(weights: BDV[Double]) = {
197197
// Have a local copy to avoid the serialization of CostFun object which is not serializable.
198-
val localData = data
199198
val localGradient = gradient
199+
val n = weights.length
200+
val bcWeights = data.context.broadcast(weights)
200201

201-
val (gradientSum, lossSum) = localData.aggregate((BDV.zeros[Double](weights.size), 0.0))(
202+
val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
202203
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
203204
val l = localGradient.compute(
204-
features, label, Vectors.fromBreeze(weights), Vectors.fromBreeze(grad))
205+
features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
205206
(grad, loss + l)
206207
},
207208
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>

mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double
5656
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
5757
// and intercept is needed.
5858
val localWeights = weights
59+
val bcWeights = testData.context.broadcast(localWeights)
5960
val localIntercept = intercept
60-
61-
testData.map(v => predictPoint(v, localWeights, localIntercept))
61+
testData.mapPartitions { iter =>
62+
val w = bcWeights.value
63+
iter.map(v => predictPoint(v, w, localIntercept))
64+
}
6265
}
6366

6467
/**

mllib/src/test/java/org/apache/spark/mllib/classification/JavaLogisticRegressionSuite.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,6 @@ public void runLRUsingStaticMethods() {
9292
testRDD.rdd(), 100, 1.0, 1.0);
9393

9494
int numAccurate = validatePrediction(validationData, model);
95-
System.out.println(numAccurate);
9695
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
9796
}
98-
9997
}

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.scalatest.Matchers
2525

2626
import org.apache.spark.mllib.linalg.Vectors
2727
import org.apache.spark.mllib.regression._
28-
import org.apache.spark.mllib.util.LocalSparkContext
28+
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
2929

3030
object LogisticRegressionSuite {
3131

@@ -126,3 +126,19 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
126126
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
127127
}
128128
}
129+
130+
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
131+
132+
test("task size should be small in both training and prediction") {
133+
val m = 4
134+
val n = 200000
135+
val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
136+
val random = new Random(idx)
137+
iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
138+
}.cache()
139+
// If we serialize data directly in the task closure, the size of the serialized task would be
140+
// greater than 1MB and hence Spark would throw an error.
141+
val model = LogisticRegressionWithSGD.train(points, 2)
142+
val predictions = model.predict(points.map(_.features))
143+
}
144+
}

mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
2323

2424
import org.apache.spark.mllib.linalg.Vectors
2525
import org.apache.spark.mllib.regression.LabeledPoint
26-
import org.apache.spark.mllib.util.LocalSparkContext
26+
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
2727

2828
object NaiveBayesSuite {
2929

@@ -96,3 +96,21 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
9696
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
9797
}
9898
}
99+
100+
class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {
101+
102+
test("task size should be small in both training and prediction") {
103+
val m = 10
104+
val n = 200000
105+
val examples = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
106+
val random = new Random(idx)
107+
iter.map { i =>
108+
LabeledPoint(random.nextInt(2), Vectors.dense(Array.fill(n)(random.nextDouble())))
109+
}
110+
}
111+
// If we serialize data directly in the task closure, the size of the serialized task would be
112+
// greater than 1MB and hence Spark would throw an error.
113+
val model = NaiveBayes.train(examples)
114+
val predictions = model.predict(examples.map(_.features))
115+
}
116+
}

mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,16 @@
1717

1818
package org.apache.spark.mllib.classification
1919

20-
import scala.util.Random
2120
import scala.collection.JavaConversions._
22-
23-
import org.scalatest.FunSuite
21+
import scala.util.Random
2422

2523
import org.jblas.DoubleMatrix
24+
import org.scalatest.FunSuite
2625

2726
import org.apache.spark.SparkException
28-
import org.apache.spark.mllib.regression._
29-
import org.apache.spark.mllib.util.LocalSparkContext
3027
import org.apache.spark.mllib.linalg.Vectors
28+
import org.apache.spark.mllib.regression._
29+
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
3130

3231
object SVMSuite {
3332

@@ -193,3 +192,19 @@ class SVMSuite extends FunSuite with LocalSparkContext {
193192
new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
194193
}
195194
}
195+
196+
class SVMClusterSuite extends FunSuite with LocalClusterSparkContext {
197+
198+
test("task size should be small in both training and prediction") {
199+
val m = 4
200+
val n = 200000
201+
val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
202+
val random = new Random(idx)
203+
iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
204+
}.cache()
205+
// If we serialize data directly in the task closure, the size of the serialized task would be
206+
// greater than 1MB and hence Spark would throw an error.
207+
val model = SVMWithSGD.train(points, 2)
208+
val predictions = model.predict(points.map(_.features))
209+
}
210+
}

0 commit comments

Comments
 (0)