Skip to content

Commit b97c184

Browse files
committed
minimal change to LBFGS
1 parent 9ebadcc commit b97c184

File tree

1 file changed

+9
-11
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/optimization

1 file changed

+9
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,15 @@ object LBFGS extends Logging {
199199
val n = weights.length
200200
val bcWeights = data.context.broadcast(weights)
201201

202-
val (gradientSum, lossSum) = data.mapPartitions { iter =>
203-
val cumGrad = Vectors.dense(new Array[Double](n))
204-
val thisWeights = Vectors.fromBreeze(bcWeights.value)
205-
var loss = 0.0
206-
iter.foreach { case (label, features) =>
207-
loss += localGradient.compute(features, label, thisWeights, cumGrad)
208-
}
209-
Iterator((cumGrad.toBreeze.asInstanceOf[BDV[Double]], loss))
210-
}.reduce { case ((grad1, loss1), (grad2, loss2)) =>
211-
(grad1 += grad2, loss1 + loss2)
212-
}
202+
val (gradientSum, lossSum) = data.aggregate((BDV.zeros[Double](n), 0.0))(
203+
seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
204+
val l = localGradient.compute(
205+
features, label, Vectors.fromBreeze(bcWeights.value), Vectors.fromBreeze(grad))
206+
(grad, loss + l)
207+
},
208+
combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
209+
(grad1 += grad2, loss1 + loss2)
210+
})
213211

214212
/**
215213
* regVal is sum of weight squares if it's L2 updater;

0 commit comments

Comments
 (0)