Skip to content

Commit e7850ed

Browse files
author
Li Pu
committed
use aggregate and axpy
1 parent 827411b commit e7850ed

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,16 @@ object EigenValueDecomposition {
3636
* The caller needs to ensure that the input matrix is real symmetric. This function requires
3737
* memory for `n*(4*k+4)` doubles.
3838
*
39-
* @param mul a function that multiplies the symmetric matrix with a Vector.
39+
* @param mul a function that multiplies the symmetric matrix with a DenseVector.
4040
* @param n dimension of the square matrix (maximum Int.MaxValue).
4141
* @param k number of leading eigenvalues required.
4242
* @param tol tolerance of the eigs computation.
4343
* @return a dense vector of eigenvalues in descending order and a dense matrix of eigenvectors
4444
* (columns of the matrix). The number of computed eigenvalues might be smaller than k.
4545
*/
46-
private[mllib] def symmetricEigs(mul: Vector => Vector, n: Int, k: Int, tol: Double)
46+
private[mllib] def symmetricEigs(mul: DenseVector => DenseVector, n: Int, k: Int, tol: Double)
4747
: (BDV[Double], BDM[Double]) = {
48+
// TODO: remove this function and use eigs in breeze when switching breeze version
4849
require(n > k, s"Number of required eigenvalues $k must be smaller than matrix dimension $n")
4950

5051
val arpack = ARPACK.getInstance()
@@ -84,7 +85,7 @@ object EigenValueDecomposition {
8485
val outputOffset = ipntr(1) - 1
8586
val x = w(inputOffset until inputOffset + n)
8687
val y = w(outputOffset until outputOffset + n)
87-
y := BDV(mul(Vectors.fromBreeze(x)).toArray)
88+
y := BDV(mul(Vectors.fromBreeze(x).asInstanceOf[DenseVector]).toArray)
8889
// call ARPACK
8990
arpack.dsaupd(ido, bmat, n, which, nev.`val`, tolW, resid, ncv, v, n, iparam, ipntr,
9091
workd, workl, workl.length, info)

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ package org.apache.spark.mllib.linalg.distributed
1919

2020
import java.util
2121

22-
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
22+
import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV}
23+
import breeze.linalg.{svd => brzSvd, axpy => brzAxpy}
2324
import breeze.numerics.{sqrt => brzSqrt}
2425
import com.github.fommil.netlib.BLAS.{getInstance => blas}
2526

@@ -201,16 +202,28 @@ class RowMatrix(
201202
}
202203

203204
/**
204-
* Multiply the Gramian matrix `A^T A` by a Vector on the right.
205+
* Multiply the Gramian matrix `A^T A` by a DenseVector on the right.
205206
*
206-
* @param v a local vector whose length must match the number of columns of this matrix
207-
* @return a local DenseVector representing the product
207+
* @param v a local DenseVector whose length must match the number of columns of this matrix.
208+
* @return a local DenseVector representing the product.
208209
*/
209-
private[mllib] def multiplyGramianMatrix(v: Vector): Vector = {
210-
val bv = rows.map{
211-
row => row.toBreeze * row.toBreeze.dot(v.toBreeze)
212-
}.reduce( (x: BV[Double], y: BV[Double]) => x + y )
213-
Vectors.fromBreeze(bv)
210+
private[mllib] def multiplyGramianMatrix(v: DenseVector): DenseVector = {
211+
val n = numCols().toInt
212+
213+
val bv = rows.aggregate(BDV.zeros[Double](n))(
214+
seqOp = (U, r) => {
215+
val rBrz = r.toBreeze
216+
val a = rBrz.dot(v.toBreeze)
217+
rBrz match {
218+
case _: BDV[_] => brzAxpy(a, rBrz.asInstanceOf[BDV[Double]], U)
219+
case _: BSV[_] => brzAxpy(a, rBrz.asInstanceOf[BSV[Double]], U)
220+
}
221+
U
222+
},
223+
combOp = (U1, U2) => U1 += U2
224+
)
225+
226+
new DenseVector(bv.data)
214227
}
215228

216229
/**
@@ -243,7 +256,7 @@ class RowMatrix(
243256
*
244257
* The decomposition is computed by providing a function that multiples a vector with A'A to
245258
* ARPACK, and iteratively invoking ARPACK-dsaupd on master node, from which we recover S and V.
246-
* Then we compute U via easy matrix multiplication as U = A * (V * S-1).
259+
* Then we compute U via easy matrix multiplication as U = A * (V * S^{-1}).
247260
* Note that this approach requires `O(nnz(A))` time.
248261
*
249262
* When the requested eigenvalues k = n, a non-sparse implementation will be used, which requires

0 commit comments

Comments
 (0)