@@ -361,47 +361,47 @@ object SparseMatrix {
361
361
* @param entries Array of (i, j, value) tuples
362
362
* @return The corresponding `SparseMatrix`
363
363
*/
364
- def fromCOO (numRows : Int , numCols : Int , entries : Array [(Int , Int , Double )]): SparseMatrix = {
365
- val sortedEntries = entries.sortBy(v => (v._2, v._1))
366
- val colPtrs = new Array [Int ](numCols + 1 )
367
- var nnz = 0
368
- var lastCol = - 1
369
- var lastIndex = - 1
370
- sortedEntries.foreach { case (i, j, v) =>
371
- require(i >= 0 && j >= 0 , " Negative indices given. Please make sure all indices are " +
372
- s " greater than or equal to zero. i: $i, j: $j, value: $v" )
373
- if (v != 0.0 ) {
374
- while (j != lastCol) {
375
- colPtrs(lastCol + 1 ) = nnz
376
- lastCol += 1
377
- }
378
- val index = j * numRows + i
379
- if (lastIndex != index) {
380
- nnz += 1
381
- lastIndex = index
382
- }
364
+ def fromCOO (numRows : Int , numCols : Int , entries : Iterable [(Int , Int , Double )]): SparseMatrix = {
365
+ val sortedEntries = entries.toSeq.sortBy(v => (v._2, v._1))
366
+ val numEntries = sortedEntries.size
367
+ if (sortedEntries.nonEmpty) {
368
+ // Since the entries are sorted by column index, we only need to check the first and the last.
369
+ for (col <- Seq (sortedEntries.head._2, sortedEntries.last._2)) {
370
+ require(col >= 0 && col < numCols, s " Column index out of range [0, $numCols): $col. " )
383
371
}
384
372
}
385
- while (numCols > lastCol) {
386
- colPtrs(lastCol + 1 ) = nnz
387
- lastCol += 1
388
- }
389
- val values = new Array [Double ](nnz)
390
- val rowIndices = new Array [Int ](nnz)
391
- lastIndex = - 1
392
- var cnt = - 1
393
- sortedEntries.foreach { case (i, j, v) =>
394
- if (v != 0.0 ) {
395
- val index = j * numRows + i
396
- if (lastIndex != index) {
397
- cnt += 1
398
- lastIndex = index
373
+ val colPtrs = new Array [Int ](numCols + 1 )
374
+ val rowIndices = MArrayBuilder .make[Int ]
375
+ rowIndices.sizeHint(numEntries)
376
+ val values = MArrayBuilder .make[Double ]
377
+ values.sizeHint(numEntries)
378
+ var nnz = 0
379
+ var prevCol = 0
380
+ var prevRow = - 1
381
+ var prevVal = 0.0
382
+ // Append a dummy entry to include the last one at the end of the loop.
383
+ (sortedEntries.view :+ (numRows, numCols, 1.0 )).foreach { case (i, j, v) =>
384
+ if (v != 0 ) {
385
+ if (i == prevRow && j == prevCol) {
386
+ prevVal += v
387
+ } else {
388
+ if (prevVal != 0 ) {
389
+ require(prevRow >= 0 && prevRow < numRows,
390
+ s " Row index out of range [0, $numRows): $prevRow. " )
391
+ nnz += 1
392
+ rowIndices += prevRow
393
+ values += prevVal
394
+ }
395
+ prevRow = i
396
+ prevVal = v
397
+ while (prevCol < j) {
398
+ colPtrs(prevCol + 1 ) = nnz
399
+ prevCol += 1
400
+ }
399
401
}
400
- values(cnt) += v
401
- rowIndices(cnt) = i
402
402
}
403
403
}
404
- new SparseMatrix (numRows, numCols, colPtrs.toArray , rowIndices, values)
404
+ new SparseMatrix (numRows, numCols, colPtrs, rowIndices.result() , values.result() )
405
405
}
406
406
407
407
/**
@@ -413,54 +413,59 @@ object SparseMatrix {
413
413
new SparseMatrix (n, n, (0 to n).toArray, (0 until n).toArray, Array .fill(n)(1.0 ))
414
414
}
415
415
416
- /** Generates the skeleton of a random `SparseMatrix` with a given random number generator. */
416
+ /**
417
+ * Generates the skeleton of a random `SparseMatrix` with a given random number generator.
418
+ * The values of the matrix returned are undefined.
419
+ */
417
420
private def genRandMatrix (
418
421
numRows : Int ,
419
422
numCols : Int ,
420
423
density : Double ,
421
424
rng : Random ): SparseMatrix = {
422
- require(density >= 0.0 && density <= 1.0 , " density must be a double in the range " +
423
- s " 0.0 <= d <= 1.0. Currently, density: $density" )
424
- val length = math.ceil(numRows * numCols * density).toInt
425
- var i = 0
425
+ require(numRows > 0 , s " numRows must be greater than 0 but got $numRows" )
426
+ require(numCols > 0 , s " numCols must be greater than 0 but got $numCols" )
427
+ require(density >= 0.0 && density <= 1.0 ,
428
+ s " density must be a double in the range 0.0 <= d <= 1.0. Currently, density: $density" )
429
+ val size = numRows.toLong * numCols
430
+ val expected = size * density
431
+ assert(expected < Int .MaxValue ,
432
+ " The expected number of nonzeros cannot be greater than Int.MaxValue." )
433
+ val nnz = math.ceil(expected).toInt
426
434
if (density == 0.0 ) {
427
- return new SparseMatrix (numRows, numCols, new Array [Int ](numCols + 1 ),
428
- Array [Int ](), Array [Double ]())
435
+ new SparseMatrix (numRows, numCols, new Array [Int ](numCols + 1 ), Array [Int ](), Array [Double ]())
429
436
} else if (density == 1.0 ) {
430
- val rowIndices = Array .tabulate(numCols, numRows)((j, i) => i).flatten
431
- return new SparseMatrix (numRows, numCols, ( 0 to numRows * numCols by numRows).toArray,
432
- rowIndices, new Array [Double ](numRows * numCols))
433
- }
434
- if (density < 0.34 ) { // Expected number of iterations is less than 1.5 * length
437
+ val colPtrs = Array .tabulate(numCols + 1 )(j => j * numRows)
438
+ val rowIndices = Array .tabulate(size.toInt)(idx => idx % numRows)
439
+ new SparseMatrix (numRows, numCols, colPtrs, rowIndices, new Array [Double ](numRows * numCols))
440
+ } else if (density < 0.34 ) {
441
+ // draw-by-draw, expected number of iterations is less than 1.5 * nnz
435
442
val entries = MHashSet [(Int , Int )]()
436
- while (entries.size < length ) {
443
+ while (entries.size < nnz ) {
437
444
entries += ((rng.nextInt(numRows), rng.nextInt(numCols)))
438
445
}
439
- val entryList = entries.toArray.map(v => (v._1, v._2, 1.0 ))
440
- SparseMatrix .fromCOO(numRows, numCols, entryList)
441
- } else { // selection - rejection method
446
+ SparseMatrix .fromCOO(numRows, numCols, entries.map(v => (v._1, v._2, 1.0 )))
447
+ } else {
448
+ // selection-rejection method
449
+ var idx = 0L
450
+ var numSelected = 0
451
+ var i = 0
442
452
var j = 0
443
- val pool = numRows * numCols
444
- val rowIndexBuilder = new MArrayBuilder .ofInt
445
453
val colPtrs = new Array [Int ](numCols + 1 )
446
- while (i < length && j < numCols) {
447
- var passedInPool = j * numRows
448
- var r = 0
449
- while (i < length && r < numRows) {
450
- if (rng.nextDouble() < 1.0 * (length - i) / (pool - passedInPool)) {
451
- rowIndexBuilder += r
452
- i += 1
454
+ val rowIndices = new Array [Int ](nnz)
455
+ while (j < numCols && numSelected < nnz) {
456
+ while (i < numRows && numSelected < nnz) {
457
+ if (rng.nextDouble() < 1.0 * (nnz - numSelected) / (size - idx)) {
458
+ rowIndices(numSelected) = i
459
+ numSelected += 1
453
460
}
454
- r += 1
455
- passedInPool += 1
461
+ i += 1
462
+ idx += 1
456
463
}
464
+ colPtrs(j + 1 ) = numSelected
457
465
j += 1
458
- colPtrs(j) = i
459
466
}
460
- val rowIndices = rowIndexBuilder.result()
461
- new SparseMatrix (numRows, numCols, colPtrs, rowIndices, new Array [Double ](rowIndices.size))
467
+ new SparseMatrix (numRows, numCols, colPtrs, rowIndices, new Array [Double ](nnz))
462
468
}
463
-
464
469
}
465
470
466
471
/**
@@ -735,16 +740,13 @@ object Matrices {
735
740
val numCols = matrices(0 ).numCols
736
741
var hasSparse = false
737
742
var numRows = 0
738
- var valsLength = 0
739
743
matrices.foreach { mat =>
740
744
require(numCols == mat.numCols, " The number of rows of the matrices in this sequence, " +
741
745
" don't match!" )
742
746
mat match {
743
747
case sparse : SparseMatrix =>
744
748
hasSparse = true
745
- valsLength += sparse.values.length
746
749
case dense : DenseMatrix =>
747
- valsLength += dense.values.length
748
750
case _ => throw new IllegalArgumentException (" Unsupported matrix format. Expected " +
749
751
s " SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}" )
750
752
}
0 commit comments