@@ -199,17 +199,15 @@ object LBFGS extends Logging {
199
199
val n = weights.length
200
200
val bcWeights = data.context.broadcast(weights)
201
201
202
- val (gradientSum, lossSum) = data.mapPartitions { iter =>
203
- val cumGrad = Vectors .dense(new Array [Double ](n))
204
- val thisWeights = Vectors .fromBreeze(bcWeights.value)
205
- var loss = 0.0
206
- iter.foreach { case (label, features) =>
207
- loss += localGradient.compute(features, label, thisWeights, cumGrad)
208
- }
209
- Iterator ((cumGrad.toBreeze.asInstanceOf [BDV [Double ]], loss))
210
- }.reduce { case ((grad1, loss1), (grad2, loss2)) =>
211
- (grad1 += grad2, loss1 + loss2)
212
- }
202
+ val (gradientSum, lossSum) = data.aggregate((BDV .zeros[Double ](n), 0.0 ))(
203
+ seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) =>
204
+ val l = localGradient.compute(
205
+ features, label, Vectors .fromBreeze(bcWeights.value), Vectors .fromBreeze(grad))
206
+ (grad, loss + l)
207
+ },
208
+ combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) =>
209
+ (grad1 += grad2, loss1 + loss2)
210
+ })
213
211
214
212
/**
215
213
* regVal is sum of weight squares if it's L2 updater;
0 commit comments