Skip to content

Commit 200bef0

Browse files
committed
optimize computeYtY and updateBlock
1 parent 16788a6 commit 200bef0

File tree

1 file changed

+45
-35
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/recommendation

1 file changed

+45
-35
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -210,21 +210,47 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
210210
*/
211211
def computeYtY(factors: RDD[(Int, Array[Array[Double]])]) = {
212212
if (implicitPrefs) {
213-
Option(
214-
factors.flatMapValues { case factorArray =>
215-
factorArray.view.map { vector =>
216-
val x = new DoubleMatrix(vector)
217-
x.mmul(x.transpose())
218-
}
219-
}.reduceByKeyLocally((a, b) => a.addi(b))
220-
.values
221-
.reduce((a, b) => a.addi(b))
222-
)
213+
val n = rank * (rank + 1) / 2
214+
val LYtY = factors.values.aggregate(new DoubleMatrix(n))( seqOp = (L, Y) => {
215+
Y.foreach(y => dspr(1.0, new DoubleMatrix(y), L))
216+
L
217+
}, combOp = (L1, L2) => {
218+
L1.addi(L2)
219+
})
220+
val YtY = new DoubleMatrix(rank, rank)
221+
fillFullMatrix(LYtY, YtY)
222+
Option(YtY)
223223
} else {
224224
None
225225
}
226226
}
227227

228+
/**
229+
* Adding x * x.t to a matrix, the same as BLAS's DSPR.
230+
*
231+
* @param x a vector of length n
232+
* @param L the lower triangular part of the matrix packed in an array (row major)
233+
*/
234+
private def dspr(alpha: Double, x: DoubleMatrix, L: DoubleMatrix) = {
235+
val n = x.length
236+
var i = 0
237+
var j = 0
238+
var idx = 0
239+
var axi = 0.0
240+
val xd = x.data
241+
val Ld = L.data
242+
while (i < n) {
243+
axi = alpha * xd(i)
244+
j = 0
245+
while (j <= i) {
246+
Ld(idx) += axi * xd(j)
247+
j += 1
248+
idx += 1
249+
}
250+
i += 1
251+
}
252+
}
253+
228254
/**
229255
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
230256
*/
@@ -376,18 +402,21 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
376402
for (productBlock <- 0 until numBlocks) {
377403
for (p <- 0 until blockFactors(productBlock).length) {
378404
val x = new DoubleMatrix(blockFactors(productBlock)(p))
379-
fillXtX(x, tempXtX)
405+
tempXtX.fill(0.0)
406+
dspr(1.0, x, tempXtX)
380407
val (us, rs) = inLinkBlock.ratingsForBlock(productBlock)(p)
381408
for (i <- 0 until us.length) {
382409
implicitPrefs match {
383410
case false =>
384411
userXtX(us(i)).addi(tempXtX)
412+
// dspr(1.0, x, userXtX(us(i)))
385413
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
386414
case true =>
387415
// Extension to the original paper to handle rs(i) < 0. confidence is a function
388416
// of |rs(i)| instead so that it is never negative:
389417
val confidence = 1 + alpha * abs(rs(i))
390-
userXtX(us(i)).addi(tempXtX.mul(confidence - 1))
418+
SimpleBlas.axpy(confidence - 1.0, tempXtX, userXtX(us(i)))
419+
// dspr(confidence - 1.0, x, userXtX(us(i)))
391420
// For rs(i) < 0, the corresponding entry in P is 0 now, not 1 -- negative rs(i)
392421
// means we try to reconstruct 0. We add terms only where P = 1, so, term below
393422
// is now only added for rs(i) > 0:
@@ -400,38 +429,19 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
400429
}
401430

402431
// Solve the least-squares problem for each user and return the new feature vectors
403-
userXtX.zipWithIndex.map{ case (triangularXtX, index) =>
432+
userXtX.zip(userXy).map { case (triangularXtX, rhs) =>
404433
// Compute the full XtX matrix from the lower-triangular part we got above
405434
fillFullMatrix(triangularXtX, fullXtX)
406435
// Add regularization
407436
(0 until rank).foreach(i => fullXtX.data(i*rank + i) += lambda)
408437
// Solve the resulting matrix, which is symmetric and positive-definite
409438
implicitPrefs match {
410-
case false => Solve.solvePositive(fullXtX, userXy(index)).data
411-
case true => Solve.solvePositive(fullXtX.add(YtY.value.get), userXy(index)).data
439+
case false => Solve.solvePositive(fullXtX, rhs).data
440+
case true => Solve.solvePositive(fullXtX.addi(YtY.value.get), rhs).data
412441
}
413442
}
414443
}
415444

416-
/**
417-
* Set xtxDest to the lower-triangular part of x transpose * x. For efficiency in summing
418-
* these matrices, we store xtxDest as only rank * (rank+1) / 2 values, namely the values
419-
* at (0,0), (1,0), (1,1), (2,0), (2,1), (2,2), etc in that order.
420-
*/
421-
private def fillXtX(x: DoubleMatrix, xtxDest: DoubleMatrix) {
422-
var i = 0
423-
var pos = 0
424-
while (i < x.length) {
425-
var j = 0
426-
while (j <= i) {
427-
xtxDest.data(pos) = x.data(i) * x.data(j)
428-
pos += 1
429-
j += 1
430-
}
431-
i += 1
432-
}
433-
}
434-
435445
/**
436446
* Given a triangular matrix in the order of fillXtX above, compute the full symmetric square
437447
* matrix that it represents, storing it into destMatrix.
@@ -455,7 +465,7 @@ class ALS private (var numBlocks: Int, var rank: Int, var iterations: Int, var l
455465

456466

457467
/**
458-
* Top-level methods for calling Alternating Least Squares (ALS) matrix factorizaton.
468+
* Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
459469
*/
460470
object ALS {
461471
/**

0 commit comments

Comments
 (0)