Skip to content

Commit 04c4829

Browse files
committed
Merge pull request #1 from mengxr/SPARK-4409
Some updates for linear algebra utilities
2 parents 10a63a6 + 80cfa29 commit 04c4829

File tree

1 file changed

+72
-70
lines changed

1 file changed

+72
-70
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -361,47 +361,47 @@ object SparseMatrix {
361361
* @param entries Array of (i, j, value) tuples
362362
* @return The corresponding `SparseMatrix`
363363
*/
364-
def fromCOO(numRows: Int, numCols: Int, entries: Array[(Int, Int, Double)]): SparseMatrix = {
365-
val sortedEntries = entries.sortBy(v => (v._2, v._1))
366-
val colPtrs = new Array[Int](numCols + 1)
367-
var nnz = 0
368-
var lastCol = -1
369-
var lastIndex = -1
370-
sortedEntries.foreach { case (i, j, v) =>
371-
require(i >= 0 && j >= 0, "Negative indices given. Please make sure all indices are " +
372-
s"greater than or equal to zero. i: $i, j: $j, value: $v")
373-
if (v != 0.0) {
374-
while (j != lastCol) {
375-
colPtrs(lastCol + 1) = nnz
376-
lastCol += 1
377-
}
378-
val index = j * numRows + i
379-
if (lastIndex != index) {
380-
nnz += 1
381-
lastIndex = index
382-
}
364+
def fromCOO(numRows: Int, numCols: Int, entries: Iterable[(Int, Int, Double)]): SparseMatrix = {
365+
val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1))
366+
val numEntries = sortedEntries.size
367+
if (sortedEntries.nonEmpty) {
368+
// Since the entries are sorted by column index, we only need to check the first and the last.
369+
for (col <- Seq(sortedEntries.head._2, sortedEntries.last._2)) {
370+
require(col >= 0 && col < numCols, s"Column index out of range [0, $numCols): $col.")
383371
}
384372
}
385-
while (numCols > lastCol) {
386-
colPtrs(lastCol + 1) = nnz
387-
lastCol += 1
388-
}
389-
val values = new Array[Double](nnz)
390-
val rowIndices = new Array[Int](nnz)
391-
lastIndex = -1
392-
var cnt = -1
393-
sortedEntries.foreach { case (i, j, v) =>
394-
if (v != 0.0) {
395-
val index = j * numRows + i
396-
if (lastIndex != index) {
397-
cnt += 1
398-
lastIndex = index
373+
val colPtrs = new Array[Int](numCols + 1)
374+
val rowIndices = MArrayBuilder.make[Int]
375+
rowIndices.sizeHint(numEntries)
376+
val values = MArrayBuilder.make[Double]
377+
values.sizeHint(numEntries)
378+
var nnz = 0
379+
var prevCol = 0
380+
var prevRow = -1
381+
var prevVal = 0.0
382+
// Append a dummy entry to include the last one at the end of the loop.
383+
(sortedEntries.view :+ (numRows, numCols, 1.0)).foreach { case (i, j, v) =>
384+
if (v != 0) {
385+
if (i == prevRow && j == prevCol) {
386+
prevVal += v
387+
} else {
388+
if (prevVal != 0) {
389+
require(prevRow >= 0 && prevRow < numRows,
390+
s"Row index out of range [0, $numRows): $prevRow.")
391+
nnz += 1
392+
rowIndices += prevRow
393+
values += prevVal
394+
}
395+
prevRow = i
396+
prevVal = v
397+
while (prevCol < j) {
398+
colPtrs(prevCol + 1) = nnz
399+
prevCol += 1
400+
}
399401
}
400-
values(cnt) += v
401-
rowIndices(cnt) = i
402402
}
403403
}
404-
new SparseMatrix(numRows, numCols, colPtrs.toArray, rowIndices, values)
404+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), values.result())
405405
}
406406

407407
/**
@@ -413,54 +413,59 @@ object SparseMatrix {
413413
new SparseMatrix(n, n, (0 to n).toArray, (0 until n).toArray, Array.fill(n)(1.0))
414414
}
415415

416-
/** Generates the skeleton of a random `SparseMatrix` with a given random number generator. */
416+
/**
417+
* Generates the skeleton of a random `SparseMatrix` with a given random number generator.
418+
* The values of the matrix returned are undefined.
419+
*/
417420
private def genRandMatrix(
418421
numRows: Int,
419422
numCols: Int,
420423
density: Double,
421424
rng: Random): SparseMatrix = {
422-
require(density >= 0.0 && density <= 1.0, "density must be a double in the range " +
423-
s"0.0 <= d <= 1.0. Currently, density: $density")
424-
val length = math.ceil(numRows * numCols * density).toInt
425-
var i = 0
425+
require(numRows > 0, s"numRows must be greater than 0 but got $numRows")
426+
require(numCols > 0, s"numCols must be greater than 0 but got $numCols")
427+
require(density >= 0.0 && density <= 1.0,
428+
s"density must be a double in the range 0.0 <= d <= 1.0. Currently, density: $density")
429+
val size = numRows.toLong * numCols
430+
val expected = size * density
431+
assert(expected < Int.MaxValue,
432+
"The expected number of nonzeros cannot be greater than Int.MaxValue.")
433+
val nnz = math.ceil(expected).toInt
426434
if (density == 0.0) {
427-
return new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1),
428-
Array[Int](), Array[Double]())
435+
new SparseMatrix(numRows, numCols, new Array[Int](numCols + 1), Array[Int](), Array[Double]())
429436
} else if (density == 1.0) {
430-
val rowIndices = Array.tabulate(numCols, numRows)((j, i) => i).flatten
431-
return new SparseMatrix(numRows, numCols, (0 to numRows * numCols by numRows).toArray,
432-
rowIndices, new Array[Double](numRows * numCols))
433-
}
434-
if (density < 0.34) { // Expected number of iterations is less than 1.5 * length
437+
val colPtrs = Array.tabulate(numCols + 1)(j => j * numRows)
438+
val rowIndices = Array.tabulate(size.toInt)(idx => idx % numRows)
439+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](numRows * numCols))
440+
} else if (density < 0.34) {
441+
// draw-by-draw, expected number of iterations is less than 1.5 * nnz
435442
val entries = MHashSet[(Int, Int)]()
436-
while (entries.size < length) {
443+
while (entries.size < nnz) {
437444
entries += ((rng.nextInt(numRows), rng.nextInt(numCols)))
438445
}
439-
val entryList = entries.toArray.map(v => (v._1, v._2, 1.0))
440-
SparseMatrix.fromCOO(numRows, numCols, entryList)
441-
} else { // selection - rejection method
446+
SparseMatrix.fromCOO(numRows, numCols, entries.map(v => (v._1, v._2, 1.0)))
447+
} else {
448+
// selection-rejection method
449+
var idx = 0L
450+
var numSelected = 0
451+
var i = 0
442452
var j = 0
443-
val pool = numRows * numCols
444-
val rowIndexBuilder = new MArrayBuilder.ofInt
445453
val colPtrs = new Array[Int](numCols + 1)
446-
while (i < length && j < numCols) {
447-
var passedInPool = j * numRows
448-
var r = 0
449-
while (i < length && r < numRows) {
450-
if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
451-
rowIndexBuilder += r
452-
i += 1
454+
val rowIndices = new Array[Int](nnz)
455+
while (j < numCols && numSelected < nnz) {
456+
while (i < numRows && numSelected < nnz) {
457+
if (rng.nextDouble() < 1.0 * (nnz - numSelected) / (size - idx)) {
458+
rowIndices(numSelected) = i
459+
numSelected += 1
453460
}
454-
r += 1
455-
passedInPool += 1
461+
i += 1
462+
idx += 1
456463
}
464+
colPtrs(j + 1) = numSelected
457465
j += 1
458-
colPtrs(j) = i
459466
}
460-
val rowIndices = rowIndexBuilder.result()
461-
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](rowIndices.size))
467+
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, new Array[Double](nnz))
462468
}
463-
464469
}
465470

466471
/**
@@ -735,16 +740,13 @@ object Matrices {
735740
val numCols = matrices(0).numCols
736741
var hasSparse = false
737742
var numRows = 0
738-
var valsLength = 0
739743
matrices.foreach { mat =>
740744
require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " +
741745
"don't match!")
742746
mat match {
743747
case sparse: SparseMatrix =>
744748
hasSparse = true
745-
valsLength += sparse.values.length
746749
case dense: DenseMatrix =>
747-
valsLength += dense.values.length
748750
case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " +
749751
s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}")
750752
}

0 commit comments

Comments
 (0)