Skip to content

Commit 561d31d

Browse files
committed
[SPARK-4614][MLLIB] Slight API changes in Matrix and Matrices
Before we have a full picture of the operators we want to add, it might be safer to hide `Matrix.transposeMultiply` in 1.2.0. Another update we want to change is `Matrix.randn` and `Matrix.rand`, both of which should take a `Random` implementation. Otherwise, it is very likely to produce inconsistent RDDs. I also added some unit tests for matrix factory methods. All APIs are new in 1.2, so there is no incompatible changes. brkyvz Author: Xiangrui Meng <[email protected]> Closes apache#3468 from mengxr/SPARK-4614 and squashes the following commits: 3b0e4e2 [Xiangrui Meng] add mima excludes 6bfd8a4 [Xiangrui Meng] hide transposeMultiply; add rng to rand and randn; add unit tests
1 parent 288ce58 commit 561d31d

File tree

3 files changed

+65
-11
lines changed

3 files changed

+65
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@
1717

1818
package org.apache.spark.mllib.linalg
1919

20-
import java.util.Arrays
20+
import java.util.{Random, Arrays}
2121

2222
import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM}
2323

24-
import org.apache.spark.util.random.XORShiftRandom
25-
2624
/**
2725
* Trait for a local matrix.
2826
*/
@@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable {
6765
}
6866

6967
/** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */
70-
def transposeMultiply(y: DenseMatrix): DenseMatrix = {
68+
private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = {
7169
val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix]
7270
BLAS.gemm(true, false, 1.0, this, y, 0.0, C)
7371
C
7472
}
7573

7674
/** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */
77-
def transposeMultiply(y: DenseVector): DenseVector = {
75+
private[mllib] def transposeMultiply(y: DenseVector): DenseVector = {
7876
val output = new DenseVector(new Array[Double](numCols))
7977
BLAS.gemv(true, 1.0, this, y, 0.0, output)
8078
output
@@ -291,22 +289,22 @@ object Matrices {
291289
* Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers.
292290
* @param numRows number of rows of the matrix
293291
* @param numCols number of columns of the matrix
292+
* @param rng a random number generator
294293
* @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1)
295294
*/
296-
def rand(numRows: Int, numCols: Int): Matrix = {
297-
val rand = new XORShiftRandom
298-
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble()))
295+
def rand(numRows: Int, numCols: Int, rng: Random): Matrix = {
296+
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble()))
299297
}
300298

301299
/**
302300
* Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers.
303301
* @param numRows number of rows of the matrix
304302
* @param numCols number of columns of the matrix
303+
* @param rng a random number generator
305304
* @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1)
306305
*/
307-
def randn(numRows: Int, numCols: Int): Matrix = {
308-
val rand = new XORShiftRandom
309-
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian()))
306+
def randn(numRows: Int, numCols: Int, rng: Random): Matrix = {
307+
new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian()))
310308
}
311309

312310
/**

mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717

1818
package org.apache.spark.mllib.linalg
1919

20+
import java.util.Random
21+
22+
import org.mockito.Mockito.when
2023
import org.scalatest.FunSuite
24+
import org.scalatest.mock.MockitoSugar._
2125

2226
class MatricesSuite extends FunSuite {
2327
test("dense matrix construction") {
@@ -112,4 +116,50 @@ class MatricesSuite extends FunSuite {
112116
assert(sparseMat(0, 1) === 10.0)
113117
assert(sparseMat.values(2) === 10.0)
114118
}
119+
120+
test("zeros") {
121+
val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix]
122+
assert(mat.numRows === 2)
123+
assert(mat.numCols === 3)
124+
assert(mat.values.forall(_ == 0.0))
125+
}
126+
127+
test("ones") {
128+
val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix]
129+
assert(mat.numRows === 2)
130+
assert(mat.numCols === 3)
131+
assert(mat.values.forall(_ == 1.0))
132+
}
133+
134+
test("eye") {
135+
val mat = Matrices.eye(2).asInstanceOf[DenseMatrix]
136+
assert(mat.numCols === 2)
137+
assert(mat.numCols === 2)
138+
assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0))
139+
}
140+
141+
test("rand") {
142+
val rng = mock[Random]
143+
when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0)
144+
val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix]
145+
assert(mat.numRows === 2)
146+
assert(mat.numCols === 2)
147+
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
148+
}
149+
150+
test("randn") {
151+
val rng = mock[Random]
152+
when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0)
153+
val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix]
154+
assert(mat.numRows === 2)
155+
assert(mat.numCols === 2)
156+
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
157+
}
158+
159+
test("diag") {
160+
val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix]
161+
assert(mat.numRows === 2)
162+
assert(mat.numCols === 2)
163+
assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0))
164+
}
115165
}

project/MimaExcludes.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ object MimaExcludes {
4747
"org.apache.spark.SparkStageInfoImpl.this"),
4848
ProblemFilters.exclude[MissingMethodProblem](
4949
"org.apache.spark.SparkStageInfo.submissionTime")
50+
) ++ Seq(
51+
// SPARK-4614
52+
ProblemFilters.exclude[MissingMethodProblem](
53+
"org.apache.spark.mllib.linalg.Matrices.randn"),
54+
ProblemFilters.exclude[MissingMethodProblem](
55+
"org.apache.spark.mllib.linalg.Matrices.rand")
5056
)
5157

5258
case v if v.startsWith("1.2") =>

0 commit comments

Comments
 (0)